1
use std::collections::VecDeque;
2
use std::fmt::Debug;
3
use std::time::Instant;
4

            
5
use ahash::HashMap;
6
use log::debug;
7
use log::info;
8
use log::log_enabled;
9
use log::trace;
10
use log::warn;
11
use merc_aterm::Term;
12
use merc_data::DataExpression;
13
use merc_data::DataExpressionRef;
14
use merc_data::DataFunctionSymbol;
15
use merc_data::is_data_application;
16
use merc_data::is_data_function_symbol;
17
use merc_data::is_data_machine_number;
18
use merc_data::is_data_variable;
19
use smallvec::SmallVec;
20
use smallvec::smallvec;
21

            
22
use crate::rewrite_specification::RewriteSpecification;
23
use crate::rewrite_specification::Rule;
24
use crate::utilities::DataPosition;
25

            
26
use super::DotFormatter;
27
use super::MatchGoal;
28

            
29
/// The Set Automaton used to find all matching patterns in a term.
30
pub struct SetAutomaton<T> {
31
    states: Vec<State>,
32
    transitions: HashMap<(usize, usize), Transition<T>>,
33
}
34

            
35
/// A match announcement contains the rule that can be announced as a match at
36
/// the given position.
37
///
38
/// `symbols_seen` is internally used to keep track of how many symbols have
39
/// been observed so far. Since these symbols have a unique number this can be
40
/// used to speed up certain operations.
41
#[derive(Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
42
pub struct MatchAnnouncement {
43
    pub rule: Rule,
44
    pub position: DataPosition,
45
    pub symbols_seen: usize,
46
}
47

            
48
/// Represents a transition in the [SetAutomaton].
49
#[derive(Clone)]
50
pub struct Transition<T> {
51
    pub symbol: DataFunctionSymbol,
52
    pub announcements: SmallVec<[(MatchAnnouncement, T); 1]>,
53
    pub destinations: SmallVec<[(DataPosition, usize); 1]>,
54
}
55

            
56
/// Represents a match obligation in the [SetAutomaton].
57
#[derive(Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
58
pub struct MatchObligation {
59
    pub pattern: DataExpression,
60
    pub position: DataPosition,
61
}
62

            
63
impl MatchObligation {
64
    /// Returns the pattern of the match obligation
65
1077614
    pub fn new(pattern: DataExpression, position: DataPosition) -> Self {
66
1077614
        MatchObligation { pattern, position }
67
1077614
    }
68
}
69

            
70
/// Represents either the initial state or a set of match goals in the
71
/// [SetAutomaton].
72
///
73
/// This is only used during construction to avoid craating all the goals for
74
/// the initial state.
75
#[derive(Debug)]
76
enum GoalsOrInitial {
77
    InitialState,
78
    Goals(Vec<MatchGoal>),
79
}
80

            
81
impl<M> SetAutomaton<M> {
82
    /// Creates a new SetAutomaton from the given rewrite specification. If
83
    /// `apma` is true an Adaptive Pattern Matching Automaton is created,
84
    /// meaning that it only finds matches at the root position.
85
    ///
86
    /// The `annotate` function is used to create the annotation for each match
87
    /// announcement. This is used to accomondate different types of annotations
88
    /// for the different rewrite engines.
89
57
    pub fn new<F>(spec: &RewriteSpecification, annotate: F, apma: bool) -> SetAutomaton<M>
90
57
    where
91
57
        F: Fn(&Rule) -> M,
92
    {
93
57
        let start = Instant::now();
94

            
95
        // States are labelled s0, s1, s2, etcetera. state_counter keeps track of count.
96
57
        let mut state_counter: usize = 1;
97

            
98
        // Remove rules that we cannot deal with
99
57
        let supported_rules: Vec<Rule> = spec
100
57
            .rewrite_rules()
101
57
            .iter()
102
1560
            .filter(|rule| is_supported_rule(rule))
103
57
            .map(Rule::clone)
104
57
            .collect();
105

            
106
        // Find the indices of all the function symbols.
107
57
        let symbols = {
108
57
            let mut symbols = HashMap::default();
109

            
110
1560
            for rule in &supported_rules {
111
1560
                find_symbols(&rule.lhs.copy(), &mut symbols);
112
1560
                find_symbols(&rule.rhs.copy(), &mut symbols);
113

            
114
1560
                for cond in &rule.conditions {
115
160
                    find_symbols(&cond.lhs.copy(), &mut symbols);
116
160
                    find_symbols(&cond.rhs.copy(), &mut symbols);
117
160
                }
118
            }
119

            
120
57
            symbols
121
        };
122

            
123
1221
        for (index, (symbol, arity)) in symbols.iter().enumerate() {
124
1221
            trace!("{index}: {symbol} {arity}");
125
        }
126

            
127
        // The initial state has a match goals for each pattern. For each pattern l there is a match goal
128
        // with one obligation l@ε and announcement l@ε.
129
57
        let mut initial_match_goals = Vec::<MatchGoal>::new();
130
1560
        for rr in &supported_rules {
131
1560
            initial_match_goals.push(MatchGoal::new(
132
1560
                MatchAnnouncement {
133
1560
                    rule: (*rr).clone(),
134
1560
                    position: DataPosition::empty(),
135
1560
                    symbols_seen: 0,
136
1560
                },
137
1560
                vec![MatchObligation::new(rr.lhs.clone(), DataPosition::empty())],
138
1560
            ));
139
1560
        }
140

            
141
        // Match goals need to be sorted so that we can easily check whether a state with a certain
142
        // set of match goals already exists.
143
57
        initial_match_goals.sort();
144

            
145
        // Create the initial state
146
57
        let initial_state = State {
147
57
            label: DataPosition::empty(),
148
57
            match_goals: initial_match_goals.clone(),
149
57
        };
150

            
151
        // HashMap from goals to state number
152
57
        let mut map_goals_state = HashMap::default();
153

            
154
        // Queue of states that still need to be explored
155
57
        let mut queue = VecDeque::new();
156
57
        queue.push_back(0);
157

            
158
57
        map_goals_state.insert(initial_match_goals, 0);
159

            
160
57
        let mut states = vec![initial_state];
161
57
        let mut transitions = HashMap::default();
162

            
163
        // Pick a state to explore
164
1284
        while let Some(s_index) = queue.pop_front() {
165
48653
            for (symbol, arity) in &symbols {
166
48653
                let (mut announcements, pos_to_goals) =
167
48653
                    states
168
48653
                        .get(s_index)
169
48653
                        .unwrap()
170
48653
                        .derive_transition(symbol, *arity, &supported_rules, apma);
171

            
172
48653
                announcements.sort_by(|ma1, ma2| ma1.position.cmp(&ma2.position));
173

            
174
                // For the destinations we convert the match goal destinations to states
175
48653
                let mut destinations = smallvec![];
176

            
177
                // Loop over the hypertransitions
178
61467
                for (pos, goals_or_initial) in pos_to_goals {
179
                    // Match goals need to be sorted so that we can easily check whether a state with a certain
180
                    // set of match goals already exists.
181
61467
                    if let GoalsOrInitial::Goals(goals) = goals_or_initial {
182
                        // This code cannot be replaced by the entry since contains_key takes a reference.
183
                        #[allow(clippy::map_entry)]
184
47454
                        if map_goals_state.contains_key(&goals) {
185
                            // The destination state already exists
186
46284
                            destinations.push((pos, *map_goals_state.get(&goals).unwrap()))
187
1170
                        } else if !goals.is_empty() {
188
1170
                            // The destination state does not yet exist, create it
189
1170
                            let new_state = State::new(goals.clone());
190
1170
                            states.push(new_state);
191
1170
                            destinations.push((pos, state_counter));
192
1170
                            map_goals_state.insert(goals, state_counter);
193
1170
                            queue.push_back(state_counter);
194
1170
                            state_counter += 1;
195
1170
                        }
196
14013
                    } else {
197
14013
                        // The transition is to the initial state
198
14013
                        destinations.push((pos, 0));
199
14013
                    }
200
                }
201

            
202
                // Add the annotation for every match announcement.
203
48653
                let announcements = announcements
204
48653
                    .into_iter()
205
48653
                    .map(|ma| {
206
15554
                        let annotation = annotate(&ma.rule);
207
15554
                        (ma, annotation)
208
15554
                    })
209
48653
                    .collect();
210

            
211
                // Add the resulting outgoing transition to the state.
212
48653
                debug_assert!(
213
48653
                    !&transitions.contains_key(&(s_index, symbol.operation_id())),
214
                    "Set automaton should not contain duplicated transitions"
215
                );
216
48653
                transitions.insert(
217
48653
                    (s_index, symbol.operation_id()),
218
48653
                    Transition {
219
48653
                        symbol: symbol.clone(),
220
48653
                        announcements,
221
48653
                        destinations,
222
48653
                    },
223
                );
224
            }
225

            
226
1227
            debug!(
227
                "Queue size {}, currently {} states and {} transitions",
228
                queue.len(),
229
                states.len(),
230
                transitions.len()
231
            );
232
        }
233

            
234
        // Clear the match goals since they are only for debugging purposes.
235
57
        if !log_enabled!(log::Level::Debug) {
236
1227
            for state in &mut states {
237
1227
                state.match_goals.clear();
238
1227
            }
239
        }
240
57
        info!(
241
            "Created set automaton (states: {}, transitions: {}, apma: {}) in {} ms",
242
            states.len(),
243
            transitions.len(),
244
            apma,
245
            (Instant::now() - start).as_millis()
246
        );
247

            
248
57
        let result = SetAutomaton { states, transitions };
249
57
        debug!("{result:?}");
250

            
251
57
        result
252
57
    }
253

            
254
    /// Returns the number of states
255
    pub fn num_of_states(&self) -> usize {
256
        self.states.len()
257
    }
258

            
259
    /// Returns the number of transitions
260
    pub fn num_of_transitions(&self) -> usize {
261
        self.transitions.len()
262
    }
263

            
264
    /// Returns the states of the automaton
265
13811024
    pub fn states(&self) -> &Vec<State> {
266
13811024
        &self.states
267
13811024
    }
268

            
269
    /// Returns the transitions of the automaton
270
12643380
    pub fn transitions(&self) -> &HashMap<(usize, usize), Transition<M>> {
271
12643380
        &self.transitions
272
12643380
    }
273

            
274
    /// Provides a formatter for the .dot file format
275
    pub fn to_dot_graph(&self, show_backtransitions: bool, show_final: bool) -> DotFormatter<'_, M> {
276
        DotFormatter {
277
            automaton: self,
278
            show_backtransitions,
279
            show_final,
280
        }
281
    }
