1
use std::fmt;
2
use std::marker::PhantomData;
3

            
4
use bitvec::bitvec;
5
use bitvec::order::Lsb0;
6
use delegate::delegate;
7
use log::trace;
8

            
9
use merc_io::BytesFormatter;
10
use merc_utilities::TagIndex;
11
use merc_utilities::debug_trace;
12
use merc_utilities::is_valid_permutation;
13

            
14
/// A copy of `vec![]` that can be used for the [`crate::ByteCompressedVec`].
15
#[macro_export]
16
macro_rules! bytevec {
17
    () => {
18
        $crate::ByteCompressedVec::new()
19
    };
20
    ($elem:expr; $n:expr) => {
21
        $crate::ByteCompressedVec::from_elem($elem, $n)
22
    };
23
}
24

            
25
/// A vector data structure that stores objects in a byte compressed format. The
26
/// basic idea is that elements of type `T` implement the `CompressedEntry`
27
/// trait which allows them to be converted to and from a byte representation.
28
/// The vector dynamically adjusts the number of bytes used per entry based on
29
/// the maximum size of the entries added so far.
30
///
31
/// For numbers this means that we only store the number of bytes required to
32
/// represent the largest number added so far. Note that the number of bytes
33
/// used per entry is only increased over time as larger entries are added.
34
///
35
/// TODO: The `drop()` function of `T` is never called.
36
#[derive(Default, PartialEq, Eq, Clone)]
37
pub struct ByteCompressedVec<T> {
38
    data: Vec<u8>,
39
    bytes_per_entry: usize,
40
    _marker: PhantomData<T>,
41
}
42

            
43
impl<T: CompressedEntry> ByteCompressedVec<T> {
44
30928
    pub fn new() -> ByteCompressedVec<T> {
45
30928
        ByteCompressedVec {
46
30928
            data: Vec::new(),
47
30928
            bytes_per_entry: 0,
48
30928
            _marker: PhantomData,
49
30928
        }
50
30928
    }
51

            
52
    /// Initializes a ByteCompressedVec with the given capacity and (minimal) bytes per entry.
53
71516
    pub fn with_capacity(capacity: usize, bytes_per_entry: usize) -> ByteCompressedVec<T> {
54
71516
        ByteCompressedVec {
55
71516
            data: Vec::with_capacity(capacity * bytes_per_entry),
56
71516
            bytes_per_entry,
57
71516
            _marker: PhantomData,
58
71516
        }
59
71516
    }
60

            
61
    /// This is basically the collect() of `Vec`.
62
    ///
63
    /// However, we use it to determine the required bytes per entry in advance.
64
100
    pub fn with_iter<I>(iter: I) -> ByteCompressedVec<T>
65
100
    where
66
100
        I: ExactSizeIterator<Item = T> + Clone,
67
    {
68
100
        let bytes_per_entry = iter
69
100
            .clone()
70
491
            .fold(0, |max_bytes, entry| max_bytes.max(entry.bytes_required()));
71

            
72
100
        let mut vec = ByteCompressedVec::with_capacity(iter.len(), bytes_per_entry);
73
491
        for entry in iter {
74
491
            vec.push(entry);
75
491
        }
76
100
        vec
77
100
    }
78

            
79
    /// Adds a new entry to the vector.
80
3955052
    pub fn push(&mut self, entry: T) {
81
3955052
        self.resize_entries(entry.bytes_required());
82

            
83
        // Add the new entry to the end of the vector.
84
3955052
        let old_len = self.data.len();
85
3955052
        self.data.resize(old_len + self.bytes_per_entry, 0);
86
3955052
        entry.to_bytes(&mut self.data[old_len..]);
87
3955052
    }
88

            
89
    /// Removes the last element from the vector and returns it, or None if it is empty.
90
400
    pub fn pop(&mut self) -> Option<T> {
91
400
        if self.is_empty() {
92
            None
93
        } else {
94
400
            let index = self.len() - 1;
95
400
            let entry = self.index(index);
96
400
            self.data.truncate(index * self.bytes_per_entry);
97
400
            Some(entry)
98
        }
99
400
    }
100

            
101
    /// Returns the entry at the given index.
102
13665945
    pub fn index(&self, index: usize) -> T {
103
13665945
        let start = index * self.bytes_per_entry;
104
13665945
        let end = start + self.bytes_per_entry;
105
13665945
        T::from_bytes(&self.data[start..end])
106
13665945
    }
107

            
108
    /// Sets the entry at the given index.
109
5741124
    pub fn set(&mut self, index: usize, entry: T) {
110
5741124
        self.resize_entries(entry.bytes_required());
111

            
112
5741124
        let start = index * self.bytes_per_entry;
113
5741124
        let end = start + self.bytes_per_entry;
114
5741124
        entry.to_bytes(&mut self.data[start..end]);
115
5741124
    }
116

            
117
    /// Returns the number of elements in the vector.
118
1885032
    pub fn len(&self) -> usize {
119
1885032
        if self.bytes_per_entry == 0 {
120
61857
            0
121
        } else {
122
1823175
            debug_assert!(self.data.len().is_multiple_of(self.bytes_per_entry));
123
1823175
            self.data.len() / self.bytes_per_entry
124
        }
125
1885032
    }
126

            
127
    /// Returns true if the vector is empty.
128
400
    pub fn is_empty(&self) -> bool {
129
400
        self.len() == 0
130
400
    }
131

            
132
    /// Returns metrics about memory usage of this compressed vector
133
    pub fn metrics(&self) -> CompressedVecMetrics {
134
        let element_count = self.len();
135
        let actual_memory =
136
            self.data.len() + std::mem::size_of_val(&self.bytes_per_entry) + std::mem::size_of::<PhantomData<T>>();
137
        let worst_case_memory = element_count * std::mem::size_of::<T>();
138

            
139
        CompressedVecMetrics {
140
            actual_memory,
141
            worst_case_memory,
142
        }
143
    }
144

            
145
    /// Returns an iterator over the elements in the vector.
146
3467
    pub fn iter(&self) -> ByteCompressedVecIterator<'_, T> {
147
3467
        ByteCompressedVecIterator {
148
3467
            vector: self,
149
3467
            current: 0,
150
3467
            end: self.len(),
151
3467
        }
152
3467
    }
