1
use std::cell::UnsafeCell;
2
use std::collections::HashSet;
3
use std::fmt;
4
use std::sync::Arc;
5
use std::sync::LazyLock;
6
use std::sync::atomic::AtomicUsize;
7
use std::time::Instant;
8

            
9
use log::debug;
10

            
11
use merc_collections::ProtectionSet;
12
use merc_io::LargeFormatter;
13
use merc_sharedmutex::GlobalBfSharedMutex;
14
use merc_sharedmutex::RecursiveLockReadGuard;
15
use merc_unsafety::StablePointer;
16
use merc_utilities::debug_trace;
17

            
18
use crate::ATermIndex;
19
use crate::ATermRef;
20
use crate::Markable;
21
use crate::Symb;
22
use crate::Symbol;
23
use crate::SymbolIndex;
24
use crate::SymbolRef;
25
use crate::Term;
26
use crate::storage::ATermStorage;
27
use crate::storage::SharedTerm;
28
use crate::storage::SharedTermLookup;
29
use crate::storage::SymbolPool;
30

            
31
/// This is the global set of protection sets that are managed by the ThreadTermPool
32
pub static GLOBAL_TERM_POOL: LazyLock<GlobalBfSharedMutex<GlobalTermPool>> =
33
533
    LazyLock::new(|| GlobalBfSharedMutex::new(GlobalTermPool::new()));
34

            
35
/// Enables aggressive garbage collection, which is used for testing.
36
pub(crate) const AGGRESSIVE_GC: bool = false;
37

            
38
/// A type alias for the global term pool guard
39
pub(crate) type GlobalTermPoolGuard<'a> = RecursiveLockReadGuard<'a, GlobalTermPool>;
40

            
41
/// A type alias for deletion hooks
42
type DeletionHook = Box<dyn Fn(&ATermIndex) + Sync + Send>;
43

            
44
/// The single global (singleton) term pool.
45
pub struct GlobalTermPool {
46
    /// Unique table of all terms with stable pointers for references
47
    terms: ATermStorage,
48
    /// The symbol pool for managing function symbols.
49
    symbol_pool: SymbolPool,
50
    /// The thread-specific protection sets.
51
    thread_pools: Vec<Option<Arc<UnsafeCell<SharedTermProtection>>>>,
52

            
53
    // Data structures used for garbage collection
54
    /// Used to avoid reallocations for the markings of all terms - uses pointers as keys
55
    marked_terms: HashSet<ATermIndex>,
56
    /// Used to avoid reallocations for the markings of all symbols
57
    marked_symbols: HashSet<SymbolIndex>,
58
    /// A stack used to mark terms recursively.
59
    stack: Vec<ATermIndex>,
60

            
61
    /// Deletion hooks called whenever a term with the given head symbol is deleted.
62
    deletion_hooks: Vec<(Symbol, DeletionHook)>,
63

            
64
    /// Indicates whether automatic garbage collection is enabled.
65
    garbage_collection: bool,
66

            
67
    /// Default terms
68
    int_symbol: SymbolRef<'static>,
69
    empty_list_symbol: SymbolRef<'static>,
70
    list_symbol: SymbolRef<'static>,
71
}
72

            
73
unsafe impl Send for GlobalTermPool {}
74
unsafe impl Sync for GlobalTermPool {}
75

            
76
impl GlobalTermPool {
77
533
    fn new() -> GlobalTermPool {
78
        // Insert the default symbols.
79
533
        let symbol_pool = SymbolPool::new();
80
533
        let int_symbol = unsafe { SymbolRef::from_index(&symbol_pool.create("<aterm_int>", 0)) };
81
533
        let list_symbol = unsafe { SymbolRef::from_index(&symbol_pool.create("<list_constructor>", 2)) };
82
533
        let empty_list_symbol = unsafe { SymbolRef::from_index(&symbol_pool.create("<empty_list>", 0)) };
83

            
84
533
        GlobalTermPool {
85
533
            terms: ATermStorage::new(),
86
533
            symbol_pool,
87
533
            thread_pools: Vec::new(),
88
533
            marked_terms: HashSet::new(),
89
533
            marked_symbols: HashSet::new(),
90
533
            stack: Vec::new(),
91
533
            deletion_hooks: Vec::new(),
92
533
            garbage_collection: true,
93
533
            int_symbol,
94
533
            list_symbol,
95
533
            empty_list_symbol,
96
533
        }
97
533
    }
98

            
99
    /// Returns the number of terms in the pool.
100
    pub fn len(&self) -> usize {
101
        self.terms.len()
102
    }
103

            
104
    /// Returns whether the term pool is empty.
105
    pub fn is_empty(&self) -> bool {
106
        self.len() == 0
107
    }
108

            
109
    /// Creates a term storing a single integer value.
110
3741624
    pub fn create_int(&self, value: usize) -> (StablePointer<SharedTerm>, bool) {
111
3741624
        let shared_term = SharedTermLookup {
112
3741624
            symbol: unsafe { SymbolRef::from_index(self.int_symbol.shared()) },
113
3741624
            arguments: &[],
114
3741624
            annotation: Some(value),
115
3741624
        };
116

            
117
3741624
        let (index, inserted) = unsafe {
118
3741624
            self.terms
119
3741624
                .insert_equiv_dst(&shared_term, SharedTerm::length_for(&shared_term), |ptr, key| {
120
1885
                    SharedTerm::construct(ptr, key)
121
1885
                })
122
        };
123

            
124
3741624
        (index, inserted)
125
3741624
    }
126

            
127
    /// Create a term from a head symbol and an iterator over its arguments
128
50053331
    pub fn create_term_array<'a, 'b, 'c>(
129
50053331
        &'c self,
130
50053331
        symbol: &'b impl Symb<'a, 'b>,
131
50053331
        args: &'c [ATermRef<'c>],
