1
use std::alloc::Layout;
2
use std::alloc::LayoutError;
3
use std::fmt;
4
use std::hash::Hash;
5
use std::mem::offset_of;
6
use std::ptr;
7
use std::ptr::NonNull;
8
use std::ptr::slice_from_raw_parts_mut;
9

            
10
use equivalent::Equivalent;
11
use merc_unsafety::Erasable;
12
use merc_unsafety::ErasedPtr;
13
use merc_unsafety::SliceDst;
14
use merc_unsafety::repr_c;
15

            
16
use crate::ATermRef;
17
use crate::Symb;
18
use crate::SymbolRef;
19
use crate::Term;
20

            
21
/// The underlying type of terms that are maximally shared.
22
///
23
/// # Details
24
///
25
/// Uses a C representation and is a dynamically sized type for compact memory
26
/// usage, implementing [SliceDst] and [Erasable]. This allows us to avoid
27
/// storing the length and capacity of an underlying vector. As such this is
28
/// even more compact than `smallvec`. Arguments are stored as [ATermRef] slices.
29
#[repr(C)]
30
pub struct SharedTerm {
31
    symbol: SymbolRef<'static>,
32
    arguments: [ATermRef<'static>],
33
}
34

            
35
impl PartialEq for SharedTerm {
36
    fn eq(&self, other: &Self) -> bool {
37
        self.symbol == other.symbol && self.arguments() == other.arguments()
38
    }
39
}
40

            
41
impl Eq for SharedTerm {}
42

            
43
/// Note that the length is stored in the symbol's arity
44
unsafe impl SliceDst for SharedTerm {
45
2
    fn layout_for(len: usize) -> Result<Layout, LayoutError> {
46
2
        let header_layout = Layout::new::<SymbolRef<'static>>();
47
2
        let slice_layout = Layout::array::<ATermRef<'static>>(len)?;
48

            
49
2
        repr_c(&[header_layout, slice_layout])
50
2
    }
51

            
52
1
    fn retype(ptr: std::ptr::NonNull<[()]>) -> NonNull<Self> {
53
1
        unsafe { NonNull::new_unchecked(ptr.as_ptr() as *mut _) }
54
1
    }
55

            
56
    fn length(&self) -> usize {
57
        self.symbol().arity()
58
    }
59
}
60

            
61
unsafe impl Erasable for SharedTerm {
62
    fn erase(this: NonNull<Self>) -> ErasedPtr {
63
        this.cast()
64
    }
65

            
66
    unsafe fn unerase(this: ErasedPtr) -> NonNull<Self> {
67
        unsafe {
68
            let symbol: SymbolRef = ptr::read(this.as_ptr().cast());
69
            let len = symbol.arity();
70

            
71
            let raw = NonNull::new_unchecked(slice_from_raw_parts_mut(this.as_ptr().cast(), len));
72
            Self::retype(raw)
73
        }
74
    }
75
}
76

            
77
impl fmt::Debug for SharedTerm {
78
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79
        write!(
80
            f,
81
            "SharedTerm {{ symbol: {:?}, arguments: {:?} }}",
82
            self.symbol,
83
            self.arguments()
84
        )
85
    }
86
}
87

            
88
impl SharedTerm {
89
    /// Returns the symbol of the term.
90
11877754267
    pub fn symbol(&self) -> &SymbolRef<'_> {
91
11877754267
        &self.symbol
92
11877754267
    }
93

            
94
    /// Returns the arguments of the term.
95
648972442
    pub fn arguments(&self) -> &[ATermRef<'static>] {
96
648972442
        &self.arguments
97
648972442
    }
98

            
99
    /// Returns a unique index for this shared term.
100
203850976
    pub fn index(&self) -> usize {
101
203850976
        self as *const Self as *const u8 as usize
102
203850976
    }
103

            
104
    /// Returns the length for a [SharedTermLookup]
105
    pub(crate) fn length_for(object: &SharedTermLookup) -> usize {
106
        object.arguments.len()
107
    }
108

            
109
    /// Constructs an uninitialised ptr from a [SharedTermLookup]
