1
use std::cell::Cell;
2
use std::cell::RefCell;
3
use std::cell::UnsafeCell;
4
use std::mem::ManuallyDrop;
5
use std::ops::Deref;
6
use std::ops::DerefMut;
7
use std::sync::Arc;
8

            
9
use log::debug;
10

            
11
use merc_collections::ProtectionIndex;
12
use merc_pest_consume::Parser;
13
use merc_sharedmutex::RecursiveLock;
14
use merc_sharedmutex::RecursiveLockReadGuard;
15
use merc_unsafety::StablePointer;
16
use merc_utilities::MercError;
17
use merc_utilities::debug_trace;
18

            
19
use crate::Markable;
20
use crate::Return;
21
use crate::Rule;
22
use crate::Symb;
23
use crate::Symbol;
24
use crate::SymbolRef;
25
use crate::Term;
26
use crate::TermParser;
27
use crate::aterm::ATerm;
28
use crate::aterm::ATermRef;
29
use crate::storage::AGGRESSIVE_GC;
30
use crate::storage::GlobalTermPool;
31
use crate::storage::GlobalTermPoolGuard;
32
use crate::storage::SharedTerm;
33
use crate::storage::SharedTermProtection;
34
use crate::storage::global_aterm_pool::GLOBAL_TERM_POOL;
35

            
36
thread_local! {
37
    /// Thread-specific term pool that manages protection sets.
38
    pub static THREAD_TERM_POOL: RefCell<ThreadTermPool> = RefCell::new(ThreadTermPool::new());
39
}
40

            
41
/// Per-thread term pool managing local protection sets.
42
pub struct ThreadTermPool {
43
    /// A reference to the protection set of this thread pool.
44
    protection_set: Arc<UnsafeCell<SharedTermProtection>>,
45

            
46
    /// The number of times termms have been created before garbage collection is triggered.
47
    garbage_collection_counter: Cell<usize>,
48

            
49
    /// A vector of terms that are used to store the arguments of a term for loopup.
50
    tmp_arguments: RefCell<Vec<ATermRef<'static>>>,
51

            
52
    /// A local view for the global term pool.
53
    term_pool: RecursiveLock<GlobalTermPool>,
54

            
55
    /// Copy of the default terms since thread local access is cheaper.
56
    int_symbol: SymbolRef<'static>,
57
    empty_list_symbol: SymbolRef<'static>,
58
    list_symbol: SymbolRef<'static>,
59
}
60

            
61
impl ThreadTermPool {
62
    /// Creates a new thread-local term pool.
63
595
    fn new() -> Self {
64
        // Register protection sets with global pool
65
595
        let term_pool: RecursiveLock<GlobalTermPool> = RecursiveLock::from_mutex(GLOBAL_TERM_POOL.share());
66

            
67
595
        let mut pool = term_pool.write().expect("Lock poisoned!");
68

            
69
595
        let protection_set = pool.register_thread_term_pool();
70
595
        let int_symbol = pool.get_int_symbol().copy();
71
595
        let empty_list_symbol = pool.get_empty_list_symbol().copy();
72
595
        let list_symbol = pool.get_list_symbol().copy();
73
595
        drop(pool);
74

            
75
        // Arbitrary value to trigger garbage collection
76
        Self {
77
595
            protection_set,
78
595
            garbage_collection_counter: Cell::new(if AGGRESSIVE_GC { 1 } else { 1000 }),
79
595
            tmp_arguments: RefCell::new(Vec::new()),
80
595
            int_symbol,
81
595
            empty_list_symbol,
82
595
            list_symbol,
83
595
            term_pool,
84
        }
85
595
    }
86

            
87
    /// Creates a term without arguments.
88
2204402
    pub fn create_constant(&self, symbol: &SymbolRef<'_>) -> ATerm {
89
2204402
        assert!(symbol.arity() == 0, "A constant should not have arity > 0");
90

            
91
2204402
        let empty_args: [ATermRef<'_>; 0] = [];
92
2204402
        let guard = self.term_pool.read_recursive().expect("Lock poisoned!");
93

            
94
2204402
        let (index, inserted) = guard.create_term_array(symbol, &empty_args);
95
2204402
        let result = self.protect_guard(guard, &unsafe { ATermRef::from_index(&index) });
96

            
97
2204402
        if inserted {
98
9930
            // Intentially called after the guard is dropped.
99
9930
            self.trigger_garbage_collection();
100
2194480
        }
101

            
102
2204402
        result
103
2204402
    }
104

            
105
    /// Create a term with the given arguments
106
1509984
    pub fn create_term<'a, 'b, S: Symb<'a, 'b>, T: Term<'a, 'b>>(
107
1509984
        &self,
108
1509984
        symbol: &'b S,
109
1509984
        args: &'b [T],
110
1509984
    ) -> Return<ATermRef<'static>> {
111
1509984
        let mut arguments = self.tmp_arguments.borrow_mut();
112

            
113
1509984
        arguments.clear();
114
2463371
        for arg in args {
115
2463371
            unsafe {
116
2463371
                arguments.push(ATermRef::from_index(arg.shared()));
117
2463371
            }
118
        }
119

            
120
1509984
        let guard = self.term_pool.read_recursive().expect("Lock poisoned!");
121
1509984
        let (index, inserted) = guard.create_term_array(symbol, &arguments);
122

            
123
1509984
        let result = unsafe {
124
            // SAFETY: The guard is guaranteed to live as long as the returned term, since it is thread local and Return cannot be sended to other threads.
125
1509984
            Return::new(
126
1509984
                std::mem::transmute::<RecursiveLockReadGuard<'_, _>, RecursiveLockReadGuard<'static, _>>(guard),
127
1509984
                ATermRef::from_index(&index),
128
            )
129
        };
130

            
131
1509984
        if inserted {
132
150520
            // Intentially called after the guard is dropped.
133
150520
            self.trigger_garbage_collection();
134
1370785
        }
135

            
136
1509984
        result
137
1509984
    }
138

            
139
    /// Create a term with the given index.
140
3730244
    pub fn create_int(&self, value: usize) -> ATerm {
141
3730244
        let guard = self.term_pool.read_recursive().expect("Lock poisoned!");
142
3730244
        let (index, inserted) = guard.create_int(value);
143
3730244
        let result = self.protect_guard(guard, &unsafe { ATermRef::from_index(&index) });
144

            
145
3730244
        if inserted {
146
1985
            // Intentially called after the guard is dropped.
147
1985
            self.trigger_garbage_collection();
148
3728259
        }
149

            
150
3730244
        result
151
3730244
    }
152

            
153
    /// Create a term with the given arguments given by the iterator.
154
60634
    pub fn create_term_iter<'a, 'b, 'c, 'd, S, I, T>(&self, symbol: &'b S, args: I) -> ATerm
155
60634
    where
156
60634
        S: Symb<'a, 'b>,
157
60634
        I: IntoIterator<Item = T>,
158
60634
        T: Term<'c, 'd>,
159
    {
160
60634
        let mut arguments = self.tmp_arguments.borrow_mut();
161
60634
        arguments.clear();
162
89979
        for arg in args {
163
89979
            unsafe {
164
89979
                arguments.push(ATermRef::from_index(arg.shared()));
165
89979
            }
166
        }
167

            
168
60634
        let guard = self.term_pool.read_recursive().expect("Lock poisoned!");
169
60634
        let (index, inserted) = guard.create_term_array(symbol, &arguments);
170

            
171
60634
        let result = self.protect_guard(guard, &unsafe { ATermRef::from_index(&index) });
172

            
173
60634
        if inserted {
174
25760
            // Intentially called after the guard is dropped.
175
25760
            self.trigger_garbage_collection();
176
34878
        }
177

            
178
60634
        result
179
60634
    }
180

            
181
    /// Create a term with the given arguments given by the iterator that is failable.
182
195055
    pub fn try_create_term_iter<'a, 'b, 'c, 'd, S, I, T>(&self, symbol: &'b S, args: I) -> Result<ATerm, MercError>
183
195055
    where
184
195055
        S: Symb<'a, 'b>,
185
195055
        I: IntoIterator<Item = Result<T, MercError>>,
186
195055
        T: Term<'c, 'd>,
187
    {
188
195055
        let mut arguments = self.tmp_arguments.borrow_mut();
189
195055
        arguments.clear();
190
199346
        for arg in args {
191
            unsafe {
192
199346
                arguments.push(ATermRef::from_index(arg?.shared()));
193
            }
194
        }
195

            
196
195055
        let guard = self.term_pool.read_recursive().expect("Lock poisoned!");
197
195055
        let (index, inserted) = guard.create_term_array(symbol, &arguments);
198

            
199
195055
        let result = Ok(self.protect_guard(guard, &unsafe { ATermRef::from_index(&index) }));
200

            
201
195055
        if inserted {
202
2151
            // Intentially called after the guard is dropped.
203
2151
            self.trigger_garbage_collection();
204
193672
        }
205

            
206
195055
        result
207
195055
    }
208

            
209
    /// Create a term with the given arguments given by the iterator.
210
4325067
    pub fn create_term_iter_head<'a, 'b, 'c, 'd, 'e, 'f, S, H, I, T>(
211
4325067
        &self,
212
4325067
        symbol: &'b S,
213
4325067
        head: &'d H,
214
4325067
        args: I,
215
4325067
    ) -> ATerm
216
4325067
    where
217
4325067
        S: Symb<'a, 'b>,
218
4325067
        H: Term<'c, 'd>,
219
4325067
        I: IntoIterator<Item = T>,
220
4325067
        T: Term<'e, 'f>,
221
    {
222
4325067
        let mut arguments = self.tmp_arguments.borrow_mut();
223
4325067
        arguments.clear();
224
4325067
        unsafe {
225
4325067
            arguments.push(ATermRef::from_index(head.shared()));
226
4325067
        }
227
6822543
        for arg in args {
228
6822543
            unsafe {
229
6822543
                arguments.push(ATermRef::from_index(arg.shared()));
230
6822543
            }
231
        }
232

            
233
4325067
        let guard = self.term_pool.read_recursive().expect("Lock poisoned!");
234
4325067
        let (index, inserted) = guard.create_term_array(symbol, &arguments);
235

            
236
4325067
        let result = self.protect_guard(guard, &unsafe { ATermRef::from_index(&index) });
237

            
238
4325067
        if inserted {
239
418238
            // Intentially called after the guard is dropped.
240
418238
            self.trigger_garbage_collection();
241
3906831
        }
242

            
243
4325067
        result
244
4325067
    }
245

            
246
    /// Create a function symbol
247
3329864
    pub fn create_symbol<N: Into<String> + AsRef<str>>(&self, name: N, arity: usize) -> Symbol {
248
3329864
        self.term_pool
249
3329864
            .read_recursive()
250
3329864
            .expect("Lock poisoned!")
251
3329864
            .create_symbol(name, arity, |index| unsafe {
252
3329864
                self.protect_symbol(&SymbolRef::from_index(&index))
253
3329864
            })
254
3329864
    }
255

            
256
    /// Protect the term by adding its index to the protection set
257
354924177
    pub fn protect(&self, term: &ATermRef<'_>) -> ATerm {
258
        // Protect the term by adding its index to the protection set
259
354924177
        let root = self.lock_protection_set().protection_set.protect(term.shared().copy());
260

            
261
        // Return the protected term
262
354924177
        let result = ATerm::from_index(term.shared(), root);
263

            
264
354924177
        debug_trace!(
265
            "Protected term {:?}, root {}, protection set {}",
266
            term,
267
            root,
268
            self.index()
269
        );
270

            
271
354924177
        result
272
354924177
    }
273

            
274
    /// Protect the term by adding its index to the protection set
275
50761410
    pub fn protect_guard(&self, _guard: RecursiveLockReadGuard<'_, GlobalTermPool>, term: &ATermRef<'_>) -> ATerm {
276
        // Protect the term by adding its index to the protection set
277
        // SAFETY: If the global term pool is locked, so we can safely access the protection set.
278
50761410
        let root = unsafe { &mut *self.protection_set.get() }
279
50761410
            .protection_set
280
50761410
            .protect(term.shared().copy());
281

            
282
        // Return the protected term
283
50761410
        let result = ATerm::from_index(term.shared(), root);
284

            
285
50761410
        debug_trace!(
286
            "Protected term {:?}, root {}, protection set {}",
287
            term,
288
            root,
289
            self.index()
290
        );
291

            
292
50761410
        result
293
50761410
    }
294

            
295
    /// Unprotects a term from this thread's protection set.
296
405685587
    pub fn drop(&self, term: &ATerm) {
297
405685587
        self.lock_protection_set().protection_set.unprotect(term.root());
298

            
299
405685587
        debug_trace!(
300
            "Unprotected term {:?}, root {}, protection set {}",
301
            term,
302
            term.root(),
303
            self.index()
304
        );
305
405685587
    }
306

            
307
    /// Protects a container in this thread's container protection set.
308
33242771
    pub fn protect_container(&self, container: Arc<dyn Markable + Send + Sync>) -> ProtectionIndex {
309
33242771
        let root = self.lock_protection_set().container_protection_set.protect(container);
310

            
311
33242771
        debug_trace!("Protected container index {}, protection set {}", root, self.index());
312

            
313
33242771
        root
314
33242771
    }
315

            
316
    /// Unprotects a container from this thread's container protection set.
317
33242771
    pub fn drop_container(&self, root: ProtectionIndex) {
318
33242771
        self.lock_protection_set().container_protection_set.unprotect(root);
319

            
320
33242771
        debug_trace!("Unprotected container index {}, protection set {}", root, self.index());
321
33242771
    }
322

            
323
    /// Parse the given string and returns the Term representation.
324
3842
    pub fn from_string(&self, text: &str) -> Result<ATerm, MercError> {
325
3842
        let mut result = TermParser::parse(Rule::TermSpec, text)?;
326
3842
        let root = result.next().unwrap();
327

            
328
3842
        Ok(TermParser::TermSpec(root).unwrap())
329
3842
    }
330

            
331
    /// Protects a symbol from garbage collection.
332
5283994
    pub fn protect_symbol(&self, symbol: &SymbolRef<'_>) -> Symbol {
333
5283994
        let result = unsafe {
334
5283994
            Symbol::from_index(
335
5283994
                symbol.shared(),
336
5283994
                self.lock_protection_set()
337
5283994
                    .symbol_protection_set
338
5283994
                    .protect(symbol.shared().copy()),
339
            )
340
        };
341

            
342
5283994
        debug_trace!(
343
            "Protected symbol {}, root {}, protection set {}",
344
            symbol,
345
            result.root(),
346
            lock.index,
347
        );
348

            
349
5283994
        result
350
5283994
    }
351

            
352
    /// Unprotects a symbol, allowing it to be garbage collected.
353
5274454
    pub fn drop_symbol(&self, symbol: &mut Symbol) {
354
5274454
        self.lock_protection_set()
355
5274454
            .symbol_protection_set
356
5274454
            .unprotect(symbol.root());
357
5274454
    }
358

            
359
    /// Returns the symbol for ATermInt
360
686331103
    pub fn int_symbol(&self) -> &SymbolRef<'_> {
361
686331103
        &self.int_symbol
362
686331103
    }
363

            
364
    /// Returns the symbol for ATermList
365
1554538
    pub fn list_symbol(&self) -> &SymbolRef<'_> {
366
1554538
        &self.list_symbol
367
1554538
    }
368

            
369
    /// Returns the symbol for the empty ATermInt
370
1569262
    pub fn empty_list_symbol(&self) -> &SymbolRef<'_> {
371
1569262
        &self.empty_list_symbol
372
1569262
    }
