1
use std::fmt;
2
use std::marker::PhantomData;
3

            
4
use bitvec::bitvec;
5
use bitvec::order::Lsb0;
6
use delegate::delegate;
7
use log::trace;
8

            
9
use merc_io::BytesFormatter;
10
use merc_utilities::TagIndex;
11
use merc_utilities::debug_trace;
12
use merc_utilities::is_valid_permutation;
13

            
14
/// A copy of `vec![]` that can be used for the [`crate::ByteCompressedVec`].
15
#[macro_export]
16
macro_rules! bytevec {
17
    () => {
18
        $crate::ByteCompressedVec::new()
19
    };
20
    ($elem:expr; $n:expr) => {
21
        $crate::ByteCompressedVec::from_elem($elem, $n)
22
    };
23
}
24

            
25
/// A vector data structure that stores objects in a byte compressed format.
26
///
27
/// # Details
28
///
29
/// The basic idea is that elements of type `T` implement the `CompressedEntry`
30
/// trait which allows them to be converted to and from a byte representation.
31
/// The vector dynamically adjusts the number of bytes used per entry based on
32
/// the maximum size of the entries added so far.
33
///
34
/// For numbers this means that we only store the number of bytes required to
35
/// represent the largest number added so far. Note that the number of bytes
36
/// used per entry is only increased over time as larger entries are added.
37
///
38
/// Note that the `drop()` function of `T` is never called, but we cannot
39
/// require that `T: !Drop`.
40
#[derive(Default, PartialEq, Eq, Clone)]
41
pub struct ByteCompressedVec<T> {
42
    data: Vec<u8>,
43
    bytes_per_entry: usize,
44
    _marker: PhantomData<T>,
45
}
46

            
47
impl<T: CompressedEntry> ByteCompressedVec<T> {
48
251874
    pub fn new() -> ByteCompressedVec<T> {
49
251874
        ByteCompressedVec {
50
251874
            data: Vec::new(),
51
251874
            bytes_per_entry: 0,
52
251874
            _marker: PhantomData,
53
251874
        }
54
251874
    }
55

            
56
    /// Initializes a ByteCompressedVec with the given capacity and (minimal) bytes per entry.
57
467901
    pub fn with_capacity(capacity: usize, bytes_per_entry: usize) -> ByteCompressedVec<T> {
58
467901
        ByteCompressedVec {
59
467901
            data: Vec::with_capacity(capacity * bytes_per_entry),
60
467901
            bytes_per_entry,
61
467901
            _marker: PhantomData,
62
467901
        }
63
467901
    }
64

            
65
    /// This is basically the collect() of `Vec`.
66
    ///
67
    /// However, we use it to determine the required bytes per entry in advance.
68
100
    pub fn with_iter<I>(iter: I) -> ByteCompressedVec<T>
69
100
    where
70
100
        I: ExactSizeIterator<Item = T> + Clone,
71
    {
72
100
        let bytes_per_entry = iter
73
100
            .clone()
74
532
            .fold(0, |max_bytes, entry| max_bytes.max(entry.bytes_required()));
75

            
76
100
        let mut vec = ByteCompressedVec::with_capacity(iter.len(), bytes_per_entry);
77
532
        for entry in iter {
78
532
            vec.push(entry);
79
532
        }
80
100
        vec
81
100
    }
82

            
83
    /// Adds a new entry to the vector.
84
349032647
    pub fn push(&mut self, entry: T) {
85
349032647
        self.resize_entries(entry.bytes_required());
86

            
87
        // Add the new entry to the end of the vector.
88
349032647
        let old_len = self.data.len();
89
349032647
        self.data.resize(old_len + self.bytes_per_entry, 0);
90
349032647
        entry.to_bytes(&mut self.data[old_len..]);
91
349032647
    }
92

            
93
    /// Removes the last element from the vector and returns it, or None if it is empty.
94
4409
    pub fn pop(&mut self) -> Option<T> {
95
4409
        if self.is_empty() {
96
            None
97
        } else {
98
4409
            let index = self.len() - 1;
99
4409
            let entry = self.index(index);
100
4409
            self.data.truncate(index * self.bytes_per_entry);
101
4409
            Some(entry)
102
        }
103
4409
    }
104

            
105
    /// Returns the entry at the given index.
106
11869241804
    pub fn index(&self, index: usize) -> T {
107
11869241804
        let start = index * self.bytes_per_entry;
108
11869241804
        let end = start + self.bytes_per_entry;
109
11869241804
        T::from_bytes(&self.data[start..end])
110
11869241804
    }
111

            
112
    /// Sets the entry at the given index.
113
501769450
    pub fn set(&mut self, index: usize, entry: T) {
114
501769450
        self.resize_entries(entry.bytes_required());
115

            
116
501769450
        let start = index * self.bytes_per_entry;
117
501769450
        let end = start + self.bytes_per_entry;
118
501769450
        entry.to_bytes(&mut self.data[start..end]);
119
501769450
    }
120

            
121
    /// Returns the number of elements in the vector.
122
142344916
    pub fn len(&self) -> usize {
123
142344916
        if self.bytes_per_entry == 0 {
124
655297
            0
125
        } else {
126
141689619
            debug_assert!(self.data.len().is_multiple_of(self.bytes_per_entry));
127
141689619
            self.data.len() / self.bytes_per_entry
128
        }
129
142344916
    }
130

            
131
    /// Returns true if the vector is empty.
132
344905
    pub fn is_empty(&self) -> bool {
133
344905
        self.len() == 0
134
344905
    }
135

            
136
    /// Returns metrics about memory usage of this compressed vector
137
    pub fn metrics(&self) -> CompressedVecMetrics {
138
        let element_count = self.len();
139
        let actual_memory =
140
            self.data.len() + std::mem::size_of_val(&self.bytes_per_entry) + std::mem::size_of::<PhantomData<T>>();
141
        let worst_case_memory = element_count * std::mem::size_of::<T>();
142

            
143
        CompressedVecMetrics {
144
            actual_memory,
145
            worst_case_memory,
146
        }
147
    }
148

            
149
    /// Returns an iterator over the elements in the vector.
150
104620
    pub fn iter(&self) -> ByteCompressedVecIterator<'_, T> {
151
104620
        ByteCompressedVecIterator {
152
104620
            vector: self,
153
104620
            current: 0,
154
104620
            end: self.len(),
155
104620
        }
156
104620
    }
