1
#![forbid(unsafe_code)]
2

            
3
use std::fmt;
4

            
5
use itertools::Itertools;
6

            
7
use merc_collections::BlockIndex;
8
use merc_lts::IncomingTransitions;
9
use merc_lts::StateIndex;
10

            
11
use super::Partition;
12

            
13
/// A partition that explicitly stores a list of blocks and their indexing into
14
/// the list of elements.
15
#[derive(Debug)]
16
pub struct BlockPartition {
17
    elements: Vec<StateIndex>,
18
    blocks: Vec<Block>,
19

            
20
    // These are only used to provide O(1) marking of elements.
21
    /// Stores the block index for each element.
22
    element_to_block: Vec<BlockIndex>,
23

            
24
    /// Stores the offset within the block for every element.
25
    element_offset: Vec<usize>,
26
}
27

            
28
impl BlockPartition {
29
    /// Create an initial partition where all the states are in a single block
30
    /// 0. And all the elements in the block are marked.
31
1202
    pub fn new(num_of_elements: usize) -> BlockPartition {
32
1202
        debug_assert!(num_of_elements > 0, "Cannot partition the empty set");
33

            
34
1202
        let blocks = vec![Block::new(0, num_of_elements)];
35
1202
        let elements = (0..num_of_elements).map(StateIndex::new).collect();
36
1202
        let element_to_block = vec![BlockIndex::new(0); num_of_elements];
37
1202
        let element_to_block_offset = (0..num_of_elements).collect();
38

            
39
1202
        BlockPartition {
40
1202
            elements,
41
1202
            element_to_block,
42
1202
            element_offset: element_to_block_offset,
43
1202
            blocks,
44
1202
        }
45
1202
    }
46

            
47
    /// Partition the elements of the given block into multiple new blocks based
48
    /// on the given partitioner; which returns a number for each marked
49
    /// element. Elements with the same number belong to the same block, and the
50
    /// returned numbers should be dense.
51
    ///
52
    /// Returns an iterator over the new block indices, where the first element
53
    /// is the index of the block that was partitioned. And that block is the
54
    /// largest block.
55
150428
    pub fn partition_marked_with<F>(
56
150428
        &mut self,
57
150428
        block_index: BlockIndex,
58
150428
        builder: &mut BlockPartitionBuilder,
59
150428
        mut partitioner: F,
60
150428
    ) -> impl Iterator<Item = BlockIndex> + use<F>
61
150428
    where
62
150428
        F: FnMut(StateIndex, &BlockPartition) -> BlockIndex,
63
    {
64
150428
        let block = self.blocks[block_index];
65
150428
        debug_assert!(
66
150428
            block.has_marked(),
67
            "Cannot partition marked elements of a block without marked elements"
68
        );
69

            
70
150428
        if block.len() == 1 {
71
            // Block only has one element, so trivially partitioned.
72
62698
            self.blocks[block_index].unmark_all();
73
            // Note that all the returned iterators MUST have the same type, but we cannot chain typed_index since Step is an unstable trait.
74
62698
            return (block_index.value()..=block_index.value())
75
62698
                .chain(0..0)
76
62698
                .map(BlockIndex::new);
77
87730
        }
78

            
79
        // Keeps track of the block index for every element in this block by index.
80
87730
        builder.index_to_block.clear();
81
87730
        builder.block_sizes.clear();
82
87730
        builder.old_elements.clear();
83

            
84
87730
        builder.index_to_block.resize(block.len_marked(), BlockIndex::new(0));
85

            
86
        // O(n log n) Loop through the marked elements in order (to maintain topological sorting)
87
87730
        builder.old_elements.extend(block.iter_marked(&self.elements));
88
87730
        builder.old_elements.sort_unstable();
89

            
90
        // O(n) Loop over marked elements to determine the number of the new block each element is in.
91
1086541
        for (element_index, &element) in builder.old_elements.iter().enumerate() {
92
1086541
            let number = partitioner(element, self);
93

            
94
1086541
            builder.index_to_block[element_index] = number;
95
1086541
            if number.value() + 1 > builder.block_sizes.len() {
96
291588
                builder.block_sizes.resize(number.value() + 1, 0);
97
794953
            }
98

            
99
1086541
            builder.block_sizes[number] += 1;
100
        }
101

            
102
        // Convert block sizes into block offsets.
103
87730
        let end_of_blocks = self.blocks.len();
104
87730
        let new_block_index = if block.has_unmarked() {
105
35979
            self.blocks.len()
106
        } else {
107
51751
            self.blocks.len() - 1
108
        };
109

            
110
291588
        let _ = builder.block_sizes.iter_mut().fold(0usize, |current, size| {
111
291588
            debug_assert!(*size > 0, "Partition is not dense, there are empty blocks");
112

            
113
291588
            let current = if current == 0 {
114
87730
                if block.has_unmarked() {
115
                    // Adapt the offsets of the current block to only include the unmarked elements.
116
35979
                    self.blocks[block_index] = Block::new_unmarked(block.begin, block.marked_split);
117

            
118
                    // Introduce a new block for the zero block.
119
35979
                    self.blocks
120
35979
                        .push(Block::new_unmarked(block.marked_split, block.marked_split + *size));
121
35979
                    block.marked_split
122
                } else {
123
                    // Use this as the zero block.
124
51751
                    self.blocks[block_index] = Block::new_unmarked(block.begin, block.begin + *size);
125
51751
                    block.begin
126
                }
127
            } else {
128
                // Introduce a new block for every other non-empty block.
129
203858
                self.blocks.push(Block::new_unmarked(current, current + *size));
130
203858
                current
131
            };
132

            
133
291588
            let offset = current + *size;
134
291588
            *size = current;
135
291588
            offset
136
291588
        });
137
87730
        let block_offsets = &mut builder.block_sizes;
138

            
139
1086541
        for (index, offset_block_index) in builder.index_to_block.iter().enumerate() {
140
            // Swap the element to the correct position.
141
1086541
            let element = builder.old_elements[index];
142
1086541
            self.elements[block_offsets[*offset_block_index]] = builder.old_elements[index];
143
1086541
            self.element_offset[element] = block_offsets[*offset_block_index];
144
1086541
            self.element_to_block[element] = if *offset_block_index == 0 && !block.has_unmarked() {
145
192986
                block_index
146
            } else {
147
893555
                BlockIndex::new(new_block_index + offset_block_index.value())
148
            };
149

            
150
            // Update the offset for this block.
151
1086541
            block_offsets[*offset_block_index] += 1;
152
        }
153

            
154
        // Swap the first block and the maximum sized block.
155
87730
        let max_block_index = (block_index.value()..=block_index.value())
156
87730
            .chain(end_of_blocks..self.blocks.len())
157
87730
            .map(BlockIndex::new)
158
327567
            .max_by_key(|block_index| self.block(*block_index).len())
159
87730
            .unwrap();
160
87730
        self.swap_blocks(block_index, max_block_index);
161

            
162
87730
        self.assert_consistent();
163

            
164
87730
        (block_index.value()..=block_index.value())
165
87730
            .chain(end_of_blocks..self.blocks.len())
166
87730
            .map(BlockIndex::new)
167
150428
    }
168

            
169
    /// Split the given block into two separate block based on the splitter
170
    /// predicate.
171
3
    pub fn split_marked(&mut self, block_index: usize, mut splitter: impl FnMut(StateIndex) -> bool) {
172
3
        let mut updated_block = self.blocks[block_index];
173
3
        let mut new_block: Option<Block> = None;
174

            
175
        // Loop over all elements, we use a while loop since the index stays the
176
        // same when a swap takes place.
177
3
        let mut element_index = updated_block.marked_split;
178
23
        while element_index < updated_block.end {
179
20
            let element = self.elements[element_index];
180
20
            if splitter(element) {
181
10
                match &mut new_block {
182
3
                    None => {
183
3
                        new_block = Some(Block::new_unmarked(updated_block.end - 1, updated_block.end));
184
3

            
185
3
                        // Swap the current element to the last place
186
3
                        self.swap_elements(element_index, updated_block.end - 1);
187
3
                        updated_block.end -= 1;
188
3
                    }
189
7
                    Some(new_block_index) => {
190
7
                        // Swap the current element to the beginning of the new block.
191
7
                        new_block_index.begin -= 1;
192
7
                        updated_block.end -= 1;
193
7

            
194
7
                        self.swap_elements(element_index, new_block_index.begin);
195
7
                    }
196
                }
197
10
            } else {
198
10
                // If no swap takes place consider the next index.
199
10
                element_index += 1;
200
10
            }
201
        }
202

            
203
3
        if let Some(new_block) = new_block
204
3
            && (updated_block.end - updated_block.begin) != 0
205
        {
206
            // A new block was introduced, so we need to update the current
207
            // block. Unless the current block is empty in which case
208
            // nothing changes.
209
2
            updated_block.unmark_all();
210
2
            self.blocks[block_index] = updated_block;
211

            
212
            // Introduce a new block for the split, containing only the new element.
213
2
            self.blocks.push(new_block);
214

            
215
            // Update the elements for the new block
216
7
            for element in new_block.iter(&self.elements) {
217
7
                self.element_to_block[element] = BlockIndex::new(self.blocks.len() - 1);
218
7
            }
219
1
        }
220

            
221
3
        self.assert_consistent();
222
3
    }
223

            
224
    /// Makes the marked elements closed under the silent closure of incoming
225
    /// tau-transitions within the current block.
226
95377
    pub fn mark_backward_closure(&mut self, block_index: BlockIndex, incoming_transitions: &IncomingTransitions) {
227
95377
        let block = self.blocks[block_index];
228
95377
        let mut it = block.end - 1;
229

            
230
        // First compute backwards silent transitive closure.
231
323680
        while it >= self.blocks[block_index].marked_split && self.blocks[block_index].has_unmarked() {
232
228303
            for transition in incoming_transitions.incoming_silent_transitions(self.elements[it]) {
233
54472
                if self.block_number(transition.from) == block_index {
234
30442
                    self.mark_element(transition.from);
235
30442
                }
236
            }
237

            
238
228303
            if it == 0 {
239
                break;
240
228303
            }
241

            
242
228303
            it -= 1;
243
        }
244
95377
    }
245

            
246
    /// Swaps the given blocks given by the indices.
247
87730
    pub fn swap_blocks(&mut self, left_index: BlockIndex, right_index: BlockIndex) {
248
87730
        if left_index == right_index {
249
            // Nothing to do.
250
39439
            return;
251
48291
        }
252

            
253
48291
        self.blocks.swap(left_index.value(), right_index.value());
254

            
255
201592
        for element in self.block(left_index).iter(&self.elements) {
256
201592
            self.element_to_block[element] = left_index;
257
201592
        }
258

            
259
84238
        for element in self.block(right_index).iter(&self.elements) {
260
84238
            self.element_to_block[element] = right_index;
261
84238
        }
262

            
263
48291
        self.assert_consistent();
264
87730
    }
265

            
266
    /// Marks the given element, such that it is returned by iter_marked.
267
760373
    pub fn mark_element(&mut self, element: StateIndex) {
268
760373
        let block_index = self.element_to_block[element];
269
760373
        let offset = self.element_offset[element];
270
760373
        let marked_split = self.blocks[block_index].marked_split;
271

            
272
760373
        if offset < marked_split {
273
535608
            // Element was not already marked.
274
535608
            self.swap_elements(offset, marked_split - 1);
275
535608
            self.blocks[block_index].marked_split -= 1;
276
537469
        }
277

            
278
760373
        self.blocks[block_index].assert_consistent();
279
760373
    }
280

            
281
    /// Returns true iff the given element has already been marked.
282
113468
    pub fn is_element_marked(&self, element: StateIndex) -> bool {
283
113468
        let block_index = self.element_to_block[element];
284
113468
        let offset = self.element_offset[element];
285
113468
        let marked_split = self.blocks[block_index].marked_split;
286

            
287
113468
        offset >= marked_split
288
113468
    }
289

            
290
    /// Return a reference to the given block.
291
1304494
    pub fn block(&self, block_index: BlockIndex) -> &Block {
292
1304494
        &self.blocks[block_index]
293
1304494
    }
294

            
295
    /// Returns the number of blocks in the partition.
296
301480
    pub fn num_of_blocks(&self) -> usize {
297
301480
        self.blocks.len()
298
301480
    }
299

            
300
    /// Returns an iterator over the elements of a given block.
301
330688
    pub fn iter_block(&self, block_index: BlockIndex) -> BlockIter<'_> {
302
330688
        BlockIter {
303
330688
            elements: &self.elements,
304
330688
            index: self.blocks[block_index].begin,
305
330688
            end: self.blocks[block_index].end,
306
330688
        }
307
330688
    }
