1
use merc_utilities::MercError;
2

            
3
use crate::StateFrm;
4

            
5
/// Applies the given function recursively to the state formula.
6
///
7
/// The substitution function takes a state formula and returns an optional new
8
/// formula. If it returns `Some(new_formula)`, the substitution is applied and
9
/// the new formula is returned. If it returns `None`, the substitution is not
10
/// applied and the function continues to traverse the formula tree.
11
10
pub fn apply_statefrm(
12
10
    formula: StateFrm,
13
10
    mut function: impl FnMut(&StateFrm) -> Result<Option<StateFrm>, MercError>,
14
10
) -> Result<StateFrm, MercError> {
15
10
    apply_statefrm_rec(formula, &mut function)
16
10
}
17

            
18
/// Visits the state formula and calls the given function on each subformula.
19
///
20
/// The substitution function takes a state formula and returns an optional new
21
/// formula. If it returns `Some(new_formula)`, the substitution is applied and
22
/// the new formula is returned. If it returns `None`, the substitution is not
23
/// applied and the function continues to traverse the formula tree.
24
4
pub fn visit_statefrm(
25
4
    formula: &StateFrm,
26
4
    mut visitor: impl FnMut(&StateFrm) -> Result<(), MercError>,
27
4
) -> Result<(), MercError> {
28
4
    visit_statefrm_rec(formula, &mut visitor)
29
4
}
30

            
31
/// See [`apply`].
32
42
fn apply_statefrm_rec(
33
42
    formula: StateFrm,
34
42
    apply: &mut impl FnMut(&StateFrm) -> Result<Option<StateFrm>, MercError>,
35
42
) -> Result<StateFrm, MercError> {
36
42
    if let Some(formula) = apply(&formula)? {
37
        // A substitution was made, return the new formula.
38
5
        return Ok(formula);
39
37
    }
40

            
41
37
    match formula {
42
9
        StateFrm::Binary { op, lhs, rhs } => {
43
9
            let new_lhs = apply_statefrm_rec(*lhs, apply)?;
44
9
            let new_rhs = apply_statefrm_rec(*rhs, apply)?;
45
9
            Ok(StateFrm::Binary {
46
9
                op,
47
9
                lhs: Box::new(new_lhs),
48
9
                rhs: Box::new(new_rhs),
49
9
            })
50
        }
51
        StateFrm::FixedPoint {
52
2
            operator,
53
2
            variable,
54
2
            body,
55
        } => {
56
2
            let new_body = apply_statefrm_rec(*body, apply)?;
57
2
            Ok(StateFrm::FixedPoint {
58
2
                operator,
59
2
                variable,
60
2
                body: Box::new(new_body),
61
2
            })
62
        }
63
        StateFrm::Bound { bound, variables, body } => {
64
            let new_body = apply_statefrm_rec(*body, apply)?;
65
            Ok(StateFrm::Bound {
66
                bound,
67
                variables,
68
                body: Box::new(new_body),
69
            })
70
        }
71
        StateFrm::Modality {
72
12
            operator,
73
12
            formula,
74
12
            expr,
75
        } => {
76
12
            let expr = apply_statefrm_rec(*expr, apply)?;
77
12
            Ok(StateFrm::Modality {
78
12
                operator,
79
12
                formula,
80
12
                expr: Box::new(expr),
81
12
            })
82
        }
83
        StateFrm::Quantifier {
84
            quantifier,
85
            variables,
86
            body,
87
        } => {
88
            let new_body = apply_statefrm_rec(*body, apply)?;
89
            Ok(StateFrm::Quantifier {
90
                quantifier,
91
                variables,
92
                body: Box::new(new_body),
93
            })
94
        }
95
        StateFrm::DataValExprRightMult(expr, data_val) => {
96
            let new_expr = apply_statefrm_rec(*expr, apply)?;
97
            Ok(StateFrm::DataValExprRightMult(Box::new(new_expr), data_val))
98
        }
99
        StateFrm::DataValExprLeftMult(data_val, expr) => {
100
            let new_expr = apply_statefrm_rec(*expr, apply)?;
101
            Ok(StateFrm::DataValExprLeftMult(data_val, Box::new(new_expr)))
102
        }
103
        StateFrm::Unary { op, expr } => {
104
            let new_expr = apply_statefrm_rec(*expr, apply)?;
105
            Ok(StateFrm::Unary {
106
                op,
107
                expr: Box::new(new_expr),
108
            })
109
        }
110
        StateFrm::Id(_, _)
111
        | StateFrm::True
112
        | StateFrm::False
113
        | StateFrm::Delay(_)
114
        | StateFrm::Yaled(_)
115
14
        | StateFrm::DataValExpr(_) => Ok(formula),
116
    }
117
42
}
118

            
119
/// See [`visit`].
120
38
fn visit_statefrm_rec(
121
38
    formula: &StateFrm,
122
38
    function: &mut impl FnMut(&StateFrm) -> Result<(), MercError>,
123
38
) -> Result<(), MercError> {
124
38
    function(formula)?;
125

            
126
38
    match formula {
127
7
        StateFrm::Binary { lhs, rhs, .. } => {
128
7
            visit_statefrm_rec(lhs, function)?;
129
7
            visit_statefrm_rec(rhs, function)?;
130
        }
131
9
        StateFrm::FixedPoint { body, .. } => {
132
9
            visit_statefrm_rec(body, function)?;
133
        }
134
        StateFrm::Bound { body, .. } => {
135
            visit_statefrm_rec(body, function)?;
136
        }
137
11
        StateFrm::Modality { expr, .. } => {
138
11
            visit_statefrm_rec(expr, function)?;
139
        }
140
        StateFrm::Quantifier { body, .. } => {
141
            visit_statefrm_rec(body, function)?;
142
        }
143
        StateFrm::DataValExprRightMult(expr, _data_val) => {
144
            visit_statefrm_rec(expr, function)?;
145
        }
146
        StateFrm::DataValExprLeftMult(_data_val, expr) => {
147
            visit_statefrm_rec(expr, function)?;
148
        }
149
        StateFrm::Unary { expr, .. } => {
150
            visit_statefrm_rec(expr, function)?;
151
        }
152
        StateFrm::Id(_, _)
153
        | StateFrm::True
154
        | StateFrm::False
155
        | StateFrm::Delay(_)
156
        | StateFrm::Yaled(_)
157
11
        | StateFrm::DataValExpr(_) => {}
158
    }
159

            
160
38
    Ok(())
161
38
}
162

            
163
#[cfg(test)]
164
mod tests {
165
    use std::vec;
166

            
167
    use crate::UntypedStateFrmSpec;
168

            
169
    use super::*;
170

            
171
    #[test]
172
1
    fn test_visit_state_frm_variables() {
173
1
        let input = UntypedStateFrmSpec::parse("mu X. [a]X && mu X. X && Y").unwrap();
174

            
175
1
        let mut variables = vec![];
176
8
        apply_statefrm(input.formula, |frm| {
177
8
            if let StateFrm::Id(name, _) = frm {
178
3
                variables.push(name.clone());
179
5
            }
180

            
181
8
            Ok(None)
182
8
        })
183
1
        .unwrap();
184

            
185
1
        assert_eq!(variables, vec!["X", "X", "Y"]);
186
1
    }
187
}