1
//! Authors: Maurice Laveaux, Flip van Spaendonck and Jan Friso Groote
2

            
3
use std::cell::Cell;
4
use std::error::Error;
5
use std::mem;
6
use std::ops::Deref;
7
use std::ops::DerefMut;
8

            
9
use crate::BfSharedMutex;
10
use crate::BfSharedMutexReadGuard;
11
use crate::BfSharedMutexWriteGuard;
12

            
13
/// An extension of the [BfSharedMutex] that allows recursive read locking without deadlocks.
14
pub struct RecursiveLock<T> {
15
    inner: BfSharedMutex<T>,
16

            
17
    /// The number of times the current thread has read locked the mutex.
18
    recursive_depth: Cell<usize>,
19

            
20
    /// The number of calls to the write() method.
21
    write_calls: Cell<usize>,
22

            
23
    /// The number of calls to the read_recursive() method.
24
    read_recursive_calls: Cell<usize>,
25
}
26

            
27
impl<T> RecursiveLock<T> {
28
    /// Creates a new `RecursiveLock` with the given data.
29
5
    pub fn new(data: T) -> Self {
30
5
        RecursiveLock {
31
5
            inner: BfSharedMutex::new(data),
32
5
            recursive_depth: Cell::new(0),
33
5
            write_calls: Cell::new(0),
34
5
            read_recursive_calls: Cell::new(0),
35
5
        }
36
5
    }
37

            
38
    /// Creates a new `RecursiveLock` from an existing `BfSharedMutex`.
39
586
    pub fn from_mutex(mutex: BfSharedMutex<T>) -> Self {
40
586
        RecursiveLock {
41
586
            inner: mutex,
42
586
            recursive_depth: Cell::new(0),
43
586
            write_calls: Cell::new(0),
44
586
            read_recursive_calls: Cell::new(0),
45
586
        }
46
586
    }
47

            
48
    delegate::delegate! {
49
        to self.inner {
50
            pub fn data_ptr(&self) -> *const T;
51
140645
            pub fn is_locked(&self) -> bool;
52
            pub fn is_locked_exclusive(&self) -> bool;
53
        }
54
    }
55

            
56
    /// Acquires a write lock on the mutex.
57
29195
    pub fn write(&self) -> Result<RecursiveLockWriteGuard<'_, T>, Box<dyn Error + '_>> {
58
29195
        debug_assert!(
59
29195
            self.recursive_depth.get() == 0,
60
            "Cannot call write() inside a read section"
61
        );
62
29195
        self.write_calls.set(self.write_calls.get() + 1);
63
29195
        self.recursive_depth.set(1);
64
        Ok(RecursiveLockWriteGuard {
65
29195
            mutex: self,
66
29195
            guard: self.inner.write()?,
67
        })
68
29195
    }
69

            
70
    /// Acquires a read lock on the mutex.
71
1
    pub fn read(&self) -> Result<BfSharedMutexReadGuard<'_, T>, Box<dyn Error + '_>> {
72
1
        debug_assert!(
73
1
            self.recursive_depth.get() == 0,
74
            "Cannot call read() inside a read section"
75
        );
76
1
        self.inner.read()
77
1
    }
78

            
79
    /// Acquires a read lock on the mutex, allowing for recursive read locking.
80
1549627313
    pub fn read_recursive<'a>(&'a self) -> Result<RecursiveLockReadGuard<'a, T>, Box<dyn Error + 'a>> {
81
1549627313
        self.read_recursive_calls.set(self.read_recursive_calls.get() + 1);
82
1549627313
        if self.recursive_depth.get() == 0 {
83
            // If we are not already holding a read lock, we acquire one.
84
            // Acquire the read guard, but forget it to prevent it from being dropped.
85
966621249
            self.recursive_depth.set(1);
86
966621249
            mem::forget(self.inner.read());
87
966621249
            Ok(RecursiveLockReadGuard { mutex: self })
88
        } else {
89
            // If we are already holding a read lock, we just increment the depth.
90
583006064
            self.recursive_depth.set(self.recursive_depth.get() + 1);
91
583006064
            Ok(RecursiveLockReadGuard { mutex: self })
92
        }
93
1549627313
    }
94

            
95
    /// Returns the number of times `write()` has been called.
