1
use std::alloc::handle_alloc_error;
2
use std::collections::hash_map::RandomState;
3
use std::fmt;
4
use std::hash::BuildHasher;
5
use std::hash::Hash;
6
use std::hash::Hasher;
7
use std::ops::Deref;
8
use std::ptr::NonNull;
9
use std::ptr::addr_eq;
10
#[cfg(debug_assertions)]
11
use std::sync::Arc;
12

            
13
use allocator_api2::alloc::Allocator;
14
use allocator_api2::alloc::Global;
15
use allocator_api2::alloc::Layout;
16
use dashmap::DashSet;
17
use equivalent::Equivalent;
18

            
19
use crate::AllocatorDst;
20
use crate::SliceDst;
21

            
22
/// A safe wrapper around a raw pointer that allows immutable dereferencing. This remains valid as long as the `StablePointerSet` remains
23
/// valid, which is not managed by the borrow checker.
24
///
25
/// Comparisons are based on the pointer's address, not the value it points to.
26
#[repr(C)]
27
#[derive(Clone)]
28
pub struct StablePointer<T: ?Sized> {
29
    /// The raw pointer to the element.
30
    /// This is a NonNull pointer, which means it is guaranteed to be non-null.
31
    ptr: NonNull<T>,
32

            
33
    /// Keep track of reference counts in debug mode.
34
    #[cfg(debug_assertions)]
35
    reference_counter: Arc<()>,
36
}
37

            
38
/// Check that the Option<StablePointer> is the same size as a usize for release builds.
39
#[cfg(not(debug_assertions))]
40
const _: () = assert!(std::mem::size_of::<Option<StablePointer<usize>>>() == std::mem::size_of::<usize>());
41

            
42
impl<T: ?Sized> StablePointer<T> {
43
    /// Returns true if this is the last reference to the pointer.
44
4
    fn is_last_reference(&self) -> bool {
45
        #[cfg(debug_assertions)]
46
        {
47
            // There is a reference in the table, and the one of `self.ptr`.
48
4
            Arc::strong_count(&self.reference_counter) == 2
49
        }
50
        #[cfg(not(debug_assertions))]
51
        {
52
            true
53
        }
54
4
    }
55

            
56
    /// Creates a new StablePointer from a raw pointer.
57
    ///
58
    /// # Safety
59
    ///
60
    /// The caller must ensure that the pointer is valid and points to a valid T that outlives the StablePointer.
61
    pub unsafe fn from_ptr(ptr: NonNull<T>) -> Self {
62
        Self {
63
            ptr,
64
            #[cfg(debug_assertions)]
65
            reference_counter: Arc::new(()),
66
        }
67
    }
68

            
69
    /// Creates a new StablePointer from a raw pointer while preserving the
70
    /// debug reference counter from another StablePointer.
71
    ///
72
    /// # Safety
73
    ///
74
    /// The caller must ensure that `ptr` points to the same allocation as
75
    /// `source` (potentially with a different pointee type/metadata) and
76
    /// remains valid for at least as long as any derived StablePointer.
77
114106384
    pub unsafe fn from_related_ptr<U: ?Sized>(ptr: NonNull<U>, #[allow(unused)] source: &Self) -> StablePointer<U> {
78
114106384
        StablePointer {
79
114106384
            ptr,
80
114106384
            #[cfg(debug_assertions)]
81
114106384
            reference_counter: source.reference_counter.clone(),
82
114106384
        }
83
114106384
    }
84

            
85
    /// Returns public access to the underlying pointer.
86
120545228
    pub fn ptr(&self) -> NonNull<T> {
87
120545228
        self.ptr
88
120545228
    }
89
}
90

            
91
impl<T: ?Sized> PartialEq for StablePointer<T> {
92
21107023447
    fn eq(&self, other: &Self) -> bool {
93
        // SAFETY: This is safe because we are comparing pointers, which is a valid operation.
94
21107023447
        addr_eq(self.ptr.as_ptr(), other.ptr.as_ptr())
95
21107023447
    }
96
}
97

            
98
impl<T: ?Sized> Eq for StablePointer<T> {}
99

            
100
impl<T: ?Sized> Ord for StablePointer<T> {
101
259134560
    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
102
        // SAFETY: This is safe because we are comparing pointers, which is a valid operation.
103
259134560
        self.ptr.as_ptr().cast::<()>().cmp(&(other.ptr.as_ptr().cast::<()>()))
104
259134560
    }
