1
//! Authors: Maurice Laveaux, Flip van Spaendonck and Jan Friso Groote
2

            
3
use std::cell::Cell;
4
use std::error::Error;
5
use std::mem;
6
use std::ops::Deref;
7
use std::ops::DerefMut;
8

            
9
use crate::BfSharedMutex;
10
use crate::BfSharedMutexReadGuard;
11
use crate::BfSharedMutexWriteGuard;
12

            
13
/// An extension of the [BfSharedMutex] that allows recursive read locking without deadlocks.
14
pub struct RecursiveLock<T> {
15
    inner: BfSharedMutex<T>,
16

            
17
    /// The number of times the current thread has read locked the mutex.
18
    recursive_depth: Cell<usize>,
19

            
20
    /// The number of calls to the write() method.
21
    write_calls: Cell<usize>,
22

            
23
    /// The number of calls to the read_recursive() method.
24
    read_recursive_calls: Cell<usize>,
25
}
26

            
27
impl<T> RecursiveLock<T> {
28
    /// Creates a new `RecursiveLock` with the given data.
29
5
    pub fn new(data: T) -> Self {
30
5
        RecursiveLock {
31
5
            inner: BfSharedMutex::new(data),
32
5
            recursive_depth: Cell::new(0),
33
5
            write_calls: Cell::new(0),
34
5
            read_recursive_calls: Cell::new(0),
35
5
        }
36
5
    }
37

            
38
    /// Creates a new `RecursiveLock` from an existing `BfSharedMutex`.
39
536
    pub fn from_mutex(mutex: BfSharedMutex<T>) -> Self {
40
536
        RecursiveLock {
41
536
            inner: mutex,
42
536
            recursive_depth: Cell::new(0),
43
536
            write_calls: Cell::new(0),
44
536
            read_recursive_calls: Cell::new(0),
45
536
        }
46
536
    }
47

            
48
    delegate::delegate! {
49
        to self.inner {
50
            pub fn data_ptr(&self) -> *const T;
51
395766
            pub fn is_locked(&self) -> bool;
52
            pub fn is_locked_exclusive(&self) -> bool;
53
        }
54
    }
55

            
56
    /// Acquires a write lock on the mutex.
57
1075
    pub fn write(&self) -> Result<RecursiveLockWriteGuard<'_, T>, Box<dyn Error + '_>> {
58
1075
        debug_assert!(
59
1075
            self.recursive_depth.get() == 0,
60
            "Cannot call write() inside a read section"
61
        );
62
1075
        self.write_calls.set(self.write_calls.get() + 1);
63
1075
        self.recursive_depth.set(1);
64
        Ok(RecursiveLockWriteGuard {
65
1075
            mutex: self,
66
1075
            guard: self.inner.write()?,
67
        })
68
1075
    }
69

            
70
    /// Acquires a read lock on the mutex.
71
1
    pub fn read(&self) -> Result<BfSharedMutexReadGuard<'_, T>, Box<dyn Error + '_>> {
72
1
        debug_assert!(
73
1
            self.recursive_depth.get() == 0,
74
            "Cannot call read() inside a read section"
75
        );
76
1
        self.inner.read()
77
1
    }
78

            
79
    /// Acquires a read lock on the mutex, allowing for recursive read locking.
80
1535333109
    pub fn read_recursive<'a>(&'a self) -> Result<RecursiveLockReadGuard<'a, T>, Box<dyn Error + 'a>> {
81
1535333109
        self.read_recursive_calls.set(self.read_recursive_calls.get() + 1);
82
1535333109
        if self.recursive_depth.get() == 0 {
83
            // If we are not already holding a read lock, we acquire one.
84
            // Acquire the read guard, but forget it to prevent it from being dropped.
85
966644555
            self.recursive_depth.set(1);
86
966644555
            mem::forget(self.inner.read());
87
966644555
            Ok(RecursiveLockReadGuard { mutex: self })
88
        } else {
89
            // If we are already holding a read lock, we just increment the depth.
90
568688554
            self.recursive_depth.set(self.recursive_depth.get() + 1);
91
568688554
            Ok(RecursiveLockReadGuard { mutex: self })
92
        }
93
1535333109
    }
