1
use std::cell::UnsafeCell;
2
use std::fmt;
3
use std::sync::Arc;
4
use std::sync::LazyLock;
5
use std::sync::Mutex;
6
use std::sync::atomic::AtomicUsize;
7
use std::time::Instant;
8

            
9
use log::debug;
10
use rustc_hash::FxHashSet;
11

            
12
use merc_io::LargeFormatter;
13
use merc_sharedmutex::GlobalBfSharedMutex;
14
use merc_sharedmutex::RecursiveLockReadGuard;
15
use merc_unsafety::ProtectionSet;
16
use merc_unsafety::StablePointer;
17
use merc_utilities::debug_trace;
18

            
19
use crate::ATermIndex;
20
use crate::ATermRef;
21
use crate::Markable;
22
use crate::Symb;
23
use crate::Symbol;
24
use crate::SymbolIndex;
25
use crate::SymbolRef;
26
use crate::Term;
27
use crate::storage::ATermStorage;
28
use crate::storage::SharedTerm;
29
use crate::storage::SymbolPool;
30

            
31
/// This is the global set of protection sets that are managed by the [crate::storage::ThreadTermPool].
32
pub static GLOBAL_TERM_POOL: LazyLock<GlobalBfSharedMutex<GlobalTermPool>> =
33
959
    LazyLock::new(|| GlobalBfSharedMutex::new(GlobalTermPool::new()));
34

            
35
/// Enables aggressive garbage collection, which is used for testing.
36
pub(crate) const AGGRESSIVE_GC: bool = false;
37

            
38
/// A type alias for the global term pool guard
39
pub(crate) type GlobalTermPoolGuard<'a> = RecursiveLockReadGuard<'a, GlobalTermPool>;
40

            
41
/// A type alias for deletion hooks
42
type DeletionHook = Box<dyn Fn(&ATermIndex) + Sync + Send>;
43

            
44
/// The single global (singleton) term pool, accessed via [GLOBAL_TERM_POOL].
45
pub struct GlobalTermPool {
46
    /// Unique table of all terms with stable pointers for references
47
    terms: ATermStorage,
48
    /// The symbol pool for managing function symbols.
49
    symbol_pool: SymbolPool,
50
    /// The thread-specific protection sets.
51
    thread_pools: ThreadPoolList,
52
    /// A separate protection set for sendable terms, see [crate::ATermSend].
53
    send_term_protection_sets: Vec<Option<Arc<Mutex<ProtectionSet<ATermIndex>>>>>,
54

            
55
    // Data structures used for garbage collection
56
    /// Used to avoid reallocations for the markings of all terms - uses pointers as keys
57
    marked_terms: FxHashSet<ATermIndex>,
58
    /// Used to avoid reallocations for the markings of all symbols
59
    marked_symbols: FxHashSet<SymbolIndex>,
60
    /// A stack used to mark terms recursively.
61
    stack: Vec<ATermIndex>,
62

            
63
    /// Deletion hooks called whenever a term with the given head symbol is deleted.
64
    deletion_hooks: Vec<(Symbol, DeletionHook)>,
65

            
66
    /// Indicates whether automatic garbage collection is enabled.
67
    garbage_collection: bool,
68

            
69
    /// Default terms
70
    int_symbol: SymbolRef<'static>,
71
    empty_list_symbol: SymbolRef<'static>,
72
    list_symbol: SymbolRef<'static>,
73
}
74

            
75
impl GlobalTermPool {
76
959
    fn new() -> GlobalTermPool {
77
        // Insert the default symbols, mirrors the symbols defined in mCRL2.
78
959
        let symbol_pool = SymbolPool::new();
79
959
        let int_symbol = unsafe { SymbolRef::from_index(&symbol_pool.create("<aterm_int>", 0)) };
80
959
        let list_symbol = unsafe { SymbolRef::from_index(&symbol_pool.create("<list_constructor>", 2)) };
81
959
        let empty_list_symbol = unsafe { SymbolRef::from_index(&symbol_pool.create("<empty_list>", 0)) };
82

            
83
959
        GlobalTermPool {
84
959
            terms: ATermStorage::new(),
85
959
            symbol_pool,
86
959
            thread_pools: ThreadPoolList(Vec::new()),
87
959
            send_term_protection_sets: Vec::new(),
88
959
            marked_terms: FxHashSet::default(),
89
959
            marked_symbols: FxHashSet::default(),
90
959
            stack: Vec::new(),
91
959
            deletion_hooks: Vec::new(),
92
959
            garbage_collection: true,
93
959
            int_symbol,
94
959
            list_symbol,
95
959
            empty_list_symbol,
96
959
        }
97
959
    }
98

            
99
    /// Returns the number of terms in the pool.
100
776814
    pub fn len(&self) -> usize {
101
776814
        self.terms.len()
102
776814
    }
103

            
104
    /// Returns whether the term pool is empty.
105
    pub fn is_empty(&self) -> bool {
106
        self.len() == 0
107
    }
108

            
109
    /// Creates a term storing a single integer value.
110
6438188
    pub fn create_int(&self, value: usize) -> (StablePointer<SharedTerm>, bool) {
111
6438188
        let (index, inserted) = unsafe {
112
6438188
            self.terms
113
6438188
                .insert_int_term(SymbolRef::from_index(self.int_symbol.shared()), value)
114
6438188
        };
115

            
116
6438188
        (index, inserted)
117
6438188
    }
118

            
119
    /// Create a term from a head symbol and an iterator over its arguments
120
80303400
    pub fn create_term_array<'a, 'b, 'c, S: Symb<'a, 'b>>(
121
80303400
        &'c self,
122
80303400
        symbol: &'b S,
123
80303400
        args: &'c [ATermRef<'c>],
