1
use std::cell::Cell;
2
use std::cell::RefCell;
3
use std::cell::UnsafeCell;
4
use std::mem::ManuallyDrop;
5
use std::ops::Deref;
6
use std::ops::DerefMut;
7
use std::sync::Arc;
8

            
9
use log::debug;
10

            
11
use merc_collections::ProtectionIndex;
12
use merc_pest_consume::Parser;
13
use merc_sharedmutex::RecursiveLock;
14
use merc_sharedmutex::RecursiveLockReadGuard;
15
use merc_unsafety::StablePointer;
16
use merc_utilities::MercError;
17
use merc_utilities::debug_trace;
18

            
19
use crate::Markable;
20
use crate::Return;
21
use crate::Rule;
22
use crate::Symb;
23
use crate::Symbol;
24
use crate::SymbolRef;
25
use crate::Term;
26
use crate::TermParser;
27
use crate::aterm::ATerm;
28
use crate::aterm::ATermRef;
29
use crate::storage::AGGRESSIVE_GC;
30
use crate::storage::GlobalTermPool;
31
use crate::storage::GlobalTermPoolGuard;
32
use crate::storage::SharedTerm;
33
use crate::storage::SharedTermProtection;
34
use crate::storage::global_aterm_pool::GLOBAL_TERM_POOL;
35

            
36
thread_local! {
37
    /// Thread-specific term pool that manages protection sets.
38
    pub static THREAD_TERM_POOL: RefCell<ThreadTermPool> = RefCell::new(ThreadTermPool::new());
39
}
40

            
41
/// Per-thread term pool managing local protection sets.
42
pub struct ThreadTermPool {
43
    /// A reference to the protection set of this thread pool.
44
    protection_set: Arc<UnsafeCell<SharedTermProtection>>,
45

            
46
    /// The number of times termms have been created before garbage collection is triggered.
47
    garbage_collection_counter: Cell<usize>,
48

            
49
    /// A vector of terms that are used to store the arguments of a term for loopup.
50
    tmp_arguments: RefCell<Vec<ATermRef<'static>>>,
51

            
52
    /// A local view for the global term pool.
53
    term_pool: RecursiveLock<GlobalTermPool>,
54

            
55
    /// Copy of the default terms since thread local access is cheaper.
56
    int_symbol: SymbolRef<'static>,
57
    empty_list_symbol: SymbolRef<'static>,
58
    list_symbol: SymbolRef<'static>,
59
}
60

            
61
impl ThreadTermPool {
62
    /// Creates a new thread-local term pool.
63
585
    fn new() -> Self {
64
        // Register protection sets with global pool
65
585
        let term_pool: RecursiveLock<GlobalTermPool> = RecursiveLock::from_mutex(GLOBAL_TERM_POOL.share());
66

            
67
585
        let mut pool = term_pool.write().expect("Lock poisoned!");
68

            
69
585
        let protection_set = pool.register_thread_term_pool();
70
585
        let int_symbol = pool.get_int_symbol().copy();
71
585
        let empty_list_symbol = pool.get_empty_list_symbol().copy();
72
585
        let list_symbol = pool.get_list_symbol().copy();
73
585
        drop(pool);
74

            
75
        // Arbitrary value to trigger garbage collection
76
        Self {
77
585
            protection_set,
78
585
            garbage_collection_counter: Cell::new(if AGGRESSIVE_GC { 1 } else { 1000 }),
79
585
            tmp_arguments: RefCell::new(Vec::new()),
80
585
            int_symbol,
81
585
            empty_list_symbol,
82
585
            list_symbol,
83
585
            term_pool,
84
        }
85
585
    }
86

            
87
    /// Creates a term without arguments.
88
2169442
    pub fn create_constant(&self, symbol: &SymbolRef<'_>) -> ATerm {
89
2169442
        assert!(symbol.arity() == 0, "A constant should not have arity > 0");
90

            
91
2169442
        let empty_args: [ATermRef<'_>; 0] = [];
92
2169442
        let guard = self.term_pool.read_recursive().expect("Lock poisoned!");
93

            
94
2169442
        let (index, inserted) = guard.create_term_array(symbol, &empty_args);
95
2169442
        let result = self.protect_guard(guard, &unsafe { ATermRef::from_index(&index) });
96

            
97
2169442
        if inserted {
98
9810
            // Intentially called after the guard is dropped.
99
9810
            self.trigger_garbage_collection();
100
2159640
        }
101

            
102
2169442
        result
103
2169442
    }
104

            
105
    /// Create a term with the given arguments
106
1492384
    pub fn create_term<'a, 'b>(
107
1492384
        &self,
108
1492384
        symbol: &'b impl Symb<'a, 'b>,
109
1492384
        args: &'b [impl Term<'a, 'b>],
110
1492384
    ) -> Return<ATermRef<'static>> {
111
1492384
        let mut arguments = self.tmp_arguments.borrow_mut();
112

            
113
1492384
        arguments.clear();
114
2436563
        for arg in args {
115
2436563
            unsafe {
116
2436563
                arguments.push(ATermRef::from_index(arg.shared()));
117
2436563
            }
118
        }
119

            
120
1492384
        let guard = self.term_pool.read_recursive().expect("Lock poisoned!");
121
1492384
        let (index, inserted) = guard.create_term_array(symbol, &arguments);
122

            
123
1492384
        let result = unsafe {
124
            // SAFETY: The guard is guaranteed to live as long as the returned term, since it is thread local and Return cannot be sended to other threads.
125
1492384
            Return::new(
126
1492384
                std::mem::transmute::<RecursiveLockReadGuard<'_, _>, RecursiveLockReadGuard<'static, _>>(guard),
127
1492384
                ATermRef::from_index(&index),
128
            )
129
        };
130

            
131
1492384
        if inserted {
132
150328
            // Intentially called after the guard is dropped.
133
150328
            self.trigger_garbage_collection();
134
1355415
        }
135

            
136
1492384
        result
137
1492384
    }
138

            
139
    /// Create a term with the given index.
140
3706324
    pub fn create_int(&self, value: usize) -> ATerm {
141
3706324
        let guard = self.term_pool.read_recursive().expect("Lock poisoned!");
142
3706324
        let (index, inserted) = guard.create_int(value);
143
3706324
        let result = self.protect_guard(guard, &unsafe { ATermRef::from_index(&index) });
144

            
145
3706324
        if inserted {
146
1985
            // Intentially called after the guard is dropped.
147
1985
            self.trigger_garbage_collection();
148
3704339
        }
149

            
150
3706324
        result
151
3706324
    }
152

            
153
    /// Create a term with the given arguments given by the iterator.
154
60634
    pub fn create_term_iter<'a, 'b, 'c, 'd, I, T>(&self, symbol: &'b impl Symb<'a, 'b>, args: I) -> ATerm
155
60634
    where
156
60634
        I: IntoIterator<Item = T>,
157
60634
        T: Term<'c, 'd>,
158
    {
159
60634
        let mut arguments = self.tmp_arguments.borrow_mut();
160
60634
        arguments.clear();
161
89979
        for arg in args {
162
89979
            unsafe {
163
89979
                arguments.push(ATermRef::from_index(arg.shared()));
164
89979
            }
165
        }
166

            
167
60634
        let guard = self.term_pool.read_recursive().expect("Lock poisoned!");
168
60634
        let (index, inserted) = guard.create_term_array(symbol, &arguments);
169

            
170
60634
        let result = self.protect_guard(guard, &unsafe { ATermRef::from_index(&index) });
171

            
172
60634
        if inserted {
173
25760
            // Intentially called after the guard is dropped.
174
25760
            self.trigger_garbage_collection();
175
34878
        }
176

            
177
60634
        result
178
60634
    }
179

            
180
    /// Create a term with the given arguments given by the iterator that is failable.
181
193859
    pub fn try_create_term_iter<'a, 'b, 'c, 'd, I, T>(
182
193859
        &self,
183
193859
        symbol: &'b impl Symb<'a, 'b>,
184
193859
        args: I,
185
193859
    ) -> Result<ATerm, MercError>
186
193859
    where
187
193859
        I: IntoIterator<Item = Result<T, MercError>>,
188
193859
        T: Term<'c, 'd>,
189
    {
190
193859
        let mut arguments = self.tmp_arguments.borrow_mut();
191
193859
        arguments.clear();
192
198165
        for arg in args {
193
            unsafe {
194
198165
                arguments.push(ATermRef::from_index(arg?.shared()));
195
            }
196
        }
197

            
198
193859
        let guard = self.term_pool.read_recursive().expect("Lock poisoned!");
199
193859
        let (index, inserted) = guard.create_term_array(symbol, &arguments);
200

            
201
193859
        let result = Ok(self.protect_guard(guard, &unsafe { ATermRef::from_index(&index) }));
202

            
203
193859
        if inserted {
204
2151
            // Intentially called after the guard is dropped.
205
2151
            self.trigger_garbage_collection();
206
192476
        }
207

            
208
193859
        result
209
193859
    }
210

            
211
    /// Create a term with the given arguments given by the iterator.
212
4319782
    pub fn create_term_iter_head<'a, 'b, 'c, 'd, 'e, 'f, I, T>(
213
4319782
        &self,
214
4319782
        symbol: &'b impl Symb<'a, 'b>,
215
4319782
        head: &'d impl Term<'c, 'd>,
216
4319782
        args: I,
217
4319782
    ) -> ATerm
218
4319782
    where
219
4319782
        I: IntoIterator<Item = T>,
220
4319782
        T: Term<'e, 'f>,
221
    {
222
4319782
        let mut arguments = self.tmp_arguments.borrow_mut();
223
4319782
        arguments.clear();
224
4319782
        unsafe {
225
4319782
            arguments.push(ATermRef::from_index(head.shared()));
226
4319782
        }
227
6814139
        for arg in args {
228
6814139
            unsafe {
229
6814139
                arguments.push(ATermRef::from_index(arg.shared()));
230
6814139
            }
231
        }
232

            
233
4319782
        let guard = self.term_pool.read_recursive().expect("Lock poisoned!");
234
4319782
        let (index, inserted) = guard.create_term_array(symbol, &arguments);
235

            
236
4319782
        let result = self.protect_guard(guard, &unsafe { ATermRef::from_index(&index) });
237

            
238
4319782
        if inserted {
239
418142
            // Intentially called after the guard is dropped.
240
418142
            self.trigger_garbage_collection();
241
3901642
        }
242

            
243
4319782
        result
244
4319782
    }
245

            
246
    /// Create a function symbol
247
3288484
    pub fn create_symbol(&self, name: impl Into<String> + AsRef<str>, arity: usize) -> Symbol {
248
3288484
        self.term_pool
249
3288484
            .read_recursive()
250
3288484
            .expect("Lock poisoned!")
251
3288484
            .create_symbol(name, arity, |index| unsafe {
252
3288484
                self.protect_symbol(&SymbolRef::from_index(&index))
253
3288484
            })
254
3288484
    }
255

            
256
    /// Protect the term by adding its index to the protection set
257
354663226
    pub fn protect(&self, term: &ATermRef<'_>) -> ATerm {
258
        // Protect the term by adding its index to the protection set
259
354663226
        let root = self.lock_protection_set().protection_set.protect(term.shared().copy());
260

            
261
        // Return the protected term
262
354663226
        let result = ATerm::from_index(term.shared(), root);
263

            
264
354663226
        debug_trace!(
265
            "Protected term {:?}, root {}, protection set {}",
266
            term,
267
            root,
268
            self.index()
269
        );
270

            
271
354663226
        result
272
354663226
    }
273

            
274
    /// Protect the term by adding its index to the protection set
275
50637720
    pub fn protect_guard(&self, _guard: RecursiveLockReadGuard<'_, GlobalTermPool>, term: &ATermRef<'_>) -> ATerm {
276
        // Protect the term by adding its index to the protection set
277
        // SAFETY: If the global term pool is locked, so we can safely access the protection set.
278
50637720
        let root = unsafe { &mut *self.protection_set.get() }
279
50637720
            .protection_set
280
50637720
            .protect(term.shared().copy());
281

            
282
        // Return the protected term
283
50637720
        let result = ATerm::from_index(term.shared(), root);
284

            
285
50637720
        debug_trace!(
286
            "Protected term {:?}, root {}, protection set {}",
287
            term,
288
            root,
289
            self.index()
290
        );
291

            
292
50637720
        result
293
50637720
    }
294

            
295
    /// Unprotects a term from this thread's protection set.
296
405300946
    pub fn drop(&self, term: &ATerm) {
297
405300946
        self.lock_protection_set().protection_set.unprotect(term.root());
298

            
299
405300946
        debug_trace!(
300
            "Unprotected term {:?}, root {}, protection set {}",
301
            term,
302
            term.root(),
303
            self.index()
304
        );
305
405300946
    }
306

            
307
    /// Protects a container in this thread's container protection set.
308
33353691
    pub fn protect_container(&self, container: Arc<dyn Markable + Send + Sync>) -> ProtectionIndex {
309
33353691
        let root = self.lock_protection_set().container_protection_set.protect(container);
310

            
311
33353691
        debug_trace!("Protected container index {}, protection set {}", root, self.index());
312

            
313
33353691
        root
314
33353691
    }
315

            
316
    /// Unprotects a container from this thread's container protection set.
317
33353691
    pub fn drop_container(&self, root: ProtectionIndex) {
318
33353691
        self.lock_protection_set().container_protection_set.unprotect(root);
319

            
320
33353691
        debug_trace!("Unprotected container index {}, protection set {}", root, self.index());
321
33353691
    }
322

            
323
    /// Parse the given string and returns the Term representation.
324
2842
    pub fn from_string(&self, text: &str) -> Result<ATerm, MercError> {
325
2842
        let mut result = TermParser::parse(Rule::TermSpec, text)?;
326
2842
        let root = result.next().unwrap();
327

            
328
2842
        Ok(TermParser::TermSpec(root).unwrap())
329
2842
    }
330

            
331
    /// Protects a symbol from garbage collection.
332
5230654
    pub fn protect_symbol(&self, symbol: &SymbolRef<'_>) -> Symbol {
333
5230654
        let result = unsafe {
334
5230654
            Symbol::from_index(
335
5230654
                symbol.shared(),
336
5230654
                self.lock_protection_set()
337
5230654
                    .symbol_protection_set
338
5230654
                    .protect(symbol.shared().copy()),
339
            )
340
        };
341

            
342
5230654
        debug_trace!(
343
            "Protected symbol {}, root {}, protection set {}",
344
            symbol,
345
            result.root(),
346
            lock.index,
347
        );
348

            
349
5230654
        result
350
5230654
    }
351

            
352
    /// Unprotects a symbol, allowing it to be garbage collected.
353
5227534
    pub fn drop_symbol(&self, symbol: &mut Symbol) {
354
5227534
        self.lock_protection_set()
355
5227534
            .symbol_protection_set
356
5227534
            .unprotect(symbol.root());
357
5227534
    }
358

            
359
    /// Returns the symbol for ATermInt
360
687584789
    pub fn int_symbol(&self) -> &SymbolRef<'_> {
361
687584789
        &self.int_symbol
362
687584789
    }
363

            
364
    /// Returns the symbol for ATermList
365
1382194
    pub fn list_symbol(&self) -> &SymbolRef<'_> {
366
1382194
        &self.list_symbol
367
1382194
    }