96
9
    pub fn write_call_count(&self) -> usize {
97
9
        self.write_calls.get()
98
9
    }
99

            
100
    /// Returns the number of times `read_recursive()` has been called.
101
11
    pub fn read_recursive_call_count(&self) -> usize {
102
11
        self.read_recursive_calls.get()
103
11
    }
104
}
105

            
106
#[must_use = "Dropping the guard unlocks the recursive lock immediately"]
107
pub struct RecursiveLockReadGuard<'a, T> {
108
    mutex: &'a RecursiveLock<T>,
109
}
110

            
111
impl<T> RecursiveLockReadGuard<'_, T> {
112
    /// Returns the read depth of the recursive lock.
113
915530830
    pub fn read_depth(&self) -> usize {
114
915530830
        self.mutex.recursive_depth.get()
115
915530830
    }
116
}
117

            
118
/// Allow dereferences the underlying object.
119
impl<T> Deref for RecursiveLockReadGuard<'_, T> {
120
    type Target = T;
121

            
122
57133993
    fn deref(&self) -> &Self::Target {
123
        // There can only be shared guards, which only provide immutable access to the object.
124
57133993
        unsafe { self.mutex.inner.data_ptr().as_ref().unwrap_unchecked() }
125
57133993
    }
126
}
127

            
128
impl<T> Drop for RecursiveLockReadGuard<'_, T> {
129
1549627313
    fn drop(&mut self) {
130
1549627313
        self.mutex.recursive_depth.set(self.mutex.recursive_depth.get() - 1);
131
1549627313
        if self.mutex.recursive_depth.get() == 0 {
132
            // If we are not holding a read lock anymore, we release the mutex.
133
            // This will allow other threads to acquire a read lock.
134
966621249
            unsafe {
135
966621249
                self.mutex.inner.create_read_guard_unchecked();
136
966621249
            }
137
583006064
        }
138
1549627313
    }
139
}
140

            
141
#[must_use = "Dropping the guard unlocks the recursive lock immediately"]
142
pub struct RecursiveLockWriteGuard<'a, T> {
143
    mutex: &'a RecursiveLock<T>,
144
    guard: BfSharedMutexWriteGuard<'a, T>,
145
}
146

            
147
/// Allow dereferences the underlying object.
148
impl<T> Deref for RecursiveLockWriteGuard<'_, T> {
149
    type Target = T;
150

            
151
1756
    fn deref(&self) -> &Self::Target {
152
        // There can only be shared guards, which only provide immutable access to the object.
153
1756
        self.guard.deref()
154
1756
    }
155
}
156

            
157
/// Allow dereferences the underlying object.
158
impl<T> DerefMut for RecursiveLockWriteGuard<'_, T> {
159
29190
    fn deref_mut(&mut self) -> &mut Self::Target {
160
        // There can only be shared guards, which only provide immutable access to the object.
161
29190
        self.guard.deref_mut()
162
29190
    }
163
}
164

            
165
impl<T> Drop for RecursiveLockWriteGuard<'_, T> {
166
29195
    fn drop(&mut self) {
167
29195
        self.mutex.recursive_depth.set(self.mutex.recursive_depth.get() - 1);
168
29195
    }
