1
use std::fmt;
2
use std::marker::PhantomData;
3

            
4
use bitvec::bitvec;
5
use bitvec::order::Lsb0;
6
use delegate::delegate;
7
use log::trace;
8

            
9
use merc_io::BytesFormatter;
10
use merc_utilities::TagIndex;
11
use merc_utilities::debug_trace;
12
use merc_utilities::is_valid_permutation;
13

            
14
/// A copy of `vec![]` that can be used for the [`crate::ByteCompressedVec`].
15
#[macro_export]
16
macro_rules! bytevec {
17
    () => {
18
        $crate::ByteCompressedVec::new()
19
    };
20
    ($elem:expr; $n:expr) => {
21
        $crate::ByteCompressedVec::from_elem($elem, $n)
22
    };
23
}
24

            
25
/// A vector data structure that stores objects in a byte compressed format. The basic idea is that elements of type `T` impplement the `CompressedEntry` trait which allows them to be converted to and from a byte representation. The vector dynamically adjusts the number of bytes used per entry based on the maximum size of the entries added so far.
26
///
27
/// For numbers this means that we only store the number of bytes required to represent the largest number added so far. Note that the number of bytes used per entry is only increased over time as larger entries are added.
28
///
29
/// TODO: The `drop()` function of `T` is never called.
30
#[derive(Default, PartialEq, Eq, Clone)]
31
pub struct ByteCompressedVec<T> {
32
    data: Vec<u8>,
33
    bytes_per_entry: usize,
34
    _marker: PhantomData<T>,
35
}
36

            
37
impl<T: CompressedEntry> ByteCompressedVec<T> {
38
20017
    pub fn new() -> ByteCompressedVec<T> {
39
20017
        ByteCompressedVec {
40
20017
            data: Vec::new(),
41
20017
            bytes_per_entry: 0,
42
20017
            _marker: PhantomData,
43
20017
        }
44
20017
    }
45

            
46
    /// Initializes a ByteCompressedVec with the given capacity and (minimal) bytes per entry.
47
47568
    pub fn with_capacity(capacity: usize, bytes_per_entry: usize) -> ByteCompressedVec<T> {
48
47568
        ByteCompressedVec {
49
47568
            data: Vec::with_capacity(capacity * bytes_per_entry),
50
47568
            bytes_per_entry,
51
47568
            _marker: PhantomData,
52
47568
        }
53
47568
    }
54

            
55
    /// This is basically the collect() of `Vec`.
56
    ///
57
    /// However, we use it to determine the required bytes per entry in advance.
58
100
    pub fn with_iter<I>(iter: I) -> ByteCompressedVec<T>
59
100
    where
60
100
        I: ExactSizeIterator<Item = T> + Clone,
61
    {
62
100
        let bytes_per_entry = iter
63
100
            .clone()
64
529
            .fold(0, |max_bytes, entry| max_bytes.max(entry.bytes_required()));
65

            
66
100
        let mut vec = ByteCompressedVec::with_capacity(iter.len(), bytes_per_entry);
67
529
        for entry in iter {
68
529
            vec.push(entry);
69
529
        }
70
100
        vec
71
100
    }
72

            
73
    /// Adds a new entry to the vector.
74
4954179
    pub fn push(&mut self, entry: T) {
75
4954179
        self.resize_entries(entry.bytes_required());
76

            
77
        // Add the new entry to the end of the vector.
78
4954179
        let old_len = self.data.len();
79
4954179
        self.data.resize(old_len + self.bytes_per_entry, 0);
80
4954179
        entry.to_bytes(&mut self.data[old_len..]);
81
4954179
    }
82

            
83
    /// Removes the last element from the vector and returns it, or None if it is empty.
84
300
    pub fn pop(&mut self) -> Option<T> {
85
300
        if self.is_empty() {
86
            None
87
        } else {
88
300
            let index = self.len() - 1;
89
300
            let entry = self.index(index);
90
300
            self.data.truncate(index * self.bytes_per_entry);
91
300
            Some(entry)
92
        }
93
300
    }
94

            
95
    /// Returns the entry at the given index.
96
32173715
    pub fn index(&self, index: usize) -> T {
97
32173715
        let start = index * self.bytes_per_entry;
98
32173715
        let end = start + self.bytes_per_entry;
99
32173715
        T::from_bytes(&self.data[start..end])
100
32173715
    }
101

            
102
    /// Sets the entry at the given index.
103
7909785
    pub fn set(&mut self, index: usize, entry: T) {
104
7909785
        self.resize_entries(entry.bytes_required());
105

            
106
7909785
        let start = index * self.bytes_per_entry;
107
7909785
        let end = start + self.bytes_per_entry;
108
7909785
        entry.to_bytes(&mut self.data[start..end]);
109
7909785
    }
110

            
111
    /// Returns the number of elements in the vector.
112
1921051
    pub fn len(&self) -> usize {
113
1921051
        if self.bytes_per_entry == 0 {
114
40034
            0
115
        } else {
116
1881017
            debug_assert!(self.data.len() % self.bytes_per_entry == 0);
117
1881017
            self.data.len() / self.bytes_per_entry
118
        }
119
1921051
    }
120

            
121
    /// Returns true if the vector is empty.
122
300
    pub fn is_empty(&self) -> bool {
123
300
        self.len() == 0
124
300
    }
125

            
126
    /// Returns metrics about memory usage of this compressed vector
127
    pub fn metrics(&self) -> CompressedVecMetrics {
128
        let element_count = self.len();
129
        let actual_memory =
130
            self.data.len() + std::mem::size_of_val(&self.bytes_per_entry) + std::mem::size_of::<PhantomData<T>>();
131
        let worst_case_memory = element_count * std::mem::size_of::<T>();
132

            
133
        CompressedVecMetrics {
134
            actual_memory,
135
            worst_case_memory,
136
        }
137
    }
138

            
139
    /// Returns an iterator over the elements in the vector.
140
3037
    pub fn iter(&self) -> ByteCompressedVecIterator<'_, T> {
141
3037
        ByteCompressedVecIterator {
142
3037
            vector: self,
143
3037
            current: 0,
144
3037
            end: self.len(),
145
3037
        }
146
3037
    }