368

            
369
    /// Returns the symbol for the empty ATermInt
370
1396918
    pub fn empty_list_symbol(&self) -> &SymbolRef<'_> {
371
1396918
        &self.empty_list_symbol
372
1396918
    }
373

            
374
    /// Enables or disables automatic garbage collection.
375
    pub fn automatic_garbage_collection(&self, enabled: bool) {
376
        let mut guard = self.term_pool.write().expect("Lock poisoned!");
377
        guard.automatic_garbage_collection(enabled);
378
    }
379

            
380
    /// Triggers delayed garbage collection if the counter has reached zero.
381
    /// 
382
    /// # Safety
383
    /// 
384
    /// This function drops the passed guard.
385
260167850
    pub(crate) unsafe fn trigger_delayed_garbage_collection(&self, guard: &mut ManuallyDrop<GlobalTermPoolGuard<'_>>) {
386
260167850
        unsafe {
387
260167850
            ManuallyDrop::drop(guard);
388
260167850
        }
389

            
390
260167850
        debug_assert!(guard.read_depth() == 0, "Cannot trigger garbage collection while holding a read lock");        
391
260167850
        if self.garbage_collection_counter.get() == 0 {
392
28020
            self.trigger_garbage_collection();
393
260139830
        }
394
260167850
    }
