1
#[cfg(debug_assertions)]
2
use std::cell::RefCell;
3
use std::fmt::Debug;
4
use std::hash::Hash;
5
use std::mem::transmute;
6
use std::ops::Deref;
7
use std::ops::DerefMut;
8
use std::sync::Arc;
9

            
10
use merc_unsafety::ProtectionIndex;
11
use merc_utilities::PhantomUnsend;
12

            
13
use crate::Markable;
14
use crate::Symb;
15
use crate::SymbolRef;
16
use crate::Term;
17
use crate::Transmutable;
18
use crate::aterm::ATermRef;
19
use crate::storage::GcMutex;
20
use crate::storage::GcMutexGuard;
21
use crate::storage::THREAD_TERM_POOL;
22

            
23
/// A container of objects, typically either terms or objects containing terms,
24
/// that implement [Markable]. These store [ATermRef]`<'static>` values that are
25
/// protected during garbage collection by being in the container itself.
26
pub struct Protected<C> {
27
    container: Arc<GcMutex<C>>,
28
    root: ProtectionIndex,
29

            
30
    // Protected is not Send because it uses thread-local state for its protection
31
    // mechanism.
32
    _unsend: PhantomUnsend,
33
}
34

            
35
impl<C: Markable + Send + Transmutable + 'static> Protected<C> {
36
    /// Creates a new Protected container from a given container.
37
3325363
    pub fn new(container: C) -> Protected<C> {
38
3325363
        let shared = Arc::new(GcMutex::new(container));
39

            
40
3325363
        let root = THREAD_TERM_POOL.with_borrow(|tp| tp.protect_container(shared.clone()));
41

            
42
3325363
        Protected {
43
3325363
            container: shared,
44
3325363
            root,
45
3325363
            _unsend: Default::default(),
46
3325363
        }
47
3325363
    }
48

            
49
    /// Provides mutable access to the underlying container, returning a [ProtectedWriteGuard].
50
38799739
    pub fn write(&mut self) -> ProtectedWriteGuard<'_, C> {
51
38799739
        ProtectedWriteGuard::new(self.container.lock())
52
38799739
    }
53

            
54
    /// Provides immutable access to the underlying container, returning a [ProtectedReadGuard].
55
27417505
    pub fn read(&self) -> ProtectedReadGuard<'_, C> {
56
27417505
        ProtectedReadGuard::new(self.container.lock())
57
27417505
    }
58
}
59

            
60
impl<C: Default + Markable + Send + Transmutable + 'static> Default for Protected<C> {
61
3059006
    fn default() -> Self {
62
3059006
        Protected::new(Default::default())
63
3059006
    }
64
}
65

            
66
impl<C: Clone + Markable + Send + Transmutable + 'static> Clone for Protected<C> {
67
    fn clone(&self) -> Self {
68
        Protected::new(self.container.lock().clone())
69
    }
70
}
71

            
72
impl<C: Hash + Markable> Hash for Protected<C> {
73
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
74
        self.container.lock().hash(state)
75
    }
76
}
77

            
78
impl<C: PartialEq + Markable> PartialEq for Protected<C> {
79
1
    fn eq(&self, other: &Self) -> bool {
80
1
        self.container.lock().eq(&other.container.lock())
81
1
    }
82
}
83

            
84
impl<C: PartialOrd + Markable> PartialOrd for Protected<C> {
85
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
86
        let c: &C = &other.container.lock();
87
        self.container.lock().partial_cmp(c)
88
    }
89
}
90

            
91
impl<C: Debug + Markable> Debug for Protected<C> {
92
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93
        let c: &C = &self.container.lock();
94
        write!(f, "{c:?}")
95
    }
96
}
97

            
98
impl<C: Eq + PartialEq + Markable> Eq for Protected<C> {}
99
impl<C: Ord + PartialEq + PartialOrd + Markable> Ord for Protected<C> {
100
    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
101
        let c: &C = &other.container.lock();
102
        self.container.lock().partial_cmp(c).unwrap()
103
    }
104
}
105

            
106
impl<C> Drop for Protected<C> {
107
3325363
    fn drop(&mut self) {
108
3325363
        THREAD_TERM_POOL.with_borrow(|tp| {
109
3325363
            tp.drop_container(self.root);
110
3325363
        });
111
3325363
    }
112
}
113

            
114
pub struct ProtectedWriteGuard<'a, C: Markable> {
115
    reference: GcMutexGuard<'a, C>,
116

            
117
    /// Terms that have been protected during the lifetime of this guard.
118
    #[cfg(debug_assertions)]
119
    protected: RefCell<Vec<ATermRef<'static>>>,
120

            
121
    /// Symbols that have been protected during the lifetime of this guard.
122
    #[cfg(debug_assertions)]
123
    protected_symbols: RefCell<Vec<SymbolRef<'static>>>,
