1
use std::cell::UnsafeCell;
2
use std::collections::HashSet;
3
use std::fmt;
4
use std::sync::Arc;
5
use std::sync::LazyLock;
6
use std::sync::atomic::AtomicUsize;
7
use std::time::Instant;
8

            
9
use log::debug;
10

            
11
use merc_collections::ProtectionSet;
12
use merc_io::LargeFormatter;
13
use merc_sharedmutex::GlobalBfSharedMutex;
14
use merc_sharedmutex::RecursiveLockReadGuard;
15
use merc_unsafety::StablePointer;
16
use merc_utilities::debug_trace;
17

            
18
use crate::ATermIndex;
19
use crate::ATermRef;
20
use crate::Markable;
21
use crate::Symb;
22
use crate::Symbol;
23
use crate::SymbolIndex;
24
use crate::SymbolRef;
25
use crate::Term;
26
use crate::storage::ATermStorage;
27
use crate::storage::SharedTerm;
28
use crate::storage::SharedTermLookup;
29
use crate::storage::SymbolPool;
30

            
31
/// This is the global set of protection sets that are managed by the ThreadTermPool
32
pub static GLOBAL_TERM_POOL: LazyLock<GlobalBfSharedMutex<GlobalTermPool>> =
33
583
    LazyLock::new(|| GlobalBfSharedMutex::new(GlobalTermPool::new()));
34

            
35
/// Enables aggressive garbage collection, which is used for testing.
36
pub(crate) const AGGRESSIVE_GC: bool = false;
37

            
38
/// A type alias for the global term pool guard
39
pub(crate) type GlobalTermPoolGuard<'a> = RecursiveLockReadGuard<'a, GlobalTermPool>;
40

            
41
/// A type alias for deletion hooks
42
type DeletionHook = Box<dyn Fn(&ATermIndex) + Sync + Send>;
43

            
44
/// The single global (singleton) term pool.
45
pub struct GlobalTermPool {
46
    /// Unique table of all terms with stable pointers for references
47
    terms: ATermStorage,
48
    /// The symbol pool for managing function symbols.
49
    symbol_pool: SymbolPool,
50
    /// The thread-specific protection sets.
51
    thread_pools: Vec<Option<Arc<UnsafeCell<SharedTermProtection>>>>,
52

            
53
    // Data structures used for garbage collection
54
    /// Used to avoid reallocations for the markings of all terms - uses pointers as keys
55
    marked_terms: HashSet<ATermIndex>,
56
    /// Used to avoid reallocations for the markings of all symbols
57
    marked_symbols: HashSet<SymbolIndex>,
58
    /// A stack used to mark terms recursively.
59
    stack: Vec<ATermIndex>,
60

            
61
    /// Deletion hooks called whenever a term with the given head symbol is deleted.
62
    deletion_hooks: Vec<(Symbol, DeletionHook)>,
63

            
64
    /// Indicates whether automatic garbage collection is enabled.
65
    garbage_collection: bool,
66

            
67
    /// Default terms
68
    int_symbol: SymbolRef<'static>,
69
    empty_list_symbol: SymbolRef<'static>,
70
    list_symbol: SymbolRef<'static>,
71
}
72

            
73
unsafe impl Send for GlobalTermPool {}
74
unsafe impl Sync for GlobalTermPool {}
75

            
76
impl GlobalTermPool {
77
583
    fn new() -> GlobalTermPool {
78
        // Insert the default symbols.
79
583
        let symbol_pool = SymbolPool::new();
80
583
        let int_symbol = unsafe { SymbolRef::from_index(&symbol_pool.create("<aterm_int>", 0)) };
81
583
        let list_symbol = unsafe { SymbolRef::from_index(&symbol_pool.create("<list_constructor>", 2)) };
82
583
        let empty_list_symbol = unsafe { SymbolRef::from_index(&symbol_pool.create("<empty_list>", 0)) };
83

            
84
583
        GlobalTermPool {
85
583
            terms: ATermStorage::new(),
86
583
            symbol_pool,
87
583
            thread_pools: Vec::new(),
88
583
            marked_terms: HashSet::new(),
89
583
            marked_symbols: HashSet::new(),
90
583
            stack: Vec::new(),
91
583
            deletion_hooks: Vec::new(),
92
583
            garbage_collection: true,
93
583
            int_symbol,
94
583
            list_symbol,
95
583
            empty_list_symbol,
96
583
        }
97
583
    }
98

            
99
    /// Returns the number of terms in the pool.
100
56040
    pub fn len(&self) -> usize {
101
56040
        self.terms.len()
102
56040
    }
103

            
104
    /// Returns whether the term pool is empty.
105
    pub fn is_empty(&self) -> bool {
106
        self.len() == 0
107
    }
108

            
109
    /// Creates a term storing a single integer value.
110
3706324
    pub fn create_int(&self, value: usize) -> (StablePointer<SharedTerm>, bool) {
111
3706324
        let shared_term = SharedTermLookup {
112
3706324
            symbol: unsafe { SymbolRef::from_index(self.int_symbol.shared()) },
113
3706324
            arguments: &[],
114
3706324
            annotation: Some(value),
115
3706324
        };
116

            
117
3706324
        let (index, inserted) = unsafe {
118
3706324
            self.terms
119
3706324
                .insert_equiv_dst(&shared_term, SharedTerm::length_for(&shared_term), |ptr, key| {
120
1985
                    SharedTerm::construct(ptr, key)
121
1985
                })
122
        };
123

            
124
3706324
        (index, inserted)
125
3706324
    }
126

            
127
    /// Create a term from a head symbol and an iterator over its arguments
128
50063311
    pub fn create_term_array<'a, 'b, 'c>(
129
50063311
        &'c self,
130
50063311
        symbol: &'b impl Symb<'a, 'b>,
131
50063311
        args: &'c [ATermRef<'c>],