308

            
309
    /// Swaps the elements at the given indices and updates the element_to_block
310
535618
    fn swap_elements(&mut self, left_index: usize, right_index: usize) {
311
535618
        self.elements.swap(left_index, right_index);
312
535618
        self.element_offset[self.elements[left_index]] = left_index;
313
535618
        self.element_offset[self.elements[right_index]] = right_index;
314
535618
    }
315

            
316
    /// Returns true iff the invariants of a partition hold
317
136024
    fn assert_consistent(&self) -> bool {
318
136024
        if cfg!(debug_assertions) {
319
136024
            let mut marked = vec![false; self.elements.len()];
320

            
321
62129787
            for block in &self.blocks {
322
258143777
                for element in block.iter(&self.elements) {
323
258143777
                    debug_assert!(
324
258143777
                        !marked[element],
325
                        "Partition {self}, element {element} belongs to multiple blocks"
326
                    );
327
258143777
                    marked[element] = true;
328
                }
329

            
330
62129787
                block.assert_consistent();
331
            }
332

            
333
            // Check that every element belongs to a block.
334
136024
            debug_assert!(
335
136024
                !marked.contains(&false),
336
                "Partition {self} contains elements that do not belong to a block"
337
            );
338

            
339
            // Check that it belongs to the block indicated by element_to_block
340
258143777
            for (current_element, block_index) in self.element_to_block.iter().enumerate() {
341
258143777
                debug_assert!(
342
258143777
                    self.blocks[block_index.value()]
343
258143777
                        .iter(&self.elements)
344
35327129898
                        .any(|element| element == current_element),
345
                    "Partition {self:?}, element {current_element} does not belong to block {block_index} as indicated by element_to_block"
346
                );
347

            
348
258143777
                let index = self.element_offset[current_element];
349
258143777
                debug_assert_eq!(
350
258143777
                    self.elements[index], current_element,
351
                    "Partition {self:?}, element {current_element} does not have the correct offset in the block"
352
                );
353
            }
354
        }
355

            
356
136024
        true
357
136024
    }
