1
use std::fmt;
2

            
3
use ahash::HashMap;
4
use ahash::HashMapExt;
5
use merc_aterm::ATermRef;
6
use merc_aterm::Markable;
7
use merc_aterm::Protected;
8
use merc_aterm::SymbolRef;
9
use merc_aterm::Term;
10
use merc_aterm::Transmutable;
11
use merc_aterm::storage::Marker;
12
use merc_data::DataApplication;
13
use merc_data::DataExpression;
14
use merc_data::DataExpressionRef;
15
use merc_data::DataFunctionSymbolRef;
16
use merc_data::DataVariable;
17
use merc_data::DataVariableRef;
18
use merc_data::is_data_machine_number;
19
use merc_data::is_data_variable;
20
use merc_utilities::debug_trace;
21

            
22
use crate::Rule;
23
use crate::utilities::InnermostStack;
24

            
25
use super::DataPosition;
26
use super::DataPositionIterator;
27

            
28
/// A stack used to represent a term with free variables that can be constructed
29
/// efficiently.
30
///
31
/// It stores as much as possible in the term pool. Due to variables it cannot
32
/// be fully compressed. For variables it stores the position in the lhs of a
33
/// rewrite rule where the concrete term can be found that will replace the
34
/// variable.
35
///
36
#[derive(Hash, Debug, PartialEq, Eq, PartialOrd, Ord)]
37
pub struct TermStack {
38
    /// The innermost rewrite stack for the right hand side and the positions that must be added to the stack.
39
    pub innermost_stack: Protected<Vec<Config<'static>>>,
40
    /// The variables of the left-hand side that must be placed at certain places in the stack.
41
    pub variables: Vec<(DataPosition, usize)>,
42
    /// The number of elements that must be reserved on the innermost stack.
43
    pub stack_size: usize,
44
}
45

            
46
#[derive(Hash, Eq, PartialEq, Ord, PartialOrd, Debug)]
47
pub enum Config<'a> {
48
    /// Rewrite the top of the stack and put result at the given index.
49
    Rewrite(usize),
50
    /// Constructs function symbol with given arity at the given index.
51
    Construct(DataFunctionSymbolRef<'a>, usize, usize),
52
    /// A concrete term to be placed at the current position in the stack.
53
    Term(DataExpressionRef<'a>, usize),
54
    /// Yields the given index as returned term.
55
    Return(),
56
}
57

            
58
impl Markable for Config<'_> {
59
3336336
    fn mark(&self, marker: &mut Marker<'_>) {
60
3336336
        if let Config::Construct(t, _, _) = self {
61
3336136
            t.mark(marker);
62
3336136
        }
63
3336336
    }
64

            
65
206854389
    fn contains_term(&self, term: &ATermRef<'_>) -> bool {
66
206854389
        if let Config::Construct(t, _, _) = self {
67
192072017
            t.contains_term(term)
68
        } else {
69
14782372
            false
70
        }
71
206854389
    }
72

            
73
    fn contains_symbol(&self, symbol: &SymbolRef<'_>) -> bool {
74
        if let Config::Construct(t, _, _) = self {
75
            t.contains_symbol(symbol)
76
        } else {
77
            false
78
        }
79
    }
80

            
81
    fn len(&self) -> usize {
82
        if let Config::Construct(_, _, _) = self { 1 } else { 0 }
83
    }
84
}
85

            
86
impl Transmutable for Config<'static> {
87
    type Target<'a> = Config<'a>;
88

            
89
    fn transmute_lifetime<'a>(&'_ self) -> &'a Self::Target<'a> {
90
        unsafe { std::mem::transmute::<&Self, &'a Config>(self) }
91
    }
92

            
93
    fn transmute_lifetime_mut<'a>(&'_ mut self) -> &'a mut Self::Target<'a> {
94
        unsafe { std::mem::transmute::<&mut Self, &'a mut Config>(self) }
95
    }
96
}
97

            
98
impl fmt::Display for Config<'_> {
99
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100
        match self {
101
            Config::Rewrite(result) => write!(f, "Rewrite({result})"),
102
            Config::Construct(symbol, arity, result) => {
103
                write!(f, "Construct({symbol}, {arity}, {result})")
104
            }
105
            Config::Term(term, result) => {
106
                write!(f, "Term({term}, {result})")
107
            }
108
            Config::Return() => write!(f, "Return()"),
109
        }