124
80303400
    ) -> (StablePointer<SharedTerm>, bool) {
125
80303400
        self.terms.insert(symbol, args)
126
80303400
    }
127

            
128
    /// Create a function symbol
129
4981772
    pub fn create_symbol<P, N>(&self, name: N, arity: usize, protect: P) -> Symbol
130
4981772
    where
131
4981772
        P: FnOnce(SymbolIndex) -> Symbol,
132
4981772
        N: Into<String> + AsRef<str>,
133
    {
134
4981772
        protect(self.symbol_pool.create(name, arity))
135
4981772
    }
136

            
137
    /// Registers a new thread term pool.
138
    ///
139
    /// # Safety
140
    ///
141
    /// Note that the returned `Arc<UnsafeCell<...>>` is not Send or Sync, so it
142
    /// *must* be protected through other means.
143
    #[allow(clippy::arc_with_non_send_sync)]
144
962
    pub(crate) fn register_thread_term_pool(
145
962
        &mut self,
146
962
    ) -> (
147
962
        Arc<UnsafeCell<SharedTermProtection>>,
148
962
        Arc<Mutex<ProtectionSet<ATermIndex>>>,
149
962
    ) {
150
962
        let protection = Arc::new(UnsafeCell::new(SharedTermProtection {
151
962
            term_protection_set: ProtectionSet::new(),
152
962
            symbol_protection_set: ProtectionSet::new(),
153
962
            container_protection_set: ProtectionSet::new(),
154
962
            index: self.thread_pools.len(),
155
962
        }));
156

            
157
962
        debug!("Registered thread_local protection set(s) {}", self.thread_pools.len());
158
962
        self.thread_pools.push(Some(protection.clone()));
159

            
160
962
        let protection_set = Arc::new(Mutex::new(ProtectionSet::new()));
161
962
        self.send_term_protection_sets.push(Some(protection_set.clone()));
162

            
163
962
        (protection, protection_set)
164
962
    }
165

            
166
    /// Deregisters a thread pool.
167
962
    pub(crate) fn deregister_thread_pool(&mut self, index: usize) {
168
962
        debug!("Removed thread_local protection set(s) {index}");
169
962
        if let Some(entry) = self.thread_pools.get_mut(index) {
170
962
            *entry = None;
171
962
        }
172
962
    }
173

            
174
    /// Triggers garbage collection if necessary and returns an updated counter for the thread local pool.
175
388407
    pub(crate) fn trigger_garbage_collection(&mut self) -> usize {
176
388407
        if self.garbage_collection {
177
388407
            // Garbage collection is enabled.
178
388407
            self.collect_garbage();
179
388407
        }
180

            
181
388407
        if AGGRESSIVE_GC {
182
            return 1;
183
388407
        }
184

            
185
388407
        self.len()
186
388407
    }
187

            
188
    /// Returns a counter for the unique numeric suffix of the given prefix.
189
1
    pub fn register_prefix(&self, prefix: &str) -> Arc<AtomicUsize> {
190
1
        self.symbol_pool.create_prefix(prefix)
191
1
    }
192

            
193
    /// Removes the registration of a prefix from the symbol pool.
194
    pub fn remove_prefix(&self, prefix: &str) {
195
        self.symbol_pool.remove_prefix(prefix)
196
    }