282
}
283

            
284
#[derive(Debug)]
285
pub struct Derivative {
286
    pub completed: Vec<MatchGoal>,
287
    pub unchanged: Vec<MatchGoal>,
288
    pub reduced: Vec<MatchGoal>,
289
}
290

            
291
pub struct State {
292
    label: DataPosition,
293
    match_goals: Vec<MatchGoal>,
294
}
295

            
296
impl State {
297
    /// Derive transitions from a state given a head symbol. The resulting transition is returned as a tuple
298
    /// The tuple consists of a vector of outputs and a set of destinations (which are sets of match goals).
299
    /// We don't use the struct Transition as it requires that the destination is a full state, with name.
300
    /// Since we don't yet know whether the state already exists we just return a set of match goals as 'state'.
301
    ///
302
    /// Parameter symbol is the symbol for which the transition is computed
303
48653
    fn derive_transition(
304
48653
        &self,
305
48653
        symbol: &DataFunctionSymbol,
306
48653
        arity: usize,
307
48653
        rewrite_rules: &Vec<Rule>,
308
48653
        apma: bool,
309
48653
    ) -> (Vec<MatchAnnouncement>, Vec<(DataPosition, GoalsOrInitial)>) {
310
        // Computes the derivative containing the goals that are completed, unchanged and reduced
311
48653
        let mut derivative = self.compute_derivative(symbol, arity);
312

            
313
        // The outputs/matching patterns of the transitions are those who are completed
314
48653
        let outputs = derivative.completed.into_iter().map(|x| x.announcement).collect();
315

            
316
        // The new match goals are the unchanged and reduced match goals.
317
48653
        let mut new_match_goals = derivative.unchanged;
318
48653
        new_match_goals.append(&mut derivative.reduced);
319

            
320
48653
        let mut destinations = vec![];
321
        // If we are building an APMA we do not deepen the position or create a hypertransitions
322
        // with multiple endpoints
323
48653
        if apma {
324
14224
            if !new_match_goals.is_empty() {
325
802
                destinations.push((DataPosition::empty(), GoalsOrInitial::Goals(new_match_goals)));
326
13422
            }
327
        } else {
328
            // In case we are building a set automaton we partition the match goals
329
34429
            let partitioned = MatchGoal::partition(new_match_goals);
330

            
331
            // Get the greatest common prefix and shorten the positions
332
34429
            let mut positions_per_partition = vec![];
333
34429
            let mut gcp_length_per_partition = vec![];
334
46652
            for (p, pos) in partitioned {
335
46652
                positions_per_partition.push(pos);
336
46652
                let gcp = MatchGoal::greatest_common_prefix(&p);
337
46652
                let gcp_length = gcp.len();
338
46652
                gcp_length_per_partition.push(gcp_length);
339
46652
                let mut goals = MatchGoal::remove_prefix(p, gcp_length);
340
46652
                goals.sort_unstable();
341
46652
                destinations.push((gcp, GoalsOrInitial::Goals(goals)));
342
46652
            }
343

            
344
            // Handle fresh match goals, they are the positions Label(state).i
345
            // where i is between 1 and the arity of the function symbol of
346
            // the transition. Position 1 is the first argument.
347
34429
            for i in 1..arity + 1 {
348
31805
                let mut pos = self.label.clone();
349
31805
                pos.push(i);
350

            
351
                // Check if the fresh goals are related to one of the existing partitions
352
31805
                let mut partition_key = None;
353
52695
                'outer: for (i, part_pos) in positions_per_partition.iter().enumerate() {
354
79196
                    for p in part_pos {
355
79196
                        if MatchGoal::pos_comparable(p, &pos) {
356
17792
                            partition_key = Some(i);
357
17792
                            break 'outer;
358
61404
                        }
359
                    }
360
                }
361

            
362
31805
                if let Some(key) = partition_key {
363
                    // If the fresh goals fall in an existing partition
364
17792
                    let gcp_length = gcp_length_per_partition[key];
365
17792
                    let pos = DataPosition::new(&pos.indices()[gcp_length..]);
366

            
367
                    // Add the fresh goals to the partition
368
1076054
                    for rr in rewrite_rules {
369
1076054
                        if let GoalsOrInitial::Goals(goals) = &mut destinations[key].1 {
370
1076054
                            goals.push(MatchGoal {
371
1076054
                                obligations: vec![MatchObligation::new(rr.lhs.clone(), pos.clone())],
372
1076054
                                announcement: MatchAnnouncement {
373
1076054
                                    rule: (*rr).clone(),
374
1076054
                                    position: pos.clone(),
375
1076054
                                    symbols_seen: 0,
376
1076054
                                },
377
1076054
                            });
378
1076054
                        }
379
                    }
380
14013
                } else {
381
14013
                    // The transition is simply to the initial state
382
14013
                    // GoalsOrInitial::InitialState avoids unnecessary work of creating all these fresh goals
383
14013
                    destinations.push((pos, GoalsOrInitial::InitialState));
384
14013
                }
385
            }
386
        }
387

            
388
        // Sort the destination such that transitions which do not deepen the position are listed first
389
48653
        destinations.sort_unstable_by(|(pos1, _), (pos2, _)| pos1.cmp(pos2));
390
48653
        (outputs, destinations)
391
48653
    }
