1
use proc_macro2::TokenStream;
2

            
3
use quote::ToTokens;
4
use quote::format_ident;
5
use quote::quote;
6
use syn::Item;
7
use syn::ItemMod;
8
use syn::parse_quote;
9

            
10
16
pub(crate) fn merc_derive_terms_impl(_attributes: TokenStream, input: TokenStream) -> TokenStream {
11
    // Parse the input tokens into a syntax tree
12
16
    let mut ast: ItemMod = syn::parse2(input.clone()).expect("merc_term can only be applied to a module");
13

            
14
16
    if let Some((_, content)) = &mut ast.content {
15
        // Generated code blocks are added to this list.
16
16
        let mut added = vec![];
17

            
18
257
        for item in content.iter_mut() {
19
82
            match item {
20
37
                Item::Struct(object) => {
21
                    // If the struct is annotated with term we process it as a term.
22
91
                    if let Some(attr) = object.attrs.iter().find(|attr| attr.meta.path().is_ident("merc_term")) {
23
                        // The #term(assertion) annotation must contain an assertion
24
37
                        let assertion = match attr.parse_args::<syn::Ident>() {
25
37
                            Ok(assertion) => {
26
37
                                let assertion_msg = format!("{assertion}");
27
37
                                quote!(
28
                                    debug_assert!(#assertion(&term), "Term {:?} does not satisfy {}", term, #assertion_msg)
29
                                )
30
                            }
31
                            Err(_x) => {
32
                                quote!()
33
                            }
34
                        };
35

            
36
                        // Add the expected derive macros to the input struct.
37
37
                        object
38
37
                            .attrs
39
37
                            .push(parse_quote!(#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]));
40

            
41
                        // ALL structs in this module must contain the term.
42
37
                        assert!(
43
37
                            object.fields.iter().any(|field| {
44
37
                                if let Some(name) = &field.ident {
45
37
                                    name == "term"
46
                                } else {
47
                                    false
48
                                }
49
37
                            }),
50
                            "The struct {} in mod {} has no field 'term: ATerm'",
51
                            object.ident,
52
                            ast.ident
53
                        );
54

            
55
37
                        let name = format_ident!("{}", object.ident);
56

            
57
                        // Simply the generics from the struct.
58
37
                        let generics = object.generics.clone();
59

            
60
                        // Helper to create generics with added lifetimes.
61
148
                        fn create_generics_with_lifetimes(
62
148
                            base_generics: &syn::Generics,
63
148
                            lifetime_names: &[&str],
64
148
                        ) -> syn::Generics {
65
148
                            let mut generics = base_generics.clone();
66
185
                            for &lifetime_name in lifetime_names {
67
185
                                generics.params.push(syn::GenericParam::Lifetime(syn::LifetimeParam {
68
185
                                    attrs: vec![],
69
185
                                    lifetime: syn::Lifetime::new(lifetime_name, proc_macro2::Span::call_site()),
70
185
                                    bounds: syn::punctuated::Punctuated::new(),
71
185
                                    colon_token: None,
72
185
                                }));
73
185
                            }
74
148
                            generics
75
148
                        }
76

            
77
                        // The generics from the struct with <'a, 'b> added for the Term trait.
78
37
                        let generics_term = create_generics_with_lifetimes(&object.generics, &["'a", "'b"]);
79

            
80
                        // Only 'a prepended for the Ref<'a> struct.
81
37
                        let generics_ref = create_generics_with_lifetimes(&object.generics, &["'a"]);
82

            
83
                        // Only 'b prepended for the Ref<'b> struct.
84
37
                        let generics_ref_b = create_generics_with_lifetimes(&object.generics, &["'b"]);
85

            
86
                        // Only 'static prepended for the Ref<'static> struct.
87
37
                        let generics_static = create_generics_with_lifetimes(&object.generics, &["'static"]);
88

            
89
                        // Handle PhantomData generics - use void type if no generics exist
90
37
                        let generics_phantom = if object.generics.params.is_empty() {
91
37
                            quote!(<()>)
92
                        } else {
93
                            generics.to_token_stream()
94
                        };
95

            
96
                        // Add a <name>Ref struct that contains the ATermRef<'a> and
97
                        // the implementation and both protect and borrow. Also add
98
                        // the conversion from and to an ATerm.
99
37
                        let name_ref = format_ident!("{}Ref", object.ident);
100
37
                        let generated: TokenStream = quote!(
101
                            impl #generics #name #generics {
102
                                pub fn copy #generics_ref(&'a self) -> #name_ref #generics_ref {
103
                                    self.term.copy().into()
104
                                }
105
                            }
106

            
107
                            impl #generics From<ATerm> for #name #generics {
108
                                fn from(term: ATerm) -> #name {
109
                                    #assertion;
110
                                    #name {
111
                                        term
112
                                    }
113
                                }
114
                            }
115

            
116
                            impl #generics ::std::convert::Into<ATerm> for #name #generics{
117
                                fn into(self) -> ATerm {
118
                                    self.term
119
                                }
120
                            }
121

            
122
                            impl #generics ::std::ops::Deref for #name #generics{
123
                                type Target = ATerm;
124

            
125
                                fn deref(&self) -> &Self::Target {
126
                                    &self.term
127
                                }
128
                            }
129

            
130
                            impl #generics ::std::borrow::Borrow<ATerm> for #name #generics{
131
                                fn borrow(&self) -> &ATerm {
132
                                    &self.term
133
                                }
134
                            }
135

            
136
                            impl #generics Markable for #name #generics{
137
                                fn mark(&self, marker: &mut Marker) {
138
                                    self.term.mark(marker);
139
                                }
140

            
141
                                fn contains_term(&self, term: &ATermRef<'_>) -> bool {
142
                                    &self.term.copy() == term
143
                                }
144

            
145
                                fn contains_symbol(&self, symbol: &SymbolRef<'_>) -> bool {
146
                                    self.get_head_symbol() == *symbol
147
                                }
148

            
149
                                fn len(&self) -> usize {
150
                                    1
151
                                }
152
                            }
153

            
154
                            impl ::std::fmt::Debug for #name #generics {
155
                                fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
156
                                    write!(f, "{:?}", self.term)
157
                                }
