1
use std::cell::Cell;
2
use std::cell::RefCell;
3
use std::cell::UnsafeCell;
4
use std::ops::Deref;
5
use std::ops::DerefMut;
6
use std::sync::Arc;
7

            
8
use log::debug;
9

            
10
use merc_collections::ProtectionIndex;
11
use merc_pest_consume::Parser;
12
use merc_sharedmutex::RecursiveLock;
13
use merc_sharedmutex::RecursiveLockReadGuard;
14
use merc_unsafety::StablePointer;
15
use merc_utilities::MercError;
16
use merc_utilities::debug_trace;
17

            
18
use crate::Markable;
19
use crate::Return;
20
use crate::Rule;
21
use crate::Symb;
22
use crate::Symbol;
23
use crate::SymbolRef;
24
use crate::Term;
25
use crate::TermParser;
26
use crate::aterm::ATerm;
27
use crate::aterm::ATermRef;
28
use crate::storage::AGGRESSIVE_GC;
29
use crate::storage::GlobalTermPool;
30
use crate::storage::SharedTerm;
31
use crate::storage::SharedTermProtection;
32
use crate::storage::global_aterm_pool::GLOBAL_TERM_POOL;
33

            
34
thread_local! {
35
    /// Thread-specific term pool that manages protection sets.
36
    pub static THREAD_TERM_POOL: RefCell<ThreadTermPool> = RefCell::new(ThreadTermPool::new());
37
}
38

            
39
/// Per-thread term pool managing local protection sets.
40
pub struct ThreadTermPool {
41
    /// A reference to the protection set of this thread pool.
42
    protection_set: Arc<UnsafeCell<SharedTermProtection>>,
43

            
44
    /// The number of times termms have been created before garbage collection is triggered.
45
    garbage_collection_counter: Cell<usize>,
46

            
47
    /// A vector of terms that are used to store the arguments of a term for loopup.
48
    tmp_arguments: RefCell<Vec<ATermRef<'static>>>,
49

            
50
    /// A local view for the global term pool.
51
    term_pool: RecursiveLock<GlobalTermPool>,
52

            
53
    /// Copy of the default terms since thread local access is cheaper.
54
    int_symbol: SymbolRef<'static>,
55
    empty_list_symbol: SymbolRef<'static>,
56
    list_symbol: SymbolRef<'static>,
57
}
58

            
59
impl ThreadTermPool {
60
    /// Creates a new thread-local term pool.
61
535
    fn new() -> Self {
62
        // Register protection sets with global pool
63
535
        let term_pool: RecursiveLock<GlobalTermPool> = RecursiveLock::from_mutex(GLOBAL_TERM_POOL.share());
64

            
65
535
        let mut pool = term_pool.write().expect("Lock poisoned!");
66

            
67
535
        let protection_set = pool.register_thread_term_pool();
68
535
        let int_symbol = pool.get_int_symbol().copy();
69
535
        let empty_list_symbol = pool.get_empty_list_symbol().copy();
70
535
        let list_symbol = pool.get_list_symbol().copy();
71
535
        drop(pool);
72

            
73
        // Arbitrary value to trigger garbage collection
74
        Self {
75
535
            protection_set,
76
535
            garbage_collection_counter: Cell::new(if AGGRESSIVE_GC { 1 } else { 1000 }),
77
535
            tmp_arguments: RefCell::new(Vec::new()),
78
535
            int_symbol,
79
535
            empty_list_symbol,
80
535
            list_symbol,
81
535
            term_pool,
82
        }
83
535
    }
84

            
85
    /// Creates a term without arguments.
86
2142072
    pub fn create_constant(&self, symbol: &SymbolRef<'_>) -> ATerm {
87
2142072
        assert!(symbol.arity() == 0, "A constant should not have arity > 0");
88

            
89
2142072
        let empty_args: [ATermRef<'_>; 0] = [];
90
2142072
        let guard = self.term_pool.read_recursive().expect("Lock poisoned!");
91

            
92
2142072
        let (index, inserted) = guard.create_term_array(symbol, &empty_args);
93

            
94
2142072
        if inserted {
95
9540
            self.trigger_garbage_collection();
96
2132540
        }
97

            
98
2142072
        self.protect_guard(guard, &unsafe { ATermRef::from_index(&index) })
99
2142072
    }
100

            
101
    /// Create a term with the given arguments
102
1457120
    pub fn create_term<'a, 'b>(
103
1457120
        &self,
104
1457120
        symbol: &'b impl Symb<'a, 'b>,
105
1457120
        args: &'b [impl Term<'a, 'b>],
106
1457120
    ) -> Return<ATermRef<'static>> {
107
1457120
        let mut arguments = self.tmp_arguments.borrow_mut();
108

            
109
1457120
        arguments.clear();
110
2382902
        for arg in args {
111
2382902
            unsafe {
112
2382902
                arguments.push(ATermRef::from_index(arg.shared()));
113
2382902
            }
114
        }
115

            
116
1457120
        let guard = self.term_pool.read_recursive().expect("Lock poisoned!");
117

            
118
1457120
        let (index, inserted) = guard.create_term_array(symbol, &arguments);
119

            
120
1457120
        if inserted {
121
52786
            self.trigger_garbage_collection();
122
1404340
        }
123

            
124
        unsafe {
125
            // SAFETY: The guard is guaranteed to live as long as the returned term, since it is thread local and Return cannot be sended to other threads.
126
1457120
            Return::new(
127
1457120
                std::mem::transmute::<RecursiveLockReadGuard<'_, _>, RecursiveLockReadGuard<'static, _>>(guard),
128
1457120
                ATermRef::from_index(&index),
129
            )
130
        }
131
1457120
    }
132

            
133
    /// Create a term with the given index.
134
3741624
    pub fn create_int(&self, value: usize) -> ATerm {
135
3741624
        let guard = self.term_pool.read_recursive().expect("Lock poisoned!");
136

            
137
3741624
        let (index, inserted) = guard.create_int(value);
138

            
139
3741624
        if inserted {
140
1885
            self.trigger_garbage_collection();
141
3739739
        }
142

            
143
3741624
        self.protect_guard(guard, &unsafe { ATermRef::from_index(&index) })
144
3741624
    }
145

            
146
    /// Create a term with the given arguments given by the iterator.
147
60634
    pub fn create_term_iter<'a, 'b, 'c, 'd, I, T>(&self, symbol: &'b impl Symb<'a, 'b>, args: I) -> ATerm
148
60634
    where
149
60634
        I: IntoIterator<Item = T>,
150
60634
        T: Term<'c, 'd>,
151
    {
152
60634
        let mut arguments = self.tmp_arguments.borrow_mut();
153
60634
        arguments.clear();
154
89979
        for arg in args {
155
89979
            unsafe {
156
89979
                arguments.push(ATermRef::from_index(arg.shared()));
157
89979
            }
158
        }
159

            
160
60634
        let guard = self.term_pool.read_recursive().expect("Lock poisoned!");
161

            
162
60634
        let (index, inserted) = guard.create_term_array(symbol, &arguments);
163

            
164
60634
        if inserted {
165
25760
            self.trigger_garbage_collection();
166
34878
        }
167

            
168
60634
        self.protect_guard(guard, &unsafe { ATermRef::from_index(&index) })
169
60634
    }
170

            
171
    /// Create a term with the given arguments given by the iterator that is failable.
172
195247
    pub fn try_create_term_iter<'a, 'b, 'c, 'd, I, T>(
173
195247
        &self,
174
195247
        symbol: &'b impl Symb<'a, 'b>,
175
195247
        args: I,
176
195247
    ) -> Result<ATerm, MercError>
177
195247
    where
178
195247
        I: IntoIterator<Item = Result<T, MercError>>,
179
195247
        T: Term<'c, 'd>,
180
    {
181
195247
        let mut arguments = self.tmp_arguments.borrow_mut();
182
195247
        arguments.clear();
183
199223
        for arg in args {
184
            unsafe {
185
199223
                arguments.push(ATermRef::from_index(arg?.shared()));
186
            }
187
        }
188

            
189
195247
        let guard = self.term_pool.read_recursive().expect("Lock poisoned!");
190

            
191
195247
        let (index, inserted) = guard.create_term_array(symbol, &arguments);
192

            
193
195247
        if inserted {
194
1845
            self.trigger_garbage_collection();
195
194044
        }
196

            
197
195247
        Ok(self.protect_guard(guard, &unsafe { ATermRef::from_index(&index) }))
198
195247
    }
199

            
200
    /// Create a term with the given arguments given by the iterator.
201
4324541
    pub fn create_term_iter_head<'a, 'b, 'c, 'd, 'e, 'f, I, T>(
202
4324541
        &self,
203
4324541
        symbol: &'b impl Symb<'a, 'b>,
204
4324541
        head: &'d impl Term<'c, 'd>,
205
4324541
        args: I,
206
4324541
    ) -> ATerm
207
4324541
    where
208
4324541
        I: IntoIterator<Item = T>,
209
4324541
        T: Term<'e, 'f>,
210
    {
211
4324541
        let mut arguments = self.tmp_arguments.borrow_mut();
212
4324541
        arguments.clear();
213
4324541
        unsafe {
214
4324541
            arguments.push(ATermRef::from_index(head.shared()));
215
4324541
        }
216
6821611
        for arg in args {
217
6821611
            unsafe {
218
6821611
                arguments.push(ATermRef::from_index(arg.shared()));
219
6821611
            }
220
        }
221

            
222
4324541
        let guard = self.term_pool.read_recursive().expect("Lock poisoned!");
223

            
224
4324541
        let (index, inserted) = guard.create_term_array(symbol, &arguments);
225

            
226
4324541
        if inserted {
227
54298
            self.trigger_garbage_collection();
228
4270245
        }
229

            
230
4324541
        self.protect_guard(guard, &unsafe { ATermRef::from_index(&index) })
231
4324541
    }
232

            
233
    /// Create a function symbol
234
3259254
    pub fn create_symbol(&self, name: impl Into<String> + AsRef<str>, arity: usize) -> Symbol {
235
3259254
        self.term_pool
236
3259254
            .read_recursive()
237
3259254
            .expect("Lock poisoned!")
238
3259254
            .create_symbol(name, arity, |index| unsafe {
239
3259254
                self.protect_symbol(&SymbolRef::from_index(&index))
240
3259254
            })
241
3259254
    }
242

            
243
    /// Protect the term by adding its index to the protection set
244
354378654
    pub fn protect(&self, term: &ATermRef<'_>) -> ATerm {
245
        // Protect the term by adding its index to the protection set
246
354378654
        let root = self.lock_protection_set().protection_set.protect(term.shared().copy());
247

            
248
        // Return the protected term
249
354378654
        let result = ATerm::from_index(term.shared(), root);
250

            
251
354378654
        debug_trace!(
252
            "Protected term {:?}, root {}, protection set {}",
253
            term,
254
            root,
255
            self.index()
256
        );
257

            
258
354378654
        result
259
354378654
    }
260

            
261
    /// Protect the term by adding its index to the protection set
262
50707120
    pub fn protect_guard(&self, _guard: RecursiveLockReadGuard<'_, GlobalTermPool>, term: &ATermRef<'_>) -> ATerm {
263
        // Protect the term by adding its index to the protection set
264
        // SAFETY: If the global term pool is locked, so we can safely access the protection set.
265
50707120
        let root = unsafe { &mut *self.protection_set.get() }
266
50707120
            .protection_set
267
50707120
            .protect(term.shared().copy());
268

            
269
        // Return the protected term
270
50707120
        let result = ATerm::from_index(term.shared(), root);
271

            
272
50707120
        debug_trace!(
273
            "Protected term {:?}, root {}, protection set {}",
274
            term,
275
            root,
276
            self.index()
277
        );
278

            
279
50707120
        result
280
50707120
    }
281

            
282
    /// Unprotects a term from this thread's protection set.
283
405085774
    pub fn drop(&self, term: &ATerm) {
284
405085774
        self.lock_protection_set().protection_set.unprotect(term.root());
285

            
286
405085774
        debug_trace!(
287
            "Unprotected term {:?}, root {}, protection set {}",
288
            term,
289
            term.root(),
290
            self.index()
291
        );
292
405085774
    }
293

            
294
    /// Protects a container in this thread's container protection set.
295
33271991
    pub fn protect_container(&self, container: Arc<dyn Markable + Send + Sync>) -> ProtectionIndex {
296
33271991
        let root = self.lock_protection_set().container_protection_set.protect(container);
297

            
298
33271991
        debug_trace!("Protected container index {}, protection set {}", root, self.index());
299

            
300
33271991
        root
301
33271991
    }
302

            
303
    /// Unprotects a container from this thread's container protection set.
304
33271991
    pub fn drop_container(&self, root: ProtectionIndex) {
305
33271991
        self.lock_protection_set().container_protection_set.unprotect(root);
306

            
307
33271991
        debug_trace!("Unprotected container index {}, protection set {}", root, self.index());
308
33271991
    }
309

            
310
    /// Parse the given string and returns the Term representation.
311
802
    pub fn from_string(&self, text: &str) -> Result<ATerm, MercError> {
312
802
        let mut result = TermParser::parse(Rule::TermSpec, text)?;
313
802
        let root = result.next().unwrap();
314

            
315
802
        Ok(TermParser::TermSpec(root).unwrap())
316
802
    }
317

            
318
    /// Protects a symbol from garbage collection.
319
5220164
    pub fn protect_symbol(&self, symbol: &SymbolRef<'_>) -> Symbol {
320
5220164
        let result = unsafe {
321
5220164
            Symbol::from_index(
322
5220164
                symbol.shared(),
323
5220164
                self.lock_protection_set()
324
5220164
                    .symbol_protection_set
325
5220164
                    .protect(symbol.shared().copy()),
326
            )
327
        };
328

            
329
5220164
        debug_trace!(
330
            "Protected symbol {}, root {}, protection set {}",
331
            symbol,
332
            result.root(),
333
            lock.index,
334
        );
335

            
336
5220164
        result
337
5220164
    }
338

            
339
    /// Unprotects a symbol, allowing it to be garbage collected.
340
5217864
    pub fn drop_symbol(&self, symbol: &mut Symbol) {
341
5217864
        self.lock_protection_set()
342
5217864
            .symbol_protection_set
343
5217864
            .unprotect(symbol.root());
344
5217864
    }
345

            
346
    /// Returns the symbol for ATermInt
347
686065160
    pub fn int_symbol(&self) -> &SymbolRef<'_> {
348
686065160
        &self.int_symbol
349
686065160
    }
350

            
351
    /// Returns the symbol for ATermList
352
974055
    pub fn list_symbol(&self) -> &SymbolRef<'_> {
353
974055
        &self.list_symbol
354
974055
    }
