ark_std/
rand_helper.rs

1#[cfg(feature = "std")]
2use rand::RngCore;
3use rand::{
4    distributions::{Distribution, Standard},
5    prelude::StdRng,
6    Rng,
7};
8
9pub use rand;
10
11pub trait UniformRand: Sized {
12    fn rand<R: Rng + ?Sized>(rng: &mut R) -> Self;
13}
14
15impl<T> UniformRand for T
16where
17    Standard: Distribution<T>,
18{
19    #[inline]
20    fn rand<R: Rng + ?Sized>(rng: &mut R) -> Self {
21        rng.sample(Standard)
22    }
23}
24
25fn test_rng_helper() -> StdRng {
26    use rand::SeedableRng;
27    // arbitrary seed
28    let seed = [
29        1, 0, 0, 0, 23, 0, 0, 0, 200, 1, 0, 0, 210, 30, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
30        0, 0, 0, 0,
31    ];
32    rand::rngs::StdRng::from_seed(seed)
33}
34
35/// Should be used only for tests, not for any real world usage.
36#[cfg(not(feature = "std"))]
37pub fn test_rng() -> impl rand::Rng {
38    test_rng_helper()
39}
40
41/// Should be used only for tests, not for any real world usage.
42#[cfg(feature = "std")]
43pub fn test_rng() -> impl rand::Rng {
44    #[cfg(any(feature = "getrandom", test))]
45    {
46        let is_deterministic =
47            std::env::vars().any(|(key, val)| key == "DETERMINISTIC_TEST_RNG" && val == "1");
48        if is_deterministic {
49            RngWrapper::Deterministic(test_rng_helper())
50        } else {
51            RngWrapper::Randomized(rand::thread_rng())
52        }
53    }
54    #[cfg(not(any(feature = "getrandom", test)))]
55    {
56        RngWrapper::Deterministic(test_rng_helper())
57    }
58}
59
60/// Helper wrapper to enable `test_rng` to return `impl::Rng`.
61#[cfg(feature = "std")]
62enum RngWrapper {
63    Deterministic(StdRng),
64    #[cfg(any(feature = "getrandom", test))]
65    Randomized(rand::rngs::ThreadRng),
66}
67
68#[cfg(feature = "std")]
69impl RngCore for RngWrapper {
70    #[inline(always)]
71    fn next_u32(&mut self) -> u32 {
72        match self {
73            Self::Deterministic(rng) => rng.next_u32(),
74            #[cfg(any(feature = "getrandom", test))]
75            Self::Randomized(rng) => rng.next_u32(),
76        }
77    }
78
79    #[inline(always)]
80    fn next_u64(&mut self) -> u64 {
81        match self {
82            Self::Deterministic(rng) => rng.next_u64(),
83            #[cfg(any(feature = "getrandom", test))]
84            Self::Randomized(rng) => rng.next_u64(),
85        }
86    }
87
88    #[inline(always)]
89    fn fill_bytes(&mut self, dest: &mut [u8]) {
90        match self {
91            Self::Deterministic(rng) => rng.fill_bytes(dest),
92            #[cfg(any(feature = "getrandom", test))]
93            Self::Randomized(rng) => rng.fill_bytes(dest),
94        }
95    }
96
97    #[inline(always)]
98    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
99        match self {
100            Self::Deterministic(rng) => rng.try_fill_bytes(dest),
101            #[cfg(any(feature = "getrandom", test))]
102            Self::Randomized(rng) => rng.try_fill_bytes(dest),
103        }
104    }
105}
106
107#[cfg(all(test, feature = "std"))]
108mod test {
109    #[test]
110    fn test_deterministic_rng() {
111        use super::*;
112
113        let mut rng = super::test_rng();
114        let a = u128::rand(&mut rng);
115
116        // Reset the rng by sampling a new one.
117        let mut rng = super::test_rng();
118        let b = u128::rand(&mut rng);
119        assert_ne!(a, b); // should be unequal with high probability.
120
121        // Let's make the rng deterministic.
122        std::env::set_var("DETERMINISTIC_TEST_RNG", "1");
123        let mut rng = super::test_rng();
124        let a = u128::rand(&mut rng);
125
126        // Reset the rng by sampling a new one.
127        let mut rng = super::test_rng();
128        let b = u128::rand(&mut rng);
129        assert_eq!(a, b); // should be equal with high probability.
130    }
131}