358
}
359

            
360
#[derive(Default)]
361
pub struct BlockPartitionBuilder {
362
    // Keeps track of the block index for every element in this block by index.
363
    index_to_block: Vec<BlockIndex>,
364

            
365
    /// Keeps track of the size of each block.
366
    block_sizes: Vec<usize>,
367

            
368
    /// Stores the old elements to perform the swaps safely.
369
    old_elements: Vec<StateIndex>,
370
}
371

            
372
impl Partition for BlockPartition {
373
4757937
    fn block_number(&self, element: StateIndex) -> BlockIndex {
374
4757937
        self.element_to_block[element.value()]
375
4757937
    }
376

            
377
200
    fn num_of_blocks(&self) -> usize {
378
200
        self.blocks.len()
379
200
    }
380

            
381
2702
    fn len(&self) -> usize {
382
2702
        self.elements.len()
383
2702
    }
384
}
385

            
386
impl fmt::Display for BlockPartition {
387
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
388
        let blocks_str = self.blocks.iter().format_with(", ", |block, f| {
389
            let elements = block
390
                .iter_unmarked(&self.elements)
391
                .map(|e| (e, false))
392
                .chain(block.iter_marked(&self.elements).map(|e| (e, true)))
393
                .format_with(", ", |(e, marked), f| {
394
                    if marked {
395
                        f(&format_args!("{}*", e))
396
                    } else {
397
                        f(&format_args!("{}", e))
398
                    }
399
                });
400

            
401
            f(&format_args!("{{{}}}", elements))
402
        });
403

            
404
        write!(f, "{{{}}}", blocks_str)
405
    }
