p3_challenger/
hash_challenger.rs

1use alloc::vec;
2use alloc::vec::Vec;
3
4use p3_symmetric::CryptographicHasher;
5
6use crate::{CanObserve, CanSample};
7
8/// A generic challenger that uses a cryptographic hash function to generate challenges.
9#[derive(Clone, Debug)]
10pub struct HashChallenger<T, H, const OUT_LEN: usize>
11where
12    T: Clone,
13    H: CryptographicHasher<T, [T; OUT_LEN]>,
14{
15    /// Buffer to store observed values before hashing.
16    input_buffer: Vec<T>,
17    /// Buffer to store hashed output values, which are consumed when sampling.
18    output_buffer: Vec<T>,
19    /// The cryptographic hash function used for generating challenges.
20    hasher: H,
21}
22
23impl<T, H, const OUT_LEN: usize> HashChallenger<T, H, OUT_LEN>
24where
25    T: Clone,
26    H: CryptographicHasher<T, [T; OUT_LEN]>,
27{
28    pub const fn new(initial_state: Vec<T>, hasher: H) -> Self {
29        Self {
30            input_buffer: initial_state,
31            output_buffer: vec![],
32            hasher,
33        }
34    }
35
36    fn flush(&mut self) {
37        let inputs = self.input_buffer.drain(..);
38        let output = self.hasher.hash_iter(inputs);
39
40        self.output_buffer = output.to_vec();
41
42        // Chaining values.
43        self.input_buffer.extend(output.to_vec());
44    }
45}
46
47impl<T, H, const OUT_LEN: usize> CanObserve<T> for HashChallenger<T, H, OUT_LEN>
48where
49    T: Clone,
50    H: CryptographicHasher<T, [T; OUT_LEN]>,
51{
52    fn observe(&mut self, value: T) {
53        // Any buffered output is now invalid.
54        self.output_buffer.clear();
55
56        self.input_buffer.push(value);
57    }
58}
59
60impl<T, H, const N: usize, const OUT_LEN: usize> CanObserve<[T; N]>
61    for HashChallenger<T, H, OUT_LEN>
62where
63    T: Clone,
64    H: CryptographicHasher<T, [T; OUT_LEN]>,
65{
66    fn observe(&mut self, values: [T; N]) {
67        for value in values {
68            self.observe(value);
69        }
70    }
71}
72
73impl<T, H, const OUT_LEN: usize> CanSample<T> for HashChallenger<T, H, OUT_LEN>
74where
75    T: Clone,
76    H: CryptographicHasher<T, [T; OUT_LEN]>,
77{
78    fn sample(&mut self) -> T {
79        if self.output_buffer.is_empty() {
80            self.flush();
81        }
82        self.output_buffer
83            .pop()
84            .expect("Output buffer should be non-empty")
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use p3_field::PrimeCharacteristicRing;
91    use p3_goldilocks::Goldilocks;
92
93    use super::*;
94
95    const OUT_LEN: usize = 2;
96    type F = Goldilocks;
97
98    #[derive(Clone)]
99    struct TestHasher {}
100
101    impl CryptographicHasher<F, [F; OUT_LEN]> for TestHasher {
102        /// A very simple hash iterator. From an input of type `IntoIterator<Item = Goldilocks>`,
103        /// it outputs the sum of its elements and its length (as a field element).
104        fn hash_iter<I>(&self, input: I) -> [F; OUT_LEN]
105        where
106            I: IntoIterator<Item = F>,
107        {
108            let (sum, len) = input
109                .into_iter()
110                .fold((F::ZERO, 0_usize), |(acc_sum, acc_len), f| {
111                    (acc_sum + f, acc_len + 1)
112                });
113            [sum, F::from_usize(len)]
114        }
115
116        /// A very simple slice hash iterator. From an input of type `IntoIterator<Item = &'a [Goldilocks]>`,
117        /// it outputs the sum of its elements and its length (as a field element).
118        fn hash_iter_slices<'a, I>(&self, input: I) -> [F; OUT_LEN]
119        where
120            I: IntoIterator<Item = &'a [F]>,
121            F: 'a,
122        {
123            let (sum, len) = input
124                .into_iter()
125                .fold((F::ZERO, 0_usize), |(acc_sum, acc_len), n| {
126                    (
127                        acc_sum + n.iter().fold(F::ZERO, |acc, f| acc + *f),
128                        acc_len + n.len(),
129                    )
130                });
131            [sum, F::from_usize(len)]
132        }
133    }
134
135    #[test]
136    fn test_hash_challenger() {
137        let initial_state = (1..11_u8).map(F::from_u8).collect::<Vec<_>>();
138        let test_hasher = TestHasher {};
139        let mut hash_challenger = HashChallenger::new(initial_state.clone(), test_hasher);
140
141        assert_eq!(hash_challenger.input_buffer, initial_state);
142        assert_eq!(hash_challenger.output_buffer, vec![]);
143
144        hash_challenger.flush();
145
146        let expected_sum = F::from_u8(55);
147        let expected_len = F::from_u8(10);
148        assert_eq!(
149            hash_challenger.input_buffer,
150            vec![expected_sum, expected_len]
151        );
152        assert_eq!(
153            hash_challenger.output_buffer,
154            vec![expected_sum, expected_len]
155        );
156
157        let new_element = F::from_u8(11);
158        hash_challenger.observe(new_element);
159        assert_eq!(
160            hash_challenger.input_buffer,
161            vec![expected_sum, expected_len, new_element]
162        );
163        assert_eq!(hash_challenger.output_buffer, vec![]);
164
165        let new_expected_len = 3;
166        let new_expected_sum = 76;
167
168        let new_element = hash_challenger.sample();
169        assert_eq!(new_element, F::from_u8(new_expected_len));
170        assert_eq!(
171            hash_challenger.output_buffer,
172            [F::from_u8(new_expected_sum)]
173        );
174    }
175
176    #[test]
177    fn test_hash_challenger_flush() {
178        let initial_state = (1..11_u8).map(F::from_u8).collect::<Vec<_>>();
179        let test_hasher = TestHasher {};
180        let mut hash_challenger = HashChallenger::new(initial_state, test_hasher);
181
182        // Sample twice to ensure flush happens
183        let first_sample = hash_challenger.sample();
184
185        let second_sample = hash_challenger.sample();
186
187        // Verify that the first sample is the length of 1..11, (i.e. 10).
188        assert_eq!(first_sample, F::from_u8(10));
189        //  Verify that the second sample is the sum of numbers from 1 to 10 (i.e. 55)
190        assert_eq!(second_sample, F::from_u8(55));
191
192        // Verify that the output buffer is now empty
193        assert!(hash_challenger.output_buffer.is_empty());
194    }
195
196    #[test]
197    fn test_observe_single_value() {
198        let test_hasher = TestHasher {};
199        // Initial state non-empty
200        let mut hash_challenger = HashChallenger::new(vec![F::from_u8(123)], test_hasher);
201
202        // Observe a single value
203        let value = F::from_u8(42);
204        hash_challenger.observe(value);
205
206        // Check that the input buffer contains the initial and observed values
207        assert_eq!(
208            hash_challenger.input_buffer,
209            vec![F::from_u8(123), F::from_u8(42)]
210        );
211        // Check that the output buffer is empty (clears after observation)
212        assert!(hash_challenger.output_buffer.is_empty());
213    }
214
215    #[test]
216    fn test_observe_array() {
217        let test_hasher = TestHasher {};
218        // Initial state non-empty
219        let mut hash_challenger = HashChallenger::new(vec![F::from_u8(123)], test_hasher);
220
221        // Observe an array of values
222        let values = [F::from_u8(1), F::from_u8(2), F::from_u8(3)];
223        hash_challenger.observe(values);
224
225        // Check that the input buffer contains the values
226        assert_eq!(
227            hash_challenger.input_buffer,
228            vec![F::from_u8(123), F::from_u8(1), F::from_u8(2), F::from_u8(3)]
229        );
230        // Check that the output buffer is empty (clears after observation)
231        assert!(hash_challenger.output_buffer.is_empty());
232    }
233
234    #[test]
235    fn test_sample_output_buffer() {
236        let test_hasher = TestHasher {};
237        let initial_state = vec![F::from_u8(5), F::from_u8(10)];
238        let mut hash_challenger = HashChallenger::new(initial_state, test_hasher);
239
240        let sample = hash_challenger.sample();
241        // Verify that the sample is the length of the initial state
242        assert_eq!(sample, F::from_u8(2));
243        // Check that the output buffer contains the sum of the initial state
244        assert_eq!(hash_challenger.output_buffer, vec![F::from_u8(15)]);
245    }
246
247    #[test]
248    fn test_flush_empty_buffer() {
249        let test_hasher = TestHasher {};
250        let mut hash_challenger = HashChallenger::new(vec![], test_hasher);
251
252        // Flush empty buffer
253        hash_challenger.flush();
254
255        // Check that the input and output buffers contain the sum and length of the empty buffer
256        assert_eq!(hash_challenger.input_buffer, vec![F::ZERO, F::ZERO]);
257        assert_eq!(hash_challenger.output_buffer, vec![F::ZERO, F::ZERO]);
258    }
259
260    #[test]
261    fn test_flush_with_data() {
262        let test_hasher = TestHasher {};
263        // Initial state non-empty
264        let initial_state = vec![F::from_u8(1), F::from_u8(2)];
265        let mut hash_challenger = HashChallenger::new(initial_state, test_hasher);
266
267        hash_challenger.flush();
268
269        // Check that the input buffer contains the sum and length of the initial state
270        assert_eq!(
271            hash_challenger.input_buffer,
272            vec![F::from_u8(3), F::from_u8(2)]
273        );
274        // Check that the output buffer contains the sum and length of the initial state
275        assert_eq!(
276            hash_challenger.output_buffer,
277            vec![F::from_u8(3), F::from_u8(2)]
278        );
279    }
280
281    #[test]
282    fn test_sample_after_observe() {
283        let test_hasher = TestHasher {};
284        let initial_state = vec![F::from_u8(1), F::from_u8(2)];
285        let mut hash_challenger = HashChallenger::new(initial_state, test_hasher);
286
287        // Observe will clear the output buffer
288        hash_challenger.observe(F::from_u8(3));
289
290        // Verify that the output buffer is empty
291        assert!(hash_challenger.output_buffer.is_empty());
292
293        // Verify the new value is in the input buffer
294        assert_eq!(
295            hash_challenger.input_buffer,
296            vec![F::from_u8(1), F::from_u8(2), F::from_u8(3)]
297        );
298
299        let sample = hash_challenger.sample();
300
301        // Length of initial state + observed value
302        assert_eq!(sample, F::from_u8(3));
303    }
304
305    #[test]
306    fn test_sample_with_non_empty_output_buffer() {
307        let test_hasher = TestHasher {};
308        let mut hash_challenger = HashChallenger::new(vec![], test_hasher);
309
310        hash_challenger.output_buffer = vec![F::from_u8(42), F::from_u8(24)];
311
312        let sample = hash_challenger.sample();
313
314        // Sample will pop the last element from the output buffer
315        assert_eq!(sample, F::from_u8(24));
316
317        // Check that the output buffer is now one element shorter
318        assert_eq!(hash_challenger.output_buffer, vec![F::from_u8(42)]);
319    }
320
321    #[test]
322    fn test_output_buffer_cleared_on_observe() {
323        let test_hasher = TestHasher {};
324        let mut hash_challenger = HashChallenger::new(vec![], test_hasher);
325
326        // Populate artificially the output buffer
327        hash_challenger.output_buffer.push(F::from_u8(42));
328
329        // Ensure the output buffer is populated
330        assert!(!hash_challenger.output_buffer.is_empty());
331
332        // Observe a new value
333        hash_challenger.observe(F::from_u8(3));
334
335        // Verify that the output buffer is cleared after observing
336        assert!(hash_challenger.output_buffer.is_empty());
337    }
338}