1
#![forbid(unsafe_code)]
2

            
3
use std::collections::HashMap;
4
use std::fmt;
5

            
6
use merc_collections::ByteCompressedVec;
7
use merc_collections::CompressedEntry;
8
use merc_collections::CompressedVecMetrics;
9
use merc_collections::bytevec;
10
use merc_io::LargeFormatter;
11
use merc_utilities::MercError;
12
use merc_utilities::TagIndex;
13

            
14
use crate::LTS;
15
use crate::LabelIndex;
16
use crate::LabelTag;
17
use crate::StateIndex;
18
use crate::Transition;
19
use crate::TransitionLabel;
20

            
21
/// Represents a labelled transition system consisting of states with directed
22
/// labelled transitions between them.
23
///
24
/// # Details
25
///
26
/// Uses byte compressed vectors to store the states and their outgoing
27
/// transitions efficiently in memory.
28
#[derive(PartialEq, Eq, Clone)]
29
pub struct LabelledTransitionSystem<Label> {
30
    /// Encodes the states and their outgoing transitions.
31
    states: ByteCompressedVec<usize>,
32
    transition_labels: ByteCompressedVec<LabelIndex>,
33
    transition_to: ByteCompressedVec<StateIndex>,
34

            
35
    /// Keeps track of the labels for every index, and which of them are hidden.
36
    labels: Vec<Label>,
37

            
38
    /// The index of the initial state.
39
    initial_state: StateIndex,
40
}
41

            
42
impl<Label: TransitionLabel> LabelledTransitionSystem<Label> {
43
    /// Creates a new labelled transition system with the given transitions,
44
    /// labels, and hidden labels.
45
    ///
46
    /// The initial state is the state with the given index. `num_of_states` is
47
    /// the number of states in the LTS, if known. If it is not known, pass
48
    /// `None`. However, in that case the number of states will be determined
49
    /// based on the maximum state index in the transitions. And all states that
50
    /// do not have any outgoing transitions will simply be created as deadlock
51
    /// states.
52
20417
    pub fn new<I, F>(
53
20417
        initial_state: StateIndex,
54
20417
        num_of_states: Option<usize>,
55
20417
        mut transition_iter: F,
56
20417
        labels: Vec<Label>,
57
20417
    ) -> LabelledTransitionSystem<Label>
58
20417
    where
59
20417
        F: FnMut() -> I,
60
20417
        I: Iterator<Item = (StateIndex, LabelIndex, StateIndex)>,
61
    {
62
20417
        let mut states = ByteCompressedVec::new();
63
20417
        if let Some(num_of_states) = num_of_states {
64
20416
            states.resize_with(num_of_states, Default::default);
65
20416
        }
66

            
67
        // Count the number of transitions for every state
68
20417
        let mut num_of_transitions = 0;
69
17616510
        for (from, _, to) in transition_iter() {
70
            // Ensure that the states vector is large enough.
71
17616510
            if states.len() <= *from.max(to) {
72
2
                states.resize_with(*from.max(to) + 1, || 0);
73
17616508
            }
74

            
75
17616510
            states.update(*from, |start| *start += 1);
76
17616510
            num_of_transitions += 1;
77

            
78
17616510
            if let Some(num_of_states) = num_of_states {
79
17616505
                debug_assert!(
80
17616505
                    *from < num_of_states && *to < num_of_states,
81
                    "State index out of bounds: from {:?}, to {:?}, num_of_states {}",
82
                    from,
83
                    to,
84
                    num_of_states
85
                );
86
5
            }
87
        }
88

            
89
20417
        if initial_state.value() >= states.len() {
90
1080
            // Ensure that the initial state is a valid state (and all states before it exist).
91
1080
            states.resize_with(initial_state.value() + 1, Default::default);
92
19337
        }
93

            
94
        // Track the number of transitions before every state.
95
5252468
        states.fold(0, |count, start| {
96
5252468
            let result = count + *start;
97
5252468
            *start = count;
98
5252468
            result
99
5252468
        });
100

            
101
        // Place the transitions, and increment the end for every state.
102
20417
        let mut transition_labels = bytevec![LabelIndex::new(labels.len()); num_of_transitions];
103
20417
        let mut transition_to = bytevec![StateIndex::new(states.len()); num_of_transitions];
104
17616510
        for (from, label, to) in transition_iter() {
105
17616510
            states.update(*from, |start| {
106
17616510
                transition_labels.set(*start, label);
107
17616510
                transition_to.set(*start, to);
108
17616510
                *start += 1
109
17616510
            });
110
        }
111

            
112
        // Reset the offset.
113
5252468
        states.fold(0, |previous, start| {
114
5252468
            let result = *start;
115
5252468
            *start = previous;
116
5252468
            result
117
5252468
        });
118

            
119
        // Add the sentinel state.
120
20417
        states.push(transition_labels.len());
121

            
122
20417
        LabelledTransitionSystem::from_raw_parts(initial_state, states, transition_labels, transition_to, labels)
123
20417
    }
124

            
125
    /// Constructs a LTS by a successor function for every state.
126
    pub fn with_successors<F, I>(
127
        initial_state: StateIndex,
128
        num_of_states: usize,
129
        labels: Vec<Label>,
130
        mut successors: F,
131
    ) -> Self
132
    where
133
        F: FnMut(StateIndex) -> I,
134
        I: Iterator<Item = (LabelIndex, StateIndex)>,
135
    {
136
        let mut states = ByteCompressedVec::new();
137
        states.resize_with(num_of_states, Default::default);
138

            
139
        let mut transition_labels = ByteCompressedVec::with_capacity(num_of_states, 16usize.bytes_required());
140
        let mut transition_to = ByteCompressedVec::with_capacity(num_of_states, num_of_states.bytes_required());
141

            
142
        for state_index in 0..num_of_states {
143
            let state_index = StateIndex::new(state_index);
144
            states.update(*state_index, |entry| {
145
                *entry = transition_labels.len();
146
            });
147

            
148
            for (label, to) in successors(state_index) {
149
                transition_labels.push(label);
150
                transition_to.push(to);
151
            }
152
        }
153

            
154
        // Add the sentinel state.
155
        states.push(transition_labels.len());
156

            
157
        Self::from_raw_parts(initial_state, states, transition_labels, transition_to, labels)
158
    }
159

            
160
    /// Consumes the current LTS and merges it with another one, returning the merged LTS.
161
    ///
162
    /// # Details
163
    ///
164
    /// Internally this works by offsetting the state indices of the other LTS by the number of states
165
    /// in the current LTS, and combining the action labels. The offset is returned such that
166
    /// can find the states of the other LTS in the merged LTS as the initial state of the other LTS.
167
4409
    fn merge_disjoint_impl<L: LTS<Label = Label>>(mut self, other: &L) -> (Self, StateIndex) {
168
        // Determine the combination of action labels
169
4409
        let mut all_labels = self.labels().to_vec();
170
15147
        for label in other.labels() {
171
15147
            if !all_labels.contains(label) {
172
3
                all_labels.push(label.clone());
173
15144
            }
174
        }
175

            
176
4409
        let label_indices: HashMap<Label, TagIndex<usize, LabelTag>> = HashMap::from_iter(
177
4409
            all_labels
178
4409
                .iter()
179
4409
                .enumerate()
180
15147
                .map(|(i, label)| (label.clone(), LabelIndex::new(i))),
181
        );
182

            
183
4409
        let total_number_of_states = self.num_of_states() + other.num_of_states();
184

            
185
        // Reserve space for the right LTS.
186
4409
        self.states
187
4409
            .reserve(other.num_of_states(), total_number_of_states.bytes_required());
188
4409
        self.transition_labels
189
4409
            .reserve(other.num_of_transitions(), all_labels.len().bytes_required());
190
4409
        self.transition_to
191
4409
            .reserve(other.num_of_transitions(), total_number_of_states.bytes_required());
192

            
193
4409
        let offset = self.num_of_states();
194

            
195
        // Remove the sentinel state temporarily. This breaks the state invariant, but we will add it back later.
196
4409
        self.states.pop();
197

            
198
        // Add vertices for the other LTS that are offset by the number of states in self
199
563403
        for state_index in other.iter_states() {
200
            // Add a new state for every state in the other LTS
201
563403
            self.states.push(self.num_of_transitions());
202
1754285
            for transition in other.outgoing_transitions(state_index) {
203
1754181
                // Add the transitions of the other LTS, offsetting the state indices
204
1754181
                self.transition_to.push(StateIndex::new(transition.to.value() + offset));
205
1754181

            
206
1754181
                // Map the label to the new index in all_labels
207
1754181
                let label_name = &other.labels()[transition.label.value()];
208
1754181
                self.transition_labels
209
1754181
                    .push(*label_indices.get(label_name).expect("Label should exist in all_labels"));
210
1754181
            }
211
        }
212

            
213
        // Add back the sentinel state
214
4409
        self.states.push(self.num_of_transitions());
215
4409
        debug_assert_eq!(self.num_of_states(), total_number_of_states);
216

            
217
4409
        (
218
4409
            Self::from_raw_parts(
219
4409
                self.initial_state,
220
4409
                self.states,
221
4409
                self.transition_labels,
222
4409
                self.transition_to,
223
4409
                all_labels,
224
4409
            ),
225
4409
            StateIndex::new(offset + other.initial_state_index().value()),
226
4409
        )
227
4409
    }
228

            
229
    /// Creates a labelled transition system from another one, given the permutation of state indices.
230
    ///
231
    /// The permutation maps old state indices to new state indices, i.e.,
232
    /// `permutation(old) = new`. The transition arrays are rebuilt so that
233
    /// transitions are contiguous per new state index, and all transition
234
    /// targets are updated to reference the new state indices.
235
5303
    pub fn new_from_permutation<P>(lts: Self, permutation: P) -> Self
236
5303
    where
237
5303
        P: Fn(StateIndex) -> StateIndex + Copy,
238
    {
239
        // Build the inverse permutation: inverse[new_index] = old_index
240
5303
        let mut inverse = vec![StateIndex::new(0); lts.num_of_states()];
241
1731971
        for state_index in lts.iter_states() {
242
1731971
            inverse[*permutation(state_index)] = state_index;
243
1731971
        }
244

            
245
        // Rebuild transition arrays in the order of the new state indices.
246
5303
        let mut states = ByteCompressedVec::new();
247
5303
        let mut transition_labels = ByteCompressedVec::new();
248
5303
        let mut transition_to = ByteCompressedVec::new();
249

            
250
1731971
        for old_index in &inverse {
251
1731971
            states.push(transition_labels.len());
252

            
253
1731971
            let start = lts.states.index(**old_index);
254
1731971
            let end = lts.states.index(**old_index + 1);
255

            
256
1767851
            for i in start..end {
257
1767851
                transition_labels.push(lts.transition_labels.index(i));
258
1767851
                transition_to.push(permutation(lts.transition_to.index(i)));
259
1767851
            }
260
        }
261

            
262
        // Add the sentinel state.
263
5303
        states.push(transition_labels.len());
264

            
265
5303
        Self::from_raw_parts(
266
5303
            permutation(lts.initial_state),
267
5303
            states,
268
5303
            transition_labels,
269
5303
            transition_to,
270
5303
            lts.labels,
271
        )
272
5303
    }
273

            
274
    /// Consumes the LTS and relabels its transition labels according to the
275
    /// given mapping.
276
    ///
277
    /// Note that this only relabels the visible labels, since the hidden label
278
    /// must be kept consistent.
279
    pub fn relabel<L, F>(self, labelling: F) -> Result<LabelledTransitionSystem<L>, MercError>
280
    where
281
        F: Fn(Label) -> Result<L, MercError>,
282
        L: TransitionLabel,
283
    {
284
        let new_labels: Vec<L> = self
285
            .labels
286
            .into_iter()
287
            .enumerate()
288
            .map(|(index, label)| {
289
                if index == 0 {
290
                    Ok(L::tau_label())
291
                } else {
292
                    labelling(label)
293
                }
294
            })
295
            .collect::<Result<_, _>>()?;
296

            
297
        Ok(LabelledTransitionSystem::from_raw_parts(
298
            self.initial_state,
299
            self.states,
300
            self.transition_labels,
301
            self.transition_to,
302
            new_labels,
303
        ))
304
    }
305

            
306
    /// A [Self::relabel] variant that also applies to the tau label. This is
307
    /// useful when the tau label cannot be constructed from `L::tau_label()`.
308
2
    pub fn relabel_all<L, F>(self, labelling: F) -> Result<LabelledTransitionSystem<L>, MercError>
309
2
    where
310
2
        F: Fn(Label) -> Result<L, MercError>,
311
2
        L: TransitionLabel,
312
    {
313
2
        let new_labels: Vec<L> = self.labels.into_iter().map(labelling).collect::<Result<_, _>>()?;
314

            
315
2
        Ok(LabelledTransitionSystem::from_raw_parts(
316
2
            self.initial_state,
317
2
            self.states,
318
2
            self.transition_labels,
319
2
            self.transition_to,
320
2
            new_labels,
321
2
        ))
322
2
    }
323

            
324
    /// Constructs a [LabelledTransitionSystem] directly from its raw internal arrays.
325
    ///
326
    /// The `states` array must contain one entry per state holding the start offset of that
327
    /// state's transitions in the transition arrays, plus a sentinel entry at the end equal
328
    /// to the total number of transitions. `transition_labels` and `transition_to` must have
329
    /// equal length and all indices they contain must be in bounds.
330
    ///
331
    /// # Panics
332
    ///
333
    /// Panics (in debug mode) if the invariants of the internal representation are violated.
334
293075
    pub fn from_raw_parts(
335
293075
        initial_state: StateIndex,
336
293075
        states: ByteCompressedVec<usize>,
337
293075
        transition_labels: ByteCompressedVec<LabelIndex>,
338
293075
        transition_to: ByteCompressedVec<StateIndex>,
339
293075
        labels: Vec<Label>,
340
293075
    ) -> Self {
341
293075
        let lts = LabelledTransitionSystem {
342
293075
            initial_state,
343
293075
            states,
344
293075
            transition_labels,
345
293075
            transition_to,
346
293075
            labels,
347
293075
        };
348
293075
        lts.assert_valid();
349
293075
        lts
350
293075
    }
351

            
352
    /// Checks that the internal representation satisfies all structural invariants.
353
293075
    pub fn assert_valid(&self) {
354
293075
        let num_states = self.num_of_states();
355
293075
        let num_transitions = self.num_of_transitions();
356

            
357
293075
        debug_assert!(
358
293075
            !self.states.is_empty(),
359
            "states array must have at least one entry (the sentinel)"
360
        );
361

            
362
293075
        debug_assert!(
363
293075
            self.initial_state.value() < num_states,
364
            "initial_state {:?} is out of bounds (num_states: {})",
365
            self.initial_state,
366
            num_states
367
        );
368

            
369
293075
        debug_assert_eq!(
370
293075
            self.states.index(num_states),
371
            num_transitions,
372
            "sentinel value must equal the number of transitions"
373
        );
374

            
375
293075
        debug_assert_eq!(
376
293075
            self.transition_labels.len(),
377
293075
            self.transition_to.len(),
378
            "transition_labels and transition_to must have equal length"
379
        );
380

            
381
293075
        assert!(
382
293075
            self.labels
383
293075
                .first()
384
293075
                .expect("At least one label (the hidden label) must be provided")
385
293075
                .is_tau_label(),
386
            "The first label must be the hidden label."
387
        );
388

            
389
76803324
        for i in 0..num_states {
390
76803324
            debug_assert!(
391
76803324
                self.states.index(i) <= self.states.index(i + 1),
392
                "state {i} has offset {} which is greater than successor offset {}",
393
                self.states.index(i),
394
                self.states.index(i + 1)
395
            );
396
        }
397

            
398
105732716
        for i in 0..num_transitions {
399
105732716
            let label = self.transition_labels.index(i);
400
105732716
            debug_assert!(
401
105732716
                label.value() < self.labels.len(),
402
                "transition {i} references label index {} which is out of bounds (num_labels: {})",
403
                label.value(),
404
                self.labels.len()
405
            );
406

            
407
105732716
            let to = self.transition_to.index(i);
408
105732716
            debug_assert!(
409
105732716
                to.value() < num_states,
410
                "transition {i} references target state {} which is out of bounds (num_states: {})",
411
                to.value(),
412
                num_states
413
            );
414
        }
415
293075
    }
416

            
417
    /// Returns metrics about the LTS.
418
    pub fn metrics(&self) -> LtsMetrics {
419
        LtsMetrics {
420
            num_of_states: self.num_of_states(),
421
            num_of_labels: self.num_of_labels(),
422
            num_of_transitions: self.num_of_transitions(),
423
            state_metrics: self.states.metrics(),
424
            transition_labels_metrics: self.transition_labels.metrics(),
425
            transition_to_metrics: self.transition_to.metrics(),
426
        }
427
    }
428
}
429

            
430
impl<L: TransitionLabel> LTS for LabelledTransitionSystem<L> {
431
    type Label = L;
432

            
433
22818
    fn initial_state_index(&self) -> StateIndex {
434
22818
        self.initial_state
435
22818
    }
436

            
437
214935395
    fn outgoing_transitions(&self, state_index: StateIndex) -> impl Iterator<Item = Transition> + '_ {
438
214935395
        let start = self.states.index(*state_index);
439
214935395
        let end = self.states.index(*state_index + 1);
440

            
441
214935395
        (start..end).map(move |i| Transition {
442
227025104
            label: self.transition_labels.index(i),
443
227025104
            to: self.transition_to.index(i),
444
227025104
        })
445
214935395
    }
