1
use std::alloc::handle_alloc_error;
2
use std::collections::hash_map::RandomState;
3
use std::fmt;
4
use std::hash::BuildHasher;
5
use std::hash::Hash;
6
use std::hash::Hasher;
7
use std::ops::Deref;
8
use std::ptr::NonNull;
9
use std::ptr::addr_eq;
10
#[cfg(debug_assertions)]
11
use std::sync::Arc;
12

            
13
use allocator_api2::alloc::Allocator;
14
use allocator_api2::alloc::Global;
15
use allocator_api2::alloc::Layout;
16
use dashmap::DashSet;
17
use equivalent::Equivalent;
18

            
19
use crate::AllocatorDst;
20
use crate::SliceDst;
21

            
22
/// A safe wrapper around a raw pointer that allows immutable dereferencing. This remains valid as long as the `StablePointerSet` remains
23
/// valid, which is not managed by the borrow checker.
24
///
25
/// Comparisons are based on the pointer's address, not the value it points to.
26
#[repr(C)]
27
#[derive(Clone)]
28
pub struct StablePointer<T: ?Sized> {
29
    /// The raw pointer to the element.
30
    /// This is a NonNull pointer, which means it is guaranteed to be non-null.
31
    ptr: NonNull<T>,
32

            
33
    /// Keep track of reference counts in debug mode.
34
    #[cfg(debug_assertions)]
35
    reference_counter: Arc<()>,
36
}
37

            
38
/// Check that the Option<StablePointer> is the same size as a usize for release builds.
39
#[cfg(not(debug_assertions))]
40
const _: () = assert!(std::mem::size_of::<Option<StablePointer<usize>>>() == std::mem::size_of::<usize>());
41

            
42
impl<T: ?Sized> StablePointer<T> {
43
    /// Returns true if this is the last reference to the pointer.
44
4
    fn is_last_reference(&self) -> bool {
45
        #[cfg(debug_assertions)]
46
        {
47
            // There is a reference in the table, and the one of `self.ptr`.
48
4
            Arc::strong_count(&self.reference_counter) == 2
49
        }
50
        #[cfg(not(debug_assertions))]
51
        {
52
            true
53
        }
54
4
    }
55

            
56
    /// Creates a new StablePointer from a raw pointer.
57
    ///
58
    /// # Safety
59
    ///
60
    /// The caller must ensure that the pointer is valid and points to a valid T that outlives the StablePointer.
61
    pub unsafe fn from_ptr(ptr: NonNull<T>) -> Self {
62
        Self {
63
            ptr,
64
            #[cfg(debug_assertions)]
65
            reference_counter: Arc::new(()),
66
        }
67
    }
68

            
69
    /// Returns public access to the underlying pointer.
70
    pub fn ptr(&self) -> NonNull<T> {
71
        self.ptr
72
    }
73
}
74

            
75
impl<T: ?Sized> PartialEq for StablePointer<T> {
76
12413595136
    fn eq(&self, other: &Self) -> bool {
77
        // SAFETY: This is safe because we are comparing pointers, which is a valid operation.
78
12413595136
        addr_eq(self.ptr.as_ptr(), other.ptr.as_ptr())
79
12413595136
    }
80
}
81

            
82
impl<T: ?Sized> Eq for StablePointer<T> {}
83

            
84
impl<T: ?Sized> Ord for StablePointer<T> {
85
159124390
    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
86
        // SAFETY: This is safe because we are comparing pointers, which is a valid operation.
87
159124390
        self.ptr.as_ptr().cast::<()>().cmp(&(other.ptr.as_ptr().cast::<()>()))
88
159124390
    }
89
}
90

            
91
impl<T: ?Sized> PartialOrd for StablePointer<T> {
92
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
93
        // SAFETY: This is safe because we are comparing pointers, which is a valid operation.
94
        Some(self.cmp(other))
95
    }