373

            
374
    /// Enables or disables automatic garbage collection.
375
    pub fn automatic_garbage_collection(&self, enabled: bool) {
376
        let mut guard = self.term_pool.write().expect("Lock poisoned!");
377
        guard.automatic_garbage_collection(enabled);
378
    }
379

            
380
    /// Triggers delayed garbage collection if the counter has reached zero.
381
    ///
382
    /// # Safety
383
    ///
384
    /// This function drops the passed guard.
385
260722290
    pub(crate) unsafe fn trigger_delayed_garbage_collection(&self, guard: &mut ManuallyDrop<GlobalTermPoolGuard<'_>>) {
386
260722290
        unsafe {
387
260722290
            ManuallyDrop::drop(guard);
388
260722290
        }
389

            
390
260722290
        debug_assert!(
391
260722290
            guard.read_depth() == 0,
392
            "Cannot trigger garbage collection while holding a read lock"
393
        );
394
260722290
        if self.garbage_collection_counter.get() == 0 {
395
28020
            self.trigger_garbage_collection();
396
260694270
        }
397
260722290
    }
398

            
399
    /// Returns access to the shared protection set.
400
    pub(crate) fn get_protection_set(&self) -> &Arc<UnsafeCell<SharedTermProtection>> {
401
        &self.protection_set
402
    }
