1
#![forbid(unsafe_code)]
2

            
3
use merc_aterm::Protected;
4
use merc_aterm::Term;
5
use merc_aterm::storage::ThreadTermPool;
6
use merc_data::DataExpression;
7
use merc_data::DataExpressionRef;
8
use merc_data::is_data_application;
9

            
10
use super::DataPosition;
11

            
12
pub type DataSubstitutionBuilder = Protected<Vec<DataExpressionRef<'static>>>;
13

            
14
/// This function substitutes the term 't' at the position 'p' with 'new_subterm', see [super::substitute].
15
1
pub fn data_substitute(
16
1
    tp: &ThreadTermPool,
17
1
    t: &DataExpressionRef<'_>,
18
1
    new_subterm: DataExpression,
19
1
    position: &DataPosition,
20
1
) -> DataExpression {
21
1
    let mut args = Protected::new(vec![]);
22
1
    substitute_rec(tp, t, new_subterm, position.indices(), &mut args, 0)
23
1
}
24

            
25
/// This is the same as [data_substitute], but it uses a [DataSubstitutionBuilder] to store the arguments temporarily.
26
1292078
pub fn data_substitute_with(
27
1292078
    builder: &mut DataSubstitutionBuilder,
28
1292078
    tp: &ThreadTermPool,
29
1292078
    t: &DataExpressionRef<'_>,
30
1292078
    new_subterm: DataExpression,
31
1292078
    position: &DataPosition,
32
1292078
) -> DataExpression {
33
1292078
    substitute_rec(tp, t, new_subterm, position.indices(), builder, 0)
34
1292078
}
35

            
36
/// The recursive implementation for [data_substitute]
37
///
38
/// Uses `depth` to keep track of the depth in 't', initially 0.
39
1467662
fn substitute_rec(
40
1467662
    tp: &ThreadTermPool,
41
1467662
    t: &DataExpressionRef<'_>,
42
1467662
    new_subterm: DataExpression,
43
1467662
    p: &[usize],
44
1467662
    args: &mut DataSubstitutionBuilder,
45
1467662
    depth: usize,
46
1467662
) -> DataExpression {
47
1467662
    if p.len() == depth {
48
        // in this case we have arrived at the place where 'new_subterm' needs to be injected
49
1292079
        new_subterm
50
    } else {
51
        // else recurse deeper into 't', do not subtract 1 from the index, since we are using DataPosition
52
175583
        let new_child_index = p[depth];
53
175583
        let new_child = substitute_rec(tp, &t.arg(new_child_index).into(), new_subterm, p, args, depth + 1);
54

            
55
175583
        debug_assert!(
56
175583
            is_data_application(t),
57
            "Can only perform data substitution on DataApplications"
58
        );
59

            
60
175583
        let mut write_args = args.write();
61
623047
        for (index, arg) in t.arguments().enumerate() {
62
623047
            if index == new_child_index {
63
175583
                let t = write_args.protect(&new_child);
64
175583
                write_args.push(t.into());
65
447464
            } else {
66
447464
                let t = write_args.protect(&arg);
67
447464
                write_args.push(t.into());
68
447464
            }
69
        }
70

            
71
        // Avoid the (more expensive) DataApplication constructor by simply having the data_function_symbol in args.
72
175583
        let result = tp.create_term(&t.get_head_symbol(), &write_args);
73
175583
        drop(write_args);
74

            
75
        // Clear the args buffer for reuse.
76
175583
        args.write().clear();
77
175583
        result.protect().into()
78
    }
79
1467662
}
80

            
81
#[cfg(test)]
82
mod tests {
83
    use super::*;
84

            
85
    use merc_aterm::storage::THREAD_TERM_POOL;
86

            
87
    use crate::utilities::DataPosition;
88
    use crate::utilities::DataPositionIndexed;
89

            
90
    #[test]
91
1
    fn test_data_substitute() {
92
1
        let t = DataExpression::from_string("s(s(a))").unwrap();
93
1
        let t0 = DataExpression::from_string("0").unwrap();
94

            
95
        // substitute the a for 0 in the term s(s(a))
96
1
        let result =
97
1
            THREAD_TERM_POOL.with_borrow(|tp| data_substitute(tp, &t.copy(), t0.clone(), &DataPosition::new(&[1, 1])));
98

            
99
        // Check that indeed the new term as a 0 at position 1.1.
100
1
        assert_eq!(t0, result.get_data_position(&DataPosition::new(&vec![1, 1])).protect());
101
1
    }
102
}