1
#![forbid(unsafe_code)]
2

            
3
use std::hash::Hash;
4
use std::hash::Hasher;
5
use std::sync::Arc;
6
use std::sync::atomic::AtomicUsize;
7
use std::sync::atomic::Ordering;
8

            
9
use dashmap::DashMap;
10
use equivalent::Equivalent;
11
use merc_unsafety::StablePointer;
12
use rustc_hash::FxBuildHasher;
13

            
14
use merc_unsafety::StablePointerSet;
15

            
16
use crate::Symb;
17
use crate::SymbolIndex;
18
use crate::SymbolRef;
19

            
20
/// Pool for maximal sharing of function symbols, see [crate::SymbolRef]. Ensures that function symbols
21
/// with the same name and arity point to the same [SharedSymbol] object.
22
/// Returns [crate::Symbol] that can be used to refer to the shared symbol, avoiding
23
/// garbage collection of the underlying shared symbol.
24
pub struct SymbolPool {
25
    /// Unique table of all function symbols
26
    symbols: StablePointerSet<SharedSymbol, FxBuildHasher>,
27

            
28
    /// A map from prefixes to counters that track the next available index for function symbols
29
    prefix_to_register_function_map: DashMap<String, Arc<AtomicUsize>, FxBuildHasher>,
30
}
31

            
32
impl SymbolPool {
33
    /// Creates a new empty symbol pool.
34
593
    pub(crate) fn new() -> Self {
35
593
        Self {
36
593
            symbols: StablePointerSet::with_hasher(FxBuildHasher),
37
593
            prefix_to_register_function_map: DashMap::with_hasher(FxBuildHasher),
38
593
        }
39
593
    }
40

            
41
    /// Creates or retrieves a function symbol with the given name and arity.
42
3331643
    pub fn create<N>(&self, name: N, arity: usize) -> StablePointer<SharedSymbol>
43
3331643
    where
44
3331643
        N: Into<String> + AsRef<str>,
45
    {
46
        // Get or create symbol index
47
3331643
        let (shared_symbol, inserted) = self.symbols.insert_equiv(&SharedSymbolLookup { name, arity });
48

            
49
3331643
        if inserted {
50
21495
            // If the symbol was newly created, register its prefix.
51
21495
            self.update_prefix(shared_symbol.name());
52
3310148
        }
53

            
54
        // Return cloned symbol
55
3331643
        shared_symbol
56
3331643
    }
57

            
58
    /// Return the symbol of the SharedTerm for the given ATermRef
59
    pub fn symbol_name<'a>(&self, symbol: &'a SymbolRef<'a>) -> &'a str {
60
        symbol.shared().name()
61
    }
62

            
63
    /// Returns the arity of the function symbol
64
    pub fn symbol_arity<'a, 'b, S: Symb<'a, 'b>>(&self, symbol: &'b S) -> usize {
65
        symbol.shared().arity()
66
    }
67

            
68
    /// Returns the number of symbols in the pool.
69
28020
    pub fn len(&self) -> usize {
70
28020
        self.symbols.len()
71
28020
    }
72

            
73
    /// Returns true if the pool is empty.
74
    pub fn is_empty(&self) -> bool {
75
        self.symbols.is_empty()
76
    }
77

            
78
    /// Returns the capacity of the pool.
79
    pub fn capacity(&self) -> usize {
80
        self.symbols.capacity()
81
    }
82

            
83
    /// Retain only symbols satisfying the given predicate.
84
28020
    pub fn retain<F>(&mut self, mut f: F)
85
28020
    where
86
28020
        F: FnMut(&SymbolIndex) -> bool,
87
    {
88
1308700
        self.symbols.retain(|element| f(element));
89
28020
    }
90

            
91
    /// Creates a new prefix counter for the given prefix.
92
1
    pub fn create_prefix(&self, prefix: &str) -> Arc<AtomicUsize> {
93
        // Create a new counter for the prefix if it does not exist
94
1
        let result = match self.prefix_to_register_function_map.get(prefix) {
95
            Some(result) => result.clone(),
96
            None => {
97
1
                let result = Arc::new(AtomicUsize::new(0));
98
1
                assert!(
99
1
                    self.prefix_to_register_function_map
100
1
                        .insert(prefix.to_string(), result.clone())
101
1
                        .is_none(),
102
                    "This key should not yet exist"
103
                );
104
1
                result
105
            }
106
        };
107

            
108
        // Ensure the counter starts at a sufficiently large index
109
1
        self.get_sufficiently_large_postfix_index(prefix, &result);
110
1
        result
111
1
    }
112

            
113
    /// Removes a prefix counter from the pool.
114
    pub fn remove_prefix(&self, prefix: &str) {
115
        // Remove the prefix counter if it exists
116
        self.prefix_to_register_function_map.remove(prefix);
117
    }
118

            
119
    /// Updates the counter for a registered prefix for the newly created symbol.
120
29325
    fn update_prefix(&self, name: &str) {
121
        // Check whether there is a registered prefix p such that name equal pn where n is a number.
122
        // In that case prevent that pn will be generated as a fresh function name.
123
29325
        let start_of_index = name
124
33228
            .rfind(|c: char| !c.is_ascii_digit())
125
29325
            .map(|pos| pos + 1)
126
29325
            .unwrap_or(0);
127

            
128
29325
        if start_of_index < name.len() {
129
3072
            let potential_number = &name[start_of_index..];
130
3072
            let prefix = &name[..start_of_index];
131

            
132
3072
            if let Some(counter) = self.prefix_to_register_function_map.get(prefix)
133
1
                && let Ok(number) = potential_number.parse::<usize>() {
134
1
                    counter.fetch_max(number + 1, Ordering::Relaxed);
135
3071
                }
136
26253
        }
137
29325
    }
138

            
139
    /// Traverse all symbols to find the maximum numeric suffix for this prefix
140
1
    fn get_sufficiently_large_postfix_index(&self, prefix: &str, counter: &Arc<AtomicUsize>) {
141
5
        for symbol in self.symbols.iter() {
142
5
            let name = symbol.name();
143
5
            if name.starts_with(prefix) {
144
                // Symbol name starts with the prefix, check for numeric suffix
145
2
                let suffix_start = prefix.len();
146
2
                if suffix_start < name.len() {
147
2
                    let suffix = &name[suffix_start..];
148
2
                    if let Ok(number) = suffix.parse::<usize>() {
149
1
                        // There is a numeric suffix, update the counter if it's larger
150
1
                        counter.fetch_max(number + 1, Ordering::Relaxed);
151
1
                    }
152
                }
153
3
            }
154
        }
155
1
    }
156
}
157

            
158
/// Represents a function symbol with a name and arity.
159
#[derive(Debug, Clone, Eq, PartialEq)]
160
pub struct SharedSymbol {
161
    /// Name of the function
162
    name: String,
163
    /// Number of arguments
164
    arity: usize,
165
}
166

            
167
impl SharedSymbol {
168
    /// Creates a new function symbol.
169
29326
    pub fn new<N: Into<String>>(name: N, arity: usize) -> Self {
170
29326
        Self {
171
29326
            name: name.into(),
172
29326
            arity,
173
29326
        }
174
29326
    }
175

            
176
    /// Returns the name of the function symbol
177
1782767
    pub fn name(&self) -> &str {
178
1782767
        &self.name
179
1782767
    }
180

            
181
    /// Returns the arity of the function symbol
182
1733035892
    pub fn arity(&self) -> usize {
183
1733035892
        self.arity
184
1733035892
    }
185

            
186
    /// Returns a unique index for this shared symbol
187
    pub fn index(&self) -> usize {
188
        self as *const Self as *const u8 as usize
189
    }
190
}
191

            
192
/// A cheap way to look up SharedSymbol
193
struct SharedSymbolLookup<T: Into<String> + AsRef<str>> {
194
    name: T,
195
    arity: usize,
196
}
197

            
198
impl<T: Into<String> + AsRef<str>> From<&SharedSymbolLookup<T>> for SharedSymbol {
199
21496
    fn from(lookup: &SharedSymbolLookup<T>) -> Self {
200
        // TODO: Not optimal
201
21496
        let string = lookup.name.as_ref().to_string();
202
21496
        Self::new(string, lookup.arity)
203
21496
    }
204
}
205

            
206
impl<T: Into<String> + AsRef<str>> Equivalent<SharedSymbol> for SharedSymbolLookup<T> {
207
3314001
    fn equivalent(&self, other: &SharedSymbol) -> bool {
208
3314001
        self.name.as_ref() == other.name && self.arity == other.arity
209
3314001
    }
210
}
211

            
212
/// These hash implementations should be the same as `SharedSymbol`.
213
impl<T: Into<String> + AsRef<str>> Hash for SharedSymbolLookup<T> {
214
3331643
    fn hash<H: Hasher>(&self, state: &mut H) {
215
3331643
        self.name.as_ref().hash(state);
216
3331643
        self.arity.hash(state);
217
3331643
    }
218
}
219

            
220
impl Hash for SharedSymbol {
221
44453
    fn hash<H: Hasher>(&self, state: &mut H) {
222
44453
        self.name.hash(state);
223
44453
        self.arity.hash(state);
224
44453
    }
225
}
226

            
227
#[cfg(test)]
228
mod tests {
229
    use std::sync::atomic::Ordering;
230

            
231
    use crate::Symbol;
232
    use crate::storage::THREAD_TERM_POOL;
233

            
234
    #[test]
235
1
    fn test_symbol_sharing() {
236
1
        let _ = merc_utilities::test_logger();
237

            
238
1
        let f1 = Symbol::new("f", 2);
239
1
        let f2 = Symbol::new("f", 2);
240

            
241
        // Should be the same object
242
1
        assert_eq!(f1, f2);
243
1
    }
244

            
245
    #[test]
246
1
    fn test_prefix_counter() {
247
1
        let _ = merc_utilities::test_logger();
248

            
249
1
        let _symbol = Symbol::new("x69", 0);
250
1
        let _symbol2 = Symbol::new("x_y", 0);
251

            
252
1
        let value =
253
1
            THREAD_TERM_POOL.with_borrow(|tp| tp.term_pool().write().expect("Lock poisoned!").register_prefix("x"));
254

            
255
1
        assert_eq!(value.load(Ordering::Relaxed), 70);
256

            
257
1
        let _symbol3 = Symbol::new("x_no_effect", 0);
258
1
        let _symbol4 = Symbol::new("x130", 0);
259

            
260
1
        assert_eq!(value.load(Ordering::Relaxed), 131);
261
1
    }
262
}