1
use rand::RngCore;
2
use rand::SeedableRng;
3
use rand::rngs::StdRng;
4

            
5
use crate::test_logger;
6

            
7
/// Constructs a random number generator that should be used in random tests. Prints its seed to the console for reproducibility.
8
57
pub fn random_test<F>(iterations: usize, mut test_function: F)
9
57
where
10
57
    F: FnMut(&mut StdRng),
11
{
12
57
    test_logger();
13

            
14
57
    if let Ok(seed_str) = std::env::var("MERC_SEED") {
15
        let seed = seed_str.parse::<u64>().expect("MERC_SEED must be a valid u64");
16
        println!("seed: {seed} (set by MERC_SEED)");
17
        let mut rng = StdRng::seed_from_u64(seed);
18
        for _ in 0..iterations {
19
            test_function(&mut rng);
20
        }
21
        return;
22
57
    }
23

            
24
57
    let seed: u64 = rand::random();
25
57
    println!("random seed: {seed} (use MERC_SEED=<seed> to set fixed seed)");
26
57
    let mut rng = StdRng::seed_from_u64(seed);
27

            
28
6520
    for _ in 0..iterations {
29
6520
        test_function(&mut rng);
30
6520
    }
31
57
}
32

            
33
1
pub fn random_test_threads<C, F, G>(iterations: usize, num_threads: usize, init_function: G, test_function: F)
34
1
where
35
1
    C: Send + 'static,
36
1
    F: Fn(&mut StdRng, &mut C) + Copy + Send + Sync + 'static,
37
1
    G: Fn() -> C,
38
{
39
1
    test_logger();
40

            
41
1
    let mut threads = vec![];
42

            
43
1
    let seed: u64 = rand::random();
44
1
    println!("seed: {seed}");
45
1
    let mut rng = StdRng::seed_from_u64(seed);
46

            
47
1
    for _ in 0..num_threads {
48
20
        let mut rng = StdRng::seed_from_u64(rng.next_u64());
49
20
        let mut init = init_function();
50
20
        threads.push(std::thread::spawn(move || {
51
100000
            for _ in 0..iterations {
52
100000
                test_function(&mut rng, &mut init);
53
100000
            }
54
20
        }));
55
    }
56

            
57
20
    for thread in threads {
58
20
        let _ = thread.join();
59
20
    }
60
1
}