147

            
148
    /// Returns an iterator over the elements in the vector for the begin, end range.
149
    pub fn iter_range(&self, begin: usize, end: usize) -> ByteCompressedVecIterator<'_, T> {
150
        ByteCompressedVecIterator {
151
            vector: self,
152
            current: begin,
153
            end,
154
        }
155
    }
156

            
157
    /// Updates the given entry using a closure.
158
1912982
    pub fn update<F>(&mut self, index: usize, mut update: F)
159
1912982
    where
160
1912982
        F: FnMut(&mut T),
161
    {
162
1912982
        let mut entry = self.index(index);
163
1912982
        update(&mut entry);
164
1912982
        self.set(index, entry);
165
1912982
    }
166

            
167
    /// Iterate over all elements and adapt the elements using a closure.
168
    pub fn map<F>(&mut self, mut f: F)
169
    where
170
        F: FnMut(&mut T),
171
    {
172
        for index in 0..self.len() {
173
            let mut entry = self.index(index);
174
            f(&mut entry);
175
            self.set(index, entry);
176
        }
177
    }
178

            
179
    /// Folds over the elements in the vector using the provided closure.
180
21530
    pub fn fold<B, F>(&mut self, init: B, mut f: F) -> B
181
21530
    where
182
21530
        F: FnMut(B, &mut T) -> B,
183
    {
184
21530
        let mut accumulator = init;
185
403126
        for index in 0..self.len() {
186
403126
            let mut element = self.index(index);
187
403126
            accumulator = f(accumulator, &mut element);
188
403126
            self.set(index, element);
189
403126
        }
190
21530
        accumulator
191
21530
    }
192

            
193
    /// Permutes a vector in place according to the given permutation function.
