1
use std::borrow::Borrow;
2
use std::cmp::Ordering;
3
use std::fmt;
4
use std::hash::Hash;
5
use std::hash::Hasher;
6
use std::marker::PhantomData;
7
use std::ops::Deref;
8

            
9
use delegate::delegate;
10

            
11
use merc_collections::ProtectionIndex;
12
use merc_unsafety::StablePointer;
13

            
14
use crate::Markable;
15
use crate::storage::Marker;
16
use crate::storage::SharedSymbol;
17
use crate::storage::THREAD_TERM_POOL;
18

            
19
/// The public interface for a function symbol. Can be used to write generic
20
/// functions that accept both [Symbol] and [SymbolRef].
21
///
22
/// See [crate::Term] for more information on how to use this trait with two lifetimes.
23
pub trait Symb<'a, 'b> {
24
    /// Obtain the symbol's name.
25
    fn name(&'b self) -> &'a str;
26

            
27
    /// Obtain the symbol's arity.
28
    fn arity(&self) -> usize;
29

            
30
    /// Create a copy of the symbol reference.
31
    fn copy(&'b self) -> SymbolRef<'a>;
32

            
33
    /// Returns a unique index for the symbol.
34
    fn index(&self) -> usize;
35

            
36
    /// TODO: How to actually hide this implementation?
37
    fn shared(&self) -> &SymbolIndex;
38
}
39

            
40
/// An alias for the type that is used to reference into the symbol set.
41
pub type SymbolIndex = StablePointer<SharedSymbol>;
42

            
43
/// A reference to a function symbol in the symbol pool.
44
#[repr(transparent)]
45
#[derive(Hash, PartialEq, Eq, PartialOrd, Ord)]
46
pub struct SymbolRef<'a> {
47
    shared: SymbolIndex,
48
    marker: PhantomData<&'a ()>,
49
}
50

            
51
/// Check that the SymbolRef is the same size as a usize.
52
#[cfg(not(debug_assertions))]
53
const _: () = assert!(std::mem::size_of::<SymbolRef>() == std::mem::size_of::<usize>());
54

            
55
/// Check that the Option<SymbolRef> is the same size as a usize using niche value optimisation.
56
#[cfg(not(debug_assertions))]
57
const _: () = assert!(std::mem::size_of::<Option<SymbolRef>>() == std::mem::size_of::<usize>());
58

            
59
/// A reference to a function symbol with a known lifetime.
60
impl<'a> SymbolRef<'a> {
61
    /// Protects the symbol from garbage collection, yielding a `Symbol`.
62
1885040
    pub fn protect(&self) -> Symbol {
63
1885040
        THREAD_TERM_POOL.with_borrow(|tp| tp.protect_symbol(self))
64
1885040
    }
65

            
66
    /// Internal constructor to create a `SymbolRef` from a `SymbolIndex`.
67
    ///
68
    /// # Safety
69
    ///
70
    /// We must ensure that the lifetime `'a` is valid for the returned `SymbolRef`.
71
6788307009
    pub unsafe fn from_index(index: &SymbolIndex) -> SymbolRef<'a> {
72
6788307009
        SymbolRef {
73
6788307009
            shared: index.copy(),
74
6788307009
            marker: PhantomData,
75
6788307009
        }
76
6788307009
    }
77
}
78

            
79
impl SymbolRef<'_> {
80
    /// Internal constructo to convert any `Symb` to a `SymbolRef`.
81
50053331
    pub(crate) fn from_symbol<'a, 'b>(symbol: &'b impl Symb<'a, 'b>) -> Self {
82
50053331
        SymbolRef {
83
50053331
            shared: symbol.shared().copy(),
84
50053331
            marker: PhantomData,
85
50053331
        }
86
50053331
    }
87
}
88

            
89
impl<'a> Symb<'a, '_> for SymbolRef<'a> {
90
1158954
    fn name(&self) -> &'a str {
91
1158954
        unsafe { std::mem::transmute(self.shared.name()) }
92
1158954
    }
93

            
94
1672226641
    fn arity(&self) -> usize {
95
1672226641
        self.shared.arity()
96
1672226641
    }
97

            
98
6775445612
    fn copy(&self) -> SymbolRef<'a> {
99
6775445612
        unsafe { SymbolRef::from_index(self.shared()) }
100
6775445612
    }
101

            
102
    fn index(&self) -> usize {
103
        self.shared.index()
104
    }
105

            
106
6890297112
    fn shared(&self) -> &SymbolIndex {
107
6890297112
        &self.shared
108
6890297112
    }
109
}
110

            
111
impl Markable for SymbolRef<'_> {
112
    fn mark(&self, marker: &mut Marker) {
113
        marker.mark_symbol(self);
114
    }
115

            
116
    fn contains_term(&self, _term: &crate::aterm::ATermRef<'_>) -> bool {
117
        false
118
    }
119

            
120
273110
    fn contains_symbol(&self, symbol: &SymbolRef<'_>) -> bool {
121
273110
        self == symbol
122
273110
    }