392

            
393
    /// For a transition 'symbol' of state 'self' this function computes which match goals are
394
    /// completed, unchanged and reduced.
395
48653
    fn compute_derivative(&self, symbol: &DataFunctionSymbol, arity: usize) -> Derivative {
396
48653
        let mut result = Derivative {
397
48653
            completed: vec![],
398
48653
            unchanged: vec![],
399
48653
            reduced: vec![],
400
48653
        };
401

            
402
5138007
        for mg in &self.match_goals {
403
5138007
            debug_assert!(
404
5138007
                !mg.obligations.is_empty(),
405
                "The obligations should never be empty, should be completed then"
406
            );
407

            
408
            // Completed match goals
409
5138007
            if mg.obligations.len() == 1
410
5110907
                && mg.obligations.iter().any(|mo| {
411
5110907
                    mo.position == self.label
412
2565108
                        && mo.pattern.data_function_symbol() == symbol.copy()
413
48449
                        && mo.pattern.data_arguments().all(|x| is_data_variable(&x))
414
                    // Again skip the function symbol
415
5110907
                })
416
15554
            {
417
15554
                result.completed.push(mg.clone());
418
5122453
            } else if mg
419
5122453
                .obligations
420
5122453
                .iter()
421
5138864
                .any(|mo| mo.position == self.label && mo.pattern.data_function_symbol() != symbol.copy())
422
2531108
            {
423
2531108
                // Match goal is discarded since head symbol does not match.
424
2607355
            } else if mg.obligations.iter().all(|mo| mo.position != self.label) {
425
                // Unchanged match goals
426
2560969
                let mut mg = mg.clone();
427
2560969
                if mg.announcement.rule.lhs != mg.obligations.first().unwrap().pattern {
428
34869
                    mg.announcement.symbols_seen += 1;
429
2526100
                }
430

            
431
2560969
                result.unchanged.push(mg.clone());
432
            } else {
433
                // Reduce match obligations
434
30376
                let mut mg = mg.clone();
435
30376
                let mut new_obligations = vec![];
436

            
437
30777
                for mo in mg.obligations {
438
30777
                    if mo.pattern.data_function_symbol() == symbol.copy() && mo.position == self.label {
439
                        // Reduced match obligation
440
55500
                        for (index, t) in mo.pattern.data_arguments().enumerate() {
441
55500
                            assert!(
442
55500
                                index < arity,
443
                                "This pattern associates function symbol {:?} with different arities {} and {}",
444
                                symbol,
445
                                index + 1,
446
                                arity
447
                            );
448

            
449
55500
                            if !is_data_variable(&t) {
450
39582
                                let mut new_pos = mo.position.clone();
451
39582
                                new_pos.push(index + 1);
452
39582
                                new_obligations.push(MatchObligation {
453
39582
                                    pattern: t.protect(),
454
39582
                                    position: new_pos,
455
39582
                                });
456
39582
                            }
457
                        }
458
401
                    } else {
459
401
                        // remains unchanged
460
401
                        new_obligations.push(mo.clone());
461
401
                    }
462
                }
463

            
464
30376
                new_obligations.sort_unstable_by(|mo1, mo2| mo1.position.len().cmp(&mo2.position.len()));
465
30376
                mg.obligations = new_obligations;
466
30376
                mg.announcement.symbols_seen += 1;
467

            
468
30376
                result.reduced.push(mg);
469
            }
470
        }
471

            
472
48653
        trace!(
473
            "=== compute_derivative(symbol = {}, label = {}) ===",
474
            symbol, self.label
475
        );
476
48653
        trace!("Match goals: {{");
477
5138007
        for mg in &self.match_goals {
478
5138007
            trace!("\t {mg:?}");
479
        }
480

            
481
48653
        trace!("}}");
482
48653
        trace!("Completed: {{");
483
48653
        for mg in &result.completed {
484
15554
            trace!("\t {mg:?}");
485
        }
486

            
487
48653
        trace!("}}");
488
48653
        trace!("Unchanged: {{");
489
2560969
        for mg in &result.unchanged {
490
2560969
            trace!("\t {mg:?}");
491
        }
492

            
493
48653
        trace!("}}");
494
48653
        trace!("Reduced: {{");
495
48653
        for mg in &result.reduced {
496
30376
            trace!("\t {mg:?}");
497
        }
498
48653
        trace!("}}");
499

            
500
48653
        result
501
48653
    }