157

            
158
    /// Returns an iterator over the elements in the vector for the begin, end range.
159
    pub fn iter_range(&self, begin: usize, end: usize) -> ByteCompressedVecIterator<'_, T> {
160
        ByteCompressedVecIterator {
161
            vector: self,
162
            current: begin,
163
            end,
164
        }
165
    }
166

            
167
    /// Updates the given entry using a closure.
168
39838750
    pub fn update<F>(&mut self, index: usize, mut update: F)
169
39838750
    where
170
39838750
        F: FnMut(&mut T),
171
    {
172
39838750
        let mut entry = self.index(index);
173
39838750
        update(&mut entry);
174
39838750
        self.set(index, entry);
175
39838750
    }
176

            
177
    /// Iterate over all elements and adapt the elements using a closure.
178
    pub fn map<F>(&mut self, mut f: F)
179
    where
180
        F: FnMut(&mut T),
181
    {
182
        for index in 0..self.len() {
183
            let mut entry = self.index(index);
184
            f(&mut entry);
185
            self.set(index, entry);
186
        }
187
    }
188

            
189
    /// Folds over the elements in the vector using the provided closure.
190
54732
    pub fn fold<B, F>(&mut self, init: B, mut f: F) -> B
191
54732
    where
192
54732
        F: FnMut(B, &mut T) -> B,
193
    {
194
54732
        let mut accumulator = init;
195
13136224
        for index in 0..self.len() {
196
13136224
            let mut element = self.index(index);
197
13136224
            accumulator = f(accumulator, &mut element);
198
13136224
            self.set(index, element);
199
13136224
        }
200
54732
        accumulator
201
54732
    }