355

            
356
    /// Returns the symbol for the empty ATermInt
357
982489
    pub fn empty_list_symbol(&self) -> &SymbolRef<'_> {
358
982489
        &self.empty_list_symbol
359
982489
    }
360

            
361
    /// Enables or disables automatic garbage collection.
362
    pub fn automatic_garbage_collection(&self, enabled: bool) {
363
        let mut guard = self.term_pool.write().expect("Lock poisoned!");
364
        guard.automatic_garbage_collection(enabled);
365
    }
366

            
367
    /// Returns access to the shared protection set.
368
    pub(crate) fn get_protection_set(&self) -> &Arc<UnsafeCell<SharedTermProtection>> {
369
        &self.protection_set
370
    }
371

            
372
    /// Returns a reference to the global term pool.
373
641756041
    pub(crate) fn term_pool(&self) -> &RecursiveLock<GlobalTermPool> {
374
641756041
        &self.term_pool
375
641756041
    }
376

            
377
    /// Replace the entry in the protection set with the given term.
378
    pub(crate) fn replace(
379
        &self,
380
        _guard: RecursiveLockReadGuard<'_, GlobalTermPool>,
381
        root: ProtectionIndex,
382
        term: StablePointer<SharedTerm>,
383
    ) {
384
        // Protect the term by adding its index to the protection set
385
        // SAFETY: If the global term pool is locked, so we can safely access the protection set.
386
        unsafe { &mut *self.protection_set.get() }
387
            .protection_set
388
            .replace(root, term);
389
    }