153

            
154
    /// Returns an iterator over the elements in the vector for the begin, end range.
155
    pub fn iter_range(&self, begin: usize, end: usize) -> ByteCompressedVecIterator<'_, T> {
156
        ByteCompressedVecIterator {
157
            vector: self,
158
            current: begin,
159
            end,
160
        }
161
    }
162

            
163
    /// Updates the given entry using a closure.
164
1848382
    pub fn update<F>(&mut self, index: usize, mut update: F)
165
1848382
    where
166
1848382
        F: FnMut(&mut T),
167
    {
168
1848382
        let mut entry = self.index(index);
169
1848382
        update(&mut entry);
170
1848382
        self.set(index, entry);
171
1848382
    }
172

            
173
    /// Iterate over all elements and adapt the elements using a closure.
174
    pub fn map<F>(&mut self, mut f: F)
175
    where
176
        F: FnMut(&mut T),
177
    {
178
        for index in 0..self.len() {
179
            let mut entry = self.index(index);
180
            f(&mut entry);
181
            self.set(index, entry);
182
        }
183
    }
184

            
185
    /// Folds over the elements in the vector using the provided closure.
186
25046
    pub fn fold<B, F>(&mut self, init: B, mut f: F) -> B
187
25046
    where
188
25046
        F: FnMut(B, &mut T) -> B,
189
    {
190
25046
        let mut accumulator = init;
191
374040
        for index in 0..self.len() {
192
374040
            let mut element = self.index(index);
193
374040
            accumulator = f(accumulator, &mut element);
194
374040
            self.set(index, element);
195
374040
        }
196
25046
        accumulator
197
25046
    }
