1
use proc_macro2::TokenStream;
2

            
3
use quote::ToTokens;
4
use quote::format_ident;
5
use quote::quote;
6
use syn::Item;
7
use syn::ItemMod;
8
use syn::parse_quote;
9

            
10
16
pub(crate) fn merc_derive_terms_impl(_attributes: TokenStream, input: TokenStream) -> TokenStream {
11
    // Parse the input tokens into a syntax tree
12
16
    let mut ast: ItemMod = syn::parse2(input.clone()).expect("merc_term can only be applied to a module");
13

            
14
16
    if let Some((_, content)) = &mut ast.content {
15
        // Generated code blocks are added to this list.
16
16
        let mut added = vec![];
17

            
18
155
        for item in content.iter_mut() {
19
155
            match item {
20
34
                Item::Struct(object) => {
21
                    // If the struct is annotated with term we process it as a term.
22
88
                    if let Some(attr) = object.attrs.iter().find(|attr| attr.meta.path().is_ident("merc_term")) {
23
                        // The #term(assertion) annotation must contain an assertion
24
34
                        let assertion = match attr.parse_args::<syn::Ident>() {
25
34
                            Ok(assertion) => {
26
34
                                let assertion_msg = format!("{assertion}");
27
34
                                quote!(
28
                                    debug_assert!(#assertion(&term), "Term {:?} does not satisfy {}", term, #assertion_msg)
29
                                )
30
                            }
31
                            Err(_x) => {
32
                                quote!()
33
                            }
34
                        };
35

            
36
                        // Add the expected derive macros to the input struct.
37
34
                        object
38
34
                            .attrs
39
34
                            .push(parse_quote!(#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]));
40

            
41
                        // ALL structs in this module must contain the term.
42
34
                        assert!(
43
34
                            object.fields.iter().any(|field| {
44
34
                                if let Some(name) = &field.ident {
45
34
                                    name == "term"
46
                                } else {
47
                                    false
48
                                }
49
34
                            }),
50
                            "The struct {} in mod {} has no field 'term: ATerm'",
51
                            object.ident,
52
                            ast.ident
53
                        );
54

            
55
34
                        let name = format_ident!("{}", object.ident);
56

            
57
                        // Simply the generics from the struct.
58
34
                        let generics = object.generics.clone();
59

            
60
                        // Helper to create generics with added lifetimes.
61
136
                        fn create_generics_with_lifetimes(
62
136
                            base_generics: &syn::Generics,
63
136
                            lifetime_names: &[&str],
64
136
                        ) -> syn::Generics {
65
136
                            let mut generics = base_generics.clone();
66
170
                            for &lifetime_name in lifetime_names {
67
170
                                generics.params.push(syn::GenericParam::Lifetime(syn::LifetimeParam {
68
170
                                    attrs: vec![],
69
170
                                    lifetime: syn::Lifetime::new(lifetime_name, proc_macro2::Span::call_site()),
70
170
                                    bounds: syn::punctuated::Punctuated::new(),
71
170
                                    colon_token: None,
72
170
                                }));
73
170
                            }
74
136
                            generics
75
136
                        }
76

            
77
                        // The generics from the struct with <'a, 'b> added for the Term trait.
78
34
                        let generics_term = create_generics_with_lifetimes(&object.generics, &["'a", "'b"]);
79

            
80
                        // Only 'a prepended for the Ref<'a> struct.
81
34
                        let generics_ref = create_generics_with_lifetimes(&object.generics, &["'a"]);
82

            
83
                        // Only 'b prepended for the Ref<'b> struct.
84
34
                        let generics_ref_b = create_generics_with_lifetimes(&object.generics, &["'b"]);
85

            
86
                        // Only 'static prepended for the Ref<'static> struct.
87
34
                        let generics_static = create_generics_with_lifetimes(&object.generics, &["'static"]);
88

            
89
                        // Handle PhantomData generics - use void type if no generics exist
90
34
                        let generics_phantom = if object.generics.params.is_empty() {
91
34
                            quote!(<()>)
92
                        } else {
93
                            generics.to_token_stream()
94
                        };
95

            
96
                        // Add a <name>Ref struct that contains the ATermRef<'a> and
97
                        // the implementation and both protect and borrow. Also add
98
                        // the conversion from and to an ATerm.
99
34
                        let name_ref = format_ident!("{}Ref", object.ident);
100
34
                        let generated: TokenStream = quote!(
101
                            impl #generics #name #generics {
102
                                pub fn copy #generics_ref(&'a self) -> #name_ref #generics_ref {
103
                                    self.term.copy().into()
104
                                }
105
                            }
106

            
107
                            impl #generics From<ATerm> for #name #generics {
108
                                fn from(term: ATerm) -> #name {
109
                                    #assertion;
110
                                    #name {
111
                                        term
112
                                    }
113
                                }
114
                            }
115

            
116
                            impl #generics ::std::convert::Into<ATerm> for #name #generics{
117
                                fn into(self) -> ATerm {
118
                                    self.term
119
                                }
120
                            }
121

            
122
                            impl #generics ::std::ops::Deref for #name #generics{
123
                                type Target = ATerm;
124

            
125
                                fn deref(&self) -> &Self::Target {
126
                                    &self.term
127
                                }
128
                            }
129

            
130
                            impl #generics ::std::borrow::Borrow<ATerm> for #name #generics{
131
                                fn borrow(&self) -> &ATerm {
132
                                    &self.term
133
                                }
134
                            }
135

            
136
                            impl #generics Markable for #name #generics{
137
                                fn mark(&self, marker: &mut Marker) {
138
                                    self.term.mark(marker);
139
                                }
140

            
141
                                fn contains_term(&self, term: &ATermRef<'_>) -> bool {
142
                                    &self.term.copy() == term
143
                                }
144

            
145
                                fn contains_symbol(&self, symbol: &SymbolRef<'_>) -> bool {
146
                                    self.get_head_symbol() == *symbol
147
                                }
148

            
149
                                fn len(&self) -> usize {
150
                                    1
151
                                }
152
                            }
153

            
154
                            impl #generics_term Term<'a, 'b> for #name #generics where 'b: 'a {
155
                                delegate! {
156
                                    to self.term {
157
                                        fn protect(&self) -> ATerm;
158
                                        fn arg(&'b self, index: usize) -> ATermRef<'a>;
159
                                        fn arguments(&'b self) -> ATermArgs<'a>;
160
                                        fn copy(&'b self) -> ATermRef<'a>;
161
                                        fn get_head_symbol(&'b self) -> SymbolRef<'a>;
162
                                        fn iter(&'b self) -> TermIterator<'a>;
163
                                        fn index(&self) -> usize;
164
                                        fn shared(&self) -> &ATermIndex;
165
                                        fn annotation(&self) -> Option<usize>;
166
                                    }
167
                                }
168
                            }
169

            
170
                            #[derive(Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
171
                            pub struct #name_ref #generics_ref {
172
                                pub(crate) term: ATermRef<'a>,
173
                                _marker: ::std::marker::PhantomData #generics_phantom,
174
                            }
175

            
176
                            impl #generics_ref  #name_ref #generics_ref  {
177
                                pub fn copy<'b>(&'b self) -> #name_ref #generics_ref_b{
178
                                    self.term.copy().into()
179
                                }
180

            
181
                                pub fn protect(&self) -> #name {
182
                                    self.term.protect().into()
183
                                }
184
                            }