390

            
391
    /// This triggers the global garbage collection based on heuristics.
392
562885
    fn trigger_garbage_collection(&self) {
393
        // If the term was newly inserted, decrease the garbage collection counter and trigger garbage collection if necessary
394
562885
        let mut value = self.garbage_collection_counter.get();
395
562885
        value = value.saturating_sub(1);
396

            
397
562885
        if value == 0 && !self.term_pool.is_locked() {
398
            // Trigger garbage collection and acquire a new counter value.
399
            value = self
400
                .term_pool
401
                .write()
402
                .expect("Lock poisoned!")
403
                .trigger_garbage_collection();
404
562885
        }
405

            
406
562885
        self.garbage_collection_counter.set(value);
407
562885
    }
408

            
409
    /// Returns the index of the protection set.
410
535
    fn index(&self) -> usize {
411
535
        self.lock_protection_set().index
412
535
    }
413

            
414
    /// The protection set is locked by the global read-write lock
415
836446979
    fn lock_protection_set(&self) -> ProtectionSetGuard<'_> {
416
836446979
        let guard = self.term_pool.read_recursive().expect("Lock poisoned!");
417
836446979
        let protection_set = unsafe { &mut *self.protection_set.get() };
418

            
419
836446979
        ProtectionSetGuard::new(guard, protection_set)
