1
use std::fmt;
2
use std::marker::PhantomData;
3

            
4
use bitvec::bitvec;
5
use bitvec::order::Lsb0;
6
use delegate::delegate;
7
use log::trace;
8

            
9
use merc_io::BytesFormatter;
10
use merc_utilities::TagIndex;
11
use merc_utilities::debug_trace;
12
use merc_utilities::is_valid_permutation;
13

            
14
/// A copy of `vec![]` that can be used for the [`crate::ByteCompressedVec`].
15
#[macro_export]
16
macro_rules! bytevec {
17
    () => {
18
        $crate::ByteCompressedVec::new()
19
    };
20
    ($elem:expr; $n:expr) => {
21
        $crate::ByteCompressedVec::from_elem($elem, $n)
22
    };
23
}
24

            
25
/// A vector data structure that stores objects in a byte compressed format.
26
///
27
/// # Details
28
///
29
/// The basic idea is that elements of type `T` implement the `CompressedEntry`
30
/// trait which allows them to be converted to and from a byte representation.
31
/// The vector dynamically adjusts the number of bytes used per entry based on
32
/// the maximum size of the entries added so far.
33
///
34
/// For numbers this means that we only store the number of bytes required to
35
/// represent the largest number added so far. Note that the number of bytes
36
/// used per entry is only increased over time as larger entries are added.
37
///
38
/// Note that the `drop()` function of `T` is never called, but we cannot
39
/// require that `T: !Drop`.
40
#[derive(Default, PartialEq, Eq, Clone)]
41
pub struct ByteCompressedVec<T> {
42
    data: Vec<u8>,
43
    bytes_per_entry: usize,
44
    _marker: PhantomData<T>,
45
}
46

            
47
impl<T: CompressedEntry> ByteCompressedVec<T> {
48
79896
    pub fn new() -> ByteCompressedVec<T> {
49
79896
        ByteCompressedVec {
50
79896
            data: Vec::new(),
51
79896
            bytes_per_entry: 0,
52
79896
            _marker: PhantomData,
53
79896
        }
54
79896
    }
55

            
56
    /// Initializes a ByteCompressedVec with the given capacity and (minimal) bytes per entry.
57
150950
    pub fn with_capacity(capacity: usize, bytes_per_entry: usize) -> ByteCompressedVec<T> {
58
150950
        ByteCompressedVec {
59
150950
            data: Vec::with_capacity(capacity * bytes_per_entry),
60
150950
            bytes_per_entry,
61
150950
            _marker: PhantomData,
62
150950
        }
63
150950
    }
64

            
65
    /// This is basically the collect() of `Vec`.
66
    ///
67
    /// However, we use it to determine the required bytes per entry in advance.
68
100
    pub fn with_iter<I>(iter: I) -> ByteCompressedVec<T>
69
100
    where
70
100
        I: ExactSizeIterator<Item = T> + Clone,
71
    {
72
100
        let bytes_per_entry = iter
73
100
            .clone()
74
493
            .fold(0, |max_bytes, entry| max_bytes.max(entry.bytes_required()));
75

            
76
100
        let mut vec = ByteCompressedVec::with_capacity(iter.len(), bytes_per_entry);
77
493
        for entry in iter {
78
493
            vec.push(entry);
79
493
        }
80
100
        vec
81
100
    }
82

            
83
    /// Adds a new entry to the vector.
84
55264624
    pub fn push(&mut self, entry: T) {
85
55264624
        self.resize_entries(entry.bytes_required());
86

            
87
        // Add the new entry to the end of the vector.
88
55264624
        let old_len = self.data.len();
89
55264624
        self.data.resize(old_len + self.bytes_per_entry, 0);
90
55264624
        entry.to_bytes(&mut self.data[old_len..]);
91
55264624
    }
92

            
93
    /// Removes the last element from the vector and returns it, or None if it is empty.
94
2003
    pub fn pop(&mut self) -> Option<T> {
95
2003
        if self.is_empty() {
96
            None
97
        } else {
98
2003
            let index = self.len() - 1;
99
2003
            let entry = self.index(index);
100
2003
            self.data.truncate(index * self.bytes_per_entry);
101
2003
            Some(entry)
102
        }
103
2003
    }
104

            
105
    /// Returns the entry at the given index.
106
683377726
    pub fn index(&self, index: usize) -> T {
107
683377726
        let start = index * self.bytes_per_entry;
108
683377726
        let end = start + self.bytes_per_entry;
109
683377726
        T::from_bytes(&self.data[start..end])
110
683377726
    }
111

            
112
    /// Sets the entry at the given index.
113
84391817
    pub fn set(&mut self, index: usize, entry: T) {
114
84391817
        self.resize_entries(entry.bytes_required());
115

            
116
84391817
        let start = index * self.bytes_per_entry;
117
84391817
        let end = start + self.bytes_per_entry;
118
84391817
        entry.to_bytes(&mut self.data[start..end]);
119
84391817
    }
120

            
121
    /// Returns the number of elements in the vector.
122
19878258
    pub fn len(&self) -> usize {
123
19878258
        if self.bytes_per_entry == 0 {
124
177145
            0
125
        } else {
126
19701113
            debug_assert!(self.data.len().is_multiple_of(self.bytes_per_entry));
127
19701113
            self.data.len() / self.bytes_per_entry
128
        }
129
19878258
    }
130

            
131
    /// Returns true if the vector is empty.
132
89925
    pub fn is_empty(&self) -> bool {
133
89925
        self.len() == 0
134
89925
    }
135

            
136
    /// Returns metrics about memory usage of this compressed vector
137
    pub fn metrics(&self) -> CompressedVecMetrics {
138
        let element_count = self.len();
139
        let actual_memory =
140
            self.data.len() + std::mem::size_of_val(&self.bytes_per_entry) + std::mem::size_of::<PhantomData<T>>();
141
        let worst_case_memory = element_count * std::mem::size_of::<T>();
142

            
143
        CompressedVecMetrics {
144
            actual_memory,
145
            worst_case_memory,
146
        }
147
    }
148

            
149
    /// Returns an iterator over the elements in the vector.
150
15407
    pub fn iter(&self) -> ByteCompressedVecIterator<'_, T> {
151
15407
        ByteCompressedVecIterator {
152
15407
            vector: self,
153
15407
            current: 0,
154
15407
            end: self.len(),
155
15407
        }
156
15407
    }