406
}
407

            
408
/// A block stores a subset of the elements in a partition.
409
///
410
/// # Details
411
///
412
/// A block uses `start`, `middle` and `end` indices to indicate a range
413
/// `start`..`end` of elements in the partition. The middle is used such that
414
/// `marked_split`..`end` are the marked elements. This is useful to be able to
415
/// split off new blocks cheaply.
416
///
417
/// Invariant: `start` <= `middle` <= `end` && `start` < `end`.
418
#[derive(Clone, Copy, Debug)]
419
pub struct Block {
420
    begin: usize,
421
    marked_split: usize,
422
    end: usize,
423
}
424

            
425
impl Block {
426
    /// Creates a new block where every element is marked.
427
1202
    pub fn new(begin: usize, end: usize) -> Block {
428
1202
        debug_assert!(begin < end, "The range of this block is incorrect");
429

            
430
1202
        Block {
431
1202
            begin,
432
1202
            marked_split: begin,
433
1202
            end,
434
1202
        }
435
1202
    }
436

            
437
327570
    pub fn new_unmarked(begin: usize, end: usize) -> Block {
438
327570
        debug_assert!(begin < end, "The range {begin} to {end} of this block is incorrect");
439

            
440
327570
        Block {
441
327570
            begin,
442
327570
            marked_split: end,
443
327570
            end,
444
327570
        }
445
327570
    }
446

            
447
    /// Returns an iterator over the elements in this block.
448
320370148
    pub fn iter<'a>(&self, elements: &'a Vec<StateIndex>) -> BlockIter<'a> {
