1
use merc_utilities::MercError;
2

            
3
use crate::RegFrm;
4
use crate::SortExpression;
5
use crate::StateFrm;
6

            
7
/// Applies the given function recursively to the state formula.
8
///
9
/// The substitution `function` takes a state formula and returns an optional new
10
/// formula. If it returns `Some(new_formula)`, the substitution is applied and
11
/// the new formula is returned. If it returns `None`, the substitution is not
12
/// applied and the function continues to traverse the formula tree.
13
4689
pub fn apply_statefrm<F>(formula: StateFrm, mut function: F) -> Result<StateFrm, MercError>
14
4689
where
15
4689
    F: FnMut(&StateFrm) -> Result<Option<StateFrm>, MercError>,
16
{
17
4689
    apply_statefrm_rec(formula, &mut function)
18
4689
}
19

            
20
pub fn apply_sort_expression<F>(sort_expr: SortExpression, mut function: F) -> Result<SortExpression, MercError>
21
where
22
    F: FnMut(&SortExpression) -> Result<Option<SortExpression>, MercError>,
23
{
24
    apply_sort_expression_rec(sort_expr, &mut function)
25
}
26

            
27
/// Applies the given `function` recursively to the regular formula.
28
///
29
/// # Details
30
///
31
/// The substitution function is a partial function, where `Some(formula)`
32
/// indicates that substitution should be applied.
33
pub fn apply_regular_formula<F>(formula: RegFrm, mut function: F) -> Result<RegFrm, MercError>
34
where
35
    F: FnMut(&RegFrm) -> Result<Option<RegFrm>, MercError>,
36
{
37
    apply_regular_formula_rec(formula, &mut function)
38
}
39

            
40
/// See [apply_regular_formula].
41
fn apply_regular_formula_rec<F>(formula: RegFrm, apply: &mut F) -> Result<RegFrm, MercError>
42
where
43
    F: FnMut(&RegFrm) -> Result<Option<RegFrm>, MercError>,
44
{
45
    if let Some(formula) = apply(&formula)? {
46
        // A substitution was made, return the new formula.
47
        return Ok(formula);
48
    }
49

            
50
    match formula {
51
        RegFrm::Iteration(reg_frm) => {
52
            let new_reg_frm = apply_regular_formula_rec(*reg_frm, apply)?;
53
            Ok(RegFrm::Iteration(Box::new(new_reg_frm)))
54
        }
55
        RegFrm::Plus(reg_frm) => {
56
            let new_reg_frm = apply_regular_formula_rec(*reg_frm, apply)?;
57
            Ok(RegFrm::Plus(Box::new(new_reg_frm)))
58
        }
59
        RegFrm::Sequence { lhs, rhs } => {
60
            let new_lhs = apply_regular_formula_rec(*lhs, apply)?;
61
            let new_rhs = apply_regular_formula_rec(*rhs, apply)?;
62
            Ok(RegFrm::Sequence {
63
                lhs: Box::new(new_lhs),
64
                rhs: Box::new(new_rhs),
65
            })
66
        }
67
        RegFrm::Choice { lhs, rhs } => {
68
            let new_lhs = apply_regular_formula_rec(*lhs, apply)?;
69
            let new_rhs = apply_regular_formula_rec(*rhs, apply)?;
70
            Ok(RegFrm::Choice {
71
                lhs: Box::new(new_lhs),
72
                rhs: Box::new(new_rhs),
73
            })
74
        }
75
        _ => Ok(formula),
76
    }
77
}
78

            
79
/// See [`apply_statefrm`].
80
25545
fn apply_statefrm_rec<F>(formula: StateFrm, apply: &mut F) -> Result<StateFrm, MercError>
81
25545
where
82
25545
    F: FnMut(&StateFrm) -> Result<Option<StateFrm>, MercError>,