157

            
158
    /// Returns an iterator over the elements in the vector for the begin, end range.
159
    pub fn iter_range(&self, begin: usize, end: usize) -> ByteCompressedVecIterator<'_, T> {
160
        ByteCompressedVecIterator {
161
            vector: self,
162
            current: begin,
163
            end,
164
        }
165
    }
166

            
167
    /// Updates the given entry using a closure.
168
7749126
    pub fn update<F>(&mut self, index: usize, mut update: F)
169
7749126
    where
170
7749126
        F: FnMut(&mut T),
171
    {
172
7749126
        let mut entry = self.index(index);
173
7749126
        update(&mut entry);
174
7749126
        self.set(index, entry);
175
7749126
    }
176

            
177
    /// Iterate over all elements and adapt the elements using a closure.
178
    pub fn map<F>(&mut self, mut f: F)
179
    where
180
        F: FnMut(&mut T),
181
    {
182
        for index in 0..self.len() {
183
            let mut entry = self.index(index);
184
            f(&mut entry);
185
            self.set(index, entry);
186
        }
187
    }
188

            
189
    /// Folds over the elements in the vector using the provided closure.
190
45760
    pub fn fold<B, F>(&mut self, init: B, mut f: F) -> B
191
45760
    where
192
45760
        F: FnMut(B, &mut T) -> B,
193
    {
194
45760
        let mut accumulator = init;
195
6474460
        for index in 0..self.len() {
196
6474460
            let mut element = self.index(index);
197
6474460
            accumulator = f(accumulator, &mut element);
198
6474460
            self.set(index, element);
199
6474460
        }
200
45760
        accumulator
201
45760
    }
