1
use std::io::BufWriter;
2
use std::io::Read;
3
use std::io::Write;
4

            
5
use log::info;
6
use regex::Regex;
7
use streaming_iterator::StreamingIterator;
8
use thiserror::Error;
9

            
10
use merc_io::LineIterator;
11
use merc_io::TimeProgress;
12
use merc_utilities::MercError;
13
use merc_utilities::debug_trace;
14

            
15
use crate::LTS;
16
use crate::LabelledTransitionSystem;
17
use crate::LtsBuilder;
18
use crate::StateIndex;
19
use crate::TransitionLabel;
20

            
21
#[derive(Error, Debug)]
22
pub enum IOError {
23
    #[error("Invalid .aut header {0}")]
24
    InvalidHeader(&'static str),
25

            
26
    #[error("Invalid transition {0}")]
27
    InvalidTransition(String),
28
}
29

            
30
/// Loads a labelled transition system in the [Aldebaran
31
/// format](https://cadp.inria.fr/man/aldebaran.html) from the given reader.
32
///
33
/// Note that the reader has a buffer in the form of  `BufReader`` internally.
34
///
35
/// The Aldebaran format consists of a header: `des (<initial>: Nat,
36
///     <num_of_transitions>: Nat, <num_of_states>: Nat)`
37
///     
38
/// And one line for every transition either one of these cases:
39
///  `(<from>: Nat, "<label>": Str, <to>: Nat)`
40
///  `(<from>: Nat, <label>: Str, <to>: Nat)`
41
108
pub fn read_aut(reader: impl Read, hidden_labels: Vec<String>) -> Result<LabelledTransitionSystem<String>, MercError> {
42
108
    info!("Reading LTS in .aut format...");
43

            
44
108
    let mut lines = LineIterator::new(reader);
45
108
    lines.advance();
46
108
    let header = lines
47
108
        .get()
48
108
        .ok_or(IOError::InvalidHeader("The first line should be the header"))?;
49

            
50
    // Regex for des (<initial>: Nat, <num_of_states>: Nat, <num_of_transitions>: Nat)
51
108
    let header_regex = Regex::new(r#"des\s*\(\s*([0-9]*)\s*,\s*([0-9]*)\s*,\s*([0-9]*)\s*\)\s*"#)
52
108
        .expect("Regex compilation should not fail");
53

            
54
108
    let (_, [initial_txt, num_of_transitions_txt, num_of_states_txt]) = header_regex
55
108
        .captures(header)
56
108
        .ok_or(IOError::InvalidHeader(
57
108
            "does not match des (<init>, <num_of_transitions>, <num_of_states>)",
58
108
        ))?
59
106
        .extract();
60

            
61
106
    let initial_state = StateIndex::new(initial_txt.parse()?);
62
106
    let num_of_transitions: usize = num_of_transitions_txt.parse()?;
63
106
    let num_of_states: usize = num_of_states_txt.parse()?;
64

            
65
106
    let mut builder = LtsBuilder::with_capacity(Vec::new(), hidden_labels, num_of_states, 16, num_of_transitions);
66
106
    let progress = TimeProgress::new(|percentage: usize| info!("Reading transitions {}%...", percentage), 1);
67

            
68
94663
    while let Some(line) = lines.next() {
69
94557
        let (from_txt, label_txt, to_txt) =
70
94557
            read_transition(line).ok_or_else(|| IOError::InvalidTransition(line.clone()))?;
71

            
72
        // Parse the from and to states, with the given label.
73
94557
        let from = StateIndex::new(from_txt.parse()?);
74
94557
        let to = StateIndex::new(to_txt.parse()?);
75

            
76
94557
        debug_trace!("Read transition {from} --[{label_txt}]-> {to}");
77

            
78
94557
        builder.add_transition(from, label_txt, to);
79

            
80
94557
        progress.print(builder.num_of_transitions() * 100 / num_of_transitions);
81
    }
82

            
83
106
    info!("Finished reading LTS");
84

            
85
106
    Ok(builder.finish(initial_state))
86
108
}
87

            
88
/// Write a labelled transition system in plain text in Aldebaran format to the
89
/// given writer, see [read_aut].
90
///
91
/// Note that the writer is buffered internally using a `BufWriter`.
92
101
pub fn write_aut(writer: &mut impl Write, lts: &impl LTS) -> Result<(), MercError> {
93
101
    let mut writer = BufWriter::new(writer);
94
101
    writeln!(
95
101
        writer,
96
        "des ({}, {}, {})",
97
101
        lts.initial_state_index(),
98
101
        lts.num_of_transitions(),
99
101
        lts.num_of_states()
100
    )?;
101

            
102
101
    let progress = TimeProgress::new(|percentage: usize| info!("Writing transitions {}%...", percentage), 1);
103
101
    let mut transitions_written = 0usize;
104
10074
    for state_index in lts.iter_states() {
105
92902
        for transition in lts.outgoing_transitions(state_index) {
106
92902
            writeln!(
107
92902
                writer,
108
                "({}, \"{}\", {})",
109
                state_index,
110
92902
                lts.labels()[transition.label.value()],
111
                transition.to
112
            )?;
113

            
114
92902
            progress.print(transitions_written * 100 / lts.num_of_transitions());
115
92902
            transitions_written += 1;
116
        }
117
    }
118

            
119
101
    Ok(())
120
101
}
121

            
122
/// Dedicated function to parse the following transition formats:
123
///
124
/// # Details
125
///
126
/// One of the following formats:
127
///     `(<from>: Nat, "<label>": Str, <to>: Nat)`
128
///     `(<from>: Nat, <label>: Str, <to>: Nat)`
129
///
130
/// This was generally faster than the regex variant, since that one has to backtrack after
131
97315
fn read_transition(input: &str) -> Option<(&str, &str, &str)> {
132
97315
    let start_paren = input.find('(')?;
133
97315
    let start_comma = input.find(',')?;
134

            
135
    // Find the comma in the second part
136
97315
    let start_second_comma = input.rfind(',')?;
137
97315
    let end_paren = input.rfind(')')?;
138

            
139
97315
    let from = input.get(start_paren + 1..start_comma)?.trim();
140
97315
    let label = input.get(start_comma + 1..start_second_comma)?.trim();
141
97315
    let to = input.get(start_second_comma + 1..end_paren)?.trim();
142
    // Handle the special case where it has quotes.
143
97315
    if label.starts_with('"') && label.ends_with('"') {
144
97315
        return Some((from, &label[1..label.len() - 1], to));
145
    }
146

            
147
    Some((from, label, to))
148
97315
}
149

            
150
/// A trait for labels that can be used in transitions.
151
impl TransitionLabel for String {
152
113700
    fn is_tau_label(&self) -> bool {
153
113700
        self == "tau"
154
113700
    }
155

            
156
30120
    fn tau_label() -> Self {
157
30120
        "tau".to_string()
158
30120
    }
159

            
160
    fn matches_label(&self, label: &str) -> bool {
161
        self == label
162
    }
163

            
164
44300
    fn from_index(i: usize) -> Self {
165
44300
        char::from_digit(i as u32, 36)
166
44300
            .expect("Radix is less than 37, so should not panic")
167
44300
            .to_string()
168
44300
    }
169
}
170

            
171
#[cfg(test)]
172
mod tests {
173
    use crate::random_lts_monolithic;
174

            
175
    use super::*;
176

            
177
    use merc_utilities::random_test;
178
    use test_log::test;
179

            
180
    #[test]
181
    fn test_reading_aut() {
182
        let file = include_str!("../../../examples/lts/abp.aut");
183

            
184
        let lts = read_aut(file.as_bytes(), vec![]).unwrap();
185

            
186
        assert_eq!(lts.initial_state_index().value(), 0);
187
        assert_eq!(lts.num_of_transitions(), 92);
188
    }
189

            
190
    #[test]
191
    fn test_lts_failure() {
192
        let wrong_header = "
193
        des (0,2,                                     
194
            (0,\"r1(d1)\",1)
195
            (0,\"r1(d2)\",2)
196
        ";
197

            
198
        debug_assert!(read_aut(wrong_header.as_bytes(), vec![]).is_err());
199

            
200
        let wrong_transition = "
201
        des (0,2,3)                           
202
            (0,\"r1(d1),1)
203
            (0,\"r1(d2)\",2)
204
        ";
205

            
206
        debug_assert!(read_aut(wrong_transition.as_bytes(), vec![]).is_err());
207
    }
208

            
209
    #[test]
210
    fn test_traversal_lts() {
211
        let file = include_str!("../../../examples/lts/abp.aut");
212

            
213
        let lts = read_aut(file.as_bytes(), vec![]).unwrap();
214

            
215
        // Check the number of outgoing transitions of the initial state
216
        assert_eq!(lts.outgoing_transitions(lts.initial_state_index()).count(), 2);
217
    }
218

            
219
    #[test]
220
    fn test_writing_lts() {
221
        let file = include_str!("../../../examples/lts/abp.aut");
222
        let lts_original = read_aut(file.as_bytes(), vec![]).unwrap();
223

            
224
        // Check that it can be read after writing, and results in the same LTS.
225
        let mut buffer: Vec<u8> = Vec::new();
226
        write_aut(&mut buffer, &lts_original).unwrap();
227

            
228
        let lts = read_aut(&buffer[0..], vec![]).unwrap();
229

            
230
        assert!(lts.num_of_states() == lts_original.num_of_states());
231
        assert!(lts.num_of_labels() == lts_original.num_of_labels());
232
        assert!(lts.num_of_transitions() == lts_original.num_of_transitions());
233
    }
234

            
235
    #[test]
236
    #[cfg_attr(miri, ignore)]
237
    fn test_random_aut_io() {
238
100
        random_test(100, |rng| {
239
100
            let lts = random_lts_monolithic::<String>(rng, 100, 3, 20);
240

            
241
100
            let mut buffer: Vec<u8> = Vec::new();
242
100
            write_aut(&mut buffer, &lts).unwrap();
243

            
244
100
            let lts_read = read_aut(&buffer[0..], vec![]).unwrap();
245

            
246
100
            println!("LTS labels: {:?}", lts.labels());
247
100
            println!("Read LTS labels: {:?}", lts_read.labels());
248

            
249
            // If labels are not used, the number of labels may be less. So find a remapping of old labels to new labels.
250
100
            let mapping = lts
251
100
                .labels()
252
100
                .iter()
253
100
                .enumerate()
254
600
                .map(|(_i, label)| lts_read.labels().iter().position(|l| l == label))
255
100
                .collect::<Vec<_>>();
256

            
257
            // Print the mapping
258
300
            for (i, m) in mapping.iter().enumerate() {
259
300
                println!("Label {} mapped to {:?}", i, m);
260
300
            }
261

            
262
100
            assert_eq!(lts.num_of_states(), lts_read.num_of_states());
263
100
            assert_eq!(lts.num_of_transitions(), lts_read.num_of_transitions());
264

            
265
            // Check that all the outgoing transitions are the same.
266
10000
            for state_index in lts.iter_states() {
267
10000
                let transitions: Vec<_> = lts.outgoing_transitions(state_index).collect();
268
10000
                let transitions_read: Vec<_> = lts_read.outgoing_transitions(state_index).collect();
269

            
270
                // Check that transitions are the same, modulo label remapping.
271
92810
                transitions.iter().for_each(|t| {
272
92810
                    let mapped_label = mapping[t.label.value()].expect(&format!("Label {} should be found", t.label));
273
92810
                    assert!(
274
92810
                        transitions_read
275
92810
                            .iter()
276
632172
                            .any(|tr| tr.to == t.to && tr.label.value() == mapped_label)
277
                    );
278
92810
                });
279
            }
280
100
        })
281
    }
282
}