1
use std::borrow::Borrow;
2
use std::cmp::Ordering;
3
use std::fmt;
4
use std::hash::Hash;
5
use std::hash::Hasher;
6
use std::marker::PhantomData;
7
use std::ops::Deref;
8

            
9
use delegate::delegate;
10

            
11
use merc_collections::ProtectionIndex;
12
use merc_unsafety::StablePointer;
13

            
14
use crate::Markable;
15
use crate::storage::Marker;
16
use crate::storage::SharedSymbol;
17
use crate::storage::THREAD_TERM_POOL;
18

            
19
/// The public interface for a function symbol. Can be used to write generic
20
/// functions that accept both [Symbol] and [SymbolRef].
21
///
22
/// See [crate::Term] for more information on how to use this trait with two lifetimes.
23
pub trait Symb<'a, 'b> {
24
    /// Obtain the symbol's name.
25
    fn name(&'b self) -> &'a str;
26

            
27
    /// Obtain the symbol's arity.
28
    fn arity(&self) -> usize;
29

            
30
    /// Create a copy of the symbol reference.
31
    fn copy(&'b self) -> SymbolRef<'a>;
32

            
33
    /// Returns a unique index for the symbol.
34
    fn index(&self) -> usize;
35

            
36
    /// TODO: How to actually hide this implementation?
37
    fn shared(&self) -> &SymbolIndex;
38
}
39

            
40
/// An alias for the type that is used to reference into the symbol set.
41
pub type SymbolIndex = StablePointer<SharedSymbol>;
42

            
43
/// A reference to a function symbol in the symbol pool.
44
#[repr(transparent)]
45
#[derive(Hash, PartialEq, Eq, PartialOrd, Ord)]
46
pub struct SymbolRef<'a> {
47
    shared: SymbolIndex,
48
    marker: PhantomData<&'a ()>,
49
}
50

            
51
/// Check that the SymbolRef is the same size as a usize.
52
#[cfg(not(debug_assertions))]
53
const _: () = assert!(std::mem::size_of::<SymbolRef>() == std::mem::size_of::<usize>());
54

            
55
/// Check that the Option<SymbolRef> is the same size as a usize using niche value optimisation.
56
#[cfg(not(debug_assertions))]
57
const _: () = assert!(std::mem::size_of::<Option<SymbolRef>>() == std::mem::size_of::<usize>());
58

            
59
/// A reference to a function symbol with a known lifetime.
60
impl<'a> SymbolRef<'a> {
61
    /// Protects the symbol from garbage collection, yielding a `Symbol`.
62
1878260
    pub fn protect(&self) -> Symbol {
63
1878260
        THREAD_TERM_POOL.with_borrow(|tp| tp.protect_symbol(self))
64
1878260
    }
65

            
66
    /// Internal constructor to create a `SymbolRef` from a `SymbolIndex`.
67
    ///
68
    /// # Safety
69
    ///
70
    /// We must ensure that the lifetime `'a` is valid for the returned `SymbolRef`.
71
7472863237
    pub unsafe fn from_index(index: &SymbolIndex) -> SymbolRef<'a> {
72
7472863237
        SymbolRef {
73
7472863237
            shared: index.copy(),
74
7472863237
            marker: PhantomData,
75
7472863237
        }
76
7472863237
    }
77
}
78

            
79
impl SymbolRef<'_> {
80
    /// Internal constructo to convert any `Symb` to a `SymbolRef`.
81
50185081
    pub(crate) fn from_symbol<'a, 'b, S: Symb<'a, 'b>>(symbol: &'b S) -> Self {
82
50185081
        SymbolRef {
83
50185081
            shared: symbol.shared().copy(),
84
50185081
            marker: PhantomData,
85
50185081
        }
86
50185081
    }
87
}
88

            
89
impl<'a> Symb<'a, '_> for SymbolRef<'a> {
90
1753437
    fn name(&self) -> &'a str {
91
1753437
        unsafe { std::mem::transmute(self.shared.name()) }
92
1753437
    }
93

            
94
1682850811
    fn arity(&self) -> usize {
95
1682850811
        self.shared.arity()
96
1682850811
    }
97

            
98
7455261093
    fn copy(&self) -> SymbolRef<'a> {
99
7455261093
        unsafe { SymbolRef::from_index(self.shared()) }
100
7455261093
    }
101

            
102
    fn index(&self) -> usize {
103
        self.shared.index()
104
    }
105

            
106
7583783790
    fn shared(&self) -> &SymbolIndex {
107
7583783790
        &self.shared
108
7583783790
    }
109
}
110

            
111
impl Markable for SymbolRef<'_> {
112
1550
    fn mark(&self, marker: &mut Marker) {
113
1550
        marker.mark_symbol(self);
114
1550
    }
115

            
116
    fn contains_term(&self, _term: &crate::aterm::ATermRef<'_>) -> bool {
117
        false
118
    }
119

            
120
292890
    fn contains_symbol(&self, symbol: &SymbolRef<'_>) -> bool {
121
292890
        self == symbol
122
292890
    }