202

            
203
    /// Permutes a vector in place according to the given permutation function.
204
    ///
205
    /// The resulting vector will be [v_p^-1(0), v_p^-1(1), ..., v_p^-1(n-1)] where p is the permutation function.
206
100
    pub fn permute<P>(&mut self, permutation: P)
207
100
    where
208
100
        P: Fn(usize) -> usize,
209
    {
210
100
        debug_assert!(
211
100
            is_valid_permutation(&permutation, self.len()),
212
            "The given permutation must be a bijective mapping"
213
        );
214

            
215
100
        let mut visited = bitvec![usize, Lsb0; 0; self.len()];
216
493
        for start in 0..self.len() {
217
493
            if visited[start] {
218
291
                continue;
219
202
            }
220

            
221
            // Perform the cycle starting at 'start'
222
202
            let mut current = start;
223

            
224
            // Keeps track of the last displaced element
225
202
            let mut old = self.index(start);
226

            
227
202
            debug_trace!("Starting new cycle at position {}", start);
228
695
            while !visited[current] {
229
493
                visited.set(current, true);
230
493
                let next = permutation(current);
231
493
                if next != current {
232
402
                    debug_trace!("Moving element from position {} to position {}", current, next);
233
402
                    let temp = self.index(next);
234
402
                    self.set(next, old);
235
402
                    old = temp;
236
402
                }
237

            
238
493
                current = next;
239
            }
240
        }
241
100
    }
242

            
243
    /// Applies a permutation to a vector in place using an index function.
244
    ///
245
    /// The resulting vector will be [v_p(0), v_p(1), ..., v_p(n-1)] where p is the index function.
246
100
    pub fn permute_indices<P>(&mut self, indices: P)
247
100
    where
248
100
        P: Fn(usize) -> usize,
249
    {
250
100
        debug_assert!(
251
100
            is_valid_permutation(&indices, self.len()),
252
            "The given permutation must be a bijective mapping"
253
        );
254

            
255
100
        let mut visited = bitvec![usize, Lsb0; 0; self.len()];
256
493
        for start in 0..self.len() {
257
493
            if visited[start] {
258
287
                continue;
259
206
            }
260

            
261
            // Follow the cycle starting at 'start'
262
206
            debug_trace!("Starting new cycle at position {}", start);
263
206
            let mut current = start;
264
206
            let original = self.index(start);
265

            
266
588
            while !visited[current] {
267
493
                visited.set(current, true);
268
493
                let next = indices(current);
269

            
270
493
                if next != current {
271
398
                    if next != start {
272
287
                        debug_trace!("Moving element from position {} to position {}", current, next);
273
287
                        self.set(current, self.index(next));
274
287
                    } else {
275
111
                        break;
276
                    }
277
95
                }
278

            
279
382
                current = next;
280
            }
281

            
282
206
            trace!("Writing original to {}", current);
283
206
            self.set(current, original);
284
        }
285
100
    }
286

            
287
    /// Applies a permutation to a vector in place using an index function.
288
    ///
289
    /// This variant is faster but requires additional memory for the intermediate result vector.
290
    pub fn permute_indices_fast<P>(&mut self, indices: P)
291
    where
292
        P: Fn(usize) -> usize,
293
    {
294
        let mut result = ByteCompressedVec::with_capacity(self.data.capacity(), self.bytes_per_entry);
295
        for index in 0..self.len() {
296
            result.push(self.index(indices(index)));
297
        }
298
        *self = result;
299
    }
300

            
301
    /// Swaps the entries at the given indices.