132
50053331
    ) -> (StablePointer<SharedTerm>, bool) {
133
50053331
        let shared_term = SharedTermLookup {
134
50053331
            symbol: SymbolRef::from_symbol(symbol),
135
50053331
            arguments: args,
136
50053331
            annotation: None,
137
50053331
        };
138

            
139
50053331
        debug_assert_eq!(
140
50053331
            symbol.shared().arity(),
141
50053331
            shared_term.arguments.len(),
142
            "The number of arguments does not match the arity of the symbol"
143
        );
144

            
145
50053331
        let (index, inserted) = unsafe {
146
50053331
            self.terms
147
50053331
                .insert_equiv_dst(&shared_term, SharedTerm::length_for(&shared_term), |ptr, key| {
148
561000
                    SharedTerm::construct(ptr, key)
149
561000
                })
150
        };
151

            
152
50053331
        (index, inserted)
153
50053331
    }
154

            
155
    /// Create a function symbol
156
3259254
    pub fn create_symbol<P>(&self, name: impl Into<String> + AsRef<str>, arity: usize, protect: P) -> Symbol
157
3259254
    where
158
3259254
        P: FnOnce(SymbolIndex) -> Symbol,
159
    {
160
3259254
        protect(self.symbol_pool.create(name, arity))
161
3259254
    }
162

            
163
    /// Registers a new thread term pool.
164
    ///
165
    /// # Safety
166
    ///
167
    /// Note that the returned `Arc<UnsafeCell<...>>` is not Send or Sync, so it
168
    /// *must* be protected through other means.
169
    #[allow(clippy::arc_with_non_send_sync)]
170
535
    pub(crate) fn register_thread_term_pool(&mut self) -> Arc<UnsafeCell<SharedTermProtection>> {
171
535
        let protection = Arc::new(UnsafeCell::new(SharedTermProtection {
172
535
            protection_set: ProtectionSet::new(),
173
535
            symbol_protection_set: ProtectionSet::new(),
174
535
            container_protection_set: ProtectionSet::new(),
175
535
            index: self.thread_pools.len(),
176
535
        }));
177

            
178
535
        debug!("Registered thread_local protection set(s) {}", self.thread_pools.len());
179
535
        self.thread_pools.push(Some(protection.clone()));
180

            
181
535
        protection
182
535
    }
