1
use std::fmt;
2

            
3
use ahash::HashMap;
4
use ahash::HashMapExt;
5
use merc_aterm::ATermRef;
6
use merc_aterm::Markable;
7
use merc_aterm::Protected;
8
use merc_aterm::SymbolRef;
9
use merc_aterm::Term;
10
use merc_aterm::Transmutable;
11
use merc_aterm::storage::Marker;
12
use merc_data::DataApplication;
13
use merc_data::DataExpression;
14
use merc_data::DataExpressionRef;
15
use merc_data::DataFunctionSymbolRef;
16
use merc_data::DataVariable;
17
use merc_data::DataVariableRef;
18
use merc_data::is_data_machine_number;
19
use merc_data::is_data_variable;
20
use merc_utilities::debug_trace;
21

            
22
use crate::Rule;
23
use crate::utilities::InnermostStack;
24

            
25
use super::DataPosition;
26
use super::DataPositionIterator;
27

            
28
/// A stack used to represent a term with free variables that can be constructed
29
/// efficiently.
30
///
31
/// It stores as much as possible in the term pool. Due to variables it cannot
32
/// be fully compressed. For variables it stores the position in the lhs of a
33
/// rewrite rule where the concrete term can be found that will replace the
34
/// variable.
35
///
36
#[derive(Hash, Debug, PartialEq, Eq, PartialOrd, Ord)]
37
pub struct TermStack {
38
    /// The innermost rewrite stack for the right hand side and the positions that must be added to the stack.
39
    pub innermost_stack: Protected<Vec<Config<'static>>>,
40
    /// The variables of the left-hand side that must be placed at certain places in the stack.
41
    pub variables: Vec<(DataPosition, usize)>,
42
    /// The number of elements that must be reserved on the innermost stack.
43
    pub stack_size: usize,
44
}
45

            
46
#[derive(Hash, Eq, PartialEq, Ord, PartialOrd, Debug)]
47
pub enum Config<'a> {
48
    /// Rewrite the top of the stack and put result at the given index.
49
    Rewrite(usize),
50
    /// Constructs function symbol with given arity at the given index.
51
    Construct(DataFunctionSymbolRef<'a>, usize, usize),
52
    /// A concrete term to be placed at the current position in the stack.
53
    Term(DataExpressionRef<'a>, usize),
54
    /// Yields the given index as returned term.
55
    Return(),
56
}
57

            
58
impl Markable for Config<'_> {
59
    fn mark(&self, marker: &mut Marker<'_>) {
60
        if let Config::Construct(t, _, _) = self {
61
            t.mark(marker);
62
        }
63
    }
64

            
65
206744252
    fn contains_term(&self, term: &ATermRef<'_>) -> bool {
66
206744252
        if let Config::Construct(t, _, _) = self {
67
191973150
            t.contains_term(term)
68
        } else {
69
14771102
            false
70
        }
71
206744252
    }
72

            
73
    fn contains_symbol(&self, symbol: &SymbolRef<'_>) -> bool {
74
        if let Config::Construct(t, _, _) = self {
75
            t.contains_symbol(symbol)
76
        } else {
77
            false
78
        }
79
    }
80

            
81
    fn len(&self) -> usize {
82
        if let Config::Construct(_, _, _) = self { 1 } else { 0 }
83
    }
84
}
85

            
86
impl Transmutable for Config<'static> {
87
    type Target<'a> = Config<'a>;
88

            
89
    fn transmute_lifetime<'a>(&'_ self) -> &'a Self::Target<'a> {
90
        unsafe { std::mem::transmute::<&Self, &'a Config>(self) }
91
    }
92

            
93
    fn transmute_lifetime_mut<'a>(&'_ mut self) -> &'a mut Self::Target<'a> {
94
        unsafe { std::mem::transmute::<&mut Self, &'a mut Config>(self) }
95
    }
96
}
97

            
98
impl fmt::Display for Config<'_> {
99
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100
        match self {
101
            Config::Rewrite(result) => write!(f, "Rewrite({result})"),
102
            Config::Construct(symbol, arity, result) => {
103
                write!(f, "Construct({symbol}, {arity}, {result})")
104
            }
105
            Config::Term(term, result) => {
106
                write!(f, "Term({term}, {result})")
107
            }
108
            Config::Return() => write!(f, "Return()"),
109
        }