96
}
97

            
98
impl<T: ?Sized> Hash for StablePointer<T> {
99
210322322
    fn hash<H: Hasher>(&self, state: &mut H) {
100
        // SAFETY: This is safe because we are hashing pointers, which is a valid operation.
101
210322322
        self.ptr.hash(state);
102
210322322
    }
103
}
104

            
105
unsafe impl<T: ?Sized + Send> Send for StablePointer<T> {}
106
unsafe impl<T: ?Sized + Sync> Sync for StablePointer<T> {}
107

            
108
impl<T: ?Sized> StablePointer<T> {
109
    /// Returns a copy of the StablePointer.
110
    ///
111
    /// # Safety
112
    /// The caller must ensure the pointer points to a valid T that outlives the returned StablePointer.
113
9909354017
    pub fn copy(&self) -> Self {
114
9909354017
        Self {
115
9909354017
            ptr: self.ptr,
116
9909354017
            #[cfg(debug_assertions)]
117
9909354017
            reference_counter: self.reference_counter.clone(),
118
9909354017
        }
119
9909354017
    }
120

            
121
    /// Creates a new StablePointer from a boxed element.
122
57131715
    fn from_entry(entry: &Entry<T>) -> Self {
123
57131715
        Self {
124
57131715
            ptr: entry.ptr,
125
57131715
            #[cfg(debug_assertions)]
126
57131715
            reference_counter: entry.reference_counter.clone(),
127
57131715
        }
128
57131715
    }
129
}
130

            
131
impl<T: ?Sized> Deref for StablePointer<T> {
132
    type Target = T;
133

            
134
8976506575
    fn deref(&self) -> &Self::Target {
135
        // The caller must ensure the pointer points to a valid T that outlives this StablePointer.
136
8976506575
        unsafe { self.ptr.as_ref() }
137
8976506575
    }
138
}
139

            
140
impl<T: fmt::Debug + ?Sized> fmt::Debug for StablePointer<T> {
141
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
142
        f.debug_tuple("StablePointer").field(&self.ptr).finish()
143
    }
144
}
145

            
146
/// A set that provides stable pointers to its elements.
147
///
148
/// Similar to `IndexedSet` but uses pointers instead of indices for direct access to elements.
149
/// Elements are stored in stable memory locations using a custom allocator, with the hash set maintaining references.
150
///
151
/// The set can use a custom hasher type for potentially better performance based on workload characteristics.
152
/// Uses an allocator for memory management, defaulting to the global allocator.
153
pub struct StablePointerSet<T: ?Sized, S = RandomState, A = Global>
154
where
155
    T: Hash + Eq + SliceDst,
156
    S: BuildHasher + Clone,
157
    A: Allocator + AllocatorDst,
158
{
159
    index: DashSet<Entry<T>, S>,
160

            
161
    allocator: A,
162
}
163

            
164
impl<T: ?Sized> Default for StablePointerSet<T, RandomState, Global>
165
where
166
    T: Hash + Eq + SliceDst,
167
{
168
    fn default() -> Self {
169
        Self::new()
170
    }
171
}
172

            
173
impl<T: ?Sized> StablePointerSet<T, RandomState, Global>
174
where
175
    T: Hash + Eq + SliceDst,
176
{
177
    /// Creates an empty StablePointerSet with the default hasher and global allocator.
178
541
    pub fn new() -> Self {
179
541
        Self {
180
541
            index: DashSet::default(),
181
541
            allocator: Global,
182
541
        }
183
541
    }
184

            
185
    /// Creates an empty StablePointerSet with the specified capacity, default hasher, and global allocator.
186
    pub fn with_capacity(capacity: usize) -> Self {
187
        Self {
188
            index: DashSet::with_capacity_and_hasher(capacity, RandomState::new()),
189
            allocator: Global,
190
        }
191
    }
192
}
193

            
194
impl<T: ?Sized, S> StablePointerSet<T, S, Global>
195
where
196
    T: Hash + Eq + SliceDst,
197
    S: BuildHasher + Clone,
198
{
199
    /// Creates an empty StablePointerSet with the specified hasher and global allocator.
200
533
    pub fn with_hasher(hasher: S) -> Self {
201
533
        Self {
202
533
            index: DashSet::with_hasher(hasher),
203
533
            allocator: Global,
204
533
        }
205
533
    }
206

            
207
    /// Creates an empty StablePointerSet with the specified capacity, hasher, and global allocator.
208
    pub fn with_capacity_and_hasher(capacity: usize, hasher: S) -> Self {
209
        Self {
210
            index: DashSet::with_capacity_and_hasher(capacity, hasher),
211
            allocator: Global,
212
        }
213
    }
214
}
215

            
216
impl<T: ?Sized, S, A> StablePointerSet<T, S, A>
217
where
218
    T: Hash + Eq + SliceDst,
