1
use std::borrow::Borrow;
2
use std::f64::consts::LOG10_2;
3
use std::hash::Hash;
4
use std::io::BufWriter;
5
use std::io::Seek;
6
use std::io::SeekFrom;
7
use std::io::Write;
8
use std::marker::PhantomData;
9

            
10
use merc_utilities::MercError;
11

            
12
use crate::LtsBuilder;
13
use crate::StateIndex;
14
use crate::TransitionLabel;
15

            
16
/// A stream writer for the AUT format, which allows writing an LTS to a file
17
/// without keeping the entire LTS in memory. The header of the AUT file is
18
/// written at the end, after all transitions have been added, to ensure that
19
/// the correct number of states and transitions is included in the header.
20
pub struct AutStream<W: Write, L> {
21
    writer: BufWriter<W>,
22

            
23
    /// Keep track of the number of transitions added.
24
    number_of_transitions: usize,
25

            
26
    /// Keep track of the number of states added.
27
    number_of_states: usize,
28

            
29
    _marker: PhantomData<L>,
30
}
31

            
32
impl<W: Write, L> AutStream<W, L> {
33
    /// Creates a new AUT stream writer.
34
    ///
35
    /// Note that the writer is buffered internally using a `BufWriter`.
36
100
    pub fn new(writer: W) -> Self {
37
100
        let mut writer = BufWriter::new(writer);
38
        // Write a placeholder for the header, which will be filled in later.
39
        // Reserve enough space for the header using the number of bits of
40
        // usize. This avoids overwriting transition bytes when the final header
41
        // is longer.
42
100
        let max_usize_digits = (usize::BITS as f64 * LOG10_2).ceil() as usize;
43
100
        let header_len = format!("des ({0:<1$}, {0:<1$}, {0:<1$})\n", " ", max_usize_digits).len();
44
100
        writer.write_all(" ".repeat(header_len).as_bytes()).unwrap();
45

            
46
100
        Self {
47
100
            writer,
48
100
            number_of_transitions: 0,
49
100
            number_of_states: 0,
50
100
            _marker: PhantomData,
51
100
        }
52
100
    }
53

            
54
    /// Sets the number of states to at least the given number. All states without transitions
55
    /// will simply become deadlock states.
56
100
    pub fn require_num_of_states(&mut self, num_states: usize) {
57
100
        if num_states > self.number_of_states {
58
12
            self.number_of_states = num_states;
59
88
        }
60
100
    }
61
}
62

            
63
impl<W: Write + Seek, L: TransitionLabel> LtsBuilder<L> for AutStream<W, L> {
64
    type LTS = ();
65

            
66
99938
    fn add_transition<Q>(&mut self, from: StateIndex, label: &Q, to: StateIndex) -> Result<(), MercError>
67
99938
    where
68
99938
        L: Borrow<Q>,
69
99938
        Q: ?Sized + ToOwned<Owned = L> + Eq + Hash,
70
    {
71
99938
        self.number_of_transitions += 1;
72
99938
        self.number_of_states = self.number_of_states.max(from.value() + 1).max(to.value() + 1);
73

            
74
99938
        writeln!(self.writer, "({}, \"{}\", {})", from, label.to_owned(), to)?;
75
99938
        Ok(())
76
99938
    }
77

            
78
100
    fn finish(&mut self, initial_state: StateIndex) -> Result<Self::LTS, MercError> {
79
        // Flush to ensure all buffered transitions are written
80
100
        self.writer.flush()?;
81

            
82
        // Seek to the start and overwrite the header
83
100
        self.writer.seek(SeekFrom::Start(0))?;
84
100
        writeln!(
85
100
            self.writer,
86
            "des ({}, {}, {})",
87
            initial_state, self.number_of_transitions, self.number_of_states
88
        )?;
89

            
90
        // Flush the updated header
91
100
        self.writer.flush()?;
92
100
        Ok(())
93
100
    }
94

            
95
    /// Returns the number of transitions added to the builder.
96
    fn num_of_transitions(&self) -> usize {
97
        self.number_of_transitions
98
    }
99

            
100
    /// Returns the number of states added to the builder.
101
    fn num_of_states(&self) -> usize {
102
        self.number_of_states
103
    }
104
}
105

            
106
#[cfg(test)]
107
mod tests {
108
    use std::io::Cursor;
109

            
110
    use merc_utilities::random_test;
111

            
112
    use crate::AutStream;
113
    use crate::LTS;
114
    use crate::LtsBuilder;
115
    use crate::random_lts;
116
    use crate::read_aut;
117

            
118
    #[test]
119
    #[cfg_attr(miri, ignore)] // Test is too slow under Miri
120
1
    fn test_random_aut_stream_io() {
121
100
        random_test(100, |rng| {
122
100
            let lts = random_lts::<String, _>(rng, 1000, 3);
123

            
124
100
            let mut buffer = Cursor::new(Vec::new());
125
            {
126
100
                let mut stream = AutStream::new(&mut buffer);
127

            
128
                // Write all the transitions to the stream.
129
100000
                for state_index in lts.iter_states() {
130
100000
                    for transition in lts.outgoing_transitions(state_index) {
131
99938
                        stream
132
99938
                            .add_transition(state_index, &lts.labels()[transition.label.value()], transition.to)
133
99938
                            .unwrap();
134
99938
                    }
135
                }
136

            
137
100
                stream.require_num_of_states(lts.num_of_states());
138
100
                stream.finish(lts.initial_state_index()).unwrap();
139
            }
140

            
141
            // Rewind the buffer to the beginning before reading.
142
100
            buffer.set_position(0);
143
100
            let result_lts = read_aut(&mut buffer).unwrap();
144

            
145
100
            crate::check_equivalent(&lts, &result_lts);
146
100
        })
147
1
    }
148
}