132
50063311
    ) -> (StablePointer<SharedTerm>, bool) {
133
50063311
        let shared_term = SharedTermLookup {
134
50063311
            symbol: SymbolRef::from_symbol(symbol),
135
50063311
            arguments: args,
136
50063311
            annotation: None,
137
50063311
        };
138

            
139
50063311
        debug_assert_eq!(
140
50063311
            symbol.shared().arity(),
141
50063311
            shared_term.arguments.len(),
142
            "The number of arguments does not match the arity of the symbol"
143
        );
144

            
145
50063311
        let (index, inserted) = unsafe {
146
50063311
            self.terms
147
50063311
                .insert_equiv_dst(&shared_term, SharedTerm::length_for(&shared_term), |ptr, key| {
148
5176489
                    SharedTerm::construct(ptr, key)
149
5176489
                })
150
        };
151

            
152
50063311
        (index, inserted)
153
50063311
    }
154

            
155
    /// Create a function symbol
156
3288484
    pub fn create_symbol<P>(&self, name: impl Into<String> + AsRef<str>, arity: usize, protect: P) -> Symbol
157
3288484
    where
158
3288484
        P: FnOnce(SymbolIndex) -> Symbol,
159
    {
160
3288484
        protect(self.symbol_pool.create(name, arity))
161
3288484
    }
162

            
163
    /// Registers a new thread term pool.
164
    ///
165
    /// # Safety
166
    ///
167
    /// Note that the returned `Arc<UnsafeCell<...>>` is not Send or Sync, so it
168
    /// *must* be protected through other means.
169
    #[allow(clippy::arc_with_non_send_sync)]
170
585
    pub(crate) fn register_thread_term_pool(&mut self) -> Arc<UnsafeCell<SharedTermProtection>> {
171
585
        let protection = Arc::new(UnsafeCell::new(SharedTermProtection {
172
585
            protection_set: ProtectionSet::new(),
173
585
            symbol_protection_set: ProtectionSet::new(),
174
585
            container_protection_set: ProtectionSet::new(),
175
585
            index: self.thread_pools.len(),
176
585
        }));
177

            
178
585
        debug!("Registered thread_local protection set(s) {}", self.thread_pools.len());
179
585
        self.thread_pools.push(Some(protection.clone()));
180

            
181
585
        protection
182
585
    }
183

            
184
    /// Deregisters a thread pool.
