1
use std::fmt;
2
use std::ops::Deref;
3

            
4
use ahash::AHashSet;
5
use delegate::delegate;
6

            
7
use merc_aterm::ATerm;
8
use merc_aterm::ATermArgs;
9
use merc_aterm::ATermIndex;
10
use merc_aterm::ATermRef;
11
use merc_aterm::ATermString;
12
use merc_aterm::Markable;
13
use merc_aterm::Symb;
14
use merc_aterm::SymbolRef;
15
use merc_aterm::Term;
16
use merc_aterm::TermBuilder;
17
use merc_aterm::TermIterator;
18
use merc_aterm::Transmutable;
19
use merc_aterm::Yield;
20
use merc_aterm::storage::Marker;
21
use merc_aterm::storage::THREAD_TERM_POOL;
22
use merc_macros::merc_derive_terms;
23
use merc_macros::merc_ignore;
24
use merc_macros::merc_term;
25

            
26
use crate::DATA_SYMBOLS;
27
use crate::SortExpression;
28
use crate::SortExpressionRef;
29
use crate::is_data_application;
30
use crate::is_data_expression;
31
use crate::is_data_function_symbol;
32
use crate::is_data_machine_number;
33
use crate::is_data_variable;
34

            
35
// This module is only used internally to run the proc macro.
36
#[merc_derive_terms]
37
mod inner {
38

            
39
    use std::iter;
40

            
41
    use merc_aterm::ATermIntRef;
42
    use merc_aterm::ATermStringRef;
43
    use merc_utilities::MercError;
44

            
45
    use super::*;
46

            
47
    /// A data expression is an [merc_aterm::ATerm] with additional structure.
48
    ///
49
    /// # Details
50
    ///
51
    /// A data expression can be any of:
52
    ///     - a variable
53
    ///     - a function symbol, i.e. f without arguments.
54
    ///     - a term applied to a number of arguments, i.e., t_0(t1, ..., tn).
55
    ///     - an abstraction lambda x: Sort . e, or forall and exists.
56
    ///     - machine number, a value [0, ..., 2^64-1].
57
    ///
58
    /// Not supported:
59
    ///     - a where clause "e where [x := f, ...]"
60
    ///     - set enumeration
61
    ///     - bag enumeration
62
    ///
63
    #[merc_term(is_data_expression)]
64
    pub struct DataExpression {
65
        term: ATerm,
66
    }
67

            
68
    impl DataExpression {
69
        /// Returns the head symbol a data expression
70
        ///     - function symbol                  f -> f
71
        ///     - application       f(t_0, ..., t_n) -> f
72
281812554
        pub fn data_function_symbol(&self) -> DataFunctionSymbolRef<'_> {
73
281812554
            if is_data_application(&self.term) {
74
231417033
                self.term.arg(0).into()
75
50395521
            } else if is_data_function_symbol(&self.term) {
76
50395521
                self.term.copy().into()
77
            } else {
78
                // This can only happen if the term is an incorrect data expression.
79
                panic!("data_function_symbol not implemented for {self}");
80
            }
81
281812554
        }
82

            
83
        /// Returns the arguments of a data expression
84
        ///     - function symbol                  f -> []
85
        ///     - application       f(t_0, ..., t_n) -> [t_0, ..., t_n]
86
        #[merc_ignore]
87
1062895
        pub fn data_arguments(&self) -> impl ExactSizeIterator<Item = DataExpressionRef<'_>> + use<'_> {
88
1062895
            let mut result = self.term.arguments();
89
1062895
            if is_data_application(&self.term) {
90
923889
                result.next();
91
923889
            } else if is_data_function_symbol(&self.term) || is_data_variable(&self.term) {
92
139006
                result.next();
93
139006
                result.next();
94
139006
            } else {
95
                // This can only happen if the term is an incorrect data expression.
96
                panic!("data_arguments not implemented for {self}");
97
            }
98

            
99
1090923
            result.map(|t| t.into())
100
1062895
        }
101

            
102
        /// Creates a closed [DataExpression] from a string, i.e., has no free variables.
103
        #[merc_ignore]
104
4341
        pub fn from_string(text: &str) -> Result<DataExpression, MercError> {
105
4341
            Ok(to_untyped_data_expression(ATerm::from_string(text)?, None))
106
4341
        }
107

            
108
        /// Creates a [DataExpression] from a string with free untyped variables indicated by the set of names.
109
        #[merc_ignore]
110
42
        pub fn from_string_untyped(text: &str, variables: &AHashSet<String>) -> Result<DataExpression, MercError> {
111
42
            Ok(to_untyped_data_expression(ATerm::from_string(text)?, Some(variables)))
112
42
        }
113

            
114
        /// Returns the ith argument of a data application.
115
        #[merc_ignore]
116
2
        pub fn data_arg(&self, index: usize) -> DataExpressionRef<'_> {
117
2
            debug_assert!(is_data_application(self), "Term {self:?} is not a data application");
118
2
            debug_assert!(
119
2
                index + 1 < self.get_head_symbol().arity(),
120
                "data_arg({index}) is not defined for term {self:?}"
121
            );
122

            
123
2
            self.term.arg(index + 1).into()
124
2
        }