83
{
84
25545
    if let Some(formula) = apply(&formula)? {
85
        // A substitution was made, return the new formula.
86
3654
        return Ok(formula);
87
21891
    }
88

            
89
21891
    match formula {
90
4949
        StateFrm::Binary { op, lhs, rhs } => {
91
4949
            let new_lhs = apply_statefrm_rec(*lhs, apply)?;
92
4949
            let new_rhs = apply_statefrm_rec(*rhs, apply)?;
93
4949
            Ok(StateFrm::Binary {
94
4949
                op,
95
4949
                lhs: Box::new(new_lhs),
96
4949
                rhs: Box::new(new_rhs),
97
4949
            })
98
        }
99
        StateFrm::FixedPoint {
100
2028
            operator,
101
2028
            variable,
102
2028
            body,
103
        } => {
104
2028
            let new_body = apply_statefrm_rec(*body, apply)?;
105
2028
            Ok(StateFrm::FixedPoint {
106
2028
                operator,
107
2028
                variable,
108
2028
                body: Box::new(new_body),
109
2028
            })
110
        }
111
        StateFrm::Bound { bound, variables, body } => {
112
            let new_body = apply_statefrm_rec(*body, apply)?;
113
            Ok(StateFrm::Bound {
114
                bound,
115
                variables,
116
                body: Box::new(new_body),
117
            })
118
        }
119
        StateFrm::Modality {
120
8930
            operator,
121
8930
            formula,
122
8930
            expr,
123
        } => {
124
8930
            let expr = apply_statefrm_rec(*expr, apply)?;
125
8930
            Ok(StateFrm::Modality {
126
8930
                operator,
127
8930
                formula,
128
8930
                expr: Box::new(expr),
129
8930
            })
130
        }
131
        StateFrm::Quantifier {
132
            quantifier,
133
            variables,
134
            body,
135
        } => {
136
            let new_body = apply_statefrm_rec(*body, apply)?;
137
            Ok(StateFrm::Quantifier {
138
                quantifier,
139
                variables,
140
                body: Box::new(new_body),
141
            })
142
        }
143
        StateFrm::DataValExprRightMult(expr, data_val) => {
144
            let new_expr = apply_statefrm_rec(*expr, apply)?;
145
            Ok(StateFrm::DataValExprRightMult(Box::new(new_expr), data_val))
146
        }
147
        StateFrm::DataValExprLeftMult(data_val, expr) => {
148
            let new_expr = apply_statefrm_rec(*expr, apply)?;
149
            Ok(StateFrm::DataValExprLeftMult(data_val, Box::new(new_expr)))
150
        }
151
        StateFrm::Unary { op, expr } => {
152
            let new_expr = apply_statefrm_rec(*expr, apply)?;
153
            Ok(StateFrm::Unary {
154
                op,
155
                expr: Box::new(new_expr),
156
            })
157
        }
158
        StateFrm::Id(_, _)
159
        | StateFrm::True
160
        | StateFrm::False
161
        | StateFrm::Delay(_)
162
        | StateFrm::Yaled(_)
163
5984
        | StateFrm::DataValExpr(_) => Ok(formula),
164
    }
165
25545
}
166

            
167
fn apply_sort_expression_rec<F>(sort_expr: SortExpression, apply: &mut F) -> Result<SortExpression, MercError>
168
where
169
    F: FnMut(&SortExpression) -> Result<Option<SortExpression>, MercError>,
170
{
171
    if let Some(sort_expr) = apply(&sort_expr)? {
172
        // A substitution was made, return the new sort expression.
173
        return Ok(sort_expr);
174
    }
175

            
176
    match sort_expr {
177
        SortExpression::Product { lhs, rhs } => {
178
            let lhs = apply_sort_expression_rec(*lhs, apply)?;
179
            let rhs = apply_sort_expression_rec(*rhs, apply)?;
180
            Ok(SortExpression::Product {
181
                lhs: Box::new(lhs),
182
                rhs: Box::new(rhs),
183
            })
184
        }
185
        SortExpression::Function { domain, range } => {
186
            let domain = apply_sort_expression_rec(*domain, apply)?;
187
            let range = apply_sort_expression_rec(*range, apply)?;
188
            Ok(SortExpression::Function {
189
                domain: Box::new(domain),
190
                range: Box::new(range),
191
            })
192
        }
193
        SortExpression::Struct { mut inner } => {
194
            for decl in &mut inner {
195
                for (_, sort) in &mut decl.args {
196
                    *sort = apply_sort_expression_rec(sort.clone(), apply)?;
197
                }
198
            }
199

            
200
            Ok(SortExpression::Struct { inner })
201
        }
202
        SortExpression::Complex(complex_sort, sort_expression) => {
203
            let inner = apply_sort_expression_rec(*sort_expression, apply)?;
204
            Ok(SortExpression::Complex(complex_sort, Box::new(inner)))
205
        }
206
        SortExpression::Reference(_) | SortExpression::Simple(_) => {
207
            // Ignored
208
            Ok(sort_expr)
209
        }
210
    }
211
}
212

            
213
#[cfg(test)]
214
mod tests {
215
    use std::vec;
216

            
217
    use crate::UntypedStateFrmSpec;
218

            
219
    use super::*;
220

            
221
    #[test]
222
1
    fn test_visit_state_frm_variables() {
223
1
        let input = UntypedStateFrmSpec::parse("mu X. [a]X && mu X. X && Y").unwrap();
224

            
225
1
        let mut variables = vec![];
226
8
        apply_statefrm(input.formula, |frm| {
227
8
            if let StateFrm::Id(name, _) = frm {
228
3
                variables.push(name.clone());
229
5
            }
230

            
231
8
            Ok(None)
232
8
        })
233
1
        .unwrap();
234

            
235
1
        assert_eq!(variables, vec!["X", "X", "Y"]);
236
1
    }
237
}