502

            
503
    /// Create a state from a set of match goals
504
1170
    fn new(goals: Vec<MatchGoal>) -> State {
505
        // The label of the state is taken from a match obligation of a root match goal.
506
1170
        let mut label: Option<DataPosition> = None;
507

            
508
        // Go through all match goals until a root match goal is found
509
103185
        for goal in &goals {
510
103185
            if goal.announcement.position.is_empty() {
511
                // Find the shortest match obligation position.
512
                // This design decision was taken as it presumably has two advantages.
513
                // 1. Patterns that overlap will be more quickly distinguished, potentially decreasing
514
                // the size of the automaton.
515
                // 2. The average lookahead may be shorter.
516
2987
                if label.is_none() {
517
1170
                    label = Some(goal.obligations.first().unwrap().position.clone());
518
1817
                }
519

            
520
3400
                for obligation in &goal.obligations {
521
3400
                    if let Some(l) = &label && &obligation.position < l {
522
3
                        label = Some(obligation.position.clone());
523
3397
                    }
524
                }
525
100198
            }
526
        }
527

            
528
1170
        State {
529
1170
            label: label.unwrap(),
530
1170
            match_goals: goals,
531
1170
        }
532
1170
    }
533

            
534
    /// Returns the label of the state
535
13811024
    pub fn label(&self) -> &DataPosition {
536
13811024
        &self.label
537
13811024
    }