105
}
106

            
107
impl<T: ?Sized> PartialOrd for StablePointer<T> {
108
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
109
        // SAFETY: This is safe because we are comparing pointers, which is a valid operation.
110
        Some(self.cmp(other))
111
    }
112
}
113

            
114
impl<T: ?Sized> Hash for StablePointer<T> {
115
622726649
    fn hash<H: Hasher>(&self, state: &mut H) {
116
        // SAFETY: This is safe because we are hashing pointers, which is a valid operation.
117
622726649
        self.ptr.hash(state);
118
622726649
    }
119
}
120

            
121
unsafe impl<T: ?Sized + Send> Send for StablePointer<T> {}
122
unsafe impl<T: ?Sized + Sync> Sync for StablePointer<T> {}
123

            
124
impl<T: ?Sized> StablePointer<T> {
125
    /// Returns a copy of the StablePointer.
126
    ///
127
    /// # Safety
128
    /// The caller must ensure the pointer points to a valid T that outlives the returned StablePointer.
129
17284994393
    pub fn copy(&self) -> Self {
130
17284994393
        Self {
131
17284994393
            ptr: self.ptr,
132
17284994393
            #[cfg(debug_assertions)]
133
17284994393
            reference_counter: self.reference_counter.clone(),
134
17284994393
        }
135
17284994393
    }
136

            
137
    /// Creates a new StablePointer from a boxed element.
138
131974775
    fn from_entry(entry: &Entry<T>) -> Self {
139
131974775
        Self {
140
131974775
            ptr: entry.ptr,
141
131974775
            #[cfg(debug_assertions)]
142
131974775
            reference_counter: entry.reference_counter.clone(),
143
131974775
        }
144
131974775
    }
145
}
146

            
147
impl<T: ?Sized> Deref for StablePointer<T> {
148
    type Target = T;
149

            
150
15581404658
    fn deref(&self) -> &Self::Target {
151
        // The caller must ensure the pointer points to a valid T that outlives this StablePointer.
152
15581404658
        unsafe { self.ptr.as_ref() }
153
15581404658
    }
154
}
155

            
156
impl<T: fmt::Debug + ?Sized> fmt::Debug for StablePointer<T> {
157
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158
        f.debug_tuple("StablePointer").field(&self.ptr).finish()
159
    }
160
}
161

            
162
/// A set that provides stable pointers to its elements.
163
///
164
/// Similar to `IndexedSet` but uses pointers instead of indices for direct access to elements.
165
/// Elements are stored in stable memory locations using a custom allocator, with the hash set maintaining references.
166
///
167
/// The set can use a custom hasher type for potentially better performance based on workload characteristics.
168
/// Uses an allocator for memory management, defaulting to the global allocator.
169
pub struct StablePointerSet<T: ?Sized, S = RandomState, A = Global>
170
where
171
    T: Hash + Eq + SliceDst,
172
    S: BuildHasher + Clone,
173
    A: Allocator + AllocatorDst,
174
{
175
    index: DashSet<Entry<T>, S>,
176

            
177
    allocator: A,
178
}
179

            
180
impl<T: ?Sized> Default for StablePointerSet<T, RandomState, Global>
181
where
182
    T: Hash + Eq + SliceDst,
183
{
184
    fn default() -> Self {
185
        Self::new()
186
    }
187
}
188

            
189
impl<T: ?Sized> StablePointerSet<T, RandomState, Global>
190
where
191
    T: Hash + Eq + SliceDst,
192
{
193
    /// Creates an empty StablePointerSet with the default hasher and global allocator.
194
8
    pub fn new() -> Self {
195
8
        Self {
196
8
            index: DashSet::default(),
197
8
            allocator: Global,
198
8
        }
199
8
    }
200

            
201
    /// Creates an empty StablePointerSet with the specified capacity, default hasher, and global allocator.
202
    pub fn with_capacity(capacity: usize) -> Self {
203
        Self {
204
            index: DashSet::with_capacity_and_hasher(capacity, RandomState::new()),
205
            allocator: Global,
206
        }
207
    }
208
}
209

            
210
impl<T: ?Sized, S> StablePointerSet<T, S, Global>
211
where
212
    T: Hash + Eq + SliceDst,
213
    S: BuildHasher + Clone,