194
    ///
195
    /// The resulting vector will be [v_p^-1(0), v_p^-1(1), ..., v_p^-1(n-1)] where p is the permutation function.
196
100
    pub fn permute<P>(&mut self, permutation: P)
197
100
    where
198
100
        P: Fn(usize) -> usize,
199
    {
200
100
        debug_assert!(
201
100
            is_valid_permutation(&permutation, self.len()),
202
            "The given permutation must be a bijective mapping"
203
        );
204

            
205
100
        let mut visited = bitvec![usize, Lsb0; 0; self.len()];
206
529
        for start in 0..self.len() {
207
529
            if visited[start] {
208
326
                continue;
209
203
            }
210

            
211
            // Perform the cycle starting at 'start'
212
203
            let mut current = start;
213

            
214
            // Keeps track of the last displaced element
215
203
            let mut old = self.index(start);
216

            
217
203
            debug_trace!("Starting new cycle at position {}", start);
218
732
            while !visited[current] {
219
529
                visited.set(current, true);
220
529
                let next = permutation(current);
221
529
                if next != current {
222
445
                    debug_trace!("Moving element from position {} to position {}", current, next);
223
445
                    let temp = self.index(next);
224
445
                    self.set(next, old);
225
445
                    old = temp;
226
445
                }
227

            
228
529
                current = next;
229
            }
230
        }
231
100
    }
232

            
233
    /// Applies a permutation to a vector in place using an index function.
234
    ///
235
    /// The resulting vector will be [v_p(0), v_p(1), ..., v_p(n-1)] where p is the index function.
236
100
    pub fn permute_indices<P>(&mut self, indices: P)
237
100
    where
238
100
        P: Fn(usize) -> usize,
239
    {
240
100
        debug_assert!(
241
100
            is_valid_permutation(&indices, self.len()),
242
            "The given permutation must be a bijective mapping"
243
        );
244

            
245
100
        let mut visited = bitvec![usize, Lsb0; 0; self.len()];
246
529
        for start in 0..self.len() {
247
529
            if visited[start] {
248
298
                continue;
249
231
            }
250

            
251
            // Follow the cycle starting at 'start'
252
231
            debug_trace!("Starting new cycle at position {}", start);
253
231
            let mut current = start;
254
231
            let original = self.index(start);
255

            
256
636
            while !visited[current] {
257
529
                visited.set(current, true);
258
529
                let next = indices(current);
259

            
260
529
                if next != current {
261
422
                    if next != start {
262
298
                        debug_trace!("Moving element from position {} to position {}", current, next);
263
298
                        self.set(current, self.index(next));
264
298
                    } else {
265
124
                        break;
266
                    }
267
107
                }
268

            
269
405
                current = next;
270
            }
271

            
272
231
            trace!("Writing original to {}", current);
273
231
            self.set(current, original);
274
        }
275
100
    }
276

            
277
    /// Applies a permutation to a vector in place using an index function.
278
    ///
279
    /// This variant is faster but requires additional memory for the intermediate result vector.
280
    pub fn permute_indices_fast<P>(&mut self, indices: P)
281
    where
282
        P: Fn(usize) -> usize,
283
    {
284
        let mut result = ByteCompressedVec::with_capacity(self.data.capacity(), self.bytes_per_entry);
285
        for entry in self.iter().enumerate() {
286
            result.push(self.index(indices(entry.0)));
287
        }
288
        *self = result;
289
    }
290

            
291
    /// Swaps the entries at the given indices.
292
1
    pub fn swap(&mut self, index1: usize, index2: usize) {
293
1
        if index1 != index2 {
294
1
            let start1 = index1 * self.bytes_per_entry;
295
1
            let start2 = index2 * self.bytes_per_entry;
296
1

            
297
1
            // Create a temporary buffer for one entry
298
1
            let temp = T::from_bytes(&self.data[start1..start1 + self.bytes_per_entry]);
299
1

            
300
1
            // Copy entry2 to entry1's position
301
1
            self.data.copy_within(start2..start2 + self.bytes_per_entry, start1);
302
1

            
303
1
            // Copy temp to entry2's position
304
1
            temp.to_bytes(&mut self.data[start2..start2 + self.bytes_per_entry]);
305
1
        }
306
1
    }