125

            
126
        /// Returns the arguments of a data expression
127
        ///     - function symbol                  f -> []
128
        ///     - application       f(t_0, ..., t_n) -> [t_0, ..., t_n]
129
        pub fn data_sort(&self) -> SortExpression {
130
            if is_data_function_symbol(&self.term) {
131
                DataFunctionSymbolRef::from(self.term.copy()).sort().protect()
132
            } else if is_data_variable(&self.term) {
133
                DataVariableRef::from(self.term.copy()).sort().protect()
134
            } else {
135
                panic!("data_sort not implemented for {self}");
136
            }
137
        }
138
    }
139

            
140
    impl fmt::Display for DataExpression {
141
16591
        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
142
16591
            if is_data_function_symbol(&self.term) {
143
9479
                write!(f, "{}", DataFunctionSymbolRef::from(self.term.copy()))
144
7112
            } else if is_data_application(&self.term) {
145
6062
                write!(f, "{}", DataApplicationRef::from(self.term.copy()))
146
1050
            } else if is_data_variable(&self.term) {
147
                write!(f, "{}", DataVariableRef::from(self.term.copy()))
148
1050
            } else if is_data_machine_number(&self.term) {
149
1050
                write!(f, "{}", MachineNumberRef::from(self.term.copy()))
150
            } else {
151
                write!(f, "{}", self.term)
152
            }
153
16591
        }
154
    }
155

            
156
    #[merc_term(is_data_function_symbol)]
157
    pub struct DataFunctionSymbol {
158
        term: ATerm,
159
    }
160

            
161
    impl DataFunctionSymbol {
162
        #[merc_ignore]
163
120845
        pub fn new<N>(name: N) -> DataFunctionSymbol
164
120845
        where
165
120845
            N: Into<ATermString> + AsRef<str>,
166
        {
167
120845
            DATA_SYMBOLS.with_borrow(|ds| DataFunctionSymbol {
168
120845
                term: ATerm::with_args(
169
120845
                    ds.data_function_symbol.deref(),
170
120845
                    &[Into::<ATerm>::into(name.into()), SortExpression::unknown_sort().into()],
171
120845
                )
172
120845
                .protect(),
173
120845
            })
174
120845
        }
175

            
176
        /// Returns the name of the function symbol
177
15545
        pub fn name(&self) -> ATermStringRef<'_> {
178
15545
            ATermStringRef::from(self.term.arg(0))
179
15545
        }
180

            
181
        /// Returns the sort of the function symbol.
182
        pub fn sort(&self) -> SortExpressionRef<'_> {
183
            self.term.arg(1).into()
184
        }
185

            
186
        /// Returns the internal operation id (a unique number) for the data::function_symbol.
187
178369604
        pub fn operation_id(&self) -> usize {
188
178369604
            self.term.index()
189
178369604
        }
190
    }
191

            
192
    impl fmt::Display for DataFunctionSymbol {
193
15543
        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194
15543
            write!(f, "{}", self.name())
195
15543
        }
196
    }
197

            
198
    #[merc_term(is_data_variable)]
199
    pub struct DataVariable {
200
        term: ATerm,
201
    }
