1use blake3::{
2 self,
3 guts::BLOCK_LEN,
4 platform::{Platform, MAX_SIMD_DEGREE},
5 IncrementCounter, OUT_LEN,
6};
7#[cfg(feature = "parallel")]
8use rayon::broadcast;
9
10use super::PowStrategy;
11
12#[derive(Clone, Copy)]
19pub struct Blake3PoW {
20 challenge: [u8; 32],
22 threshold: u64,
24 platform: Platform,
26 inputs: [[u8; BLOCK_LEN]; MAX_SIMD_DEGREE],
28 outputs: [u8; OUT_LEN * MAX_SIMD_DEGREE],
30}
31
32impl PowStrategy for Blake3PoW {
33 #[allow(clippy::cast_sign_loss)]
43 fn new(challenge: [u8; 32], bits: f64) -> Self {
44 assert_eq!(BLOCK_LEN, 64);
46 assert_eq!(OUT_LEN, 32);
48 assert!((0.0..60.0).contains(&bits), "bits must be smaller than 60");
50
51 let mut inputs = [[0u8; BLOCK_LEN]; MAX_SIMD_DEGREE];
53 for input in &mut inputs {
54 input[..32].copy_from_slice(&challenge);
55 }
56
57 Self {
58 challenge,
60 threshold: (64.0 - bits).exp2().ceil() as u64,
62 platform: Platform::detect(),
64 inputs,
66 outputs: [0; OUT_LEN * MAX_SIMD_DEGREE],
68 }
69 }
70
71 fn check(&mut self, nonce: u64) -> bool {
79 let mut hasher = blake3::Hasher::new();
81
82 hasher.update(&self.challenge);
84 hasher.update(&nonce.to_le_bytes());
86 hasher.update(&[0; 24]);
88
89 let mut hash = [0u8; 8];
91 hasher.finalize_xof().fill(&mut hash);
92
93 u64::from_le_bytes(hash) < self.threshold
95 }
96
97 #[cfg(not(feature = "parallel"))]
99 fn solve(&mut self) -> Option<u64> {
100 (0..)
101 .step_by(MAX_SIMD_DEGREE)
102 .find_map(|nonce| self.check_many(nonce))
103 }
104
105 #[cfg(feature = "parallel")]
111 fn solve(&mut self) -> Option<u64> {
112 use std::sync::atomic::{AtomicU64, Ordering};
113
114 let global_min = AtomicU64::new(u64::MAX);
117
118 let _ = broadcast(|ctx| {
120 let mut worker = *self;
122
123 let nonces = ((MAX_SIMD_DEGREE * ctx.index()) as u64..)
125 .step_by(MAX_SIMD_DEGREE * ctx.num_threads());
126
127 for nonce in nonces {
128 if nonce >= global_min.load(Ordering::Relaxed) {
133 break;
134 }
135 if let Some(nonce) = worker.check_many(nonce) {
137 global_min.fetch_min(nonce, Ordering::SeqCst);
140 break;
141 }
142 }
143 });
144
145 match global_min.load(Ordering::SeqCst) {
147 u64::MAX => self.check(u64::MAX).then_some(u64::MAX),
148 nonce => Some(nonce),
149 }
150 }
151}
152
153impl Blake3PoW {
154 #[allow(clippy::unreadable_literal)]
156 const BLAKE3_IV: [u32; 8] = [
157 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB,
158 0x5BE0CD19,
159 ];
160 const BLAKE3_FLAGS: u8 = 0x0B; fn check_many(&mut self, nonce: u64) -> Option<u64> {
167 for (i, input) in self.inputs.iter_mut().enumerate() {
169 let n = (nonce + i as u64).to_le_bytes();
171 input[32..40].copy_from_slice(&n);
172 }
173
174 let input_refs: [&[u8; BLOCK_LEN]; MAX_SIMD_DEGREE] =
176 std::array::from_fn(|i| &self.inputs[i]);
177
178 self.platform.hash_many::<BLOCK_LEN>(
180 &input_refs,
181 &Self::BLAKE3_IV, 0, IncrementCounter::No, Self::BLAKE3_FLAGS, 0,
186 0, &mut self.outputs,
188 );
189
190 for (i, chunk) in self.outputs.chunks_exact(OUT_LEN).enumerate() {
192 let hash = u64::from_le_bytes(chunk[..8].try_into().unwrap());
193 if hash < self.threshold {
194 return Some(nonce + i as u64);
195 }
196 }
197
198 None
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use spongefish::{DefaultHash, DomainSeparator};
206
207 use super::*;
208 use crate::{
209 ByteDomainSeparator, BytesToUnitDeserialize, BytesToUnitSerialize, PoWChallenge,
210 PoWDomainSeparator,
211 };
212
213 fn sample_challenges() -> Vec<[u8; 32]> {
214 vec![
215 [0u8; 32], [0xFF; 32], [42u8; 32], (0..32).collect::<Vec<u8>>().try_into().unwrap(), (0..32).rev().collect::<Vec<u8>>().try_into().unwrap(), ]
221 }
222
223 #[test]
224 fn test_pow_blake3() {
225 const BITS: f64 = 10.0;
226
227 let domain_separator = DomainSeparator::<DefaultHash>::new("the proof of work lottery 🎰")
228 .add_bytes(1, "something")
229 .challenge_pow("rolling dices");
230
231 let mut prover = domain_separator.to_prover_state();
232 prover.add_bytes(b"\0").expect("Invalid DomainSeparator");
233 prover.challenge_pow::<Blake3PoW>(BITS).unwrap();
234
235 let mut verifier = domain_separator.to_verifier_state(prover.narg_string());
236 let byte = verifier.next_bytes::<1>().unwrap();
237 assert_eq!(&byte, b"\0");
238 verifier.challenge_pow::<Blake3PoW>(BITS).unwrap();
239 }
240
241 #[test]
242 #[allow(clippy::cast_sign_loss)]
243 fn test_new_pow_valid_bits() {
244 for bits in [0.1, 10.0, 20.0, 40.0, 59.99] {
245 let challenge = [1u8; 32];
246 let pow = Blake3PoW::new(challenge, bits);
247 let expected_threshold = (64.0 - bits).exp2().ceil() as u64;
248 assert_eq!(pow.threshold, expected_threshold);
249 assert_eq!(pow.challenge, challenge);
250 }
251 }
252
253 #[test]
254 #[should_panic]
255 fn test_new_invalid_bits() {
256 let _ = Blake3PoW::new([0u8; 32], 60.0);
257 }
258
259 #[test]
260 fn test_check_function_basic() {
261 let challenge = [0u8; 32];
262 let mut pow = Blake3PoW::new(challenge, 8.0);
263 for nonce in (0u64..10000).step_by(MAX_SIMD_DEGREE) {
264 if let Some(solution) = pow.check_many(nonce) {
265 assert!(pow.check(solution), "check() should match check_many()");
266 return;
267 }
268 }
269 panic!("Expected at least one valid nonce under threshold using check_many");
270 }
271
272 #[cfg(not(feature = "parallel"))]
273 #[test]
274 fn test_solve_sequential() {
275 let challenge = [2u8; 32];
276 let mut pow = Blake3PoW::new(challenge, 10.0);
277 let nonce = pow.solve().expect("Should find a nonce");
278 assert!(pow.check(nonce), "Found nonce does not satisfy challenge");
279 }
280
281 #[cfg(feature = "parallel")]
282 #[test]
283 fn test_solve_parallel() {
284 let challenge = [3u8; 32];
285 let mut pow = Blake3PoW::new(challenge, 10.0);
286 let nonce = pow.solve().expect("Should find a nonce");
287 assert!(pow.check(nonce), "Found nonce does not satisfy challenge");
288 }
289
290 #[test]
291 fn test_different_challenges_consistency() {
292 let bits = 8.0;
293 for challenge in sample_challenges() {
294 let mut pow = Blake3PoW::new(challenge, bits);
295 let nonce = pow.solve().expect("Must find solution for low difficulty");
296 assert!(pow.check(nonce));
297 }
298 }
299
300 #[test]
301 fn test_check_many_determinism() {
302 let challenge = [42u8; 32];
303 let mut pow1 = Blake3PoW::new(challenge, 10.0);
304 let mut pow2 = Blake3PoW::new(challenge, 10.0);
305
306 let n1 = pow1.check_many(0);
307 let n2 = pow2.check_many(0);
308 assert_eq!(n1, n2, "check_many should be deterministic");
309 }
310
311 #[test]
312 #[allow(clippy::cast_sign_loss)]
313 fn test_threshold_rounding_boundaries() {
314 let c = [7u8; 32];
315 let bits = 24.5;
316 let pow = Blake3PoW::new(c, bits);
317 let expected = (64.0 - bits).exp2().ceil() as u64;
318 assert_eq!(pow.threshold, expected);
319 }
320
321 #[test]
322 fn test_check_many_inserts_nonce_bytes() {
323 let challenge = [0xAB; 32];
324 let mut pow = Blake3PoW::new(challenge, 50.0);
325
326 let base_nonce = 12_345_678;
328 let _ = pow.check_many(base_nonce);
329
330 for (i, input) in pow.inputs.iter().enumerate() {
331 assert_eq!(&input[..32], &challenge);
333 let expected_nonce = base_nonce + i as u64;
335 let actual = u64::from_le_bytes(input[32..40].try_into().unwrap());
336 assert_eq!(actual, expected_nonce);
337 }
338 }
339
340 #[test]
341 fn test_solve_returns_minimal_nonce() {
342 let c = [123; 32];
343 let mut pow = Blake3PoW::new(c, 10.0);
344 let mut best = None;
345 for nonce in (0..10000).step_by(MAX_SIMD_DEGREE) {
346 if let Some(found) = pow.check_many(nonce) {
347 best = Some(found);
348 break;
349 }
350 }
351 let result = pow.solve();
352 assert_eq!(result, best, "solve should return the first valid nonce");
353 }
354
355 #[test]
356 fn stress_test_check_many_entropy() {
357 let challenge = [42u8; 32];
358 let mut pow = Blake3PoW::new(challenge, 16.0);
359
360 let mut found = 0;
361 for nonce in (0..1_000_000).step_by(MAX_SIMD_DEGREE) {
362 if pow.check_many(nonce).is_some() {
363 found += 1;
364 }
365 }
366
367 assert!(found > 0, "Expected to find at least one solution");
369 }
370}