202

            
203
    /// Permutes a vector in place according to the given permutation function.
204
    ///
205
    /// The resulting vector will be [v_p^-1(0), v_p^-1(1), ..., v_p^-1(n-1)] where p is the permutation function.
206
100
    pub fn permute<P>(&mut self, permutation: P)
207
100
    where
208
100
        P: Fn(usize) -> usize,
209
    {
210
100
        debug_assert!(
211
100
            is_valid_permutation(&permutation, self.len()),
212
            "The given permutation must be a bijective mapping"
213
        );
214

            
215
100
        let mut visited = bitvec![usize, Lsb0; 0; self.len()];
216
532
        for start in 0..self.len() {
217
532
            if visited[start] {
218
310
                continue;
219
222
            }
220

            
221
            // Perform the cycle starting at 'start'
222
222
            let mut current = start;
223

            
224
            // Keeps track of the last displaced element
225
222
            let mut old = self.index(start);
226

            
227
222
            debug_trace!("Starting new cycle at position {}", start);
228
754
            while !visited[current] {
229
532
                visited.set(current, true);
230
532
                let next = permutation(current);
231
532
                if next != current {
232
437
                    debug_trace!("Moving element from position {} to position {}", current, next);
233
437
                    let temp = self.index(next);
234
437
                    self.set(next, old);
235
437
                    old = temp;
236
437
                }
237

            
238
532
                current = next;
239
            }
240
        }
241
100
    }
242

            
243
    /// Applies a permutation to a vector in place using an index function.
244
    ///
245
    /// The resulting vector will be [v_p(0), v_p(1), ..., v_p(n-1)] where p is the index function.
246
100
    pub fn permute_indices<P>(&mut self, indices: P)
247
100
    where
248
100
        P: Fn(usize) -> usize,
249
    {
250
100
        debug_assert!(
251
100
            is_valid_permutation(&indices, self.len()),
252
            "The given permutation must be a bijective mapping"
253
        );
254

            
255
100
        let mut visited = bitvec![usize, Lsb0; 0; self.len()];
256
532
        for start in 0..self.len() {
257
532
            if visited[start] {
258
306
                continue;
259
226
            }
260

            
261
            // Follow the cycle starting at 'start'
262
226
            debug_trace!("Starting new cycle at position {}", start);
263
226
            let mut current = start;
264
226
            let original = self.index(start);
265

            
266
631
            while !visited[current] {
267
532
                visited.set(current, true);
268
532
                let next = indices(current);
269

            
270
532
                if next != current {
271
433
                    if next != start {
272
306
                        debug_trace!("Moving element from position {} to position {}", current, next);
273
306
                        self.set(current, self.index(next));
274
306
                    } else {
275
127
                        break;
276
                    }
277
99
                }
278

            
279
405
                current = next;
280
            }
281

            
282
226
            trace!("Writing original to {}", current);
283
226
            self.set(current, original);
284
        }
285
100
    }
286

            
287
    /// Applies a permutation to a vector in place using an index function.
288
    ///
289
    /// This variant is faster but requires additional memory for the intermediate result vector.
290
    pub fn permute_indices_fast<P>(&mut self, indices: P)
291
    where
292
        P: Fn(usize) -> usize,
293
    {
294
        let mut result = ByteCompressedVec::with_capacity(self.data.capacity(), self.bytes_per_entry);
295
        for index in 0..self.len() {
296
            result.push(self.index(indices(index)));
297
        }
298
        *self = result;
299
    }
300

            
301
    /// Swaps the entries at the given indices.
302
1
    pub fn swap(&mut self, index1: usize, index2: usize) {
303
1
        if index1 != index2 {
304
1
            let start1 = index1 * self.bytes_per_entry;
305
1
            let start2 = index2 * self.bytes_per_entry;
306
1

            
307
1
            // Create a temporary buffer for one entry
308
1
            let temp = T::from_bytes(&self.data[start1..start1 + self.bytes_per_entry]);
309
1

            
310
1
            // Copy entry2 to entry1's position
311
1
            self.data.copy_within(start2..start2 + self.bytes_per_entry, start1);
312
1

            
313
1
            // Copy temp to entry2's position
314
1
            temp.to_bytes(&mut self.data[start2..start2 + self.bytes_per_entry]);
315
1
        }
316
1
    }