395

            
396
    /// Returns access to the shared protection set.
397
    pub(crate) fn get_protection_set(&self) -> &Arc<UnsafeCell<SharedTermProtection>> {
398
        &self.protection_set
399
    }
400

            
401
    /// Returns a reference to the global term pool.
402
655362981
    pub(crate) fn term_pool(&self) -> &RecursiveLock<GlobalTermPool> {
403
655362981
        &self.term_pool
404
655362981
    }
405

            
406
    /// Replace the entry in the protection set with the given term.
407
    pub(crate) fn replace(
408
        &self,
409
        _guard: RecursiveLockReadGuard<'_, GlobalTermPool>,
410
        root: ProtectionIndex,
411
        term: StablePointer<SharedTerm>,
412
    ) {
413
        // Protect the term by adding its index to the protection set
414
        // SAFETY: If the global term pool is locked, so we can safely access the protection set.
415
        unsafe { &mut *self.protection_set.get() }
416
            .protection_set
417
            .replace(root, term);
418
    }
419

            
420
    /// This triggers the global garbage collection based on heuristics.
421
5206494
    fn trigger_garbage_collection(&self) {
422
        // If the term was newly inserted, decrease the garbage collection counter and trigger garbage collection if necessary
423
5206494
        let mut value = self.garbage_collection_counter.get();
424
5206494
        value = value.saturating_sub(1);
425

            
426
5206494
        if value == 0 && !self.term_pool.is_locked() {
427
28020
            // Trigger garbage collection and acquire a new counter value.
428
28020
            value = self
429
28020
                .term_pool
430
28020
                .write()
431
28020
                .expect("Lock poisoned!")
432
28020
                .trigger_garbage_collection();
433
5178474
        }
434

            
435
5206494
        self.garbage_collection_counter.set(value);
436
5206494
    }
