Skip to main content

spongefish_pow/
blake3.rs

1use blake3::{
2    guts::BLOCK_LEN,
3    platform::{Platform, MAX_SIMD_DEGREE},
4    IncrementCounter, OUT_LEN,
5};
6#[cfg(feature = "parallel")]
7use rayon::broadcast;
8
9use super::PowStrategy;
10use crate::PoWSolution;
11
12/// A SIMD-accelerated BLAKE3-based proof-of-work engine.
13///
14/// This struct encapsulates the state needed to search for a nonce such that
15/// `BLAKE3(challenge || nonce)` is below a difficulty threshold.
16///
17/// It leverages `Platform::hash_many` for parallel hash evaluation using `MAX_SIMD_DEGREE` lanes.
18#[derive(Clone, Copy)]
19pub struct Blake3PoW {
20    /// The 32-byte challenge seed used as a prefix to every hash input.
21    challenge: [u8; 32],
22    /// Difficulty target: hashes must be less than this 64-bit threshold.
23    threshold: u64,
24    /// Platform-specific SIMD hashing backend selected at runtime.
25    platform: Platform,
26    /// SIMD batch of hash inputs, each 64 bytes (challenge + nonce).
27    inputs: [[u8; BLOCK_LEN]; MAX_SIMD_DEGREE],
28    /// SIMD batch of hash outputs (32 bytes each).
29    outputs: [u8; OUT_LEN * MAX_SIMD_DEGREE],
30}
31
32impl PowStrategy for Blake3PoW {
33    /// Create a new Blake3PoW instance with a given challenge and difficulty.
34    ///
35    /// The `bits` parameter controls the difficulty. A higher number means
36    /// lower probability of success per nonce. This function prepares the SIMD
37    /// input buffer with the challenge prefix and sets the internal threshold.
38    ///
39    /// # Panics
40    /// - If `bits` is not in the range [0.0, 60.0).
41    /// - If `BLOCK_LEN` or `OUT_LEN` do not match expected values.
42    #[allow(clippy::cast_sign_loss)]
43    fn new(challenge: [u8; 32], bits: f64) -> Self {
44        // BLAKE3 block size must be 64 bytes.
45        assert_eq!(BLOCK_LEN, 64);
46        // BLAKE3 output size must be 32 bytes.
47        assert_eq!(OUT_LEN, 32);
48        // Ensure the difficulty is within supported range.
49        assert!((0.0..60.0).contains(&bits), "bits must be smaller than 60");
50
51        // Prepare SIMD input buffer: fill each lane with the challenge prefix.
52        let mut inputs = [[0u8; BLOCK_LEN]; MAX_SIMD_DEGREE];
53        for input in &mut inputs {
54            input[..32].copy_from_slice(&challenge);
55        }
56
57        Self {
58            // Store challenge prefix.
59            challenge,
60            // Compute threshold: smaller means harder PoW.
61            threshold: (64.0 - bits).exp2().ceil() as u64,
62            // Detect SIMD platform (e.g., AVX2, NEON, etc).
63            platform: Platform::detect(),
64            // Pre-filled SIMD inputs (nonce injected later).
65            inputs,
66            // Zero-initialized output buffer for SIMD hashes.
67            outputs: [0; OUT_LEN * MAX_SIMD_DEGREE],
68        }
69    }
70
71    fn solution(&self, nonce: u64) -> PoWSolution {
72        PoWSolution {
73            challenge: self.challenge,
74            nonce,
75        }
76    }
77
78    /// Check if a given `nonce` satisfies the challenge.
79    ///
80    /// This uses the standard high-level BLAKE3 interface to ensure
81    /// full compatibility with reference implementations.
82    ///
83    /// A nonce is valid if the first 8 bytes of the hash output,
84    /// interpreted as a little-endian `u64`, are below the internal threshold.
85    fn check(&mut self, nonce: u64) -> bool {
86        // Create a new BLAKE3 hasher instance.
87        let mut hasher = blake3::Hasher::new();
88
89        // Feed the challenge prefix.
90        hasher.update(&self.challenge);
91        // Feed the nonce as little-endian bytes.
92        hasher.update(&nonce.to_le_bytes());
93        // Zero-extend the nonce to 32 bytes (challenge + nonce = full block).
94        hasher.update(&[0; 24]);
95
96        // Hash the input and extract the first 8 bytes.
97        let mut hash = [0u8; 8];
98        hasher.finalize_xof().fill(&mut hash);
99
100        // Check whether the result is below the threshold.
101        u64::from_le_bytes(hash) < self.threshold
102    }
103
104    /// Finds the minimal `nonce` that satisfies the challenge.
105    #[cfg(not(feature = "parallel"))]
106    fn solve(&mut self) -> Option<PoWSolution> {
107        (0..)
108            .step_by(MAX_SIMD_DEGREE)
109            .find_map(|nonce| self.check_many(nonce))
110            .map(|nonce| self.solution(nonce))
111    }
112
113    /// Search for the lowest `nonce` that satisfies the challenge using parallel threads.
114    ///
115    /// Each thread scans disjoint chunks of the nonce space in stride-sized steps.
116    /// The first thread to find a satisfying nonce updates a shared atomic minimum,
117    /// and all others check against it to avoid unnecessary work.
118    #[cfg(feature = "parallel")]
119    fn solve(&mut self) -> Option<PoWSolution> {
120        use std::sync::atomic::{AtomicU64, Ordering};
121
122        // Split the work across all available threads.
123        // Use atomics to find the unique deterministic lowest satisfying nonce.
124        let global_min = AtomicU64::new(u64::MAX);
125
126        // Spawn parallel workers using Rayon's broadcast.
127        let _ = broadcast(|ctx| {
128            // Copy the PoW instance for thread-local use.
129            let mut worker = *self;
130
131            // Each thread searches a distinct subset of nonces.
132            let nonces = ((MAX_SIMD_DEGREE * ctx.index()) as u64..)
133                .step_by(MAX_SIMD_DEGREE * ctx.num_threads());
134
135            for nonce in nonces {
136                // Skip work if another thread already found a lower valid nonce.
137                //
138                // Use relaxed ordering to eventually get notified of another thread's solution.
139                // (Propagation delay should be in the order of tens of nanoseconds.)
140                if nonce >= global_min.load(Ordering::Relaxed) {
141                    break;
142                }
143                // Check a batch of nonces starting from `nonce`.
144                if let Some(nonce) = worker.check_many(nonce) {
145                    // We found a solution, store it in the global_min.
146                    // Use fetch_min to solve race condition with simultaneous solutions.
147                    global_min.fetch_min(nonce, Ordering::SeqCst);
148                    break;
149                }
150            }
151        });
152
153        // Return the best found nonce, or fallback check on `u64::MAX`.
154        match global_min.load(Ordering::SeqCst) {
155            u64::MAX => self.check(u64::MAX).then(|| self.solution(u64::MAX)),
156            nonce => Some(self.solution(nonce)),
157        }
158    }
159}
160
161impl Blake3PoW {
162    /// Default Blake3 initialization vector. Copied here because it is not publicly exported.
163    #[allow(clippy::unreadable_literal)]
164    const BLAKE3_IV: [u32; 8] = [
165        0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB,
166        0x5BE0CD19,
167    ];
168    const BLAKE3_FLAGS: u8 = 0x0B; // CHUNK_START | CHUNK_END | ROOT
169
170    /// Check a SIMD-width batch of nonces starting at `nonce`.
171    ///
172    /// Returns the first nonce in the batch that satisfies the challenge threshold,
173    /// or `None` if none do.
174    fn check_many(&mut self, nonce: u64) -> Option<u64> {
175        // Fill each SIMD input block with the challenge + nonce suffix.
176        // If the batch would overflow `u64`, stop at the last representable nonce
177        // and zero the remaining lanes so they never get reported as valid.
178        let mut valid_lanes = MAX_SIMD_DEGREE;
179        for i in 0..MAX_SIMD_DEGREE {
180            let Some(batch_nonce) = nonce.checked_add(i as u64) else {
181                valid_lanes = i;
182                break;
183            };
184            self.inputs[i][32..40].copy_from_slice(&batch_nonce.to_le_bytes());
185        }
186        for input in &mut self.inputs[valid_lanes..] {
187            input[32..40].fill(0);
188        }
189
190        // Create references required by `hash_many`.
191        let input_refs: [&[u8; BLOCK_LEN]; MAX_SIMD_DEGREE] =
192            std::array::from_fn(|i| &self.inputs[i]);
193
194        // Perform parallel hashing over the input blocks.
195        self.platform.hash_many::<BLOCK_LEN>(
196            &input_refs,
197            &Self::BLAKE3_IV,     // Initialization vector
198            0,                    // Counter
199            IncrementCounter::No, // Do not increment counter
200            Self::BLAKE3_FLAGS,   // Default flags
201            0,
202            0, // No start/end flags
203            &mut self.outputs,
204        );
205
206        // Scan results and return the first nonce under the threshold.
207        for (i, chunk) in self
208            .outputs
209            .as_chunks::<OUT_LEN>()
210            .0
211            .iter()
212            .take(valid_lanes)
213            .enumerate()
214        {
215            let hash = u64::from_le_bytes(chunk[..8].try_into().unwrap());
216            if hash < self.threshold {
217                return nonce.checked_add(i as u64);
218            }
219        }
220
221        // None of the batch satisfied the condition.
222        None
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use crate::{convenience::*, PoWGrinder};
230
231    fn sample_challenges() -> Vec<[u8; 32]> {
232        vec![
233            [0u8; 32],                                              // All zeroes
234            [0xFF; 32],                                             // All ones
235            [42u8; 32],                                             // Constant value
236            (0..32).collect::<Vec<u8>>().try_into().unwrap(),       // Increasing
237            (0..32).rev().collect::<Vec<u8>>().try_into().unwrap(), // Decreasing
238        ]
239    }
240
241    #[test]
242    fn test_pow_blake3() {
243        const BITS: f64 = 10.0;
244
245        for challenge in sample_challenges() {
246            // Generate a proof-of-work solution
247            let _solution =
248                grind_pow::<Blake3PoW>(challenge, BITS).expect("Should find a valid solution");
249
250            // Verify we can generate solutions consistently
251            assert!(grind_pow::<Blake3PoW>(challenge, BITS).is_some());
252        }
253
254        // Test using the PoWGrinder directly
255        let challenge = [42u8; 32];
256        let mut grinder = PoWGrinder::<Blake3PoW>::new(challenge, BITS);
257        let _solution = grinder.grind().expect("Should find a valid solution");
258    }
259
260    #[test]
261    #[allow(clippy::cast_sign_loss)]
262    fn test_new_pow_valid_bits() {
263        for bits in [0.1, 10.0, 20.0, 40.0, 59.99] {
264            let challenge = [1u8; 32];
265            let pow = Blake3PoW::new(challenge, bits);
266            let expected_threshold = (64.0 - bits).exp2().ceil() as u64;
267            assert_eq!(pow.threshold, expected_threshold);
268            assert_eq!(pow.challenge, challenge);
269        }
270    }
271
272    #[test]
273    #[should_panic]
274    fn test_new_invalid_bits() {
275        let _ = Blake3PoW::new([0u8; 32], 60.0);
276    }
277
278    #[test]
279    fn test_check_function_basic() {
280        let challenge = [0u8; 32];
281        let mut pow = Blake3PoW::new(challenge, 8.0);
282        for nonce in (0u64..10000).step_by(MAX_SIMD_DEGREE) {
283            if let Some(solution) = pow.check_many(nonce) {
284                assert!(pow.check(solution), "check() should match check_many()");
285                return;
286            }
287        }
288        panic!("Expected at least one valid nonce under threshold using check_many");
289    }
290
291    #[test]
292    fn test_check_many_handles_u64_tail() {
293        let mut pow = Blake3PoW::new([0u8; 32], 50.0);
294        let _ = pow.check_many(u64::MAX - 1);
295        assert_eq!(
296            u64::from_le_bytes(pow.inputs[0][32..40].try_into().unwrap()),
297            u64::MAX - 1
298        );
299        if MAX_SIMD_DEGREE > 2 {
300            assert_eq!(
301                u64::from_le_bytes(pow.inputs[2][32..40].try_into().unwrap()),
302                0
303            );
304        }
305    }
306
307    #[cfg(not(feature = "parallel"))]
308    #[test]
309    fn test_solve_sequential() {
310        let challenge = [2u8; 32];
311        let mut pow = Blake3PoW::new(challenge, 10.0);
312        let nonce = pow.solve().expect("Should find a nonce");
313        assert!(pow.check(nonce), "Found nonce does not satisfy challenge");
314    }
315
316    #[cfg(feature = "parallel")]
317    #[test]
318    fn test_solve_parallel() {
319        let challenge = [3u8; 32];
320        let mut pow = Blake3PoW::new(challenge, 10.0);
321        let _solution = pow.solve().expect("Should find a solution");
322        // Solution is valid by construction
323    }
324
325    #[test]
326    fn test_different_challenges_consistency() {
327        let bits = 8.0;
328        for challenge in sample_challenges() {
329            let mut pow = Blake3PoW::new(challenge, bits);
330            let _solution = pow.solve().expect("Must find solution for low difficulty");
331            // Solution is valid by construction
332        }
333    }
334
335    #[test]
336    fn test_check_many_determinism() {
337        let challenge = [42u8; 32];
338        let mut pow1 = Blake3PoW::new(challenge, 10.0);
339        let mut pow2 = Blake3PoW::new(challenge, 10.0);
340
341        let n1 = pow1.check_many(0);
342        let n2 = pow2.check_many(0);
343        assert_eq!(n1, n2, "check_many should be deterministic");
344    }
345
346    #[test]
347    #[allow(clippy::cast_sign_loss)]
348    fn test_threshold_rounding_boundaries() {
349        let c = [7u8; 32];
350        let bits = 24.5;
351        let pow = Blake3PoW::new(c, bits);
352        let expected = (64.0 - bits).exp2().ceil() as u64;
353        assert_eq!(pow.threshold, expected);
354    }
355
356    #[test]
357    fn test_check_many_inserts_nonce_bytes() {
358        let challenge = [0xAB; 32];
359        let mut pow = Blake3PoW::new(challenge, 50.0);
360
361        // Run check_many to populate nonces.
362        let base_nonce = 12_345_678;
363        let _ = pow.check_many(base_nonce);
364
365        for (i, input) in pow.inputs.iter().enumerate() {
366            // Confirm prefix is unchanged
367            assert_eq!(&input[..32], &challenge);
368            // Confirm suffix is the correct nonce bytes
369            let expected_nonce = base_nonce + i as u64;
370            let actual = u64::from_le_bytes(input[32..40].try_into().unwrap());
371            assert_eq!(actual, expected_nonce);
372        }
373    }
374
375    #[test]
376    fn test_solve_returns_minimal_nonce() {
377        let c = [123; 32];
378        let mut pow = Blake3PoW::new(c, 10.0);
379        let mut best_nonce = None;
380        for nonce in (0..10000).step_by(MAX_SIMD_DEGREE) {
381            if let Some(found) = pow.check_many(nonce) {
382                best_nonce = Some(found);
383                break;
384            }
385        }
386
387        // Get solution from solve()
388        let solution = pow.solve().expect("Should find a solution");
389
390        // If we found a nonce manually, create the same solution and compare
391        if let Some(nonce) = best_nonce {
392            let expected_solution = pow.solution(nonce);
393
394            assert_eq!(
395                solution.nonce, expected_solution.nonce,
396                "solve should return the first valid nonce"
397            );
398        }
399    }
400
401    #[test]
402    fn stress_test_check_many_entropy() {
403        let challenge = [42u8; 32];
404        let mut pow = Blake3PoW::new(challenge, 16.0);
405
406        let mut found = 0;
407        for nonce in (0..1_000_000).step_by(MAX_SIMD_DEGREE) {
408            if pow.check_many(nonce).is_some() {
409                found += 1;
410            }
411        }
412
413        // Should find some hits at low difficulty
414        assert!(found > 0, "Expected to find at least one solution");
415    }
416}