1
use std::alloc::GlobalAlloc;
2
use std::alloc::Layout;
3
use std::alloc::System;
4
use std::fmt;
5
use std::ptr::NonNull;
6
use std::sync::atomic::AtomicUsize;
7
use std::sync::atomic::Ordering;
8

            
9
use allocator_api2::alloc::AllocError;
10
use allocator_api2::alloc::Allocator;
11

            
12
use merc_io::BytesFormatter;
13

            
14
/// An allocator that can be used to count performance metrics
15
/// on the allocations performed.
16
pub struct AllocCounter {
17
    number_of_allocations: AtomicUsize,
18
    size_of_allocations: AtomicUsize,
19

            
20
    total_number_of_allocations: AtomicUsize,
21
    total_size_of_allocations: AtomicUsize,
22

            
23
    max_number_of_allocations: AtomicUsize,
24
    max_size_of_allocations: AtomicUsize,
25
}
26

            
27
pub struct AllocMetrics {
28
    number_of_allocations: usize,
29
    size_of_allocations: usize,
30

            
31
    total_number_of_allocations: usize,
32
    total_size_of_allocations: usize,
33

            
34
    max_number_of_allocations: usize,
35
    max_size_of_allocations: usize,
36
}
37

            
38
impl fmt::Display for AllocMetrics {
39
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40
        writeln!(
41
            f,
42
            "Current allocations: {} (size: {} bytes)",
43
            self.number_of_allocations,
44
            BytesFormatter(self.size_of_allocations)
45
        )?;
46
        writeln!(
47
            f,
48
            "Total allocations: {} (size: {} bytes)",
49
            self.total_number_of_allocations,
50
            BytesFormatter(self.total_size_of_allocations)
51
        )?;
52
        write!(
53
            f,
54
            "Peak allocations: {} (size: {} bytes)",
55
            self.max_number_of_allocations,
56
            BytesFormatter(self.max_size_of_allocations)
57
        )
58
    }
59
}
60

            
61
impl Default for AllocCounter {
62
    /// Creates a new allocation counter with all metrics initialized to zero
63
    fn default() -> Self {
64
        Self::new()
65
    }
66
}
67

            
68
impl AllocCounter {
69
    /// Creates a new allocation counter with all metrics initialized to zero
70
3
    pub const fn new() -> Self {
71
3
        Self {
72
3
            number_of_allocations: AtomicUsize::new(0),
73
3
            size_of_allocations: AtomicUsize::new(0),
74
3
            total_number_of_allocations: AtomicUsize::new(0),
75
3
            total_size_of_allocations: AtomicUsize::new(0),
76
3
            max_number_of_allocations: AtomicUsize::new(0),
77
3
            max_size_of_allocations: AtomicUsize::new(0),
78
3
        }
79
3
    }
80

            
81
    /// Returns the performance metrics of the allocator
82
4
    pub fn get_metrics(&self) -> AllocMetrics {
83
4
        AllocMetrics {
84
4
            number_of_allocations: self.number_of_allocations.load(Ordering::Relaxed),
85
4
            size_of_allocations: self.size_of_allocations.load(Ordering::Relaxed),
86
4

            
87
4
            total_number_of_allocations: self.total_number_of_allocations.load(Ordering::Relaxed),
88
4
            total_size_of_allocations: self.total_size_of_allocations.load(Ordering::Relaxed),
89
4

            
90
4
            max_number_of_allocations: self.max_number_of_allocations.load(Ordering::Relaxed),
91
4
            max_size_of_allocations: self.max_size_of_allocations.load(Ordering::Relaxed),
92
4
        }
93
4
    }
94

            
95
    /// Resets all current allocation metrics (but preserves total and max metrics)
96
1
    pub fn reset(&self) {
97
1
        self.number_of_allocations.store(0, Ordering::Relaxed);
98
1
        self.size_of_allocations.store(0, Ordering::Relaxed);
99
1
    }
100

            
101
4001
    fn alloc(&self, layout: Layout) -> *mut u8 {
102
4001
        let ret = unsafe { System.alloc(layout) };
103

            
104
4001
        if !ret.is_null() {
105
            // Update allocation counters atomically
106
4001
            self.number_of_allocations.fetch_add(1, Ordering::Relaxed);
107
4001
            self.size_of_allocations.fetch_add(layout.size(), Ordering::Relaxed);
108

            
109
4001
            self.total_number_of_allocations.fetch_add(1, Ordering::Relaxed);
110
4001
            self.total_size_of_allocations
111
4001
                .fetch_add(layout.size(), Ordering::Relaxed);
112

            
113
            // Update max counters using compare-and-swap loops
114
4001
            let current_allocs = self.number_of_allocations.load(Ordering::Relaxed);
115
4001
            let mut max_allocs = self.max_number_of_allocations.load(Ordering::Relaxed);
116
4001
            while current_allocs > max_allocs {
117
2
                match self.max_number_of_allocations.compare_exchange_weak(
118
2
                    max_allocs,
119
2
                    current_allocs,
120
2
                    Ordering::Relaxed,
121
2
                    Ordering::Relaxed,
122
2
                ) {
123
2
                    Ok(_) => break,
124
                    Err(val) => max_allocs = val,
125
                }
126
            }
127

            
128
4001
            let current_size = self.size_of_allocations.load(Ordering::Relaxed);
129
4001
            let mut max_size = self.max_size_of_allocations.load(Ordering::Relaxed);
130
4001
            while current_size > max_size {
131
2
                match self.max_size_of_allocations.compare_exchange_weak(
132
2
                    max_size,
133
2
                    current_size,
134
2
                    Ordering::Relaxed,
135
2
                    Ordering::Relaxed,
136
2
                ) {
137
2
                    Ok(_) => break,
138
                    Err(val) => max_size = val,
139
                }
140
            }
141
        }
142

            
143
4001
        ret
144
4001
    }
145

            
146
4001
    fn dealloc(&self, ptr: *mut u8, layout: Layout) {
147
4001
        unsafe {
148
4001
            System.dealloc(ptr, layout);
149
4001
        }
150

            
151
        // Update allocation counters atomically
152
4001
        self.number_of_allocations.fetch_sub(1, Ordering::Relaxed);
153
4001
        self.size_of_allocations.fetch_sub(layout.size(), Ordering::Relaxed);
154
4001
    }
155
}
156

            
157
unsafe impl GlobalAlloc for AllocCounter {
158
    unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
159
        self.alloc(layout)
160
    }
