1
#![forbid(unsafe_code)]
2

            
3
use std::hash::Hash;
4
use std::hash::Hasher;
5
use std::sync::Arc;
6
use std::sync::atomic::AtomicUsize;
7
use std::sync::atomic::Ordering;
8

            
9
use dashmap::DashMap;
10
use equivalent::Equivalent;
11
use merc_unsafety::StablePointer;
12
use rustc_hash::FxBuildHasher;
13

            
14
use merc_unsafety::StablePointerSet;
15

            
16
use crate::Symb;
17
use crate::SymbolIndex;
18
use crate::SymbolRef;
19

            
20
/// Pool for maximal sharing of function symbols, see [crate::SymbolRef]. Ensures that function symbols
21
/// with the same name and arity point to the same [SharedSymbol] object.
22
/// Returns [crate::Symbol] that can be used to refer to the shared symbol, avoiding
23
/// garbage collection of the underlying shared symbol.
24
pub struct SymbolPool {
25
    /// Unique table of all function symbols
26
    symbols: StablePointerSet<SharedSymbol, FxBuildHasher>,
27

            
28
    /// A map from prefixes to counters that track the next available index for function symbols
29
    prefix_to_register_function_map: DashMap<String, Arc<AtomicUsize>, FxBuildHasher>,
30
}
31

            
32
impl SymbolPool {
33
    /// Creates a new empty symbol pool.
34
533
    pub(crate) fn new() -> Self {
35
533
        Self {
36
533
            symbols: StablePointerSet::with_hasher(FxBuildHasher),
37
533
            prefix_to_register_function_map: DashMap::with_hasher(FxBuildHasher),
38
533
        }
39
533
    }
40

            
41
    /// Creates or retrieves a function symbol with the given name and arity.
42
3260853
    pub fn create<N>(&self, name: N, arity: usize) -> StablePointer<SharedSymbol>
43
3260853
    where
44
3260853
        N: Into<String> + AsRef<str>,
45
    {
46
        // Get or create symbol index
47
3260853
        let (shared_symbol, inserted) = self.symbols.insert_equiv(&SharedSymbolLookup { name, arity });
48

            
49
3260853
        if inserted {
50
12905
            // If the symbol was newly created, register its prefix.
51
12905
            self.update_prefix(shared_symbol.name());
52
3247948
        }
53

            
54
        // Return cloned symbol
55
3260853
        shared_symbol
56
3260853
    }
57

            
58
    /// Return the symbol of the SharedTerm for the given ATermRef
59
    pub fn symbol_name<'a>(&self, symbol: &'a SymbolRef<'a>) -> &'a str {
60
        symbol.shared().name()
61
    }
62

            
63
    /// Returns the arity of the function symbol
64
    pub fn symbol_arity<'a, 'b>(&self, symbol: &'b impl Symb<'a, 'b>) -> usize {
65
        symbol.shared().arity()
66
    }
67

            
68
    /// Returns the number of symbols in the pool.
69
    pub fn len(&self) -> usize {
70
        self.symbols.len()
71
    }
72

            
73
    /// Returns true if the pool is empty.
74
    pub fn is_empty(&self) -> bool {
75
        self.symbols.is_empty()
76
    }
77

            
78
    /// Returns the capacity of the pool.
79
    pub fn capacity(&self) -> usize {
80
        self.symbols.capacity()
81
    }
82

            
83
    /// Retain only symbols satisfying the given predicate.
84
    pub fn retain<F>(&mut self, mut f: F)
85
    where
86
        F: FnMut(&SymbolIndex) -> bool,
87
    {
88
        self.symbols.retain(|element| f(element));
89
    }
90

            
91
    /// Creates a new prefix counter for the given prefix.
92
1
    pub fn create_prefix(&self, prefix: &str) -> Arc<AtomicUsize> {
93
        // Create a new counter for the prefix if it does not exist
94
1
        let result = match self.prefix_to_register_function_map.get(prefix) {
95
            Some(result) => result.clone(),
96
            None => {
97
1
                let result = Arc::new(AtomicUsize::new(0));
98
1
                assert!(
99
1
                    self.prefix_to_register_function_map
100
1
                        .insert(prefix.to_string(), result.clone())
101
1
                        .is_none(),
102
                    "This key should not yet exist"
103
                );
104
1
                result
105
            }
106
        };
107

            
108
        // Ensure the counter starts at a sufficiently large index
109
1
        self.get_sufficiently_large_postfix_index(prefix, &result);
110
1
        result
111
1
    }
112

            
113
    /// Removes a prefix counter from the pool.
114
    pub fn remove_prefix(&self, prefix: &str) {
115
        // Remove the prefix counter if it exists
116
        self.prefix_to_register_function_map.remove(prefix);
117
    }
118

            
119
    /// Updates the counter for a registered prefix for the newly created symbol.
120
20735
    fn update_prefix(&self, name: &str) {
121
        // Check whether there is a registered prefix p such that name equal pn where n is a number.
122
        // In that case prevent that pn will be generated as a fresh function name.
123
20735
        let start_of_index = name
124
24198
            .rfind(|c: char| !c.is_ascii_digit())
125
20735
            .map(|pos| pos + 1)
126
20735
            .unwrap_or(0);
127

            
128
20735
        if start_of_index < name.len() {
129
2562
            let potential_number = &name[start_of_index..];
130
2562
            let prefix = &name[..start_of_index];
131

            
132
2562
            if let Some(counter) = self.prefix_to_register_function_map.get(prefix) {
133
1
                if let Ok(number) = potential_number.parse::<usize>() {
134
1
                    counter.fetch_max(number + 1, Ordering::Relaxed);
135
1
                }
136
2561
            }
137
18173
        }
138
20735
    }
139

            
140
    /// Traverse all symbols to find the maximum numeric suffix for this prefix
141
1
    fn get_sufficiently_large_postfix_index(&self, prefix: &str, counter: &Arc<AtomicUsize>) {
142
5
        for symbol in self.symbols.iter() {
143
5
            let name = symbol.name();
144
5
            if name.starts_with(prefix) {
145
                // Symbol name starts with the prefix, check for numeric suffix
146
2
                let suffix_start = prefix.len();
147
2
                if suffix_start < name.len() {
148
2
                    let suffix = &name[suffix_start..];
149
2
                    if let Ok(number) = suffix.parse::<usize>() {
150
1
                        // There is a numeric suffix, update the counter if it's larger
151
1
                        counter.fetch_max(number + 1, Ordering::Relaxed);
152
1
                    }
153
                }
154
3
            }
155
        }
156
1
    }
157
}
158

            
159
/// Represents a function symbol with a name and arity.
160
#[derive(Debug, Clone, Eq, PartialEq)]
161
pub struct SharedSymbol {
162
    /// Name of the function
163
    name: String,
164
    /// Number of arguments
165
    arity: usize,
166
}
167

            
168
impl SharedSymbol {
169
    /// Creates a new function symbol.
170
20735
    pub fn new(name: impl Into<String>, arity: usize) -> Self {
171
20735
        Self {
172
20735
            name: name.into(),
173
20735
            arity,
174
20735
        }
175
20735
    }
176

            
177
    /// Returns the name of the function symbol
178
1179694
    pub fn name(&self) -> &str {
179
1179694
        &self.name
180
1179694
    }
181

            
182
    /// Returns the arity of the function symbol
183
1722279972
    pub fn arity(&self) -> usize {
184
1722279972
        self.arity
185
1722279972
    }
186

            
187
    /// Returns a unique index for this shared symbol
188
    pub fn index(&self) -> usize {
189
        self as *const Self as *const u8 as usize
190
    }
191
}
192

            
193
/// A cheap way to look up SharedSymbol
194
struct SharedSymbolLookup<T: Into<String> + AsRef<str>> {
195
    name: T,
196
    arity: usize,
197
}
198

            
199
impl<T: Into<String> + AsRef<str>> From<&SharedSymbolLookup<T>> for SharedSymbol {
200
12905
    fn from(lookup: &SharedSymbolLookup<T>) -> Self {
201
        // TODO: Not optimal
202
12905
        let string = lookup.name.as_ref().to_string();
203
12905
        Self::new(string, lookup.arity)
204
12905
    }
205
}
206

            
207
impl<T: Into<String> + AsRef<str>> Equivalent<SharedSymbol> for SharedSymbolLookup<T> {
208
3249952
    fn equivalent(&self, other: &SharedSymbol) -> bool {
209
3249952
        self.name.as_ref() == other.name && self.arity == other.arity
210
3249952
    }
211
}
212

            
213
/// These hash implementations should be the same as `SharedSymbol`.
214
impl<T: Into<String> + AsRef<str>> Hash for SharedSymbolLookup<T> {
215
3260853
    fn hash<H: Hasher>(&self, state: &mut H) {
216
3260853
        self.name.as_ref().hash(state);
217
3260853
        self.arity.hash(state);
218
3260853
    }
219
}
220

            
221
impl Hash for SharedSymbol {
222
30881
    fn hash<H: Hasher>(&self, state: &mut H) {
223
30881
        self.name.hash(state);
224
30881
        self.arity.hash(state);
225
30881
    }
226
}
227

            
228
#[cfg(test)]
229
mod tests {
230
    use std::sync::atomic::Ordering;
231

            
232
    use crate::Symbol;
233
    use crate::storage::THREAD_TERM_POOL;
234

            
235
    #[test]
236
1
    fn test_symbol_sharing() {
237
1
        let _ = merc_utilities::test_logger();
238

            
239
1
        let f1 = Symbol::new("f", 2);
240
1
        let f2 = Symbol::new("f", 2);
241

            
242
        // Should be the same object
243
1
        assert_eq!(f1, f2);
244
1
    }
245

            
246
    #[test]
247
1
    fn test_prefix_counter() {
248
1
        let _ = merc_utilities::test_logger();
249

            
250
1
        let _symbol = Symbol::new("x69", 0);
251
1
        let _symbol2 = Symbol::new("x_y", 0);
252

            
253
1
        let value =
254
1
            THREAD_TERM_POOL.with_borrow(|tp| tp.term_pool().write().expect("Lock poisoned!").register_prefix("x"));
255

            
256
1
        assert_eq!(value.load(Ordering::Relaxed), 70);
257

            
258
1
        let _symbol3 = Symbol::new("x_no_effect", 0);
259
1
        let _symbol4 = Symbol::new("x130", 0);
260

            
261
1
        assert_eq!(value.load(Ordering::Relaxed), 131);
262
1
    }
263
}