446

            
447
1660364
    fn iter_states(&self) -> impl Iterator<Item = StateIndex> + '_ {
448
1660364
        (0..self.num_of_states()).map(StateIndex::new)
449
1660364
    }
450

            
451
52741288
    fn num_of_states(&self) -> usize {
452
        // Remove the sentinel state.
453
52741288
        self.states.len() - 1
454
52741288
    }
455

            
456
4794187
    fn num_of_labels(&self) -> usize {
457
4794187
        self.labels.len()
458
4794187
    }
459

            
460
6240137
    fn num_of_transitions(&self) -> usize {
461
6240137
        self.transition_labels.len()
462
6240137
    }
463

            
464
4431400
    fn labels(&self) -> &[Self::Label] {
465
4431400
        &self.labels[0..]
466
4431400
    }
467

            
468
473630324
    fn is_hidden_label(&self, label_index: LabelIndex) -> bool {
469
473630324
        label_index.value() == 0
470
473630324
    }
471

            
472
4409
    fn merge_disjoint<T: LTS<Label = Self::Label>>(self, other: &T) -> (Self, StateIndex) {
473
4409
        self.merge_disjoint_impl(other)
474
4409
    }
475
}
476

            
477
/// Metrics for a labelled transition system.
478
#[derive(Debug, Clone)]
479
pub struct LtsMetrics {
480
    /// The number of states in the LTS.
481
    pub num_of_states: usize,
482
    pub state_metrics: CompressedVecMetrics,
483
    /// The number of transitions in the LTS.
484
    pub num_of_transitions: usize,
485
    pub transition_labels_metrics: CompressedVecMetrics,
486
    pub transition_to_metrics: CompressedVecMetrics,
487
    /// The number of action labels in the LTS.
488
    pub num_of_labels: usize,
489
}
490

            
491
impl fmt::Display for LtsMetrics {
492
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
493
        // Print some information about the LTS.
494
        writeln!(f, "Number of states: {}", LargeFormatter(self.num_of_states))?;
495
        writeln!(f, "Number of action labels: {}", LargeFormatter(self.num_of_labels))?;
496
        writeln!(
497
            f,
498
            "Number of transitions: {}\n",
499
            LargeFormatter(self.num_of_transitions)
500
        )?;
501
        writeln!(f, "Memory usage:")?;
502
        writeln!(f, "States {}", self.state_metrics)?;
503
        writeln!(f, "Transition labels {}", self.transition_labels_metrics)?;
504
        write!(f, "Transition to {}", self.transition_to_metrics)
505
    }