437

            
438
    /// Returns the index of the protection set.
439
585
    fn index(&self) -> usize {
440
585
        self.lock_protection_set().index
441
585
    }
442

            
443
    /// The protection set is locked by the global read-write lock
444
837130333
    fn lock_protection_set(&self) -> ProtectionSetGuard<'_> {
445
837130333
        let guard = self.term_pool.read_recursive().expect("Lock poisoned!");
446
837130333
        let protection_set = unsafe { &mut *self.protection_set.get() };
447

            
448
837130333
        ProtectionSetGuard::new(guard, protection_set)
449
837130333
    }
450
}
451

            
452
impl Drop for ThreadTermPool {
453
585
    fn drop(&mut self) {
454
585
        let mut write = self.term_pool.write().expect("Lock poisoned!");
455

            
456
585
        debug!("{}", write.metrics());
457
585
        write.deregister_thread_pool(self.index());
458

            
459
585
        debug!("{}", unsafe { &mut *self.protection_set.get() }.metrics());
460
585
        debug!(
461
            "Acquired {} read locks and {} write locks",
462
            self.term_pool.read_recursive_call_count(),
463
            self.term_pool.write_call_count()
464
        )
465
585
    }
466
}
467

            
468
struct ProtectionSetGuard<'a> {
469
    _guard: RecursiveLockReadGuard<'a, GlobalTermPool>,