185

            
186
                            impl #generics_ref  From<ATermRef<'a>> for #name_ref #generics_ref {
187
                                fn from(term: ATermRef<'a>) -> #name_ref #generics_ref  {
188
                                    #assertion;
189
                                    #name_ref {
190
                                        term,
191
                                        _marker: ::std::marker::PhantomData,
192
                                    }
193
                                }
194
                            }
195

            
196
                            impl #generics_ref  Into<ATermRef<'a>> for #name_ref #generics_ref  {
197
                                fn into(self) -> ATermRef<'a> {
198
                                    self.term
199
                                }
200
                            }
201

            
202
                            impl #generics_term Term<'a, '_> for #name_ref #generics_ref  {
203
                                delegate! {
204
                                    to self.term {
205
                                        fn protect(&self) -> ATerm;
206
                                        fn arg(&self, index: usize) -> ATermRef<'a>;
207
                                        fn arguments(&self) -> ATermArgs<'a>;
208
                                        fn copy(&self) -> ATermRef<'a>;
209
                                        fn get_head_symbol(&self) -> SymbolRef<'a>;
210
                                        fn iter(&self) -> TermIterator<'a>;
211
                                        fn index(&self) -> usize;
212
                                        fn shared(&self) -> &ATermIndex;
213
                                        fn annotation(&self) -> Option<usize>;
214
                                    }
215
                                }
216
                            }
217

            
218
                            impl #generics_ref ::std::borrow::Borrow<ATermRef<'a>> for #name_ref #generics_ref {