506
}
507

            
508
/// Checks that two LTSs are equivalent, for testing purposes.
509
#[cfg(test)]
510
300
pub fn check_equivalent<L: LTS>(lts: &L, lts_read: &L) {
511
300
    println!("LTS labels: {:?}", lts.labels());
512
300
    println!("Read LTS labels: {:?}", lts_read.labels());
513

            
514
    // If labels are not used, the number of labels may be less. So find a remapping of old labels to new labels.
515
300
    let mapping = lts
516
300
        .labels()
517
300
        .iter()
518
300
        .enumerate()
519
1800
        .map(|(_i, label)| lts_read.labels().iter().position(|l| l == label))
520
300
        .collect::<Vec<_>>();
521

            
522
    // Print the mapping
523
900
    for (i, m) in mapping.iter().enumerate() {
524
900
        println!("Label {} mapped to {:?}", i, m);
525
900
    }
526

            
527
300
    assert_eq!(lts.num_of_transitions(), lts_read.num_of_transitions());
528

            
529
    // Check that all the outgoing transitions are the same.
530
300000
    for state_index in lts.iter_states() {
531
300000
        let transitions: Vec<_> = lts.outgoing_transitions(state_index).collect();
532
300000
        let transitions_read: Vec<_> = if state_index.value() < lts_read.num_of_states() {
533
299995
            lts_read.outgoing_transitions(state_index).collect()
534
        } else {
535
            // Treat as deadlock if state_index is out of bounds
536
5
            Vec::new()
537
        };
538

            
539
        // Check that transitions are the same, modulo label remapping.
540
300220
        transitions.iter().for_each(|t| {
541
300220
            let mapped_label = mapping[t.label.value()].expect(&format!("Label {} should be found", t.label));
542
300220
            assert!(
543
300220
                transitions_read
544
300220
                    .iter()
545
400250
                    .any(|tr| tr.to == t.to && tr.label.value() == mapped_label)
546
            );
547
300220
        });
548
    }
549
300
}
550

            
551
#[cfg(test)]
552
mod tests {
553
    use merc_io::DumpFiles;
554
    use merc_utilities::random_test;
555

            
556
    use crate::LTS;
557
    use crate::num_reachable_states;
558
    use crate::random_lts;
559
    use crate::write_aut;
560

            
561
    #[test]
562
    #[cfg_attr(miri, ignore)] // Miri is too slow
563
1
    fn test_random_labelled_transition_system_merge_disjoint() {
564
100
        random_test(100, |rng| {
565
100
            let mut files = DumpFiles::new("test_random_merge_disjoint");
566

            
567
100
            let left = random_lts::<String, _>(rng, 1000, 20);
568
100
            files.dump("left.aut", |w| write_aut(w, &left)).unwrap();
569

            
570
100
            let right = random_lts::<String, _>(rng, 1000, 20);
571
100
            files.dump("right.aut", |w| write_aut(w, &right)).unwrap();
572

            
573
100
            let (merged, right_initial) = left.clone().merge_disjoint(&right);
574
100
            files.dump("merged.aut", |w| write_aut(w, &merged)).unwrap();
575

            
576
100
            assert_eq!(
577
100
                num_reachable_states(&left, left.initial_state_index()),
578
100
                num_reachable_states(&merged, merged.initial_state_index()),
579
                "The left LTS should be fully reachable in the merged LTS"
580
            );
581
100
            assert_eq!(
582
100
                num_reachable_states(&right, right.initial_state_index()),
583
100
                num_reachable_states(&merged, right_initial),
584
                "The right LTS should be fully reachable in the merged LTS"
585
            );
586
100
        });
587
1
    }
588
}