214
{
215
    /// Creates an empty StablePointerSet with the specified hasher and global allocator.
216
    pub fn with_hasher(hasher: S) -> Self {
217
        Self {
218
            index: DashSet::with_hasher(hasher),
219
            allocator: Global,
220
        }
221
    }
222

            
223
    /// Creates an empty StablePointerSet with the specified capacity, hasher, and global allocator.
224
959
    pub fn with_capacity_and_hasher(capacity: usize, hasher: S) -> Self {
225
959
        Self {
226
959
            index: DashSet::with_capacity_and_hasher(capacity, hasher),
227
959
            allocator: Global,
228
959
        }
229
959
    }
230
}
231

            
232
impl<T: ?Sized, S, A> StablePointerSet<T, S, A>
233
where
234
    T: Hash + Eq + SliceDst,
235
    S: BuildHasher + Clone,
236
    A: Allocator + AllocatorDst,
237
{
238
    /// Creates an empty StablePointerSet with the specified allocator and default hasher.
239
1
    pub fn new_in(allocator: A) -> Self
240
1
    where
241
1
        S: Default,
242
    {
243
1
        Self {
244
1
            index: DashSet::with_hasher(S::default()),
245
1
            allocator,
246
1
        }
247
1
    }
248

            
249
    /// Creates an empty StablePointerSet with the specified capacity, allocator, and default hasher.
250
    pub fn with_capacity_in(capacity: usize, allocator: A) -> Self
251
    where
252
        S: Default,
253
    {
254
        Self {
255
            index: DashSet::with_capacity_and_hasher(capacity, S::default()),
256
            allocator,
257
        }
258
    }
259

            
260
    /// Creates an empty StablePointerSet with the specified hasher and allocator.
261
960
    pub fn with_hasher_in(hasher: S, allocator: A) -> Self {
262
960
        Self {
263
960
            index: DashSet::with_hasher(hasher),
264
960
            allocator,
265
960
        }
266
960
    }
267

            
268
    /// Creates an empty StablePointerSet with the specified capacity, hasher, and allocator.
269
8631
    pub fn with_capacity_and_hasher_in(capacity: usize, hasher: S, allocator: A) -> Self {
270
8631
        Self {
271
8631
            index: DashSet::with_capacity_and_hasher(capacity, hasher),
272
8631
            allocator,
273
8631
        }
274
8631
    }
275

            
276
    /// Returns the number of elements in the set.
277
8156556
    pub fn len(&self) -> usize {
278
8156556
        self.index.len()
279
8156556
    }
280

            
281
    /// Returns true if the set is empty.
282
    pub fn is_empty(&self) -> bool {
283
        self.len() == 0
284
    }
285

            
286
    /// Returns the capacity of the set.
287
    pub fn capacity(&self) -> usize {
288
        self.index.capacity()
289
    }
290

            
291
    /// Inserts an element into the set using an equivalent value.
292
    ///
293
    /// This version takes a reference to an equivalent value and creates the value to insert
294
    /// only if it doesn't already exist in the set. Returns a stable pointer to the element
295
    /// and a boolean indicating whether the element was inserted.
296
4984652
    pub fn insert_equiv<'a, Q>(&self, value: &'a Q) -> (StablePointer<T>, bool)
297
4984652
    where
298
4984652
        Q: Hash + Equivalent<T>,
299
4984652
        T: From<&'a Q>,
300
    {
301
4984652
        debug_assert!(std::mem::size_of::<T>() > 0, "Zero-sized types not supported");
302

            
303
        // Check if we already have this value
304
4984652
        let raw_ptr = self.get(value);
305

            
306
4984652
        if let Some(ptr) = raw_ptr {
307
            // We already have this value, return pointer to existing element
308
4904096
            return (ptr, false);
309
80556
        }
310

            
311
        // Allocate memory for the value
312
80556
        let layout = Layout::new::<T>();
313
80556
        let ptr = self.allocator.allocate(layout).expect("Allocation failed").cast::<T>();
314

            
315
        // Write the value to the allocated memory
316
80556
        unsafe {
317
80556
            ptr.as_ptr().write(value.into());
318
80556
        }
319

            
320
        // Insert new value using allocator
321
80556
        let entry = Entry::new(ptr);
322
80556
        let result = StablePointer::from_entry(&entry);
323

            
324
        // First add to storage, then to index
325
80556
        let inserted = self.index.insert(entry);
326
80556
        if !inserted {
327
            let entry = Entry::new(ptr);
328
            let element = self
329
                .index
330
                .get(&entry)
331
                .expect("Insertion failed, so entry must be in the set");
332

            
333
            // Call the drop function
334
            unsafe { std::ptr::drop_in_place(ptr.as_ptr()) };
335

            
336
            // Remove the entry we just created since it was not inserted
337
            unsafe {
338
                self.allocator.deallocate(ptr.cast(), layout);
339
            }
340

            
341
            return (StablePointer::from_entry(&element), false);
342
80556
        }
343

            
344
        // Insertion succeeded.
345
80556
        (result, true)
346
4984652
    }
