1
use proc_macro2::TokenStream;
2

            
3
use quote::ToTokens;
4
use quote::format_ident;
5
use quote::quote;
6
use syn::Item;
7
use syn::ItemMod;
8
use syn::parse_quote;
9

            
10
16
pub(crate) fn merc_derive_terms_impl(_attributes: TokenStream, input: TokenStream) -> TokenStream {
11
    // Parse the input tokens into a syntax tree
12
16
    let mut ast: ItemMod = syn::parse2(input.clone()).expect("merc_term can only be applied to a module");
13

            
14
16
    if let Some((_, content)) = &mut ast.content {
15
        // Generated code blocks are added to this list.
16
16
        let mut added = vec![];
17

            
18
155
        for item in content.iter_mut() {
19
155
            match item {
20
34
                Item::Struct(object) => {
21
                    // If the struct is annotated with term we process it as a term.
22
88
                    if let Some(attr) = object.attrs.iter().find(|attr| attr.meta.path().is_ident("merc_term")) {
23
                        // The #term(assertion) annotation must contain an assertion
24
34
                        let assertion = match attr.parse_args::<syn::Ident>() {
25
34
                            Ok(assertion) => {
26
34
                                let assertion_msg = format!("{assertion}");
27
34
                                quote!(
28
                                    debug_assert!(#assertion(&term), "Term {:?} does not satisfy {}", term, #assertion_msg)
29
                                )
30
                            }
31
                            Err(_x) => {
32
                                quote!()
33
                            }
34
                        };
35

            
36
                        // Add the expected derive macros to the input struct.
37
34
                        object
38
34
                            .attrs
39
34
                            .push(parse_quote!(#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]));
40

            
41
                        // ALL structs in this module must contain the term.
42
34
                        assert!(
43
34
                            object.fields.iter().any(|field| {
44
34
                                if let Some(name) = &field.ident {
45
34
                                    name == "term"
46
                                } else {
47
                                    false
48
                                }
49
34
                            }),
50
                            "The struct {} in mod {} has no field 'term: ATerm'",
51
                            object.ident,
52
                            ast.ident
53
                        );
54

            
55
34
                        let name = format_ident!("{}", object.ident);
56

            
57
                        // Simply the generics from the struct.
58
34
                        let generics = object.generics.clone();
59

            
60
                        // Helper to create generics with added lifetimes.
61
136
                        fn create_generics_with_lifetimes(
62
136
                            base_generics: &syn::Generics,
63
136
                            lifetime_names: &[&str],
64
136
                        ) -> syn::Generics {
65
136
                            let mut generics = base_generics.clone();
66
170
                            for &lifetime_name in lifetime_names {
67
170
                                generics.params.push(syn::GenericParam::Lifetime(syn::LifetimeParam {
68
170
                                    attrs: vec![],
69
170
                                    lifetime: syn::Lifetime::new(lifetime_name, proc_macro2::Span::call_site()),
70
170
                                    bounds: syn::punctuated::Punctuated::new(),
71
170
                                    colon_token: None,
72
170
                                }));
73
170
                            }
74
136
                            generics
75
136
                        }
76

            
77
                        // The generics from the struct with <'a, 'b> added for the Term trait.
78
34
                        let generics_term = create_generics_with_lifetimes(&object.generics, &["'a", "'b"]);
79

            
80
                        // Only 'a prepended for the Ref<'a> struct.
81
34
                        let generics_ref = create_generics_with_lifetimes(&object.generics, &["'a"]);
82

            
83
                        // Only 'b prepended for the Ref<'b> struct.
84
34
                        let generics_ref_b = create_generics_with_lifetimes(&object.generics, &["'b"]);
85

            
86
                        // Only 'static prepended for the Ref<'static> struct.
87
34
                        let generics_static = create_generics_with_lifetimes(&object.generics, &["'static"]);
88

            
89
                        // Handle PhantomData generics - use void type if no generics exist
90
34
                        let generics_phantom = if object.generics.params.is_empty() {
91
34
                            quote!(<()>)
92
                        } else {
93
                            generics.to_token_stream()
94
                        };
95

            
96
                        // Add a <name>Ref struct that contains the ATermRef<'a> and
97
                        // the implementation and both protect and borrow. Also add
98
                        // the conversion from and to an ATerm.
99
34
                        let name_ref = format_ident!("{}Ref", object.ident);
100
34
                        let generated: TokenStream = quote!(
101
                            impl #generics #name #generics {
102
                                pub fn copy #generics_ref(&'a self) -> #name_ref #generics_ref {
103
                                    self.term.copy().into()
104
                                }
105
                            }
106

            
107
                            impl #generics From<ATerm> for #name #generics {
108
                                fn from(term: ATerm) -> #name {
109
                                    #assertion;
110
                                    #name {
111
                                        term
112
                                    }
113
                                }
114
                            }
115

            
116
                            impl #generics ::std::convert::Into<ATerm> for #name #generics{
117
                                fn into(self) -> ATerm {
118
                                    self.term
119
                                }
120
                            }
121

            
122
                            impl #generics ::std::ops::Deref for #name #generics{
123
                                type Target = ATerm;
124

            
125
                                fn deref(&self) -> &Self::Target {
126
                                    &self.term
127
                                }
128
                            }
129

            
130
                            impl #generics ::std::borrow::Borrow<ATerm> for #name #generics{
131
                                fn borrow(&self) -> &ATerm {
132
                                    &self.term
133
                                }
134
                            }
135

            
136
                            impl #generics Markable for #name #generics{
137
                                fn mark(&self, marker: &mut Marker) {
138
                                    self.term.mark(marker);
139
                                }
140

            
141
                                fn contains_term(&self, term: &ATermRef<'_>) -> bool {
142
                                    &self.term.copy() == term
143
                                }
144

            
145
                                fn contains_symbol(&self, symbol: &SymbolRef<'_>) -> bool {
146
                                    self.get_head_symbol() == *symbol
147
                                }
148

            
149
                                fn len(&self) -> usize {
150
                                    1
151
                                }
152
                            }
153

            
154
                            impl ::std::fmt::Debug for #name #generics {
155
                                fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
156
                                    write!(f, "{:?}", self.term)
157
                                }
