1
//! Authors: Maurice Laveaux, Flip van Spaendonck and Jan Friso Groote
2

            
3
use std::cell::Cell;
4
use std::error::Error;
5
use std::mem;
6
use std::ops::Deref;
7
use std::ops::DerefMut;
8

            
9
use crate::BfSharedMutex;
10
use crate::BfSharedMutexReadGuard;
11
use crate::BfSharedMutexWriteGuard;
12

            
13
/// An extension of the [BfSharedMutex] that allows recursive read locking without deadlocks.
14
pub struct RecursiveLock<T> {
15
    inner: BfSharedMutex<T>,
16

            
17
    /// The number of times the current thread has read locked the mutex.
18
    recursive_depth: Cell<usize>,
19

            
20
    /// The number of calls to the write() method.
21
    write_calls: Cell<usize>,
22

            
23
    /// The number of calls to the read_recursive() method.
24
    read_recursive_calls: Cell<usize>,
25
}
26

            
27
impl<T> RecursiveLock<T> {
28
    /// Creates a new `RecursiveLock` with the given data.
29
5
    pub fn new(data: T) -> Self {
30
5
        RecursiveLock {
31
5
            inner: BfSharedMutex::new(data),
32
5
            recursive_depth: Cell::new(0),
33
5
            write_calls: Cell::new(0),
34
5
            read_recursive_calls: Cell::new(0),
35
5
        }
36
5
    }
37

            
38
    /// Creates a new `RecursiveLock` from an existing `BfSharedMutex`.
39
963
    pub fn from_mutex(mutex: BfSharedMutex<T>) -> Self {
40
963
        RecursiveLock {
41
963
            inner: mutex,
42
963
            recursive_depth: Cell::new(0),
43
963
            write_calls: Cell::new(0),
44
963
            read_recursive_calls: Cell::new(0),
45
963
        }
46
963
    }
47

            
48
    delegate::delegate! {
49
        to self.inner {
50
            #[cfg(not(loom))]
51
            pub fn data_ptr(&self) -> *const T;
52
            #[cfg(loom)]
53
            pub fn data_ptr(&self) -> loom::cell::ConstPtr<T>;
54
917941
            pub fn is_locked(&self) -> bool;
55
            pub fn is_locked_exclusive(&self) -> bool;
56
        }
57
    }
58

            
59
    /// Acquires a write lock on the mutex.
60
390336
    pub fn write(&self) -> Result<RecursiveLockWriteGuard<'_, T>, Box<dyn Error + '_>> {
61
390336
        debug_assert!(
62
390336
            self.recursive_depth.get() == 0,
63
            "Cannot call write() inside a read section"
64
        );
65
390336
        self.write_calls.set(self.write_calls.get() + 1);
66
390336
        self.recursive_depth.set(1);
67
        Ok(RecursiveLockWriteGuard {
68
390336
            mutex: self,
69
390336
            guard: self.inner.write()?,
70
            #[cfg(loom)]
71
            ptr: self.inner.data_ptr(),
72
        })
73
390336
    }
74

            
75
    /// Acquires a read lock on the mutex.
76
1
    pub fn read(&self) -> Result<BfSharedMutexReadGuard<'_, T>, Box<dyn Error + '_>> {
77
1
        debug_assert!(
78
1
            self.recursive_depth.get() == 0,
79
            "Cannot call read() inside a read section"
80
        );
81
1
        self.inner.read()
82
1
    }
83

            
84
    /// Acquires a read lock on the mutex, allowing for recursive read locking.
85
2501409992
    pub fn read_recursive<'a>(&'a self) -> Result<RecursiveLockReadGuard<'a, T>, Box<dyn Error + 'a>> {
86
2501409992
        self.read_recursive_calls.set(self.read_recursive_calls.get() + 1);
87
2501409992
        if self.recursive_depth.get() == 0 {
88
            // If we are not already holding a read lock, we acquire one.
89
            // Acquire the read guard, but forget it to prevent it from being dropped.
90
1571091781
            self.recursive_depth.set(1);
91
1571091781
            mem::forget(self.inner.read()?);
92
1571091781
            Ok(RecursiveLockReadGuard {
93
1571091781
                mutex: self,
94
1571091781
                #[cfg(loom)]
95
1571091781
                ptr: self.inner.data_ptr(),
96
1571091781
            })
97
        } else {
98
            // If we are already holding a read lock, we just increment the depth.
99
930318211
            self.recursive_depth.set(self.recursive_depth.get() + 1);
100
930318211
            Ok(RecursiveLockReadGuard {
101
930318211
                mutex: self,
102
930318211
                #[cfg(loom)]
103
930318211
                ptr: self.inner.data_ptr(),
104
930318211
            })
105
        }
106
2501409992
    }