197

            
198
    /// Register a deletion hook that is called whenever a term is deleted with the given symbol.
199
    pub fn register_deletion_hook<F>(&mut self, symbol: SymbolRef<'static>, hook: F)
200
    where
201
        F: Fn(&ATermIndex) + Sync + Send + 'static,
202
    {
203
        self.deletion_hooks.push((symbol.protect(), Box::new(hook)));
204
    }
205

            
206
    /// Enables or disables automatic garbage collection.
207
    pub fn automatic_garbage_collection(&mut self, enabled: bool) {
208
        self.garbage_collection = enabled;
209
    }
210

            
211
    /// Collects garbage terms.
212
388407
    pub fn collect_garbage(&mut self) {
213
        // Mark the default symbols
214
388407
        self.marked_symbols.insert(self.int_symbol.shared().copy());
215
388407
        self.marked_symbols.insert(self.list_symbol.shared().copy());
216
388407
        self.marked_symbols.insert(self.empty_list_symbol.shared().copy());
217

            
218
388407
        let mut marker = Marker {
219
388407
            marked_terms: &mut self.marked_terms,
220
388407
            marked_symbols: &mut self.marked_symbols,
221
388407
            stack: &mut self.stack,
222
388407
        };
223

            
224
388407
        let mark_time = Instant::now();
225

            
226
        // Loop through all protection sets and mark the terms.
227
388407
        for pool in self.thread_pools.iter().flatten() {
228
            // SAFETY: We have exclusive access to the global term pool, so no other thread can modify the protection sets.
229
388407
            let pool = unsafe { &mut *pool.get() };
230

            
231
7220503
            for (_root, symbol) in pool.symbol_protection_set.iter() {
232
7220103
                debug_trace!("Marking root {_root} symbol {symbol:?}");
233
7220103
                // Remove all symbols that are not protected
234
7220103
                marker.marked_symbols.insert(symbol.copy());
235
7220103
            }
236

            
237
65923831
            for (_root, term) in pool.term_protection_set.iter() {
238
65923831
                debug_trace!("Marking root {_root} term {term:?}");
239
65923831
                unsafe {
240
65923831
                    ATermRef::from_index(term).mark(&mut marker);
241
65923831
                }
242
            }
243

            
244
21242512
            for (_, container) in pool.container_protection_set.iter() {
245
21242512
                container.mark(&mut marker);
246
21242512
            }
247
        }
248

            
249
388407
        for pool in self.send_term_protection_sets.iter().flatten() {
250
388407
            let pool = pool.lock().expect("Lock poisoned!");
251
388407
            for (_root, term) in pool.iter() {
252
                debug_trace!("Marking sendable term {term:?}");
253
                unsafe {
254
                    ATermRef::from_index(term).mark(&mut marker);
255
                }
256
            }
257
        }
258

            
259
388407
        let mark_time_elapsed = mark_time.elapsed();
260
388407
        let collect_time = Instant::now();
261

            
262
388407
        let num_of_terms = self.len();
263
388407
        let num_of_symbols = self.symbol_pool.len();
264

            
265
        // Delete all terms that are not marked
266
27364797
        self.terms.retain(|term| {
267
27364797
            if !self.marked_terms.contains(term) {
268
13735528
                debug_trace!("Dropping term: {:?}", term);
269

            
270
                // Call the deletion hooks for the term
271
13735528
                for (symbol, hook) in &self.deletion_hooks {
272
                    if symbol == term.symbol() {
273
                        debug_trace!("Calling deletion hook for term: {:?}", term);
274
                        hook(term);
275
                    }
276
                }
277

            
278
13735528
                return false;
279
13629269
            }
280

            
281
13629269
            true
282
27364797
        });
283

            
284
        // We ensure that every removed symbol is not used anymore.
285
12847785
        self.symbol_pool.retain(|symbol| {
286
12847785
            if !self.marked_symbols.contains(symbol) {
287
44082
                debug_trace!("Dropping symbol: {:?}", symbol);
288
44082
                return false;
289
12803703
            }
290

            
291
12803703
            true
292
12847785
        });
293

            
294
388407
        debug!(
295
            "Garbage collection: marking took {}ms, collection took {}ms, {} terms and {} symbols removed",
296
            mark_time_elapsed.as_millis(),
297
            collect_time.elapsed().as_millis(),
298
            num_of_terms - self.len(),
299
            num_of_symbols - self.symbol_pool.len()
300
        );
301

            
302
388407
        debug!("{}", self.metrics());
303

            
304
        // Print information from the protection sets.
305
388407
        for pool in self.thread_pools.iter().flatten() {
306
            // SAFETY: We have exclusive access to the global term pool, so no other thread can modify the protection sets.
307
388407
            let pool = unsafe { &mut *pool.get() };
308
388407
            debug!("{}", pool.metrics());
309
        }
310

            
311
        // Clear marking data structures
312
388407
        self.marked_terms.clear();
313
388407
        self.marked_symbols.clear();
314
388407
        self.stack.clear();
315
388407
    }
316

            
317
    /// Returns the metrics of the term pool, can be formatted and written to output.
318
    pub fn metrics(&self) -> TermPoolMetrics<'_> {
319
        TermPoolMetrics(self)
320
    }
