1
use std::borrow::Borrow;
2
use std::cmp::Ordering;
3
use std::fmt;
4
use std::hash::Hash;
5
use std::hash::Hasher;
6
use std::marker::PhantomData;
7
use std::ops::Deref;
8

            
9
use delegate::delegate;
10

            
11
use merc_unsafety::ProtectionIndex;
12
use merc_unsafety::StablePointer;
13

            
14
use crate::Markable;
15
use crate::storage::Marker;
16
use crate::storage::SharedSymbol;
17
use crate::storage::THREAD_TERM_POOL;
18

            
19
/// The public interface for a function symbol. Can be used to write generic
20
/// functions that accept both [Symbol] and [SymbolRef].
21
///
22
/// See [crate::Term] for more information on how to use this trait with two lifetimes.
23
pub trait Symb<'a, 'b> {
24
    /// Obtain the symbol's name.
25
    fn name(&'b self) -> &'a str;
26

            
27
    /// Obtain the symbol's arity.
28
    fn arity(&self) -> usize;
29

            
30
    /// Create a copy of the symbol reference.
31
    fn copy(&'b self) -> SymbolRef<'a>;
32

            
33
    /// Returns a unique index for the symbol.
34
    fn index(&self) -> usize;
35

            
36
    /// TODO: How to actually hide this implementation?
37
    fn shared(&self) -> &SymbolIndex;
38
}
39

            
40
/// An alias for the type that is used to reference into the [SharedSymbol] set.
41
pub type SymbolIndex = StablePointer<SharedSymbol>;
42

            
43
/// A reference to a function symbol in the symbol pool.
44
#[repr(transparent)]
45
#[derive(Hash, PartialEq, Eq, PartialOrd, Ord)]
46
pub struct SymbolRef<'a> {
47
    shared: SymbolIndex,
48
    marker: PhantomData<&'a ()>,
49
}
50

            
51
/// Check that the SymbolRef is the same size as a usize.
52
#[cfg(not(debug_assertions))]
53
const _: () = assert!(std::mem::size_of::<SymbolRef>() == std::mem::size_of::<usize>());
54

            
55
/// Check that the Option<SymbolRef> is the same size as a usize using niche value optimisation.
56
#[cfg(not(debug_assertions))]
57
const _: () = assert!(std::mem::size_of::<Option<SymbolRef>>() == std::mem::size_of::<usize>());
58

            
59
/// A reference to a function symbol with a known lifetime.
60
impl<'a> SymbolRef<'a> {
61
    /// Protects the symbol from garbage collection, yielding a `Symbol`.
62
    pub fn protect(&self) -> Symbol {
63
        THREAD_TERM_POOL.with_borrow(|tp| tp.protect_symbol(self))
64
    }
65

            
66
    /// Internal constructor to create a [SymbolRef] from a [SymbolIndex].
67
    ///
68
    /// # Safety
69
    ///
70
    /// We must ensure that the lifetime `'a` is valid for the returned `SymbolRef`.
71
11948581336
    pub unsafe fn from_index(index: &SymbolIndex) -> SymbolRef<'a> {
72
11948581336
        SymbolRef {
73
11948581336
            shared: index.copy(),
74
11948581336
            marker: PhantomData,
75
11948581336
        }
76
11948581336
    }
77
}
78

            
79
impl SymbolRef<'_> {
80
    /// Internal constructor to convert any `Symb` to a `SymbolRef`.
81
80303399
    pub(crate) fn from_symbol<'a, 'b, S: Symb<'a, 'b>>(symbol: &'b S) -> Self {
82
80303399
        SymbolRef {
83
80303399
            shared: symbol.shared().copy(),
84
80303399
            marker: PhantomData,
85
80303399
        }
86
80303399
    }
87
}
88

            
89
impl<'a> Symb<'a, '_> for SymbolRef<'a> {
90
2238638
    fn name(&self) -> &'a str {
91
2238638
        unsafe { std::mem::transmute(self.shared.name()) }
92
2238638
    }
93

            
94
2848501290
    fn arity(&self) -> usize {
95
2848501290
        self.shared.arity()
96
2848501290
    }
97

            
98
11925666698
    fn copy(&self) -> SymbolRef<'a> {
99
11925666698
        unsafe { SymbolRef::from_index(self.shared()) }
100
11925666698
    }
101

            
102
    fn index(&self) -> usize {
103
        self.shared.index()
104
    }
105

            
106
12058320556
    fn shared(&self) -> &SymbolIndex {
107
12058320556
        &self.shared
108
12058320556
    }
109
}
110

            
111
impl Markable for SymbolRef<'_> {
112
4733568
    fn mark(&self, marker: &mut Marker) {
113
4733568
        marker.mark_symbol(self);
114
4733568
    }
115

            
116
    fn contains_term(&self, _term: &crate::aterm::ATermRef<'_>) -> bool {
117
        false
118
    }
119

            
120
498640
    fn contains_symbol(&self, symbol: &SymbolRef<'_>) -> bool {
121
498640
        self == symbol
122
498640
    }
