1
use merc_collections::ByteCompressedVec;
2
use merc_collections::bytevec;
3

            
4
use crate::LTS;
5
use crate::LabelIndex;
6
use crate::StateIndex;
7
use crate::Transition;
8

            
9
/// Stores the incoming transitions for a given labelled transition system.
10
pub struct IncomingTransitions {
11
    transition_labels: ByteCompressedVec<LabelIndex>,
12
    transition_from: ByteCompressedVec<StateIndex>,
13
    state2incoming: ByteCompressedVec<usize>,
14
}
15

            
16
impl IncomingTransitions {
17
600
    pub fn new(lts: &impl LTS) -> Self {
18
600
        let mut transition_labels = bytevec![LabelIndex::new(0); lts.num_of_transitions()];
19
600
        let mut transition_from = bytevec![StateIndex::new(0); lts.num_of_transitions()];
20
600
        let mut state2incoming = bytevec![0usize; lts.num_of_states()];
21

            
22
        // Count the number of incoming transitions for each state
23
20066
        for state_index in lts.iter_states() {
24
94340
            for transition in lts.outgoing_transitions(state_index) {
25
94340
                state2incoming.update(transition.to.value(), |start| *start += 1);
26
            }
27
        }
28

            
29
        // Compute the start offsets (prefix sum)
30
20066
        state2incoming.fold(0, |offset, start| {
31
20066
            let new_offset = offset + *start;
32
20066
            *start = offset;
33
20066
            new_offset
34
20066
        });
35

            
36
        // Place the transitions
37
20066
        for state_index in lts.iter_states() {
38
94340
            for transition in lts.outgoing_transitions(state_index) {
39
94340
                state2incoming.update(transition.to.value(), |start| {
40
94340
                    transition_labels.set(*start, transition.label);
41
94340
                    transition_from.set(*start, state_index);
42
94340
                    *start += 1;
43
94340
                });
44
            }
45
        }
46

            
47
20066
        state2incoming.fold(0, |previous, start| {
48
20066
            let result = *start;
49
20066
            *start = previous;
50
20066
            result
51
20066
        });
52

            
53
        // Add sentinel state
54
600
        state2incoming.push(transition_labels.len());
55

            
56
        // Sort the incoming transitions such that silent transitions come first.
57
        //
58
        // TODO: This could be more efficient by simply grouping them instead of sorting, perhaps some group using a predicate.
59
600
        let mut pairs = Vec::new();
60
20066
        for state_index in 0..lts.num_of_states() {
61
20066
            let start = state2incoming.index(state_index);
62
20066
            let end = state2incoming.index(state_index + 1);
63

            
64
            // Extract, sort, and put back
65
20066
            pairs.clear();
66
94340
            pairs.extend((start..end).map(|i| (transition_labels.index(i), transition_from.index(i))));
67
20066
            pairs.sort_unstable_by_key(|(label, _)| *label);
68

            
69
94340
            for (i, (label, from)) in pairs.iter().enumerate() {
70
94340
                transition_labels.set(start + i, *label);
71
94340
                transition_from.set(start + i, *from);
72
94340
            }
73
        }
74

            
75
600
        Self {
76
600
            transition_labels,
77
600
            transition_from,
78
600
            state2incoming,
79
600
        }
80
600
    }
81

            
82
    /// Returns an iterator over the incoming transitions for the given state.
83
65623
    pub fn incoming_transitions(&self, state_index: StateIndex) -> impl Iterator<Item = Transition> + '_ {
84
65623
        let start = self.state2incoming.index(state_index.value());
85
65623
        let end = self.state2incoming.index(state_index.value() + 1);
86
149110
        (start..end).map(move |i| Transition::new(self.transition_labels.index(i), self.transition_from.index(i)))
87
65623
    }
88

            
89
    // Return an iterator over the incoming silent transitions for the given state.
90
6195
    pub fn incoming_silent_transitions(&self, state_index: StateIndex) -> impl Iterator<Item = Transition> + '_ {
91
6195
        let start = self.state2incoming.index(state_index.value());
92
6195
        let end = self.state2incoming.index(state_index.value() + 1);
93
6195
        (start..end)
94
6861
            .map(move |i| Transition::new(self.transition_labels.index(i), self.transition_from.index(i)))
95
6861
            .take_while(|transition| transition.label == 0)
96
6195
    }
97
}
98

            
99
#[cfg(test)]
100
mod tests {
101
    use super::*;
102

            
103
    use merc_io::DumpFiles;
104
    use merc_utilities::random_test;
105

            
106
    use crate::random_lts;
107

            
108
    #[test]
109
1
    fn test_random_incoming_transitions() {
110
100
        random_test(100, |rng| {
111
100
            let mut files = DumpFiles::new("test_random_incoming_transitions");
112

            
113
100
            let lts = random_lts(rng, 10, 3, 3);
114
100
            files.dump("input.aut", |f| crate::write_aut(f, &lts)).unwrap();
115
100
            let incoming = IncomingTransitions::new(&lts);
116

            
117
            // Check that for every outgoing transition there is an incoming transition.
118
869
            for state_index in lts.iter_states() {
119
1643
                for transition in lts.outgoing_transitions(state_index) {
120
1643
                    let found = incoming
121
1643
                        .incoming_transitions(transition.to)
122
2892
                        .any(|incoming| incoming.label == transition.label && incoming.to == state_index);
123
1643
                    assert!(
124
1643
                        found,
125
                        "Outgoing transition ({state_index}, {transition:?}) should have an incoming transition"
126
                    );
127
                }
128
            }
129

            
130
            // Check that all incoming transitions belong to some outgoing transition.
131
869
            for state_index in lts.iter_states() {
132
1643
                for transition in incoming.incoming_transitions(state_index) {
133
1643
                    let found = lts
134
1643
                        .outgoing_transitions(transition.to)
135
2954
                        .any(|outgoing| outgoing.label == transition.label && outgoing.to == state_index);
136
1643
                    assert!(
137
1643
                        found,
138
                        "Incoming transition ({transition:?}, {state_index}) should have an outgoing transition"
139
                    );
140
                }
141
            }
142
100
        });
143
1
    }
144
}