1
use std::alloc::Layout;
2
use std::array;
3
use std::cell::RefCell;
4
use std::fmt;
5
use std::mem::ManuallyDrop;
6
use std::ptr::NonNull;
7

            
8
use allocator_api2::alloc::AllocError;
9
use allocator_api2::alloc::Allocator;
10
use itertools::Itertools;
11

            
12
/// This is a slab allocator or also called block allocator for a concrete type
13
/// `T`. It stores blocks of `Size` to minimize the overhead of individual
14
/// memory allocations (which are typically in the range of one or two words).
15
///
16
/// Behaves like `Allocator`, except that it only allocates for layouts of `T`.
17
///
18
/// # Details
19
///
20
/// Internally stores blocks of `N` elements
21
pub struct BlockAllocator<T, const N: usize> {
22
    /// This is the block that contains unoccupied entries.
23
    head_block: Option<Box<Block<T, N>>>,
24

            
25
    /// The start of the freelist
26
    free: Option<NonNull<Entry<T>>>,
27
}
28

            
29
impl<T, const N: usize> Default for BlockAllocator<T, N> {
30
    fn default() -> Self {
31
        Self::new()
32
    }
33
}
34

            
35
impl<T, const N: usize> BlockAllocator<T, N> {
36
633
    pub fn new() -> Self {
37
633
        Self {
38
633
            head_block: None,
39
633
            free: None,
40
633
        }
41
633
    }
42

            
43
    /// Similar to the [Allocator] trait, but instead of passing a layout we allocate just an object of type `T`.
44
100000
    pub fn allocate_object(&mut self) -> Result<NonNull<T>, AllocError> {
45
100000
        if let Some(free) = self.free {
46
            unsafe {
47
                // Safety: By invariant of the freelist the next must point to the next free element.
48
                self.free = Some(free.as_ref().next);
49
            }
50
            return Ok(free.cast::<T>());
51
100000
        }
52

            
53
        // After this the block definitely has space for at least one element
54
100000
        let block = match &mut self.head_block {
55
99900
            Some(block) => {
56
99900
                if block.is_full() {
57
300
                    let mut new_block = Box::new(Block::new());
58
300
                    std::mem::swap(block, &mut new_block);
59
300
                    block.next = Some(new_block);
60
99600
                }
61

            
62
99900
                block
63
            }
64
            None => {
65
100
                let block = Box::new(Block::new());
66
100
                self.head_block = Some(block);
67
100
                self.head_block.as_mut().expect("Is initialized in the previous line")
68
            }
69
        };
70

            
71
100000
        let length = block.length;
72
100000
        block.length += 1;
73
        unsafe {
74
            // Safety: We take a pointer (value does not have to be initialized) to a ManuallDrop<T>, which has the same layout as T.
75
100000
            Ok(NonNull::new_unchecked(
76
100000
                &mut block.data[length].data as *mut ManuallyDrop<T> as *mut T,
77
100000
            ))
78
        }
79
100000
    }
80

            
81
    /// Deallocate the given pointer.
82
    pub fn deallocate_object(&mut self, ptr: NonNull<T>) {
83
        if let Some(free) = self.free {
84
            unsafe { (ptr.cast::<Entry<_>>()).as_mut().next = free }
85
        }
86

            
87
        self.free = Some(ptr.cast());
88
    }
89

            
90
    /// Returns an iterator over the free list entries.
91
    fn iter_free(&self) -> FreeListIterator<T> {
92
        FreeListIterator { current: self.free }
93
    }
94
}
95

            
96
/// A type that can implement `Allocator` using the underlying `BlockAllocator`.
97
pub struct AllocBlock<T, const N: usize> {
98
    block_allocator: RefCell<BlockAllocator<T, N>>,
99
}
100

            
101
impl<T, const N: usize> Default for AllocBlock<T, N> {
102
    fn default() -> Self {
103
        Self::new()
104
    }
105
}
106

            
107
impl<T, const N: usize> AllocBlock<T, N> {
108
    /// Creates a new `AllocBlock`.
109
533
    pub fn new() -> Self {
110
533
        Self {
111
533
            block_allocator: RefCell::new(BlockAllocator::new()),
112
533
        }
113
533
    }
114
}
115

            
116
unsafe impl<T, const N: usize> Allocator for AllocBlock<T, N> {
117
    fn allocate(&self, layout: std::alloc::Layout) -> Result<NonNull<[u8]>, AllocError> {
118
        debug_assert_eq!(
119
            layout,
120
            Layout::new::<T>(),
121
            "The requested layout should match the type T"
122
        );
123

            
124
        let ptr = self.block_allocator.borrow_mut().allocate_object()?;
125

            
126
        // Convert NonNull<T> to NonNull<[u8]> with the correct size
127
        let byte_ptr = ptr.cast::<u8>();
128
        let slice_ptr = NonNull::slice_from_raw_parts(byte_ptr, std::mem::size_of::<T>());
129

            
130
        Ok(slice_ptr)
131
    }
132

            
133
    unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
134
        debug_assert_eq!(
135
            layout,
136
            Layout::new::<T>(),
137
            "The requested layout should match the type T"
138
        );
139
        self.block_allocator.borrow_mut().deallocate_object(ptr.cast::<T>());
140
    }
