1
use std::collections::HashMap;
2
use std::iter;
3

            
4
use quote::quote;
5
use syn::Error;
6
use syn::FnArg;
7
use syn::Ident;
8
use syn::ImplItem;
9
use syn::ImplItemFn;
10
use syn::ItemImpl;
11
use syn::LitBool;
12
use syn::Pat;
13
use syn::Path;
14
use syn::parse::Parse;
15
use syn::parse::ParseStream;
16
use syn::parse::Result;
17
use syn::parse_quote;
18
use syn::spanned::Spanned;
19
use syn::token;
20

            
21
mod kw {
22
    syn::custom_keyword!(shortcut);
23
    syn::custom_keyword!(rule);
24
    syn::custom_keyword!(parser);
25
}
26

            
27
/// Attributes for the parser macro
28
struct MakeParserAttrs {
29
    parser: Path,
30
    rule_enum: Path,
31
}
32

            
33
/// Arguments for an alias attribute
34
struct AliasArgs {
35
    target: Ident,
36
    is_shortcut: bool,
37
}
38

            
39
/// Source of an alias, including identifier and shortcut status.
40
struct AliasSrc {
41
    ident: Ident,      // Identifier
42
    is_shortcut: bool, // Whether it's a shortcut
43
}
44

            
45
/// Parsed function metadata including function body, name, input argument, and aliases.
46
struct ParsedFn<'a> {
47
    // Body of the function
48
    function: &'a mut ImplItemFn,
49
    // Name of the function.
50
    fn_name: Ident,
51
    // Name of the first argument of the function, which should be of type `Node`.
52
    input_arg: Ident,
53
    // List of aliases pointing to this function
54
    alias_srcs: Vec<AliasSrc>,
55
}
56

            
57
impl Parse for MakeParserAttrs {
58
8
    fn parse(input: ParseStream) -> Result<Self> {
59
        // By default, the pest parser is the same type as the pest_consume one
60
8
        let mut parser = parse_quote!(Self);
61
        // By default, use the `Rule` type in scope
62
8
        let mut rule_enum = parse_quote!(Rule);
63

            
64
8
        while !input.is_empty() {
65
            let lookahead = input.lookahead1();
66
            if lookahead.peek(kw::parser) {
67
                let _: kw::parser = input.parse()?;
68
                let _: token::Eq = input.parse()?;
69
                parser = input.parse()?;
70
            } else if lookahead.peek(kw::rule) {
71
                let _: kw::rule = input.parse()?;
72
                let _: token::Eq = input.parse()?;
73
                rule_enum = input.parse()?;
74
            } else {
75
                return Err(lookahead.error());
76
            }
77

            
78
            if input.peek(token::Comma) {
79
                let _: token::Comma = input.parse()?;
80
            } else {
81
                break;
82
            }
83
        }
84

            
85
8
        Ok(MakeParserAttrs { parser, rule_enum })
86
8
    }
87
}
88

            
89
impl Parse for AliasArgs {
90
    fn parse(input: ParseStream) -> Result<Self> {
91
        let target = input.parse()?;
92
        let is_shortcut = if input.peek(token::Comma) {
93
            // #[alias(rule, shortcut = true)]
94
            let _: token::Comma = input.parse()?;
95
            let _: kw::shortcut = input.parse()?;
96
            let _: token::Eq = input.parse()?;
97
            let b: LitBool = input.parse()?;
98
            b.value
99
        } else {
100
            // #[alias(rule)]
101
            false
102
        };
103
        Ok(AliasArgs { target, is_shortcut })
104
    }
105
}
106

            
107
/// Collects and maps aliases from an implementation block.
108
8
fn collect_aliases(imp: &mut ItemImpl) -> Result<HashMap<Ident, Vec<AliasSrc>>> {
109
304
    let functions = imp.items.iter_mut().flat_map(|item| match item {
110
304
        ImplItem::Fn(m) => Some(m),
111
        _ => None,
112
304
    });
113

            
114
8
    let mut alias_map = HashMap::new();
115
304
    for function in functions {
116
304
        let fn_name = function.sig.ident.clone();
117
304
        let mut alias_attrs = function.attrs.iter().filter(|attr| attr.path().is_ident("alias"));
118

            
119
304
        if let Some(attr) = alias_attrs.next() {
120
            let args: AliasArgs = attr.parse_args()?;
121
            alias_map.entry(args.target).or_insert_with(Vec::new).push(AliasSrc {
122
                ident: fn_name,
123
                is_shortcut: args.is_shortcut,
124
            });
125
304
        } else {
126
304
            // Self entry
127
304
            alias_map
128
304
                .entry(fn_name.clone())
129
304
                .or_insert_with(Vec::new)
130
304
                .push(AliasSrc {
131
304
                    ident: fn_name,
132
304
                    is_shortcut: false,
133
304
                });
134
304
        }
135
304
        if let Some(attr) = alias_attrs.next() {
136
            return Err(Error::new(attr.span(), "expected at most one alias attribute"));
137
304
        }
138
    }
139

            
140
8
    debug_assert!(!alias_map.is_empty(), "Alias map should not be empty after collection");
141
8
    Ok(alias_map)
142
8
}
143

            
144
/// Extracts an identifier from a function argument.
145
304
fn extract_ident_argument(input_arg: &FnArg) -> Result<Ident> {
146
304
    match input_arg {
147
        FnArg::Receiver(_) => Err(Error::new(input_arg.span(), "this argument should not be `self`")),
148
304
        FnArg::Typed(input_arg) => match &*input_arg.pat {
149
304
            Pat::Ident(pat) => Ok(pat.ident.clone()),
150
            _ => Err(Error::new(
151
                input_arg.span(),
152
                "this argument should be a plain identifier instead of a pattern",
153
            )),
154
        },
155
    }
156
304
}
157

            
158
/// Parses a function to extract metadata for rule method processing.
159
304
fn parse_fn<'a>(function: &'a mut ImplItemFn, alias_map: &mut HashMap<Ident, Vec<AliasSrc>>) -> Result<ParsedFn<'a>> {
160
    // Rule methods must have exactly one argument