219
    S: BuildHasher + Clone,
220
    A: Allocator + AllocatorDst,
221
{
222
    /// Creates an empty StablePointerSet with the specified allocator and default hasher.
223
1
    pub fn new_in(allocator: A) -> Self
224
1
    where
225
1
        S: Default,
226
    {
227
1
        Self {
228
1
            index: DashSet::with_hasher(S::default()),
229
1
            allocator,
230
1
        }
231
1
    }
232

            
233
    /// Creates an empty StablePointerSet with the specified capacity, allocator, and default hasher.
234
533
    pub fn with_capacity_in(capacity: usize, allocator: A) -> Self
235
533
    where
236
533
        S: Default,
237
    {
238
533
        Self {
239
533
            index: DashSet::with_capacity_and_hasher(capacity, S::default()),
240
533
            allocator,
241
533
        }
242
533
    }
243

            
244
    /// Creates an empty StablePointerSet with the specified hasher and allocator.
245
1
    pub fn with_hasher_in(hasher: S, allocator: A) -> Self {
246
1
        Self {
247
1
            index: DashSet::with_hasher(hasher),
248
1
            allocator,
249
1
        }
250
1
    }
251

            
252
    /// Creates an empty StablePointerSet with the specified capacity, hasher, and allocator.
253
    pub fn with_capacity_and_hasher_in(capacity: usize, hasher: S, allocator: A) -> Self {
254
        Self {
255
            index: DashSet::with_capacity_and_hasher(capacity, hasher),
256
            allocator,
257
        }
258
    }
259

            
260
    /// Returns the number of elements in the set.
261
9
    pub fn len(&self) -> usize {
262
9
        self.index.len()
263
9
    }
264

            
265
    /// Returns true if the set is empty.
266
    pub fn is_empty(&self) -> bool {
267
        self.len() == 0
268
    }
269

            
270
    /// Returns the capacity of the set.
271
    pub fn capacity(&self) -> usize {
272
        self.index.capacity()
273
    }
274

            
275
    /// Inserts an element into the set using an equivalent value.
276
    ///
277
    /// This version takes a reference to an equivalent value and creates the value to insert
278
    /// only if it doesn't already exist in the set. Returns a stable pointer to the element
279
    /// and a boolean indicating whether the element was inserted.
280
3260856
    pub fn insert_equiv<'a, Q>(&self, value: &'a Q) -> (StablePointer<T>, bool)
281
3260856
    where
282
3260856
        Q: Hash + Equivalent<T>,
283
3260856
        T: From<&'a Q>,
284
    {
285
3260856
        debug_assert!(std::mem::size_of::<T>() > 0, "Zero-sized types not supported");
286

            
287
        // Check if we already have this value
288
3260856
        let raw_ptr = self.get(value);
289

            
290
3260856
        if let Some(ptr) = raw_ptr {
291
            // We already have this value, return pointer to existing element
292
3247949
            return (ptr, false);
293
12907
        }
294

            
295
        // Allocate memory for the value
296
12907
        let layout = Layout::new::<T>();
297
12907
        let ptr = self.allocator.allocate(layout).expect("Allocation failed").cast::<T>();
298

            
299
        // Write the value to the allocated memory
300
12907
        unsafe {
301
12907
            ptr.as_ptr().write(value.into());
302
12907
        }
303

            
304
        // Insert new value using allocator
305
12907
        let entry = Entry::new(ptr);
306
12907
        let result = StablePointer::from_entry(&entry);
307

            
308
        // First add to storage, then to index
309
12907
        let inserted = self.index.insert(entry);
310
12907
        if !inserted {
311
            let entry = Entry::new(ptr);
312
            let element = self
313
                .index
314
                .get(&entry)
315
                .expect("Insertion failed, so entry must be in the set");
316
            return (StablePointer::from_entry(&element), false);
317
12907
        }
318

            
319
        // Insertion succeeded.
320
12907
        (result, true)
321
3260856
    }
