1
use std::collections::VecDeque;
2
use std::fmt::Debug;
3
use std::time::Instant;
4

            
5
use ahash::HashMap;
6
use log::debug;
7
use log::info;
8
use log::log_enabled;
9
use log::trace;
10
use log::warn;
11
use merc_aterm::Term;
12
use merc_data::DataExpression;
13
use merc_data::DataExpressionRef;
14
use merc_data::DataFunctionSymbol;
15
use merc_data::is_data_application;
16
use merc_data::is_data_function_symbol;
17
use merc_data::is_data_machine_number;
18
use merc_data::is_data_variable;
19
use smallvec::SmallVec;
20
use smallvec::smallvec;
21

            
22
use crate::rewrite_specification::RewriteSpecification;
23
use crate::rewrite_specification::Rule;
24
use crate::utilities::DataPosition;
25

            
26
use super::DotFormatter;
27
use super::MatchGoal;
28

            
29
/// The Set Automaton used to find all matching patterns in a term.
30
pub struct SetAutomaton<T> {
31
    states: Vec<State>,
32
    transitions: HashMap<(usize, usize), Transition<T>>,
33
}
34

            
35
/// A match announcement contains the rule that can be announced as a match at
36
/// the given position.
37
///
38
/// `symbols_seen` is internally used to keep track of how many symbols have
39
/// been observed so far. Since these symbols have a unique number this can be
40
/// used to speed up certain operations.
41
#[derive(Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
42
pub struct MatchAnnouncement {
43
    pub rule: Rule,
44
    pub position: DataPosition,
45
    pub symbols_seen: usize,
46
}
47

            
48
/// Represents a transition in the [SetAutomaton].
49
#[derive(Clone)]
50
pub struct Transition<T> {
51
    pub symbol: DataFunctionSymbol,
52
    pub announcements: SmallVec<[(MatchAnnouncement, T); 1]>,
53
    pub destinations: SmallVec<[(DataPosition, usize); 1]>,
54
}
55

            
56
/// Represents a match obligation in the [SetAutomaton].
57
#[derive(Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
58
pub struct MatchObligation {
59
    pub pattern: DataExpression,
60
    pub position: DataPosition,
61
}
62

            
63
impl MatchObligation {
64
    /// Returns the pattern of the match obligation
65
1077614
    pub fn new(pattern: DataExpression, position: DataPosition) -> Self {
66
1077614
        MatchObligation { pattern, position }
67
1077614
    }
68
}
69

            
70
/// Represents either the initial state or a set of match goals in the
71
/// [SetAutomaton].
72
///
73
/// This is only used during construction to avoid craating all the goals for
74
/// the initial state.
75
#[derive(Debug)]
76
enum GoalsOrInitial {
77
    InitialState,
78
    Goals(Vec<MatchGoal>),
79
}
80

            
81
impl<M> SetAutomaton<M> {
82
    /// Creates a new SetAutomaton from the given rewrite specification. If
83
    /// `apma` is true an Adaptive Pattern Matching Automaton is created,
84
    /// meaning that it only finds matches at the root position.
85
    ///
86
    /// The `annotate` function is used to create the annotation for each match
87
    /// announcement. This is used to accomondate different types of annotations
88
    /// for the different rewrite engines.
89
57
    pub fn new(spec: &RewriteSpecification, annotate: impl Fn(&Rule) -> M, apma: bool) -> SetAutomaton<M> {
90
57
        let start = Instant::now();
91

            
92
        // States are labelled s0, s1, s2, etcetera. state_counter keeps track of count.
93
57
        let mut state_counter: usize = 1;
94

            
95
        // Remove rules that we cannot deal with
96
57
        let supported_rules: Vec<Rule> = spec
97
57
            .rewrite_rules()
98
57
            .iter()
99
1560
            .filter(|rule| is_supported_rule(rule))
100
57
            .map(Rule::clone)
101
57
            .collect();
102

            
103
        // Find the indices of all the function symbols.
104
57
        let symbols = {
105
57
            let mut symbols = HashMap::default();
106

            
107
1560
            for rule in &supported_rules {
108
1560
                find_symbols(&rule.lhs.copy(), &mut symbols);
109
1560
                find_symbols(&rule.rhs.copy(), &mut symbols);
110

            
111
1560
                for cond in &rule.conditions {
112
160
                    find_symbols(&cond.lhs.copy(), &mut symbols);
113
160
                    find_symbols(&cond.rhs.copy(), &mut symbols);
114
160
                }
115
            }
116

            
117
57
            symbols
118
        };
119

            
120
1221
        for (index, (symbol, arity)) in symbols.iter().enumerate() {
121
1221
            trace!("{index}: {symbol} {arity}");
122
        }
123

            
124
        // The initial state has a match goals for each pattern. For each pattern l there is a match goal
125
        // with one obligation l@ε and announcement l@ε.
126
57
        let mut initial_match_goals = Vec::<MatchGoal>::new();
127
1560
        for rr in &supported_rules {
128
1560
            initial_match_goals.push(MatchGoal::new(
129
1560
                MatchAnnouncement {
130
1560
                    rule: (*rr).clone(),
131
1560
                    position: DataPosition::empty(),
132
1560
                    symbols_seen: 0,
133
1560
                },
134
1560
                vec![MatchObligation::new(rr.lhs.clone(), DataPosition::empty())],
135
1560
            ));
136
1560
        }
137

            
138
        // Match goals need to be sorted so that we can easily check whether a state with a certain
139
        // set of match goals already exists.
140
57
        initial_match_goals.sort();
141

            
142
        // Create the initial state
143
57
        let initial_state = State {
144
57
            label: DataPosition::empty(),
145
57
            match_goals: initial_match_goals.clone(),
146
57
        };
147

            
148
        // HashMap from goals to state number
149
57
        let mut map_goals_state = HashMap::default();
150

            
151
        // Queue of states that still need to be explored
152
57
        let mut queue = VecDeque::new();
153
57
        queue.push_back(0);
154

            
155
57
        map_goals_state.insert(initial_match_goals, 0);
156

            
157
57
        let mut states = vec![initial_state];
158
57
        let mut transitions = HashMap::default();
159

            
160
        // Pick a state to explore
161
1284
        while let Some(s_index) = queue.pop_front() {
162
48653
            for (symbol, arity) in &symbols {
163
48653
                let (mut announcements, pos_to_goals) =
164
48653
                    states
165
48653
                        .get(s_index)
166
48653
                        .unwrap()
167
48653
                        .derive_transition(symbol, *arity, &supported_rules, apma);
168

            
169
48653
                announcements.sort_by(|ma1, ma2| ma1.position.cmp(&ma2.position));
170

            
171
                // For the destinations we convert the match goal destinations to states
172
48653
                let mut destinations = smallvec![];
173

            
174
                // Loop over the hypertransitions
175
61467
                for (pos, goals_or_initial) in pos_to_goals {
176
                    // Match goals need to be sorted so that we can easily check whether a state with a certain
177
                    // set of match goals already exists.
178
61467
                    if let GoalsOrInitial::Goals(goals) = goals_or_initial {
179
                        // This code cannot be replaced by the entry since contains_key takes a reference.
180
                        #[allow(clippy::map_entry)]
181
47454
                        if map_goals_state.contains_key(&goals) {
182
                            // The destination state already exists
183
46284
                            destinations.push((pos, *map_goals_state.get(&goals).unwrap()))
184
1170
                        } else if !goals.is_empty() {
185
1170
                            // The destination state does not yet exist, create it
186
1170
                            let new_state = State::new(goals.clone());
187
1170
                            states.push(new_state);
188
1170
                            destinations.push((pos, state_counter));
189
1170
                            map_goals_state.insert(goals, state_counter);
190
1170
                            queue.push_back(state_counter);
191
1170
                            state_counter += 1;
192
1170
                        }
193
14013
                    } else {
194
14013
                        // The transition is to the initial state
195
14013
                        destinations.push((pos, 0));
196
14013
                    }
197
                }
198

            
199
                // Add the annotation for every match announcement.
200
48653
                let announcements = announcements
201
48653
                    .into_iter()
202
48653
                    .map(|ma| {
203
15554
                        let annotation = annotate(&ma.rule);
204
15554
                        (ma, annotation)
205
15554
                    })
206
48653
                    .collect();
207

            
208
                // Add the resulting outgoing transition to the state.
209
48653
                debug_assert!(
210
48653
                    !&transitions.contains_key(&(s_index, symbol.operation_id())),
211
                    "Set automaton should not contain duplicated transitions"
212
                );
213
48653
                transitions.insert(
214
48653
                    (s_index, symbol.operation_id()),
215
48653
                    Transition {
216
48653
                        symbol: symbol.clone(),
217
48653
                        announcements,
218
48653
                        destinations,
219
48653
                    },
220
                );
221
            }
222

            
223
1227
            debug!(
224
                "Queue size {}, currently {} states and {} transitions",
225
                queue.len(),
226
                states.len(),
227
                transitions.len()
228
            );
229
        }
230

            
231
        // Clear the match goals since they are only for debugging purposes.
232
57
        if !log_enabled!(log::Level::Debug) {
233
1227
            for state in &mut states {
234
1227
                state.match_goals.clear();
235
1227
            }
236
        }
237
57
        info!(
238
            "Created set automaton (states: {}, transitions: {}, apma: {}) in {} ms",
239
            states.len(),
240
            transitions.len(),
241
            apma,
242
            (Instant::now() - start).as_millis()
243
        );
244

            
245
57
        let result = SetAutomaton { states, transitions };
246
57
        debug!("{result:?}");
247

            
248
57
        result
249
57
    }
250

            
251
    /// Returns the number of states
252
    pub fn num_of_states(&self) -> usize {
253
        self.states.len()
254
    }
255

            
256
    /// Returns the number of transitions
257
    pub fn num_of_transitions(&self) -> usize {
258
        self.transitions.len()
259
    }
260

            
261
    /// Returns the states of the automaton
262
13821450
    pub fn states(&self) -> &Vec<State> {
263
13821450
        &self.states
264
13821450
    }
265

            
266
    /// Returns the transitions of the automaton
267
12653643
    pub fn transitions(&self) -> &HashMap<(usize, usize), Transition<M>> {
268
12653643
        &self.transitions
269
12653643
    }
270

            
271
    /// Provides a formatter for the .dot file format
272
    pub fn to_dot_graph(&self, show_backtransitions: bool, show_final: bool) -> DotFormatter<'_, M> {
273
        DotFormatter {
274
            automaton: self,
275
            show_backtransitions,
276
            show_final,
277
        }
278
    }