183

            
184
    /// Deregisters a thread pool.
185
535
    pub(crate) fn deregister_thread_pool(&mut self, index: usize) {
186
535
        debug!("Removed thread_local protection set(s) {index}");
187
535
        if let Some(entry) = self.thread_pools.get_mut(index) {
188
535
            *entry = None;
189
535
        }
190
535
    }
191

            
192
    /// Triggers garbage collection if necessary and returns an updated counter for the thread local pool.
193
    pub(crate) fn trigger_garbage_collection(&mut self) -> usize {
194
        self.collect_garbage();
195

            
196
        if AGGRESSIVE_GC {
197
            return 1;
198
        }
199

            
200
        self.len()
201
    }
202

            
203
    /// Returns a counter for the unique numeric suffix of the given prefix.
204
1
    pub fn register_prefix(&self, prefix: &str) -> Arc<AtomicUsize> {
205
1
        self.symbol_pool.create_prefix(prefix)
206
1
    }
207

            
208
    /// Removes the registration of a prefix from the symbol pool.
209
    pub fn remove_prefix(&self, prefix: &str) {
210
        self.symbol_pool.remove_prefix(prefix)
211
    }
212

            
213
    /// Register a deletion hook that is called whenever a term is deleted with the given symbol.
214
    pub fn register_deletion_hook<F>(&mut self, symbol: SymbolRef<'static>, hook: F)
215
    where
216
        F: Fn(&ATermIndex) + Sync + Send + 'static,
217
    {
218
        self.deletion_hooks.push((symbol.protect(), Box::new(hook)));
219
    }
220

            
221
    /// Enables or disables automatic garbage collection.
222
    pub fn automatic_garbage_collection(&mut self, enabled: bool) {
223
        self.garbage_collection = enabled;
224
    }
225

            
226
    /// Collects garbage terms.
227
    fn collect_garbage(&mut self) {
228
        if !self.garbage_collection {
229
            // Garbage collection is disabled.
230
            return;
231
        }
232

            
233
        // Clear marking data structures
234
        self.marked_terms.clear();
235
        self.marked_symbols.clear();
236
        self.stack.clear();
237

            
238
        // Mark the default symbols
239
        self.marked_symbols.insert(self.int_symbol.shared().copy());
240
        self.marked_symbols.insert(self.list_symbol.shared().copy());
241
        self.marked_symbols.insert(self.empty_list_symbol.shared().copy());
242

            
243
        let mut marker = Marker {
244
            marked_terms: &mut self.marked_terms,
245
            marked_symbols: &mut self.marked_symbols,
246
            stack: &mut self.stack,
247
        };
248

            
249
        let mark_time = Instant::now();
250

            
251
        // Loop through all protection sets and mark the terms.
252
        for pool in self.thread_pools.iter().flatten() {
253
            // SAFETY: We have exclusive access to the global term pool, so no other thread can modify the protection sets.
254
            let pool = unsafe { &mut *pool.get() };
255

            
256
            for (_root, symbol) in pool.symbol_protection_set.iter() {
257
                debug_trace!("Marking root {_root} symbol {symbol:?}");
258
                // Remove all symbols that are not protected
259
                marker.marked_symbols.insert(symbol.copy());
260
            }
261

            
262
            for (_root, term) in pool.protection_set.iter() {
263
                debug_trace!("Marking root {_root} term {term:?}");
264
                unsafe {
265
                    ATermRef::from_index(term).mark(&mut marker);
266
                }
267
            }
268

            
269
            for (_, container) in pool.container_protection_set.iter() {
270
                container.mark(&mut marker);
271
            }
272
        }
273

            
274
        let mark_time_elapsed = mark_time.elapsed();
275
        let collect_time = Instant::now();
276

            
277
        let num_of_terms = self.len();
278
        let num_of_symbols = self.symbol_pool.len();
279

            
280
        // Delete all terms that are not marked
281
        self.terms.retain(|term| {
282
            if !self.marked_terms.contains(term) {
283
                debug_trace!("Dropping term: {:?}", term);
284

            
285
                // Call the deletion hooks for the term
286
                for (symbol, hook) in &self.deletion_hooks {
287
                    if symbol == term.symbol() {
288
                        debug_trace!("Calling deletion hook for term: {:?}", term);
289
                        hook(term);
290
                    }
291
                }
292

            
293
                return false;
294
            }
295

            
296
            true
297
        });
298

            
299
        // We ensure that every removed symbol is not used anymore.
300
        self.symbol_pool.retain(|symbol| {
301
            if !self.marked_symbols.contains(symbol) {
302
                debug_trace!("Dropping symbol: {:?}", symbol);
303
                return false;
304
            }
305

            
306
            true
307
        });
308

            
309
        debug!(
310
            "Garbage collection: marking took {}ms, collection took {}ms, {} terms and {} symbols removed",
311
            mark_time_elapsed.as_millis(),
312
            collect_time.elapsed().as_millis(),
313
            num_of_terms - self.len(),
314
            num_of_symbols - self.symbol_pool.len()
315
        );
316

            
317
        debug!("{}", self.metrics());
318

            
319
        // Print information from the protection sets.
320
        for pool in self.thread_pools.iter().flatten() {
321
            // SAFETY: We have exclusive access to the global term pool, so no other thread can modify the protection sets.
322
            let pool = unsafe { &mut *pool.get() };
323
            debug!("{}", pool.metrics());
324
        }
325
    }
326

            
327
    /// Returns the metrics of the term pool, can be formatted and written to output.
328
    pub fn metrics(&self) -> TermPoolMetrics<'_> {
329
        TermPoolMetrics(self)
330
    }