198

            
199
    /// Permutes a vector in place according to the given permutation function.
200
    ///
201
    /// The resulting vector will be [v_p^-1(0), v_p^-1(1), ..., v_p^-1(n-1)] where p is the permutation function.
202
100
    pub fn permute<P>(&mut self, permutation: P)
203
100
    where
204
100
        P: Fn(usize) -> usize,
205
    {
206
100
        debug_assert!(
207
100
            is_valid_permutation(&permutation, self.len()),
208
            "The given permutation must be a bijective mapping"
209
        );
210

            
211
100
        let mut visited = bitvec![usize, Lsb0; 0; self.len()];
212
491
        for start in 0..self.len() {
213
491
            if visited[start] {
214
281
                continue;
215
210
            }
216

            
217
            // Perform the cycle starting at 'start'
218
210
            let mut current = start;
219

            
220
            // Keeps track of the last displaced element
221
210
            let mut old = self.index(start);
222

            
223
210
            debug_trace!("Starting new cycle at position {}", start);
224
701
            while !visited[current] {
225
491
                visited.set(current, true);
226
491
                let next = permutation(current);
227
491
                if next != current {
228
393
                    debug_trace!("Moving element from position {} to position {}", current, next);
229
393
                    let temp = self.index(next);
230
393
                    self.set(next, old);
231
393
                    old = temp;
232
393
                }
233

            
234
491
                current = next;
235
            }
236
        }
237
100
    }
238

            
239
    /// Applies a permutation to a vector in place using an index function.
240
    ///
241
    /// The resulting vector will be [v_p(0), v_p(1), ..., v_p(n-1)] where p is the index function.
242
100
    pub fn permute_indices<P>(&mut self, indices: P)
243
100
    where
244
100
        P: Fn(usize) -> usize,
245
    {
246
100
        debug_assert!(
247
100
            is_valid_permutation(&indices, self.len()),
248
            "The given permutation must be a bijective mapping"
249
        );
250

            
251
100
        let mut visited = bitvec![usize, Lsb0; 0; self.len()];
252
491
        for start in 0..self.len() {
253
491
            if visited[start] {
254
275
                continue;
255
216
            }
256

            
257
            // Follow the cycle starting at 'start'
258
216
            debug_trace!("Starting new cycle at position {}", start);
259
216
            let mut current = start;
260
216
            let original = self.index(start);
261

            
262
593
            while !visited[current] {
263
491
                visited.set(current, true);
264
491
                let next = indices(current);
265

            
266
491
                if next != current {
267
389
                    if next != start {
268
275
                        debug_trace!("Moving element from position {} to position {}", current, next);
269
275
                        self.set(current, self.index(next));
270
275
                    } else {
271
114
                        break;
272
                    }
273
102
                }
274

            
275
377
                current = next;
276
            }
277

            
278
216
            trace!("Writing original to {}", current);
279
216
            self.set(current, original);
280
        }
281
100
    }
282

            
283
    /// Applies a permutation to a vector in place using an index function.
284
    ///
285
    /// This variant is faster but requires additional memory for the intermediate result vector.
286
    pub fn permute_indices_fast<P>(&mut self, indices: P)
287
    where
288
        P: Fn(usize) -> usize,
289
    {
290
        let mut result = ByteCompressedVec::with_capacity(self.data.capacity(), self.bytes_per_entry);
291
        for index in 0..self.len() {
292
            result.push(self.index(indices(index)));
293
        }
294
        *self = result;
295
    }
296

            
297
    /// Swaps the entries at the given indices.
298
1
    pub fn swap(&mut self, index1: usize, index2: usize) {
299
1
        if index1 != index2 {
300
1
            let start1 = index1 * self.bytes_per_entry;
301
1
            let start2 = index2 * self.bytes_per_entry;
302
1

            
303
1
            // Create a temporary buffer for one entry
304
1
            let temp = T::from_bytes(&self.data[start1..start1 + self.bytes_per_entry]);
305
1

            
306
1
            // Copy entry2 to entry1's position
307
1
            self.data.copy_within(start2..start2 + self.bytes_per_entry, start1);
308
1

            
309
1
            // Copy temp to entry2's position
310
1
            temp.to_bytes(&mut self.data[start2..start2 + self.bytes_per_entry]);
311
1
        }
312
1
    }
