1
use std::cell::UnsafeCell;
2
use std::collections::HashSet;
3
use std::fmt;
4
use std::sync::Arc;
5
use std::sync::LazyLock;
6
use std::sync::atomic::AtomicUsize;
7
use std::time::Instant;
8

            
9
use log::debug;
10

            
11
use merc_collections::ProtectionSet;
12
use merc_io::LargeFormatter;
13
use merc_sharedmutex::GlobalBfSharedMutex;
14
use merc_sharedmutex::RecursiveLockReadGuard;
15
use merc_unsafety::StablePointer;
16
use merc_utilities::debug_trace;
17

            
18
use crate::ATermIndex;
19
use crate::ATermRef;
20
use crate::Markable;
21
use crate::Symb;
22
use crate::Symbol;
23
use crate::SymbolIndex;
24
use crate::SymbolRef;
25
use crate::Term;
26
use crate::storage::ATermStorage;
27
use crate::storage::SharedTerm;
28
use crate::storage::SharedTermLookup;
29
use crate::storage::SymbolPool;
30

            
31
/// This is the global set of protection sets that are managed by the ThreadTermPool
32
pub static GLOBAL_TERM_POOL: LazyLock<GlobalBfSharedMutex<GlobalTermPool>> =
33
593
    LazyLock::new(|| GlobalBfSharedMutex::new(GlobalTermPool::new()));
34

            
35
/// Enables aggressive garbage collection, which is used for testing.
36
pub(crate) const AGGRESSIVE_GC: bool = false;
37

            
38
/// A type alias for the global term pool guard
39
pub(crate) type GlobalTermPoolGuard<'a> = RecursiveLockReadGuard<'a, GlobalTermPool>;
40

            
41
/// A type alias for deletion hooks
42
type DeletionHook = Box<dyn Fn(&ATermIndex) + Sync + Send>;
43

            
44
/// The single global (singleton) term pool.
45
pub struct GlobalTermPool {
46
    /// Unique table of all terms with stable pointers for references
47
    terms: ATermStorage,
48
    /// The symbol pool for managing function symbols.
49
    symbol_pool: SymbolPool,
50
    /// The thread-specific protection sets.
51
    thread_pools: Vec<Option<Arc<UnsafeCell<SharedTermProtection>>>>,
52

            
53
    // Data structures used for garbage collection
54
    /// Used to avoid reallocations for the markings of all terms - uses pointers as keys
55
    marked_terms: HashSet<ATermIndex>,
56
    /// Used to avoid reallocations for the markings of all symbols
57
    marked_symbols: HashSet<SymbolIndex>,
58
    /// A stack used to mark terms recursively.
59
    stack: Vec<ATermIndex>,
60

            
61
    /// Deletion hooks called whenever a term with the given head symbol is deleted.
62
    deletion_hooks: Vec<(Symbol, DeletionHook)>,
63

            
64
    /// Indicates whether automatic garbage collection is enabled.
65
    garbage_collection: bool,
66

            
67
    /// Default terms
68
    int_symbol: SymbolRef<'static>,
69
    empty_list_symbol: SymbolRef<'static>,
70
    list_symbol: SymbolRef<'static>,
71
}
72

            
73
unsafe impl Send for GlobalTermPool {}
74
unsafe impl Sync for GlobalTermPool {}
75

            
76
impl GlobalTermPool {
77
593
    fn new() -> GlobalTermPool {
78
        // Insert the default symbols.
79
593
        let symbol_pool = SymbolPool::new();
80
593
        let int_symbol = unsafe { SymbolRef::from_index(&symbol_pool.create("<aterm_int>", 0)) };
81
593
        let list_symbol = unsafe { SymbolRef::from_index(&symbol_pool.create("<list_constructor>", 2)) };
82
593
        let empty_list_symbol = unsafe { SymbolRef::from_index(&symbol_pool.create("<empty_list>", 0)) };
83

            
84
593
        GlobalTermPool {
85
593
            terms: ATermStorage::new(),
86
593
            symbol_pool,
87
593
            thread_pools: Vec::new(),
88
593
            marked_terms: HashSet::new(),
89
593
            marked_symbols: HashSet::new(),
90
593
            stack: Vec::new(),
91
593
            deletion_hooks: Vec::new(),
92
593
            garbage_collection: true,
93
593
            int_symbol,
94
593
            list_symbol,
95
593
            empty_list_symbol,
96
593
        }
97
593
    }
98

            
99
    /// Returns the number of terms in the pool.
100
56040
    pub fn len(&self) -> usize {
101
56040
        self.terms.len()
102
56040
    }
103

            
104
    /// Returns whether the term pool is empty.
105
    pub fn is_empty(&self) -> bool {
106
        self.len() == 0
107
    }
108

            
109
    /// Creates a term storing a single integer value.
110
3730244
    pub fn create_int(&self, value: usize) -> (StablePointer<SharedTerm>, bool) {
111
3730244
        let shared_term = SharedTermLookup {
112
3730244
            symbol: unsafe { SymbolRef::from_index(self.int_symbol.shared()) },
113
3730244
            arguments: &[],
114
3730244
            annotation: Some(value),
115
3730244
        };
116

            
117
3730244
        let (index, inserted) = unsafe {
118
3730244
            self.terms
119
3730244
                .insert_equiv_dst(&shared_term, SharedTerm::length_for(&shared_term), |ptr, key| {
120
1985
                    SharedTerm::construct(ptr, key)
121
1985
                })
122
        };
123

            
124
3730244
        (index, inserted)
125
3730244
    }
126

            
127
    /// Create a term from a head symbol and an iterator over its arguments
128
50185081
    pub fn create_term_array<'a, 'b, 'c, S: Symb<'a, 'b>>(
129
50185081
        &'c self,
130
50185081
        symbol: &'b S,
131
50185081
        args: &'c [ATermRef<'c>],