279
}
280

            
281
#[derive(Debug)]
282
pub struct Derivative {
283
    pub completed: Vec<MatchGoal>,
284
    pub unchanged: Vec<MatchGoal>,
285
    pub reduced: Vec<MatchGoal>,
286
}
287

            
288
pub struct State {
289
    label: DataPosition,
290
    match_goals: Vec<MatchGoal>,
291
}
292

            
293
impl State {
294
    /// Derive transitions from a state given a head symbol. The resulting transition is returned as a tuple
295
    /// The tuple consists of a vector of outputs and a set of destinations (which are sets of match goals).
296
    /// We don't use the struct Transition as it requires that the destination is a full state, with name.
297
    /// Since we don't yet know whether the state already exists we just return a set of match goals as 'state'.
298
    ///
299
    /// Parameter symbol is the symbol for which the transition is computed
300
48653
    fn derive_transition(
301
48653
        &self,
302
48653
        symbol: &DataFunctionSymbol,
303
48653
        arity: usize,
304
48653
        rewrite_rules: &Vec<Rule>,
305
48653
        apma: bool,
306
48653
    ) -> (Vec<MatchAnnouncement>, Vec<(DataPosition, GoalsOrInitial)>) {
307
        // Computes the derivative containing the goals that are completed, unchanged and reduced
308
48653
        let mut derivative = self.compute_derivative(symbol, arity);
309

            
310
        // The outputs/matching patterns of the transitions are those who are completed
311
48653
        let outputs = derivative.completed.into_iter().map(|x| x.announcement).collect();
312

            
313
        // The new match goals are the unchanged and reduced match goals.
314
48653
        let mut new_match_goals = derivative.unchanged;
315
48653
        new_match_goals.append(&mut derivative.reduced);
316

            
317
48653
        let mut destinations = vec![];
318
        // If we are building an APMA we do not deepen the position or create a hypertransitions
319
        // with multiple endpoints
320
48653
        if apma {
321
14224
            if !new_match_goals.is_empty() {
322
802
                destinations.push((DataPosition::empty(), GoalsOrInitial::Goals(new_match_goals)));
323
13422
            }
324
        } else {
325
            // In case we are building a set automaton we partition the match goals
326
34429
            let partitioned = MatchGoal::partition(new_match_goals);
327

            
328
            // Get the greatest common prefix and shorten the positions
329
34429
            let mut positions_per_partition = vec![];
330
34429
            let mut gcp_length_per_partition = vec![];
331
46652
            for (p, pos) in partitioned {
332
46652
                positions_per_partition.push(pos);
333
46652
                let gcp = MatchGoal::greatest_common_prefix(&p);
334
46652
                let gcp_length = gcp.len();
335
46652
                gcp_length_per_partition.push(gcp_length);
336
46652
                let mut goals = MatchGoal::remove_prefix(p, gcp_length);
337
46652
                goals.sort_unstable();
338
46652
                destinations.push((gcp, GoalsOrInitial::Goals(goals)));
339
46652
            }
340

            
341
            // Handle fresh match goals, they are the positions Label(state).i
342
            // where i is between 1 and the arity of the function symbol of
343
            // the transition. Position 1 is the first argument.
344
34429
            for i in 1..arity + 1 {
345
31805
                let mut pos = self.label.clone();
346
31805
                pos.push(i);
347

            
348
                // Check if the fresh goals are related to one of the existing partitions
349
31805
                let mut partition_key = None;
350
52695
                'outer: for (i, part_pos) in positions_per_partition.iter().enumerate() {
351
79620
                    for p in part_pos {
352
79620
                        if MatchGoal::pos_comparable(p, &pos) {
353
17792
                            partition_key = Some(i);
354
17792
                            break 'outer;
355
61828
                        }
356
                    }
357
                }
358

            
359
31805
                if let Some(key) = partition_key {
360
                    // If the fresh goals fall in an existing partition
361
17792
                    let gcp_length = gcp_length_per_partition[key];
362
17792
                    let pos = DataPosition::new(&pos.indices()[gcp_length..]);
363

            
364
                    // Add the fresh goals to the partition
365
1076054
                    for rr in rewrite_rules {
366
1076054
                        if let GoalsOrInitial::Goals(goals) = &mut destinations[key].1 {
367
1076054
                            goals.push(MatchGoal {
368
1076054
                                obligations: vec![MatchObligation::new(rr.lhs.clone(), pos.clone())],
369
1076054
                                announcement: MatchAnnouncement {
370
1076054
                                    rule: (*rr).clone(),
371
1076054
                                    position: pos.clone(),
372
1076054
                                    symbols_seen: 0,
373
1076054
                                },
374
1076054
                            });
375
1076054
                        }
376
                    }
377
14013
                } else {
378
14013
                    // The transition is simply to the initial state
379
14013
                    // GoalsOrInitial::InitialState avoids unnecessary work of creating all these fresh goals
380
14013
                    destinations.push((pos, GoalsOrInitial::InitialState));
381
14013
                }
382
            }
383
        }
384

            
385
        // Sort the destination such that transitions which do not deepen the position are listed first
386
48653
        destinations.sort_unstable_by(|(pos1, _), (pos2, _)| pos1.cmp(pos2));
387
48653
        (outputs, destinations)
388
48653
    }