110
1
    pub(crate) unsafe fn construct(ptr: *mut SharedTerm, object: &SharedTermLookup) {
111
1
        let header_layout = Layout::new::<SymbolRef<'static>>();
112
1
        let slice_layout =
113
1
            Layout::array::<ATermRef<'static>>(object.arguments.len()).expect("Layout should not exceed isize");
114

            
115
1
        let (_, slice_offset) = header_layout
116
1
            .extend(slice_layout)
117
1
            .expect("Layout should not exceed isize");
118
        unsafe {
119
1
            ptr.cast::<SymbolRef<'static>>()
120
1
                .write(SymbolRef::from_index(object.symbol.shared()));
121

            
122
2
            for (index, argument) in object.arguments.iter().enumerate() {
123
2
                ptr.byte_offset(slice_offset as isize)
124
2
                    .cast::<ATermRef<'static>>()
125
2
                    .add(index)
126
2
                    .write(ATermRef::from_index(argument.shared()));
127
2
            }
128
        }
129
1
    }
130
}
131

            
132
impl Hash for SharedTerm {
133
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
134
        self.symbol.hash(state);
135
        self.arguments().hash(state);
136
    }
137
}
138

            
139
/// A cheap reference to the elements of a [SharedTerm] that can be used for
140
/// lookup of terms without allocating.
141
pub(crate) struct SharedTermLookup<'a> {
142
    pub(crate) symbol: SymbolRef<'a>,
143
    pub(crate) arguments: &'a [ATermRef<'a>],
144
}
145

            
146
impl Equivalent<SharedTerm> for SharedTermLookup<'_> {
147
    fn equivalent(&self, other: &SharedTerm) -> bool {
148
        self.symbol == other.symbol && self.arguments == other.arguments()
149
    }
150
}
151

            
152
/// This Hash implement must be the same as for [SharedTerm]
153
impl Hash for SharedTermLookup<'_> {
154
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
155
        self.symbol.hash(state);
156
        self.arguments.hash(state);
157
    }
158
}
159

            
160
// `symbol` must be at offset 0 in all term representations so that any pointer to a term.
161
const _: () = assert!(offset_of!(SharedTerm, symbol) == 0);
162

            
163
#[cfg(test)]
164
mod tests {
165
    use allocator_api2::alloc::Global;
166

            
167
    use merc_unsafety::AllocatorDst;
168
    #[cfg(not(debug_assertions))]
169
    use merc_unsafety::SliceDst;
170

            
171
    use crate::ATerm;
172
    use crate::Symbol;
173
    use crate::Term;
174
    use crate::storage::SharedTerm;
175
    use crate::storage::SharedTermLookup;
176

            
177
    #[test]
178
    #[cfg(not(debug_assertions))]
179
    fn test_shared_symbol_size() {
180
        // Cannot be a const assertion since the size depends on the length.
181
        assert_eq!(
182
            SharedTerm::layout_for(0)
183
                .expect("The layout should not overflow")
184
                .size(),
185
            1 * std::mem::size_of::<usize>(),
186
            "A SharedTerm without arguments should be the same size as the Symbol"
187
        );
188

            
189
        // TODO: Shared terms are still too large.
190
        // assert_eq!(
191
        //     SharedTerm::layout_for(2)
192
        //         .expect("The layout should not overflow")
193
        //         .size(),
194
        //     3 * std::mem::size_of::<usize>(),
195
        //     "A SharedTerm with arity two should be the same size as the Symbol and two ATermRef arguments"
196
        // );
197
    }
198

            
199
    #[test]
200
1
    fn test_shared_term_lookup() {
201
1
        let symbol = Symbol::new("a", 2);
202

            
203
1
        let term = ATerm::constant(&Symbol::new("b", 0));
204

            
205
1
        let lookup = SharedTermLookup {
206
1
            symbol: symbol.copy(),
207
1
            arguments: &[term.copy(), term.copy()],
208
1
        };
209

            
210
1
        let ptr = Global.allocate_slice_dst(2).expect("Could not allocate slice dst");
211

            
212
        unsafe {
213
1
            SharedTerm::construct(ptr.as_ptr(), &lookup);
214
1
            assert_eq!(
215
1
                *ptr.as_ref().symbol(),
216
1
                symbol.copy(),
217
                "The symbol should match the lookup symbol"
218
            );
219
1
            assert_eq!(
220
1
                ptr.as_ref().arguments()[0],
221
1
                term.copy(),
222
                "The arguments should match the lookup arguments"
223
            );
224
1
            assert_eq!(
225
1
                ptr.as_ref().arguments()[1],
226
1
                term.copy(),
227
                "The arguments should match the lookup arguments"
228
            );
229
        }
230

            
231
1
        Global.deallocate_slice_dst(ptr, 2);
232
1
    }
233
}