1
#![forbid(unsafe_code)]
2

            
3
use std::fmt;
4

            
5
use itertools::Itertools;
6
use merc_data::DataVariable;
7
use merc_data::DataVariableRef;
8
use merc_data::is_data_variable;
9

            
10
use crate::Rule;
11
use crate::utilities::DataPosition;
12
use crate::utilities::DataPositionIndexed;
13
use crate::utilities::DataPositionIterator;
14

            
15
/// An equivalence class is a variable with (multiple) positions. This is
16
/// necessary for non-linear patterns.
17
///
18
/// # Example
19
/// Suppose we have a pattern f(x,x), where x is a variable. Then it will have
20
/// one equivalence class storing "x" and the positions 1 and 2. The function
21
/// equivalences_hold checks whether the term has the same term on those
22
/// positions. For example, it will returns false on the term f(a, b) and true
23
/// on the term f(a, a).
24
#[derive(Hash, Clone, Eq, PartialEq, Ord, PartialOrd, Debug)]
25
pub struct EquivalenceClass {
26
    pub variable: DataVariable,
27
    pub positions: Vec<DataPosition>,
28
}
29

            
30
/// Derives the positions in a pattern with same variable (for non-linear patters)
31
15555
pub fn derive_equivalence_classes(rule: &Rule) -> Vec<EquivalenceClass> {
32
15555
    let mut var_equivalences = vec![];
33

            
34
32872
    for (term, pos) in DataPositionIterator::new(rule.lhs.copy()) {
35
32872
        if is_data_variable(&term) {
36
14423
            // Register the position of the variable
37
14423
            update_equivalences(&mut var_equivalences, &DataVariableRef::from(term), pos);
38
18449
        }
39
    }
40

            
41
    // Discard variables that only occur once
42
15555
    var_equivalences.retain(|x| x.positions.len() > 1);
43
15555
    var_equivalences
44
15555
}
45

            
46
/// Checks if the equivalence classes hold for the given term.
47
1131794
pub fn check_equivalence_classes<'a, T, P>(term: &'a P, eqs: &[EquivalenceClass]) -> bool
48
1131794
where
49
1131794
    P: DataPositionIndexed<'a, Target<'a> = T> + 'a,
50
1131794
    T: PartialEq,
51
{
52
1131794
    eqs.iter().all(|ec| {
53
1
        debug_assert!(
54
1
            ec.positions.len() >= 2,
55
            "An equivalence class must contain at least two positions"
56
        );
57

            
58
        // The term at the first position must be equivalent to all other positions.
59
1
        let mut iter_pos = ec.positions.iter();
60
1
        let first = iter_pos.next().unwrap();
61
1
        iter_pos.all(|other_pos| term.get_data_position(first) == term.get_data_position(other_pos))
62
1
    })
63
1131794
}
64

            
65
/// Adds the position of a variable to the equivalence classes
66
14423
fn update_equivalences(ve: &mut Vec<EquivalenceClass>, variable: &DataVariableRef<'_>, pos: DataPosition) {
67
    // Check if the variable was seen before
68
14423
    if ve.iter().any(|ec| ec.variable.copy() == *variable) {
69
1
        for ec in ve.iter_mut() {
70
            // Find the equivalence class and add the position
71
1
            if ec.variable.copy() == *variable && !ec.positions.iter().any(|x| x == &pos) {
72
1
                ec.positions.push(pos);
73
1
                break;
74
            }
75
        }
76
14422
    } else {
77
14422
        // If the variable was not found at another position add a new equivalence class
78
14422
        ve.push(EquivalenceClass {
79
14422
            variable: variable.protect(),
80
14422
            positions: vec![pos],
81
14422
        });
82
14422
    }
83
14423
}
84

            
85
impl fmt::Display for EquivalenceClass {
86
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87
        write!(f, "{}{{ {} }}", self.variable, self.positions.iter().format(", "))
88
    }
89
}
90

            
91
#[cfg(test)]
92
mod tests {
93
    use merc_data::DataExpression;
94
    use merc_data::DataVariable;
95

            
96
    use crate::test_utility::create_rewrite_rule;
97

            
98
    use super::*;
99

            
100
    #[test]
101
1
    fn test_derive_equivalence_classes() {
102
1
        let eq: Vec<EquivalenceClass> =
103
1
            derive_equivalence_classes(&create_rewrite_rule("f(x, h(x))", "result", &["x"]).unwrap());
104

            
105
1
        assert_eq!(
106
            eq,
107
1
            vec![EquivalenceClass {
108
1
                variable: DataVariable::new("x").into(),
109
1
                positions: vec![DataPosition::new(&[1]), DataPosition::new(&[2, 1])]
110
1
            },],
111
            "The resulting config stack is not as expected"
112
        );
113

            
114
        // Check the equivalence class for an example
115
1
        let expression = DataExpression::from_string("f(a(b), h(a(b)))").unwrap();
116

            
117
1
        assert!(
118
1
            check_equivalence_classes(&expression, &eq),
119
            "The equivalence classes are not checked correctly, equivalences: {:?} and term {}",
120
            &eq,
121
            &expression
122
        );
123
1
    }
124
}