94

            
95
    /// Returns the number of times `write()` has been called.
96
9
    pub fn write_call_count(&self) -> usize {
97
9
        self.write_calls.get()
98
9
    }
99

            
100
    /// Returns the number of times `read_recursive()` has been called.
101
11
    pub fn read_recursive_call_count(&self) -> usize {
102
11
        self.read_recursive_calls.get()
103
11
    }
104
}
105

            
106
#[must_use = "Dropping the guard unlocks the recursive lock immediately"]
107
pub struct RecursiveLockReadGuard<'a, T> {
108
    mutex: &'a RecursiveLock<T>,
109
}
110

            
111
/// Allow dereferences the underlying object.
112
impl<T> Deref for RecursiveLockReadGuard<'_, T> {
113
    type Target = T;
114

            
115
57130083
    fn deref(&self) -> &Self::Target {
116
        // There can only be shared guards, which only provide immutable access to the object.
117
57130083
        unsafe { self.mutex.inner.data_ptr().as_ref().unwrap_unchecked() }
118
57130083
    }
119
}
120

            
121
impl<T> Drop for RecursiveLockReadGuard<'_, T> {
122
1535333109
    fn drop(&mut self) {
123
1535333109
        self.mutex.recursive_depth.set(self.mutex.recursive_depth.get() - 1);
124
1535333109
        if self.mutex.recursive_depth.get() == 0 {
125
            // If we are not holding a read lock anymore, we release the mutex.
126
            // This will allow other threads to acquire a read lock.
127
966644555
            unsafe {
128
966644555
                self.mutex.inner.create_read_guard_unchecked();
129
966644555
            }
130
568688554
        }
131
1535333109
    }
132
}
133

            
134
#[must_use = "Dropping the guard unlocks the recursive lock immediately"]
135
pub struct RecursiveLockWriteGuard<'a, T> {
136
    mutex: &'a RecursiveLock<T>,
137
    guard: BfSharedMutexWriteGuard<'a, T>,
138
}
139

            
140
/// Allow dereferences the underlying object.
141
impl<T> Deref for RecursiveLockWriteGuard<'_, T> {
142
    type Target = T;
143

            
144
1606
    fn deref(&self) -> &Self::Target {
145
        // There can only be shared guards, which only provide immutable access to the object.
146
1606
        self.guard.deref()
147
1606
    }
148
}
149

            
150
/// Allow dereferences the underlying object.
151
impl<T> DerefMut for RecursiveLockWriteGuard<'_, T> {
152
1070
    fn deref_mut(&mut self) -> &mut Self::Target {
153
        // There can only be shared guards, which only provide immutable access to the object.
154
1070
        self.guard.deref_mut()
155
1070
    }
156
}
157

            
158
impl<T> Drop for RecursiveLockWriteGuard<'_, T> {
159
1075
    fn drop(&mut self) {
160
1075
        self.mutex.recursive_depth.set(self.mutex.recursive_depth.get() - 1);
161
1075
    }