420
836446979
    }
421
}
422

            
423
impl Drop for ThreadTermPool {
424
535
    fn drop(&mut self) {
425
535
        let mut write = self.term_pool.write().expect("Lock poisoned!");
426

            
427
535
        debug!("{}", write.metrics());
428
535
        write.deregister_thread_pool(self.index());
429

            
430
535
        debug!("{}", unsafe { &mut *self.protection_set.get() }.metrics());
431
535
        debug!(
432
            "Acquired {} read locks and {} write locks",
433
            self.term_pool.read_recursive_call_count(),
434
            self.term_pool.write_call_count()
435
        )
436
535
    }
437
}
438

            
439
struct ProtectionSetGuard<'a> {
440
    _guard: RecursiveLockReadGuard<'a, GlobalTermPool>,
441
    object: &'a mut SharedTermProtection,
442
}
443

            
444
impl ProtectionSetGuard<'_> {
445
836446979
    fn new<'a>(
446
836446979
        guard: RecursiveLockReadGuard<'a, GlobalTermPool>,
447
836446979
        object: &'a mut SharedTermProtection,
448
836446979
    ) -> ProtectionSetGuard<'a> {
449
836446979
        ProtectionSetGuard { _guard: guard, object }
450
836446979
    }
451
}
452

            
453
impl Deref for ProtectionSetGuard<'_> {
454
    type Target = SharedTermProtection;