389

            
390
    /// For a transition 'symbol' of state 'self' this function computes which match goals are
391
    /// completed, unchanged and reduced.
392
48653
    fn compute_derivative(&self, symbol: &DataFunctionSymbol, arity: usize) -> Derivative {
393
48653
        let mut result = Derivative {
394
48653
            completed: vec![],
395
48653
            unchanged: vec![],
396
48653
            reduced: vec![],
397
48653
        };
398

            
399
5138007
        for mg in &self.match_goals {
400
5138007
            debug_assert!(
401
5138007
                !mg.obligations.is_empty(),
402
                "The obligations should never be empty, should be completed then"
403
            );
404

            
405
            // Completed match goals
406
5138007
            if mg.obligations.len() == 1
407
5110907
                && mg.obligations.iter().any(|mo| {
408
5110907
                    mo.position == self.label
409
2565108
                        && mo.pattern.data_function_symbol() == symbol.copy()
410
48449
                        && mo.pattern.data_arguments().all(|x| is_data_variable(&x))
411
                    // Again skip the function symbol
412
5110907
                })
413
15554
            {
414
15554
                result.completed.push(mg.clone());
415
5122453
            } else if mg
416
5122453
                .obligations
417
5122453
                .iter()
418
5138864
                .any(|mo| mo.position == self.label && mo.pattern.data_function_symbol() != symbol.copy())
419
2531108
            {
420
2531108
                // Match goal is discarded since head symbol does not match.
421
2607355
            } else if mg.obligations.iter().all(|mo| mo.position != self.label) {
422
                // Unchanged match goals
423
2560969
                let mut mg = mg.clone();
424
2560969
                if mg.announcement.rule.lhs != mg.obligations.first().unwrap().pattern {
425
34869
                    mg.announcement.symbols_seen += 1;
426
2526100
                }
427

            
428
2560969
                result.unchanged.push(mg.clone());
429
            } else {
430
                // Reduce match obligations
431
30376
                let mut mg = mg.clone();
432
30376
                let mut new_obligations = vec![];
433

            
434
30777
                for mo in mg.obligations {
435
30777
                    if mo.pattern.data_function_symbol() == symbol.copy() && mo.position == self.label {
436
                        // Reduced match obligation
437
55500
                        for (index, t) in mo.pattern.data_arguments().enumerate() {
438
55500
                            assert!(
439
55500
                                index < arity,
440
                                "This pattern associates function symbol {:?} with different arities {} and {}",
441
                                symbol,
442
                                index + 1,
443
                                arity
444
                            );
445

            
446
55500
                            if !is_data_variable(&t) {
447
39582
                                let mut new_pos = mo.position.clone();
448
39582
                                new_pos.push(index + 1);
449
39582
                                new_obligations.push(MatchObligation {
450
39582
                                    pattern: t.protect(),
451
39582
                                    position: new_pos,
452
39582
                                });
453
39582
                            }
454
                        }
455
401
                    } else {
456
401
                        // remains unchanged
457
401
                        new_obligations.push(mo.clone());
458
401
                    }
459
                }
460

            
461
30376
                new_obligations.sort_unstable_by(|mo1, mo2| mo1.position.len().cmp(&mo2.position.len()));
462
30376
                mg.obligations = new_obligations;
463
30376
                mg.announcement.symbols_seen += 1;
464

            
465
30376
                result.reduced.push(mg);
466
            }
467
        }
468

            
469
48653
        trace!(
470
            "=== compute_derivative(symbol = {}, label = {}) ===",
471
            symbol, self.label
472
        );
473
48653
        trace!("Match goals: {{");
474
5138007
        for mg in &self.match_goals {
475
5138007
            trace!("\t {mg:?}");
476
        }
477

            
478
48653
        trace!("}}");
479
48653
        trace!("Completed: {{");
480
48653
        for mg in &result.completed {
481
15554
            trace!("\t {mg:?}");
482
        }
483

            
484
48653
        trace!("}}");
485
48653
        trace!("Unchanged: {{");
486
2560969
        for mg in &result.unchanged {
487
2560969
            trace!("\t {mg:?}");
488
        }
489

            
490
48653
        trace!("}}");
491
48653
        trace!("Reduced: {{");
492
48653
        for mg in &result.reduced {
493
30376
            trace!("\t {mg:?}");
494
        }
495
48653
        trace!("}}");
496

            
497
48653
        result
498
48653
    }