403

            
404
    /// Returns a reference to the global term pool.
405
654312901
    pub(crate) fn term_pool(&self) -> &RecursiveLock<GlobalTermPool> {
406
654312901
        &self.term_pool
407
654312901
    }
408

            
409
    /// Replace the entry in the protection set with the given term.
410
    pub(crate) fn replace(
411
        &self,
412
        _guard: RecursiveLockReadGuard<'_, GlobalTermPool>,
413
        root: ProtectionIndex,
414
        term: StablePointer<SharedTerm>,
415
    ) {
416
        // Protect the term by adding its index to the protection set
417
        // SAFETY: If the global term pool is locked, so we can safely access the protection set.
418
        unsafe { &mut *self.protection_set.get() }
419
            .protection_set
420
            .replace(root, term);
421
    }
422

            
423
    /// This triggers the global garbage collection based on heuristics.
424
5208411
    fn trigger_garbage_collection(&self) {
425
        // If the term was newly inserted, decrease the garbage collection counter and trigger garbage collection if necessary
426
5208411
        let mut value = self.garbage_collection_counter.get();
427
5208411
        value = value.saturating_sub(1);
428

            
429
5208411
        if value == 0 && !self.term_pool.is_locked() {
430
28020
            // Trigger garbage collection and acquire a new counter value.
431
28020
            value = self
432
28020
                .term_pool
433
28020
                .write()
434
28020
                .expect("Lock poisoned!")
435
28020
                .trigger_garbage_collection();
436
5180391
        }
437

            
438
5208411
        self.garbage_collection_counter.set(value);
439
5208411
    }