141
}
142

            
143
union Entry<T> {
144
    /// Stores the actual element.
145
    data: ManuallyDrop<T>,
146

            
147
    /// If the element is free, this points to the next entry in the freelist.
148
    next: NonNull<Entry<T>>,
149
}
150

            
151
/// We maintain a list of a blocks that store N elements each.
152
struct Block<T, const N: usize> {
153
    data: [Entry<T>; N],
154

            
155
    /// Keeps track of the number of elements in the block that are used.
156
    length: usize,
157

            
158
    /// Pointer to the next block.
159
    next: Option<Box<Block<T, N>>>,
160
}
161

            
162
impl<T, const N: usize> Block<T, N> {
163
400
    fn new() -> Self {
164
        Self {
165
400
            data: array::from_fn(|_i| Entry {
166
102400
                next: NonNull::dangling(),
167
102400
            }),
168
            length: 0,
169
400
            next: None,
170
        }
171
400
    }
172

            
173
    /// Returns true iff this block is full.
174
99900
    fn is_full(&self) -> bool {
175
99900
        self.length == N
176
99900
    }
177
}
178

            
179
/// Iterator over the free list entries in a BlockAllocator.
180
struct FreeListIterator<T> {
181
    current: Option<NonNull<Entry<T>>>,
182
}
183

            
184
impl<T> Iterator for FreeListIterator<T> {
185
    type Item = NonNull<Entry<T>>;
186

            
187
    fn next(&mut self) -> Option<Self::Item> {
188
        if let Some(current) = self.current {
189
            // Safety: We assume the free list is properly constructed and current points to a valid Entry
190
            unsafe {
191
                self.current = Some(current.as_ref().next);
192
            }
193
            Some(current)
194
        } else {
195
            None
196
        }
197
    }
198
}
199

            
200
impl<T, const N: usize> fmt::Debug for BlockAllocator<T, N> {
201
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
202
        write!(f, "freelist = {:?}", self.iter_free().format(", "))
203
    }
204
}
205

            
206
#[cfg(test)]
207
mod tests {
208
    use super::*;
209

            
210
    use rand::Rng;
211

            
212
    use merc_utilities::random_test;
213

            
214
    #[test]
215
    #[cfg_attr(miri, ignore)]
216
1
    fn test_block_allocator() {
217
100
        random_test(100, |rng| {
218
100
            let mut allocator: BlockAllocator<u64, 256> = BlockAllocator::new();
219

            
220
100
            let mut allocated = Vec::new();
221
100000
            for _ in 0..1000 {
222
100000
                let ptr = allocator.allocate_object().unwrap();
223
100000
                unsafe {
224
100000
                    ptr.as_ptr().write(rng.random());
225
100000
                }
226
100000
                allocated.push(ptr);
227
100000
            }
228

            
229
            // Remove various elements and check whether all the remaining elements are valid
230
100
        })
231
1
    }
232
}