132
50185081
    ) -> (StablePointer<SharedTerm>, bool) {
133
50185081
        let shared_term = SharedTermLookup {
134
50185081
            symbol: SymbolRef::from_symbol(symbol),
135
50185081
            arguments: args,
136
50185081
            annotation: None,
137
50185081
        };
138

            
139
50185081
        debug_assert_eq!(
140
50185081
            symbol.shared().arity(),
141
50185081
            shared_term.arguments.len(),
142
            "The number of arguments does not match the arity of the symbol"
143
        );
144

            
145
50185081
        let (index, inserted) = unsafe {
146
50185081
            self.terms
147
50185081
                .insert_equiv_dst(&shared_term, SharedTerm::length_for(&shared_term), |ptr, key| {
148
5178407
                    SharedTerm::construct(ptr, key)
149
5178407
                })
150
        };
151

            
152
50185081
        (index, inserted)
153
50185081
    }
154

            
155
    /// Create a function symbol
156
3329864
    pub fn create_symbol<P, N>(&self, name: N, arity: usize, protect: P) -> Symbol
157
3329864
    where
158
3329864
        P: FnOnce(SymbolIndex) -> Symbol,
159
3329864
        N: Into<String> + AsRef<str>,
160
    {
161
3329864
        protect(self.symbol_pool.create(name, arity))
162
3329864
    }
163

            
164
    /// Registers a new thread term pool.
165
    ///
166
    /// # Safety
167
    ///
168
    /// Note that the returned `Arc<UnsafeCell<...>>` is not Send or Sync, so it
169
    /// *must* be protected through other means.
170
    #[allow(clippy::arc_with_non_send_sync)]
171
595
    pub(crate) fn register_thread_term_pool(&mut self) -> Arc<UnsafeCell<SharedTermProtection>> {
172
595
        let protection = Arc::new(UnsafeCell::new(SharedTermProtection {
173
595
            protection_set: ProtectionSet::new(),
174
595
            symbol_protection_set: ProtectionSet::new(),
175
595
            container_protection_set: ProtectionSet::new(),
176
595
            index: self.thread_pools.len(),
177
595
        }));
178

            
179
595
        debug!("Registered thread_local protection set(s) {}", self.thread_pools.len());
180
595
        self.thread_pools.push(Some(protection.clone()));
181

            
182
595
        protection
183
595
    }
184

            
185
    /// Deregisters a thread pool.