440

            
441
    /// Returns the index of the protection set.
442
595
    fn index(&self) -> usize {
443
595
        self.lock_protection_set().index
444
595
    }
445

            
446
    /// The protection set is locked by the global read-write lock
447
837654355
    fn lock_protection_set(&self) -> ProtectionSetGuard<'_> {
448
837654355
        let guard = self.term_pool.read_recursive().expect("Lock poisoned!");
449
837654355
        let protection_set = unsafe { &mut *self.protection_set.get() };
450

            
451
837654355
        ProtectionSetGuard::new(guard, protection_set)
452
837654355
    }
453
}
454

            
455
impl Drop for ThreadTermPool {
456
595
    fn drop(&mut self) {
457
595
        let mut write = self.term_pool.write().expect("Lock poisoned!");
458

            
459
595
        debug!("{}", write.metrics());
460
595
        write.deregister_thread_pool(self.index());
461

            
462
595
        debug!("{}", unsafe { &mut *self.protection_set.get() }.metrics());
463
595
        debug!(
464
            "Acquired {} read locks and {} write locks",
465
            self.term_pool.read_recursive_call_count(),
466
            self.term_pool.write_call_count()
467
        )
468
595
    }
469
}
470

            
471
struct ProtectionSetGuard<'a> {
472
    _guard: RecursiveLockReadGuard<'a, GlobalTermPool>,