347

            
348
    /// Returns `true` if the set contains a value.
349
13
    pub fn contains<Q>(&self, value: &Q) -> bool
350
13
    where
351
13
        T: Eq + Hash,
352
13
        Q: ?Sized + Hash + Equivalent<T>,
353
    {
354
13
        self.get(value).is_some()
355
13
    }
356

            
357
    /// Returns a stable pointer to a value in the set, if present.
358
    ///
359
    /// Searches for a value equal to the provided reference and returns a pointer to the stored element.
360
    /// The returned pointer remains valid until the element is removed from the set.
361
91726275
    pub fn get<Q>(&self, value: &Q) -> Option<StablePointer<T>>
362
91726275
    where
363
91726275
        T: Eq + Hash,
364
91726275
        Q: ?Sized + Hash + Equivalent<T>,
365
    {
366
        // Find the boxed element that contains an equivalent value
367
91726275
        let boxed = self.index.get(&LookUp(value))?;
368

            
369
        // SAFETY: The pointer is valid as long as the set is valid.
370
77710668
        let ptr = StablePointer::from_entry(boxed.key());
371
77710668
        Some(ptr)
372
91726275
    }
373

            
374
    /// Returns an iterator over the elements of the set.
375
2
    pub fn iter(&self) -> impl Iterator<Item = &T> {
376
8
        self.index.iter().map(|boxed| unsafe { boxed.ptr.as_ref() })
377
2
    }
378

            
379
    /// Removes an element from the set using its stable pointer.
380
    ///
381
    /// Returns true if the element was found and removed.
382
4
    pub fn remove(&self, pointer: StablePointer<T>) -> bool {
383
4
        debug_assert!(
384
4
            pointer.is_last_reference(),
385
            "Pointer must be the last reference to the element"
386
        );
387

            
388
        // SAFETY: This is the last reference to the element, so it is safe to remove it.
389
4
        let t = pointer.deref();
390
4
        let result = self.index.remove(&LookUp(t));
391

            
392
4
        if let Some(ptr) = result {
393
            // SAFETY: We have exclusive access during drop and the pointer is valid
394
4
            unsafe {
395
4
                self.drop_and_deallocate_entry(ptr.ptr);
396
4
            }
397
4
            true
398
        } else {
399
            false
400
        }
401
4
    }
402

            
403
    /// Retains only the elements specified by the predicate, modifying the set in-place.
404
    ///
405
    /// The predicate closure is called with a mutable reference to each element and must
406
    /// return true if the element should remain in the set.
407
    ///
408
    /// # Safety
409
    ///
410
    /// It invalidates any StablePointers to removed elements
411
4272478
    pub fn retain<F>(&self, mut predicate: F)
412
4272478
    where
413
4272478
        F: FnMut(&StablePointer<T>) -> bool,
414
    {
415
        // First pass: determine what to keep/remove without modifying the collection
416
40212586
        self.index.retain(|element| {
417
40212586
            let ptr = StablePointer::from_entry(element);
418

            
419
40212586
            if !predicate(&ptr) {
420
                // Note that retain can remove disconnect graphs of elements in
421
                // one go, so it is not necessarily the case that there is only
422
                // one reference to the element.
423

            
424
                // SAFETY: We have exclusive access during drop and the pointer
425
                // is valid
426
13779612
                unsafe {
427
13779612
                    self.drop_and_deallocate_entry(ptr.ptr);
428
13779612
                }
429
13779612
                return false;
430
26432974
            }
431

            
432
26432974
            true
433
40212586
        });
434
4272478
    }
435

            
436
    /// Returns mutable access to the underlying allocator.
437
3884070
    pub fn allocator_mut(&mut self) -> &mut A {
438
3884070
        &mut self.allocator
439
3884070
    }
440

            
441
    /// Drops the element at the given pointer and deallocates its memory.
442
    ///
443
    /// # Safety
444
    ///
445
    /// This requires that ptr can be dereferenced, so it must point to a valid element.