313

            
314
    /// Resizes the vector to the given length, filling new entries with the provided value.
315
30923
    pub fn resize_with<F>(&mut self, new_len: usize, mut f: F)
316
30923
    where
317
30923
        F: FnMut() -> T,
318
    {
319
30923
        let current_len = self.len();
320
30923
        if new_len > current_len {
321
            // Preallocate the required space.
322
30923
            self.data.reserve(new_len * self.bytes_per_entry);
323
341370
            for _ in current_len..new_len {
324
341370
                self.push(f());
325
341370
            }
326
        } else if new_len < current_len {
327
            if new_len == 0 {
328
                self.data.clear();
329
                self.bytes_per_entry = 0;
330
            } else {
331
                // It could be that the bytes per entry is now less, but that we never reduce.
332
                self.data.truncate(new_len * self.bytes_per_entry);
333
            }
334
        }
335
30923
    }
336

            
337
    /// Reserves capacity for at least additional more entries to be inserted with the given bytes per entry.
338
1200
    pub fn reserve(&mut self, additional: usize, bytes_per_entry: usize) {
339
1200
        self.resize_entries(bytes_per_entry);
340
1200
        self.data.reserve(additional * self.bytes_per_entry);
341
1200
    }
342

            
343
    /// Resizes all entries in the vector to the given length.
344
9700076
    fn resize_entries(&mut self, new_bytes_required: usize) {
345
9700076
        if new_bytes_required > self.bytes_per_entry {
346
31968
            let mut new_data: Vec<u8> = vec![0; self.len() * new_bytes_required];
347

            
348
31968
            if self.bytes_per_entry > 0 {
349
                // Resize all the existing elements because the new entry requires more bytes.
350
122877
                for (index, entry) in self.iter().enumerate() {
351
122877
                    let start = index * new_bytes_required;
352
122877
                    let end = start + new_bytes_required;
353
122877
                    entry.to_bytes(&mut new_data[start..end]);
354
122877
                }
355
30928
            }
356

            
357
31968
            self.bytes_per_entry = new_bytes_required;
358
31968
            self.data = new_data;
359
9668108
        }
360
9700076
    }
361
}
362

            
363
impl<T: CompressedEntry + Clone> ByteCompressedVec<T> {
364
70453
    pub fn from_elem(entry: T, n: usize) -> ByteCompressedVec<T> {
365
70453
        let mut vec = ByteCompressedVec::with_capacity(n, entry.bytes_required());
366
2384336
        for _ in 0..n {
367
2384336
            vec.push(entry.clone());
368
2384336
        }
369
70453
        vec
370
70453
    }
371
}
372

            
373
/// Metrics for tracking memory usage of a ByteCompressedVec
374
#[derive(Debug, Clone)]
375
pub struct CompressedVecMetrics {
376
    /// Actual memory used by the compressed vector (in bytes)
377
    pub actual_memory: usize,
378
    /// Worst-case memory that would be used by an uncompressed vector (len * sizeof(T))
379
    pub worst_case_memory: usize,
380
}
381

            
382
impl CompressedVecMetrics {
383
    /// Calculate memory savings in bytes
384
    pub fn memory_savings(&self) -> usize {
385
        self.worst_case_memory.saturating_sub(self.actual_memory)
386
    }
387

            
388
    /// Calculate memory savings as a percentage
389
    pub fn used_percentage(&self) -> f64 {
390
        if self.worst_case_memory == 0 {
391
            0.0
392
        } else {
393
            (self.actual_memory as f64 / self.worst_case_memory as f64) * 100.0
394
        }
395
    }
396
}
397

            
398
impl fmt::Display for CompressedVecMetrics {
399
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
400
        write!(
401
            f,
402
            "memory: {} ({:.1}%), saving: {} ",
403
            BytesFormatter(self.actual_memory),
404
            self.used_percentage(),
405
            BytesFormatter(self.memory_savings()),
406
        )
407
    }
