1
//! Provides a generational index implementation that offers generation checking
2
//! in debug builds while having zero runtime cost in release builds.
3

            
4
use std::fmt;
5
use std::hash::Hash;
6
use std::hash::Hasher;
7
use std::ops::Deref;
8

            
9
/// A generational index that stores both an index and a generation counter.
10
/// The generation is only tracked in debug builds to avoid overhead in release.
11
///
12
/// This allows detecting use-after-free scenarios in debug mode while
13
/// maintaining zero overhead in release mode.
14
#[repr(C)]
15
#[derive(Copy, Clone)]
16
pub struct GenerationalIndex<I: Copy + Into<usize> = usize> {
17
    /// The raw index value
18
    index: I,
19

            
20
    #[cfg(debug_assertions)]
21
    /// Generation counter, only available in debug builds
22
    generation: usize,
23
}
24

            
25
impl Default for GenerationalIndex<usize> {
26
352305182
    fn default() -> Self {
27
352305182
        GenerationalIndex {
28
352305182
            index: 0,
29
352305182
            #[cfg(debug_assertions)]
30
352305182
            generation: usize::MAX,
31
352305182
        }
32
352305182
    }
33
}
34

            
35
impl<I: Copy + Into<usize>> Deref for GenerationalIndex<I> {
36
    type Target = I;
37

            
38
    /// Deref implementation to access the underlying index value.
39
1103445690
    fn deref(&self) -> &Self::Target {
40
1103445690
        &self.index
41
1103445690
    }
42
}
43

            
44
impl<I: Copy + Into<usize>> GenerationalIndex<I> {
45
    /// Creates a new generational index with the specified index.
46
    #[cfg(debug_assertions)]
47
735048567
    fn new(index: I, generation: usize) -> Self {
48
735048567
        Self { index, generation }
49
735048567
    }
50

            
51
    /// Creates a new generational index with the specified index and generation.
52
    #[cfg(not(debug_assertions))]
53
    fn new(index: I) -> Self {
54
        Self { index }
55
    }
56
}
57

            
58
/// A counter that keeps track of generational indices.
59
/// This helps manage generations of indices to detect use-after-free and similar issues.
60
#[derive(Clone, Debug, Default)]
61
pub struct GenerationCounter {
62
    /// Current generation count, only stored in debug builds
63
    #[cfg(debug_assertions)]
64
    current_generation: Vec<usize>,
65
}
66

            
67
impl GenerationCounter {
68
    /// Creates a new generation counter.
69
208108
    pub fn new() -> Self {
70
        #[cfg(debug_assertions)]
71
        {
72
208108
            Self {
73
208108
                current_generation: Vec::new(),
74
208108
            }
75
        }
76

            
77
        #[cfg(not(debug_assertions))]
78
        Self {}
79
208108
    }
80
}
81

            
82
impl GenerationCounter {
83
    /// Creates a new generational index with the given index and the next generation.
84
684713033
    pub fn create_index<I>(&mut self, index: I) -> GenerationalIndex<I>
85
684713033
    where
86
684713033
        I: Copy + Into<usize>,
87
    {
88
        #[cfg(debug_assertions)]
89
        {
90
684713033
            let generation = if self.current_generation.len() <= index.into() {
91
72126043
                self.current_generation.resize(index.into() + 1, 0);
92
72126043
                0
93
            } else {
94
612586990
                let generation = &mut self.current_generation[index.into()];
95
612586990
                *generation = generation.wrapping_add(1);
96
612586990
                *generation
97
            };
98

            
99
684713033
            GenerationalIndex::new(index, generation)
100
        }
101

            
102
        #[cfg(not(debug_assertions))]
103
        {
104
            GenerationalIndex::new(index)
105
        }
106
684713033
    }
107

            
108
    /// Returns a generational index with the given index and the current generation.
109
50335534
    pub fn recall_index<I>(&self, index: I) -> GenerationalIndex<I>
110
50335534
    where
111
50335534
        I: Copy + Into<usize>,
112
    {
113
        #[cfg(debug_assertions)]
114
        {
115
50335534
            GenerationalIndex::new(index, self.current_generation[index.into()])
116
        }
117
        #[cfg(not(debug_assertions))]
118
        {
119
            GenerationalIndex::new(index)
120
        }
121
50335534
    }
122

            
123
    /// Returns the underlying index, checks if the generation is correct.
124
1440265820
    pub fn get_index<I>(&self, index: GenerationalIndex<I>) -> I
125
1440265820
    where
126
1440265820
        I: Copy + Into<usize> + fmt::Debug,
127
    {
128
        #[cfg(debug_assertions)]
129
        {
130
1440265820
            if self.current_generation[index.index.into()] != index.generation {
131
                panic!("Attempting to access an invalid index: {index:?}");
132
1440265820
            }
133
        }
134

            
135
1440265820
        index.index
136
1440265820
    }
137
}
138

            
139
// Standard trait implementations for GenerationalIndex
140

            
141
impl<I> PartialEq for GenerationalIndex<I>
142
where
143
    I: Copy + Into<usize> + Eq,