307

            
308
    /// Resizes the vector to the given length, filling new entries with the provided value.
309
20013
    pub fn resize_with<F>(&mut self, new_len: usize, mut f: F)
310
20013
    where
311
20013
        F: FnMut() -> T,
312
    {
313
20013
        let current_len = self.len();
314
20013
        if new_len > current_len {
315
            // Preallocate the required space.
316
20013
            self.data.reserve(new_len * self.bytes_per_entry);
317
308524
            for _ in current_len..new_len {
318
308524
                self.push(f());
319
308524
            }
320
        } else if new_len < current_len {
321
            if new_len == 0 {
322
                self.data.clear();
323
                self.bytes_per_entry = 0;
324
            } else {
325
                // It could be that the bytes per entry is now less, but that we never reduce.
326
                self.data.truncate(new_len * self.bytes_per_entry);
327
            }
328
        }
329
20013
    }
330

            
331
    /// Reserves capacity for at least additional more entries to be inserted with the given bytes per entry.
332
900
    pub fn reserve(&mut self, additional: usize, bytes_per_entry: usize) {
333
900
        self.resize_entries(bytes_per_entry);
334
900
        self.data.reserve(additional * self.bytes_per_entry);
335
900
    }
336

            
337
    /// Resizes all entries in the vector to the given length.
338
12866064
    fn resize_entries(&mut self, new_bytes_required: usize) {
339
12866064
        if new_bytes_required > self.bytes_per_entry {
340
21287
            let mut new_data: Vec<u8> = vec![0; self.len() * new_bytes_required];
341

            
342
21287
            if self.bytes_per_entry > 0 {
343
                // Resize all the existing elements because the new entry requires more bytes.
344
438176
                for (index, entry) in self.iter().enumerate() {
345
438176
                    let start = index * new_bytes_required;
346
438176
                    let end = start + new_bytes_required;
347
438176
                    entry.to_bytes(&mut new_data[start..end]);
348
438176
                }
349
20017
            }
350

            
351
21287
            self.bytes_per_entry = new_bytes_required;
352
21287
            self.data = new_data;
353
12844777
        }
354
12866064
    }
355
}
356

            
357
impl<T: CompressedEntry + Clone> ByteCompressedVec<T> {
358
46835
    pub fn from_elem(entry: T, n: usize) -> ByteCompressedVec<T> {
359
46835
        let mut vec = ByteCompressedVec::with_capacity(n, entry.bytes_required());
360
3214372
        for _ in 0..n {
361
3214372
            vec.push(entry.clone());
362
3214372
        }
363
46835
        vec
364
46835
    }
365
}
366

            
367
/// Metrics for tracking memory usage of a ByteCompressedVec
368
#[derive(Debug, Clone)]
369
pub struct CompressedVecMetrics {
370
    /// Actual memory used by the compressed vector (in bytes)
371
    pub actual_memory: usize,
372
    /// Worst-case memory that would be used by an uncompressed vector (len * sizeof(T))
373
    pub worst_case_memory: usize,
374
}
375

            
376
impl CompressedVecMetrics {
377
    /// Calculate memory savings in bytes
378
    pub fn memory_savings(&self) -> usize {
379
        self.worst_case_memory.saturating_sub(self.actual_memory)
380
    }
381

            
382
    /// Calculate memory savings as a percentage
383
    pub fn used_percentage(&self) -> f64 {
384
        if self.worst_case_memory == 0 {
385
            0.0
386
        } else {
387
            (self.actual_memory as f64 / self.worst_case_memory as f64) * 100.0
388
        }
389
    }
390
}
391

            
392
impl fmt::Display for CompressedVecMetrics {
393
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
394
        write!(
395
            f,
396
            "memory: {} ({:.1}%), saving: {} ",
397
            BytesFormatter(self.actual_memory),
398
            self.used_percentage(),
399
            BytesFormatter(self.memory_savings()),
400
        )
401
    }