408
}
409
pub struct ByteCompressedVecIterator<'a, T> {
410
    vector: &'a ByteCompressedVec<T>,
411
    current: usize,
412
    end: usize,
413
}
414

            
415
impl<T: CompressedEntry> Iterator for ByteCompressedVecIterator<'_, T> {
416
    type Item = T;
417

            
418
1843348
    fn next(&mut self) -> Option<Self::Item> {
419
1843348
        if self.current < self.end {
420
1841167
            let result = self.vector.index(self.current);
421
1841167
            self.current += 1;
422
1841167
            Some(result)
423
        } else {
424
2181
            None
425
        }
426
1843348
    }
427
}
428

            
429
pub trait CompressedEntry {
430
    // Returns the entry as a byte vector
431
    fn to_bytes(&self, bytes: &mut [u8]);
432

            
433
    // Creates an entry from a byte vector
434
    fn from_bytes(bytes: &[u8]) -> Self;
435

            
436
    // Returns the number of bytes required to store the current entry
437
    fn bytes_required(&self) -> usize;
438
}
439

            
440
impl CompressedEntry for usize {
441
79285672
    fn to_bytes(&self, bytes: &mut [u8]) {
442
79285672
        let array = &self.to_le_bytes();
443
91214150
        for (i, byte) in bytes.iter_mut().enumerate().take(usize::BITS as usize / 8) {
444
91214150
            *byte = array[i];
445
91214150
        }
446
79285672
    }
447

            
448
93237018
    fn from_bytes(bytes: &[u8]) -> Self {
449
93237018
        let mut array = [0; 8];
450
106888155
        for (i, byte) in bytes.iter().enumerate().take(usize::BITS as usize / 8) {
451
106888155
            array[i] = *byte;
452
106888155
        }
453
93237018
        usize::from_le_bytes(array)
454
93237018
    }
455

            
456
78550181
    fn bytes_required(&self) -> usize {
457
78550181
        ((self + 1).ilog2() / u8::BITS) as usize + 1
458
78550181
    }
459
}
460

            
461
impl<T: CompressedEntry + fmt::Debug> fmt::Debug for ByteCompressedVec<T> {
462
400
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
463
400
        f.debug_list().entries(self.iter()).finish()
464
400
    }