317

            
318
    /// Resizes the vector to the given length, filling new entries with the provided value.
319
178699
    pub fn resize_with<F>(&mut self, new_len: usize, mut f: F)
320
178699
    where
321
178699
        F: FnMut() -> T,
322
    {
323
178699
        let current_len = self.len();
324
178699
        if new_len > current_len {
325
            // Preallocate the required space.
326
177619
            self.data.reserve(new_len * self.bytes_per_entry);
327
43597104
            for _ in current_len..new_len {
328
43597104
                self.push(f());
329
43597104
            }
330
1080
        } else if new_len < current_len {
331
            if new_len == 0 {
332
                self.data.clear();
333
                self.bytes_per_entry = 0;
334
            } else {
335
                // It could be that the bytes per entry is now less, but that we never reduce.
336
                self.data.truncate(new_len * self.bytes_per_entry);
337
            }
338
1080
        }
339
178699
    }
340

            
341
    /// Reserves capacity for at least additional more entries to be inserted with the given bytes per entry.
342
13227
    pub fn reserve(&mut self, additional: usize, bytes_per_entry: usize) {
343
13227
        self.resize_entries(bytes_per_entry);
344
13227
        self.data.reserve(additional * self.bytes_per_entry);
345
13227
    }
346

            
347
    /// Resizes all entries in the vector to the given length.
348
850957521
    fn resize_entries(&mut self, new_bytes_required: usize) {
349
850957521
        if new_bytes_required > self.bytes_per_entry {
350
441309
            let mut new_data: Vec<u8> = vec![0; self.len() * new_bytes_required];
351

            
352
441309
            if self.bytes_per_entry > 0 {
353
                // Resize all the existing elements because the new entry requires more bytes.
354
61739232
                for (index, entry) in self.iter().enumerate() {
355
61739232
                    let start = index * new_bytes_required;
356
61739232
                    let end = start + new_bytes_required;
357
61739232
                    entry.to_bytes(&mut new_data[start..end]);
358
61739232
                }
359
368540
            }
360

            
361
441309
            self.bytes_per_entry = new_bytes_required;
362
441309
            self.data = new_data;
363
850516212
        }
364
850957521
    }
365
}
366

            
367
impl<T: CompressedEntry + Clone> ByteCompressedVec<T> {
368
452126
    pub fn from_elem(entry: T, n: usize) -> ByteCompressedVec<T> {
369
452126
        let mut vec = ByteCompressedVec::with_capacity(n, entry.bytes_required());
370
167502404
        for _ in 0..n {
371
167502404
            vec.push(entry.clone());
372
167502404
        }
373
452126
        vec
374
452126
    }
375
}
376

            
377
/// Metrics for tracking memory usage of a ByteCompressedVec
378
#[derive(Debug, Clone)]
379
pub struct CompressedVecMetrics {
380
    /// Actual memory used by the compressed vector (in bytes)
381
    pub actual_memory: usize,
382
    /// Worst-case memory that would be used by an uncompressed vector (len * sizeof(T))
383
    pub worst_case_memory: usize,
384
}
385

            
386
impl CompressedVecMetrics {
387
    /// Calculate memory savings in bytes
388
    pub fn memory_savings(&self) -> usize {
389
        self.worst_case_memory.saturating_sub(self.actual_memory)
390
    }
391

            
392
    /// Calculate memory savings as a percentage
393
    pub fn used_percentage(&self) -> f64 {
394
        if self.worst_case_memory == 0 {
395
            0.0
396
        } else {
397
            (self.actual_memory as f64 / self.worst_case_memory as f64) * 100.0
398
        }
399
    }
400
}
401

            
402
impl fmt::Display for CompressedVecMetrics {
403
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
404
        write!(
405
            f,
406
            "memory: {} ({:.1}%), saving: {} ",
407
            BytesFormatter(self.actual_memory),
408
            self.used_percentage(),
409
            BytesFormatter(self.memory_savings()),
410
        )
411
    }