402
}
403
pub struct ByteCompressedVecIterator<'a, T> {
404
    vector: &'a ByteCompressedVec<T>,
405
    current: usize,
406
    end: usize,
407
}
408

            
409
impl<T: CompressedEntry> Iterator for ByteCompressedVecIterator<'_, T> {
410
    type Item = T;
411

            
412
1591167
    fn next(&mut self) -> Option<Self::Item> {
413
1591167
        if self.current < self.end {
414
1588976
            let result = self.vector.index(self.current);
415
1588976
            self.current += 1;
416
1588976
            Some(result)
417
        } else {
418
2191
            None
419
        }
420
1591167
    }
421
}
422

            
423
pub trait CompressedEntry {
424
    // Returns the entry as a byte vector
425
    fn to_bytes(&self, bytes: &mut [u8]);
426

            
427
    // Creates an entry from a byte vector
428
    fn from_bytes(bytes: &[u8]) -> Self;
429

            
430
    // Returns the number of bytes required to store the current entry
431
    fn bytes_required(&self) -> usize;
432
}
433

            
434
impl CompressedEntry for usize {
435
83326573
    fn to_bytes(&self, bytes: &mut [u8]) {
436
83326573
        let array = &self.to_le_bytes();
437
101250007
        for (i, byte) in bytes.iter_mut().enumerate().take(usize::BITS as usize / 8) {
438
101250007
            *byte = array[i];
439
101250007
        }
440
83326573
    }
441

            
442
154424268
    fn from_bytes(bytes: &[u8]) -> Self {
443
154424268
        let mut array = [0; 8];
444
205102021
        for (i, byte) in bytes.iter().enumerate().take(usize::BITS as usize / 8) {
445
205102021
            array[i] = *byte;
446
205102021
        }
447
154424268
        usize::from_le_bytes(array)
448
154424268
    }
449

            
450
81554029
    fn bytes_required(&self) -> usize {
451
81554029
        ((self + 1).ilog2() / u8::BITS) as usize + 1
452
81554029
    }
453
}
454

            
455
impl<T: CompressedEntry + fmt::Debug> fmt::Debug for ByteCompressedVec<T> {
456
400
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
457
400
        f.debug_list().entries(self.iter()).finish()
458
400
    }
