1
#![forbid(unsafe_code)]
2

            
3
use std::collections::VecDeque;
4
use std::fmt;
5

            
6
use merc_data::DataExpression;
7
use merc_data::DataExpressionRef;
8

            
9
use super::ExplicitPosition;
10

            
11
/// A newtype wrapper around [ExplicitPosition] specifically for data expressions
12
/// This provides type safety and clarity when dealing with positions in data expressions
13
#[repr(transparent)]
14
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
15
pub struct DataPosition(ExplicitPosition);
16

            
17
impl DataPosition {
18
    /// Creates a new empty position
19
76825
    pub fn empty() -> Self {
20
76825
        Self(ExplicitPosition::empty())
21
76825
    }
22

            
23
    /// Creates a new position from a slice of indices
24
5269475
    pub fn new(indices: &[usize]) -> Self {
25
5269475
        Self(ExplicitPosition::new(indices))
26
5269475
    }
27

            
28
    /// Returns the underlying indices
29
41543318
    pub fn indices(&self) -> &[usize] {
30
41543318
        self.0.indices()
31
41543318
    }
32

            
33
    /// Returns the length of the position indices
34
5642081
    pub fn len(&self) -> usize {
35
5642081
        self.0.len()
36
5642081
    }
37

            
38
    /// Returns true if the position is empty
39
2653196
    pub fn is_empty(&self) -> bool {
40
2653196
        self.0.is_empty()
41
2653196
    }
42

            
43
    /// Adds the index to the position
44
188407
    pub fn push(&mut self, index: usize) {
45
188407
        self.0.push(index);
46
188407
    }
47
}
48

            
49
impl fmt::Display for DataPosition {
50
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51
        write!(f, "{}", self.0)
52
    }
53
}
54

            
55
impl fmt::Debug for DataPosition {
56
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57
        write!(f, "{}", self.0)
58
    }
59
}
60

            
61
/// A specialisation of the [super::PositionIndexed] trait for [DataExpression]. This is used to keep the indexing consistent.
62
pub trait DataPositionIndexed<'b> {
63
    type Target<'a>
64
    where
65
        Self: 'a,
66
        Self: 'b;
67

            
68
    /// Returns the Target at the given position.
69
    fn get_data_position(&'b self, position: &DataPosition) -> Self::Target<'b>;
70
}
71

            
72
impl<'b> DataPositionIndexed<'b> for DataExpression {
73
    type Target<'a>
74
        = DataExpressionRef<'a>
75
    where
76
        Self: 'a;
77

            
78
4566
    fn get_data_position(&'b self, position: &DataPosition) -> Self::Target<'b> {
79
4566
        let mut result = self.copy();
80

            
81
4569
        for index in position.indices() {
82
1242
            result = result.data_arg(*index - 1); // Note that positions are 1 indexed.
83
1242
        }
84

            
85
4566
        result
86
4566
    }
87
}
88

            
89
impl<'b> DataPositionIndexed<'b> for DataExpressionRef<'b> {
90
    type Target<'a>
91
        = DataExpressionRef<'a>
92
    where
93
        Self: 'a;
94

            
95
24284004
    fn get_data_position(&'b self, position: &DataPosition) -> Self::Target<'b> {
96
24284004
        let mut result = self.copy();
97

            
98
24284006
        for index in position.indices() {
99
16368345
            result = result.data_arg(*index - 1); // Note that positions are 1 indexed.
100
16368345
        }
101

            
102
24284004
        result
103
24284004
    }
104
}
105

            
106
/// An iterator over all (term, position) pairs of the given [DataExpression].
107
pub struct DataPositionIterator<'a> {
108
    queue: VecDeque<(DataExpressionRef<'a>, DataPosition)>,
109
}
110

            
111
impl<'a> DataPositionIterator<'a> {
112
72846
    pub fn new(t: DataExpressionRef<'a>) -> Self {
113
72846
        Self {
114
72846
            queue: VecDeque::from([(t, DataPosition::empty())]),
115
72846
        }
116
72846
    }
117
}
118

            
119
impl<'a> Iterator for DataPositionIterator<'a> {
120
    type Item = (DataExpressionRef<'a>, DataPosition);
121

            
122
262712
    fn next(&mut self) -> Option<Self::Item> {
123
262712
        if self.queue.is_empty() {
124
72846
            None
125
        } else {
126
            // Get a subterm to inspect
127
189866
            let (term, pos) = self.queue.pop_front().unwrap();
128

            
129
            // Put subterms in the queue
130
189866
            for (i, argument) in term.data_arguments().enumerate() {
131
117020
                let mut new_position = pos.clone();
132
117020
                new_position.push(i + 1);
133
117020
                self.queue.push_back((argument, new_position));
134
117020
            }
135

            
136
189866
            Some((term, pos))
137
        }
138
262712
    }
139
}
140

            
141
#[cfg(test)]
142
mod tests {
143
    use super::*;
144

            
145
    #[test]
146
1
    fn test_get_data_position() {
147
1
        let t = DataExpression::from_string("f(g(a),b)").unwrap();
148
1
        let expected = DataExpression::from_string("a").unwrap();
149

            
150
1
        assert_eq!(t.get_data_position(&DataPosition::new(&[1, 1])), expected.copy());
151
1
    }
152

            
153
    #[test]
154
1
    fn test_data_position_iterator() {
155
1
        let t = DataExpression::from_string("f(g(a),b)").unwrap();
156

            
157
4
        for (term, pos) in DataPositionIterator::new(t.copy()) {
158
4
            assert_eq!(
159
4
                t.get_data_position(&pos),
160
                term,
161
                "The resulting (subterm, position) pair doesn't match the get_data_position implementation"
162
            );
163
        }
164

            
165
1
        assert_eq!(
166
1
            DataPositionIterator::new(t.copy()).count(),
167
            4,
168
            "The number of subterms doesn't match the expected value"
169
        );
170
1
    }
171
}