1
use proc_macro2::Span;
2
use proc_macro2::TokenStream;
3
use proc_macro2::TokenTree;
4
use quote::quote;
5
use syn::Expr;
6
use syn::Ident;
7
use syn::Pat;
8
use syn::Type;
9
use syn::bracketed;
10
use syn::parenthesized;
11
use syn::parse::Parse;
12
use syn::parse::ParseStream;
13
use syn::parse::Result;
14
use syn::parse_quote;
15
use syn::punctuated::Punctuated;
16
use syn::spanned::Spanned;
17
use syn::token;
18

            
19
/// Pattern for matching a node in a pattern matching expression.
20
///
21
/// Represents an individual pattern element with optional tag, rule name,
22
/// binding pattern, and multiplicity flag.
23
struct Pattern {
24
    tag: Option<String>,      // Optional tag for the pattern
25
    rule_name: Option<Ident>, // Optional rule name for parsing
26
    binder: Pat,              // Pattern to bind the matched node to
27
    multiple: bool,           // Whether this pattern matches multiple nodes
28
}
29

            
30
/// Alternative in a pattern matching expression.
31
///
32
/// Represents a sequence of patterns that can be matched against nodes.
33
struct Alternative {
34
    patterns: Punctuated<Pattern, token::Comma>, // Comma-separated patterns
35
}
36

            
37
/// Branch in a pattern matching expression.
38
///
39
/// Represents a set of alternatives and their associated body expression.
40
struct MatchBranch {
41
    alternatives: Punctuated<Alternative, token::Or>, // Alternatives separated by |
42
    body: Expr,                                       // Body expression to evaluate on match
43
}
44

            
45
/// Input for the match_nodes macro.
46
///
47
/// Contains parser type, input expression, and pattern matching branches.
48
struct MacroInput {
49
    parser: Type,                                    // Parser type
50
    input_expr: Expr,                                // Expression to match against
51
    branches: Punctuated<MatchBranch, token::Comma>, // Pattern matching branches
52
}
53

            
54
impl Parse for MacroInput {
55
236
    fn parse(input: ParseStream) -> Result<Self> {
56
236
        let parser = if input.peek(token::Lt) {
57
7
            let _: token::Lt = input.parse()?;
58
7
            let parser = input.parse()?;
59
7
            let _: token::Gt = input.parse()?;
60
7
            let _: token::Semi = input.parse()?;
61
7
            parser
62
        } else {
63
229
            parse_quote!(Self)
64
        };
65

            
66
236
        let input_expr = input.parse()?;
67
236
        let _: token::Semi = input.parse()?;
68
236
        let branches = Punctuated::parse_terminated(input)?;
69

            
70
236
        Ok(MacroInput {
71
236
            parser,
72
236
            input_expr,
73
236
            branches,
74
236
        })
75
236
    }
76
}
77

            
78
impl Parse for MatchBranch {
79
286
    fn parse(input: ParseStream) -> Result<Self> {
80
286
        let alternatives = Punctuated::parse_separated_nonempty(input)?;
81
286
        let _: token::FatArrow = input.parse()?;
82
286
        let body = input.parse()?;
83

            
84
286
        Ok(MatchBranch { alternatives, body })
85
286
    }
86
}
87

            
88
impl Parse for Alternative {
89
287
    fn parse(input: ParseStream) -> Result<Self> {
90
        let contents;
91
287
        let _: token::Bracket = bracketed!(contents in input);
92
287
        let patterns = Punctuated::parse_terminated(&contents)?;
93
287
        Ok(Alternative { patterns })
94
287
    }
95
}
96

            
97
impl Parse for Pattern {
98
459
    fn parse(input: ParseStream) -> Result<Self> {
99
459
        let mut tag = None;
100
        let binder: Pat;
101
        let multiple;
102
        let rule_name;
103

            
104
459
        let ahead = input.fork();
105
459
        let _: TokenTree = ahead.parse()?;
106
459
        if ahead.peek(token::Pound) {
107
3
            let tag_ident: Ident = input.parse()?;
108
3
            tag = Some(tag_ident.to_string());
109
3
            let _: token::Pound = input.parse()?;
110
456
        }
111

            
112
459
        let ahead = input.fork();
113
459
        let _: TokenTree = ahead.parse()?;
114
459
        if ahead.peek(token::Paren) {
115
            // If `input` starts with `foo(`
116
444
            rule_name = Some(input.parse()?);
117
            let contents;
118
444
            parenthesized!(contents in input);
119
444
            binder = Pat::parse_multi(&contents)?;
120
        } else {
121
            // A plain pattern captures the node itself without parsing anything.
122
15
            rule_name = None;
123
15
            binder = Pat::parse_multi(input)?;
124
        }
125

            
126
459
        if input.peek(token::DotDot) {
127
80
            let _: token::DotDot = input.parse()?;
128
80
            multiple = true;
129
379
        } else if input.is_empty() || input.peek(token::Comma) {
130
379
            multiple = false;
131
379
        } else {
132
            return Err(input.error("expected `..` or nothing"));
133
        }
134

            
135
459
        Ok(Pattern {
136
459
            tag,
137
459
            rule_name,
138
459
            binder,
139
459
            multiple,
140
459
        })
141
459
    }
142
}
143

            
144
/// Traverses a pattern and generates code to match against nodes.
145
574
fn traverse_pattern(
146
574
    mut patterns: &[Pattern],
147
574
    i_iter: &Ident,
148
574
    matches_pat: impl Fn(&Pattern, TokenStream) -> TokenStream,
149
574
    process_item: impl Fn(&Pattern, TokenStream) -> TokenStream,
150
574
    error: TokenStream,
151
574
) -> TokenStream {
152
574
    let mut steps = Vec::new();
153

            
154
    // Handle trailing single patterns first for correct non-greedy matching
155
1286
    while patterns.last().is_some_and(|pat| !pat.multiple) {
156
712
        let [remaining_pats @ .., pat] = patterns else {
157
            unreachable!()
158
        };
159

            
160
712
        patterns = remaining_pats;
161
712
        let this_node = process_item(pat, quote!(node));
162
712
        steps.push(quote!(
163
            let Some(node) = #i_iter.next_back() else { #error };
164
            #this_node;
165
        ));
166
    }
167

            
168
    // Process remaining patterns
169
574
    for pat in patterns {
170
206
        if !pat.multiple {
171
46
            // Single pattern - match exactly one node
172
46
            let this_node = process_item(pat, quote!(node));
173
46
            steps.push(quote!(
174
46
                let Some(node) = #i_iter.next() else { #error };
175
46
                #this_node;
176
46
            ));
177
46
        } else {
178
            // Multiple pattern - match greedily as long as nodes match
179
160
            let matches_node = matches_pat(pat, quote!(node));
180
160
            let this_slice = process_item(pat, quote!(matched));
181
160
            steps.push(quote!(
182
                let matched = <_ as ::merc_pest_consume::Itertools>::peeking_take_while(&mut #i_iter, |node| #matches_node);
183
                #this_slice;
184
            ))
185
        }
186
    }
187

            
188
574
    debug_assert!(
189
574
        !steps.is_empty() || patterns.is_empty(),
190
        "Must generate steps for non-empty patterns"
191
    );
192

            
193
574
    quote!(
194
        #[allow(unused_mut)]
195
        let mut #i_iter = #i_iter.peekable();
196
        #(#steps)*
197
    )
198
574
}
199

            
200
/// Generates code for a single pattern matching alternative.
201
287
fn make_alternative(
202
287
    alternative: Alternative,
203
287
    body: &Expr,
204
287
    i_nodes: &Ident,
205
287
    i_node_namer: &Ident,
206
287
    parser: &Type,
207
287
) -> TokenStream {
208
287
    let i_nodes_iter = Ident::new("___nodes_iter", Span::call_site());
209
287
    let name_enum = quote!(<#parser as ::merc_pest_consume::NodeMatcher>::NodeName);
210
287
    let node_namer_ty = quote!(<_ as ::merc_pest_consume::NodeNamer<#parser>>);
211
287
    let patterns: Vec<_> = alternative.patterns.into_iter().collect();
212

            
213
    // Function to generate code for checking if a pattern matches a node
214
539
    let matches_pat = |pat: &Pattern, x| {
215
539
        let rule_cond = match &pat.rule_name {
216
524
            Some(rule_name) => {
217
524
                quote!(#node_namer_ty::node_name(&#i_node_namer, &#x) == #name_enum::#rule_name)
218
            }
219
15
            None => quote!(true),
220
        };
221
539
        let tag_cond = match &pat.tag {
222
5
            Some(tag) => {
223
5
                quote!(#node_namer_ty::tag(&#i_node_namer, &#x) == Some(#tag))
224
            }
225
534
            None => quote!(true),
226
        };
227
539
        quote!(#rule_cond && #tag_cond)
228
539
    };
229

            
230
    // Generate code for checking if this alternative matches
231
459
    let process_item = |pat: &Pattern, i_matched| {
232
459
        if !pat.multiple {
233
379
            let cond = matches_pat(pat, i_matched);
234
379
            quote!(
235
                if !(#cond) { return false; }
236
            )
237
        } else {
238
80
            quote!(
239
                // Consume the iterator.
240
                #i_matched.count();
241
            )
242
        }
243
459
    };
244

            
245
287
    let conditions = traverse_pattern(
246
287
        patterns.as_slice(),
247
287
        &i_nodes_iter,
248
287
        matches_pat,
249
287
        process_item,
250
287
        quote!(return false),
251
    );
252

            
253
    // Generate code for parsing nodes when the alternative matches
254
459
    let parse_rule = |rule: &Option<_>, node| match rule {
255
444
        Some(rule_name) => quote!(#parser::#rule_name(#node)),
256
15
        None => quote!(Ok(#node)),
257
459
    };
258

            
259
459
    let process_item = |pat: &Pattern, i_matched| {
260
459
        if !pat.multiple {
261
379
            let parse = parse_rule(&pat.rule_name, quote!(#i_matched));
262
379
            let binder = &pat.binder;
263
379
            quote!(
264
                let #binder = #parse?;
265
            )
266
        } else {
267
80
            let parse_node = parse_rule(&pat.rule_name, quote!(node));
268
80
            let binder = &pat.binder;
269
80
            quote!(
270
                let #binder = #i_matched
271
                    .map(|node| #parse_node)
272
                    .collect::<::std::result::Result<::std::vec::Vec<_>, _>>()?
273
                    .into_iter();
274
            )
275
        }
276
459
    };
277

            
278
287
    let parses = traverse_pattern(
279
287
        patterns.as_slice(),
280
287
        &i_nodes_iter,
281
287
        matches_pat,
282
287
        process_item,
283
287
        quote!(unreachable!()),
284
    );
285

            
286
287
    debug_assert!(!patterns.is_empty(), "Alternative must have at least one pattern");
287

            
288
287
    quote!(
289
        _ if {
290
            let check_condition = |slice: &[_]| -> bool {
291
                let #i_nodes_iter = slice.iter();
292
                #conditions
293
                #i_nodes_iter.next().is_none()
294
            };
295
            check_condition(#i_nodes.as_slice())
296
        } => {
297
            let #i_nodes_iter = #i_nodes.into_iter();
298
            #parses
299
            #body
300
        }
301
    )
302
287
}
303

            
304
/// Implements the match_nodes macro.
305
236
pub fn match_nodes(input: proc_macro::TokenStream) -> Result<proc_macro2::TokenStream> {
306
236
    let input: MacroInput = syn::parse(input)?;
307

            
308
236
    let i_nodes = Ident::new("___nodes", input.input_expr.span());
309
236
    let i_node_rules = Ident::new("___node_rules", Span::call_site());
310
236
    let i_node_namer = Ident::new("___node_namer", Span::call_site());
311

            
312
236
    let input_expr = &input.input_expr;
313
236
    let parser = &input.parser;
314

            
315
    // Generate code for each alternative in each branch
316
236
    let branches = input
317
236
        .branches
318
236
        .into_iter()
319
286
        .flat_map(|br| {
320
286
            let body = br.body;
321
286
            let i_nodes = &i_nodes;
322
286
            let i_node_namer = &i_node_namer;
323
286
            br.alternatives
324
286
                .into_iter()
325
287
                .map(move |alt| make_alternative(alt, &body, i_nodes, i_node_namer, parser))
326
286
        })
327
236
        .collect::<Vec<_>>();
328

            
329
236
    debug_assert!(!branches.is_empty(), "Must generate at least one branch");
330

            
331
236
    let node_list_ty = quote!(<_ as ::merc_pest_consume::NodeList<#parser>>);
332
236
    let node_namer_ty = quote!(<_ as ::merc_pest_consume::NodeNamer<#parser>>);
333
236
    Ok(quote!({
334
        let (#i_nodes, #i_node_namer) = #node_list_ty::consume(#input_expr);
335

            
336
        #[allow(unreachable_code, clippy::int_plus_one)]
337
        match () {
338
            #(#branches,)*
339
            _ => {
340
                // Collect the rule names to display.
341
                let #i_node_rules: ::std::vec::Vec<_> =
342
                        #i_nodes.iter().map(|n| #node_namer_ty::node_name(&#i_node_namer, n)).collect();
343
                return ::std::result::Result::Err(
344
                    #node_namer_ty::error(
345
                        #i_node_namer,
346
                        format!("Nodes didn't match any pattern: {:?}", #i_node_rules)
347
                    )
348
                );
349
            }
350
        }
351
    }))
352
236
}