465
}
466

            
467
/// Implement it for the TagIndex for convenience.
468
impl<T: CompressedEntry + Copy, Tag> CompressedEntry for TagIndex<T, Tag> {
469
    delegate! {
470
        to self.value() {
471
5909718
            fn to_bytes(&self, bytes: &mut [u8]);
472
5967247
            fn bytes_required(&self) -> usize;
473
        }
474
    }
475

            
476
7768887
    fn from_bytes(bytes: &[u8]) -> Self {
477
7768887
        TagIndex::new(T::from_bytes(bytes))
478
7768887
    }
479
}
480

            
481
#[cfg(test)]
482
mod tests {
483
    use super::*;
484

            
485
    use rand::Rng;
486
    use rand::distr::Uniform;
487
    use rand::seq::SliceRandom;
488

            
489
    use merc_utilities::random_test;
490

            
491
    #[test]
492
1
    fn test_index_bytevector() {
493
1
        let mut vec = ByteCompressedVec::new();
494
1
        vec.push(1);
495
1
        assert_eq!(vec.len(), 1);
496

            
497
1
        vec.push(1024);
498
1
        assert_eq!(vec.len(), 2);
499

            
500
1
        assert_eq!(vec.index(0), 1);
501
1
        assert_eq!(vec.index(1), 1024);
502
1
    }
503

            
504
    #[test]
505
1
    fn test_random_bytevector() {
506
1
        let rng = rand::rng();
507

            
508
1
        let range = Uniform::new(0, usize::MAX).unwrap();
509
1
        let expected_vector: Vec<usize> = rng.sample_iter(range).take(100).collect();
510
1
        let mut vector = ByteCompressedVec::new();
511

            
512
100
        for element in &expected_vector {
513
100
            vector.push(*element);
514

            
515
5050
            for (expected, element) in expected_vector.iter().zip(vector.iter()) {
516
5050
                assert_eq!(*expected, element);
517
            }
518
        }
519
1
    }
520

            
521
    #[test]
522
1
    fn test_random_setting_bytevector() {
523
1
        let rng = rand::rng();
524

            
525
1
        let range = Uniform::new(0, usize::MAX).unwrap();
526
1
        let expected_vector: Vec<usize> = rng.sample_iter(range).take(100).collect();
527
1
        let mut vector = bytevec![0; 100];
528

            
529
100
        for (index, element) in expected_vector.iter().enumerate() {
530
100
            vector.set(index, *element);
531
100
        }
532

            
533
100
        for (expected, element) in expected_vector.iter().zip(vector.iter()) {
534
100
            assert_eq!(*expected, element);
535
        }
536
1
    }
537

            
538
    #[test]
539
1
    fn test_random_usize_entry() {
540
100
        random_test(100, |rng| {
541
100
            let value = rng.random_range(0..1024);
542
100
            assert!(value.bytes_required() <= 2);
543

            
544
100
            let mut bytes = [0; 2];
545
100
            value.to_bytes(&mut bytes);
546
100
            assert_eq!(usize::from_bytes(&bytes), value);
547
100
        });
548
1
    }
549

            
550
    #[test]
551
1
    fn test_swap() {
552
1
        let mut vec = ByteCompressedVec::new();
553
1
        vec.push(1);
554
1
        vec.push(256);
555
1
        vec.push(65536);
556

            
557
1
        vec.swap(0, 2);
558

            
559
1
        assert_eq!(vec.index(0), 65536);
560
1
        assert_eq!(vec.index(1), 256);
561
1
        assert_eq!(vec.index(2), 1);
562
1
    }
563

            
564
    #[test]
565
1
    fn test_random_bytevector_permute() {
566
100
        random_test(100, |rng| {
567
            // Generate random vector to permute
568
100
            let elements = (0..rng.random_range(1..10))
569
491
                .map(|_| rng.random_range(0..100))
570
100
                .collect::<Vec<_>>();
571

            
572
100
            let vec = ByteCompressedVec::with_iter(elements.iter().cloned());
573

            
574
200
            for is_inverse in [false, true] {
575
200
                println!("Inverse: {is_inverse}, Input: {:?}", vec);
576

            
577
200
                let permutation = {
578
200
                    let mut order: Vec<usize> = (0..elements.len()).collect();
579
200
                    order.shuffle(rng);
580
200
                    order
581
                };
582

            
583
200
                let mut permutated = vec.clone();
584
200
                if is_inverse {
585
1964
                    permutated.permute_indices(|i| permutation[i]);
586
                } else {
587
1964
                    permutated.permute(|i| permutation[i]);
588
                }
589

            
590
200
                println!("Permutation: {:?}", permutation);
591
200
                println!("After permutation: {:?}", permutated);
592

            
593
                // Check that the permutation was applied correctly
594
982
                for i in 0..elements.len() {
595
982
                    let pos = if is_inverse {
596
491
                        permutation[i]
597
                    } else {
598
491
                        permutation
599
491
                            .iter()
600
1798
                            .position(|&j| i == j)
601
491
                            .expect("Should find inverse mapping")
602
                    };
603

            
604
982
                    debug_assert_eq!(
605
982
                        permutated.index(i),
606
982
                        elements[pos],
607
                        "Element at index {} should be {}",
608
                        i,
609
                        elements[pos]
610
                    );
611
                }
612
            }
613
100
        });
614
1
    }
615
}