499

            
500
    /// Create a state from a set of match goals
501
1170
    fn new(goals: Vec<MatchGoal>) -> State {
502
        // The label of the state is taken from a match obligation of a root match goal.
503
1170
        let mut label: Option<DataPosition> = None;
504

            
505
        // Go through all match goals until a root match goal is found
506
103185
        for goal in &goals {
507
103185
            if goal.announcement.position.is_empty() {
508
                // Find the shortest match obligation position.
509
                // This design decision was taken as it presumably has two advantages.
510
                // 1. Patterns that overlap will be more quickly distinguished, potentially decreasing
511
                // the size of the automaton.
512
                // 2. The average lookahead may be shorter.
513
2987
                if label.is_none() {
514
1170
                    label = Some(goal.obligations.first().unwrap().position.clone());
515
1817
                }
516

            
517
3400
                for obligation in &goal.obligations {
518
3400
                    if let Some(l) = &label {
519
3400
                        if &obligation.position < l {
520
4
                            label = Some(obligation.position.clone());
521
3396
                        }
522
                    }
523
                }
524
100198
            }
525
        }
526

            
527
1170
        State {
528
1170
            label: label.unwrap(),
529
1170
            match_goals: goals,
530
1170
        }
531
1170
    }
532

            
533
    /// Returns the label of the state