322

            
323
    /// Returns `true` if the set contains a value.
324
13
    pub fn contains<Q>(&self, value: &Q) -> bool
325
13
    where
326
13
        T: Eq + Hash,
327
13
        Q: ?Sized + Hash + Equivalent<T>,
328
    {
329
13
        self.get(value).is_some()
330
13
    }
331

            
332
    /// Returns a stable pointer to a value in the set, if present.
333
    ///
334
    /// Searches for a value equal to the provided reference and returns a pointer to the stored element.
335
    /// The returned pointer remains valid until the element is removed from the set.
336
57055847
    pub fn get<Q>(&self, value: &Q) -> Option<StablePointer<T>>
337
57055847
    where
338
57055847
        T: Eq + Hash,
339
57055847
        Q: ?Sized + Hash + Equivalent<T>,
340
    {
341
        // Find the boxed element that contains an equivalent value
342
57055847
        let boxed = self.index.get(&LookUp(value))?;
343

            
344
        // SAFETY: The pointer is valid as long as the set is valid.
345
56480030
        let ptr = StablePointer::from_entry(boxed.key());
346
56480030
        Some(ptr)
347
57055847
    }
348

            
349
    /// Returns an iterator over the elements of the set.
350
2
    pub fn iter(&self) -> impl Iterator<Item = &T> {
351
8
        self.index.iter().map(|boxed| unsafe { boxed.ptr.as_ref() })
352
2
    }
353

            
354
    /// Removes an element from the set using its stable pointer.
355
    ///
356
    /// Returns true if the element was found and removed.
357
4
    pub fn remove(&self, pointer: StablePointer<T>) -> bool {
358
4
        debug_assert!(
359
4
            pointer.is_last_reference(),
360
            "Pointer must be the last reference to the element"
361
        );
362

            
363
        // SAFETY: This is the last reference to the element, so it is safe to remove it.
364
4
        let t = pointer.deref();
365
4
        let result = self.index.remove(&LookUp(t));
366

            
367
4
        if let Some(ptr) = result {
368
            // SAFETY: We have exclusive access during drop and the pointer is valid
369
4
            unsafe {
370
4
                self.drop_and_deallocate_entry(ptr.ptr);
371
4
            }
372
4
            true
373
        } else {
374
            false
375
        }
376
4
    }
377

            
378
    /// Retains only the elements specified by the predicate, modifying the set in-place.
379
    ///
380
    /// The predicate closure is called with a mutable reference to each element and must
381
    /// return true if the element should remain in the set.
382
    ///
383
    /// # Safety
384
    ///
385
    /// It invalidates any StablePointers to removed elements
386
1
    pub fn retain<F>(&self, mut predicate: F)
387
1
    where
388
1
        F: FnMut(&StablePointer<T>) -> bool,
389
    {
390
        // First pass: determine what to keep/remove without modifying the collection
391
4
        self.index.retain(|element| {
392
4
            let ptr = StablePointer::from_entry(element);
393

            
394
4
            if !predicate(&ptr) {
395
                // Note that retain can remove disconnect graphs of elements in
396
                // one go, so it is not necessarily the case that there is only
397
                // one reference to the element.
398

            
399
                // SAFETY: We have exclusive access during drop and the pointer
400
                // is valid
401
2
                unsafe {
402
2
                    self.drop_and_deallocate_entry(ptr.ptr);
403
2
                }
404
2
                return false;
405
2
            }
406

            
407
2
            true
408
4
        });
409
1
    }
410

            
411
    /// Drops the element at the given pointer and deallocates its memory.
412
    ///
413
    /// # Safety
414
    ///
415
    /// This requires that ptr can be dereferenced, so it must point to a valid element.
416
21
    unsafe fn drop_and_deallocate_entry(&self, ptr: NonNull<T>) {
417
        // SAFETY: We have exclusive access during drop and the pointer is valid
418
21
        let length = unsafe { T::length(ptr.as_ref()) };
419
21
        unsafe {
420
21
            // Drop the value in place before deallocating
421
21
            std::ptr::drop_in_place(ptr.as_ptr());
422
21
        }
423
21
        self.allocator.deallocate_slice_dst(ptr, length);
424
21
    }