202

            
203
    impl DataVariable {
204
        /// Create a new untyped variable with the given name.
205
        #[merc_ignore]
206
33898
        pub fn new<N: Into<ATermString>>(name: N) -> DataVariable {
207
33898
            DATA_SYMBOLS.with_borrow(|ds| {
208
                // TODO: Storing terms temporarily is not optimal.
209
33898
                let t = name.into();
210
33898
                let args: &[ATerm] = &[t.into(), SortExpression::unknown_sort().into()];
211

            
212
33898
                DataVariable {
213
33898
                    term: ATerm::with_args(ds.data_variable.deref(), args).protect(),
214
33898
                }
215
33898
            })
216
33898
        }
217

            
218
        /// Create a variable with the given sort and name.
219
        pub fn with_sort<N: Into<ATermString>>(name: N, sort: SortExpressionRef<'_>) -> DataVariable {
220
            DATA_SYMBOLS.with_borrow(|ds| {
221
                // TODO: Storing terms temporarily is not optimal.
222
                let t = name.into();
223
                let args: &[ATermRef<'_>] = &[t.copy().into(), sort.into()];
224

            
225
                DataVariable {
226
                    term: ATerm::with_args(ds.data_variable.deref(), args).protect(),
227
                }
228
            })
229
        }
230

            
231
        /// Returns the name of the variable.
232
8372
        pub fn name(&self) -> &str {
233
            // We only change the lifetime, but that is fine since it is derived from the current term.
234
8372
            self.term.arg(0).get_head_symbol().name()
235
8372
        }
236

            
237
        /// Returns the sort of the variable.
238
        pub fn sort(&self) -> SortExpressionRef<'_> {
239
            self.term.arg(1).into()
240
        }
241
    }
242

            
243
    impl fmt::Display for DataVariable {
244
        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
245
            write!(f, "{}", self.name())
246
        }
247
    }
248

            
249
    #[merc_term(is_data_application)]
250
    pub struct DataApplication {
251
        term: ATerm,
252
    }
253

            
254
    impl DataApplication {
255
        /// Create a new data application with the given head and arguments.
256
        #[merc_ignore]
257
1090
        pub fn with_args<'a, 'b, H: Term<'a, 'b>, T: Term<'a, 'b>>(head: &'b H, arguments: &'b [T]) -> DataApplication {
258
1090
            DATA_SYMBOLS.with_borrow_mut(|ds| {
259
1090
                let symbol = ds.get_data_application_symbol(arguments.len() + 1).copy();
260

            
261
1973
                let args = iter::once(head.copy()).chain(arguments.iter().map(|t| t.copy()));
262
1090
                let term = ATerm::with_iter(&symbol, args);
263

            
264
1090
                DataApplication { term }
265
1090
            })
266
1090
        }
267

            
268
        /// Create a new data application with the given head and arguments.
269
        ///
270
        /// arity must be equal to the number of arguments + 1.
271
        #[merc_ignore]
272
4359807
        pub fn with_iter<'a, 'b, 'c, 'd, T, H, I>(head: &'b H, arity: usize, arguments: I) -> DataApplication
273
4359807
        where
274
4359807
            I: Iterator<Item = T>,
275
4359807
            T: Term<'c, 'd>,
276
4359807
            H: Term<'a, 'b>,
277
        {
278
4359807
            DATA_SYMBOLS.with_borrow_mut(|ds| {
279
4359807
                let symbol = ds.get_data_application_symbol(arity + 1).copy();
280

            
281
4359807
                let term = ATerm::with_iter_head(&symbol, head, arguments);
282

            
283
4359807
                DataApplication { term }
284
4359807
            })
285
4359807
        }
286

            
287
        /// Returns the head symbol a data application
288
6063
        pub fn data_function_symbol(&self) -> DataFunctionSymbolRef<'_> {
289
6063
            self.term.arg(0).into()
290
6063
        }
291

            
292
        /// Returns the arguments of a data application
293
6064
        pub fn data_arguments(&self) -> ATermArgs<'_> {
294
6064
            let mut result = self.term.arguments();
295
6064
            result.next();
296
6064
            result
297
6064
        }
298

            
299
        /// Returns the ith argument of a data application.
300
        pub fn data_arg(&self, index: usize) -> DataExpressionRef<'_> {