186
595
    pub(crate) fn deregister_thread_pool(&mut self, index: usize) {
187
595
        debug!("Removed thread_local protection set(s) {index}");
188
595
        if let Some(entry) = self.thread_pools.get_mut(index) {
189
595
            *entry = None;
190
595
        }
191
595
    }
192

            
193
    /// Triggers garbage collection if necessary and returns an updated counter for the thread local pool.
194
28020
    pub(crate) fn trigger_garbage_collection(&mut self) -> usize {
195
28020
        self.collect_garbage();
196

            
197
28020
        if AGGRESSIVE_GC {
198
            return 1;
199
28020
        }
200

            
201
28020
        self.len()
202
28020
    }
203

            
204
    /// Returns a counter for the unique numeric suffix of the given prefix.
205
1
    pub fn register_prefix(&self, prefix: &str) -> Arc<AtomicUsize> {
206
1
        self.symbol_pool.create_prefix(prefix)
207
1
    }
208

            
209
    /// Removes the registration of a prefix from the symbol pool.
210
    pub fn remove_prefix(&self, prefix: &str) {
211
        self.symbol_pool.remove_prefix(prefix)
212
    }
213

            
214
    /// Register a deletion hook that is called whenever a term is deleted with the given symbol.
215
    pub fn register_deletion_hook<F>(&mut self, symbol: SymbolRef<'static>, hook: F)
216
    where
217
        F: Fn(&ATermIndex) + Sync + Send + 'static,
218
    {
219
        self.deletion_hooks.push((symbol.protect(), Box::new(hook)));
220
    }
221

            
222
    /// Enables or disables automatic garbage collection.
223
    pub fn automatic_garbage_collection(&mut self, enabled: bool) {
224
        self.garbage_collection = enabled;
225
    }
226

            
227
    /// Collects garbage terms.
228
28020
    fn collect_garbage(&mut self) {
229
28020
        if !self.garbage_collection {
230
            // Garbage collection is disabled.
231
            return;
232
28020
        }
233

            
234
        // Clear marking data structures
235
28020
        self.marked_terms.clear();
236
28020
        self.marked_symbols.clear();
237
28020
        self.stack.clear();
238

            
239
        // Mark the default symbols
240
28020
        self.marked_symbols.insert(self.int_symbol.shared().copy());
241
28020
        self.marked_symbols.insert(self.list_symbol.shared().copy());
242
28020
        self.marked_symbols.insert(self.empty_list_symbol.shared().copy());
243

            
244
28020
        let mut marker = Marker {
245
28020
            marked_terms: &mut self.marked_terms,
246
28020
            marked_symbols: &mut self.marked_symbols,
247
28020
            stack: &mut self.stack,
248
28020
        };
249

            
250
28020
        let mark_time = Instant::now();
251

            
252
        // Loop through all protection sets and mark the terms.
253
28020
        for pool in self.thread_pools.iter().flatten() {
254
            // SAFETY: We have exclusive access to the global term pool, so no other thread can modify the protection sets.
255
28020
            let pool = unsafe { &mut *pool.get() };
256

            
257
647720
            for (_root, symbol) in pool.symbol_protection_set.iter() {
258
647720
                debug_trace!("Marking root {_root} symbol {symbol:?}");
259
647720
                // Remove all symbols that are not protected
260
647720
                marker.marked_symbols.insert(symbol.copy());
261
647720
            }
262

            
263
39942870
            for (_root, term) in pool.protection_set.iter() {
264
39942870
                debug_trace!("Marking root {_root} term {term:?}");
265
39942870
                unsafe {
266
39942870
                    ATermRef::from_index(term).mark(&mut marker);
267
39942870
                }
268
            }
269

            
270
12849560
            for (_, container) in pool.container_protection_set.iter() {
271
12849560
                container.mark(&mut marker);
272
12849560
            }
273
        }
274

            
275
28020
        let mark_time_elapsed = mark_time.elapsed();
276
28020
        let collect_time = Instant::now();
277

            
278
28020
        let num_of_terms = self.len();
279
28020
        let num_of_symbols = self.symbol_pool.len();
280

            
281
        // Delete all terms that are not marked
282
9943470
        self.terms.retain(|term| {
283
9943470
            if !self.marked_terms.contains(term) {
284
5003660
                debug_trace!("Dropping term: {:?}", term);
285

            
286
                // Call the deletion hooks for the term
287
5003660
                for (symbol, hook) in &self.deletion_hooks {
288
                    if symbol == term.symbol() {
289
                        debug_trace!("Calling deletion hook for term: {:?}", term);
290
                        hook(term);
291
                    }
292
                }
293

            
294
5003660
                return false;
295
4939810
            }
296

            
297
4939810
            true
298
9943470
        });
299

            
300
        // We ensure that every removed symbol is not used anymore.
301
1308700
        self.symbol_pool.retain(|symbol| {
302
1308700
            if !self.marked_symbols.contains(symbol) {
303
1690
                debug_trace!("Dropping symbol: {:?}", symbol);
304
1690
                return false;
305
1307010
            }
306

            
307
1307010
            true
308
1308700
        });
309

            
310
28020
        debug!(
311
            "Garbage collection: marking took {}ms, collection took {}ms, {} terms and {} symbols removed",
312
            mark_time_elapsed.as_millis(),
313
            collect_time.elapsed().as_millis(),
314
            num_of_terms - self.len(),
315
            num_of_symbols - self.symbol_pool.len()
316
        );
317

            
318
28020
        debug!("{}", self.metrics());
319

            
320
        // Print information from the protection sets.
321
28020
        for pool in self.thread_pools.iter().flatten() {
322
            // SAFETY: We have exclusive access to the global term pool, so no other thread can modify the protection sets.
323
28020
            let pool = unsafe { &mut *pool.get() };
324
28020
            debug!("{}", pool.metrics());
325
        }
326
28020
    }
327

            
328
    /// Returns the metrics of the term pool, can be formatted and written to output.
329
    pub fn metrics(&self) -> TermPoolMetrics<'_> {
330
        TermPoolMetrics(self)
331
    }