123

            
124
    fn len(&self) -> usize {
125
        1
126
    }
127
}
128

            
129
impl fmt::Display for SymbolRef<'_> {
130
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131
        write!(f, "{}", self.name())
132
    }
133
}
134

            
135
impl fmt::Debug for SymbolRef<'_> {
136
566168
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137
566168
        write!(f, "{}", self.name())
138
566168
    }
139
}
140

            
141
/// A protected function symbol, with the same interface as [SymbolRef].
142
pub struct Symbol {
143
    symbol: SymbolRef<'static>,
144
    root: ProtectionIndex,
145
}
146

            
147
impl Symbol {
148
    /// Create a new symbol with the given name and arity.
149
3143024
    pub fn new(name: impl Into<String> + AsRef<str>, arity: usize) -> Symbol {
150
3143024
        THREAD_TERM_POOL.with_borrow(|tp| tp.create_symbol(name, arity))
151
3143024
    }
152
}
153

            
154
impl Symbol {
155
    /// Internal constructor to create a symbol from an index and a root.
156
5220164
    pub(crate) unsafe fn from_index(index: &SymbolIndex, root: ProtectionIndex) -> Symbol {
157
5220164
        Self {
158
5220164
            symbol: unsafe { SymbolRef::from_index(index) },
159
5220164
            root,
160
5220164
        }
161
5220164
    }
162

            
163
    /// Returns the root index, i.e., the index in the protection set. See `SharedTermProtection`.
164
5217864
    pub fn root(&self) -> ProtectionIndex {
165
5217864
        self.root
166
5217864
    }
167

            
168
    /// Create a copy of the symbol reference.
169
5629944
    pub fn copy(&self) -> SymbolRef<'_> {
170
5629944
        self.symbol.copy()
171
5629944
    }
172
}
173

            
174
impl<'a> Symb<'a, '_> for &'a Symbol {
175
    delegate! {
176
        to self.symbol {
177
            fn name(&self) -> &'a str;
178
            fn arity(&self) -> usize;
179
            fn copy(&self) -> SymbolRef<'a>;
180
            fn index(&self) -> usize;
181
            fn shared(&self) -> &SymbolIndex;
182
        }
183
    }
184
}
185

            
186
impl<'a, 'b> Symb<'a, 'b> for Symbol
187
where
188
    'b: 'a,
189
{
190
    delegate! {
191
        to self.symbol {
192
            fn name(&self) -> &'a str;
193
            fn arity(&self) -> usize;
194
4840
            fn copy(&self) -> SymbolRef<'a>;
195
            fn index(&self) -> usize;
196
2779052
            fn shared(&self) -> &SymbolIndex;
197
        }
198
    }
199
}
200

            
201
impl Drop for Symbol {
202
5217864
    fn drop(&mut self) {
203
5217864
        THREAD_TERM_POOL.with_borrow(|tp| {
204
5217864
            tp.drop_symbol(self);
205
5217864
        })
206
5217864
    }
207
}
208

            
209
impl From<&SymbolRef<'_>> for Symbol {
210
    fn from(value: &SymbolRef) -> Self {
211
        value.protect()
212
    }
213
}
214

            
215
impl Clone for Symbol {
216
    fn clone(&self) -> Self {
217
        self.copy().protect()
218
    }
219
}
220

            
221
impl Deref for Symbol {
222
    type Target = SymbolRef<'static>;
223

            
224
4414798991
    fn deref(&self) -> &Self::Target {
225
4414798991
        &self.symbol
226
4414798991
    }
227
}
228

            
229
impl fmt::Display for Symbol {
230
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
231
        write!(f, "{}", self.name())
232
    }
233
}
234

            
235
impl fmt::Debug for Symbol {
236
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
237
        write!(f, "{}", self.name())
238
    }
239
}
240

            
241
impl Hash for Symbol {
242
192764
    fn hash<H: Hasher>(&self, state: &mut H) {
243
192764
        self.copy().hash(state)
244
192764
    }
245
}
246

            
247
impl PartialEq for Symbol {
248
1871241
    fn eq(&self, other: &Self) -> bool {
249
1871241
        self.copy().eq(&other.copy())
250
1871241
    }
251
}
252

            
253
impl PartialEq<SymbolRef<'_>> for Symbol {
254
    fn eq(&self, other: &SymbolRef<'_>) -> bool {
255
        self.copy().eq(other)
256
    }
257
}
258

            
259
impl PartialOrd for Symbol {
260
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
261
        Some(self.cmp(other))
262
    }
263
}
264

            
265
impl Ord for Symbol {
266
    fn cmp(&self, other: &Self) -> Ordering {
267
        self.copy().cmp(&other.copy())
268
    }
269
}
270

            
271
impl Borrow<SymbolRef<'static>> for Symbol {
272
    fn borrow(&self) -> &SymbolRef<'static> {
273
        &self.symbol
274
    }
275
}
276

            
277
impl Eq for Symbol {}