1
#![forbid(unsafe_code)]
2

            
3
use std::marker::PhantomData;
4

            
5
use merc_collections::ByteCompressedVec;
6
use merc_collections::bytevec;
7

            
8
use crate::LTS;
9
use crate::LabelIndex;
10
use crate::StateIndex;
11

            
12
/// Stores the incoming transitions for a given labelled transition system.
13
pub struct IncomingTransitions<'a> {
14
    /// A flat list of all incoming transition labels in the LTS. They are stored in two separate
15
    /// arrays since the compression is based on the highest value.
16
    transition_labels: ByteCompressedVec<LabelIndex>,
17
    transition_from: ByteCompressedVec<StateIndex>,
18

            
19
    /// A mapping from the state to the `transition_labels` and
20
    /// `transition_from` that stores its incoming transitions.
21
    state2incoming: ByteCompressedVec<usize>,
22

            
23
    /// Marker to tie the lifetime of the incoming transitions to the LTS.
24
    _marker: PhantomData<&'a ()>,
25
}
26

            
27
impl<'a> IncomingTransitions<'a> {
28
2000
    pub fn new<L: LTS>(lts: &'a L) -> Self {
29
2000
        let mut transition_labels = bytevec![LabelIndex::new(0); lts.num_of_transitions()];
30
2000
        let mut transition_from = bytevec![StateIndex::new(0); lts.num_of_transitions()];
31
2000
        let mut state2incoming = bytevec![0usize; lts.num_of_states()];
32

            
33
        // Count the number of incoming transitions for each state
34
626397
        for state_index in lts.iter_states() {
35
659495
            for transition in lts.outgoing_transitions(state_index) {
36
659495
                state2incoming.update(transition.to.value(), |start| *start += 1);
37
            }
38
        }
39

            
40
        // Compute the start offsets (prefix sum)
41
626397
        state2incoming.fold(0, |offset, start| {
42
626397
            let new_offset = offset + *start;
43
626397
            *start = offset;
44
626397
            new_offset
45
626397
        });
46

            
47
        // Place the transitions
48
626397
        for state_index in lts.iter_states() {
49
659495
            for transition in lts.outgoing_transitions(state_index) {
50
659495
                state2incoming.update(transition.to.value(), |start| {
51
659495
                    transition_labels.set(*start, transition.label);
52
659495
                    transition_from.set(*start, state_index);
53
659495
                    *start += 1;
54
659495
                });
55
            }
56
        }
57

            
58
626397
        state2incoming.fold(0, |previous, start| {
59
626397
            let result = *start;
60
626397
            *start = previous;
61
626397
            result
62
626397
        });
63

            
64
        // Add sentinel state
65
2000
        state2incoming.push(transition_labels.len());
66

            
67
        // Sort the incoming transitions such that silent transitions come first.
68
        //
69
        // TODO: This could be more efficient by simply grouping them instead of sorting, perhaps some group using a predicate.
70
2000
        let mut pairs = Vec::new();
71
626397
        for state_index in 0..lts.num_of_states() {
72
626397
            let start = state2incoming.index(state_index);
73
626397
            let end = state2incoming.index(state_index + 1);
74

            
75
            // Extract, sort, and put back
76
626397
            pairs.clear();
77
659495
            pairs.extend((start..end).map(|i| (transition_labels.index(i), transition_from.index(i))));
78
626397
            pairs.sort_unstable_by_key(|(label, _)| *label);
79

            
80
659495
            for (i, (label, from)) in pairs.iter().enumerate() {
81
659495
                transition_labels.set(start + i, *label);
82
659495
                transition_from.set(start + i, *from);
83
659495
            }
84
        }
85

            
86
2000
        Self {
87
2000
            transition_labels,
88
2000
            transition_from,
89
2000
            state2incoming,
90
2000
            _marker: PhantomData,
91
2000
        }
92
2000
    }
93

            
94
    /// Returns an iterator over the incoming transitions for the given state.
95
3108937
    pub fn incoming_transitions(&self, state_index: StateIndex) -> impl Iterator<Item = FromTransition> + '_ {
96
3108937
        let start = self.state2incoming.index(state_index.value());
97
3108937
        let end = self.state2incoming.index(state_index.value() + 1);
98
3113541
        (start..end).map(move |i| FromTransition::new(self.transition_labels.index(i), self.transition_from.index(i)))
99
3108937
    }
100

            
101
    // Return an iterator over the incoming silent transitions for the given state.
102
1509576
    pub fn incoming_silent_transitions(&self, state_index: StateIndex) -> impl Iterator<Item = FromTransition> + '_ {
103
1509576
        let start = self.state2incoming.index(state_index.value());
104
1509576
        let end = self.state2incoming.index(state_index.value() + 1);
105
1509576
        (start..end)
106
1509576
            .map(move |i| FromTransition::new(self.transition_labels.index(i), self.transition_from.index(i)))
107
1509576
            .take_while(|transition| transition.label == 0)
108
1509576
    }
109
}
110

            
111
/// Represents an incoming transition in the LTS going to a known state.
112
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
113
pub struct FromTransition {
114
    pub label: LabelIndex,
115
    pub from: StateIndex,
116
}
117

            
118
impl FromTransition {
119
    /// Constructs a new transition.
120
4846321
    pub fn new(label: LabelIndex, from: StateIndex) -> Self {
121
4846321
        Self { label, from }
122
4846321
    }
123
}
124

            
125
#[cfg(test)]
126
mod tests {
127
    use merc_io::DumpFiles;
128
    use merc_utilities::random_test;
129

            
130
    use crate::IncomingTransitions;
131
    use crate::LTS;
132
    use crate::random_lts;
133
    use crate::write_aut;
134

            
135
    #[test]
136
1
    fn test_random_incoming_transitions() {
137
100
        random_test(100, |rng| {
138
100
            let mut files = DumpFiles::new("test_random_incoming_transitions");
139

            
140
100
            let lts = random_lts(rng, 10, 3, 3);
141
100
            files.dump("input.aut", |f| write_aut(f, &lts)).unwrap();
142
100
            let incoming = IncomingTransitions::new(&lts);
143

            
144
            // Check that for every outgoing transition there is an incoming transition.
145
1240
            for state_index in lts.iter_states() {
146
2921
                for transition in lts.outgoing_transitions(state_index) {
147
2921
                    let found = incoming
148
2921
                        .incoming_transitions(transition.to)
149
5844
                        .any(|incoming| incoming.label == transition.label && incoming.from == state_index);
150
2921
                    assert!(
151
2921
                        found,
152
                        "Outgoing transition ({state_index}, {transition:?}) should have an incoming transition"
153
                    );
154
                }
155
            }
156

            
157
            // Check that all incoming transitions belong to some outgoing transition.
158
1240
            for state_index in lts.iter_states() {
159
2921
                for transition in incoming.incoming_transitions(state_index) {
160
2921
                    let found = lts
161
2921
                        .outgoing_transitions(transition.from)
162
6006
                        .any(|outgoing| outgoing.label == transition.label && outgoing.to == state_index);
163
2921
                    assert!(
164
2921
                        found,
165
                        "Incoming transition ({transition:?}, {state_index}) should have an outgoing transition"
166
                    );
167
                }
168
            }
169
100
        });
170
1
    }
171
}