332

            
333
    /// Marks the given term as being reachable.
334
    ///
335
    /// # Safety
336
    ///
337
    /// Should only be called during garbage collection.
338
    pub unsafe fn mark_term(&mut self, term: &ATermRef<'_>) {
339
        // Ensure that the global term pool is locked for writing.
340
        let mut marker = Marker {
341
            marked_terms: &mut self.marked_terms,
342
            marked_symbols: &mut self.marked_symbols,
343
            stack: &mut self.stack,
344
        };
345
        term.mark(&mut marker);
346
    }
347

            
348
    /// Returns integer function symbol.
349
595
    pub(crate) fn get_int_symbol(&self) -> &SymbolRef<'static> {
350
595
        &self.int_symbol
351
595
    }
352

            
353
    /// Returns integer function symbol.
354
595
    pub(crate) fn get_list_symbol(&self) -> &SymbolRef<'static> {
355
595
        &self.list_symbol
356
595
    }
357

            
358
    /// Returns integer function symbol.
359
595
    pub(crate) fn get_empty_list_symbol(&self) -> &SymbolRef<'static> {
360
595
        &self.empty_list_symbol
361
595
    }
362
}
363

            
364
pub struct TermPoolMetrics<'a>(&'a GlobalTermPool);
365

            
366
impl fmt::Display for TermPoolMetrics<'_> {
367
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
368
        write!(
369
            f,
370
            "There are {} terms, and {} symbols",
371
            self.0.terms.len(),
372
            self.0.symbol_pool.len()
373
        )
374
    }
375
}
376

            
377
pub struct SharedTermProtection {
378
    /// Protection set for terms
379
    pub protection_set: ProtectionSet<ATermIndex>,
380
    /// Protection set to prevent garbage collection of symbols
381
    pub symbol_protection_set: ProtectionSet<SymbolIndex>,
382
    /// Protection set for containers
383
    pub container_protection_set: ProtectionSet<Arc<dyn Markable + Sync + Send>>,
384
    /// Index in global pool's thread pools list
385
    pub index: usize,
386
}
387

            
388
impl SharedTermProtection {
389
    /// Returns the metrics of the term pool, can be formatted and written to output.
390
    pub fn metrics(&self) -> ProtectionMetrics<'_> {
391
        ProtectionMetrics(self)
392
    }