185
585
    pub(crate) fn deregister_thread_pool(&mut self, index: usize) {
186
585
        debug!("Removed thread_local protection set(s) {index}");
187
585
        if let Some(entry) = self.thread_pools.get_mut(index) {
188
585
            *entry = None;
189
585
        }
190
585
    }
191

            
192
    /// Triggers garbage collection if necessary and returns an updated counter for the thread local pool.
193
28020
    pub(crate) fn trigger_garbage_collection(&mut self) -> usize {
194
28020
        self.collect_garbage();
195

            
196
28020
        if AGGRESSIVE_GC {
197
            return 1;
198
28020
        }
199

            
200
28020
        self.len()
201
28020
    }
202

            
203
    /// Returns a counter for the unique numeric suffix of the given prefix.
204
1
    pub fn register_prefix(&self, prefix: &str) -> Arc<AtomicUsize> {
205
1
        self.symbol_pool.create_prefix(prefix)
206
1
    }
207

            
208
    /// Removes the registration of a prefix from the symbol pool.
209
    pub fn remove_prefix(&self, prefix: &str) {
210
        self.symbol_pool.remove_prefix(prefix)
211
    }
212

            
213
    /// Register a deletion hook that is called whenever a term is deleted with the given symbol.
214
    pub fn register_deletion_hook<F>(&mut self, symbol: SymbolRef<'static>, hook: F)
215
    where
216
        F: Fn(&ATermIndex) + Sync + Send + 'static,
217
    {
218
        self.deletion_hooks.push((symbol.protect(), Box::new(hook)));
219
    }
220

            
221
    /// Enables or disables automatic garbage collection.
222
    pub fn automatic_garbage_collection(&mut self, enabled: bool) {
223
        self.garbage_collection = enabled;
224
    }
225

            
226
    /// Collects garbage terms.
227
28020
    fn collect_garbage(&mut self) {
228
28020
        if !self.garbage_collection {
229
            // Garbage collection is disabled.
230
            return;
231
28020
        }
232

            
233
        // Clear marking data structures
234
28020
        self.marked_terms.clear();
235
28020
        self.marked_symbols.clear();
236
28020
        self.stack.clear();
237

            
238
        // Mark the default symbols
239
28020
        self.marked_symbols.insert(self.int_symbol.shared().copy());
240
28020
        self.marked_symbols.insert(self.list_symbol.shared().copy());
241
28020
        self.marked_symbols.insert(self.empty_list_symbol.shared().copy());
242

            
243
28020
        let mut marker = Marker {
244
28020
            marked_terms: &mut self.marked_terms,
245
28020
            marked_symbols: &mut self.marked_symbols,
246
28020
            stack: &mut self.stack,
247
28020
        };
248

            
249
28020
        let mark_time = Instant::now();
250

            
251
        // Loop through all protection sets and mark the terms.
252
28020
        for pool in self.thread_pools.iter().flatten() {
253
            // SAFETY: We have exclusive access to the global term pool, so no other thread can modify the protection sets.
254
28020
            let pool = unsafe { &mut *pool.get() };
255

            
256
311480
            for (_root, symbol) in pool.symbol_protection_set.iter() {
257
311480
                debug_trace!("Marking root {_root} symbol {symbol:?}");
258
311480
                // Remove all symbols that are not protected
259
311480
                marker.marked_symbols.insert(symbol.copy());
260
311480
            }
261

            
262
39942990
            for (_root, term) in pool.protection_set.iter() {
263
39942990
                debug_trace!("Marking root {_root} term {term:?}");
264
39942990
                unsafe {
265
39942990
                    ATermRef::from_index(term).mark(&mut marker);
266
39942990
                }
267
            }
268

            
269
12849560
            for (_, container) in pool.container_protection_set.iter() {
270
12849560
                container.mark(&mut marker);
271
12849560
            }
272
        }
273

            
274
28020
        let mark_time_elapsed = mark_time.elapsed();
275
28020
        let collect_time = Instant::now();
276

            
277
28020
        let num_of_terms = self.len();
278
28020
        let num_of_symbols = self.symbol_pool.len();
279

            
280
        // Delete all terms that are not marked
281
9939580
        self.terms.retain(|term| {
282
9939580
            if !self.marked_terms.contains(term) {
283
5001660
                debug_trace!("Dropping term: {:?}", term);
284

            
285
                // Call the deletion hooks for the term
286
5001660
                for (symbol, hook) in &self.deletion_hooks {
287
                    if symbol == term.symbol() {
288
                        debug_trace!("Calling deletion hook for term: {:?}", term);
289
                        hook(term);
290
                    }
291
                }
292

            
293
5001660
                return false;
294
4937920
            }
295

            
296
4937920
            true
297
9939580
        });
298

            
299
        // We ensure that every removed symbol is not used anymore.
300
972480
        self.symbol_pool.retain(|symbol| {
301
972480
            if !self.marked_symbols.contains(symbol) {
302
1690
                debug_trace!("Dropping symbol: {:?}", symbol);
303
1690
                return false;
304
970790
            }
305

            
306
970790
            true
307
972480
        });
308

            
309
28020
        debug!(
310
            "Garbage collection: marking took {}ms, collection took {}ms, {} terms and {} symbols removed",
311
            mark_time_elapsed.as_millis(),
312
            collect_time.elapsed().as_millis(),
313
            num_of_terms - self.len(),
314
            num_of_symbols - self.symbol_pool.len()
315
        );
316

            
317
28020
        debug!("{}", self.metrics());
318

            
319
        // Print information from the protection sets.
320
28020
        for pool in self.thread_pools.iter().flatten() {
321
            // SAFETY: We have exclusive access to the global term pool, so no other thread can modify the protection sets.
322
28020
            let pool = unsafe { &mut *pool.get() };
323
28020
            debug!("{}", pool.metrics());
324
        }
325
28020
    }
326

            
327
    /// Returns the metrics of the term pool, can be formatted and written to output.
328
    pub fn metrics(&self) -> TermPoolMetrics<'_> {
329
        TermPoolMetrics(self)
330
    }