302
1
    pub fn swap(&mut self, index1: usize, index2: usize) {
303
1
        if index1 != index2 {
304
1
            let start1 = index1 * self.bytes_per_entry;
305
1
            let start2 = index2 * self.bytes_per_entry;
306
1

            
307
1
            // Create a temporary buffer for one entry
308
1
            let temp = T::from_bytes(&self.data[start1..start1 + self.bytes_per_entry]);
309
1

            
310
1
            // Copy entry2 to entry1's position
311
1
            self.data.copy_within(start2..start2 + self.bytes_per_entry, start1);
312
1

            
313
1
            // Copy temp to entry2's position
314
1
            temp.to_bytes(&mut self.data[start2..start2 + self.bytes_per_entry]);
315
1
        }
316
1
    }
317

            
318
    /// Resizes the vector to the given length, filling new entries with the provided value.
319
63691
    pub fn resize_with<F>(&mut self, new_len: usize, mut f: F)
320
63691
    where
321
63691
        F: FnMut() -> T,
322
    {
323
63691
        let current_len = self.len();
324
63691
        if new_len > current_len {
325
            // Preallocate the required space.
326
63691
            self.data.reserve(new_len * self.bytes_per_entry);
327
9561514
            for _ in current_len..new_len {
328
9561514
                self.push(f());
329
9561514
            }
330
        } else if new_len < current_len {
331
            if new_len == 0 {
332
                self.data.clear();
333
                self.bytes_per_entry = 0;
334
            } else {
335
                // It could be that the bytes per entry is now less, but that we never reduce.
336
                self.data.truncate(new_len * self.bytes_per_entry);
337
            }
338
        }
339
63691
    }
340

            
341
    /// Reserves capacity for at least additional more entries to be inserted with the given bytes per entry.
342
6009
    pub fn reserve(&mut self, additional: usize, bytes_per_entry: usize) {
343
6009
        self.resize_entries(bytes_per_entry);
344
6009
        self.data.reserve(additional * self.bytes_per_entry);
345
6009
    }
346

            
347
    /// Resizes all entries in the vector to the given length.
348
139679577
    fn resize_entries(&mut self, new_bytes_required: usize) {
349
139679577
        if new_bytes_required > self.bytes_per_entry {
350
105876
            let mut new_data: Vec<u8> = vec![0; self.len() * new_bytes_required];
351

            
352
105876
            if self.bytes_per_entry > 0 {
353
                // Resize all the existing elements because the new entry requires more bytes.
354
14063063
                for (index, entry) in self.iter().enumerate() {
355
14063063
                    let start = index * new_bytes_required;
356
14063063
                    let end = start + new_bytes_required;
357
14063063
                    entry.to_bytes(&mut new_data[start..end]);
358
14063063
                }
359
92968
            }
360

            
361
105876
            self.bytes_per_entry = new_bytes_required;
362
105876
            self.data = new_data;
363
139573701
        }
364
139679577
    }
365
}
366

            
367
impl<T: CompressedEntry + Clone> ByteCompressedVec<T> {
368
149851
    pub fn from_elem(entry: T, n: usize) -> ByteCompressedVec<T> {
369
149851
        let mut vec = ByteCompressedVec::with_capacity(n, entry.bytes_required());
370
27233950
        for _ in 0..n {
371
27233950
            vec.push(entry.clone());
372
27233950
        }
373
149851
        vec
374
149851
    }
375
}
376

            
377
/// Metrics for tracking memory usage of a ByteCompressedVec
378
#[derive(Debug, Clone)]
379
pub struct CompressedVecMetrics {
380
    /// Actual memory used by the compressed vector (in bytes)
381
    pub actual_memory: usize,
382
    /// Worst-case memory that would be used by an uncompressed vector (len * sizeof(T))
383
    pub worst_case_memory: usize,
384
}
385

            
386
impl CompressedVecMetrics {
387
    /// Calculate memory savings in bytes
388
    pub fn memory_savings(&self) -> usize {
389
        self.worst_case_memory.saturating_sub(self.actual_memory)
390
    }
391

            
392
    /// Calculate memory savings as a percentage
393
    pub fn used_percentage(&self) -> f64 {
394
        if self.worst_case_memory == 0 {
395
            0.0
396
        } else {
397
            (self.actual_memory as f64 / self.worst_case_memory as f64) * 100.0
398
        }
399
    }
400
}
401

            
402
impl fmt::Display for CompressedVecMetrics {
403
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
404
        write!(
405
            f,
406
            "memory: {} ({:.1}%), saving: {} ",
407
            BytesFormatter(self.actual_memory),
408
            self.used_percentage(),
409
            BytesFormatter(self.memory_savings()),
410
        )
411
    }