110
    }
111
}
112

            
113
impl TermStack {
114
    /// Construct a new right-hand stack for a given equation/rewrite rule.
115
1419
    pub fn new(rule: &Rule) -> TermStack {
116
1419
        Self::from_term(&rule.rhs.copy(), &create_var_map(&rule.lhs))
117
1419
    }
118

            
119
    /// Construct a term stack from a data expression where variables are taken from a specific position of the left hand side.
120
26178
    pub fn from_term(term: &DataExpressionRef, var_map: &HashMap<DataVariable, DataPosition>) -> TermStack {
121
        // Compute the extra information for the InnermostRewriter.
122
26178
        let mut innermost_stack: Protected<Vec<Config>> = Protected::new(vec![]);
123
26178
        let mut variables = vec![];
124
26178
        let mut stack_size = 0;
125

            
126
91242
        for (term, _position) in DataPositionIterator::new(term.copy()) {
127
91242
            if is_data_variable(&term) {
128
21188
                let variable: DataVariableRef<'_> = term.into();
129
21188
                variables.push((
130
21188
                    var_map
131
21188
                        .get(&variable.protect())
132
21188
                        .expect("All variables in the right hand side must occur in the left hand side")
133
21188
                        .clone(),
134
21188
                    stack_size,
135
21188
                ));
136
21188
                stack_size += 1;
137
70054
            } else if is_data_machine_number(&term) {
138
                // Skip SortId(@NoValue) and OpId
139
70054
            } else {
140
70054
                let arity = term.data_arguments().len();
141
70054
                let mut write = innermost_stack.write();
142
70054
                write.push(Config::Construct(term.data_function_symbol(), arity, stack_size));
143
70054
                stack_size += 1;
144
70054
            }
145
        }
146

            
147
26178
        TermStack {
148
26178
            innermost_stack,
149
26178
            stack_size,
150
26178
            variables,
151
26178
        }
152
26178
    }
153

            
154
    // See [evaluate_with]
155
1410030
    pub fn evaluate<'a, 'b, T: Term<'a, 'b>>(&self, term: &'b T) -> DataExpression {
156
1410030
        let mut builder = TermStackBuilder::new();
157
1410030
        self.evaluate_with(term, &mut builder)
158
1410030
    }
159

            
160
    /// Evaluate the rhs stack for the given term and returns the result.
161
1620730
    pub fn evaluate_with<'a, 'b, T: Term<'a, 'b>>(
162
1620730
        &self,
163
1620730
        term: &'b T,
164
1620730
        builder: &mut TermStackBuilder,
165
1620730
    ) -> DataExpression {
166
1620730
        let stack = &mut builder.stack;
167
1620730
        {
168
1620730
            let mut write = stack.terms.write();
169
1620730
            write.clear();
170
1620730
            write.push(None);
171
1620730
        }
172

            
173
1620730
        InnermostStack::integrate(
174
1620730
            &mut stack.configs.write(),
175
1620730
            &mut stack.terms.write(),
176
1620730
            self,
177
1620730
            &DataExpressionRef::from(term.copy()),
178
            0,
179
        );
180

            
181
        loop {
182
3399433
            debug_trace!("{}", stack);
183

            
184
3399433
            let mut write_configs = stack.configs.write();
185
3399433
            if let Some(config) = write_configs.pop() {
186
1778703
                match config {
187
1778703
                    Config::Construct(symbol, arity, index) => {
188
                        // Take the last arity arguments.
189
1778703
                        let mut write_terms = stack.terms.write();
190
1778703
                        let length = write_terms.len();
191

            
192
1778703
                        let arguments = &write_terms[length - arity..];
193

            
194
1778703
                        let term: DataExpression = if arguments.is_empty() {
195
377467
                            symbol.protect().into()
196
                        } else {
197
1401236
                            DataApplication::with_iter(&symbol.copy(), arguments.len(), arguments.iter().flatten())
198
1401236
                                .into()
199
                        };
200

            
201
                        // Add the term on the stack.
202
1778703
                        write_terms.drain(length - arity..);
203
1778703
                        let t = write_terms.protect(&term);
204
1778703
                        write_terms[index] = Some(t.into());
205
                    }
206
                    Config::Term(term, index) => {
207
                        let mut write_terms = stack.terms.write();
208
                        let t = write_terms.protect(&term);
209
                        write_terms[index] = Some(t.into());
210
                    }
211
                    Config::Rewrite(_) => {
212
                        unreachable!("This case should not happen");
213
                    }
214
                    Config::Return() => {
215
                        unreachable!("This case should not happen");
216
                    }
217
                }
218
            } else {
219
1620730
                break;
220
            }
221
        }
222

            
223
1620730
        debug_assert!(
224
1620730
            stack.terms.read().len() == 1,
225
            "Expect exactly one term on the result stack"
226
        );
227

            
228
1620730
        let mut write_terms = stack.terms.write();
229

            
230
1620730
        write_terms
231
1620730
            .pop()
232
1620730
            .expect("The result should be the last element on the stack")
233
1620730
            .expect("The result should be Some")
234
1620730
            .protect()
235
1620730
    }
236

            
237
    /// Used to check if a subterm is duplicated, for example "times(s(x), y) =
238
    /// plus(y, times(x,y))" is duplicating.
239
14138
    pub(crate) fn contains_duplicate_var_references(&self) -> bool {
240
14138
        let mut variables: Vec<&DataPosition> = self.variables.iter().map(|(v, _)| v).collect();
241

            
242
        // Check if there are duplicates.
243
14138
        variables.sort_unstable();
244
14138
        let len = variables.len();
245
14138
        variables.dedup();
246

            
247
14138
        len != variables.len()
248
14138
    }
249
}
250

            
251
impl Clone for TermStack {
252
    fn clone(&self) -> Self {
253
        // TODO: It would make sense if Protected could implement Clone.
254
        let mut innermost_stack: Protected<Vec<Config>> = Protected::new(vec![]);
255

            
256
        let read = self.innermost_stack.read();
257
        let mut write = innermost_stack.write();
258
        for t in read.iter() {
259
            match t {
260
                Config::Rewrite(x) => write.push(Config::Rewrite(*x)),
261
                Config::Construct(f, x, y) => {
262
                    write.push(Config::Construct(f.copy(), *x, *y));
263
                }
264
                Config::Term(t, y) => {
265
                    write.push(Config::Term(t.copy(), *y));
266
                }
267
                Config::Return() => write.push(Config::Return()),
268
            }
269
        }
270
        drop(write);
271

            
272
        Self {
273
            variables: self.variables.clone(),
274
            stack_size: self.stack_size,
275
            innermost_stack,
276
        }
277
    }
278
}
279

            
280
pub struct TermStackBuilder {
281
    stack: InnermostStack,
282
}
283

            
284
impl TermStackBuilder {
285
1410056
    pub fn new() -> Self {
286
1410056
        Self {
287
1410056
            stack: InnermostStack::default(),
288
1410056
        }
289
1410056
    }
290
}
291

            
292
impl Default for TermStackBuilder {
293
    fn default() -> Self {
294
        Self::new()
295
    }
296
}
297

            
298
/// Create a mapping of variables to their position in the given term
299
31111
pub fn create_var_map(t: &DataExpression) -> HashMap<DataVariable, DataPosition> {
300
31111
    let mut result = HashMap::new();
301

            
302
65744
    for (term, position) in DataPositionIterator::new(t.copy()) {
303
65744
        if is_data_variable(&term) {
304
28846
            result.insert(term.protect().into(), position.clone());
305
36898
        }
306
    }
307

            
308
31111
    result
309
31111
}
310

            
311
#[cfg(test)]
312
mod tests {
313
    use super::*;
314

            
315
    use ahash::AHashSet;
316
    use merc_data::DataFunctionSymbol;
317
    use merc_utilities::test_logger;
318

            
319
    use crate::test_utility::create_rewrite_rule;
320

            
321
    use test_log::test;
322

            
323
    #[test]
324
    fn test_rhs_stack() {
325
        let rhs_stack = TermStack::new(&create_rewrite_rule("fact(s(N))", "times(s(N), fact(N))", &["N"]).unwrap());
326
        let mut expected = Protected::new(vec![]);
327

            
328
        let t1 = DataFunctionSymbol::new("times");
329
        let t2 = DataFunctionSymbol::new("s");
330
        let t3 = DataFunctionSymbol::new("fact");
331

            
332
        let mut write = expected.write();
333
        write.push(Config::Construct(t1.copy(), 2, 0));
334
        write.push(Config::Construct(t2.copy(), 1, 1));
335
        write.push(Config::Construct(t3.copy(), 1, 2));
336
        drop(write);
337

            
338
        // Check if the resulting construction succeeded.
339
        assert_eq!(
340
            rhs_stack.innermost_stack, expected,
341
            "The resulting config stack is not as expected"
342
        );
343

            
344
        assert_eq!(rhs_stack.stack_size, 5, "The stack size does not match");
345

            
346
        // Test the evaluation
347
        let lhs = DataExpression::from_string("fact(s(a))").unwrap();
348
        let rhs = DataExpression::from_string("times(s(a), fact(a))").unwrap();
349

            
350
        assert_eq!(
351
            rhs_stack.evaluate(&lhs),
352
            rhs,
353
            "The rhs stack does not evaluate to the expected term"
354
        );
355
    }
356

            
357
    #[test]
358
    fn test_rhs_stack_variable() {
359
        let rhs = TermStack::new(&create_rewrite_rule("f(x)", "x", &["x"]).unwrap());
360

            
361
        // Check if the resulting construction succeeded.
362
        assert!(
363
            rhs.innermost_stack.read().is_empty(),
364
            "The resulting config stack is not as expected"
365
        );
366

            
367
        assert_eq!(rhs.stack_size, 1, "The stack size does not match");
368
    }
369

            
370
    #[test]
371
    fn test_evaluation() {
372
        test_logger();
373

            
374
        let rhs = DataExpression::from_string_untyped("f(f(a,a),x)", &AHashSet::from([String::from("x")])).unwrap();
375
        let lhs = DataExpression::from_string("g(b)").unwrap();
376

            
377
        // Make a variable map with only x@1.
378
        let mut map = HashMap::new();
379
        map.insert(DataVariable::new("x"), DataPosition::new(&[1]));
380

            
381
        let sctt = TermStack::from_term(&rhs.copy(), &map);
382

            
383
        let expected = DataExpression::from_string("f(f(a,a),b)").unwrap();
384

            
385
        assert_eq!(sctt.evaluate(&lhs), expected);
386
    }
387

            
388
    #[test]
389
    fn test_create_varmap() {
390
        let t = DataExpression::from_string_untyped("f(x,x)", &AHashSet::from([String::from("x")])).unwrap();
391
        let x = DataVariable::new("x");
392

            
393
        let map = create_var_map(&t);
394
        assert!(map.contains_key(&x));
395
    }
396

            
397
    #[test]
398
    fn test_is_duplicating() {
399
        let rhs = DataExpression::from_string_untyped("f(x,x)", &AHashSet::from([String::from("x")])).unwrap();
400

            
401
        // Make a variable map with only x@1.
402
        let mut map = HashMap::new();
403
        map.insert(DataVariable::new("x"), DataPosition::new(&[1]));
404

            
405
        let sctt = TermStack::from_term(&rhs.copy(), &map);
406
        assert!(sctt.contains_duplicate_var_references(), "This sctt is duplicating");
407
    }
408
}