1
use std::borrow::Borrow;
2
use std::cmp::Ordering;
3
use std::fmt;
4
use std::hash::Hash;
5
use std::hash::Hasher;
6
use std::marker::PhantomData;
7
use std::ops::Deref;
8

            
9
use delegate::delegate;
10

            
11
use merc_collections::ProtectionIndex;
12
use merc_unsafety::StablePointer;
13

            
14
use crate::Markable;
15
use crate::storage::Marker;
16
use crate::storage::SharedSymbol;
17
use crate::storage::THREAD_TERM_POOL;
18

            
19
/// The public interface for a function symbol. Can be used to write generic
20
/// functions that accept both [Symbol] and [SymbolRef].
21
///
22
/// See [crate::Term] for more information on how to use this trait with two lifetimes.
23
pub trait Symb<'a, 'b> {
24
    /// Obtain the symbol's name.
25
    fn name(&'b self) -> &'a str;
26

            
27
    /// Obtain the symbol's arity.
28
    fn arity(&self) -> usize;
29

            
30
    /// Create a copy of the symbol reference.
31
    fn copy(&'b self) -> SymbolRef<'a>;
32

            
33
    /// Returns a unique index for the symbol.
34
    fn index(&self) -> usize;
35

            
36
    /// TODO: How to actually hide this implementation?
37
    fn shared(&self) -> &SymbolIndex;
38
}
39

            
40
/// An alias for the type that is used to reference into the symbol set.
41
pub type SymbolIndex = StablePointer<SharedSymbol>;
42

            
43
/// A reference to a function symbol in the symbol pool.
44
#[repr(transparent)]
45
#[derive(Hash, PartialEq, Eq, PartialOrd, Ord)]
46
pub struct SymbolRef<'a> {
47
    shared: SymbolIndex,
48
    marker: PhantomData<&'a ()>,
49
}
50

            
51
/// Check that the SymbolRef is the same size as a usize.
52
#[cfg(not(debug_assertions))]
53
const _: () = assert!(std::mem::size_of::<SymbolRef>() == std::mem::size_of::<usize>());
54

            
55
/// Check that the Option<SymbolRef> is the same size as a usize using niche value optimisation.
56
#[cfg(not(debug_assertions))]
57
const _: () = assert!(std::mem::size_of::<Option<SymbolRef>>() == std::mem::size_of::<usize>());
58

            
59
/// A reference to a function symbol with a known lifetime.
60
impl<'a> SymbolRef<'a> {
61
    /// Protects the symbol from garbage collection, yielding a `Symbol`.
62
1866300
    pub fn protect(&self) -> Symbol {
63
1866300
        THREAD_TERM_POOL.with_borrow(|tp| tp.protect_symbol(self))
64
1866300
    }
65

            
66
    /// Internal constructor to create a `SymbolRef` from a `SymbolIndex`.
67
    ///
68
    /// # Safety
69
    ///
70
    /// We must ensure that the lifetime `'a` is valid for the returned `SymbolRef`.
71
7488028047
    pub unsafe fn from_index(index: &SymbolIndex) -> SymbolRef<'a> {
72
7488028047
        SymbolRef {
73
7488028047
            shared: index.copy(),
74
7488028047
            marker: PhantomData,
75
7488028047
        }
76
7488028047
    }
77
}
78

            
79
impl SymbolRef<'_> {
80
    /// Internal constructo to convert any `Symb` to a `SymbolRef`.
81
50063311
    pub(crate) fn from_symbol<'a, 'b>(symbol: &'b impl Symb<'a, 'b>) -> Self {
82
50063311
        SymbolRef {
83
50063311
            shared: symbol.shared().copy(),
84
50063311
            marker: PhantomData,
85
50063311
        }
86
50063311
    }
87
}
88

            
89
impl<'a> Symb<'a, '_> for SymbolRef<'a> {
90
1580093
    fn name(&self) -> &'a str {
91
1580093
        unsafe { std::mem::transmute(self.shared.name()) }
92
1580093
    }
93

            
94
1686003794
    fn arity(&self) -> usize {
95
1686003794
        self.shared.arity()
96
1686003794
    }
97

            
98
7470546491
    fn copy(&self) -> SymbolRef<'a> {
99
7470546491
        unsafe { SymbolRef::from_index(self.shared()) }
100
7470546491
    }
101

            
102
    fn index(&self) -> usize {
103
        self.shared.index()
104
    }
105

            
106
7598714860
    fn shared(&self) -> &SymbolIndex {
107
7598714860
        &self.shared
108
7598714860
    }
109
}
110

            
111
impl Markable for SymbolRef<'_> {
112
1550
    fn mark(&self, marker: &mut Marker) {
113
1550
        marker.mark_symbol(self);
114
1550
    }
115

            
116
    fn contains_term(&self, _term: &crate::aterm::ATermRef<'_>) -> bool {
117
        false
118
    }
119

            
120
292890
    fn contains_symbol(&self, symbol: &SymbolRef<'_>) -> bool {
121
292890
        self == symbol
122
292890
    }