107

            
108
    /// Returns the number of times `write()` has been called.
109
9
    pub fn write_call_count(&self) -> usize {
110
9
        self.write_calls.get()
111
9
    }
112

            
113
    /// Returns the number of times `read_recursive()` has been called.
114
11
    pub fn read_recursive_call_count(&self) -> usize {
115
11
        self.read_recursive_calls.get()
116
11
    }
117
}
118

            
119
#[must_use = "Dropping the guard unlocks the recursive lock immediately"]
120
pub struct RecursiveLockReadGuard<'a, T> {
121
    mutex: &'a RecursiveLock<T>,
122

            
123
    #[cfg(loom)]
124
    ptr: loom::cell::ConstPtr<T>,
125
}
126

            
127
impl<T> RecursiveLockReadGuard<'_, T> {
128
    /// Returns the read depth of the recursive lock.
129
1524397918
    pub fn read_depth(&self) -> usize {
130
1524397918
        self.mutex.recursive_depth.get()
131
1524397918
    }
132
}
133

            
134
/// Allow dereferences the underlying object.
135
impl<T> Deref for RecursiveLockReadGuard<'_, T> {
136
    type Target = T;
137

            
138
91759284
    fn deref(&self) -> &Self::Target {
139
        // There can only be shared guards, which only provide immutable access to the object.
140
        #[cfg(not(loom))]
141
        unsafe {
142
91759284
            self.mutex.inner.data_ptr().as_ref().unwrap_unchecked()
143
        }
144

            
145
        #[cfg(loom)]
146
        unsafe {
147
            self.ptr.deref()
148
        }
149
91759284
    }
150
}
151

            
152
impl<T> Drop for RecursiveLockReadGuard<'_, T> {
153
2501409992
    fn drop(&mut self) {
154
2501409992
        self.mutex.recursive_depth.set(self.mutex.recursive_depth.get() - 1);
155
2501409992
        if self.mutex.recursive_depth.get() == 0 {
156
            // If we are not holding a read lock anymore, we release the mutex.
157
            // This will allow other threads to acquire a read lock.
158
1571091781
            unsafe {
159
1571091781
                // Drop the guard immediately to release busy=false via its Drop impl.
160
1571091781
                let _ = self.mutex.inner.create_read_guard_unchecked();
161
1571091781
            }
162
930318211
        }
163
2501409992
    }
164
}
165

            
166
#[must_use = "Dropping the guard unlocks the recursive lock immediately"]
167
pub struct RecursiveLockWriteGuard<'a, T> {
168
    mutex: &'a RecursiveLock<T>,
169

            
170
    guard: BfSharedMutexWriteGuard<'a, T>,
171

            
172
    #[cfg(loom)]
173
    ptr: loom::cell::ConstPtr<T>,
174
}
175

            
176
/// Allow dereferences the underlying object.
177
impl<T> Deref for RecursiveLockWriteGuard<'_, T> {
178
    type Target = T;
179

            
180
2887
    fn deref(&self) -> &Self::Target {
181
        // We hold the write guard, so immutable access is safe.
182
        #[cfg(loom)]
183
        unsafe {
184
            return self.ptr.deref();
185
        }
186

            
187
        #[cfg(not(loom))]
188
2887
        self.guard.deref()
189
2887
    }
190
}
191

            
192
/// Allow dereferences the underlying object.
193
impl<T> DerefMut for RecursiveLockWriteGuard<'_, T> {
194
390331
    fn deref_mut(&mut self) -> &mut Self::Target {
195
        // We hold the write guard exclusively, so mutable access is safe.
196
390331
        self.guard.deref_mut()
197
390331
    }
198
}
199

            
200
impl<T> Drop for RecursiveLockWriteGuard<'_, T> {
201
390336
    fn drop(&mut self) {
202
390336
        self.mutex.recursive_depth.set(self.mutex.recursive_depth.get() - 1);
203
390336
    }