158
                            }
159

            
160
                            impl #generics_term Term<'a, 'b> for #name #generics where 'b: 'a {
161
                                delegate! {
162
                                    to self.term {
163
                                        fn protect(&self) -> ATerm;
164
                                        fn arg(&'b self, index: usize) -> ATermRef<'a>;
165
                                        fn arguments(&'b self) -> ATermArgs<'a>;
166
                                        fn copy(&'b self) -> ATermRef<'a>;
167
                                        fn get_head_symbol(&'b self) -> SymbolRef<'a>;
168
                                        fn iter(&'b self) -> TermIterator<'a>;
169
                                        fn index(&self) -> usize;
170
                                        fn shared(&self) -> &ATermIndex;
171
                                    }
172
                                }
173
                            }
174

            
175
                            #[derive(Eq, Hash, Ord, PartialEq, PartialOrd)]
176
                            pub struct #name_ref #generics_ref {
177
                                pub(crate) term: ATermRef<'a>,
178
                                _marker: ::std::marker::PhantomData #generics_phantom,
179
                            }
180

            
181
                            impl #generics_ref  #name_ref #generics_ref  {
182
                                pub fn copy<'b>(&'b self) -> #name_ref #generics_ref_b{
183
                                    self.term.copy().into()
184
                                }
185

            
186
                                pub fn protect(&self) -> #name {
187
                                    self.term.protect().into()
188
                                }
189
                            }
190

            
191
                            impl #generics_ref ::std::convert::From<ATermRef<'a>> for #name_ref #generics_ref {
192
                                fn from(term: ATermRef<'a>) -> #name_ref #generics_ref  {
193
                                    #assertion;
194
                                    #name_ref {
195
                                        term,
196
                                        _marker: ::std::marker::PhantomData,
197
                                    }
198
                                }
199
                            }
200

            
201
                            impl #generics_ref ::std::convert::Into<ATermRef<'a>> for #name_ref #generics_ref  {
202
                                fn into(self) -> ATermRef<'a> {
203
                                    self.term
204
                                }
205
                            }
206

            
207
                            impl #generics_ref ::std::fmt::Debug for #name_ref #generics_ref {
208
                                fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
209
                                    write!(f, "{:?}", self.term)
210
                                }
211
                            }
212

            
213
                            impl #generics_term Term<'a, '_> for #name_ref #generics_ref  {
214
                                delegate! {
215
                                    to self.term {
216
                                        fn protect(&self) -> ATerm;
217
                                        fn arg(&self, index: usize) -> ATermRef<'a>;
218
                                        fn arguments(&self) -> ATermArgs<'a>;
219
                                        fn copy(&self) -> ATermRef<'a>;
220
                                        fn get_head_symbol(&self) -> SymbolRef<'a>;
221
                                        fn iter(&self) -> TermIterator<'a>;
222
                                        fn index(&self) -> usize;
223
                                        fn shared(&self) -> &ATermIndex;
224
                                    }
225
                                }
226
                            }
