1use blake3::{
2    self,
3    guts::BLOCK_LEN,
4    platform::{Platform, MAX_SIMD_DEGREE},
5    IncrementCounter, OUT_LEN,
6};
7#[cfg(feature = "parallel")]
8use rayon::broadcast;
9
10use super::PowStrategy;
11
12#[derive(Clone, Copy)]
19pub struct Blake3PoW {
20    challenge: [u8; 32],
22    threshold: u64,
24    platform: Platform,
26    inputs: [[u8; BLOCK_LEN]; MAX_SIMD_DEGREE],
28    outputs: [u8; OUT_LEN * MAX_SIMD_DEGREE],
30}
31
32impl PowStrategy for Blake3PoW {
33    #[allow(clippy::cast_sign_loss)]
43    fn new(challenge: [u8; 32], bits: f64) -> Self {
44        assert_eq!(BLOCK_LEN, 64);
46        assert_eq!(OUT_LEN, 32);
48        assert!((0.0..60.0).contains(&bits), "bits must be smaller than 60");
50
51        let mut inputs = [[0u8; BLOCK_LEN]; MAX_SIMD_DEGREE];
53        for input in &mut inputs {
54            input[..32].copy_from_slice(&challenge);
55        }
56
57        Self {
58            challenge,
60            threshold: (64.0 - bits).exp2().ceil() as u64,
62            platform: Platform::detect(),
64            inputs,
66            outputs: [0; OUT_LEN * MAX_SIMD_DEGREE],
68        }
69    }
70
71    fn check(&mut self, nonce: u64) -> bool {
79        let mut hasher = blake3::Hasher::new();
81
82        hasher.update(&self.challenge);
84        hasher.update(&nonce.to_le_bytes());
86        hasher.update(&[0; 24]);
88
89        let mut hash = [0u8; 8];
91        hasher.finalize_xof().fill(&mut hash);
92
93        u64::from_le_bytes(hash) < self.threshold
95    }
96
97    #[cfg(not(feature = "parallel"))]
99    fn solve(&mut self) -> Option<u64> {
100        (0..)
101            .step_by(MAX_SIMD_DEGREE)
102            .find_map(|nonce| self.check_many(nonce))
103    }
104
105    #[cfg(feature = "parallel")]
111    fn solve(&mut self) -> Option<u64> {
112        use std::sync::atomic::{AtomicU64, Ordering};
113
114        let global_min = AtomicU64::new(u64::MAX);
117
118        let _ = broadcast(|ctx| {
120            let mut worker = *self;
122
123            let nonces = ((MAX_SIMD_DEGREE * ctx.index()) as u64..)
125                .step_by(MAX_SIMD_DEGREE * ctx.num_threads());
126
127            for nonce in nonces {
128                if nonce >= global_min.load(Ordering::Relaxed) {
133                    break;
134                }
135                if let Some(nonce) = worker.check_many(nonce) {
137                    global_min.fetch_min(nonce, Ordering::SeqCst);
140                    break;
141                }
142            }
143        });
144
145        match global_min.load(Ordering::SeqCst) {
147            u64::MAX => self.check(u64::MAX).then_some(u64::MAX),
148            nonce => Some(nonce),
149        }
150    }
151}
152
153impl Blake3PoW {
154    #[allow(clippy::unreadable_literal)]
156    const BLAKE3_IV: [u32; 8] = [
157        0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB,
158        0x5BE0CD19,
159    ];
160    const BLAKE3_FLAGS: u8 = 0x0B; fn check_many(&mut self, nonce: u64) -> Option<u64> {
167        for (i, input) in self.inputs.iter_mut().enumerate() {
169            let n = (nonce + i as u64).to_le_bytes();
171            input[32..40].copy_from_slice(&n);
172        }
173
174        let input_refs: [&[u8; BLOCK_LEN]; MAX_SIMD_DEGREE] =
176            std::array::from_fn(|i| &self.inputs[i]);
177
178        self.platform.hash_many::<BLOCK_LEN>(
180            &input_refs,
181            &Self::BLAKE3_IV,     0,                    IncrementCounter::No, Self::BLAKE3_FLAGS,   0,
186            0, &mut self.outputs,
188        );
189
190        for (i, chunk) in self.outputs.chunks_exact(OUT_LEN).enumerate() {
192            let hash = u64::from_le_bytes(chunk[..8].try_into().unwrap());
193            if hash < self.threshold {
194                return Some(nonce + i as u64);
195            }
196        }
197
198        None
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use spongefish::{DefaultHash, DomainSeparator};
206
207    use super::*;
208    use crate::{
209        ByteDomainSeparator, BytesToUnitDeserialize, BytesToUnitSerialize, PoWChallenge,
210        PoWDomainSeparator,
211    };
212
213    fn sample_challenges() -> Vec<[u8; 32]> {
214        vec![
215            [0u8; 32],                                              [0xFF; 32],                                             [42u8; 32],                                             (0..32).collect::<Vec<u8>>().try_into().unwrap(),       (0..32).rev().collect::<Vec<u8>>().try_into().unwrap(), ]
221    }
222
223    #[test]
224    fn test_pow_blake3() {
225        const BITS: f64 = 10.0;
226
227        let domain_separator = DomainSeparator::<DefaultHash>::new("the proof of work lottery 🎰")
228            .add_bytes(1, "something")
229            .challenge_pow("rolling dices");
230
231        let mut prover = domain_separator.to_prover_state();
232        prover.add_bytes(b"\0").expect("Invalid DomainSeparator");
233        prover.challenge_pow::<Blake3PoW>(BITS).unwrap();
234
235        let mut verifier = domain_separator.to_verifier_state(prover.narg_string());
236        let byte = verifier.next_bytes::<1>().unwrap();
237        assert_eq!(&byte, b"\0");
238        verifier.challenge_pow::<Blake3PoW>(BITS).unwrap();
239    }
240
241    #[test]
242    #[allow(clippy::cast_sign_loss)]
243    fn test_new_pow_valid_bits() {
244        for bits in [0.1, 10.0, 20.0, 40.0, 59.99] {
245            let challenge = [1u8; 32];
246            let pow = Blake3PoW::new(challenge, bits);
247            let expected_threshold = (64.0 - bits).exp2().ceil() as u64;
248            assert_eq!(pow.threshold, expected_threshold);
249            assert_eq!(pow.challenge, challenge);
250        }
251    }
252
253    #[test]
254    #[should_panic]
255    fn test_new_invalid_bits() {
256        let _ = Blake3PoW::new([0u8; 32], 60.0);
257    }
258
259    #[test]
260    fn test_check_function_basic() {
261        let challenge = [0u8; 32];
262        let mut pow = Blake3PoW::new(challenge, 8.0);
263        for nonce in (0u64..10000).step_by(MAX_SIMD_DEGREE) {
264            if let Some(solution) = pow.check_many(nonce) {
265                assert!(pow.check(solution), "check() should match check_many()");
266                return;
267            }
268        }
269        panic!("Expected at least one valid nonce under threshold using check_many");
270    }
271
272    #[cfg(not(feature = "parallel"))]
273    #[test]
274    fn test_solve_sequential() {
275        let challenge = [2u8; 32];
276        let mut pow = Blake3PoW::new(challenge, 10.0);
277        let nonce = pow.solve().expect("Should find a nonce");
278        assert!(pow.check(nonce), "Found nonce does not satisfy challenge");
279    }
280
281    #[cfg(feature = "parallel")]
282    #[test]
283    fn test_solve_parallel() {
284        let challenge = [3u8; 32];
285        let mut pow = Blake3PoW::new(challenge, 10.0);
286        let nonce = pow.solve().expect("Should find a nonce");
287        assert!(pow.check(nonce), "Found nonce does not satisfy challenge");
288    }
289
290    #[test]
291    fn test_different_challenges_consistency() {
292        let bits = 8.0;
293        for challenge in sample_challenges() {
294            let mut pow = Blake3PoW::new(challenge, bits);
295            let nonce = pow.solve().expect("Must find solution for low difficulty");
296            assert!(pow.check(nonce));
297        }
298    }
299
300    #[test]
301    fn test_check_many_determinism() {
302        let challenge = [42u8; 32];
303        let mut pow1 = Blake3PoW::new(challenge, 10.0);
304        let mut pow2 = Blake3PoW::new(challenge, 10.0);
305
306        let n1 = pow1.check_many(0);
307        let n2 = pow2.check_many(0);
308        assert_eq!(n1, n2, "check_many should be deterministic");
309    }
310
311    #[test]
312    #[allow(clippy::cast_sign_loss)]
313    fn test_threshold_rounding_boundaries() {
314        let c = [7u8; 32];
315        let bits = 24.5;
316        let pow = Blake3PoW::new(c, bits);
317        let expected = (64.0 - bits).exp2().ceil() as u64;
318        assert_eq!(pow.threshold, expected);
319    }
320
321    #[test]
322    fn test_check_many_inserts_nonce_bytes() {
323        let challenge = [0xAB; 32];
324        let mut pow = Blake3PoW::new(challenge, 50.0);
325
326        let base_nonce = 12_345_678;
328        let _ = pow.check_many(base_nonce);
329
330        for (i, input) in pow.inputs.iter().enumerate() {
331            assert_eq!(&input[..32], &challenge);
333            let expected_nonce = base_nonce + i as u64;
335            let actual = u64::from_le_bytes(input[32..40].try_into().unwrap());
336            assert_eq!(actual, expected_nonce);
337        }
338    }
339
340    #[test]
341    fn test_solve_returns_minimal_nonce() {
342        let c = [123; 32];
343        let mut pow = Blake3PoW::new(c, 10.0);
344        let mut best = None;
345        for nonce in (0..10000).step_by(MAX_SIMD_DEGREE) {
346            if let Some(found) = pow.check_many(nonce) {
347                best = Some(found);
348                break;
349            }
350        }
351        let result = pow.solve();
352        assert_eq!(result, best, "solve should return the first valid nonce");
353    }
354
355    #[test]
356    fn stress_test_check_many_entropy() {
357        let challenge = [42u8; 32];
358        let mut pow = Blake3PoW::new(challenge, 16.0);
359
360        let mut found = 0;
361        for nonce in (0..1_000_000).step_by(MAX_SIMD_DEGREE) {
362            if pow.check_many(nonce).is_some() {
363                found += 1;
364            }
365        }
366
367        assert!(found > 0, "Expected to find at least one solution");
369    }
370}