1
#![forbid(unsafe_code)]
2

            
3
use std::marker::PhantomData;
4

            
5
use merc_collections::ByteCompressedVec;
6
use merc_collections::bytevec;
7

            
8
use crate::LTS;
9
use crate::LabelIndex;
10
use crate::StateIndex;
11
use crate::Transition;
12

            
13
/// Stores the incoming transitions for a given labelled transition system.
14
pub struct IncomingTransitions<'a> {
15
    /// A flat list of all incoming transition labels in the LTS. They are stored in two separate
16
    /// arrays since the compression is based on the highest value.
17
    transition_labels: ByteCompressedVec<LabelIndex>,
18
    transition_from: ByteCompressedVec<StateIndex>,
19

            
20
    /// A mapping from the state to the `transition_labels` and
21
    /// `transition_from` that stores its incoming transitions.
22
    state2incoming: ByteCompressedVec<usize>,
23

            
24
    /// Marker to tie the lifetime of the incoming transitions to the LTS.
25
    _marker: PhantomData<&'a ()>,
26
}
27

            
28
impl<'a> IncomingTransitions<'a> {
29
600
    pub fn new(lts: &'a impl LTS) -> Self {
30
600
        let mut transition_labels = bytevec![LabelIndex::new(0); lts.num_of_transitions()];
31
600
        let mut transition_from = bytevec![StateIndex::new(0); lts.num_of_transitions()];
32
600
        let mut state2incoming = bytevec![0usize; lts.num_of_states()];
33

            
34
        // Count the number of incoming transitions for each state
35
6143
        for state_index in lts.iter_states() {
36
13460
            for transition in lts.outgoing_transitions(state_index) {
37
13460
                state2incoming.update(transition.to.value(), |start| *start += 1);
38
            }
39
        }
40

            
41
        // Compute the start offsets (prefix sum)
42
6143
        state2incoming.fold(0, |offset, start| {
43
6143
            let new_offset = offset + *start;
44
6143
            *start = offset;
45
6143
            new_offset
46
6143
        });
47

            
48
        // Place the transitions
49
6143
        for state_index in lts.iter_states() {
50
13460
            for transition in lts.outgoing_transitions(state_index) {
51
13460
                state2incoming.update(transition.to.value(), |start| {
52
13460
                    transition_labels.set(*start, transition.label);
53
13460
                    transition_from.set(*start, state_index);
54
13460
                    *start += 1;
55
13460
                });
56
            }
57
        }
58

            
59
6143
        state2incoming.fold(0, |previous, start| {
60
6143
            let result = *start;
61
6143
            *start = previous;
62
6143
            result
63
6143
        });
64

            
65
        // Add sentinel state
66
600
        state2incoming.push(transition_labels.len());
67

            
68
        // Sort the incoming transitions such that silent transitions come first.
69
        //
70
        // TODO: This could be more efficient by simply grouping them instead of sorting, perhaps some group using a predicate.
71
600
        let mut pairs = Vec::new();
72
6143
        for state_index in 0..lts.num_of_states() {
73
6143
            let start = state2incoming.index(state_index);
74
6143
            let end = state2incoming.index(state_index + 1);
75

            
76
            // Extract, sort, and put back
77
6143
            pairs.clear();
78
13460
            pairs.extend((start..end).map(|i| (transition_labels.index(i), transition_from.index(i))));
79
6143
            pairs.sort_unstable_by_key(|(label, _)| *label);
80

            
81
13460
            for (i, (label, from)) in pairs.iter().enumerate() {
82
13460
                transition_labels.set(start + i, *label);
83
13460
                transition_from.set(start + i, *from);
84
13460
            }
85
        }
86

            
87
600
        Self {
88
600
            transition_labels,
89
600
            transition_from,
90
600
            state2incoming,
91
600
            _marker: PhantomData,
92
600
        }
93
600
    }
94

            
95
    /// Returns an iterator over the incoming transitions for the given state.
96
22049
    pub fn incoming_transitions(&self, state_index: StateIndex) -> impl Iterator<Item = Transition> + '_ {
97
22049
        let start = self.state2incoming.index(state_index.value());
98
22049
        let end = self.state2incoming.index(state_index.value() + 1);
99
29655
        (start..end).map(move |i| Transition::new(self.transition_labels.index(i), self.transition_from.index(i)))
100
22049
    }
101

            
102
    // Return an iterator over the incoming silent transitions for the given state.
103
6976
    pub fn incoming_silent_transitions(&self, state_index: StateIndex) -> impl Iterator<Item = Transition> + '_ {
104
6976
        let start = self.state2incoming.index(state_index.value());
105
6976
        let end = self.state2incoming.index(state_index.value() + 1);
106
6976
        (start..end)
107
7473
            .map(move |i| Transition::new(self.transition_labels.index(i), self.transition_from.index(i)))
108
7473
            .take_while(|transition| transition.label == 0)
109
6976
    }
110
}
111

            
112
#[cfg(test)]
113
mod tests {
114
    use super::*;
115

            
116
    use merc_io::DumpFiles;
117
    use merc_utilities::random_test;
118

            
119
    use crate::random_lts;
120
    use crate::write_aut;
121

            
122
    #[test]
123
1
    fn test_random_incoming_transitions() {
124
100
        random_test(100, |rng| {
125
100
            let mut files = DumpFiles::new("test_random_incoming_transitions");
126

            
127
100
            let lts = random_lts(rng, 10, 3, 3);
128
100
            files.dump("input.aut", |f| write_aut(f, &lts)).unwrap();
129
100
            let incoming = IncomingTransitions::new(&lts);
130

            
131
            // Check that for every outgoing transition there is an incoming transition.
132
1526
            for state_index in lts.iter_states() {
133
3583
                for transition in lts.outgoing_transitions(state_index) {
134
3583
                    let found = incoming
135
3583
                        .incoming_transitions(transition.to)
136
7358
                        .any(|incoming| incoming.label == transition.label && incoming.to == state_index);
137
3583
                    assert!(
138
3583
                        found,
139
                        "Outgoing transition ({state_index}, {transition:?}) should have an incoming transition"
140
                    );
141
                }
142
            }
143

            
144
            // Check that all incoming transitions belong to some outgoing transition.
145
1526
            for state_index in lts.iter_states() {
146
3583
                for transition in incoming.incoming_transitions(state_index) {
147
3583
                    let found = lts
148
3583
                        .outgoing_transitions(transition.to)
149
7264
                        .any(|outgoing| outgoing.label == transition.label && outgoing.to == state_index);
150
3583
                    assert!(
151
3583
                        found,
152
                        "Incoming transition ({transition:?}, {state_index}) should have an outgoing transition"
153
                    );
154
                }
155
            }
156
100
        });
157
1
    }
158
}