331

            
332
    /// Marks the given term as being reachable.
333
    ///
334
    /// # Safety
335
    ///
336
    /// Should only be called during garbage collection.
337
    pub unsafe fn mark_term(&mut self, term: &ATermRef<'_>) {
338
        // Ensure that the global term pool is locked for writing.
339
        let mut marker = Marker {
340
            marked_terms: &mut self.marked_terms,
341
            marked_symbols: &mut self.marked_symbols,
342
            stack: &mut self.stack,
343
        };
344
        term.mark(&mut marker);
345
    }
346

            
347
    /// Returns integer function symbol.
348
585
    pub(crate) fn get_int_symbol(&self) -> &SymbolRef<'static> {
349
585
        &self.int_symbol
350
585
    }
351

            
352
    /// Returns integer function symbol.
353
585
    pub(crate) fn get_list_symbol(&self) -> &SymbolRef<'static> {
354
585
        &self.list_symbol
355
585
    }
356

            
357
    /// Returns integer function symbol.
358
585
    pub(crate) fn get_empty_list_symbol(&self) -> &SymbolRef<'static> {
359
585
        &self.empty_list_symbol
360
585
    }
361
}
362

            
363
pub struct TermPoolMetrics<'a>(&'a GlobalTermPool);
364

            
365
impl fmt::Display for TermPoolMetrics<'_> {
366
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
367
        write!(
368
            f,
369
            "There are {} terms, and {} symbols",
370
            self.0.terms.len(),
371
            self.0.symbol_pool.len()
372
        )
373
    }
374
}
375

            
376
pub struct SharedTermProtection {
377
    /// Protection set for terms
378
    pub protection_set: ProtectionSet<ATermIndex>,
379
    /// Protection set to prevent garbage collection of symbols
380
    pub symbol_protection_set: ProtectionSet<SymbolIndex>,
381
    /// Protection set for containers
382
    pub container_protection_set: ProtectionSet<Arc<dyn Markable + Sync + Send>>,
383
    /// Index in global pool's thread pools list
384
    pub index: usize,
385
}
386

            
387
impl SharedTermProtection {
388
    /// Returns the metrics of the term pool, can be formatted and written to output.
389
    pub fn metrics(&self) -> ProtectionMetrics<'_> {
390
        ProtectionMetrics(self)
391
    }