446
13779631
    unsafe fn drop_and_deallocate_entry(&self, ptr: NonNull<T>) {
447
        // SAFETY: We have exclusive access during drop and the pointer is valid
448
13779631
        let length = unsafe { T::length(ptr.as_ref()) };
449
13779631
        unsafe {
450
13779631
            // Drop the value in place before deallocating
451
13779631
            std::ptr::drop_in_place(ptr.as_ptr());
452
13779631
        }
453
13779631
        self.allocator.deallocate_slice_dst(ptr, length);
454
13779631
    }
455
}
456

            
457
impl<T: ?Sized + SliceDst, S, A> StablePointerSet<T, S, A>
458
where
459
    T: Hash + Eq,
460
    S: BuildHasher + Clone,
461
    A: Allocator + AllocatorDst + Sync,
462
{
463
    /// Clears the set, removing all values and invalidating all pointers.
464
    ///
465
    /// # Safety
466
    /// This is unsafe because it invalidates all pointers to the elements in the set.
467
    pub fn clear(&self) {
468
        #[cfg(debug_assertions)]
469
        debug_assert!(
470
            self.index.iter().all(|x| Arc::strong_count(&x.reference_counter) == 1),
471
            "All pointers must be the last reference to the element"
472
        );
473

            
474
        // Manually deallocate all entries before clearing
475
        for entry in self.index.iter() {
476
            // SAFETY: We have exclusive access during drop and the pointer is valid
477
            unsafe {
478
                self.drop_and_deallocate_entry(entry.ptr);
479
            }
480
        }
481

            
482
        self.index.clear();
483
        debug_assert!(self.index.is_empty(), "Index should be empty after clearing");
484
    }
485

            
486
    /// Inserts an element into the set using an equivalent value.
487
    ///
488
    /// This version takes a reference to an equivalent value and creates the
489
    /// value to insert only if it doesn't already exist in the set. Returns a
490
    /// stable pointer to the element and a boolean indicating whether the
491
    /// element was inserted.
492
    ///
493
    /// # Safety
494
    ///
495
    /// construct must fully initialize the value at the given pointer,
496
    /// otherwise it may lead to undefined behavior.
497
    pub unsafe fn insert_equiv_dst<'a, Q, C>(
498
        &self,
499
        value: &'a Q,
500
        length: usize,
501
        construct: C,
502
    ) -> (StablePointer<T>, bool)
503
    where
504
        Q: Hash + Equivalent<T>,
505
        C: Fn(*mut T, &'a Q),
506
    {
507
        // Check if we already have this value
508
        let raw_ptr = self.get(value);
509

            
510
        if let Some(ptr) = raw_ptr {
511
            // We already have this value, return pointer to existing element
512
            return (ptr, false);
513
        }
514

            
515
        // Allocate space for the entry and construct it
516
        let mut ptr = self
517
            .allocator
518
            .allocate_slice_dst::<T>(length)
519
            .unwrap_or_else(|_| handle_alloc_error(Layout::new::<()>()));
520

            
521
        unsafe {
522
            construct(ptr.as_mut(), value);
523
        }
524

            
525
        loop {
526
            let entry = Entry::new(ptr);
527
            let ptr = StablePointer::from_entry(&entry);
528

            
529
            let inserted = self.index.insert(entry);
530
            if !inserted {
531
                // Add the result to the storage, it could be at this point that the entry was inserted by another thread. So
532
                // this insertion might actually fail, in which case we should clean up the created entry and return the old pointer.
533

            
534
                // TODO: I suppose this can go wrong with begin_insert(x); insert(x); remove(x); end_insert(x) chain.
535
                if let Some(existing_ptr) = self.get(value) {
536
                    // SAFETY: We have exclusive access during drop and the pointer is valid
537
                    unsafe {
538
                        self.drop_and_deallocate_entry(ptr.ptr);
539
                    }
540

            
541
                    return (existing_ptr, false);
542
                }
543
            } else {
544
                // Value was successfully inserted
545
                return (ptr, true);
546
            }
547
        }
548
    }
549
}
550

            
551
impl<T, S, A> StablePointerSet<T, S, A>
552
where
553
    T: Hash + Eq + SliceDst,
554
    S: BuildHasher + Clone,
555
    A: Allocator + AllocatorDst,