393
}
394

            
395
/// A struct that can be used to print the performance of the protection sets.
396
pub struct ProtectionMetrics<'a>(&'a SharedTermProtection);
397

            
398
impl fmt::Display for ProtectionMetrics<'_> {
399
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
400
        writeln!(
401
            f,
402
            "Protection set {} has {} roots, max {} and {} insertions",
403
            self.0.index,
404
            LargeFormatter(self.0.protection_set.len()),
405
            LargeFormatter(self.0.protection_set.maximum_size()),
406
            LargeFormatter(self.0.protection_set.number_of_insertions())
407
        )?;
408

            
409
        writeln!(
410
            f,
411
            "Containers: {} roots, max {} and {} insertions",
412
            LargeFormatter(self.0.container_protection_set.len()),
413
            LargeFormatter(self.0.container_protection_set.maximum_size()),
414
            LargeFormatter(self.0.container_protection_set.number_of_insertions()),
415
        )?;
416

            
417
        write!(
418
            f,
419
            "Symbols: {} roots, max {} and {} insertions",
420
            LargeFormatter(self.0.symbol_protection_set.len()),
421
            LargeFormatter(self.0.symbol_protection_set.maximum_size()),
422
            LargeFormatter(self.0.symbol_protection_set.number_of_insertions()),
423
        )
424
    }
425
}
426

            
427
/// Helper struct to pass private data required to mark term recursively.
428
pub struct Marker<'a> {
429
    marked_terms: &'a mut HashSet<ATermIndex>,
430
    marked_symbols: &'a mut HashSet<SymbolIndex>,
431
    stack: &'a mut Vec<ATermIndex>,
432
}
433

            
434
impl Marker<'_> {
435
    // Marks the given term as being reachable.
436
73862970
    pub fn mark(&mut self, term: &ATermRef<'_>) {
437
73862970
        if !self.marked_terms.contains(term.shared()) {
438
1291320
            self.stack.push(term.shared().copy());
439

            
440
6231130
            while let Some(term) = self.stack.pop() {
441
                // Each term should be marked.
442
4939810
                self.marked_terms.insert(term.copy());
443

            
444
                // Mark the function symbol.
445
4939810
                self.marked_symbols.insert(term.symbol().shared().copy());
446

            
447
                // For some terms, such as ATermInt, we must ONLY consider the valid arguments (indicated by the arity)
448
12306110
                for arg in term.arguments()[0..term.symbol().arity()].iter() {
449
                    // Skip if unnecessary, otherwise mark before pushing to stack since it can be shared.
450
12306110
                    if !self.marked_terms.contains(arg.shared()) {
451
3648490
                        self.marked_terms.insert(arg.shared().copy());
452
3648490
                        self.marked_symbols.insert(arg.get_head_symbol().shared().copy());
453
3648490
                        self.stack.push(arg.shared().copy());
454
8657620
                    }
455
                }
456
            }
457
72571650
        }
458
73862970
    }
459

            
460
    /// Marks the given symbol as being reachable.
461
1550
    pub fn mark_symbol(&mut self, symbol: &SymbolRef<'_>) {
462
1550
        self.marked_symbols.insert(symbol.shared().copy());
463
1550
    }
464
}
465

            
466
#[cfg(test)]
467
mod tests {
468
    use std::collections::HashMap;
469

            
470
    use merc_utilities::random_test;
471

            
472
    use crate::random_term;
473

            
474
    #[test]
475
    #[cfg_attr(miri, ignore)]
476
1
    fn test_maximal_sharing() {
477
100
        random_test(100, |rng| {
478
100
            let mut terms = HashMap::new();
479

            
480
100
            for _ in 0..1000 {
481
100000
                let term = random_term(rng, &[("f".into(), 2), ("g".into(), 1)], &["a".to_string()], 10);
482

            
483
100000
                let representation = format!("{}", term);
484
100000
                if let Some(entry) = terms.get(&representation) {
485
45403
                    assert_eq!(term, *entry, "There is another term with the same representation");
486
54597
                } else {
487
54597
                    terms.insert(representation, term);
488
54597
                }
489
            }
490
100
        });
491
1
    }
492
}