412
}
413
pub struct ByteCompressedVecIterator<'a, T> {
414
    vector: &'a ByteCompressedVec<T>,
415
    current: usize,
416
    end: usize,
417
}
418

            
419
impl<T: CompressedEntry> Iterator for ByteCompressedVecIterator<'_, T> {
420
    type Item = T;
421

            
422
130097148
    fn next(&mut self) -> Option<Self::Item> {
423
130097148
        if self.current < self.end {
424
130013430
            let result = self.vector.index(self.current);
425
130013430
            self.current += 1;
426
130013430
            Some(result)
427
        } else {
428
83718
            None
429
        }
430
130097148
    }
431
}
432

            
433
pub trait CompressedEntry {
434
    // Returns the entry as a byte vector
435
    fn to_bytes(&self, bytes: &mut [u8]);
436

            
437
    // Creates an entry from a byte vector
438
    fn from_bytes(bytes: &[u8]) -> Self;
439

            
440
    // Returns the number of bytes required to store the current entry
441
    fn bytes_required(&self) -> usize;
442
}
443

            
444
impl CompressedEntry for usize {
445
1695671822
    fn to_bytes(&self, bytes: &mut [u8]) {
446
1695671822
        let array = &self.to_le_bytes();
447
2440783456
        for (i, byte) in bytes.iter_mut().enumerate().take(usize::BITS as usize / 8) {
448
2440783456
            *byte = array[i];
449
2440783456
        }
450
1695671822
    }
451

            
452
17439714386
    fn from_bytes(bytes: &[u8]) -> Self {
453
17439714386
        let mut array = [0; 8];
454
26224941477
        for (i, byte) in bytes.iter().enumerate().take(usize::BITS as usize / 8) {
455
26224941477
            array[i] = *byte;
456
26224941477
        }
457
17439714386
        usize::from_le_bytes(array)
458
17439714386
    }
459

            
460
1588563835
    fn bytes_required(&self) -> usize {
461
1588563835
        ((self + 1).ilog2() / u8::BITS) as usize + 1
462
1588563835
    }
463
}
464

            
465
impl<T: CompressedEntry + fmt::Debug> fmt::Debug for ByteCompressedVec<T> {
466
400
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
467
400
        f.debug_list().entries(self.iter()).finish()
468
400
    }