556
{
557
    /// Inserts an element into the set.
558
    ///
559
    /// If the set did not have this value present, `true` is returned along
560
    /// with a stable pointer to the inserted element.
561
    ///
562
    /// If the set already had this value present, `false` is returned along
563
    /// with a stable pointer to the existing element.
564
86741607
    pub fn insert(&self, value: T) -> (StablePointer<T>, bool) {
565
86741607
        debug_assert!(std::mem::size_of::<T>() > 0, "Zero-sized types not supported");
566

            
567
86741607
        if let Some(ptr) = self.get(&value) {
568
            // We already have this value, return pointer to existing element
569
72806562
            return (ptr, false);
570
13935045
        }
571

            
572
13935045
        let ptr = self
573
13935045
            .allocator
574
13935045
            .allocate(Layout::new::<T>())
575
13935045
            .unwrap_or_else(|_| handle_alloc_error(Layout::new::<T>()))
576
13935045
            .cast::<T>();
577

            
578
13935045
        unsafe {
579
13935045
            ptr.write(value);
580
13935045
        }
581

            
582
        // Insert new value using allocator
583
13935045
        let entry = Entry::new(ptr);
584
13935045
        let ptr = StablePointer::from_entry(&entry);
585

            
586
        // First add to storage, then to index
587
13935045
        let inserted = self.index.insert(entry);
588
13935045
        if !inserted {
589
            let entry = Entry::new(ptr.ptr());
590
            let element = self
591
                .index
592
                .get(&entry)
593
                .expect("Insertion failed, so entry must be in the set");
594

            
595
            // Drop and deallocate the allocation we created since it was not inserted.
596
            unsafe { std::ptr::drop_in_place(ptr.ptr().as_ptr()) };
597
            unsafe { self.allocator.deallocate(ptr.ptr().cast(), Layout::new::<T>()) };
598

            
599
            return (StablePointer::from_entry(&element), false);
600
13935045
        }
601

            
602
13935045
        (ptr, true)
603
86741607
    }
604
}
605

            
606
impl<T: ?Sized, S, A> Drop for StablePointerSet<T, S, A>
607
where
608
    T: Hash + Eq + SliceDst,
609
    S: BuildHasher + Clone,
610
    A: Allocator + AllocatorDst,
611
{
612
10
    fn drop(&mut self) {
613
        #[cfg(debug_assertions)]
614
10
        debug_assert!(
615
15
            self.index.iter().all(|x| Arc::strong_count(&x.reference_counter) == 1),
616
            "All pointers must be the last reference to the element"
617
        );
618

            
619
        // Manually drop and deallocate all entries
620
15
        for entry in self.index.iter() {
621
15
            unsafe {
622
15
                self.drop_and_deallocate_entry(entry.ptr);
623
15
            }
624
        }
625
10
    }
626
}
627

            
628
/// A helper struct to store the allocated element in the set.
629
///
630
/// Uses manual allocation instead of Box for custom allocator support.
631
/// Optionally stores a reference counter for debugging purposes in debug builds.
632
struct Entry<T: ?Sized> {
633
    /// Pointer to the allocated value
634
    ptr: NonNull<T>,
635

            
636
    #[cfg(debug_assertions)]
637
    reference_counter: Arc<()>,
638
}
639

            
640
unsafe impl<T: ?Sized + Send> Send for Entry<T> {}
641
unsafe impl<T: ?Sized + Sync> Sync for Entry<T> {}
642

            
643
impl<T: ?Sized> Entry<T> {
644
    /// Creates a new entry by allocating memory for the value using the provided allocator.
645
14022073
    fn new(ptr: NonNull<T>) -> Self {
646
14022073
        Self {
647
14022073
            ptr,
648
14022073
            #[cfg(debug_assertions)]
649
14022073
            reference_counter: Arc::new(()),
650
14022073
        }
651
14022073
    }
652
}
653

            
654
impl<T: ?Sized> Deref for Entry<T> {
655
    type Target = T;
656

            
657
91932477
    fn deref(&self) -> &Self::Target {
658
        // SAFETY: The pointer is valid as long as the Entry exists
659
91932477
        unsafe { self.ptr.as_ref() }
660
91932477
    }
661
}
662

            
663
impl<T: PartialEq + ?Sized> PartialEq for Entry<T> {
664
45825
    fn eq(&self, other: &Self) -> bool {
665
45825
        **self == **other
666
45825
    }
667
}
668

            
669
impl<T: Hash + ?Sized> Hash for Entry<T> {
670
14048493
    fn hash<H: Hasher>(&self, state: &mut H) {
671
14048493
        (**self).hash(state);
672
14048493
    }
673
}
674

            
675
impl<T: Eq + ?Sized> Eq for Entry<T> {}
676

            
677
/// A helper struct to look up elements in the set using a reference.
678
#[derive(Hash, PartialEq, Eq)]
679
struct LookUp<'a, T: ?Sized>(&'a T);
680

            
681
impl<T: ?Sized, Q: ?Sized> Equivalent<Entry<T>> for LookUp<'_, Q>
682
where
683
    Q: Equivalent<T>,