425
}
426

            
427
impl<T: ?Sized + SliceDst, S, A> StablePointerSet<T, S, A>
428
where
429
    T: Hash + Eq,
430
    S: BuildHasher + Clone,
431
    A: Allocator + AllocatorDst + Sync,
432
{
433
    /// Clears the set, removing all values and invalidating all pointers.
434
    ///
435
    /// # Safety
436
    /// This is unsafe because it invalidates all pointers to the elements in the set.
437
    pub fn clear(&self) {
438
        #[cfg(debug_assertions)]
439
        debug_assert!(
440
            self.index.iter().all(|x| Arc::strong_count(&x.reference_counter) == 1),
441
            "All pointers must be the last reference to the element"
442
        );
443

            
444
        // Manually deallocate all entries before clearing
445
        for entry in self.index.iter() {
446
            // SAFETY: We have exclusive access during drop and the pointer is valid
447
            unsafe {
448
                self.drop_and_deallocate_entry(entry.ptr);
449
            }
450
        }
451

            
452
        self.index.clear();
453
        debug_assert!(self.index.is_empty(), "Index should be empty after clearing");
454
    }
455

            
456
    /// Inserts an element into the set using an equivalent value.
457
    ///
458
    /// This version takes a reference to an equivalent value and creates the
459
    /// value to insert only if it doesn't already exist in the set. Returns a
460
    /// stable pointer to the element and a boolean indicating whether the
461
    /// element was inserted.
462
    ///
463
    /// # Safety
464
    ///
465
    /// construct must fully initialize the value at the given pointer,
466
    /// otherwise it may lead to undefined behavior.
467
53794955
    pub unsafe fn insert_equiv_dst<'a, Q, C>(
468
53794955
        &self,
469
53794955
        value: &'a Q,
470
53794955
        length: usize,
471
53794955
        construct: C,
472
53794955
    ) -> (StablePointer<T>, bool)
473
53794955
    where
474
53794955
        Q: Hash + Equivalent<T>,
475
53794955
        C: Fn(*mut T, &'a Q),
476
    {
477
        // Check if we already have this value
478
53794955
        let raw_ptr = self.get(value);
479

            
480
53794955
        if let Some(ptr) = raw_ptr {
481
            // We already have this value, return pointer to existing element
482
53232070
            return (ptr, false);
483
562885
        }
484

            
485
        // Allocate space for the entry and construct it
486
562885
        let mut ptr = self
487
562885
            .allocator
488
562885
            .allocate_slice_dst::<T>(length)
489
562885
            .unwrap_or_else(|_| handle_alloc_error(Layout::new::<()>()));
490

            
491
562885
        unsafe {
492
562885
            construct(ptr.as_mut(), value);
493
562885
        }
494

            
495
        loop {
496
562885
            let entry = Entry::new(ptr);
497
562885
            let ptr = StablePointer::from_entry(&entry);
498

            
499
562885
            let inserted = self.index.insert(entry);
500
562885
            if !inserted {
501
                // Add the result to the storage, it could be at this point that the entry was inserted by another thread. So
502
                // this insertion might actually fail, in which case we should clean up the created entry and return the old pointer.
503

            
504
                // TODO: I suppose this can go wrong with begin_insert(x); insert(x); remove(x); end_insert(x) chain.
505
                if let Some(existing_ptr) = self.get(value) {
506
                    // SAFETY: We have exclusive access during drop and the pointer is valid
507
                    unsafe {
508
                        self.drop_and_deallocate_entry(ptr.ptr);
509
                    }
510

            
511
                    return (existing_ptr, false);
512
                }
513
            } else {
514
                // Value was successfully inserted
515
562885
                return (ptr, true);
516
            }
517
        }
518
53794955
    }
519
}
520

            
521
impl<T, S, A> StablePointerSet<T, S, A>
522
where
523
    T: Hash + Eq + SliceDst,
524
    S: BuildHasher + Clone,
525
    A: Allocator + AllocatorDst,