144
{
145
387686218
    fn eq(&self, other: &Self) -> bool {
146
        // TODO: Should we have a default index?
147
        #[cfg(debug_assertions)]
148
        {
149
387686218
            if self.generation == usize::MAX || other.generation == usize::MAX {
150
9613789
                return false;
151
378072429
            }
152

            
153
378072429
            debug_assert_eq!(
154
                self.generation, other.generation,
155
                "Comparing indices of different generations"
156
            );
157
        }
158

            
159
378072428
        self.index == other.index
160
387686217
    }
161
}
162

            
163
impl<I> Eq for GenerationalIndex<I> where I: Copy + Into<usize> + Eq {}
164

            
165
impl<I> PartialOrd for GenerationalIndex<I>
166
where
167
    I: Copy + Into<usize> + PartialOrd + Eq,
168
{
169
2378487
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
170
        #[cfg(debug_assertions)]
171
2378487
        debug_assert_eq!(
172
            self.generation, other.generation,
173
            "Comparing indices of different generations"
174
        );
175

            
176
2378487
        self.index.partial_cmp(&other.index)
177
2378487
    }
178
}
179

            
180
impl<I> Ord for GenerationalIndex<I>
181
where
182
    I: Copy + Into<usize> + Eq + Ord,
183
{
184
    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
185
        #[cfg(debug_assertions)]
186
        debug_assert_eq!(
187
            self.generation, other.generation,
188
            "Comparing indices of different generations"
189
        );
190
        self.index.cmp(&other.index)
191
    }
192
}
193

            
194
impl<I> Hash for GenerationalIndex<I>
195
where
196
    I: Copy + Into<usize> + Hash,
197
{
198
37723070
    fn hash<H: Hasher>(&self, state: &mut H) {
199
37723070
        self.index.hash(state);
200
37723070
    }
201
}
202

            
203
impl<I> fmt::Debug for GenerationalIndex<I>
204
where
205
    I: Copy + Into<usize> + fmt::Debug,
206
{
207
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208
        #[cfg(debug_assertions)]
209
        {
210
            write!(
211
                f,
212
                "GenerationalIndex(index: {:?}, generation: {})",
213
                self.index, self.generation
214
            )
215
        }
216
        #[cfg(not(debug_assertions))]
217
        {
218
            write!(f, "GenerationalIndex(index: {:?})", self.index)
219
        }
220
    }
221
}
222

            
223
impl fmt::Display for GenerationalIndex<usize> {
224
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
225
        write!(f, "{}", self.index)
226
    }
227
}
228

            
229
#[cfg(test)]
230
mod tests {
231
    #[cfg(debug_assertions)]
232
    use super::*;
233

            
234
    #[test]
235
    #[should_panic]
236
    #[cfg(debug_assertions)]
237
1
    fn test_generational_index_equality() {
238
1
        let mut counter = GenerationCounter::new();
239
1
        let idx1 = counter.create_index(42usize);
240
1
        let idx2 = counter.create_index(42usize);
241
1
        let idx4 = counter.create_index(43usize);
242

            
243
1
        let idx3 = counter.recall_index(42usize);
244

            
245
1
        assert_ne!(idx1, idx4);
246
1
        assert_eq!(idx2, idx3);
247

            
248
        // This panics since idx1 and idx2 are from different generations
249
1
        assert_eq!(idx1, idx2);
250
    }
251
}