684
{
685
77761374
    fn equivalent(&self, other: &Entry<T>) -> bool {
686
77761374
        self.0.equivalent(&**other)
687
77761374
    }
688
}
689

            
690
#[cfg(test)]
691
mod tests {
692
    use std::hash::BuildHasherDefault;
693
    use std::hash::Hash;
694
    use std::hash::Hasher;
695
    use std::hash::RandomState;
696

            
697
    use allocator_api2::alloc::System;
698
    use dashmap::Equivalent;
699
    use rustc_hash::FxHasher;
700

            
701
    use crate::StablePointerSet;
702

            
703
    #[test]
704
1
    fn test_insert_and_get() {
705
1
        let set = StablePointerSet::new();
706

            
707
        // Insert a value and ensure we get it back
708
1
        let (ptr1, inserted) = set.insert(42);
709
1
        assert!(inserted);
710
1
        assert_eq!(*ptr1, 42);
711

            
712
        // Insert the same value and ensure we get the same pointer
713
1
        let (ptr2, inserted) = set.insert(42);
714
1
        assert!(!inserted);
715
1
        assert_eq!(*ptr2, 42);
716

            
717
        // Pointers to the same value should be identical
718
1
        assert_eq!(ptr1, ptr2);
719

            
720
        // Verify that we have only one element
721
1
        assert_eq!(set.len(), 1);
722
1
    }
723

            
724
    #[test]
725
1
    fn test_contains() {
726
1
        let set = StablePointerSet::new();
727
1
        set.insert(42);
728
1
        set.insert(100);
729

            
730
1
        assert!(set.contains(&42));
731
1
        assert!(set.contains(&100));
732
1
        assert!(!set.contains(&200));
733
1
    }
734

            
735
    #[test]
736
1
    fn test_get() {
737
1
        let set = StablePointerSet::new();
738
1
        set.insert(42);
739
1
        set.insert(100);
740

            
741
1
        let ptr = set.get(&42).expect("Value should exist");
742
1
        assert_eq!(*ptr, 42);
743

            
744
1
        let ptr = set.get(&100).expect("Value should exist");
745
1
        assert_eq!(*ptr, 100);
746

            
747
1
        assert!(set.get(&200).is_none(), "Value should not exist");
748
1
    }
749

            
750
    #[test]
751
1
    fn test_iteration() {
752
1
        let set = StablePointerSet::new();
753
1
        set.insert(1);
754
1
        set.insert(2);
755
1
        set.insert(3);
756

            
757
1
        let mut values: Vec<i32> = set.iter().copied().collect();
758
1
        values.sort();
759

            
760
1
        assert_eq!(values, vec![1, 2, 3]);
761
1
    }
762

            
763
    #[test]
764
1
    fn test_stable_pointer_set_insert_equiv_ref() {
765
        #[derive(PartialEq, Eq, Debug)]
766
        struct TestValue {
767
            id: i32,
768
            name: String,
769
        }
770

            
771
        impl From<&i32> for TestValue {
772
2
            fn from(id: &i32) -> Self {
773
2
                TestValue {
774
2
                    id: *id,
775
2
                    name: format!("Value-{}", id),
776
2
                }
777
2
            }
778
        }
779

            
780
        impl Hash for TestValue {
781
2
            fn hash<H: Hasher>(&self, state: &mut H) {
782
2
                self.id.hash(state);
783
2
            }
784
        }
785

            
786
        impl Equivalent<TestValue> for i32 {
787
1
            fn equivalent(&self, key: &TestValue) -> bool {
788
1
                *self == key.id
789
1
            }
790
        }
791

            
792
1
        let set: StablePointerSet<TestValue> = StablePointerSet::new();
793

            
794
        // Insert using equivalent reference (i32 -> TestValue)
795
1
        let (ptr1, inserted) = set.insert_equiv(&42);
796
1
        assert!(inserted, "Value should be inserted");
797
1
        assert_eq!(ptr1.id, 42);
798
1
        assert_eq!(ptr1.name, "Value-42");
799

            
800
        // Try inserting the same value again via equivalent
801
1
        let (ptr2, inserted) = set.insert_equiv(&42);
802
1
        assert!(!inserted, "Value should not be inserted again");
803
1
        assert_eq!(ptr1, ptr2, "Should return the same pointer");
804

            
805
        // Insert a different value
806
1
        let (ptr3, inserted) = set.insert_equiv(&100);
807
1
        assert!(inserted, "New value should be inserted");
808
1
        assert_eq!(ptr3.id, 100);
809
1
        assert_eq!(ptr3.name, "Value-100");
810

            
811
        // Ensure we have exactly two elements
812
1
        assert_eq!(set.len(), 2);
813
1
    }
814

            
815
    #[test]
816
1
    fn test_stable_pointer_deref() {
817
1
        let set = StablePointerSet::new();
818
1
        let (ptr, _) = set.insert(42);
819

            
820
        // Test dereferencing
821
1
        let value: &i32 = &ptr;
822
1
        assert_eq!(*value, 42);
823

            
824
        // Test methods on the dereferenced value
825
1
        assert_eq!((*ptr).checked_add(10), Some(52));
826
1
    }
827

            
828
    #[test]
829
1
    fn test_stable_pointer_set_remove() {
830
1
        let set = StablePointerSet::new();
831

            
832
        // Insert values
833
1
        let (ptr1, _) = set.insert(42);
834
1
        let (ptr2, _) = set.insert(100);
835
1
        assert_eq!(set.len(), 2);
836

            
837
        // Remove one value
838
1
        assert!(set.remove(ptr1));
839
1
        assert_eq!(set.len(), 1);
840

            
841
        // Remove other value
842
1
        assert!(set.remove(ptr2));
843
1
        assert_eq!(set.len(), 0);
844
1
    }
845

            
846
    #[test]
847
1
    fn test_stable_pointer_set_retain() {
848
1
        let set = StablePointerSet::new();
849

            
850
        // Insert values
851
1
        set.insert(1);
852
1
        let (ptr2, _) = set.insert(2);
853
1
        set.insert(3);
854
1
        let (ptr4, _) = set.insert(4);
855
1
        assert_eq!(set.len(), 4);
856

            
857
        // Retain only even numbers
858
4
        set.retain(|x| **x % 2 == 0);
859

            
860
        // Verify results
861
1
        assert_eq!(set.len(), 2);
862
1
        assert!(!set.contains(&1));
863
1
        assert!(set.contains(&2));
864
1
        assert!(!set.contains(&3));
865
1
        assert!(set.contains(&4));
866

            
867
        // Verify that removed pointers are invalid and remaining are valid
868
1
        assert!(set.remove(ptr2));
869
1
        assert!(set.remove(ptr4));
870
1
    }
871

            
872
    #[test]
873
1
    fn test_stable_pointer_set_custom_allocator() {
874
        // Test with System allocator
875
1
        let set: StablePointerSet<i32, RandomState, System> = StablePointerSet::new_in(System);
876

            
877
        // Insert some values
878
1
        let (ptr1, inserted) = set.insert(42);
879
1
        assert!(inserted);
880
1
        let (ptr2, inserted) = set.insert(100);
881
1
        assert!(inserted);
882

            
883
        // Check that everything works as expected
884
1
        assert_eq!(set.len(), 2);
885
1
        assert_eq!(*ptr1, 42);
886
1
        assert_eq!(*ptr2, 100);
887

            
888
        // Test contains
889
1
        assert!(set.contains(&42));
890
1
        assert!(set.contains(&100));
891
1
        assert!(!set.contains(&200));
892
1
    }
893

            
894
    #[test]
895
1
    fn test_stable_pointer_set_custom_hasher_and_allocator() {
896
        // Use both custom hasher and allocator
897
1
        let set: StablePointerSet<i32, BuildHasherDefault<FxHasher>, System> =
898
1
            StablePointerSet::with_hasher_in(BuildHasherDefault::<FxHasher>::default(), System);
899

            
900
        // Insert some values
901
1
        let (ptr1, inserted) = set.insert(42);
902
1
        assert!(inserted);
903
1
        let (ptr2, inserted) = set.insert(100);
904
1
        assert!(inserted);
905

            
906
        // Check that everything works as expected
907
1
        assert_eq!(set.len(), 2);
908
1
        assert_eq!(*ptr1, 42);
909
1
        assert_eq!(*ptr2, 100);
910

            
911
        // Test contains
912
1
        assert!(set.contains(&42));
913
1
        assert!(set.contains(&100));
914
1
        assert!(!set.contains(&200));
915
1
    }
916
}