Skip to main content

spongefish_pow/
blake3.rs

1use blake3::{
2    guts::BLOCK_LEN,
3    platform::{Platform, MAX_SIMD_DEGREE},
4    IncrementCounter, OUT_LEN,
5};
6#[cfg(feature = "parallel")]
7use rayon::broadcast;
8
9use super::PowStrategy;
10use crate::PoWSolution;
11
12/// A SIMD-accelerated BLAKE3-based proof-of-work engine.
13///
14/// This struct encapsulates the state needed to search for a nonce such that
15/// `BLAKE3(challenge || nonce)` is below a difficulty threshold.
16///
17/// It leverages `Platform::hash_many` for parallel hash evaluation using `MAX_SIMD_DEGREE` lanes.
18#[derive(Clone, Copy)]
19pub struct Blake3PoW {
20    /// The 32-byte challenge seed used as a prefix to every hash input.
21    challenge: [u8; 32],
22    /// Difficulty target: hashes must be less than this 64-bit threshold.
23    threshold: u64,
24    /// Platform-specific SIMD hashing backend selected at runtime.
25    platform: Platform,
26    /// SIMD batch of hash inputs, each 64 bytes (challenge + nonce).
27    inputs: [[u8; BLOCK_LEN]; MAX_SIMD_DEGREE],
28    /// SIMD batch of hash outputs (32 bytes each).
29    outputs: [u8; OUT_LEN * MAX_SIMD_DEGREE],
30}
31
32impl PowStrategy for Blake3PoW {
33    /// Create a new Blake3PoW instance with a given challenge and difficulty.
34    ///
35    /// The `bits` parameter controls the difficulty. A higher number means
36    /// lower probability of success per nonce. This function prepares the SIMD
37    /// input buffer with the challenge prefix and sets the internal threshold.
38    ///
39    /// # Panics
40    /// - If `bits` is not in the range [0.0, 60.0).
41    /// - If `BLOCK_LEN` or `OUT_LEN` do not match expected values.
42    #[allow(clippy::cast_sign_loss)]
43    fn new(challenge: [u8; 32], bits: f64) -> Self {
44        // BLAKE3 block size must be 64 bytes.
45        assert_eq!(BLOCK_LEN, 64);
46        // BLAKE3 output size must be 32 bytes.
47        assert_eq!(OUT_LEN, 32);
48        // Ensure the difficulty is within supported range.
49        assert!((0.0..60.0).contains(&bits), "bits must be smaller than 60");
50
51        // Prepare SIMD input buffer: fill each lane with the challenge prefix.
52        let mut inputs = [[0u8; BLOCK_LEN]; MAX_SIMD_DEGREE];
53        for input in &mut inputs {
54            input[..32].copy_from_slice(&challenge);
55        }
56
57        Self {
58            // Store challenge prefix.
59            challenge,
60            // Compute threshold: smaller means harder PoW.
61            threshold: (64.0 - bits).exp2().ceil() as u64,
62            // Detect SIMD platform (e.g., AVX2, NEON, etc).
63            platform: Platform::detect(),
64            // Pre-filled SIMD inputs (nonce injected later).
65            inputs,
66            // Zero-initialized output buffer for SIMD hashes.
67            outputs: [0; OUT_LEN * MAX_SIMD_DEGREE],
68        }
69    }
70
71    fn solution(&self, nonce: u64) -> PoWSolution {
72        PoWSolution {
73            challenge: self.challenge,
74            nonce,
75        }
76    }
77
78    /// Check if a given `nonce` satisfies the challenge.
79    ///
80    /// This uses the standard high-level BLAKE3 interface to ensure
81    /// full compatibility with reference implementations.
82    ///
83    /// A nonce is valid if the first 8 bytes of the hash output,
84    /// interpreted as a little-endian `u64`, are below the internal threshold.
85    fn check(&mut self, nonce: u64) -> bool {
86        // Create a new BLAKE3 hasher instance.
87        let mut hasher = blake3::Hasher::new();
88
89        // Feed the challenge prefix.
90        hasher.update(&self.challenge);
91        // Feed the nonce as little-endian bytes.
92        hasher.update(&nonce.to_le_bytes());
93        // Zero-extend the nonce to 32 bytes (challenge + nonce = full block).
94        hasher.update(&[0; 24]);
95
96        // Hash the input and extract the first 8 bytes.
97        let mut hash = [0u8; 8];
98        hasher.finalize_xof().fill(&mut hash);
99
100        // Check whether the result is below the threshold.
101        u64::from_le_bytes(hash) < self.threshold
102    }
103
104    /// Finds the minimal `nonce` that satisfies the challenge.
105    #[cfg(not(feature = "parallel"))]
106    fn solve(&mut self) -> Option<PoWSolution> {
107        (0..)
108            .step_by(MAX_SIMD_DEGREE)
109            .find_map(|nonce| self.check_many(nonce))
110            .map(|nonce| self.solution(nonce))
111    }
112
113    /// Search for the lowest `nonce` that satisfies the challenge using parallel threads.
114    ///
115    /// Each thread scans disjoint chunks of the nonce space in stride-sized steps.
116    /// The first thread to find a satisfying nonce updates a shared atomic minimum,
117    /// and all others check against it to avoid unnecessary work.
118    #[cfg(feature = "parallel")]
119    fn solve(&mut self) -> Option<PoWSolution> {
120        use std::sync::atomic::{AtomicU64, Ordering};
121
122        // Split the work across all available threads.
123        // Use atomics to find the unique deterministic lowest satisfying nonce.
124        let global_min = AtomicU64::new(u64::MAX);
125
126        // Spawn parallel workers using Rayon's broadcast.
127        let _ = broadcast(|ctx| {
128            // Copy the PoW instance for thread-local use.
129            let mut worker = *self;
130
131            // Each thread searches a distinct subset of nonces.
132            let nonces = ((MAX_SIMD_DEGREE * ctx.index()) as u64..)
133                .step_by(MAX_SIMD_DEGREE * ctx.num_threads());
134
135            for nonce in nonces {
136                // Skip work if another thread already found a lower valid nonce.
137                //
138                // Use relaxed ordering to eventually get notified of another thread's solution.
139                // (Propagation delay should be in the order of tens of nanoseconds.)
140                if nonce >= global_min.load(Ordering::Relaxed) {
141                    break;
142                }
143                // Check a batch of nonces starting from `nonce`.
144                if let Some(nonce) = worker.check_many(nonce) {
145                    // We found a solution, store it in the global_min.
146                    // Use fetch_min to solve race condition with simultaneous solutions.
147                    global_min.fetch_min(nonce, Ordering::SeqCst);
148                    break;
149                }
150            }
151        });
152
153        // Return the best found nonce, or fallback check on `u64::MAX`.
154        match global_min.load(Ordering::SeqCst) {
155            u64::MAX => self.check(u64::MAX).then(|| self.solution(u64::MAX)),
156            nonce => Some(self.solution(nonce)),
157        }
158    }
159}
160
161impl Blake3PoW {
162    /// Default Blake3 initialization vector. Copied here because it is not publicly exported.
163    #[allow(clippy::unreadable_literal)]
164    const BLAKE3_IV: [u32; 8] = [
165        0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB,
166        0x5BE0CD19,
167    ];
168    const BLAKE3_FLAGS: u8 = 0x0B; // CHUNK_START | CHUNK_END | ROOT
169
170    /// Check a SIMD-width batch of nonces starting at `nonce`.
171    ///
172    /// Returns the first nonce in the batch that satisfies the challenge threshold,
173    /// or `None` if none do.
174    fn check_many(&mut self, nonce: u64) -> Option<u64> {
175        // Fill each SIMD input block with the challenge + nonce suffix.
176        // If the batch would overflow `u64`, stop at the last representable nonce
177        // and zero the remaining lanes so they never get reported as valid.
178        let mut valid_lanes = MAX_SIMD_DEGREE;
179        for i in 0..MAX_SIMD_DEGREE {
180            let Some(batch_nonce) = nonce.checked_add(i as u64) else {
181                valid_lanes = i;
182                break;
183            };
184            self.inputs[i][32..40].copy_from_slice(&batch_nonce.to_le_bytes());
185        }
186        for input in &mut self.inputs[valid_lanes..] {
187            input[32..40].fill(0);
188        }
189
190        // Create references required by `hash_many`.
191        let input_refs: [&[u8; BLOCK_LEN]; MAX_SIMD_DEGREE] =
192            std::array::from_fn(|i| &self.inputs[i]);
193
194        // Perform parallel hashing over the input blocks.
195        self.platform.hash_many::<BLOCK_LEN>(
196            &input_refs,
197            &Self::BLAKE3_IV,     // Initialization vector
198            0,                    // Counter
199            IncrementCounter::No, // Do not increment counter
200            Self::BLAKE3_FLAGS,   // Default flags
201            0,
202            0, // No start/end flags
203            &mut self.outputs,
204        );
205
206        // Scan results and return the first nonce under the threshold.
207        for (i, chunk) in self
208            .outputs
209            .chunks_exact(OUT_LEN)
210            .take(valid_lanes)
211            .enumerate()
212        {
213            let hash = u64::from_le_bytes(chunk[..8].try_into().unwrap());
214            if hash < self.threshold {
215                return nonce.checked_add(i as u64);
216            }
217        }
218
219        // None of the batch satisfied the condition.
220        None
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use crate::{convenience::*, PoWGrinder};
228
229    fn sample_challenges() -> Vec<[u8; 32]> {
230        vec![
231            [0u8; 32],                                              // All zeroes
232            [0xFF; 32],                                             // All ones
233            [42u8; 32],                                             // Constant value
234            (0..32).collect::<Vec<u8>>().try_into().unwrap(),       // Increasing
235            (0..32).rev().collect::<Vec<u8>>().try_into().unwrap(), // Decreasing
236        ]
237    }
238
239    #[test]
240    fn test_pow_blake3() {
241        const BITS: f64 = 10.0;
242
243        for challenge in sample_challenges() {
244            // Generate a proof-of-work solution
245            let _solution =
246                grind_pow::<Blake3PoW>(challenge, BITS).expect("Should find a valid solution");
247
248            // Verify we can generate solutions consistently
249            assert!(grind_pow::<Blake3PoW>(challenge, BITS).is_some());
250        }
251
252        // Test using the PoWGrinder directly
253        let challenge = [42u8; 32];
254        let mut grinder = PoWGrinder::<Blake3PoW>::new(challenge, BITS);
255        let _solution = grinder.grind().expect("Should find a valid solution");
256    }
257
258    #[test]
259    #[allow(clippy::cast_sign_loss)]
260    fn test_new_pow_valid_bits() {
261        for bits in [0.1, 10.0, 20.0, 40.0, 59.99] {
262            let challenge = [1u8; 32];
263            let pow = Blake3PoW::new(challenge, bits);
264            let expected_threshold = (64.0 - bits).exp2().ceil() as u64;
265            assert_eq!(pow.threshold, expected_threshold);
266            assert_eq!(pow.challenge, challenge);
267        }
268    }
269
270    #[test]
271    #[should_panic]
272    fn test_new_invalid_bits() {
273        let _ = Blake3PoW::new([0u8; 32], 60.0);
274    }
275
276    #[test]
277    fn test_check_function_basic() {
278        let challenge = [0u8; 32];
279        let mut pow = Blake3PoW::new(challenge, 8.0);
280        for nonce in (0u64..10000).step_by(MAX_SIMD_DEGREE) {
281            if let Some(solution) = pow.check_many(nonce) {
282                assert!(pow.check(solution), "check() should match check_many()");
283                return;
284            }
285        }
286        panic!("Expected at least one valid nonce under threshold using check_many");
287    }
288
289    #[test]
290    fn test_check_many_handles_u64_tail() {
291        let mut pow = Blake3PoW::new([0u8; 32], 50.0);
292        let _ = pow.check_many(u64::MAX - 1);
293        assert_eq!(
294            u64::from_le_bytes(pow.inputs[0][32..40].try_into().unwrap()),
295            u64::MAX - 1
296        );
297        if MAX_SIMD_DEGREE > 2 {
298            assert_eq!(
299                u64::from_le_bytes(pow.inputs[2][32..40].try_into().unwrap()),
300                0
301            );
302        }
303    }
304
305    #[cfg(not(feature = "parallel"))]
306    #[test]
307    fn test_solve_sequential() {
308        let challenge = [2u8; 32];
309        let mut pow = Blake3PoW::new(challenge, 10.0);
310        let nonce = pow.solve().expect("Should find a nonce");
311        assert!(pow.check(nonce), "Found nonce does not satisfy challenge");
312    }
313
314    #[cfg(feature = "parallel")]
315    #[test]
316    fn test_solve_parallel() {
317        let challenge = [3u8; 32];
318        let mut pow = Blake3PoW::new(challenge, 10.0);
319        let _solution = pow.solve().expect("Should find a solution");
320        // Solution is valid by construction
321    }
322
323    #[test]
324    fn test_different_challenges_consistency() {
325        let bits = 8.0;
326        for challenge in sample_challenges() {
327            let mut pow = Blake3PoW::new(challenge, bits);
328            let _solution = pow.solve().expect("Must find solution for low difficulty");
329            // Solution is valid by construction
330        }
331    }
332
333    #[test]
334    fn test_check_many_determinism() {
335        let challenge = [42u8; 32];
336        let mut pow1 = Blake3PoW::new(challenge, 10.0);
337        let mut pow2 = Blake3PoW::new(challenge, 10.0);
338
339        let n1 = pow1.check_many(0);
340        let n2 = pow2.check_many(0);
341        assert_eq!(n1, n2, "check_many should be deterministic");
342    }
343
344    #[test]
345    #[allow(clippy::cast_sign_loss)]
346    fn test_threshold_rounding_boundaries() {
347        let c = [7u8; 32];
348        let bits = 24.5;
349        let pow = Blake3PoW::new(c, bits);
350        let expected = (64.0 - bits).exp2().ceil() as u64;
351        assert_eq!(pow.threshold, expected);
352    }
353
354    #[test]
355    fn test_check_many_inserts_nonce_bytes() {
356        let challenge = [0xAB; 32];
357        let mut pow = Blake3PoW::new(challenge, 50.0);
358
359        // Run check_many to populate nonces.
360        let base_nonce = 12_345_678;
361        let _ = pow.check_many(base_nonce);
362
363        for (i, input) in pow.inputs.iter().enumerate() {
364            // Confirm prefix is unchanged
365            assert_eq!(&input[..32], &challenge);
366            // Confirm suffix is the correct nonce bytes
367            let expected_nonce = base_nonce + i as u64;
368            let actual = u64::from_le_bytes(input[32..40].try_into().unwrap());
369            assert_eq!(actual, expected_nonce);
370        }
371    }
372
373    #[test]
374    fn test_solve_returns_minimal_nonce() {
375        let c = [123; 32];
376        let mut pow = Blake3PoW::new(c, 10.0);
377        let mut best_nonce = None;
378        for nonce in (0..10000).step_by(MAX_SIMD_DEGREE) {
379            if let Some(found) = pow.check_many(nonce) {
380                best_nonce = Some(found);
381                break;
382            }
383        }
384
385        // Get solution from solve()
386        let solution = pow.solve().expect("Should find a solution");
387
388        // If we found a nonce manually, create the same solution and compare
389        if let Some(nonce) = best_nonce {
390            let expected_solution = pow.solution(nonce);
391
392            assert_eq!(
393                solution.nonce, expected_solution.nonce,
394                "solve should return the first valid nonce"
395            );
396        }
397    }
398
399    #[test]
400    fn stress_test_check_many_entropy() {
401        let challenge = [42u8; 32];
402        let mut pow = Blake3PoW::new(challenge, 16.0);
403
404        let mut found = 0;
405        for nonce in (0..1_000_000).step_by(MAX_SIMD_DEGREE) {
406            if pow.check_many(nonce).is_some() {
407                found += 1;
408            }
409        }
410
411        // Should find some hits at low difficulty
412        assert!(found > 0, "Expected to find at least one solution");
413    }
414}