470
    object: &'a mut SharedTermProtection,
471
}
472

            
473
impl ProtectionSetGuard<'_> {
474
837130333
    fn new<'a>(
475
837130333
        guard: RecursiveLockReadGuard<'a, GlobalTermPool>,
476
837130333
        object: &'a mut SharedTermProtection,
477
837130333
    ) -> ProtectionSetGuard<'a> {
478
837130333
        ProtectionSetGuard { _guard: guard, object }
479
837130333
    }
480
}
481

            
482
impl Deref for ProtectionSetGuard<'_> {
483
    type Target = SharedTermProtection;
484

            
485
591
    fn deref(&self) -> &Self::Target {
486
591
        self.object
487
591
    }
488
}
489

            
490
impl DerefMut for ProtectionSetGuard<'_> {
491
837129742
    fn deref_mut(&mut self) -> &mut Self::Target {
492
837129742
        self.object
493
837129742
    }
494
}
495

            
496
#[cfg(test)]
497
mod tests {
498
    use crate::Term;
499

            
500
    use super::*;
501
    use std::thread;
502

            
503
    #[test]
504
1
    fn test_thread_local_protection() {
505
1
        let _ = merc_utilities::test_logger();
506

            
507
1
        thread::scope(|scope| {
508
1
            for _ in 0..3 {
509
3
                scope.spawn(|| {
510
                    // Create and protect some terms
511
3
                    let symbol = Symbol::new("test", 0);
512
3
                    let term = ATerm::constant(&symbol);
513
3
                    let protected = term.protect();
514

            
515
                    // Verify protection
516
3
                    THREAD_TERM_POOL.with_borrow(|tp| {
517
3
                        assert!(tp.lock_protection_set().protection_set.contains_root(protected.root()));
518
3
                    });
519

            
520
                    // Unprotect
521
3
                    let root = protected.root();
522
3
                    drop(protected);
523

            
524
3
                    THREAD_TERM_POOL.with_borrow(|tp| {
525
3
                        assert!(!tp.lock_protection_set().protection_set.contains_root(root));
526
3
                    });
527
3
                });
528
            }
529
1
        });
530
1
    }
531

            
532
    #[test]
533
1
    fn test_parsing() {
534
1
        let _ = merc_utilities::test_logger();
535

            
536
1
        let t = ATerm::from_string("f(g(a),b)").unwrap();
537

            
538
1
        assert!(t.get_head_symbol().name() == "f");
539
1
        assert!(t.arg(0).get_head_symbol().name() == "g");
540
1
        assert!(t.arg(1).get_head_symbol().name() == "b");
541
1
    }
542

            
543
    #[test]
544
1
    fn test_create_term() {
545
1
        let _ = merc_utilities::test_logger();
546

            
547
1
        let f = Symbol::new("f", 2);
548
1
        let g = Symbol::new("g", 1);
549

            
550
1
        let t = THREAD_TERM_POOL.with_borrow(|tp| {
551
1
            tp.create_term(
552
1
                &f,
553
1
                &[
554
1
                    tp.create_term(&g, &[tp.create_constant(&Symbol::new("a", 0))])
555
1
                        .protect(),
556
1
                    tp.create_constant(&Symbol::new("b", 0)),
557
1
                ],
558
1
            )
559
1
            .protect()
560
1
        });
561

            
562
1
        assert!(t.get_head_symbol().name() == "f");
563
1
        assert!(t.arg(0).get_head_symbol().name() == "g");
564
1
        assert!(t.arg(1).get_head_symbol().name() == "b");
565
1
    }
566
}