123

            
124
    fn len(&self) -> usize {
125
        1
126
    }
127
}
128

            
129
impl fmt::Display for SymbolRef<'_> {
130
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131
        write!(f, "{}", self.name())
132
    }
133
}
134

            
135
impl fmt::Debug for SymbolRef<'_> {
136
763407
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137
763407
        write!(f, "{}", self.name())
138
763407
    }
139
}
140

            
141
/// A protected function symbol, with the same interface as [SymbolRef].
142
pub struct Symbol {
143
    symbol: SymbolRef<'static>,
144
    root: ProtectionIndex,
145
}
146

            
147
impl Symbol {
148
    /// Create a new symbol with the given name and arity.
149
3172254
    pub fn new(name: impl Into<String> + AsRef<str>, arity: usize) -> Symbol {
150
3172254
        THREAD_TERM_POOL.with_borrow(|tp| tp.create_symbol(name, arity))
151
3172254
    }
152
}
153

            
154
impl Symbol {
155
    /// Internal constructor to create a symbol from an index and a root.
156
5230654
    pub(crate) unsafe fn from_index(index: &SymbolIndex, root: ProtectionIndex) -> Symbol {
157
5230654
        Self {
158
5230654
            symbol: unsafe { SymbolRef::from_index(index) },
159
5230654
            root,
160
5230654
        }
161
5230654
    }
162

            
163
    /// Returns the root index, i.e., the index in the protection set. See `SharedTermProtection`.
164
5227534
    pub fn root(&self) -> ProtectionIndex {
165
5227534
        self.root
166
5227534
    }
167

            
168
    /// Create a copy of the symbol reference.
169
5573744
    pub fn copy(&self) -> SymbolRef<'_> {
170
5573744
        self.symbol.copy()
171
5573744
    }
172
}
173

            
174
impl<'a> Symb<'a, '_> for &'a Symbol {
175
    delegate! {
176
        to self.symbol {
177
            fn name(&self) -> &'a str;
178
            fn arity(&self) -> usize;
179
            fn copy(&self) -> SymbolRef<'a>;
180
            fn index(&self) -> usize;
181
            fn shared(&self) -> &SymbolIndex;
182
        }
183
    }
184
}
185

            
186
impl<'a, 'b> Symb<'a, 'b> for Symbol
187
where
188
    'b: 'a,
189
{
190
    delegate! {
191
        to self.symbol {
192
            fn name(&self) -> &'a str;
193
            fn arity(&self) -> usize;
194
6500
            fn copy(&self) -> SymbolRef<'a>;
195
            fn index(&self) -> usize;
196
2867212
            fn shared(&self) -> &SymbolIndex;
197
        }
198
    }
199
}
200

            
201
impl Drop for Symbol {
202
5227534
    fn drop(&mut self) {
203
5227534
        THREAD_TERM_POOL.with_borrow(|tp| {
204
5227534
            tp.drop_symbol(self);
205
5227534
        })
206
5227534
    }
207
}
208

            
209
impl From<&SymbolRef<'_>> for Symbol {
210
    fn from(value: &SymbolRef) -> Self {
211
        value.protect()
212
    }
213
}
214

            
215
impl Clone for Symbol {
216
    fn clone(&self) -> Self {
217
        self.copy().protect()
218
    }
219
}
220

            
221
impl Deref for Symbol {
222
    type Target = SymbolRef<'static>;
223

            
224
5099722641
    fn deref(&self) -> &Self::Target {
225
5099722641
        &self.symbol
226
5099722641
    }
227
}
228

            
229
impl fmt::Display for Symbol {
230
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
231
        write!(f, "{}", self.name())
232
    }
233
}
234

            
235
impl fmt::Debug for Symbol {
236
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
237
        write!(f, "{}", self.name())
238
    }
239
}
240

            
241
impl Hash for Symbol {
242
190890
    fn hash<H: Hasher>(&self, state: &mut H) {
243
190890
        self.copy().hash(state)
244
190890
    }
245
}
246

            
247
impl PartialEq for Symbol {
248
1852501
    fn eq(&self, other: &Self) -> bool {
249
1852501
        self.copy().eq(&other.copy())
250
1852501
    }
251
}
252

            
253
impl PartialEq<SymbolRef<'_>> for Symbol {
254
    fn eq(&self, other: &SymbolRef<'_>) -> bool {
255
        self.copy().eq(other)
256
    }
257
}
258

            
259
impl PartialOrd for Symbol {
260
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
261
        Some(self.cmp(other))
262
    }
263
}
264

            
265
impl Ord for Symbol {
266
    fn cmp(&self, other: &Self) -> Ordering {
267
        self.copy().cmp(&other.copy())
268
    }
269
}
270

            
271
impl Borrow<SymbolRef<'static>> for Symbol {
272
    fn borrow(&self) -> &SymbolRef<'static> {
273
        &self.symbol
274
    }
275
}
276

            
277
impl Eq for Symbol {}