110
    }
111
}
112

            
113
impl TermStack {
114
    /// Construct a new right-hand stack for a given equation/rewrite rule.
115
1419
    pub fn new(rule: &Rule) -> TermStack {
116
1419
        Self::from_term(&rule.rhs.copy(), &create_var_map(&rule.lhs))
117
1419
    }
118

            
119
    /// Construct a term stack from a data expression where variables are taken from a specific position of the left hand side.
120
26178
    pub fn from_term(term: &DataExpressionRef, var_map: &HashMap<DataVariable, DataPosition>) -> TermStack {
121
        // Compute the extra information for the InnermostRewriter.
122
26178
        let mut innermost_stack: Protected<Vec<Config>> = Protected::new(vec![]);
123
26178
        let mut variables = vec![];
124
26178
        let mut stack_size = 0;
125

            
126
91242
        for (term, _position) in DataPositionIterator::new(term.copy()) {
127
91242
            if is_data_variable(&term) {
128
21188
                let variable: DataVariableRef<'_> = term.into();
129
21188
                variables.push((
130
21188
                    var_map
131
21188
                        .get(&variable.protect())
132
21188
                        .expect("All variables in the right hand side must occur in the left hand side")
133
21188
                        .clone(),
134
21188
                    stack_size,
135
21188
                ));
136
21188
                stack_size += 1;
137
70054
            } else if is_data_machine_number(&term) {
138
                // Skip SortId(@NoValue) and OpId
139
70054
            } else {
140
70054
                let arity = term.data_arguments().len();
141
70054
                let mut write = innermost_stack.write();
142
70054
                write.push(Config::Construct(term.data_function_symbol(), arity, stack_size));
143
70054
                stack_size += 1;
144
70054
            }
145
        }
146

            
147
26178
        TermStack {
148
26178
            innermost_stack,
149
26178
            stack_size,
150
26178
            variables,
151
26178
        }
152
26178
    }
153

            
154
    // See [evaluate_with]
155
1410807
    pub fn evaluate<'a, 'b>(&self, term: &'b impl Term<'a, 'b>) -> DataExpression {
156
1410807
        let mut builder = TermStackBuilder::new();
157
1410807
        self.evaluate_with(term, &mut builder)
158
1410807
    }
159

            
160
    /// Evaluate the rhs stack for the given term and returns the result.
161
1621389
    pub fn evaluate_with<'a, 'b>(&self, term: &'b impl Term<'a, 'b>, builder: &mut TermStackBuilder) -> DataExpression {
162
1621389
        let stack = &mut builder.stack;
163
1621389
        {
164
1621389
            let mut write = stack.terms.write();
165
1621389
            write.clear();
166
1621389
            write.push(None);
167
1621389
        }
168

            
169
1621389
        InnermostStack::integrate(
170
1621389
            &mut stack.configs.write(),
171
1621389
            &mut stack.terms.write(),
172
1621389
            self,
173
1621389
            &DataExpressionRef::from(term.copy()),
174
            0,
175
        );
176

            
177
        loop {
178
3400287
            debug_trace!("{}", stack);
179

            
180
3400287
            let mut write_configs = stack.configs.write();
181
3400287
            if let Some(config) = write_configs.pop() {
182
1778898
                match config {
183
1778898
                    Config::Construct(symbol, arity, index) => {
184
                        // Take the last arity arguments.
185
1778898
                        let mut write_terms = stack.terms.write();
186
1778898
                        let length = write_terms.len();
187

            
188
1778898
                        let arguments = &write_terms[length - arity..];
189

            
190
1778898
                        let term: DataExpression = if arguments.is_empty() {
191
377500
                            symbol.protect().into()
192
                        } else {
193
1401398
                            DataApplication::with_iter(&symbol.copy(), arguments.len(), arguments.iter().flatten())
194
1401398
                                .into()
195
                        };
196

            
197
                        // Add the term on the stack.
198
1778898
                        write_terms.drain(length - arity..);
199
1778898
                        let t = write_terms.protect(&term);
200
1778898
                        write_terms[index] = Some(t.into());
201
                    }
202
                    Config::Term(term, index) => {
203
                        let mut write_terms = stack.terms.write();
204
                        let t = write_terms.protect(&term);
205
                        write_terms[index] = Some(t.into());
206
                    }
207
                    Config::Rewrite(_) => {
208
                        unreachable!("This case should not happen");
209
                    }
210
                    Config::Return() => {
211
                        unreachable!("This case should not happen");
212
                    }
213
                }
214
            } else {
215
1621389
                break;
216
            }
217
        }
218

            
219
1621389
        debug_assert!(
220
1621389
            stack.terms.read().len() == 1,
221
            "Expect exactly one term on the result stack"
222
        );
223

            
224
1621389
        let mut write_terms = stack.terms.write();
225

            
226
1621389
        write_terms
227
1621389
            .pop()
228
1621389
            .expect("The result should be the last element on the stack")
229
1621389
            .expect("The result should be Some")
230
1621389
            .protect()
231
1621389
    }
232

            
233
    /// Used to check if a subterm is duplicated, for example "times(s(x), y) =
234
    /// plus(y, times(x,y))" is duplicating.
235
14138
    pub(crate) fn contains_duplicate_var_references(&self) -> bool {
236
14138
        let mut variables: Vec<&DataPosition> = self.variables.iter().map(|(v, _)| v).collect();
237

            
238
        // Check if there are duplicates.
239
14138
        variables.sort_unstable();
240
14138
        let len = variables.len();
241
14138
        variables.dedup();
242

            
243
14138
        len != variables.len()
244
14138
    }
245
}
246

            
247
impl Clone for TermStack {
248
    fn clone(&self) -> Self {
249
        // TODO: It would make sense if Protected could implement Clone.
250
        let mut innermost_stack: Protected<Vec<Config>> = Protected::new(vec![]);
251

            
252
        let read = self.innermost_stack.read();
253
        let mut write = innermost_stack.write();
254
        for t in read.iter() {
255
            match t {
256
                Config::Rewrite(x) => write.push(Config::Rewrite(*x)),
257
                Config::Construct(f, x, y) => {
258
                    write.push(Config::Construct(f.copy(), *x, *y));
259
                }
260
                Config::Term(t, y) => {
261
                    write.push(Config::Term(t.copy(), *y));
262
                }
263
                Config::Return() => write.push(Config::Return()),
264
            }
265
        }
266
        drop(write);
267

            
268
        Self {
269
            variables: self.variables.clone(),
270
            stack_size: self.stack_size,
271
            innermost_stack,
272
        }
273
    }
274
}
275

            
276
pub struct TermStackBuilder {
277
    stack: InnermostStack,
278
}
279

            
280
impl TermStackBuilder {
281
1410833
    pub fn new() -> Self {
282
1410833
        Self {
283
1410833
            stack: InnermostStack::default(),
284
1410833
        }
285
1410833
    }
286
}
287

            
288
impl Default for TermStackBuilder {
289
    fn default() -> Self {
290
        Self::new()
291
    }
292
}
293

            
294
/// Create a mapping of variables to their position in the given term
295
31111
pub fn create_var_map(t: &DataExpression) -> HashMap<DataVariable, DataPosition> {
296
31111
    let mut result = HashMap::new();
297

            
298
65744
    for (term, position) in DataPositionIterator::new(t.copy()) {
299
65744
        if is_data_variable(&term) {
300
28846
            result.insert(term.protect().into(), position.clone());
301
36898
        }
302
    }
303

            
304
31111
    result
305
31111
}
306

            
307
#[cfg(test)]
308
mod tests {
309
    use super::*;
310

            
311
    use ahash::AHashSet;
312
    use merc_data::DataFunctionSymbol;
313
    use merc_utilities::test_logger;
314

            
315
    use crate::test_utility::create_rewrite_rule;
316

            
317
    use test_log::test;
318

            
319
    #[test]
320
    fn test_rhs_stack() {
321
        let rhs_stack = TermStack::new(&create_rewrite_rule("fact(s(N))", "times(s(N), fact(N))", &["N"]).unwrap());
322
        let mut expected = Protected::new(vec![]);
323

            
324
        let t1 = DataFunctionSymbol::new("times");
325
        let t2 = DataFunctionSymbol::new("s");
326
        let t3 = DataFunctionSymbol::new("fact");
327

            
328
        let mut write = expected.write();
329
        write.push(Config::Construct(t1.copy(), 2, 0));
330
        write.push(Config::Construct(t2.copy(), 1, 1));
331
        write.push(Config::Construct(t3.copy(), 1, 2));
332
        drop(write);
333

            
334
        // Check if the resulting construction succeeded.
335
        assert_eq!(
336
            rhs_stack.innermost_stack, expected,
337
            "The resulting config stack is not as expected"
338
        );
339

            
340
        assert_eq!(rhs_stack.stack_size, 5, "The stack size does not match");
341

            
342
        // Test the evaluation
343
        let lhs = DataExpression::from_string("fact(s(a))").unwrap();
344
        let rhs = DataExpression::from_string("times(s(a), fact(a))").unwrap();
345

            
346
        assert_eq!(
347
            rhs_stack.evaluate(&lhs),
348
            rhs,
349
            "The rhs stack does not evaluate to the expected term"
350
        );
351
    }
352

            
353
    #[test]
354
    fn test_rhs_stack_variable() {
355
        let rhs = TermStack::new(&create_rewrite_rule("f(x)", "x", &["x"]).unwrap());
356

            
357
        // Check if the resulting construction succeeded.
358
        assert!(
359
            rhs.innermost_stack.read().is_empty(),
360
            "The resulting config stack is not as expected"
361
        );
362

            
363
        assert_eq!(rhs.stack_size, 1, "The stack size does not match");
364
    }
365

            
366
    #[test]
367
    fn test_evaluation() {
368
        test_logger();
369

            
370
        let rhs = DataExpression::from_string_untyped("f(f(a,a),x)", &AHashSet::from([String::from("x")])).unwrap();
371
        let lhs = DataExpression::from_string("g(b)").unwrap();
372

            
373
        // Make a variable map with only x@1.
374
        let mut map = HashMap::new();
375
        map.insert(DataVariable::new("x"), DataPosition::new(&[1]));
376

            
377
        let sctt = TermStack::from_term(&rhs.copy(), &map);
378

            
379
        let expected = DataExpression::from_string("f(f(a,a),b)").unwrap();
380

            
381
        assert_eq!(sctt.evaluate(&lhs), expected);
382
    }
383

            
384
    #[test]
385
    fn test_create_varmap() {
386
        let t = DataExpression::from_string_untyped("f(x,x)", &AHashSet::from([String::from("x")])).unwrap();
387
        let x = DataVariable::new("x");
388

            
389
        let map = create_var_map(&t);
390
        assert!(map.contains_key(&x));
391
    }
392

            
393
    #[test]
394
    fn test_is_duplicating() {
395
        let rhs = DataExpression::from_string_untyped("f(x,x)", &AHashSet::from([String::from("x")])).unwrap();
396

            
397
        // Make a variable map with only x@1.
398
        let mut map = HashMap::new();
399
        map.insert(DataVariable::new("x"), DataPosition::new(&[1]));
400

            
401
        let sctt = TermStack::from_term(&rhs.copy(), &map);
402
        assert!(sctt.contains_duplicate_var_references(), "This sctt is duplicating");
403
    }
404
}