227

            
228
                            impl #generics_ref ::std::borrow::Borrow<ATermRef<'a>> for #name_ref #generics_ref {
229
                                fn borrow(&self) -> &ATermRef<'a> {
230
                                    &self.term
231
                                }
232
                            }
233

            
234
                            impl #generics_ref Markable for #name_ref #generics_ref {
235
                                fn mark(&self, marker: &mut Marker) {
236
                                    self.term.mark(marker);
237
                                }
238

            
239
                                fn contains_term(&self, term: &ATermRef<'_>) -> bool {
240
                                    &self.term == term
241
                                }
242

            
243
                                fn contains_symbol(&self, symbol: &SymbolRef<'_>) -> bool {
244
                                    self.get_head_symbol() == *symbol
245
                                }
246

            
247
                                fn len(&self) -> usize {
248
                                    1
249
                                }
250
                            }
251

            
252
                            impl Transmutable for #name_ref #generics_static {
253
                                type Target #generics_ref = #name_ref #generics_ref;
254

            
255
                                fn transmute_lifetime<'a>(&self) -> &'a Self::Target #generics_ref {
256
                                    unsafe { ::std::mem::transmute::<&Self, &'a #name_ref #generics_ref>(self) }
257
                                }
258

            
259
                                fn transmute_lifetime_mut<'a>(&mut self) -> &'a mut Self::Target #generics_ref {
260
                                    unsafe { ::std::mem::transmute::<&mut Self, &'a mut #name_ref #generics_ref>(self) }
261
                                }
262
                            }
263
                        );
264

            
265
37
                        added.push(Item::Verbatim(generated));
266
                    }
267
                }
268
82
                Item::Impl(implementation)
269
82
                    if !implementation
270
82
                        .attrs
271
82
                        .iter()
272
82
                        .any(|attr| attr.meta.path().is_ident("merc_ignore")) =>
273
                {
274
                    // Duplicate the implementation for the Ref struct that is generated above.
275
61
                    let mut ref_implementation = implementation.clone();
276

            
277
                    // Remove ignored functions
278
133
                    ref_implementation.items.retain(|item| match item {
279
133
                        syn::ImplItem::Fn(func) => {
280
199
                            !func.attrs.iter().any(|attr| attr.meta.path().is_ident("merc_ignore"))
281
                        }
282
                        _ => true,
283
133
                    });
284

            
285
61
                    if let syn::Type::Path(path) = ref_implementation.self_ty.as_ref() {
286
61
                        let path = if let Some(identifier) = path.path.get_ident() {
287
                            // Build an identifier with the postfix Ref<'_>
288
61
                            let name_ref = format_ident!("{}Ref", identifier);
289
61
                            parse_quote!(#name_ref <'_>)
290
                        } else {
291
                            let path_segments = &path.path.segments;
292

            
293
                            let _name_ref = format_ident!(
294
                                "{}Ref",
295
                                path_segments
296
                                    .first()
297
                                    .expect("Path should at least have an identifier")
298
                                    .ident
299
                            );
300
                            // let segments: Vec<syn::PathSegment> = path_segments.iter().skip(1).collect();
301
                            // parse_quote!(#name_ref #segments)
302
                            unimplemented!()
303
                        };
304

            
305
61
                        ref_implementation.self_ty = Box::new(syn::Type::Path(syn::TypePath { qself: None, path }));
306

            
307
61
                        added.push(Item::Verbatim(ref_implementation.into_token_stream()));
308
                    }
309
                }
310
159
                _ => {
311
159
                    // Ignore the rest.
312
159
                }
313
            }
314
        }
315

            
316
16
        content.append(&mut added);
317
    }
318

            
319
    // Hand the output tokens back to the compiler
320
16
    ast.into_token_stream()
321
16
}
322

            
323
#[cfg(test)]
324
mod tests {
325
    use std::str::FromStr;
326

            
327
    use proc_macro2::TokenStream;
328

            
329
    use crate::merc_derive_terms_impl;
330

            
331
    #[test]
332
1
    fn test_macro() {
333
1
        let input = "
334
1
            mod anything {
335
1

            
336
1
                #[merc_term(test)]
337
1
                #[derive(Debug)]
338
1
                struct Test {
339
1
                    term: ATerm,
340
1
                }
341
1

            
342
1
                impl Test {
343
1
                    fn a_function() {
344
1

            
345
1
                    }
346
1
                }
347
1
            }
348
1
        ";
349

            
350
1
        let tokens = TokenStream::from_str(input).unwrap();
351
1
        let result = merc_derive_terms_impl(TokenStream::default(), tokens);
352

            
353
1
        println!("{result}");
354
1
    }
355
}