1use alloc::vec;
2use alloc::vec::Vec;
3
4use p3_symmetric::{CryptographicHasher, Hash, MerkleCap};
5
6use crate::{CanFinalizeDigest, CanObserve, CanSample};
7
8#[derive(Clone, Debug)]
10pub struct HashChallenger<T, H, const OUT_LEN: usize>
11where
12 T: Clone,
13 H: CryptographicHasher<T, [T; OUT_LEN]>,
14{
15 input_buffer: Vec<T>,
17 output_buffer: Vec<T>,
19 hasher: H,
21}
22
23impl<T, H, const OUT_LEN: usize> HashChallenger<T, H, OUT_LEN>
24where
25 T: Clone,
26 H: CryptographicHasher<T, [T; OUT_LEN]>,
27{
28 pub const fn new(initial_state: Vec<T>, hasher: H) -> Self {
29 Self {
30 input_buffer: initial_state,
31 output_buffer: vec![],
32 hasher,
33 }
34 }
35
36 fn flush(&mut self) {
37 let inputs = self.input_buffer.drain(..);
38 let output = self.hasher.hash_iter(inputs);
39
40 self.input_buffer.extend_from_slice(&output);
42 self.output_buffer = output.into();
43 }
44}
45
46impl<T, H, const OUT_LEN: usize> CanObserve<T> for HashChallenger<T, H, OUT_LEN>
47where
48 T: Clone,
49 H: CryptographicHasher<T, [T; OUT_LEN]>,
50{
51 fn observe(&mut self, value: T) {
52 self.output_buffer.clear();
54
55 self.input_buffer.push(value);
56 }
57}
58
59impl<T, H, const N: usize, const OUT_LEN: usize> CanObserve<[T; N]>
60 for HashChallenger<T, H, OUT_LEN>
61where
62 T: Clone,
63 H: CryptographicHasher<T, [T; OUT_LEN]>,
64{
65 fn observe(&mut self, values: [T; N]) {
66 self.output_buffer.clear();
67 self.input_buffer.extend(values);
68 }
69}
70
71impl<F, T, H, const N: usize, const OUT_LEN: usize> CanObserve<Hash<F, T, N>>
72 for HashChallenger<T, H, OUT_LEN>
73where
74 T: Clone,
75 H: CryptographicHasher<T, [T; OUT_LEN]>,
76{
77 fn observe(&mut self, values: Hash<F, T, N>) {
78 for value in values {
79 self.observe(value);
80 }
81 }
82}
83
84impl<F, T, H, const N: usize, const OUT_LEN: usize> CanObserve<&MerkleCap<F, [T; N]>>
85 for HashChallenger<T, H, OUT_LEN>
86where
87 T: Clone,
88 H: CryptographicHasher<T, [T; OUT_LEN]>,
89{
90 fn observe(&mut self, cap: &MerkleCap<F, [T; N]>) {
91 for digest in cap.roots() {
92 for value in digest {
93 self.observe(value.clone());
94 }
95 }
96 }
97}
98
99impl<F, T, H, const N: usize, const OUT_LEN: usize> CanObserve<MerkleCap<F, [T; N]>>
100 for HashChallenger<T, H, OUT_LEN>
101where
102 T: Clone,
103 H: CryptographicHasher<T, [T; OUT_LEN]>,
104{
105 fn observe(&mut self, cap: MerkleCap<F, [T; N]>) {
106 self.observe(&cap);
107 }
108}
109
110impl<T, H, const OUT_LEN: usize> CanObserve<Vec<Vec<T>>> for HashChallenger<T, H, OUT_LEN>
112where
113 T: Clone,
114 H: CryptographicHasher<T, [T; OUT_LEN]>,
115{
116 fn observe(&mut self, valuess: Vec<Vec<T>>) {
117 for values in valuess {
118 for value in values {
119 self.observe(value);
120 }
121 }
122 }
123}
124
125impl<T, H, const OUT_LEN: usize> CanSample<T> for HashChallenger<T, H, OUT_LEN>
126where
127 T: Clone,
128 H: CryptographicHasher<T, [T; OUT_LEN]>,
129{
130 fn sample(&mut self) -> T {
131 if self.output_buffer.is_empty() {
132 self.flush();
133 }
134 self.output_buffer
135 .pop()
136 .expect("Output buffer should be non-empty")
137 }
138}
139
140impl<T, H, const OUT_LEN: usize> CanFinalizeDigest for HashChallenger<T, H, OUT_LEN>
141where
142 T: Clone,
143 H: CryptographicHasher<T, [T; OUT_LEN]>,
144{
145 type Digest = [T; OUT_LEN];
146
147 fn finalize(mut self) -> [T; OUT_LEN] {
148 self.flush();
155 core::array::from_fn(|i| self.output_buffer[i].clone())
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use p3_field::PrimeCharacteristicRing;
162 use p3_goldilocks::Goldilocks;
163
164 use super::*;
165
166 const OUT_LEN: usize = 2;
167 type F = Goldilocks;
168
169 #[derive(Clone)]
170 struct TestHasher {}
171
172 impl CryptographicHasher<F, [F; OUT_LEN]> for TestHasher {
173 fn hash_iter<I>(&self, input: I) -> [F; OUT_LEN]
176 where
177 I: IntoIterator<Item = F>,
178 {
179 let (sum, len) = input
180 .into_iter()
181 .fold((F::ZERO, 0_usize), |(acc_sum, acc_len), f| {
182 (acc_sum + f, acc_len + 1)
183 });
184 [sum, F::from_usize(len)]
185 }
186
187 fn hash_iter_slices<'a, I>(&self, input: I) -> [F; OUT_LEN]
190 where
191 I: IntoIterator<Item = &'a [F]>,
192 F: 'a,
193 {
194 let (sum, len) = input
195 .into_iter()
196 .fold((F::ZERO, 0_usize), |(acc_sum, acc_len), n| {
197 (
198 acc_sum + n.iter().fold(F::ZERO, |acc, f| acc + *f),
199 acc_len + n.len(),
200 )
201 });
202 [sum, F::from_usize(len)]
203 }
204 }
205
206 #[test]
207 fn test_hash_challenger() {
208 let initial_state = (1..11_u8).map(F::from_u8).collect::<Vec<_>>();
209 let test_hasher = TestHasher {};
210 let mut hash_challenger = HashChallenger::new(initial_state.clone(), test_hasher);
211
212 assert_eq!(hash_challenger.input_buffer, initial_state);
213 assert_eq!(hash_challenger.output_buffer, vec![]);
214
215 hash_challenger.flush();
216
217 let expected_sum = F::from_u8(55);
218 let expected_len = F::from_u8(10);
219 assert_eq!(
220 hash_challenger.input_buffer,
221 vec![expected_sum, expected_len]
222 );
223 assert_eq!(
224 hash_challenger.output_buffer,
225 vec![expected_sum, expected_len]
226 );
227
228 let new_element = F::from_u8(11);
229 hash_challenger.observe(new_element);
230 assert_eq!(
231 hash_challenger.input_buffer,
232 vec![expected_sum, expected_len, new_element]
233 );
234 assert_eq!(hash_challenger.output_buffer, vec![]);
235
236 let new_expected_len = 3;
237 let new_expected_sum = 76;
238
239 let new_element = hash_challenger.sample();
240 assert_eq!(new_element, F::from_u8(new_expected_len));
241 assert_eq!(
242 hash_challenger.output_buffer,
243 [F::from_u8(new_expected_sum)]
244 );
245 }
246
247 #[test]
248 fn test_hash_challenger_flush() {
249 let initial_state = (1..11_u8).map(F::from_u8).collect::<Vec<_>>();
250 let test_hasher = TestHasher {};
251 let mut hash_challenger = HashChallenger::new(initial_state, test_hasher);
252
253 let first_sample = hash_challenger.sample();
255
256 let second_sample = hash_challenger.sample();
257
258 assert_eq!(first_sample, F::from_u8(10));
260 assert_eq!(second_sample, F::from_u8(55));
262
263 assert!(hash_challenger.output_buffer.is_empty());
265 }
266
267 #[test]
268 fn test_observe_single_value() {
269 let test_hasher = TestHasher {};
270 let mut hash_challenger = HashChallenger::new(vec![F::from_u8(123)], test_hasher);
272
273 let value = F::from_u8(42);
275 hash_challenger.observe(value);
276
277 assert_eq!(
279 hash_challenger.input_buffer,
280 vec![F::from_u8(123), F::from_u8(42)]
281 );
282 assert!(hash_challenger.output_buffer.is_empty());
284 }
285
286 #[test]
287 fn test_observe_array() {
288 let test_hasher = TestHasher {};
289 let mut hash_challenger = HashChallenger::new(vec![F::from_u8(123)], test_hasher);
291
292 let values = [F::from_u8(1), F::from_u8(2), F::from_u8(3)];
294 hash_challenger.observe(values);
295
296 assert_eq!(
298 hash_challenger.input_buffer,
299 vec![F::from_u8(123), F::from_u8(1), F::from_u8(2), F::from_u8(3)]
300 );
301 assert!(hash_challenger.output_buffer.is_empty());
303 }
304
305 #[test]
306 fn test_observe_hash_cap_and_nested_vec() {
307 let test_hasher = TestHasher {};
308
309 let mut from_hash = HashChallenger::new(vec![], test_hasher.clone());
311 from_hash.observe(Hash::<F, F, 3>::from([
312 F::from_u8(1),
313 F::from_u8(2),
314 F::from_u8(3),
315 ]));
316
317 let mut from_array = HashChallenger::new(vec![], test_hasher.clone());
318 from_array.observe([F::from_u8(1), F::from_u8(2), F::from_u8(3)]);
319 assert_eq!(from_hash.input_buffer, from_array.input_buffer);
320
321 let cap = MerkleCap::<F, [F; 2]>::new(vec![
323 [F::from_u8(4), F::from_u8(5)],
324 [F::from_u8(6), F::from_u8(7)],
325 ]);
326 let flat = vec![F::from_u8(4), F::from_u8(5), F::from_u8(6), F::from_u8(7)];
327
328 let mut from_cap_ref = HashChallenger::new(vec![], test_hasher.clone());
329 from_cap_ref.observe(&cap);
330 assert_eq!(from_cap_ref.input_buffer, flat);
331
332 let mut from_cap_owned = HashChallenger::new(vec![], test_hasher.clone());
333 from_cap_owned.observe(cap);
334 assert_eq!(from_cap_owned.input_buffer, flat);
335
336 let mut from_nested = HashChallenger::new(vec![], test_hasher);
338 from_nested.observe(vec![
339 vec![F::from_u8(8), F::from_u8(9)],
340 vec![F::from_u8(10)],
341 ]);
342 assert_eq!(
343 from_nested.input_buffer,
344 vec![F::from_u8(8), F::from_u8(9), F::from_u8(10)]
345 );
346 }
347
348 #[test]
349 fn test_sample_output_buffer() {
350 let test_hasher = TestHasher {};
351 let initial_state = vec![F::from_u8(5), F::from_u8(10)];
352 let mut hash_challenger = HashChallenger::new(initial_state, test_hasher);
353
354 let sample = hash_challenger.sample();
355 assert_eq!(sample, F::from_u8(2));
357 assert_eq!(hash_challenger.output_buffer, vec![F::from_u8(15)]);
359 }
360
361 #[test]
362 fn test_flush_empty_buffer() {
363 let test_hasher = TestHasher {};
364 let mut hash_challenger = HashChallenger::new(vec![], test_hasher);
365
366 hash_challenger.flush();
368
369 assert_eq!(hash_challenger.input_buffer, vec![F::ZERO, F::ZERO]);
371 assert_eq!(hash_challenger.output_buffer, vec![F::ZERO, F::ZERO]);
372 }
373
374 #[test]
375 fn test_flush_with_data() {
376 let test_hasher = TestHasher {};
377 let initial_state = vec![F::from_u8(1), F::from_u8(2)];
379 let mut hash_challenger = HashChallenger::new(initial_state, test_hasher);
380
381 hash_challenger.flush();
382
383 assert_eq!(
385 hash_challenger.input_buffer,
386 vec![F::from_u8(3), F::from_u8(2)]
387 );
388 assert_eq!(
390 hash_challenger.output_buffer,
391 vec![F::from_u8(3), F::from_u8(2)]
392 );
393 }
394
395 #[test]
396 fn test_sample_after_observe() {
397 let test_hasher = TestHasher {};
398 let initial_state = vec![F::from_u8(1), F::from_u8(2)];
399 let mut hash_challenger = HashChallenger::new(initial_state, test_hasher);
400
401 hash_challenger.observe(F::from_u8(3));
403
404 assert!(hash_challenger.output_buffer.is_empty());
406
407 assert_eq!(
409 hash_challenger.input_buffer,
410 vec![F::from_u8(1), F::from_u8(2), F::from_u8(3)]
411 );
412
413 let sample = hash_challenger.sample();
414
415 assert_eq!(sample, F::from_u8(3));
417 }
418
419 #[test]
420 fn test_sample_with_non_empty_output_buffer() {
421 let test_hasher = TestHasher {};
422 let mut hash_challenger = HashChallenger::new(vec![], test_hasher);
423
424 hash_challenger.output_buffer = vec![F::from_u8(42), F::from_u8(24)];
425
426 let sample = hash_challenger.sample();
427
428 assert_eq!(sample, F::from_u8(24));
430
431 assert_eq!(hash_challenger.output_buffer, vec![F::from_u8(42)]);
433 }
434
435 #[test]
436 fn test_finalize() {
437 let new_chal = || HashChallenger::new(vec![F::from_u8(1), F::from_u8(2)], TestHasher {});
438
439 let mut h1 = new_chal();
441 let mut h2 = new_chal();
442 h1.observe(F::from_u8(42));
443 h2.observe(F::from_u8(42));
444 assert_eq!(h1.finalize(), h2.finalize());
445
446 let mut h1 = new_chal();
448 let mut h2 = new_chal();
449 h1.observe(F::from_u8(1));
450 h2.observe(F::from_u8(2));
451 assert_ne!(h1.finalize(), h2.finalize());
452 }
453
454 #[test]
462 fn test_finalize_sample_interaction() {
463 let digest = |n_samples: usize| {
464 let mut c = HashChallenger::new(vec![F::from_u8(1), F::from_u8(2)], TestHasher {});
465 c.observe(F::from_u8(42));
466 for _ in 0..n_samples {
467 let _: F = c.sample();
468 }
469 c.finalize()
470 };
471
472 assert_ne!(digest(0), digest(1));
476
477 assert_eq!(digest(1), digest(OUT_LEN));
481
482 assert_ne!(digest(OUT_LEN), digest(OUT_LEN + 1));
485
486 assert_eq!(digest(OUT_LEN + 1), digest(2 * OUT_LEN));
488 }
489
490 #[test]
491 fn test_output_buffer_cleared_on_observe() {
492 let test_hasher = TestHasher {};
493 let mut hash_challenger = HashChallenger::new(vec![], test_hasher);
494
495 hash_challenger.output_buffer.push(F::from_u8(42));
497
498 assert!(!hash_challenger.output_buffer.is_empty());
500
501 hash_challenger.observe(F::from_u8(3));
503
504 assert!(hash_challenger.output_buffer.is_empty());
506 }
507
508 #[test]
509 fn test_observe_empty_array_clears_output_buffer() {
510 let test_hasher = TestHasher {};
511 let initial_state = vec![F::from_u8(1), F::from_u8(2)];
512 let mut hash_challenger = HashChallenger::new(initial_state.clone(), test_hasher);
513
514 hash_challenger.output_buffer.push(F::from_u8(42));
516 assert!(!hash_challenger.output_buffer.is_empty());
517
518 let values: [F; 0] = [];
520 hash_challenger.observe(values);
521
522 assert!(hash_challenger.output_buffer.is_empty());
523 assert_eq!(hash_challenger.input_buffer, initial_state);
525 }
526}