124
}
125

            
126
impl<'a, C: Markable> ProtectedWriteGuard<'a, C> {
127
38799739
    fn new(reference: GcMutexGuard<'a, C>) -> Self {
128
        #[cfg(debug_assertions)]
129
38799739
        return ProtectedWriteGuard {
130
38799739
            reference,
131
38799739
            protected: RefCell::new(vec![]),
132
38799739
            protected_symbols: RefCell::new(vec![]),
133
38799739
        };
134

            
135
        #[cfg(not(debug_assertions))]
136
        return ProtectedWriteGuard { reference };
137
38799739
    }
138

            
139
    /// Yields a term to insert into the container.
140
    ///
141
    /// # Safety
142
    ///
143
    /// The invariant to uphold is that the resulting term MUST be inserted into
144
    /// the container. This is checked in debug mode, but not in release mode.
145
    /// If this invariant is violated, undefined behaviour may occur during
146
    /// garbage collection. We do not mark this function unsafe since that would
147
    /// make its use cumbersome.
148
26592719
    pub fn protect<'b, T: Term<'a, 'b>>(&self, term: &'b T) -> ATermRef<'static> {
149
        unsafe {
150
            // Store terms that are marked as protected to check if they are
151
            // actually in the container when the protection is dropped.
152
            #[cfg(debug_assertions)]
153
26592719
            self.protected
154
26592719
                .borrow_mut()
155
26592719
                .push(transmute::<ATermRef<'_>, ATermRef<'static>>(term.copy()));
156

            
157
26592719
            transmute::<ATermRef<'_>, ATermRef<'static>>(term.copy())
158
        }
159
26592719
    }
160

            
161
    /// Yields a symbol to insert into the container.
162
    ///
163
    /// The invariant to uphold is that the resulting symbol MUST be inserted
164
    /// into the container.
165
6266
    pub fn protect_symbol<'b, S: Symb<'a, 'b>>(&self, symbol: &'b S) -> SymbolRef<'static> {
166
        unsafe {
167
            // Store symbols that are marked as protected to check if they are
168
            // actually in the container when the protection is dropped.
169
            #[cfg(debug_assertions)]
170
6266
            self.protected_symbols
171
6266
                .borrow_mut()
172
6266
                .push(transmute::<SymbolRef<'_>, SymbolRef<'static>>(symbol.copy()));
173

            
174
6266
            transmute::<SymbolRef<'_>, SymbolRef<'static>>(symbol.copy())
175
        }
176
6266
    }
177
}
178

            
179
#[cfg(debug_assertions)]
180
impl<C: Markable> Drop for ProtectedWriteGuard<'_, C> {
181
38799739
    fn drop(&mut self) {
182
        {
183
38799739
            for term in self.protected.borrow().iter() {
184
26592719
                debug_assert!(
185
26592719
                    self.reference.contains_term(term),
186
                    "Term was protected but not actually inserted"
187
                );
188
            }
189

            
190
38799739
            for symbol in self.protected_symbols.borrow().iter() {
191
6266
                debug_assert!(
192
6266
                    self.reference.contains_symbol(symbol),
193
                    "Symbol was protected but not actually inserted"
194
                );
195
            }
196
        }
197
38799739
    }
198
}
199

            
200
impl<'a, C: Markable + Transmutable + 'a> Deref for ProtectedWriteGuard<'a, C> {
201
    type Target = C::Target<'a>;
202

            
203
23685185
    fn deref(&self) -> &Self::Target {
204
23685185
        self.reference.transmute_lifetime()
205
23685185
    }
206
}
207

            
208
impl<C: Markable + Transmutable> DerefMut for ProtectedWriteGuard<'_, C> {
209
60760898
    fn deref_mut(&mut self) -> &mut Self::Target {
210
60760898
        self.reference.deref_mut().transmute_lifetime_mut()
211
60760898
    }
212
}
213

            
214
pub struct ProtectedReadGuard<'a, C> {
215
    reference: GcMutexGuard<'a, C>,
216
}
217

            
218
impl<'a, C> ProtectedReadGuard<'a, C> {
219
27417505
    fn new(reference: GcMutexGuard<'a, C>) -> Self {
220
27417505
        Self { reference }
221
27417505
    }
222
}
223

            
224
impl<'a, C: Transmutable> Deref for ProtectedReadGuard<'a, C> {
225
    type Target = C::Target<'a>;
226

            
227
240949092
    fn deref(&self) -> &Self::Target {
228
240949092
        self.reference.transmute_lifetime()
229
240949092
    }
230
}
231

            
232
#[cfg(test)]
233
mod tests {
234
    use crate::ATerm;
235
    use crate::ATermRef;
236
    use crate::Protected;
237

            
238
    #[test]
239
1
    fn test_aterm_container() {
240
1
        let _ = merc_utilities::test_logger();
241

            
242
1
        let t = ATerm::from_string("f(g(a),b)").unwrap();
243

            
244
        // First test the trait for a standard container.
245
1
        let mut container = Protected::<Vec<ATermRef<'static>>>::new(vec![]);
246

            
247
1000
        for _ in 0..1000 {
248
1000
            let mut write = container.write();
249
1000
            write.push(t.get());
250
1000
        }
251
1
    }
252
}