321

            
322
    /// Marks the given term as being reachable.
323
    ///
324
    /// # Safety
325
    ///
326
    /// Should only be called during garbage collection.
327
    pub unsafe fn mark_term(&mut self, term: &ATermRef<'_>) {
328
        // Ensure that the global term pool is locked for writing.
329
        let mut marker = Marker {
330
            marked_terms: &mut self.marked_terms,
331
            marked_symbols: &mut self.marked_symbols,
332
            stack: &mut self.stack,
333
        };
334
        term.mark(&mut marker);
335
    }
336

            
337
    /// Returns integer function symbol.
338
962
    pub(crate) fn get_int_symbol(&self) -> &SymbolRef<'static> {
339
962
        &self.int_symbol
340
962
    }
341

            
342
    /// Returns integer function symbol.
343
962
    pub(crate) fn get_list_symbol(&self) -> &SymbolRef<'static> {
344
962
        &self.list_symbol
345
962
    }
346

            
347
    /// Returns integer function symbol.
348
962
    pub(crate) fn get_empty_list_symbol(&self) -> &SymbolRef<'static> {
349
962
        &self.empty_list_symbol
350
962
    }
351
}
352

            
353
pub struct TermPoolMetrics<'a>(&'a GlobalTermPool);
354

            
355
impl fmt::Display for TermPoolMetrics<'_> {
356
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
357
        write!(
358
            f,
359
            "There are {} terms, and {} symbols",
360
            self.0.terms.len(),
361
            self.0.symbol_pool.len()
362
        )
363
    }
364
}
365

            
366
/// A newtype wrapping the per-thread protection-set list stored inside
367
/// [`GlobalTermPool`].
368
///
369
/// # Safety
370
///
371
/// Note that [`UnsafeCell`] is not [`Sync`], but we explicitly only use this in
372
/// `&mut self` contexts, so we can safely implement `Sync` for this wrapper.
373
struct ThreadPoolList(Vec<Option<Arc<UnsafeCell<SharedTermProtection>>>>);
374

            
375
// SAFETY: See the safety documentation on `ThreadPoolList`.
376
unsafe impl Sync for ThreadPoolList {}
377
unsafe impl Send for ThreadPoolList {}
378

            
379
impl std::ops::Deref for ThreadPoolList {
380
    type Target = Vec<Option<Arc<UnsafeCell<SharedTermProtection>>>>;
381

            
382
777776
    fn deref(&self) -> &Self::Target {
383
777776
        &self.0
384
777776
    }
385
}
386

            
387
impl std::ops::DerefMut for ThreadPoolList {
388
1924
    fn deref_mut(&mut self) -> &mut Self::Target {
389
1924
        &mut self.0
390
1924
    }
391
}
392

            
393
/// A struct that contains the protection sets for a thread, as well as the
394
/// index of the thread pool in the global term pool.
395
pub struct SharedTermProtection {
396
    /// Protection set for terms
397
    pub term_protection_set: ProtectionSet<ATermIndex>,
398
    /// Protection set to prevent garbage collection of symbols
399
    pub symbol_protection_set: ProtectionSet<SymbolIndex>,
400
    /// Protection set for containers
401
    pub container_protection_set: ProtectionSet<Arc<dyn Markable + Sync + Send>>,
402
    /// Index in global pool's thread pools list
403
    pub index: usize,
404
}
405

            
406
impl SharedTermProtection {
407
    /// Returns the metrics of the term pool, can be formatted and written to output.
408
    pub fn metrics(&self) -> ProtectionMetrics<'_> {
409
        ProtectionMetrics(self)
410
    }