534
13821450
    pub fn label(&self) -> &DataPosition {
535
13821450
        &self.label
536
13821450
    }
537

            
538
    /// Returns the match goals of the state
539
    pub fn match_goals(&self) -> &Vec<MatchGoal> {
540
        &self.match_goals
541
    }
542
}
543

            
544
/// Adds the given function symbol to the indexed symbols. Errors when a
545
/// function symbol is overloaded with different arities.
546
7754
fn add_symbol(function_symbol: DataFunctionSymbol, arity: usize, symbols: &mut HashMap<DataFunctionSymbol, usize>) {
547
7754
    if let Some(x) = symbols.get(&function_symbol) {
548
6533
        assert_eq!(
549
            *x, arity,
550
            "Function symbol {function_symbol} occurs with different arities",
551
        );
552
1221
    } else {
553
1221
        symbols.insert(function_symbol, arity);
554
1221
    }
555
7754
}
556

            
557
/// Returns false iff this is a higher order term, of the shape t(t_0, ..., t_n), or an unknown term.
558
3440
fn is_supported_term(t: &DataExpression) -> bool {
559
50978
    for subterm in t.iter() {
560
50978
        if is_data_application(&subterm) && !is_data_function_symbol(&subterm.arg(0)) {
561
            warn!("{} is higher order", &subterm);
562
            return false;
563
50978
        }
564
    }
565

            
566
3440
    true
567
3440
}
568

            
569
/// Checks whether the set automaton can use this rule, no higher order rules or binders.
570
1560
pub fn is_supported_rule(rule: &Rule) -> bool {
571
    // There should be no terms of the shape t(t0,...,t_n)
572
1560
    if !is_supported_term(&rule.rhs) || !is_supported_term(&rule.lhs) {
573
        return false;
574
1560
    }
575

            
576
1560
    for cond in &rule.conditions {
577
160
        if !is_supported_term(&cond.rhs) || !is_supported_term(&cond.lhs) {
578
            return false;
579
160
        }
580
    }
581

            
582
1560
    true
583
1560
}
584

            
585
/// Finds all data symbols in the term and adds them to the symbol index.
586
11431
fn find_symbols(t: &DataExpressionRef<'_>, symbols: &mut HashMap<DataFunctionSymbol, usize>) {
587
11431
    if is_data_function_symbol(t) {
588
2500
        add_symbol(t.protect().into(), 0, symbols);
589
8931
    } else if is_data_application(t) {
590
        // REC specifications should never contain this so it can be a debug error.
591
5254
        assert!(
592
5254
            is_data_function_symbol(&t.data_function_symbol()),
593
            "Error in term {t}, higher order term rewrite systems are not supported"
594
        );
595

            
596
5254
        add_symbol(t.data_function_symbol().protect(), t.data_arguments().len(), symbols);
597
7991
        for arg in t.data_arguments() {
598
7991
            find_symbols(&arg, symbols);
599
7991
        }
600
3677
    } else if is_data_machine_number(t) {
601
        // Ignore machine numbers during matching?
602
3677
    } else if !is_data_variable(t) {
603
        panic!("Unexpected term {t:?}");
604
3677
    }
605
11431
}