1use alloc::vec;
2use alloc::vec::Vec;
3
4use p3_symmetric::CryptographicHasher;
5
6use crate::{CanFinalizeDigest, CanObserve, CanSample};
7
8#[derive(Clone, Debug)]
10pub struct HashChallenger<T, H, const OUT_LEN: usize>
11where
12 T: Clone,
13 H: CryptographicHasher<T, [T; OUT_LEN]>,
14{
15 input_buffer: Vec<T>,
17 output_buffer: Vec<T>,
19 hasher: H,
21}
22
23impl<T, H, const OUT_LEN: usize> HashChallenger<T, H, OUT_LEN>
24where
25 T: Clone,
26 H: CryptographicHasher<T, [T; OUT_LEN]>,
27{
28 pub const fn new(initial_state: Vec<T>, hasher: H) -> Self {
29 Self {
30 input_buffer: initial_state,
31 output_buffer: vec![],
32 hasher,
33 }
34 }
35
36 fn flush(&mut self) {
37 let inputs = self.input_buffer.drain(..);
38 let output = self.hasher.hash_iter(inputs);
39
40 self.input_buffer.extend_from_slice(&output);
42 self.output_buffer = output.into();
43 }
44}
45
46impl<T, H, const OUT_LEN: usize> CanObserve<T> for HashChallenger<T, H, OUT_LEN>
47where
48 T: Clone,
49 H: CryptographicHasher<T, [T; OUT_LEN]>,
50{
51 fn observe(&mut self, value: T) {
52 self.output_buffer.clear();
54
55 self.input_buffer.push(value);
56 }
57}
58
59impl<T, H, const N: usize, const OUT_LEN: usize> CanObserve<[T; N]>
60 for HashChallenger<T, H, OUT_LEN>
61where
62 T: Clone,
63 H: CryptographicHasher<T, [T; OUT_LEN]>,
64{
65 fn observe(&mut self, values: [T; N]) {
66 if N == 0 {
67 return;
68 }
69
70 self.output_buffer.clear();
71 self.input_buffer.extend(values);
72 }
73}
74
75impl<T, H, const OUT_LEN: usize> CanSample<T> for HashChallenger<T, H, OUT_LEN>
76where
77 T: Clone,
78 H: CryptographicHasher<T, [T; OUT_LEN]>,
79{
80 fn sample(&mut self) -> T {
81 if self.output_buffer.is_empty() {
82 self.flush();
83 }
84 self.output_buffer
85 .pop()
86 .expect("Output buffer should be non-empty")
87 }
88}
89
90impl<T, H, const OUT_LEN: usize> CanFinalizeDigest for HashChallenger<T, H, OUT_LEN>
91where
92 T: Clone,
93 H: CryptographicHasher<T, [T; OUT_LEN]>,
94{
95 type Digest = [T; OUT_LEN];
96
97 fn finalize(mut self) -> [T; OUT_LEN] {
98 self.flush();
105 core::array::from_fn(|i| self.output_buffer[i].clone())
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use p3_field::PrimeCharacteristicRing;
112 use p3_goldilocks::Goldilocks;
113
114 use super::*;
115
116 const OUT_LEN: usize = 2;
117 type F = Goldilocks;
118
119 #[derive(Clone)]
120 struct TestHasher {}
121
122 impl CryptographicHasher<F, [F; OUT_LEN]> for TestHasher {
123 fn hash_iter<I>(&self, input: I) -> [F; OUT_LEN]
126 where
127 I: IntoIterator<Item = F>,
128 {
129 let (sum, len) = input
130 .into_iter()
131 .fold((F::ZERO, 0_usize), |(acc_sum, acc_len), f| {
132 (acc_sum + f, acc_len + 1)
133 });
134 [sum, F::from_usize(len)]
135 }
136
137 fn hash_iter_slices<'a, I>(&self, input: I) -> [F; OUT_LEN]
140 where
141 I: IntoIterator<Item = &'a [F]>,
142 F: 'a,
143 {
144 let (sum, len) = input
145 .into_iter()
146 .fold((F::ZERO, 0_usize), |(acc_sum, acc_len), n| {
147 (
148 acc_sum + n.iter().fold(F::ZERO, |acc, f| acc + *f),
149 acc_len + n.len(),
150 )
151 });
152 [sum, F::from_usize(len)]
153 }
154 }
155
156 #[test]
157 fn test_hash_challenger() {
158 let initial_state = (1..11_u8).map(F::from_u8).collect::<Vec<_>>();
159 let test_hasher = TestHasher {};
160 let mut hash_challenger = HashChallenger::new(initial_state.clone(), test_hasher);
161
162 assert_eq!(hash_challenger.input_buffer, initial_state);
163 assert_eq!(hash_challenger.output_buffer, vec![]);
164
165 hash_challenger.flush();
166
167 let expected_sum = F::from_u8(55);
168 let expected_len = F::from_u8(10);
169 assert_eq!(
170 hash_challenger.input_buffer,
171 vec![expected_sum, expected_len]
172 );
173 assert_eq!(
174 hash_challenger.output_buffer,
175 vec![expected_sum, expected_len]
176 );
177
178 let new_element = F::from_u8(11);
179 hash_challenger.observe(new_element);
180 assert_eq!(
181 hash_challenger.input_buffer,
182 vec![expected_sum, expected_len, new_element]
183 );
184 assert_eq!(hash_challenger.output_buffer, vec![]);
185
186 let new_expected_len = 3;
187 let new_expected_sum = 76;
188
189 let new_element = hash_challenger.sample();
190 assert_eq!(new_element, F::from_u8(new_expected_len));
191 assert_eq!(
192 hash_challenger.output_buffer,
193 [F::from_u8(new_expected_sum)]
194 );
195 }
196
197 #[test]
198 fn test_hash_challenger_flush() {
199 let initial_state = (1..11_u8).map(F::from_u8).collect::<Vec<_>>();
200 let test_hasher = TestHasher {};
201 let mut hash_challenger = HashChallenger::new(initial_state, test_hasher);
202
203 let first_sample = hash_challenger.sample();
205
206 let second_sample = hash_challenger.sample();
207
208 assert_eq!(first_sample, F::from_u8(10));
210 assert_eq!(second_sample, F::from_u8(55));
212
213 assert!(hash_challenger.output_buffer.is_empty());
215 }
216
217 #[test]
218 fn test_observe_single_value() {
219 let test_hasher = TestHasher {};
220 let mut hash_challenger = HashChallenger::new(vec![F::from_u8(123)], test_hasher);
222
223 let value = F::from_u8(42);
225 hash_challenger.observe(value);
226
227 assert_eq!(
229 hash_challenger.input_buffer,
230 vec![F::from_u8(123), F::from_u8(42)]
231 );
232 assert!(hash_challenger.output_buffer.is_empty());
234 }
235
236 #[test]
237 fn test_observe_array() {
238 let test_hasher = TestHasher {};
239 let mut hash_challenger = HashChallenger::new(vec![F::from_u8(123)], test_hasher);
241
242 let values = [F::from_u8(1), F::from_u8(2), F::from_u8(3)];
244 hash_challenger.observe(values);
245
246 assert_eq!(
248 hash_challenger.input_buffer,
249 vec![F::from_u8(123), F::from_u8(1), F::from_u8(2), F::from_u8(3)]
250 );
251 assert!(hash_challenger.output_buffer.is_empty());
253 }
254
255 #[test]
256 fn test_sample_output_buffer() {
257 let test_hasher = TestHasher {};
258 let initial_state = vec![F::from_u8(5), F::from_u8(10)];
259 let mut hash_challenger = HashChallenger::new(initial_state, test_hasher);
260
261 let sample = hash_challenger.sample();
262 assert_eq!(sample, F::from_u8(2));
264 assert_eq!(hash_challenger.output_buffer, vec![F::from_u8(15)]);
266 }
267
268 #[test]
269 fn test_flush_empty_buffer() {
270 let test_hasher = TestHasher {};
271 let mut hash_challenger = HashChallenger::new(vec![], test_hasher);
272
273 hash_challenger.flush();
275
276 assert_eq!(hash_challenger.input_buffer, vec![F::ZERO, F::ZERO]);
278 assert_eq!(hash_challenger.output_buffer, vec![F::ZERO, F::ZERO]);
279 }
280
281 #[test]
282 fn test_flush_with_data() {
283 let test_hasher = TestHasher {};
284 let initial_state = vec![F::from_u8(1), F::from_u8(2)];
286 let mut hash_challenger = HashChallenger::new(initial_state, test_hasher);
287
288 hash_challenger.flush();
289
290 assert_eq!(
292 hash_challenger.input_buffer,
293 vec![F::from_u8(3), F::from_u8(2)]
294 );
295 assert_eq!(
297 hash_challenger.output_buffer,
298 vec![F::from_u8(3), F::from_u8(2)]
299 );
300 }
301
302 #[test]
303 fn test_sample_after_observe() {
304 let test_hasher = TestHasher {};
305 let initial_state = vec![F::from_u8(1), F::from_u8(2)];
306 let mut hash_challenger = HashChallenger::new(initial_state, test_hasher);
307
308 hash_challenger.observe(F::from_u8(3));
310
311 assert!(hash_challenger.output_buffer.is_empty());
313
314 assert_eq!(
316 hash_challenger.input_buffer,
317 vec![F::from_u8(1), F::from_u8(2), F::from_u8(3)]
318 );
319
320 let sample = hash_challenger.sample();
321
322 assert_eq!(sample, F::from_u8(3));
324 }
325
326 #[test]
327 fn test_sample_with_non_empty_output_buffer() {
328 let test_hasher = TestHasher {};
329 let mut hash_challenger = HashChallenger::new(vec![], test_hasher);
330
331 hash_challenger.output_buffer = vec![F::from_u8(42), F::from_u8(24)];
332
333 let sample = hash_challenger.sample();
334
335 assert_eq!(sample, F::from_u8(24));
337
338 assert_eq!(hash_challenger.output_buffer, vec![F::from_u8(42)]);
340 }
341
342 #[test]
343 fn test_finalize() {
344 let new_chal = || HashChallenger::new(vec![F::from_u8(1), F::from_u8(2)], TestHasher {});
345
346 let mut h1 = new_chal();
348 let mut h2 = new_chal();
349 h1.observe(F::from_u8(42));
350 h2.observe(F::from_u8(42));
351 assert_eq!(h1.finalize(), h2.finalize());
352
353 let mut h1 = new_chal();
355 let mut h2 = new_chal();
356 h1.observe(F::from_u8(1));
357 h2.observe(F::from_u8(2));
358 assert_ne!(h1.finalize(), h2.finalize());
359 }
360
361 #[test]
369 fn test_finalize_sample_interaction() {
370 let digest = |n_samples: usize| {
371 let mut c = HashChallenger::new(vec![F::from_u8(1), F::from_u8(2)], TestHasher {});
372 c.observe(F::from_u8(42));
373 for _ in 0..n_samples {
374 let _: F = c.sample();
375 }
376 c.finalize()
377 };
378
379 assert_ne!(digest(0), digest(1));
383
384 assert_eq!(digest(1), digest(OUT_LEN));
388
389 assert_ne!(digest(OUT_LEN), digest(OUT_LEN + 1));
392
393 assert_eq!(digest(OUT_LEN + 1), digest(2 * OUT_LEN));
395 }
396
397 #[test]
398 fn test_output_buffer_cleared_on_observe() {
399 let test_hasher = TestHasher {};
400 let mut hash_challenger = HashChallenger::new(vec![], test_hasher);
401
402 hash_challenger.output_buffer.push(F::from_u8(42));
404
405 assert!(!hash_challenger.output_buffer.is_empty());
407
408 hash_challenger.observe(F::from_u8(3));
410
411 assert!(hash_challenger.output_buffer.is_empty());
413 }
414}