473
    object: &'a mut SharedTermProtection,
474
}
475

            
476
impl ProtectionSetGuard<'_> {
477
837654355
    fn new<'a>(
478
837654355
        guard: RecursiveLockReadGuard<'a, GlobalTermPool>,
479
837654355
        object: &'a mut SharedTermProtection,
480
837654355
    ) -> ProtectionSetGuard<'a> {
481
837654355
        ProtectionSetGuard { _guard: guard, object }
482
837654355
    }
483
}
484

            
485
impl Deref for ProtectionSetGuard<'_> {
486
    type Target = SharedTermProtection;
487

            
488
601
    fn deref(&self) -> &Self::Target {
489
601
        self.object
490
601
    }
491
}
492

            
493
impl DerefMut for ProtectionSetGuard<'_> {
494
837653754
    fn deref_mut(&mut self) -> &mut Self::Target {
495
837653754
        self.object
496
837653754
    }
497
}
498

            
499
#[cfg(test)]
500
mod tests {
501
    use crate::Term;
502

            
503
    use super::*;
504
    use std::thread;
505

            
506
    #[test]
507
1
    fn test_thread_local_protection() {
508
1
        let _ = merc_utilities::test_logger();
509

            
510
1
        thread::scope(|scope| {
511
1
            for _ in 0..3 {
512
3
                scope.spawn(|| {
513
                    // Create and protect some terms
514
3
                    let symbol = Symbol::new("test", 0);
515
3
                    let term = ATerm::constant(&symbol);
516
3
                    let protected = term.protect();
517

            
518
                    // Verify protection
519
3
                    THREAD_TERM_POOL.with_borrow(|tp| {
520
3
                        assert!(tp.lock_protection_set().protection_set.contains_root(protected.root()));
521
3
                    });
522

            
523
                    // Unprotect
524
3
                    let root = protected.root();
525
3
                    drop(protected);
526

            
527
3
                    THREAD_TERM_POOL.with_borrow(|tp| {
528
3
                        assert!(!tp.lock_protection_set().protection_set.contains_root(root));
529
3
                    });
530
3
                });
531
            }
532
1
        });
533
1
    }
534

            
535
    #[test]
536
1
    fn test_parsing() {
537
1
        let _ = merc_utilities::test_logger();
538

            
539
1
        let t = ATerm::from_string("f(g(a),b)").unwrap();
540

            
541
1
        assert!(t.get_head_symbol().name() == "f");
542
1
        assert!(t.arg(0).get_head_symbol().name() == "g");
543
1
        assert!(t.arg(1).get_head_symbol().name() == "b");
544
1
    }
545

            
546
    #[test]
547
1
    fn test_create_term() {
548
1
        let _ = merc_utilities::test_logger();
549

            
550
1
        let f = Symbol::new("f", 2);
551
1
        let g = Symbol::new("g", 1);
552

            
553
1
        let t = THREAD_TERM_POOL.with_borrow(|tp| {
554
1
            tp.create_term(
555
1
                &f,
556
1
                &[
557
1
                    tp.create_term(&g, &[tp.create_constant(&Symbol::new("a", 0))])
558
1
                        .protect(),
559
1
                    tp.create_constant(&Symbol::new("b", 0)),
560
1
                ],
561
1
            )
562
1
            .protect()
563
1
        });
564

            
565
1
        assert!(t.get_head_symbol().name() == "f");
566
1
        assert!(t.arg(0).get_head_symbol().name() == "g");
567
1
        assert!(t.arg(1).get_head_symbol().name() == "b");
568
1
    }
569
}