412
}
413
pub struct ByteCompressedVecIterator<'a, T> {
414
    vector: &'a ByteCompressedVec<T>,
415
    current: usize,
416
    end: usize,
417
}
418

            
419
impl<T: CompressedEntry> Iterator for ByteCompressedVecIterator<'_, T> {
420
    type Item = T;
421

            
422
15794156
    fn next(&mut self) -> Option<Self::Item> {
423
15794156
        if self.current < self.end {
424
15780083
            let result = self.vector.index(self.current);
425
15780083
            self.current += 1;
426
15780083
            Some(result)
427
        } else {
428
14073
            None
429
        }
430
15794156
    }
431
}
432

            
433
pub trait CompressedEntry {
434
    // Returns the entry as a byte vector
435
    fn to_bytes(&self, bytes: &mut [u8]);
436

            
437
    // Creates an entry from a byte vector
438
    fn from_bytes(bytes: &[u8]) -> Self;
439

            
440
    // Returns the number of bytes required to store the current entry
441
    fn bytes_required(&self) -> usize;
442
}
443

            
444
impl CompressedEntry for usize {
445
464792433
    fn to_bytes(&self, bytes: &mut [u8]) {
446
464792433
        let array = &self.to_le_bytes();
447
689154165
        for (i, byte) in bytes.iter_mut().enumerate().take(usize::BITS as usize / 8) {
448
689154165
            *byte = array[i];
449
689154165
        }
450
464792433
    }
451

            
452
1934914145
    fn from_bytes(bytes: &[u8]) -> Self {
453
1934914145
        let mut array = [0; 8];
454
3053032437
        for (i, byte) in bytes.iter().enumerate().take(usize::BITS as usize / 8) {
455
3053032437
            array[i] = *byte;
456
3053032437
        }
457
1934914145
        usize::from_le_bytes(array)
458
1934914145
    }
459

            
460
426190016
    fn bytes_required(&self) -> usize {
461
426190016
        ((self + 1).ilog2() / u8::BITS) as usize + 1
462
426190016
    }
463
}
464

            
465
impl<T: CompressedEntry + fmt::Debug> fmt::Debug for ByteCompressedVec<T> {
466
400
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
467
400
        f.debug_list().entries(self.iter()).finish()
468
400
    }