392
}
393

            
394
/// A struct that can be used to print the performance of the protection sets.
395
pub struct ProtectionMetrics<'a>(&'a SharedTermProtection);
396

            
397
impl fmt::Display for ProtectionMetrics<'_> {
398
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
399
        writeln!(
400
            f,
401
            "Protection set {} has {} roots, max {} and {} insertions",
402
            self.0.index,
403
            LargeFormatter(self.0.protection_set.len()),
404
            LargeFormatter(self.0.protection_set.maximum_size()),
405
            LargeFormatter(self.0.protection_set.number_of_insertions())
406
        )?;
407

            
408
        writeln!(
409
            f,
410
            "Containers: {} roots, max {} and {} insertions",
411
            LargeFormatter(self.0.container_protection_set.len()),
412
            LargeFormatter(self.0.container_protection_set.maximum_size()),
413
            LargeFormatter(self.0.container_protection_set.number_of_insertions()),
414
        )?;
415

            
416
        write!(
417
            f,
418
            "Symbols: {} roots, max {} and {} insertions",
419
            LargeFormatter(self.0.symbol_protection_set.len()),
420
            LargeFormatter(self.0.symbol_protection_set.maximum_size()),
421
            LargeFormatter(self.0.symbol_protection_set.number_of_insertions()),
422
        )
423
    }
424
}
425

            
426
/// Helper struct to pass private data required to mark term recursively.
427
pub struct Marker<'a> {
428
    marked_terms: &'a mut HashSet<ATermIndex>,
429
    marked_symbols: &'a mut HashSet<SymbolIndex>,
430
    stack: &'a mut Vec<ATermIndex>,
431
}
432

            
433
impl Marker<'_> {
434
    // Marks the given term as being reachable.
435
73862860
    pub fn mark(&mut self, term: &ATermRef<'_>) {
436
73862860
        if !self.marked_terms.contains(term.shared()) {
437
1265810
            self.stack.push(term.shared().copy());
438

            
439
6203730
            while let Some(term) = self.stack.pop() {
440
                // Each term should be marked.
441
4937920
                self.marked_terms.insert(term.copy());
442

            
443
                // Mark the function symbol.
444
4937920
                self.marked_symbols.insert(term.symbol().shared().copy());
445

            
446
                // For some terms, such as ATermInt, we must ONLY consider the valid arguments (indicated by the arity)
447
12298050
                for arg in term.arguments()[0..term.symbol().arity()].iter() {
448
                    // Skip if unnecessary, otherwise mark before pushing to stack since it can be shared.
449
12298050
                    if !self.marked_terms.contains(arg.shared()) {
450
3672110
                        self.marked_terms.insert(arg.shared().copy());
451
3672110
                        self.marked_symbols.insert(arg.get_head_symbol().shared().copy());
452
3672110
                        self.stack.push(arg.shared().copy());
453
8625940
                    }
454
                }
455
            }
456
72597050
        }
457
73862860
    }
458

            
459
    /// Marks the given symbol as being reachable.
460
1550
    pub fn mark_symbol(&mut self, symbol: &SymbolRef<'_>) {
461
1550
        self.marked_symbols.insert(symbol.shared().copy());
462
1550
    }
463
}
464

            
465
#[cfg(test)]
466
mod tests {
467
    use std::collections::HashMap;
468

            
469
    use merc_utilities::random_test;
470

            
471
    use crate::random_term;
472

            
473
    #[test]
474
    #[cfg_attr(miri, ignore)]
475
1
    fn test_maximal_sharing() {
476
100
        random_test(100, |rng| {
477
100
            let mut terms = HashMap::new();
478

            
479
100
            for _ in 0..1000 {
480
100000
                let term = random_term(rng, &[("f".into(), 2), ("g".into(), 1)], &["a".to_string()], 10);
481

            
482
100000
                let representation = format!("{}", term);
483
100000
                if let Some(entry) = terms.get(&representation) {
484
45466
                    assert_eq!(term, *entry, "There is another term with the same representation");
485
54534
                } else {
486
54534
                    terms.insert(representation, term);
487
54534
                }
488
            }
489
100
        });
490
1
    }
491
}