161

            
162
    unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
163
        self.dealloc(ptr, layout)
164
    }
165
}
166

            
167
unsafe impl Allocator for AllocCounter {
168
    fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
169
        let ptr = self.alloc(layout);
170

            
171
        if ptr.is_null() {
172
            return Err(AllocError);
173
        }
174

            
175
        let slice_ptr = std::ptr::slice_from_raw_parts_mut(ptr, layout.size());
176
        Ok(NonNull::new(slice_ptr).expect("The resulting ptr will never be null"))
177
    }
178

            
179
    unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
180
        self.dealloc(ptr.as_ptr(), layout)
181
    }
182
}
183

            
184
#[cfg(test)]
185
mod tests {
186
    use super::*;
187
    use std::sync::Arc;
188
    use std::thread;
189

            
190
    #[test]
191
1
    fn test_basic_allocation_tracking() {
192
1
        let counter = AllocCounter::new();
193
1
        let metrics = counter.get_metrics();
194

            
195
        // Initially all metrics should be zero
196
1
        assert_eq!(metrics.number_of_allocations, 0);
197
1
        assert_eq!(metrics.size_of_allocations, 0);
198
1
        assert_eq!(metrics.total_number_of_allocations, 0);
199
1
        assert_eq!(metrics.total_size_of_allocations, 0);
200
1
        assert_eq!(metrics.max_number_of_allocations, 0);
201
1
        assert_eq!(metrics.max_size_of_allocations, 0);
202
1
    }
203

            
204
    #[test]
205
1
    fn test_thread_safety() {
206
1
        let counter = Arc::new(AllocCounter::new());
207
1
        let num_threads = 4;
208
1
        let allocations_per_thread = 1000;
209

            
210
1
        let handles: Vec<_> = (0..num_threads)
211
4
            .map(|_| {
212
4
                let counter = Arc::clone(&counter);
213
4
                thread::spawn(move || {
214
4
                    for _ in 0..allocations_per_thread {
215
4000
                        let layout = Layout::from_size_align(64, 8).unwrap();
216
4000
                        let ptr = counter.alloc(layout);
217
4000
                        if !ptr.is_null() {
218
4000
                            counter.dealloc(ptr, layout);
219
4000
                        }
220
                    }
221
4
                })
222
4
            })
223
1
            .collect();
224

            
225
4
        for handle in handles {
226
4
            handle.join().unwrap();
227
4
        }
228

            
229
1
        let metrics = counter.get_metrics();
230

            
231
        // After all threads complete, current allocations should be 0
232
1
        assert_eq!(metrics.number_of_allocations, 0);
233
1
        assert_eq!(metrics.size_of_allocations, 0);
234

            
235
        // Total allocations should equal num_threads * allocations_per_thread
236
1
        assert_eq!(
237
            metrics.total_number_of_allocations,
238
1
            num_threads * allocations_per_thread
239
        );
240
1
        assert_eq!(
241
            metrics.total_size_of_allocations,
242
1
            num_threads * allocations_per_thread * 64
243
        );
244
1
    }
245

            
246
    #[test]
247
1
    fn test_reset_functionality() {
248
1
        let counter = AllocCounter::new();
249

            
250
        // Simulate some allocations
251
1
        let layout = Layout::from_size_align(32, 8).unwrap();
252
1
        let ptr = counter.alloc(layout);
253

            
254
1
        let metrics_before = counter.get_metrics();
255
1
        assert!(metrics_before.number_of_allocations > 0);
256

            
257
1
        counter.reset();
258
1
        let metrics_after = counter.get_metrics();
259

            
260
        // Current metrics should be reset
261
1
        assert_eq!(metrics_after.number_of_allocations, 0);
262
1
        assert_eq!(metrics_after.size_of_allocations, 0);
263

            
264
        // Total and max metrics should be preserved
265
1
        assert_eq!(
266
            metrics_after.total_number_of_allocations,
267
            metrics_before.total_number_of_allocations
268
        );
269
1
        assert_eq!(
270
            metrics_after.max_number_of_allocations,
271
            metrics_before.max_number_of_allocations
272
        );
273

            
274
        // Clean up
275
1
        if !ptr.is_null() {
276
1
            counter.dealloc(ptr, layout);
277
1
        }
278
1
    }
279
}