162
}
163

            
164
#[cfg(test)]
165
mod tests {
166
    use super::*;
167

            
168
    #[test]
169
1
    fn test_from_mutex() {
170
1
        let mutex = BfSharedMutex::new(100);
171
1
        let lock = RecursiveLock::from_mutex(mutex);
172
1
        assert_eq!(*lock.read().unwrap(), 100);
173
1
    }
174

            
175
    #[test]
176
1
    fn test_single_recursive_read() {
177
1
        let lock = RecursiveLock::new(42);
178
1
        let guard = lock.read_recursive().unwrap();
179
1
        assert_eq!(*guard, 42);
180
1
        assert_eq!(lock.recursive_depth.get(), 1);
181
1
    }
182

            
183
    #[test]
184
1
    fn test_nested_recursive_reads() {
185
1
        let lock = RecursiveLock::new(42);
186

            
187
1
        let guard1 = lock.read_recursive().unwrap();
188
1
        assert_eq!(*guard1, 42);
189
1
        assert_eq!(lock.recursive_depth.get(), 1);
190

            
191
1
        let guard2 = lock.read_recursive().unwrap();
192
1
        assert_eq!(*guard2, 42);
193
1
        assert_eq!(lock.recursive_depth.get(), 2);
194

            
195
1
        let guard3 = lock.read_recursive().unwrap();
196
1
        assert_eq!(*guard3, 42);
197
1
        assert_eq!(lock.recursive_depth.get(), 3);
198

            
199
1
        drop(guard3);
200
1
        assert_eq!(lock.recursive_depth.get(), 2);
201

            
202
1
        drop(guard2);
203
1
        assert_eq!(lock.recursive_depth.get(), 1);
204

            
205
1
        drop(guard1);
206
1
        assert_eq!(lock.recursive_depth.get(), 0);
207
1
    }
208

            
209
    #[test]
210
1
    fn test_write_call_counter() {
211
1
        let lock = RecursiveLock::new(42);
212

            
213
        // Initially, the counter should be 0
214
1
        assert_eq!(lock.write_call_count(), 0);
215

            
216
        // After one write call, counter should be 1
217
        {
218
1
            let _guard = lock.write().unwrap();
219
1
            assert_eq!(lock.write_call_count(), 1);
220
        }
221

            
222
        // After another write call, counter should be 2
223
        {
224
1
            let _guard = lock.write().unwrap();
225
1
            assert_eq!(lock.write_call_count(), 2);
226
        }
227

            
228
        // Counter should remain 2
229
1
        assert_eq!(lock.write_call_count(), 2);
230
1
    }
231

            
232
    #[test]
233
1
    fn test_read_recursive_call_counter() {
234
1
        let lock = RecursiveLock::new(42);
235

            
236
        // Initially, the counter should be 0
237
1
        assert_eq!(lock.read_recursive_call_count(), 0);
238

            
239
        // After one read_recursive call, counter should be 1
240
        {
241
1
            let _guard = lock.read_recursive().unwrap();
242
1
            assert_eq!(lock.read_recursive_call_count(), 1);
243
        }
244

            
245
        // After another read_recursive call, counter should be 2
246
        {
247
1
            let _guard = lock.read_recursive().unwrap();
248
1
            assert_eq!(lock.read_recursive_call_count(), 2);
249
        }
250

            
251
        // Test nested recursive reads increment the counter
252
        {
253
1
            let _guard1 = lock.read_recursive().unwrap();
254
1
            assert_eq!(lock.read_recursive_call_count(), 3);
255

            
256
1
            let _guard2 = lock.read_recursive().unwrap();
257
1
            assert_eq!(lock.read_recursive_call_count(), 4);
258
        }
259

            
260
        // Counter should remain 4
261
1
        assert_eq!(lock.read_recursive_call_count(), 4);
262
1
    }
263

            
264
    #[test]
265
1
    fn test_both_counters() {
266
1
        let lock = RecursiveLock::new(42);
267

            
268
        // Initially, both counters should be 0
269
1
        assert_eq!(lock.write_call_count(), 0);
270
1
        assert_eq!(lock.read_recursive_call_count(), 0);
271

            
272
        // Call write and check counters
273
        {
274
1
            let _guard = lock.write().unwrap();
275
1
            assert_eq!(lock.write_call_count(), 1);
276
1
            assert_eq!(lock.read_recursive_call_count(), 0);
277
        }
278

            
279
        // Call read_recursive and check counters
280
        {
281
1
            let _guard = lock.read_recursive().unwrap();
282
1
            assert_eq!(lock.write_call_count(), 1);
283
1
            assert_eq!(lock.read_recursive_call_count(), 1);
284
        }
285

            
286
        // Call write again
287
        {
288
1
            let _guard = lock.write().unwrap();
289
1
            assert_eq!(lock.write_call_count(), 2);
290
1
            assert_eq!(lock.read_recursive_call_count(), 1);
291
        }
292

            
293
        // Call read_recursive multiple times
294
        {
295
1
            let _guard1 = lock.read_recursive().unwrap();
296
1
            let _guard2 = lock.read_recursive().unwrap();
297
1
            assert_eq!(lock.write_call_count(), 2);
298
1
            assert_eq!(lock.read_recursive_call_count(), 3);
299
        }
300
1
    }
301
}