526
{
527
    /// Inserts an element into the set.
528
    ///
529
    /// If the set did not have this value present, `true` is returned along
530
    /// with a stable pointer to the inserted element.
531
    ///
532
    /// If the set already had this value present, `false` is returned along
533
    /// with a stable pointer to the existing element.
534
20
    pub fn insert(&self, value: T) -> (StablePointer<T>, bool) {
535
20
        debug_assert!(std::mem::size_of::<T>() > 0, "Zero-sized types not supported");
536

            
537
20
        if let Some(ptr) = self.get(&value) {
538
            // We already have this value, return pointer to existing element
539
1
            return (ptr, false);
540
19
        }
541

            
542
19
        let ptr = self
543
19
            .allocator
544
19
            .allocate(Layout::new::<T>())
545
19
            .unwrap_or_else(|_| handle_alloc_error(Layout::new::<T>()))
546
19
            .cast::<T>();
547

            
548
19
        unsafe {
549
19
            ptr.write(value);
550
19
        }
551

            
552
        // Insert new value using allocator
553
19
        let entry = Entry::new(ptr);
554
19
        let ptr = StablePointer::from_entry(&entry);
555

            
556
        // First add to storage, then to index
557
19
        let inserted = self.index.insert(entry);
558

            
559
19
        debug_assert!(inserted, "Value should not already exist in the index");
560

            
561
19
        (ptr, true)
562
20
    }
563
}
564

            
565
impl<T: ?Sized, S, A> Drop for StablePointerSet<T, S, A>
566
where
567
    T: Hash + Eq + SliceDst,
568
    S: BuildHasher + Clone,
569
    A: Allocator + AllocatorDst,
570
{
571
10
    fn drop(&mut self) {
572
        #[cfg(debug_assertions)]
573
10
        debug_assert!(
574
15
            self.index.iter().all(|x| Arc::strong_count(&x.reference_counter) == 1),
575
            "All pointers must be the last reference to the element"
576
        );
577

            
578
        // Manually drop and deallocate all entries
579
15
        for entry in self.index.iter() {
580
15
            unsafe {
581
15
                self.drop_and_deallocate_entry(entry.ptr);
582
15
            }
583
        }
584
10
    }
585
}
586

            
587
/// A helper struct to store the allocated element in the set.
588
///
589
/// Uses manual allocation instead of Box for custom allocator support.
590
/// Optionally stores a reference counter for debugging purposes in debug builds.
591
struct Entry<T: ?Sized> {
592
    /// Pointer to the allocated value
593
    ptr: NonNull<T>,
594

            
595
    #[cfg(debug_assertions)]
596
    reference_counter: Arc<()>,
597
}
598

            
599
unsafe impl<T: ?Sized + Send> Send for Entry<T> {}
600
unsafe impl<T: ?Sized + Sync> Sync for Entry<T> {}
601

            
602
impl<T: ?Sized> Entry<T> {
603
    /// Creates a new entry by allocating memory for the value using the provided allocator.
604
583641
    fn new(ptr: NonNull<T>) -> Self {
605
583641
        Self {
606
583641
            ptr,
607
583641
            #[cfg(debug_assertions)]
608
583641
            reference_counter: Arc::new(()),
609
583641
        }
610
583641
    }
611
}
612

            
613
impl<T: ?Sized> Deref for Entry<T> {
614
    type Target = T;
615

            
616
58272953
    fn deref(&self) -> &Self::Target {
617
        // SAFETY: The pointer is valid as long as the Entry exists
618
58272953
        unsafe { self.ptr.as_ref() }
619
58272953
    }
620
}
621

            
622
impl<T: PartialEq + ?Sized> PartialEq for Entry<T> {
623
47964
    fn eq(&self, other: &Self) -> bool {
624
47964
        **self == **other
625
47964
    }
626
}
627

            
628
impl<T: Hash + ?Sized> Hash for Entry<T> {
629
1443617
    fn hash<H: Hasher>(&self, state: &mut H) {
630
1443617
        (**self).hash(state);
631
1443617
    }
632
}
633

            
634
impl<T: Eq + ?Sized> Eq for Entry<T> {}
635

            
636
/// A helper struct to look up elements in the set using a reference.
637
#[derive(Hash, PartialEq, Eq)]
638
struct LookUp<'a, T: ?Sized>(&'a T);
639

            
640
impl<T: ?Sized, Q: ?Sized> Equivalent<Entry<T>> for LookUp<'_, Q>
641
where
642
    Q: Equivalent<T>,