469
}
470

            
471
/// Implement it for the TagIndex for convenience.
472
impl<T: CompressedEntry + Copy, Tag> CompressedEntry for TagIndex<T, Tag> {
473
    delegate! {
474
        to self.value() {
475
71654982
            fn to_bytes(&self, bytes: &mut [u8]);
476
69316674
            fn bytes_required(&self) -> usize;
477
        }
478
    }
479

            
480
404266082
    fn from_bytes(bytes: &[u8]) -> Self {
481
404266082
        TagIndex::new(T::from_bytes(bytes))
482
404266082
    }
483
}
484

            
485
#[cfg(test)]
486
mod tests {
487
    use super::*;
488

            
489
    use rand::RngExt;
490
    use rand::distr::Uniform;
491
    use rand::seq::SliceRandom;
492

            
493
    use merc_utilities::random_test;
494

            
495
    #[test]
496
1
    fn test_index_bytevector() {
497
1
        let mut vec = ByteCompressedVec::new();
498
1
        vec.push(1);
499
1
        assert_eq!(vec.len(), 1);
500

            
501
1
        vec.push(1024);
502
1
        assert_eq!(vec.len(), 2);
503

            
504
1
        assert_eq!(vec.index(0), 1);
505
1
        assert_eq!(vec.index(1), 1024);
506
1
    }
507

            
508
    #[test]
509
1
    fn test_random_bytevector() {
510
1
        let rng = rand::rng();
511

            
512
1
        let range = Uniform::new(0, usize::MAX).unwrap();
513
1
        let expected_vector: Vec<usize> = rng.sample_iter(range).take(100).collect();
514
1
        let mut vector = ByteCompressedVec::new();
515

            
516
100
        for element in &expected_vector {
517
100
            vector.push(*element);
518

            
519
5050
            for (expected, element) in expected_vector.iter().zip(vector.iter()) {
520
5050
                assert_eq!(*expected, element);
521
            }
522
        }
523
1
    }
524

            
525
    #[test]
526
1
    fn test_random_setting_bytevector() {
527
1
        let rng = rand::rng();
528

            
529
1
        let range = Uniform::new(0, usize::MAX).unwrap();
530
1
        let expected_vector: Vec<usize> = rng.sample_iter(range).take(100).collect();
531
1
        let mut vector = bytevec![0; 100];
532

            
533
100
        for (index, element) in expected_vector.iter().enumerate() {
534
100
            vector.set(index, *element);
535
100
        }
536

            
537
100
        for (expected, element) in expected_vector.iter().zip(vector.iter()) {
538
100
            assert_eq!(*expected, element);
539
        }
540
1
    }
541

            
542
    #[test]
543
1
    fn test_random_usize_entry() {
544
100
        random_test(100, |rng| {
545
100
            let value = rng.random_range(0..1024);
546
100
            assert!(value.bytes_required() <= 2);
547

            
548
100
            let mut bytes = [0; 2];
549
100
            value.to_bytes(&mut bytes);
550
100
            assert_eq!(usize::from_bytes(&bytes), value);
551
100
        });
552
1
    }
553

            
554
    #[test]
555
1
    fn test_swap() {
556
1
        let mut vec = ByteCompressedVec::new();
557
1
        vec.push(1);
558
1
        vec.push(256);
559
1
        vec.push(65536);
560

            
561
1
        vec.swap(0, 2);
562

            
563
1
        assert_eq!(vec.index(0), 65536);
564
1
        assert_eq!(vec.index(1), 256);
565
1
        assert_eq!(vec.index(2), 1);
566
1
    }
567

            
568
    #[test]
569
1
    fn test_random_bytevector_permute() {
570
100
        random_test(100, |rng| {
571
            // Generate random vector to permute
572
100
            let elements = (0..rng.random_range(1..10))
573
493
                .map(|_| rng.random_range(0..100))
574
100
                .collect::<Vec<_>>();
575

            
576
100
            let vec = ByteCompressedVec::with_iter(elements.iter().cloned());
577

            
578
200
            for is_inverse in [false, true] {
579
200
                println!("Inverse: {is_inverse}, Input: {:?}", vec);
580

            
581
200
                let permutation = {
582
200
                    let mut order: Vec<usize> = (0..elements.len()).collect();
583
200
                    order.shuffle(rng);
584
200
                    order
585
                };
586

            
587
200
                let mut permutated = vec.clone();
588
200
                if is_inverse {
589
1972
                    permutated.permute_indices(|i| permutation[i]);
590
                } else {
591
1972
                    permutated.permute(|i| permutation[i]);
592
                }
593

            
594
200
                println!("Permutation: {:?}", permutation);
595
200
                println!("After permutation: {:?}", permutated);
596

            
597
                // Check that the permutation was applied correctly
598
986
                for i in 0..elements.len() {
599
986
                    let pos = if is_inverse {
600
493
                        permutation[i]
601
                    } else {
602
493
                        permutation
603
493
                            .iter()
604
1799
                            .position(|&j| i == j)
605
493
                            .expect("Should find inverse mapping")
606
                    };
607

            
608
986
                    debug_assert_eq!(
609
986
                        permutated.index(i),
610
986
                        elements[pos],
611
                        "Element at index {} should be {}",
612
                        i,
613
                        elements[pos]
614
                    );
615
                }
616
            }
617
100
        });
618
1
    }
619
}