469
}
470

            
471
/// Implement it for the TagIndex for convenience.
472
impl<T: CompressedEntry + Copy, Tag> CompressedEntry for TagIndex<T, Tag> {
473
    delegate! {
474
        to self.value() {
475
498541810
            fn to_bytes(&self, bytes: &mut [u8]);
476
479027413
            fn bytes_required(&self) -> usize;
477
        }
478
    }
479

            
480
5939469254
    fn from_bytes(bytes: &[u8]) -> Self {
481
5939469254
        TagIndex::new(T::from_bytes(bytes))
482
5939469254
    }
483
}
484

            
485
#[cfg(test)]
486
mod tests {
487
    use super::*;
488

            
489
    use rand::RngExt;
490
    use rand::distr::Uniform;
491
    use rand::seq::SliceRandom;
492

            
493
    use merc_utilities::random_test;
494

            
495
    #[test]
496
1
    fn test_index_bytevector() {
497
1
        let mut vec = ByteCompressedVec::new();
498
1
        vec.push(1);
499
1
        assert_eq!(vec.len(), 1);
500

            
501
1
        vec.push(1024);
502
1
        assert_eq!(vec.len(), 2);
503

            
504
1
        assert_eq!(vec.index(0), 1);
505
1
        assert_eq!(vec.index(1), 1024);
506
1
    }
507

            
508
    #[test]
509
1
    fn test_random_bytevector() {
510
1
        let rng = rand::rng();
511

            
512
1
        let range = Uniform::new(0, usize::MAX).unwrap();
513
1
        let expected_vector: Vec<usize> = rng.sample_iter(range).take(100).collect();
514
1
        let mut vector = ByteCompressedVec::new();
515

            
516
100
        for element in &expected_vector {
517
100
            vector.push(*element);
518

            
519
5050
            for (expected, element) in expected_vector.iter().zip(vector.iter()) {
520
5050
                assert_eq!(*expected, element);
521
            }
522
        }
523
1
    }
524

            
525
    #[test]
526
1
    fn test_random_setting_bytevector() {
527
1
        let rng = rand::rng();
528

            
529
1
        let range = Uniform::new(0, usize::MAX).unwrap();
530
1
        let expected_vector: Vec<usize> = rng.sample_iter(range).take(100).collect();
531
1
        let mut vector = bytevec![0; 100];
532

            
533
100
        for (index, element) in expected_vector.iter().enumerate() {
534
100
            vector.set(index, *element);
535
100
        }
536

            
537
100
        for (expected, element) in expected_vector.iter().zip(vector.iter()) {
538
100
            assert_eq!(*expected, element);
539
        }
540
1
    }
541

            
542
    #[test]
543
1
    fn test_random_usize_entry() {
544
100
        random_test(100, |rng| {
545
100
            let value = rng.random_range(0..1024);
546
100
            assert!(value.bytes_required() <= 2);
547

            
548
100
            let mut bytes = [0; 2];
549
100
            value.to_bytes(&mut bytes);
550
100
            assert_eq!(usize::from_bytes(&bytes), value);
551
100
        });
552
1
    }
553

            
554
    #[test]
555
1
    fn test_swap() {
556
1
        let mut vec = ByteCompressedVec::new();
557
1
        vec.push(1);
558
1
        vec.push(256);
559
1
        vec.push(65536);
560

            
561
1
        vec.swap(0, 2);
562

            
563
1
        assert_eq!(vec.index(0), 65536);
564
1
        assert_eq!(vec.index(1), 256);
565
1
        assert_eq!(vec.index(2), 1);
566
1
    }
567

            
568
    #[test]
569
1
    fn test_random_bytevector_permute() {
570
100
        random_test(100, |rng| {
571
            // Generate random vector to permute
572
100
            let elements = (0..rng.random_range(1..10))
573
532
                .map(|_| rng.random_range(0..100))
574
100
                .collect::<Vec<_>>();
575

            
576
100
            let vec = ByteCompressedVec::with_iter(elements.iter().cloned());
577

            
578
200
            for is_inverse in [false, true] {
579
200
                println!("Inverse: {is_inverse}, Input: {:?}", vec);
580

            
581
200
                let permutation = {
582
200
                    let mut order: Vec<usize> = (0..elements.len()).collect();
583
200
                    order.shuffle(rng);
584
200
                    order
585
                };
586

            
587
200
                let mut permutated = vec.clone();
588
200
                if is_inverse {
589
2128
                    permutated.permute_indices(|i| permutation[i]);
590
                } else {
591
2128
                    permutated.permute(|i| permutation[i]);
592
                }
593

            
594
200
                println!("Permutation: {:?}", permutation);
595
200
                println!("After permutation: {:?}", permutated);
596

            
597
                // Check that the permutation was applied correctly
598
1064
                for i in 0..elements.len() {
599
1064
                    let pos = if is_inverse {
600
532
                        permutation[i]
601
                    } else {
602
532
                        permutation
603
532
                            .iter()
604
2013
                            .position(|&j| i == j)
605
532
                            .expect("Should find inverse mapping")
606
                    };
607

            
608
1064
                    debug_assert_eq!(
609
1064
                        permutated.index(i),
610
1064
                        elements[pos],
611
                        "Element at index {} should be {}",
612
                        i,
613
                        elements[pos]
614
                    );
615
                }
616
            }
617
100
        });
618
1
    }
619
}