643
{
644
56665332
    fn equivalent(&self, other: &Entry<T>) -> bool {
645
56665332
        self.0.equivalent(&**other)
646
56665332
    }
647
}
648

            
649
#[cfg(test)]
650
mod tests {
651
    use super::*;
652
    use allocator_api2::alloc::System;
653
    use rustc_hash::FxHasher;
654
    use std::hash::BuildHasherDefault;
655

            
656
    #[test]
657
1
    fn test_insert_and_get() {
658
1
        let set = StablePointerSet::new();
659

            
660
        // Insert a value and ensure we get it back
661
1
        let (ptr1, inserted) = set.insert(42);
662
1
        assert!(inserted);
663
1
        assert_eq!(*ptr1, 42);
664

            
665
        // Insert the same value and ensure we get the same pointer
666
1
        let (ptr2, inserted) = set.insert(42);
667
1
        assert!(!inserted);
668
1
        assert_eq!(*ptr2, 42);
669

            
670
        // Pointers to the same value should be identical
671
1
        assert_eq!(ptr1, ptr2);
672

            
673
        // Verify that we have only one element
674
1
        assert_eq!(set.len(), 1);
675
1
    }
676

            
677
    #[test]
678
1
    fn test_contains() {
679
1
        let set = StablePointerSet::new();
680
1
        set.insert(42);
681
1
        set.insert(100);
682

            
683
1
        assert!(set.contains(&42));
684
1
        assert!(set.contains(&100));
685
1
        assert!(!set.contains(&200));
686
1
    }
687

            
688
    #[test]
689
1
    fn test_get() {
690
1
        let set = StablePointerSet::new();
691
1
        set.insert(42);
692
1
        set.insert(100);
693

            
694
1
        let ptr = set.get(&42).expect("Value should exist");
695
1
        assert_eq!(*ptr, 42);
696

            
697
1
        let ptr = set.get(&100).expect("Value should exist");
698
1
        assert_eq!(*ptr, 100);
699

            
700
1
        assert!(set.get(&200).is_none(), "Value should not exist");
701
1
    }
702

            
703
    #[test]
704
1
    fn test_iteration() {
705
1
        let set = StablePointerSet::new();
706
1
        set.insert(1);
707
1
        set.insert(2);
708
1
        set.insert(3);
709

            
710
1
        let mut values: Vec<i32> = set.iter().copied().collect();
711
1
        values.sort();
712

            
713
1
        assert_eq!(values, vec![1, 2, 3]);
714
1
    }
715

            
716
    #[test]
717
1
    fn test_stable_pointer_set_insert_equiv_ref() {
718
        #[derive(PartialEq, Eq, Debug)]
719
        struct TestValue {
720
            id: i32,
721
            name: String,
722
        }
723

            
724
        impl From<&i32> for TestValue {
725
2
            fn from(id: &i32) -> Self {
726
2
                TestValue {
727
2
                    id: *id,
728
2
                    name: format!("Value-{}", id),
729
2
                }
730
2
            }
731
        }
732

            
733
        impl Hash for TestValue {
734
2
            fn hash<H: Hasher>(&self, state: &mut H) {
735
2
                self.id.hash(state);
736
2
            }
737
        }
738

            
739
        impl Equivalent<TestValue> for i32 {
740
1
            fn equivalent(&self, key: &TestValue) -> bool {
741
1
                *self == key.id
742
1
            }
743
        }
744

            
745
1
        let set: StablePointerSet<TestValue> = StablePointerSet::new();
746

            
747
        // Insert using equivalent reference (i32 -> TestValue)
748
1
        let (ptr1, inserted) = set.insert_equiv(&42);
749
1
        assert!(inserted, "Value should be inserted");
750
1
        assert_eq!(ptr1.id, 42);
751
1
        assert_eq!(ptr1.name, "Value-42");
752

            
753
        // Try inserting the same value again via equivalent
754
1
        let (ptr2, inserted) = set.insert_equiv(&42);
755
1
        assert!(!inserted, "Value should not be inserted again");
756
1
        assert_eq!(ptr1, ptr2, "Should return the same pointer");
757

            
758
        // Insert a different value
759
1
        let (ptr3, inserted) = set.insert_equiv(&100);
760
1
        assert!(inserted, "New value should be inserted");
761
1
        assert_eq!(ptr3.id, 100);
762
1
        assert_eq!(ptr3.name, "Value-100");
763

            
764
        // Ensure we have exactly two elements
765
1
        assert_eq!(set.len(), 2);
766
1
    }
767

            
768
    #[test]
769
1
    fn test_stable_pointer_deref() {
770
1
        let set = StablePointerSet::new();
771
1
        let (ptr, _) = set.insert(42);
772

            
773
        // Test dereferencing
774
1
        let value: &i32 = &ptr;
775
1
        assert_eq!(*value, 42);
776

            
777
        // Test methods on the dereferenced value
778
1
        assert_eq!((*ptr).checked_add(10), Some(52));
779
1
    }
780

            
781
    #[test]
782
1
    fn test_stable_pointer_set_remove() {
783
1
        let set = StablePointerSet::new();
784

            
785
        // Insert values
786
1
        let (ptr1, _) = set.insert(42);
787
1
        let (ptr2, _) = set.insert(100);
788
1
        assert_eq!(set.len(), 2);
789

            
790
        // Remove one value
791
1
        assert!(set.remove(ptr1));
792
1
        assert_eq!(set.len(), 1);
793

            
794
        // Remove other value
795
1
        assert!(set.remove(ptr2));
796
1
        assert_eq!(set.len(), 0);
797
1
    }
798

            
799
    #[test]
800
1
    fn test_stable_pointer_set_retain() {
801
1
        let set = StablePointerSet::new();
802

            
803
        // Insert values
804
1
        set.insert(1);
805
1
        let (ptr2, _) = set.insert(2);
806
1
        set.insert(3);
807
1
        let (ptr4, _) = set.insert(4);
808
1
        assert_eq!(set.len(), 4);
809

            
810
        // Retain only even numbers
811
4
        set.retain(|x| **x % 2 == 0);
812

            
813
        // Verify results
814
1
        assert_eq!(set.len(), 2);
815
1
        assert!(!set.contains(&1));
816
1
        assert!(set.contains(&2));
817
1
        assert!(!set.contains(&3));
818
1
        assert!(set.contains(&4));
819

            
820
        // Verify that removed pointers are invalid and remaining are valid
821
1
        assert!(set.remove(ptr2));
822
1
        assert!(set.remove(ptr4));
823
1
    }
824

            
825
    #[test]
826
1
    fn test_stable_pointer_set_custom_allocator() {
827
        // Test with System allocator
828
1
        let set: StablePointerSet<i32, RandomState, System> = StablePointerSet::new_in(System);
829

            
830
        // Insert some values
831
1
        let (ptr1, inserted) = set.insert(42);
832
1
        assert!(inserted);
833
1
        let (ptr2, inserted) = set.insert(100);
834
1
        assert!(inserted);
835

            
836
        // Check that everything works as expected
837
1
        assert_eq!(set.len(), 2);
838
1
        assert_eq!(*ptr1, 42);
839
1
        assert_eq!(*ptr2, 100);
840

            
841
        // Test contains
842
1
        assert!(set.contains(&42));
843
1
        assert!(set.contains(&100));
844
1
        assert!(!set.contains(&200));
845
1
    }
846

            
847
    #[test]
848
1
    fn test_stable_pointer_set_custom_hasher_and_allocator() {
849
        // Use both custom hasher and allocator
850
1
        let set: StablePointerSet<i32, BuildHasherDefault<FxHasher>, System> =
851
1
            StablePointerSet::with_hasher_in(BuildHasherDefault::<FxHasher>::default(), System);
852

            
853
        // Insert some values
854
1
        let (ptr1, inserted) = set.insert(42);
855
1
        assert!(inserted);
856
1
        let (ptr2, inserted) = set.insert(100);
857
1
        assert!(inserted);
858

            
859
        // Check that everything works as expected
860
1
        assert_eq!(set.len(), 2);
861
1
        assert_eq!(*ptr1, 42);
862
1
        assert_eq!(*ptr2, 100);
863

            
864
        // Test contains
865
1
        assert!(set.contains(&42));
866
1
        assert!(set.contains(&100));
867
1
        assert!(!set.contains(&200));
868
1
    }
869
}