219
                                fn borrow(&self) -> &ATermRef<'a> {
220
                                    &self.term
221
                                }
222
                            }
223

            
224
                            impl #generics_ref Markable for #name_ref #generics_ref {
225
                                fn mark(&self, marker: &mut Marker) {
226
                                    self.term.mark(marker);
227
                                }
228

            
229
                                fn contains_term(&self, term: &ATermRef<'_>) -> bool {
230
                                    &self.term == term
231
                                }
232

            
233
                                fn contains_symbol(&self, symbol: &SymbolRef<'_>) -> bool {
234
                                    self.get_head_symbol() == *symbol
235
                                }
236

            
237
                                fn len(&self) -> usize {
238
                                    1
239
                                }
240
                            }
241

            
242
                            impl Transmutable for #name_ref #generics_static {
243
                                type Target #generics_ref = #name_ref #generics_ref;
244

            
245
                                fn transmute_lifetime<'a>(&self) -> &'a Self::Target #generics_ref {
246
                                    unsafe { ::std::mem::transmute::<&Self, &'a #name_ref #generics_ref>(self) }
247
                                }
248

            
249
                                fn transmute_lifetime_mut<'a>(&mut self) -> &'a mut Self::Target #generics_ref {
250
                                    unsafe { ::std::mem::transmute::<&mut Self, &'a mut #name_ref #generics_ref>(self) }
251
                                }
252
                            }
253
                        );
254

            
255
34
                        added.push(Item::Verbatim(generated));
256
                    }
257
                }
258
76
                Item::Impl(implementation) => {
259
76
                    if !implementation
260
76
                        .attrs
261
76
                        .iter()
262
76
                        .any(|attr| attr.meta.path().is_ident("merc_ignore"))
263
                    {
264
                        // Duplicate the implementation for the Ref struct that is generated above.
265
55
                        let mut ref_implementation = implementation.clone();
266

            
267
                        // Remove ignored functions
268
127
                        ref_implementation.items.retain(|item| match item {
269
127
                            syn::ImplItem::Fn(func) => {
270
160
                                !func.attrs.iter().any(|attr| attr.meta.path().is_ident("merc_ignore"))
271
                            }
272
                            _ => true,
273
127
                        });
274

            
275
55
                        if let syn::Type::Path(path) = ref_implementation.self_ty.as_ref() {
276
55
                            let path = if let Some(identifier) = path.path.get_ident() {
277
                                // Build an identifier with the postfix Ref<'_>
278
55
                                let name_ref = format_ident!("{}Ref", identifier);
279
55
                                parse_quote!(#name_ref <'_>)
280
                            } else {
281
                                let path_segments = &path.path.segments;
282

            
283
                                let _name_ref = format_ident!(
284
                                    "{}Ref",
285
                                    path_segments
286
                                        .first()
287
                                        .expect("Path should at least have an identifier")
288
                                        .ident
289
                                );
290
                                // let segments: Vec<syn::PathSegment> = path_segments.iter().skip(1).collect();
291
                                // parse_quote!(#name_ref #segments)
292
                                unimplemented!()
293
                            };
294

            
295
55
                            ref_implementation.self_ty = Box::new(syn::Type::Path(syn::TypePath { qself: None, path }));
296

            
297
55
                            added.push(Item::Verbatim(ref_implementation.into_token_stream()));
298
                        }
299
21
                    }
300
                }
301
45
                _ => {
302
45
                    // Ignore the rest.
303
45
                }
304
            }
305
        }
306

            
307
16
        content.append(&mut added);
308
    }
309

            
310
    // Hand the output tokens back to the compiler
311
16
    ast.into_token_stream()
312
16
}
313

            
314
#[cfg(test)]
315
mod tests {
316
    use std::str::FromStr;
317

            
318
    use super::*;
319

            
320
    #[test]
321
1
    fn test_macro() {
322
1
        let input = "
323
1
            mod anything {
324
1

            
325
1
                #[merc_term(test)]
326
1
                #[derive(Debug)]
327
1
                struct Test {
328
1
                    term: ATerm,
329
1
                }
330
1

            
331
1
                impl Test {
332
1
                    fn a_function() {
333
1

            
334
1
                    }
335
1
                }
336
1
            }
337
1
        ";
338

            
339
1
        let tokens = TokenStream::from_str(input).unwrap();
340
1
        let result = merc_derive_terms_impl(TokenStream::default(), tokens);
341

            
342
1
        println!("{result}");
343
1
    }
344
}