449
320370148
        BlockIter {
450
320370148
            elements,
451
320370148
            index: self.begin,
452
320370148
            end: self.end,
453
320370148
        }
454
320370148
    }
455

            
456
    /// Returns an iterator over the marked elements in this block.
457
87730
    pub fn iter_marked<'a>(&self, elements: &'a Vec<StateIndex>) -> BlockIter<'a> {
458
87730
        BlockIter {
459
87730
            elements,
460
87730
            index: self.marked_split,
461
87730
            end: self.end,
462
87730
        }
463
87730
    }
464

            
465
    /// Returns an iterator over the unmarked elements in this block.
466
    pub fn iter_unmarked<'a>(&self, elements: &'a Vec<StateIndex>) -> BlockIter<'a> {
467
        BlockIter {
468
            elements,
469
            index: self.begin,
470
            end: self.marked_split,
471
        }
472
    }
473

            
474
    /// Returns true iff the block has marked elements.
475
1030773
    pub fn has_marked(&self) -> bool {
476
1030773
        self.assert_consistent();
477

            
478
1030773
        self.marked_split < self.end
479
1030773
    }
480

            
481
    /// Returns true iff the block has unmarked elements.
482
724324
    pub fn has_unmarked(&self) -> bool {
483
724324
        self.assert_consistent();
484

            
485
724324
        self.begin < self.marked_split
486
724324
    }
487

            
488
    /// Returns the number of elements in the block.
489
477995
    pub fn len(&self) -> usize {
490
477995
        self.assert_consistent();
491

            
492
477995
        self.end - self.begin
493
477995
    }
494

            
495
    /// Returns true iff the block is empty.
496
    pub fn is_empty(&self) -> bool {
497
        self.assert_consistent();
498

            
499
        self.begin == self.end
500
    }