411
}
412

            
413
/// A struct that can be used to print the performance of the protection sets.
414
pub struct ProtectionMetrics<'a>(&'a SharedTermProtection);
415

            
416
impl fmt::Display for ProtectionMetrics<'_> {
417
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
418
        writeln!(
419
            f,
420
            "Protection set {} has {} roots, max {} and {} insertions",
421
            self.0.index,
422
            LargeFormatter(self.0.term_protection_set.len()),
423
            LargeFormatter(self.0.term_protection_set.maximum_size()),
424
            LargeFormatter(self.0.term_protection_set.number_of_insertions())
425
        )?;
426

            
427
        writeln!(
428
            f,
429
            "Containers: {} roots, max {} and {} insertions",
430
            LargeFormatter(self.0.container_protection_set.len()),
431
            LargeFormatter(self.0.container_protection_set.maximum_size()),
432
            LargeFormatter(self.0.container_protection_set.number_of_insertions()),
433
        )?;
434

            
435
        write!(
436
            f,
437
            "Symbols: {} roots, max {} and {} insertions",
438
            LargeFormatter(self.0.symbol_protection_set.len()),
439
            LargeFormatter(self.0.symbol_protection_set.maximum_size()),
440
            LargeFormatter(self.0.symbol_protection_set.number_of_insertions()),
441
        )
442
    }
443
}
444

            
445
/// Helper struct to pass private data required to mark term recursively.
446
pub struct Marker<'a> {
447
    marked_terms: &'a mut FxHashSet<ATermIndex>,
448
    marked_symbols: &'a mut FxHashSet<SymbolIndex>,
449
    stack: &'a mut Vec<ATermIndex>,
450
}
451

            
452
impl Marker<'_> {
453
    // Marks the given term as being reachable.
454
124219927
    pub fn mark(&mut self, term: &ATermRef<'_>) {
455
124219927
        if !self.marked_terms.contains(term.shared()) {
456
3718629
            self.stack.push(term.shared().copy());
457

            
458
17347898
            while let Some(term) = self.stack.pop() {
459
                // Each term should be marked.
460
13629269
                self.marked_terms.insert(term.copy());
461

            
462
                // Mark the function symbol.
463
13629269
                self.marked_symbols.insert(term.symbol().shared().copy());
464

            
465
                // For some terms, such as ATermInt, we must ONLY consider the valid arguments (indicated by the arity)
466
26492641
                for arg in term.arguments()[0..term.symbol().arity()].iter() {
467
                    // Skip if unnecessary, otherwise mark before pushing to stack since it can be shared.
468
26492641
                    if !self.marked_terms.contains(arg.shared()) {
469
9910640
                        self.marked_terms.insert(arg.shared().copy());
470
9910640
                        self.marked_symbols.insert(arg.get_head_symbol().shared().copy());
471
9910640
                        self.stack.push(arg.shared().copy());
472
16582001
                    }
473
                }
474
            }
475
120501298
        }
476
124219927
    }
477

            
478
    /// Marks the given symbol as being reachable.
479
4733568
    pub fn mark_symbol(&mut self, symbol: &SymbolRef<'_>) {
480
4733568
        self.marked_symbols.insert(symbol.shared().copy());
481
4733568
    }
482
}
483

            
484
#[cfg(test)]
485
mod tests {
486
    use std::collections::HashMap;
487

            
488
    use merc_utilities::random_test;
489

            
490
    use crate::ATerm;
491
    use crate::Symbol;
492
    use crate::Term;
493
    use crate::random_term;
494

            
495
    #[test]
496
    #[cfg_attr(miri, ignore)]
497
1
    fn test_maximal_sharing() {
498
100
        random_test(100, |rng| {
499
100
            let mut terms = HashMap::new();
500

            
501
100
            for _ in 0..1000 {
502
100000
                let term = random_term(rng, &[("f".into(), 2), ("g".into(), 1)], &["a".to_string()], 10);
503

            
504
100000
                let representation = format!("{}", term);
505
100000
                if let Some(entry) = terms.get(&representation) {
506
45741
                    assert_eq!(term, *entry, "There is another term with the same representation");
507
54259
                } else {
508
54259
                    terms.insert(representation, term);
509
54259
                }
510
            }
511
100
        });
512
1
    }
513

            
514
    #[test]
515
    #[should_panic]
516
1
    fn test_term_out_of_bound_arity() {
517
1
        let c = ATerm::constant(&Symbol::new("a", 0));
518

            
519
1
        let t = ATerm::with_args(&Symbol::new("f", 1), &[c.copy(), c.copy()]);
520

            
521
        // Currently we check on access
522
1
        let _ = t.arg(1);
523
1
    }
524
}