158
                            }
159

            
160
                            impl #generics_term Term<'a, 'b> for #name #generics where 'b: 'a {
161
                                delegate! {
162
                                    to self.term {
163
                                        fn protect(&self) -> ATerm;
164
                                        fn arg(&'b self, index: usize) -> ATermRef<'a>;
165
                                        fn arguments(&'b self) -> ATermArgs<'a>;
166
                                        fn copy(&'b self) -> ATermRef<'a>;
167
                                        fn get_head_symbol(&'b self) -> SymbolRef<'a>;
168
                                        fn iter(&'b self) -> TermIterator<'a>;
169
                                        fn index(&self) -> usize;
170
                                        fn shared(&self) -> &ATermIndex;
171
                                        fn annotation(&self) -> Option<usize>;
172
                                    }
173
                                }
174
                            }
175

            
176
                            #[derive(Eq, Hash, Ord, PartialEq, PartialOrd)]
177
                            pub struct #name_ref #generics_ref {
178
                                pub(crate) term: ATermRef<'a>,
179
                                _marker: ::std::marker::PhantomData #generics_phantom,
180
                            }
181

            
182
                            impl #generics_ref  #name_ref #generics_ref  {
183
                                pub fn copy<'b>(&'b self) -> #name_ref #generics_ref_b{
184
                                    self.term.copy().into()
185
                                }
186

            
187
                                pub fn protect(&self) -> #name {
188
                                    self.term.protect().into()
189
                                }
190
                            }
191

            
192
                            impl #generics_ref ::std::convert::From<ATermRef<'a>> for #name_ref #generics_ref {
193
                                fn from(term: ATermRef<'a>) -> #name_ref #generics_ref  {
194
                                    #assertion;
195
                                    #name_ref {
196
                                        term,
197
                                        _marker: ::std::marker::PhantomData,
198
                                    }
199
                                }
200
                            }
201

            
202
                            impl #generics_ref ::std::convert::Into<ATermRef<'a>> for #name_ref #generics_ref  {
203
                                fn into(self) -> ATermRef<'a> {
204
                                    self.term
205
                                }
206
                            }
207

            
208
                            impl #generics_ref ::std::fmt::Debug for #name_ref #generics_ref {
209
                                fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
210
                                    write!(f, "{:?}", self.term)
211
                                }
212
                            }
213

            
214
                            impl #generics_term Term<'a, '_> for #name_ref #generics_ref  {
215
                                delegate! {
216
                                    to self.term {
217
                                        fn protect(&self) -> ATerm;
218
                                        fn arg(&self, index: usize) -> ATermRef<'a>;
219
                                        fn arguments(&self) -> ATermArgs<'a>;
220
                                        fn copy(&self) -> ATermRef<'a>;
221
                                        fn get_head_symbol(&self) -> SymbolRef<'a>;
222
                                        fn iter(&self) -> TermIterator<'a>;
223
                                        fn index(&self) -> usize;
224
                                        fn shared(&self) -> &ATermIndex;
225
                                        fn annotation(&self) -> Option<usize>;
226
                                    }
227
                                }