455

            
456
541
    fn deref(&self) -> &Self::Target {
457
541
        self.object
458
541
    }
459
}
460

            
461
impl DerefMut for ProtectionSetGuard<'_> {
462
836446438
    fn deref_mut(&mut self) -> &mut Self::Target {
463
836446438
        self.object
464
836446438
    }
465
}
466

            
467
#[cfg(test)]
468
mod tests {
469
    use crate::Term;
470

            
471
    use super::*;
472
    use std::thread;
473

            
474
    #[test]
475
1
    fn test_thread_local_protection() {
476
1
        let _ = merc_utilities::test_logger();
477

            
478
1
        thread::scope(|scope| {
479
1
            for _ in 0..3 {
480
3
                scope.spawn(|| {
481
                    // Create and protect some terms
482
3
                    let symbol = Symbol::new("test", 0);
483
3
                    let term = ATerm::constant(&symbol);
484
3
                    let protected = term.protect();
485

            
486
                    // Verify protection
487
3
                    THREAD_TERM_POOL.with_borrow(|tp| {
488
3
                        assert!(tp.lock_protection_set().protection_set.contains_root(protected.root()));
489
3
                    });
490

            
491
                    // Unprotect
492
3
                    let root = protected.root();
493
3
                    drop(protected);
494

            
495
3
                    THREAD_TERM_POOL.with_borrow(|tp| {
496
3
                        assert!(!tp.lock_protection_set().protection_set.contains_root(root));
497
3
                    });
498
3
                });
499
            }
500
1
        });
501
1
    }
502

            
503
    #[test]
504
1
    fn test_parsing() {
505
1
        let _ = merc_utilities::test_logger();
506

            
507
1
        let t = ATerm::from_string("f(g(a),b)").unwrap();
508

            
509
1
        assert!(t.get_head_symbol().name() == "f");
510
1
        assert!(t.arg(0).get_head_symbol().name() == "g");
511
1
        assert!(t.arg(1).get_head_symbol().name() == "b");
512
1
    }
513

            
514
    #[test]
515
1
    fn test_create_term() {
516
1
        let _ = merc_utilities::test_logger();
517

            
518
1
        let f = Symbol::new("f", 2);
519
1
        let g = Symbol::new("g", 1);
520

            
521
1
        let t = THREAD_TERM_POOL.with_borrow(|tp| {
522
1
            tp.create_term(
523
1
                &f,
524
1
                &[
525
1
                    tp.create_term(&g, &[tp.create_constant(&Symbol::new("a", 0))])
526
1
                        .protect(),
527
1
                    tp.create_constant(&Symbol::new("b", 0)),
528
1
                ],
529
1
            )
530
1
            .protect()
531
1
        });
532

            
533
1
        assert!(t.get_head_symbol().name() == "f");
534
1
        assert!(t.arg(0).get_head_symbol().name() == "g");
535
1
        assert!(t.arg(1).get_head_symbol().name() == "b");
536
1
    }
537
}