301
            debug_assert!(
302
                index + 1 < self.get_head_symbol().arity(),
303
                "data_arg({index}) is not defined for term {self:?}"
304
            );
305

            
306
            self.term.arg(index + 1).into()
307
        }
308

            
309
        /// Returns the sort of a data application.
310
        pub fn sort(&self) -> SortExpressionRef<'_> {
311
            // We only change the lifetime, but that is fine since it is derived from the current term.
312
            SortExpressionRef::from(self.term.arg(0))
313
        }
314
    }
315

            
316
    impl fmt::Display for DataApplication {
317
6063
        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
318
6063
            write!(f, "{}", self.data_function_symbol())?;
319

            
320
6063
            let mut first = true;
321
11075
            for arg in self.data_arguments() {
322
11075
                if !first {
323
5012
                    write!(f, ", ")?;
324
                } else {
325
6063
                    write!(f, "(")?;
326
                }
327

            
328
11075
                write!(f, "{}", DataExpressionRef::from(arg.copy()))?;
329
11075
                first = false;
330
            }
331

            
332
6063
            if !first {
333
6063
                write!(f, ")")?;
334
            }
335

            
336
6063
            Ok(())
337
6063
        }
338
    }
339

            
340
    #[merc_term(is_data_machine_number)]
341
    struct MachineNumber {
342
        pub term: ATerm,
343
    }
344

            
345
    impl MachineNumber {
346
        /// Obtain the underlying value of a machine number.
347
        ///
348
        /// # Safety
349
        ///
350
        /// This method assumes that the term is indeed an integer term, which
351
        /// should be guaranteed by the constructor and the
352
        /// `is_data_machine_number` function.
353
1050
        pub fn value(&self) -> u64 {
354
1050
            Into::<ATermIntRef<'_>>::into(self.term.copy()).value() as u64
355
1050
        }
356
    }
357

            
358
    impl fmt::Display for MachineNumber {
359
1050
        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
360
1050
            write!(f, "{}", self.value())
361
1050
        }
362
    }
363

            
364
    /// Conversions to `DataExpression`
365
    #[merc_ignore]
366
    impl From<DataFunctionSymbol> for DataExpression {
367
13318340
        fn from(value: DataFunctionSymbol) -> Self {
368
13318340
            value.term.into()
369
13318340
        }
370
    }
371

            
372
    #[merc_ignore]
373
    impl From<DataApplication> for DataExpression {
374
59917649
        fn from(value: DataApplication) -> Self {
375
59917649
            value.term.into()
376
59917649
        }
377
    }
378

            
379
    #[merc_ignore]
380
    impl From<DataVariable> for DataExpression {
381
4200
        fn from(value: DataVariable) -> Self {
382
4200
            value.term.into()
383
4200
        }
384
    }
385

            
386
    #[merc_ignore]
387
    impl From<DataExpression> for DataFunctionSymbol {
388
35000
        fn from(value: DataExpression) -> Self {
389
35000
            value.term.into()
390
35000
        }
391
    }
392

            
393
    #[merc_ignore]
394
    impl From<DataExpression> for DataVariable {
395
403844
        fn from(value: DataExpression) -> Self {
396
403844
            value.term.into()
397
403844
        }
398
    }
399

            
400
    #[merc_ignore]
401
    impl<'a> From<DataExpressionRef<'a>> for DataVariableRef<'a> {
402
506926
        fn from(value: DataExpressionRef<'a>) -> Self {
403
506926
            value.term.into()
404
506926
        }
405
    }
406
}
407

            
408
pub use inner::*;
409

            
410
impl<'a> DataExpressionRef<'a> {
411
35260190
    pub fn data_arguments(&self) -> impl ExactSizeIterator<Item = DataExpressionRef<'a>> + use<'a> {
412
35260190
        let mut result = self.term.arguments();
413
35260190
        if is_data_application(&self.term) {
414
27111448
            result.next();
415
27111448
        } else if is_data_function_symbol(&self.term) || is_data_variable(&self.term) {
416
8148742
            result.next();
417
8148742
            result.next();
418
8148742
        } else {
419
            // This can only happen if the term is not a data expression.
420
            panic!("data_arguments not implemented for {self}");
421
        }
422

            
423
35260190
        result.map(|t| t.into())
424
35260190
    }
425

            
426
    /// Returns the ith argument of a data application.
427
229011847
    pub fn data_arg(&self, index: usize) -> DataExpressionRef<'a> {
428
229011847
        debug_assert!(is_data_application(self), "Term {self:?} is not a data application");
429
229011847
        debug_assert!(
430
229011847
            index + 1 < self.get_head_symbol().arity(),
431
            "data_arg({index}) is not defined for term {self:?}"
432
        );
433

            
434
229011847
        self.term.arg(index + 1).into()
435
229011847
    }