228
                            }
229

            
230
                            impl #generics_ref ::std::borrow::Borrow<ATermRef<'a>> for #name_ref #generics_ref {
231
                                fn borrow(&self) -> &ATermRef<'a> {
232
                                    &self.term
233
                                }
234
                            }
235

            
236
                            impl #generics_ref Markable for #name_ref #generics_ref {
237
                                fn mark(&self, marker: &mut Marker) {
238
                                    self.term.mark(marker);
239
                                }
240

            
241
                                fn contains_term(&self, term: &ATermRef<'_>) -> bool {
242
                                    &self.term == term
243
                                }
244

            
245
                                fn contains_symbol(&self, symbol: &SymbolRef<'_>) -> bool {
246
                                    self.get_head_symbol() == *symbol
247
                                }
248

            
249
                                fn len(&self) -> usize {
250
                                    1
251
                                }
252
                            }
253

            
254
                            impl Transmutable for #name_ref #generics_static {
255
                                type Target #generics_ref = #name_ref #generics_ref;
256

            
257
                                fn transmute_lifetime<'a>(&self) -> &'a Self::Target #generics_ref {
258
                                    unsafe { ::std::mem::transmute::<&Self, &'a #name_ref #generics_ref>(self) }
259
                                }
260

            
261
                                fn transmute_lifetime_mut<'a>(&mut self) -> &'a mut Self::Target #generics_ref {
262
                                    unsafe { ::std::mem::transmute::<&mut Self, &'a mut #name_ref #generics_ref>(self) }
263
                                }
264
                            }
265
                        );
266

            
267
34
                        added.push(Item::Verbatim(generated));
268
                    }
269
                }
270
76
                Item::Impl(implementation) => {
271
76
                    if !implementation
272
76
                        .attrs
273
76
                        .iter()
274
76
                        .any(|attr| attr.meta.path().is_ident("merc_ignore"))
275
                    {
276
                        // Duplicate the implementation for the Ref struct that is generated above.
277
55
                        let mut ref_implementation = implementation.clone();
278

            
279
                        // Remove ignored functions
280
127
                        ref_implementation.items.retain(|item| match item {
281
127
                            syn::ImplItem::Fn(func) => {
282
160
                                !func.attrs.iter().any(|attr| attr.meta.path().is_ident("merc_ignore"))
283
                            }
284
                            _ => true,
285
127
                        });
286

            
287
55
                        if let syn::Type::Path(path) = ref_implementation.self_ty.as_ref() {
288
55
                            let path = if let Some(identifier) = path.path.get_ident() {
289
                                // Build an identifier with the postfix Ref<'_>
290
55
                                let name_ref = format_ident!("{}Ref", identifier);
291
55
                                parse_quote!(#name_ref <'_>)
292
                            } else {
293
                                let path_segments = &path.path.segments;
294

            
295
                                let _name_ref = format_ident!(
296
                                    "{}Ref",
297
                                    path_segments
298
                                        .first()
299
                                        .expect("Path should at least have an identifier")
300
                                        .ident
301
                                );
302
                                // let segments: Vec<syn::PathSegment> = path_segments.iter().skip(1).collect();
303
                                // parse_quote!(#name_ref #segments)
304
                                unimplemented!()
305
                            };
306

            
307
55
                            ref_implementation.self_ty = Box::new(syn::Type::Path(syn::TypePath { qself: None, path }));
308

            
309
55
                            added.push(Item::Verbatim(ref_implementation.into_token_stream()));
310
                        }
311
21
                    }
312
                }
313
45
                _ => {
314
45
                    // Ignore the rest.
315
45
                }
316
            }
317
        }
318

            
319
16
        content.append(&mut added);
320
    }
321

            
322
    // Hand the output tokens back to the compiler
323
16
    ast.into_token_stream()
324
16
}
325

            
326
#[cfg(test)]
327
mod tests {
328
    use std::str::FromStr;
329

            
330
    use super::*;
331

            
332
    #[test]
333
1
    fn test_macro() {
334
1
        let input = "
335
1
            mod anything {
336
1

            
337
1
                #[merc_term(test)]
338
1
                #[derive(Debug)]
339
1
                struct Test {
340
1
                    term: ATerm,
341
1
                }
342
1

            
343
1
                impl Test {
344
1
                    fn a_function() {
345
1

            
346
1
                    }
347
1
                }
348
1
            }
349
1
        ";
350

            
351
1
        let tokens = TokenStream::from_str(input).unwrap();
352
1
        let result = merc_derive_terms_impl(TokenStream::default(), tokens);
353

            
354
1
        println!("{result}");
355
1
    }
356
}