1
use std::borrow::Borrow;
2
use std::f64::consts::LOG10_2;
3
use std::hash::Hash;
4
use std::io::BufWriter;
5
use std::io::Seek;
6
use std::io::SeekFrom;
7
use std::io::Write;
8
use std::marker::PhantomData;
9

            
10
use merc_utilities::MercError;
11

            
12
use crate::LtsBuilder;
13
use crate::StateIndex;
14
use crate::TransitionLabel;
15

            
16
pub struct AutStream<W: Write, L> {
17
    writer: BufWriter<W>,
18

            
19
    /// Keep track of the number of transitions added.
20
    number_of_transitions: usize,
21

            
22
    /// Keep track of the number of states added.
23
    number_of_states: usize,
24

            
25
    _marker: PhantomData<L>,
26
}
27

            
28
impl<W: Write, L> AutStream<W, L> {
29
    /// Creates a new AUT stream writer.
30
    ///
31
    /// Note that the writer is buffered internally using a `BufWriter`.
32
100
    pub fn new(writer: W) -> Self {
33
100
        let mut writer = BufWriter::new(writer);
34
        // Write a placeholder for the header, which will be filled in later.
35
        // Reserve enough space for the header using the number of bits of
36
        // usize. This avoids overwriting transition bytes when the final header
37
        // is longer.
38
100
        let max_usize_digits = (usize::BITS as f64 * LOG10_2).ceil() as usize;
39
100
        let header_len = format!("des ({0:<1$}, {0:<1$}, {0:<1$})\n", " ", max_usize_digits).len();
40
100
        writer.write_all(" ".repeat(header_len).as_bytes()).unwrap();
41

            
42
100
        Self {
43
100
            writer,
44
100
            number_of_transitions: 0,
45
100
            number_of_states: 0,
46
100
            _marker: PhantomData,
47
100
        }
48
100
    }
49
}
50

            
51
impl<W: Write + Seek, L: TransitionLabel> LtsBuilder<L> for AutStream<W, L> {
52
    type LTS = ();
53

            
54
92762
    fn add_transition<Q>(&mut self, from: StateIndex, label: &Q, to: StateIndex) -> Result<(), MercError>
55
92762
    where
56
92762
        L: Borrow<Q>,
57
92762
        Q: ?Sized + ToOwned<Owned = L> + Eq + Hash,
58
    {
59
92762
        self.number_of_transitions += 1;
60
92762
        self.number_of_states = self.number_of_states.max(from.value() + 1).max(to.value() + 1);
61

            
62
92762
        writeln!(self.writer, "({}, \"{}\", {})", from, label.to_owned(), to)?;
63
92762
        Ok(())
64
92762
    }
65

            
66
100
    fn finish(&mut self, initial_state: StateIndex) -> Result<Self::LTS, MercError> {
67
        // Flush to ensure all buffered transitions are written
68
100
        self.writer.flush()?;
69

            
70
        // Seek to the start and overwrite the header
71
100
        self.writer.seek(SeekFrom::Start(0))?;
72
100
        writeln!(
73
100
            self.writer,
74
            "des ({}, {}, {})",
75
            initial_state, self.number_of_transitions, self.number_of_states
76
        )?;
77

            
78
        // Flush the updated header
79
100
        self.writer.flush()?;
80
100
        Ok(())
81
100
    }
82

            
83
    /// Returns the number of transitions added to the builder.
84
    fn num_of_transitions(&self) -> usize {
85
        self.number_of_transitions
86
    }
87

            
88
    /// Returns the number of states added to the builder.
89
    fn num_of_states(&self) -> usize {
90
        self.number_of_states
91
    }
92
}
93

            
94
#[cfg(test)]
95
mod tests {
96
    use std::io::Cursor;
97

            
98
    use merc_utilities::random_test;
99

            
100
    use crate::AutStream;
101
    use crate::LTS;
102
    use crate::LtsBuilder;
103
    use crate::random_lts_monolithic;
104
    use crate::read_aut;
105

            
106
    #[test]
107
    #[cfg_attr(miri, ignore)] // Test is too slow under Miri
108
1
    fn test_random_aut_stream_io() {
109
100
        random_test(100, |rng| {
110
100
            let lts = random_lts_monolithic::<String>(rng, 100, 3, 20);
111

            
112
100
            let mut buffer = Cursor::new(Vec::new());
113
            {
114
100
                let mut stream = AutStream::new(&mut buffer);
115

            
116
                // Write all the transitions to the stream.
117
10000
                for state_index in lts.iter_states() {
118
92762
                    for transition in lts.outgoing_transitions(state_index) {
119
92762
                        stream
120
92762
                            .add_transition(state_index, &lts.labels()[transition.label.value()], transition.to)
121
92762
                            .unwrap();
122
92762
                    }
123
                }
124

            
125
100
                stream.finish(lts.initial_state_index()).unwrap();
126
            }
127

            
128
            // Rewind the buffer to the beginning before reading.
129
100
            buffer.set_position(0);
130
100
            let result_lts = read_aut(&mut buffer, vec![]).unwrap();
131

            
132
100
            crate::check_equivalent(&lts, &result_lts);
133
100
        })
134
1
    }
135
}