538

            
539
    /// Returns the match goals of the state
540
    pub fn match_goals(&self) -> &Vec<MatchGoal> {
541
        &self.match_goals
542
    }
543
}
544

            
545
/// Adds the given function symbol to the indexed symbols. Errors when a
546
/// function symbol is overloaded with different arities.
547
7754
fn add_symbol(function_symbol: DataFunctionSymbol, arity: usize, symbols: &mut HashMap<DataFunctionSymbol, usize>) {
548
7754
    if let Some(x) = symbols.get(&function_symbol) {
549
6533
        assert_eq!(
550
            *x, arity,
551
            "Function symbol {function_symbol} occurs with different arities",
552
        );
553
1221
    } else {
554
1221
        symbols.insert(function_symbol, arity);
555
1221
    }
556
7754
}
557

            
558
/// Returns false iff this is a higher order term, of the shape t(t_0, ..., t_n), or an unknown term.
559
3440
fn is_supported_term(t: &DataExpression) -> bool {
560
50978
    for subterm in t.iter() {
561
50978
        if is_data_application(&subterm) && !is_data_function_symbol(&subterm.arg(0)) {
562
            warn!("{} is higher order", &subterm);
563
            return false;
564
50978
        }
565
    }
566

            
567
3440
    true
568
3440
}
569

            
570
/// Checks whether the set automaton can use this rule, no higher order rules or binders.
571
1560
pub fn is_supported_rule(rule: &Rule) -> bool {
572
    // There should be no terms of the shape t(t0,...,t_n)
573
1560
    if !is_supported_term(&rule.rhs) || !is_supported_term(&rule.lhs) {
574
        return false;
575
1560
    }
576

            
577
1560
    for cond in &rule.conditions {
578
160
        if !is_supported_term(&cond.rhs) || !is_supported_term(&cond.lhs) {
579
            return false;
580
160
        }
581
    }
582

            
583
1560
    true
584
1560
}
585

            
586
/// Finds all data symbols in the term and adds them to the symbol index.
587
11431
fn find_symbols(t: &DataExpressionRef<'_>, symbols: &mut HashMap<DataFunctionSymbol, usize>) {
588
11431
    if is_data_function_symbol(t) {
589
2500
        add_symbol(t.protect().into(), 0, symbols);
590
8931
    } else if is_data_application(t) {
591
        // REC specifications should never contain this so it can be a debug error.
592
5254
        assert!(
593
5254
            is_data_function_symbol(&t.data_function_symbol()),
594
            "Error in term {t}, higher order term rewrite systems are not supported"
595
        );
596

            
597
5254
        add_symbol(t.data_function_symbol().protect(), t.data_arguments().len(), symbols);
598
7991
        for arg in t.data_arguments() {
599
7991
            find_symbols(&arg, symbols);
600
7991
        }
601
3677
    } else if is_data_machine_number(t) {
602
        // Ignore machine numbers during matching?
603
3677
    } else if !is_data_variable(t) {
604
        panic!("Unexpected term {t:?}");
605
3677
    }
606
11431
}