501

            
502
    /// Returns the number of marked elements in the block.
503
87730
    pub fn len_marked(&self) -> usize {
504
87730
        self.assert_consistent();
505

            
506
87730
        self.end - self.marked_split
507
87730
    }
508

            
509
    /// Unmark all elements in the block.
510
62700
    fn unmark_all(&mut self) {
511
62700
        self.marked_split = self.end;
512
62700
    }
513

            
514
    /// Returns true iff the block is consistent.
515
65210982
    fn assert_consistent(self) {
516
65210982
        debug_assert!(self.begin < self.end, "The range of block {self:?} is incorrect",);
517

            
518
65210982
        debug_assert!(
519
65210982
            self.begin <= self.marked_split,
520
            "The marked_split lies before the beginning of the block {self:?}"
521
        );
522

            
523
65210982
        debug_assert!(
524
65210982
            self.marked_split <= self.end,
525
            "The marked_split lies after the beginning of the block {self:?}"
526
        );
527
65210982
    }
528
}
529

            
530
pub struct BlockIter<'a> {
531
    elements: &'a Vec<StateIndex>,
532
    index: usize,
533
    end: usize,
534
}
535

            
536
impl Iterator for BlockIter<'_> {
537
    type Item = StateIndex;
538

            
539
35650067053
    fn next(&mut self) -> Option<Self::Item> {
540
35650067053
        if self.index < self.end {
541
35587513115
            let element = self.elements[self.index];
542
35587513115
            self.index += 1;
543
35587513115
            Some(element)
544
        } else {
545
62553938
            None
546
        }
547
35650067053
    }
548
}
549

            
550
#[cfg(test)]
551
mod tests {
552
    use merc_lts::StateIndex;
553
    use test_log::test;
554

            
555
    use merc_collections::BlockIndex;
556

            
557
    use crate::BlockPartition;
558
    use crate::BlockPartitionBuilder;
559

            
560
    #[test]
561
    fn test_block_partition_split() {
562
        let mut partition = BlockPartition::new(10);
563

            
564
10
        partition.split_marked(0, |element| element < 3);
565

            
566
        // The new block only has elements that satisfy the predicate.
567
        for element in partition.iter_block(BlockIndex::new(1)) {
568
            assert!(element < 3);
569
        }
570

            
571
        for element in partition.iter_block(BlockIndex::new(0)) {
572
            assert!(element >= 3);
573
        }
574

            
575
        for i in (0..10).map(StateIndex::new) {
576
            partition.mark_element(i);
577
        }
578

            
579
7
        partition.split_marked(0, |element| element < 7);
580
        for element in partition.iter_block(BlockIndex::new(2)) {
581
            assert!((3..7).contains(&element.value()));
582
        }
583

            
584
        for element in partition.iter_block(BlockIndex::new(0)) {
585
            assert!(element >= 7);
586
        }
587

            
588
        // Test the case where all elements belong to the split block.
589
3
        partition.split_marked(1, |element| element < 7);
590
    }
591

            
592
    #[test]
593
    fn test_block_partition_partitioning() {
594
        // Test the partitioning function for a random assignment of elements
595
        let mut partition = BlockPartition::new(10);
596
        let mut builder = BlockPartitionBuilder::default();
597

            
598
10
        let _ = partition.partition_marked_with(BlockIndex::new(0), &mut builder, |element, _| match element.value() {
599
10
            0..=1 => BlockIndex::new(0),
600
8
            2..=6 => BlockIndex::new(1),
601
3
            _ => BlockIndex::new(2),
602
10
        });
603

            
604
        partition.mark_element(StateIndex::new(7));
605
        partition.mark_element(StateIndex::new(8));
606
2
        let _ = partition.partition_marked_with(BlockIndex::new(2), &mut builder, |element, _| match element.value() {
607
1
            7 => BlockIndex::new(0),
608
1
            8 => BlockIndex::new(1),
609
            _ => BlockIndex::new(2),
610
2
        });
611
    }
612
}