169
}
170

            
171
#[cfg(test)]
172
mod tests {
173
    use super::*;
174

            
175
    #[test]
176
1
    fn test_from_mutex() {
177
1
        let mutex = BfSharedMutex::new(100);
178
1
        let lock = RecursiveLock::from_mutex(mutex);
179
1
        assert_eq!(*lock.read().unwrap(), 100);
180
1
    }
181

            
182
    #[test]
183
1
    fn test_single_recursive_read() {
184
1
        let lock = RecursiveLock::new(42);
185
1
        let guard = lock.read_recursive().unwrap();
186
1
        assert_eq!(*guard, 42);
187
1
        assert_eq!(lock.recursive_depth.get(), 1);
188
1
    }
189

            
190
    #[test]
191
1
    fn test_nested_recursive_reads() {
192
1
        let lock = RecursiveLock::new(42);
193

            
194
1
        let guard1 = lock.read_recursive().unwrap();
195
1
        assert_eq!(*guard1, 42);
196
1
        assert_eq!(lock.recursive_depth.get(), 1);
197

            
198
1
        let guard2 = lock.read_recursive().unwrap();
199
1
        assert_eq!(*guard2, 42);
200
1
        assert_eq!(lock.recursive_depth.get(), 2);
201

            
202
1
        let guard3 = lock.read_recursive().unwrap();
203
1
        assert_eq!(*guard3, 42);
204
1
        assert_eq!(lock.recursive_depth.get(), 3);
205

            
206
1
        drop(guard3);
207
1
        assert_eq!(lock.recursive_depth.get(), 2);
208

            
209
1
        drop(guard2);
210
1
        assert_eq!(lock.recursive_depth.get(), 1);
211

            
212
1
        drop(guard1);
213
1
        assert_eq!(lock.recursive_depth.get(), 0);
214
1
    }
215

            
216
    #[test]
217
1
    fn test_write_call_counter() {
218
1
        let lock = RecursiveLock::new(42);
219

            
220
        // Initially, the counter should be 0
221
1
        assert_eq!(lock.write_call_count(), 0);
222

            
223
        // After one write call, counter should be 1
224
        {
225
1
            let _guard = lock.write().unwrap();
226
1
            assert_eq!(lock.write_call_count(), 1);
227
        }
228

            
229
        // After another write call, counter should be 2
230
        {
231
1
            let _guard = lock.write().unwrap();
232
1
            assert_eq!(lock.write_call_count(), 2);
233
        }
234

            
235
        // Counter should remain 2
236
1
        assert_eq!(lock.write_call_count(), 2);
237
1
    }
238

            
239
    #[test]
240
1
    fn test_read_recursive_call_counter() {
241
1
        let lock = RecursiveLock::new(42);
242

            
243
        // Initially, the counter should be 0
244
1
        assert_eq!(lock.read_recursive_call_count(), 0);
245

            
246
        // After one read_recursive call, counter should be 1
247
        {
248
1
            let _guard = lock.read_recursive().unwrap();
249
1
            assert_eq!(lock.read_recursive_call_count(), 1);
250
        }
251

            
252
        // After another read_recursive call, counter should be 2
253
        {
254
1
            let _guard = lock.read_recursive().unwrap();
255
1
            assert_eq!(lock.read_recursive_call_count(), 2);
256
        }
257

            
258
        // Test nested recursive reads increment the counter
259
        {
260
1
            let _guard1 = lock.read_recursive().unwrap();
261
1
            assert_eq!(lock.read_recursive_call_count(), 3);
262

            
263
1
            let _guard2 = lock.read_recursive().unwrap();
264
1
            assert_eq!(lock.read_recursive_call_count(), 4);
265
        }
266

            
267
        // Counter should remain 4
268
1
        assert_eq!(lock.read_recursive_call_count(), 4);
269
1
    }
270

            
271
    #[test]
272
1
    fn test_both_counters() {
273
1
        let lock = RecursiveLock::new(42);
274

            
275
        // Initially, both counters should be 0
276
1
        assert_eq!(lock.write_call_count(), 0);
277
1
        assert_eq!(lock.read_recursive_call_count(), 0);
278

            
279
        // Call write and check counters
280
        {
281
1
            let _guard = lock.write().unwrap();
282
1
            assert_eq!(lock.write_call_count(), 1);
283
1
            assert_eq!(lock.read_recursive_call_count(), 0);
284
        }
285

            
286
        // Call read_recursive and check counters
287
        {
288
1
            let _guard = lock.read_recursive().unwrap();
289
1
            assert_eq!(lock.write_call_count(), 1);
290
1
            assert_eq!(lock.read_recursive_call_count(), 1);
291
        }
292

            
293
        // Call write again
294
        {
295
1
            let _guard = lock.write().unwrap();
296
1
            assert_eq!(lock.write_call_count(), 2);
297
1
            assert_eq!(lock.read_recursive_call_count(), 1);
298
        }
299

            
300
        // Call read_recursive multiple times
301
        {
302
1
            let _guard1 = lock.read_recursive().unwrap();
303
1
            let _guard2 = lock.read_recursive().unwrap();
304
1
            assert_eq!(lock.write_call_count(), 2);
305
1
            assert_eq!(lock.read_recursive_call_count(), 3);
306
        }
307
1
    }
308
}