331

            
332
    /// Marks the given term as being reachable.
333
    ///
334
    /// # Safety
335
    ///
336
    /// Should only be called during garbage collection.
337
    pub unsafe fn mark_term(&mut self, term: &ATermRef<'_>) {
338
        // Ensure that the global term pool is locked for writing.
339
        let mut marker = Marker {
340
            marked_terms: &mut self.marked_terms,
341
            marked_symbols: &mut self.marked_symbols,
342
            stack: &mut self.stack,
343
        };
344
        term.mark(&mut marker);
345
    }
346

            
347
    /// Returns integer function symbol.
348
535
    pub(crate) fn get_int_symbol(&self) -> &SymbolRef<'static> {
349
535
        &self.int_symbol
350
535
    }
351

            
352
    /// Returns integer function symbol.
353
535
    pub(crate) fn get_list_symbol(&self) -> &SymbolRef<'static> {
354
535
        &self.list_symbol
355
535
    }
356

            
357
    /// Returns integer function symbol.
358
535
    pub(crate) fn get_empty_list_symbol(&self) -> &SymbolRef<'static> {
359
535
        &self.empty_list_symbol
360
535
    }
361
}
362

            
363
pub struct TermPoolMetrics<'a>(&'a GlobalTermPool);
364

            
365
impl fmt::Display for TermPoolMetrics<'_> {
366
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
367
        write!(
368
            f,
369
            "There are {} terms, and {} symbols",
370
            self.0.terms.len(),
371
            self.0.symbol_pool.len()
372
        )
373
    }
374
}
375

            
376
pub struct SharedTermProtection {
377
    /// Protection set for terms
378
    pub protection_set: ProtectionSet<ATermIndex>,
379
    /// Protection set to prevent garbage collection of symbols
380
    pub symbol_protection_set: ProtectionSet<SymbolIndex>,
381
    /// Protection set for containers
382
    pub container_protection_set: ProtectionSet<Arc<dyn Markable + Sync + Send>>,
383
    /// Index in global pool's thread pools list
384
    pub index: usize,
385
}
386

            
387
impl SharedTermProtection {
388
    /// Returns the metrics of the term pool, can be formatted and written to output.
389
    pub fn metrics(&self) -> ProtectionMetrics<'_> {
390
        ProtectionMetrics(self)
391
    }