161
304
    if function.sig.inputs.len() != 1 {
162
        return Err(Error::new(
163
            function.sig.inputs.span(),
164
            "A rule method must have 1 argument",
165
        ));
166
304
    }
167

            
168
304
    let fn_name = function.sig.ident.clone();
169
    // Get the name of the first function argument
170
304
    let input_arg = extract_ident_argument(&function.sig.inputs[0])?;
171
304
    let alias_srcs = alias_map.remove(&fn_name).unwrap_or_default();
172

            
173
304
    debug_assert!(
174
304
        alias_srcs.iter().any(|src| src.ident == fn_name),
175
        "Function should have at least a self-reference in alias sources"
176
    );
177

            
178
304
    Ok(ParsedFn {
179
304
        function,
180
304
        fn_name,
181
304
        input_arg,
182
304
        alias_srcs,
183
304
    })
184
304
}
185

            
186
/// Applies special attributes to parsed functions.
187
304
fn apply_special_attrs(f: &mut ParsedFn, rule_enum: &Path) -> Result<()> {
188
304
    let function = &mut *f.function;
189
304
    let fn_name = &f.fn_name;
190
304
    let input_arg = &f.input_arg;
191

            
192
    // `alias` attr
193
    // f.alias_srcs has always at least 1 element because it has an entry pointing from itself.
194
304
    let aliases = f.alias_srcs.iter().map(|src| &src.ident).filter(|i| i != &fn_name);
195
304
    let block = &function.block;
196
304
    let self_ty = quote!(<Self as ::merc_pest_consume::Parser>);
197

            
198
    // Modify function block to handle shortcuts and aliases
199
304
    function.block = parse_quote!({
200
        let mut #input_arg = #input_arg;
201
        // While the current rule allows shortcutting, and there is a single child, and the
202
        // child can still be parsed by the current function, then skip to that child.
203
        while #self_ty::allows_shortcut(#input_arg.as_rule()) {
204
            if let ::std::result::Result::Ok(child)
205
                    = #input_arg.children().single() {
206
                if child.as_aliased_rule::<Self>() == #self_ty::rule_alias(#rule_enum::#fn_name) {
207
                    #input_arg = child;
208
                    continue;
209
                }
210
            }
211
            break
212
        }
213

            
214
        match #input_arg.as_rule() {
215
            #(#rule_enum::#aliases => Self::#aliases(#input_arg),)*
216
            #rule_enum::#fn_name => #block,
217
            r => panic!(
218
                "merc_pest_consume::parser: called the `{}` method on a node with rule `{:?}`",
219
                stringify!(#fn_name),
220
                r
221
            )
222
        }
223
    });
224

            
225
304
    debug_assert!(
226
304
        !f.alias_srcs.is_empty(),
227
        "Function must have at least one alias source (itself)"
228
    );
229
304
    Ok(())