204
}
205

            
206
#[cfg(test)]
207
mod tests {
208
    use crate::BfSharedMutex;
209
    use crate::RecursiveLock;
210

            
211
    #[test]
212
1
    fn test_from_mutex() {
213
1
        let mutex = BfSharedMutex::new(100);
214
1
        let lock = RecursiveLock::from_mutex(mutex);
215
1
        assert_eq!(*lock.read().unwrap(), 100);
216
1
    }
217

            
218
    #[test]
219
1
    fn test_single_recursive_read() {
220
1
        let lock = RecursiveLock::new(42);
221
1
        let guard = lock.read_recursive().unwrap();
222
1
        assert_eq!(*guard, 42);
223
1
        assert_eq!(lock.recursive_depth.get(), 1);
224
1
    }
225

            
226
    #[test]
227
1
    fn test_nested_recursive_reads() {
228
1
        let lock = RecursiveLock::new(42);
229

            
230
1
        let guard1 = lock.read_recursive().unwrap();
231
1
        assert_eq!(*guard1, 42);
232
1
        assert_eq!(lock.recursive_depth.get(), 1);
233

            
234
1
        let guard2 = lock.read_recursive().unwrap();
235
1
        assert_eq!(*guard2, 42);
236
1
        assert_eq!(lock.recursive_depth.get(), 2);
237

            
238
1
        let guard3 = lock.read_recursive().unwrap();
239
1
        assert_eq!(*guard3, 42);
240
1
        assert_eq!(lock.recursive_depth.get(), 3);
241

            
242
1
        drop(guard3);
243
1
        assert_eq!(lock.recursive_depth.get(), 2);
244

            
245
1
        drop(guard2);
246
1
        assert_eq!(lock.recursive_depth.get(), 1);
247

            
248
1
        drop(guard1);
249
1
        assert_eq!(lock.recursive_depth.get(), 0);
250
1
    }
251

            
252
    #[test]
253
1
    fn test_write_call_counter() {
254
1
        let lock = RecursiveLock::new(42);
255

            
256
        // Initially, the counter should be 0
257
1
        assert_eq!(lock.write_call_count(), 0);
258

            
259
        // After one write call, counter should be 1
260
        {
261
1
            let _guard = lock.write().unwrap();
262
1
            assert_eq!(lock.write_call_count(), 1);
263
        }
264

            
265
        // After another write call, counter should be 2
266
        {
267
1
            let _guard = lock.write().unwrap();
268
1
            assert_eq!(lock.write_call_count(), 2);
269
        }
270

            
271
        // Counter should remain 2
272
1
        assert_eq!(lock.write_call_count(), 2);
273
1
    }
274

            
275
    #[test]
276
1
    fn test_read_recursive_call_counter() {
277
1
        let lock = RecursiveLock::new(42);
278

            
279
        // Initially, the counter should be 0
280
1
        assert_eq!(lock.read_recursive_call_count(), 0);
281

            
282
        // After one read_recursive call, counter should be 1
283
        {
284
1
            let _guard = lock.read_recursive().unwrap();
285
1
            assert_eq!(lock.read_recursive_call_count(), 1);
286
        }
287

            
288
        // After another read_recursive call, counter should be 2
289
        {
290
1
            let _guard = lock.read_recursive().unwrap();
291
1
            assert_eq!(lock.read_recursive_call_count(), 2);
292
        }
293

            
294
        // Test nested recursive reads increment the counter
295
        {
296
1
            let _guard1 = lock.read_recursive().unwrap();
297
1
            assert_eq!(lock.read_recursive_call_count(), 3);
298

            
299
1
            let _guard2 = lock.read_recursive().unwrap();
300
1
            assert_eq!(lock.read_recursive_call_count(), 4);
301
        }
302

            
303
        // Counter should remain 4
304
1
        assert_eq!(lock.read_recursive_call_count(), 4);
305
1
    }
306

            
307
    #[test]
308
1
    fn test_both_counters() {
309
1
        let lock = RecursiveLock::new(42);
310

            
311
        // Initially, both counters should be 0
312
1
        assert_eq!(lock.write_call_count(), 0);
313
1
        assert_eq!(lock.read_recursive_call_count(), 0);
314

            
315
        // Call write and check counters
316
        {
317
1
            let _guard = lock.write().unwrap();
318
1
            assert_eq!(lock.write_call_count(), 1);
319
1
            assert_eq!(lock.read_recursive_call_count(), 0);
320
        }
321

            
322
        // Call read_recursive and check counters
323
        {
324
1
            let _guard = lock.read_recursive().unwrap();
325
1
            assert_eq!(lock.write_call_count(), 1);
326
1
            assert_eq!(lock.read_recursive_call_count(), 1);
327
        }
328

            
329
        // Call write again
330
        {
331
1
            let _guard = lock.write().unwrap();
332
1
            assert_eq!(lock.write_call_count(), 2);
333
1
            assert_eq!(lock.read_recursive_call_count(), 1);
334
        }
335

            
336
        // Call read_recursive multiple times
337
        {
338
1
            let _guard1 = lock.read_recursive().unwrap();
339
1
            let _guard2 = lock.read_recursive().unwrap();
340
1
            assert_eq!(lock.write_call_count(), 2);
341
1
            assert_eq!(lock.read_recursive_call_count(), 3);
342
        }
343
1
    }
344
}