123

            
124
    fn len(&self) -> usize {
125
        1
126
    }
127
}
128

            
129
impl fmt::Display for SymbolRef<'_> {
130
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131
        write!(f, "{}", self.name())
132
    }
133
}
134

            
135
impl fmt::Debug for SymbolRef<'_> {
136
1023083
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137
1023083
        write!(f, "{}", self.name())
138
1023083
    }
139
}
140

            
141
/// A protected function symbol, with the same interface as [SymbolRef].
142
pub struct Symbol {
143
    symbol: SymbolRef<'static>,
144
    root: ProtectionIndex,
145
}
146

            
147
impl Symbol {
148
    /// Create a new symbol with the given name and arity.
149
4775012
    pub fn new<N>(name: N, arity: usize) -> Symbol
150
4775012
    where
151
4775012
        N: Into<String> + AsRef<str>,
152
    {
153
4775012
        THREAD_TERM_POOL.with_borrow(|tp| tp.create_symbol(name, arity))
154
4775012
    }
155
}
156

            
157
impl Symbol {
158
    /// Internal constructor to create a symbol from an index and a root.
159
5017692
    pub(crate) unsafe fn from_index(index: &SymbolIndex, root: ProtectionIndex) -> Symbol {
160
5017692
        Self {
161
5017692
            symbol: unsafe { SymbolRef::from_index(index) },
162
5017692
            root,
163
5017692
        }
164
5017692
    }
165

            
166
    /// Returns the root index, i.e., the index in the protection set. See [crate::storage::SharedTermProtection].
167
5002140
    pub fn root(&self) -> ProtectionIndex {
168
5002140
        self.root
169
5002140
    }
170

            
171
    /// Create a copy of the symbol reference.
172
3700
    pub fn copy(&self) -> SymbolRef<'_> {
173
3700
        self.symbol.copy()
174
3700
    }
175
}
176

            
177
impl<'a> Symb<'a, '_> for &'a Symbol {
178
    delegate! {
179
        to self.symbol {
180
            fn name(&self) -> &'a str;
181
            fn arity(&self) -> usize;
182
            fn copy(&self) -> SymbolRef<'a>;
183
            fn index(&self) -> usize;
184
            fn shared(&self) -> &SymbolIndex;
185
        }
186
    }
187
}
188

            
189
impl<'a, 'b> Symb<'a, 'b> for Symbol
190
where
191
    'b: 'a,
192
{
193
    delegate! {
194
        to self.symbol {
195
            fn name(&self) -> &'a str;
196
3323277
            fn arity(&self) -> usize;
197
12532
            fn copy(&self) -> SymbolRef<'a>;
198
            fn index(&self) -> usize;
199
1661638
            fn shared(&self) -> &SymbolIndex;
200
        }
201
    }
202
}
203

            
204
impl Drop for Symbol {
205
5002140
    fn drop(&mut self) {
206
5002140
        THREAD_TERM_POOL.with_borrow(|tp| {
207
5002140
            tp.drop_symbol(self);
208
5002140
        })
209
5002140
    }
210
}
211

            
212
impl From<&SymbolRef<'_>> for Symbol {
213
    fn from(value: &SymbolRef) -> Self {
214
        value.protect()
215
    }
216
}
217

            
218
impl Clone for Symbol {
219
    fn clone(&self) -> Self {
220
        self.copy().protect()
221
    }
222
}
223

            
224
impl Deref for Symbol {
225
    type Target = SymbolRef<'static>;
226

            
227
8141907549
    fn deref(&self) -> &Self::Target {
228
8141907549
        &self.symbol
229
8141907549
    }
230
}
231

            
232
impl fmt::Display for Symbol {
233
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234
        write!(f, "{}", self.name())
235
    }
236
}
237

            
238
impl fmt::Debug for Symbol {
239
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
240
        write!(f, "{}", self.name())
241
    }
242
}
243

            
244
impl Hash for Symbol {
245
    fn hash<H: Hasher>(&self, state: &mut H) {
246
        self.copy().hash(state)
247
    }
248
}
249

            
250
impl PartialEq for Symbol {
251
1
    fn eq(&self, other: &Self) -> bool {
252
1
        self.copy().eq(&other.copy())
253
1
    }
254
}
255

            
256
impl PartialEq<SymbolRef<'_>> for Symbol {
257
    fn eq(&self, other: &SymbolRef<'_>) -> bool {
258
        self.copy().eq(other)
259
    }
260
}
261

            
262
impl PartialOrd for Symbol {
263
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
264
        Some(self.cmp(other))
265
    }
266
}
267

            
268
impl Ord for Symbol {
269
    fn cmp(&self, other: &Self) -> Ordering {
270
        self.copy().cmp(&other.copy())
271
    }
272
}
273

            
274
impl Borrow<SymbolRef<'static>> for Symbol {
275
    fn borrow(&self) -> &SymbolRef<'static> {
276
        &self.symbol
277
    }
278
}
279

            
280
impl Eq for Symbol {}