392
}
393

            
394
/// A struct that can be used to print the performance of the protection sets.
395
pub struct ProtectionMetrics<'a>(&'a SharedTermProtection);
396

            
397
impl fmt::Display for ProtectionMetrics<'_> {
398
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
399
        writeln!(
400
            f,
401
            "Protection set {} has {} roots, max {} and {} insertions",
402
            self.0.index,
403
            LargeFormatter(self.0.protection_set.len()),
404
            LargeFormatter(self.0.protection_set.maximum_size()),
405
            LargeFormatter(self.0.protection_set.number_of_insertions())
406
        )?;
407

            
408
        writeln!(
409
            f,
410
            "Containers: {} roots, max {} and {} insertions",
411
            LargeFormatter(self.0.container_protection_set.len()),
412
            LargeFormatter(self.0.container_protection_set.maximum_size()),
413
            LargeFormatter(self.0.container_protection_set.number_of_insertions()),
414
        )?;
415

            
416
        write!(
417
            f,
418
            "Symbols: {} roots, max {} and {} insertions",
419
            LargeFormatter(self.0.symbol_protection_set.len()),
420
            LargeFormatter(self.0.symbol_protection_set.maximum_size()),
421
            LargeFormatter(self.0.symbol_protection_set.number_of_insertions()),
422
        )
423
    }
424
}
425

            
426
/// Helper struct to pass private data required to mark term recursively.
427
pub struct Marker<'a> {
428
    marked_terms: &'a mut HashSet<ATermIndex>,
429
    marked_symbols: &'a mut HashSet<SymbolIndex>,
430
    stack: &'a mut Vec<ATermIndex>,
431
}
432

            
433
impl Marker<'_> {
434
    // Marks the given term as being reachable.
435
    pub fn mark(&mut self, term: &ATermRef<'_>) {
436
        if !self.marked_terms.contains(term.shared()) {
437
            self.stack.push(term.shared().copy());
438

            
439
            while let Some(term) = self.stack.pop() {
440
                // Each term should be marked.
441
                self.marked_terms.insert(term.copy());
442

            
443
                // Mark the function symbol.
444
                self.marked_symbols.insert(term.symbol().shared().copy());
445

            
446
                // For some terms, such as ATermInt, we must ONLY consider the valid arguments (indicated by the arity)
447
                for arg in term.arguments()[0..term.symbol().arity()].iter() {
448
                    // Skip if unnecessary, otherwise mark before pushing to stack since it can be shared.
449
                    if !self.marked_terms.contains(arg.shared()) {
450
                        self.marked_terms.insert(arg.shared().copy());
451
                        self.marked_symbols.insert(arg.get_head_symbol().shared().copy());
452
                        self.stack.push(arg.shared().copy());
453
                    }
454
                }
455
            }
456
        }
457
    }
458

            
459
    /// Marks the given symbol as being reachable.
460
    pub fn mark_symbol(&mut self, symbol: &SymbolRef<'_>) {
461
        self.marked_symbols.insert(symbol.shared().copy());
462
    }
463
}
464

            
465
#[cfg(test)]
466
mod tests {
467
    use std::collections::HashMap;
468

            
469
    use merc_utilities::random_test;
470

            
471
    use crate::random_term;
472

            
473
    #[test]
474
    #[cfg_attr(miri, ignore)]
475
1
    fn test_maximal_sharing() {
476
100
        random_test(100, |rng| {
477
100
            let mut terms = HashMap::new();
478

            
479
100
            for _ in 0..1000 {
480
100000
                let term = random_term(rng, &[("f".into(), 2), ("g".into(), 1)], &["a".to_string()], 10);
481

            
482
100000
                let representation = format!("{}", term);
483
100000
                if let Some(entry) = terms.get(&representation) {
484
45843
                    assert_eq!(term, *entry, "There is another term with the same representation");
485
54157
                } else {
486
54157
                    terms.insert(representation, term);
487
54157
                }
488
            }
489
100
        });
490
1
    }
491
}