spongefish_pow/
blake3.rs

1use blake3::{
2    self,
3    guts::BLOCK_LEN,
4    platform::{Platform, MAX_SIMD_DEGREE},
5    IncrementCounter, OUT_LEN,
6};
7#[cfg(feature = "parallel")]
8use rayon::broadcast;
9
10use super::PowStrategy;
11
12/// A SIMD-accelerated BLAKE3-based proof-of-work engine.
13///
14/// This struct encapsulates the state needed to search for a nonce such that
15/// `BLAKE3(challenge || nonce)` is below a difficulty threshold.
16///
17/// It leverages `Platform::hash_many` for parallel hash evaluation using `MAX_SIMD_DEGREE` lanes.
18#[derive(Clone, Copy)]
19pub struct Blake3PoW {
20    /// The 32-byte challenge seed used as a prefix to every hash input.
21    challenge: [u8; 32],
22    /// Difficulty target: hashes must be less than this 64-bit threshold.
23    threshold: u64,
24    /// Platform-specific SIMD hashing backend selected at runtime.
25    platform: Platform,
26    /// SIMD batch of hash inputs, each 64 bytes (challenge + nonce).
27    inputs: [[u8; BLOCK_LEN]; MAX_SIMD_DEGREE],
28    /// SIMD batch of hash outputs (32 bytes each).
29    outputs: [u8; OUT_LEN * MAX_SIMD_DEGREE],
30}
31
32impl PowStrategy for Blake3PoW {
33    /// Create a new Blake3PoW instance with a given challenge and difficulty.
34    ///
35    /// The `bits` parameter controls the difficulty. A higher number means
36    /// lower probability of success per nonce. This function prepares the SIMD
37    /// input buffer with the challenge prefix and sets the internal threshold.
38    ///
39    /// # Panics
40    /// - If `bits` is not in the range [0.0, 60.0).
41    /// - If `BLOCK_LEN` or `OUT_LEN` do not match expected values.
42    #[allow(clippy::cast_sign_loss)]
43    fn new(challenge: [u8; 32], bits: f64) -> Self {
44        // BLAKE3 block size must be 64 bytes.
45        assert_eq!(BLOCK_LEN, 64);
46        // BLAKE3 output size must be 32 bytes.
47        assert_eq!(OUT_LEN, 32);
48        // Ensure the difficulty is within supported range.
49        assert!((0.0..60.0).contains(&bits), "bits must be smaller than 60");
50
51        // Prepare SIMD input buffer: fill each lane with the challenge prefix.
52        let mut inputs = [[0u8; BLOCK_LEN]; MAX_SIMD_DEGREE];
53        for input in &mut inputs {
54            input[..32].copy_from_slice(&challenge);
55        }
56
57        Self {
58            // Store challenge prefix.
59            challenge,
60            // Compute threshold: smaller means harder PoW.
61            threshold: (64.0 - bits).exp2().ceil() as u64,
62            // Detect SIMD platform (e.g., AVX2, NEON, etc).
63            platform: Platform::detect(),
64            // Pre-filled SIMD inputs (nonce injected later).
65            inputs,
66            // Zero-initialized output buffer for SIMD hashes.
67            outputs: [0; OUT_LEN * MAX_SIMD_DEGREE],
68        }
69    }
70
71    /// Check if a given `nonce` satisfies the challenge.
72    ///
73    /// This uses the standard high-level BLAKE3 interface to ensure
74    /// full compatibility with reference implementations.
75    ///
76    /// A nonce is valid if the first 8 bytes of the hash output,
77    /// interpreted as a little-endian `u64`, are below the internal threshold.
78    fn check(&mut self, nonce: u64) -> bool {
79        // Create a new BLAKE3 hasher instance.
80        let mut hasher = blake3::Hasher::new();
81
82        // Feed the challenge prefix.
83        hasher.update(&self.challenge);
84        // Feed the nonce as little-endian bytes.
85        hasher.update(&nonce.to_le_bytes());
86        // Zero-extend the nonce to 32 bytes (challenge + nonce = full block).
87        hasher.update(&[0; 24]);
88
89        // Hash the input and extract the first 8 bytes.
90        let mut hash = [0u8; 8];
91        hasher.finalize_xof().fill(&mut hash);
92
93        // Check whether the result is below the threshold.
94        u64::from_le_bytes(hash) < self.threshold
95    }
96
97    /// Finds the minimal `nonce` that satisfies the challenge.
98    #[cfg(not(feature = "parallel"))]
99    fn solve(&mut self) -> Option<u64> {
100        (0..)
101            .step_by(MAX_SIMD_DEGREE)
102            .find_map(|nonce| self.check_many(nonce))
103    }
104
105    /// Search for the lowest `nonce` that satisfies the challenge using parallel threads.
106    ///
107    /// Each thread scans disjoint chunks of the nonce space in stride-sized steps.
108    /// The first thread to find a satisfying nonce updates a shared atomic minimum,
109    /// and all others check against it to avoid unnecessary work.
110    #[cfg(feature = "parallel")]
111    fn solve(&mut self) -> Option<u64> {
112        use std::sync::atomic::{AtomicU64, Ordering};
113
114        // Split the work across all available threads.
115        // Use atomics to find the unique deterministic lowest satisfying nonce.
116        let global_min = AtomicU64::new(u64::MAX);
117
118        // Spawn parallel workers using Rayon’s broadcast.
119        let _ = broadcast(|ctx| {
120            // Copy the PoW instance for thread-local use.
121            let mut worker = *self;
122
123            // Each thread searches a distinct subset of nonces.
124            let nonces = ((MAX_SIMD_DEGREE * ctx.index()) as u64..)
125                .step_by(MAX_SIMD_DEGREE * ctx.num_threads());
126
127            for nonce in nonces {
128                // Skip work if another thread already found a lower valid nonce.
129                //
130                // Use relaxed ordering to eventually get notified of another thread's solution.
131                // (Propagation delay should be in the order of tens of nanoseconds.)
132                if nonce >= global_min.load(Ordering::Relaxed) {
133                    break;
134                }
135                // Check a batch of nonces starting from `nonce`.
136                if let Some(nonce) = worker.check_many(nonce) {
137                    // We found a solution, store it in the global_min.
138                    // Use fetch_min to solve race condition with simultaneous solutions.
139                    global_min.fetch_min(nonce, Ordering::SeqCst);
140                    break;
141                }
142            }
143        });
144
145        // Return the best found nonce, or fallback check on `u64::MAX`.
146        match global_min.load(Ordering::SeqCst) {
147            u64::MAX => self.check(u64::MAX).then_some(u64::MAX),
148            nonce => Some(nonce),
149        }
150    }
151}
152
153impl Blake3PoW {
154    /// Default Blake3 initialization vector. Copied here because it is not publicly exported.
155    #[allow(clippy::unreadable_literal)]
156    const BLAKE3_IV: [u32; 8] = [
157        0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB,
158        0x5BE0CD19,
159    ];
160    const BLAKE3_FLAGS: u8 = 0x0B; // CHUNK_START | CHUNK_END | ROOT
161
162    /// Check a SIMD-width batch of nonces starting at `nonce`.
163    ///
164    /// Returns the first nonce in the batch that satisfies the challenge threshold,
165    /// or `None` if none do.
166    fn check_many(&mut self, nonce: u64) -> Option<u64> {
167        // Fill each SIMD input block with the challenge + nonce suffix.
168        for (i, input) in self.inputs.iter_mut().enumerate() {
169            // Write the nonce as little-endian into bytes 32..40.
170            let n = (nonce + i as u64).to_le_bytes();
171            input[32..40].copy_from_slice(&n);
172        }
173
174        // Create references required by `hash_many`.
175        let input_refs: [&[u8; BLOCK_LEN]; MAX_SIMD_DEGREE] =
176            std::array::from_fn(|i| &self.inputs[i]);
177
178        // Perform parallel hashing over the input blocks.
179        self.platform.hash_many::<BLOCK_LEN>(
180            &input_refs,
181            &Self::BLAKE3_IV,     // Initialization vector
182            0,                    // Counter
183            IncrementCounter::No, // Do not increment counter
184            Self::BLAKE3_FLAGS,   // Default flags
185            0,
186            0, // No start/end flags
187            &mut self.outputs,
188        );
189
190        // Scan results and return the first nonce under the threshold.
191        for (i, chunk) in self.outputs.chunks_exact(OUT_LEN).enumerate() {
192            let hash = u64::from_le_bytes(chunk[..8].try_into().unwrap());
193            if hash < self.threshold {
194                return Some(nonce + i as u64);
195            }
196        }
197
198        // None of the batch satisfied the condition.
199        None
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use spongefish::{DefaultHash, DomainSeparator};
206
207    use super::*;
208    use crate::{
209        ByteDomainSeparator, BytesToUnitDeserialize, BytesToUnitSerialize, PoWChallenge,
210        PoWDomainSeparator,
211    };
212
213    fn sample_challenges() -> Vec<[u8; 32]> {
214        vec![
215            [0u8; 32],                                              // All zeroes
216            [0xFF; 32],                                             // All ones
217            [42u8; 32],                                             // Constant value
218            (0..32).collect::<Vec<u8>>().try_into().unwrap(),       // Increasing
219            (0..32).rev().collect::<Vec<u8>>().try_into().unwrap(), // Decreasing
220        ]
221    }
222
223    #[test]
224    fn test_pow_blake3() {
225        const BITS: f64 = 10.0;
226
227        let domain_separator = DomainSeparator::<DefaultHash>::new("the proof of work lottery 🎰")
228            .add_bytes(1, "something")
229            .challenge_pow("rolling dices");
230
231        let mut prover = domain_separator.to_prover_state();
232        prover.add_bytes(b"\0").expect("Invalid DomainSeparator");
233        prover.challenge_pow::<Blake3PoW>(BITS).unwrap();
234
235        let mut verifier = domain_separator.to_verifier_state(prover.narg_string());
236        let byte = verifier.next_bytes::<1>().unwrap();
237        assert_eq!(&byte, b"\0");
238        verifier.challenge_pow::<Blake3PoW>(BITS).unwrap();
239    }
240
241    #[test]
242    #[allow(clippy::cast_sign_loss)]
243    fn test_new_pow_valid_bits() {
244        for bits in [0.1, 10.0, 20.0, 40.0, 59.99] {
245            let challenge = [1u8; 32];
246            let pow = Blake3PoW::new(challenge, bits);
247            let expected_threshold = (64.0 - bits).exp2().ceil() as u64;
248            assert_eq!(pow.threshold, expected_threshold);
249            assert_eq!(pow.challenge, challenge);
250        }
251    }
252
253    #[test]
254    #[should_panic]
255    fn test_new_invalid_bits() {
256        let _ = Blake3PoW::new([0u8; 32], 60.0);
257    }
258
259    #[test]
260    fn test_check_function_basic() {
261        let challenge = [0u8; 32];
262        let mut pow = Blake3PoW::new(challenge, 8.0);
263        for nonce in (0u64..10000).step_by(MAX_SIMD_DEGREE) {
264            if let Some(solution) = pow.check_many(nonce) {
265                assert!(pow.check(solution), "check() should match check_many()");
266                return;
267            }
268        }
269        panic!("Expected at least one valid nonce under threshold using check_many");
270    }
271
272    #[cfg(not(feature = "parallel"))]
273    #[test]
274    fn test_solve_sequential() {
275        let challenge = [2u8; 32];
276        let mut pow = Blake3PoW::new(challenge, 10.0);
277        let nonce = pow.solve().expect("Should find a nonce");
278        assert!(pow.check(nonce), "Found nonce does not satisfy challenge");
279    }
280
281    #[cfg(feature = "parallel")]
282    #[test]
283    fn test_solve_parallel() {
284        let challenge = [3u8; 32];
285        let mut pow = Blake3PoW::new(challenge, 10.0);
286        let nonce = pow.solve().expect("Should find a nonce");
287        assert!(pow.check(nonce), "Found nonce does not satisfy challenge");
288    }
289
290    #[test]
291    fn test_different_challenges_consistency() {
292        let bits = 8.0;
293        for challenge in sample_challenges() {
294            let mut pow = Blake3PoW::new(challenge, bits);
295            let nonce = pow.solve().expect("Must find solution for low difficulty");
296            assert!(pow.check(nonce));
297        }
298    }
299
300    #[test]
301    fn test_check_many_determinism() {
302        let challenge = [42u8; 32];
303        let mut pow1 = Blake3PoW::new(challenge, 10.0);
304        let mut pow2 = Blake3PoW::new(challenge, 10.0);
305
306        let n1 = pow1.check_many(0);
307        let n2 = pow2.check_many(0);
308        assert_eq!(n1, n2, "check_many should be deterministic");
309    }
310
311    #[test]
312    #[allow(clippy::cast_sign_loss)]
313    fn test_threshold_rounding_boundaries() {
314        let c = [7u8; 32];
315        let bits = 24.5;
316        let pow = Blake3PoW::new(c, bits);
317        let expected = (64.0 - bits).exp2().ceil() as u64;
318        assert_eq!(pow.threshold, expected);
319    }
320
321    #[test]
322    fn test_check_many_inserts_nonce_bytes() {
323        let challenge = [0xAB; 32];
324        let mut pow = Blake3PoW::new(challenge, 50.0);
325
326        // Run check_many to populate nonces.
327        let base_nonce = 12_345_678;
328        let _ = pow.check_many(base_nonce);
329
330        for (i, input) in pow.inputs.iter().enumerate() {
331            // Confirm prefix is unchanged
332            assert_eq!(&input[..32], &challenge);
333            // Confirm suffix is the correct nonce bytes
334            let expected_nonce = base_nonce + i as u64;
335            let actual = u64::from_le_bytes(input[32..40].try_into().unwrap());
336            assert_eq!(actual, expected_nonce);
337        }
338    }
339
340    #[test]
341    fn test_solve_returns_minimal_nonce() {
342        let c = [123; 32];
343        let mut pow = Blake3PoW::new(c, 10.0);
344        let mut best = None;
345        for nonce in (0..10000).step_by(MAX_SIMD_DEGREE) {
346            if let Some(found) = pow.check_many(nonce) {
347                best = Some(found);
348                break;
349            }
350        }
351        let result = pow.solve();
352        assert_eq!(result, best, "solve should return the first valid nonce");
353    }
354
355    #[test]
356    fn stress_test_check_many_entropy() {
357        let challenge = [42u8; 32];
358        let mut pow = Blake3PoW::new(challenge, 16.0);
359
360        let mut found = 0;
361        for nonce in (0..1_000_000).step_by(MAX_SIMD_DEGREE) {
362            if pow.check_many(nonce).is_some() {
363                found += 1;
364            }
365        }
366
367        // Should find some hits at low difficulty
368        assert!(found > 0, "Expected to find at least one solution");
369    }
370}