230
304
}
231

            
232
/// Main function for generating the parser implementation.
233
8
pub fn make_parser(attrs: proc_macro::TokenStream, input: proc_macro::TokenStream) -> Result<proc_macro2::TokenStream> {
234
8
    let attrs: MakeParserAttrs = syn::parse(attrs)?;
235
8
    let parser = &attrs.parser;
236
8
    let rule_enum = &attrs.rule_enum;
237
8
    let mut imp: ItemImpl = syn::parse(input)?;
238

            
239
    // Collect aliases and build rule matching logic
240
8
    let mut alias_map = collect_aliases(&mut imp)?;
241
8
    let rule_alias_branches: Vec<_> = alias_map
242
8
        .iter()
243
304
        .flat_map(|(tgt, srcs)| iter::repeat(tgt).zip(srcs))
244
304
        .map(|(tgt, src)| {
245
304
            let ident = &src.ident;
246
304
            quote!(
247
                #rule_enum::#ident => Self::AliasedRule::#tgt,
248
            )
249
304
        })
250
8
        .collect();
251
8
    let aliased_rule_variants: Vec<_> = alias_map.keys().cloned().collect();
252
8
    let shortcut_branches: Vec<_> = alias_map
253
8
        .iter()
254
8
        .flat_map(|(_tgt, srcs)| srcs)
255
304
        .map(|AliasSrc { ident, is_shortcut }| {
256
304
            quote!(
257
                #rule_enum::#ident => #is_shortcut,
258
            )
259
304
        })
260
8
        .collect();
261

            
262
    // Process functions and apply attributes
263
8
    let fn_map: HashMap<Ident, ParsedFn> = imp
264
8
        .items
265
8
        .iter_mut()
266
304
        .flat_map(|item| match item {
267
304
            ImplItem::Fn(m) => Some(m),
268
            _ => None,
269
304
        })
270
304
        .map(|method| {
271
304
            *method = parse_quote!(
272
                #[allow(non_snake_case)]
273
                #method
274
            );
275

            
276
304
            let mut f = parse_fn(method, &mut alias_map)?;
277
304
            apply_special_attrs(&mut f, rule_enum)?;
278
304
            Ok((f.fn_name.clone(), f))
279
304
        })
280
8
        .collect::<Result<_>>()?;
281

            
282
    // Create functions for any remaining aliases
283
8
    let extra_fns: Vec<_> = alias_map
284
8
        .iter()
285
8
        .map(|(tgt, srcs)| {
286
            // Get the signature of one of the functions that has this alias
287
            let f = fn_map.get(&srcs.first().unwrap().ident).unwrap();
288
            let input_arg = f.input_arg.clone();
289
            let mut sig = f.function.sig.clone();
290
            sig.ident = tgt.clone();
291
            let srcs = srcs.iter().map(|src| &src.ident);
292

            
293
            Ok(parse_quote!(
294
                #sig {
295
                    match #input_arg.as_rule() {
296
                        #(#rule_enum::#srcs => Self::#srcs(#input_arg),)*
297
                        // We can't match on #rule_enum::#tgt since `tgt` might be an arbitrary
298
                        // identifier.
299
                        r if &format!("{:?}", r) == stringify!(#tgt) =>
300
                            return ::std::result::Result::Err(#input_arg.error(format!(
301
                                "merc_pest_consume::parser: missing method for rule {}",
302
                                stringify!(#tgt),
303
                            ))),
304
                        r => return ::std::result::Result::Err(#input_arg.error(format!(
305
                            "merc_pest_consume::parser: called method `{}` on a node with rule `{:?}`",
306
                            stringify!(#tgt),
307
                            r
308
                        ))),
309
                    }
310
                }
311
            ))
312
        })
313
8
        .collect::<Result<_>>()?;
314
8
    imp.items.extend(extra_fns);
315

            
316
    // Generate the final implementation
317
8
    let ty = &imp.self_ty;
318
8
    let (impl_generics, _, where_clause) = imp.generics.split_for_impl();
319

            
320
8
    debug_assert!(
321
8
        !aliased_rule_variants.is_empty(),
322
        "Must have at least one aliased rule variant"
323
    );
324

            
325
8
    Ok(quote!(
326
        #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
327
        #[allow(non_camel_case_types)]
328
        pub enum AliasedRule {
329
            #(#aliased_rule_variants,)*
330
        }
331

            
332
        impl #impl_generics ::merc_pest_consume::Parser for #ty #where_clause {
333
            type Rule = #rule_enum;
334
            type AliasedRule = AliasedRule;
335
            type Parser = #parser;
336
            fn rule_alias(rule: Self::Rule) -> Self::AliasedRule {
337
                match rule {
338
                    #(#rule_alias_branches)*
339
                    // TODO: return a proper error ?
340
                    r => panic!("Rule `{:?}` does not have a corresponding parsing method", r),
341
                }
342
            }
343
            fn allows_shortcut(rule: Self::Rule) -> bool {
344
                match rule {
345
                    #(#shortcut_branches)*
346
                    _ => false,
347
                }
348
            }
349
        }
350

            
351
        #imp
352
    ))
353
8
}