459
}
460

            
461
/// Implement it for the TagIndex for convenience.
462
impl<T: CompressedEntry + Copy, Tag> CompressedEntry for TagIndex<T, Tag> {
463
    delegate! {
464
        to self.value() {
465
8386001
            fn to_bytes(&self, bytes: &mut [u8]);
466
8195860
            fn bytes_required(&self) -> usize;
467
        }
468
    }
469

            
470
23024185
    fn from_bytes(bytes: &[u8]) -> Self {
471
23024185
        TagIndex::new(T::from_bytes(bytes))
472
23024185
    }
473
}
474

            
475
#[cfg(test)]
476
mod tests {
477
    use super::*;
478

            
479
    use rand::Rng;
480
    use rand::distr::Uniform;
481
    use rand::seq::SliceRandom;
482

            
483
    use merc_utilities::random_test;
484

            
485
    #[test]
486
1
    fn test_index_bytevector() {
487
1
        let mut vec = ByteCompressedVec::new();
488
1
        vec.push(1);
489
1
        assert_eq!(vec.len(), 1);
490

            
491
1
        vec.push(1024);
492
1
        assert_eq!(vec.len(), 2);
493

            
494
1
        assert_eq!(vec.index(0), 1);
495
1
        assert_eq!(vec.index(1), 1024);
496
1
    }
497

            
498
    #[test]
499
1
    fn test_random_bytevector() {
500
1
        let rng = rand::rng();
501

            
502
1
        let range = Uniform::new(0, usize::MAX).unwrap();
503
1
        let expected_vector: Vec<usize> = rng.sample_iter(range).take(100).collect();
504
1
        let mut vector = ByteCompressedVec::new();
505

            
506
100
        for element in &expected_vector {
507
100
            vector.push(*element);
508

            
509
5050
            for (expected, element) in expected_vector.iter().zip(vector.iter()) {
510
5050
                assert_eq!(*expected, element);
511
            }
512
        }
513
1
    }
514

            
515
    #[test]
516
1
    fn test_random_setting_bytevector() {
517
1
        let rng = rand::rng();
518

            
519
1
        let range = Uniform::new(0, usize::MAX).unwrap();
520
1
        let expected_vector: Vec<usize> = rng.sample_iter(range).take(100).collect();
521
1
        let mut vector = bytevec![0; 100];
522

            
523
100
        for (index, element) in expected_vector.iter().enumerate() {
524
100
            vector.set(index, *element);
525
100
        }
526

            
527
100
        for (expected, element) in expected_vector.iter().zip(vector.iter()) {
528
100
            assert_eq!(*expected, element);
529
        }
530
1
    }
531

            
532
    #[test]
533
1
    fn test_random_usize_entry() {
534
100
        random_test(100, |rng| {
535
100
            let value = rng.random_range(0..1024);
536
100
            assert!(value.bytes_required() <= 2);
537

            
538
100
            let mut bytes = [0; 2];
539
100
            value.to_bytes(&mut bytes);
540
100
            assert_eq!(usize::from_bytes(&bytes), value);
541
100
        });
542
1
    }
543

            
544
    #[test]
545
1
    fn test_swap() {
546
1
        let mut vec = ByteCompressedVec::new();
547
1
        vec.push(1);
548
1
        vec.push(256);
549
1
        vec.push(65536);
550

            
551
1
        vec.swap(0, 2);
552

            
553
1
        assert_eq!(vec.index(0), 65536);
554
1
        assert_eq!(vec.index(1), 256);
555
1
        assert_eq!(vec.index(2), 1);
556
1
    }
557

            
558
    #[test]
559
1
    fn test_random_bytevector_permute() {
560
100
        random_test(100, |rng| {
561
            // Generate random vector to permute
562
100
            let elements = (0..rng.random_range(1..10))
563
529
                .map(|_| rng.random_range(0..100))
564
100
                .collect::<Vec<_>>();
565

            
566
100
            let vec = ByteCompressedVec::with_iter(elements.iter().cloned());
567

            
568
200
            for is_inverse in [false, true] {
569
200
                println!("Inverse: {is_inverse}, Input: {:?}", vec);
570

            
571
200
                let permutation = {
572
200
                    let mut order: Vec<usize> = (0..elements.len()).collect();
573
200
                    order.shuffle(rng);
574
200
                    order
575
                };
576

            
577
200
                let mut permutated = vec.clone();
578
200
                if is_inverse {
579
2116
                    permutated.permute_indices(|i| permutation[i]);
580
                } else {
581
2116
                    permutated.permute(|i| permutation[i]);
582
                }
583

            
584
200
                println!("Permutation: {:?}", permutation);
585
200
                println!("After permutation: {:?}", permutated);
586

            
587
                // Check that the permutation was applied correctly
588
1058
                for i in 0..elements.len() {
589
1058
                    let pos = if is_inverse {
590
529
                        permutation[i]
591
                    } else {
592
529
                        permutation
593
529
                            .iter()
594
1989
                            .position(|&j| i == j)
595
529
                            .expect("Should find inverse mapping")
596
                    };
597

            
598
1058
                    debug_assert_eq!(
599
1058
                        permutated.index(i),
600
1058
                        elements[pos],
601
                        "Element at index {} should be {}",
602
                        i,
603
                        elements[pos]
604
                    );
605
                }
606
            }
607
100
        });
608
1
    }
609
}