123

            
124
    fn len(&self) -> usize {
125
        1
126
    }
127
}
128

            
129
impl fmt::Display for SymbolRef<'_> {
130
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131
        write!(f, "{}", self.name())
132
    }
133
}
134

            
135
impl fmt::Debug for SymbolRef<'_> {
136
849505
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137
849505
        write!(f, "{}", self.name())
138
849505
    }
139
}
140

            
141
/// A protected function symbol, with the same interface as [SymbolRef].
142
pub struct Symbol {
143
    symbol: SymbolRef<'static>,
144
    root: ProtectionIndex,
145
}
146

            
147
impl Symbol {
148
    /// Create a new symbol with the given name and arity.
149
3213634
    pub fn new<N>(name: N, arity: usize) -> Symbol
150
3213634
    where
151
3213634
        N: Into<String> + AsRef<str>,
152
    {
153
3213634
        THREAD_TERM_POOL.with_borrow(|tp| tp.create_symbol(name, arity))
154
3213634
    }
155
}
156

            
157
impl Symbol {
158
    /// Internal constructor to create a symbol from an index and a root.
159
5283994
    pub(crate) unsafe fn from_index(index: &SymbolIndex, root: ProtectionIndex) -> Symbol {
160
5283994
        Self {
161
5283994
            symbol: unsafe { SymbolRef::from_index(index) },
162
5283994
            root,
163
5283994
        }
164
5283994
    }
165

            
166
    /// Returns the root index, i.e., the index in the protection set. See `SharedTermProtection`.
167
5274454
    pub fn root(&self) -> ProtectionIndex {
168
5274454
        self.root
169
5274454
    }
170

            
171
    /// Create a copy of the symbol reference.
172
5609624
    pub fn copy(&self) -> SymbolRef<'_> {
173
5609624
        self.symbol.copy()
174
5609624
    }
175
}
176

            
177
impl<'a> Symb<'a, '_> for &'a Symbol {
178
    delegate! {
179
        to self.symbol {
180
            fn name(&self) -> &'a str;
181
            fn arity(&self) -> usize;
182
            fn copy(&self) -> SymbolRef<'a>;
183
            fn index(&self) -> usize;
184
            fn shared(&self) -> &SymbolIndex;
185
        }
186
    }
187
}
188

            
189
impl<'a, 'b> Symb<'a, 'b> for Symbol
190
where
191
    'b: 'a,
192
{
193
    delegate! {
194
        to self.symbol {
195
            fn name(&self) -> &'a str;
196
            fn arity(&self) -> usize;
197
6500
            fn copy(&self) -> SymbolRef<'a>;
198
            fn index(&self) -> usize;
199
2911212
            fn shared(&self) -> &SymbolIndex;
200
        }
201
    }
202
}
203

            
204
impl Drop for Symbol {
205
5274454
    fn drop(&mut self) {
206
5274454
        THREAD_TERM_POOL.with_borrow(|tp| {
207
5274454
            tp.drop_symbol(self);
208
5274454
        })
209
5274454
    }
210
}
211

            
212
impl From<&SymbolRef<'_>> for Symbol {
213
    fn from(value: &SymbolRef) -> Self {
214
        value.protect()
215
    }
216
}
217

            
218
impl Clone for Symbol {
219
    fn clone(&self) -> Self {
220
        self.copy().protect()
221
    }
222
}
223

            
224
impl Deref for Symbol {
225
    type Target = SymbolRef<'static>;
226

            
227
5088411321
    fn deref(&self) -> &Self::Target {
228
5088411321
        &self.symbol
229
5088411321
    }
230
}
231

            
232
impl fmt::Display for Symbol {
233
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234
        write!(f, "{}", self.name())
235
    }
236
}
237

            
238
impl fmt::Debug for Symbol {
239
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
240
        write!(f, "{}", self.name())
241
    }
242
}
243

            
244
impl Hash for Symbol {
245
192086
    fn hash<H: Hasher>(&self, state: &mut H) {
246
192086
        self.copy().hash(state)
247
192086
    }
248
}
249

            
250
impl PartialEq for Symbol {
251
1864461
    fn eq(&self, other: &Self) -> bool {
252
1864461
        self.copy().eq(&other.copy())
253
1864461
    }
254
}
255

            
256
impl PartialEq<SymbolRef<'_>> for Symbol {
257
    fn eq(&self, other: &SymbolRef<'_>) -> bool {
258
        self.copy().eq(other)
259
    }
260
}
261

            
262
impl PartialOrd for Symbol {
263
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
264
        Some(self.cmp(other))
265
    }
266
}
267

            
268
impl Ord for Symbol {
269
    fn cmp(&self, other: &Self) -> Ordering {
270
        self.copy().cmp(&other.copy())
271
    }
272
}
273

            
274
impl Borrow<SymbolRef<'static>> for Symbol {
275
    fn borrow(&self) -> &SymbolRef<'static> {
276
        &self.symbol
277
    }
278
}
279

            
280
impl Eq for Symbol {}