436
}
437

            
438
/// Converts an [ATerm] to an untyped data expression.
439
30955
pub fn to_untyped_data_expression(t: ATerm, variables: Option<&AHashSet<String>>) -> DataExpression {
440
30955
    let mut builder = TermBuilder::<ATerm, ATerm>::new();
441
30955
    THREAD_TERM_POOL.with_borrow(|tp| {
442
30955
        builder
443
30955
            .evaluate(
444
30955
                tp,
445
30955
                t,
446
147439
                |_tp, args, t| {
447
147439
                    if variables.is_some_and(|v| v.contains(t.get_head_symbol().name())) {
448
                        // Convert a constant variable, for example 'x', into an untyped variable.
449
26642
                        Ok(Yield::Term(DataVariable::new(t.get_head_symbol().name()).into()))
450
120797
                    } else if t.get_head_symbol().arity() == 0 {
451
39735
                        Ok(Yield::Term(DataFunctionSymbol::new(t.get_head_symbol().name()).into()))
452
                    } else {
453
                        // This is a function symbol applied to a number of arguments
454
81062
                        let head = DataFunctionSymbol::new(t.get_head_symbol().name());
455

            
456
116484
                        for arg in t.arguments() {
457
116484
                            args.push(arg.protect());
458
116484
                        }
459

            
460
81062
                        Ok(Yield::Construct(head.into()))
461
                    }
462
147439
                },
463
81062
                |_tp, input, args| {
464
81062
                    let arity = args.clone().count();
465
81062
                    Ok(DataApplication::with_iter(&input, arity, args).into())
466
81062
                },
467
            )
468
30955
            .unwrap()
469
30955
            .into()
470
30955
    })
471
30955
}
472

            
473
#[cfg(test)]
474
mod tests {
475
    use super::*;
476

            
477
    use merc_aterm::ATerm;
478

            
479
    #[test]
480
1
    fn test_print() {
481
1
        let _ = merc_utilities::test_logger();
482

            
483
1
        let a = DataFunctionSymbol::new("a");
484
1
        assert_eq!("a", format!("{}", a));
485

            
486
        // Check printing of data applications.
487
1
        let f = DataFunctionSymbol::new("f");
488
1
        let appl = DataApplication::with_args(&f, &[a]);
489
1
        assert_eq!("f(a)", format!("{}", appl));
490
1
    }
491

            
492
    #[test]
493
1
    fn test_recognizers() {
494
1
        let a = DataFunctionSymbol::new("a");
495
1
        let f = DataFunctionSymbol::new("f");
496
1
        let appl = DataApplication::with_args(&f, &[a]);
497

            
498
1
        let term: ATerm = appl.into();
499
1
        assert!(is_data_application(&term));
500
1
    }
501

            
502
    #[test]
503
1
    fn test_data_arguments() {
504
1
        let a = DataFunctionSymbol::new("a");
505
1
        let f = DataFunctionSymbol::new("f");
506
1
        let appl = DataApplication::with_args(&f, &[a]);
507

            
508
1
        assert_eq!(appl.data_arguments().count(), 1);
509

            
510
1
        let data_expr: DataExpression = appl.clone().into();
511

            
512
1
        assert_eq!(data_expr.data_arguments().count(), 1);
513
1
    }
514

            
515
    #[test]
516
1
    fn test_to_data_expression() {
517
1
        let expression = DataExpression::from_string("s(s(a, b), c)").unwrap();
518

            
519
1
        assert_eq!(expression.data_arg(0).data_function_symbol().name(), "s");
520
1
        assert_eq!(expression.data_arg(0).data_arg(0).data_function_symbol().name(), "a");
521
1
    }
522
}