1use alloc::string::String;
2use alloc::vec;
3use alloc::vec::Vec;
4
5use p3_field::{BasedVectorSpace, Field, PrimeField, PrimeField32, reduce_32, split_32};
6use p3_symmetric::{CryptographicPermutation, Hash, MerkleCap};
7
8use crate::{CanFinalizeDigest, CanObserve, CanSample, CanSampleBits, FieldChallenger};
9
10#[derive(Clone, Debug)]
17pub struct MultiField32Challenger<F, PF, P, const WIDTH: usize, const RATE: usize>
18where
19 F: PrimeField32,
20 PF: Field,
21 P: CryptographicPermutation<[PF; WIDTH]>,
22{
23 sponge_state: [PF; WIDTH],
24 input_buffer: Vec<F>,
25 output_buffer: Vec<F>,
26 permutation: P,
27 num_f_elms: usize,
28}
29
30impl<F, PF, P, const WIDTH: usize, const RATE: usize> MultiField32Challenger<F, PF, P, WIDTH, RATE>
31where
32 F: PrimeField32,
33 PF: Field,
34 P: CryptographicPermutation<[PF; WIDTH]>,
35{
36 pub fn new(permutation: P) -> Result<Self, String> {
37 if F::order() >= PF::order() {
38 return Err(String::from("F::order() must be less than PF::order()"));
39 }
40 let num_f_elms = PF::bits() / 64;
41 Ok(Self {
42 sponge_state: [PF::default(); WIDTH],
43 input_buffer: vec![],
44 output_buffer: vec![],
45 permutation,
46 num_f_elms,
47 })
48 }
49}
50
51impl<F, PF, P, const WIDTH: usize, const RATE: usize> MultiField32Challenger<F, PF, P, WIDTH, RATE>
52where
53 F: PrimeField32,
54 PF: PrimeField,
55 P: CryptographicPermutation<[PF; WIDTH]>,
56{
57 fn duplexing(&mut self) {
58 assert!(self.input_buffer.len() <= self.num_f_elms * RATE);
59
60 for (i, f_chunk) in self.input_buffer.chunks(self.num_f_elms).enumerate() {
61 self.sponge_state[i] = reduce_32(f_chunk);
62 }
63 self.input_buffer.clear();
64
65 self.permutation.permute_mut(&mut self.sponge_state);
67
68 self.output_buffer.clear();
69 for &pf_val in &self.sponge_state[..RATE] {
70 self.output_buffer
71 .extend(split_32::<PF, F>(pf_val, self.num_f_elms));
72 }
73 }
74}
75
76impl<F, PF, P, const WIDTH: usize, const RATE: usize> FieldChallenger<F>
77 for MultiField32Challenger<F, PF, P, WIDTH, RATE>
78where
79 F: PrimeField32,
80 PF: PrimeField,
81 P: CryptographicPermutation<[PF; WIDTH]>,
82{
83}
84
85impl<F, PF, P, const WIDTH: usize, const RATE: usize> CanObserve<F>
86 for MultiField32Challenger<F, PF, P, WIDTH, RATE>
87where
88 F: PrimeField32,
89 PF: PrimeField,
90 P: CryptographicPermutation<[PF; WIDTH]>,
91{
92 fn observe(&mut self, value: F) {
93 self.output_buffer.clear();
95
96 self.input_buffer.push(value);
97
98 if self.input_buffer.len() == self.num_f_elms * RATE {
99 self.duplexing();
100 }
101 }
102}
103
104impl<F, PF, const N: usize, P, const WIDTH: usize, const RATE: usize> CanObserve<[F; N]>
105 for MultiField32Challenger<F, PF, P, WIDTH, RATE>
106where
107 F: PrimeField32,
108 PF: PrimeField,
109 P: CryptographicPermutation<[PF; WIDTH]>,
110{
111 fn observe(&mut self, values: [F; N]) {
112 for value in values {
113 self.observe(value);
114 }
115 }
116}
117
118impl<F, PF, const N: usize, P, const WIDTH: usize, const RATE: usize> CanObserve<Hash<F, PF, N>>
119 for MultiField32Challenger<F, PF, P, WIDTH, RATE>
120where
121 F: PrimeField32,
122 PF: PrimeField,
123 P: CryptographicPermutation<[PF; WIDTH]>,
124{
125 fn observe(&mut self, values: Hash<F, PF, N>) {
126 for pf_val in values {
127 let f_vals: Vec<F> = split_32(pf_val, self.num_f_elms);
128 for f_val in f_vals {
129 self.observe(f_val);
130 }
131 }
132 }
133}
134
135impl<F, PF, const N: usize, P, const WIDTH: usize, const RATE: usize>
136 CanObserve<&MerkleCap<F, [PF; N]>> for MultiField32Challenger<F, PF, P, WIDTH, RATE>
137where
138 F: PrimeField32,
139 PF: PrimeField,
140 P: CryptographicPermutation<[PF; WIDTH]>,
141{
142 fn observe(&mut self, cap: &MerkleCap<F, [PF; N]>) {
143 for digest in cap.roots() {
144 for pf_val in digest {
145 let f_vals: Vec<F> = split_32(*pf_val, self.num_f_elms);
146 for f_val in f_vals {
147 self.observe(f_val);
148 }
149 }
150 }
151 }
152}
153
154impl<F, PF, const N: usize, P, const WIDTH: usize, const RATE: usize>
155 CanObserve<MerkleCap<F, [PF; N]>> for MultiField32Challenger<F, PF, P, WIDTH, RATE>
156where
157 F: PrimeField32,
158 PF: PrimeField,
159 P: CryptographicPermutation<[PF; WIDTH]>,
160{
161 fn observe(&mut self, cap: MerkleCap<F, [PF; N]>) {
162 self.observe(&cap);
163 }
164}
165
166impl<F, PF, P, const WIDTH: usize, const RATE: usize> CanObserve<Vec<Vec<F>>>
168 for MultiField32Challenger<F, PF, P, WIDTH, RATE>
169where
170 F: PrimeField32,
171 PF: PrimeField,
172 P: CryptographicPermutation<[PF; WIDTH]>,
173{
174 fn observe(&mut self, valuess: Vec<Vec<F>>) {
175 for values in valuess {
176 for value in values {
177 self.observe(value);
178 }
179 }
180 }
181}
182
183impl<F, EF, PF, P, const WIDTH: usize, const RATE: usize> CanSample<EF>
184 for MultiField32Challenger<F, PF, P, WIDTH, RATE>
185where
186 F: PrimeField32,
187 EF: BasedVectorSpace<F>,
188 PF: PrimeField,
189 P: CryptographicPermutation<[PF; WIDTH]>,
190{
191 fn sample(&mut self) -> EF {
192 EF::from_basis_coefficients_fn(|_| {
193 if !self.input_buffer.is_empty() || self.output_buffer.is_empty() {
196 self.duplexing();
197 }
198
199 self.output_buffer
200 .pop()
201 .expect("Output buffer should be non-empty")
202 })
203 }
204}
205
206impl<F, PF, P, const WIDTH: usize, const RATE: usize> CanSampleBits<usize>
207 for MultiField32Challenger<F, PF, P, WIDTH, RATE>
208where
209 F: PrimeField32,
210 PF: PrimeField,
211 P: CryptographicPermutation<[PF; WIDTH]>,
212{
213 fn sample_bits(&mut self, bits: usize) -> usize {
222 assert!(bits < (usize::BITS as usize));
223 assert!((1 << bits) < F::ORDER_U32);
224 let rand_f: F = self.sample();
225 let rand_usize = rand_f.as_canonical_u32() as usize;
226 rand_usize & ((1 << bits) - 1)
227 }
228}
229
230impl<F, PF, P, const WIDTH: usize, const RATE: usize> CanFinalizeDigest
231 for MultiField32Challenger<F, PF, P, WIDTH, RATE>
232where
233 F: PrimeField32,
234 PF: PrimeField,
235 P: CryptographicPermutation<[PF; WIDTH]>,
236{
237 type Digest = [PF; RATE];
238
239 fn finalize(mut self) -> [PF; RATE] {
240 self.duplexing();
246 self.sponge_state[..RATE].try_into().unwrap()
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use p3_baby_bear::BabyBear;
253 use p3_field::PrimeCharacteristicRing;
254 use p3_goldilocks::Goldilocks;
255 use p3_symmetric::Permutation;
256
257 use super::*;
258
259 const WIDTH: usize = 8;
260 const RATE: usize = 4;
261
262 type F = BabyBear;
263 type PF = Goldilocks;
264
265 #[derive(Clone)]
266 struct TestPermutation;
267
268 impl Permutation<[PF; WIDTH]> for TestPermutation {
269 fn permute_mut(&self, input: &mut [PF; WIDTH]) {
270 for (i, val) in input.iter_mut().enumerate() {
271 *val = PF::from_u8((i + 1) as u8);
272 }
273 }
274 }
275
276 impl CryptographicPermutation<[PF; WIDTH]> for TestPermutation {}
277
278 #[derive(Clone)]
281 struct MixingPermutation;
282
283 impl Permutation<[PF; WIDTH]> for MixingPermutation {
284 fn permute_mut(&self, input: &mut [PF; WIDTH]) {
285 let sum: PF = input.iter().copied().sum();
286 for (i, val) in input.iter_mut().enumerate() {
287 *val = sum + PF::from_u8((i + 1) as u8);
288 }
289 }
290 }
291
292 impl CryptographicPermutation<[PF; WIDTH]> for MixingPermutation {}
293
294 #[test]
295 fn test_output_buffer_excludes_capacity() {
296 let permutation = TestPermutation;
297 let mut challenger =
298 MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(permutation).unwrap();
299
300 let num_f_elms = challenger.num_f_elms;
302
303 let _: F = challenger.sample();
305
306 let expected_output_size = RATE * num_f_elms;
309 let incorrect_output_size = WIDTH * num_f_elms;
310
311 assert_eq!(
312 challenger.output_buffer.len(),
313 expected_output_size - 1, "Output buffer should be based on RATE ({}), not WIDTH ({})",
315 expected_output_size,
316 incorrect_output_size
317 );
318 }
319
320 #[test]
321 fn test_finalize() {
322 let new_chal =
323 || MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
324
325 let mut c1 = new_chal();
327 let mut c2 = new_chal();
328 for i in 0..5u8 {
329 c1.observe(F::from_u8(i));
330 c2.observe(F::from_u8(i));
331 }
332 assert_eq!(c1.finalize(), c2.finalize());
333
334 let mut c1 = new_chal();
336 let mut c2 = new_chal();
337 for i in 0..5u8 {
338 c1.observe(F::from_u8(i));
339 c2.observe(F::from_u8(i + 1));
340 }
341 assert_ne!(c1.finalize(), c2.finalize());
342 }
343
344 #[test]
353 fn test_finalize_sample_interaction() {
354 let batch_size = {
355 let c =
356 MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
357 c.num_f_elms * RATE
358 };
359
360 let digest = |n_samples: usize| {
361 let mut c =
362 MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
363 for i in 0..3u8 {
364 c.observe(F::from_u8(i));
365 }
366 for _ in 0..n_samples {
367 let _: F = c.sample();
368 }
369 c.finalize()
370 };
371
372 assert_ne!(digest(0), digest(1));
375
376 assert_eq!(digest(1), digest(2));
378 assert_eq!(digest(1), digest(batch_size));
379
380 assert_ne!(digest(batch_size), digest(batch_size + 1));
382
383 assert_eq!(digest(batch_size + 1), digest(batch_size + 2));
385 }
386
387 #[test]
388 fn test_duplexing_respects_rate() {
389 let permutation = TestPermutation;
390 let mut challenger =
391 MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(permutation).unwrap();
392
393 let num_f_elms = challenger.num_f_elms;
394
395 for i in 0..(num_f_elms * RATE) {
397 challenger.observe(F::from_u8(i as u8));
398 }
399
400 assert_eq!(
403 challenger.output_buffer.len(),
404 RATE * num_f_elms,
405 "After duplexing, output buffer should contain RATE * num_f_elms elements"
406 );
407 }
408}