Skip to main content

p3_challenger/
multi_field_challenger.rs

1use alloc::string::String;
2use alloc::vec;
3use alloc::vec::Vec;
4
5use p3_field::{BasedVectorSpace, Field, PrimeField, PrimeField32, reduce_32, split_32};
6use p3_symmetric::{CryptographicPermutation, Hash, MerkleCap};
7
8use crate::{CanFinalizeDigest, CanObserve, CanSample, CanSampleBits, FieldChallenger};
9
10/// A challenger that operates natively on PF but produces challenges of F: PrimeField32.
11///
12/// Used for optimizing the cost of recursive proof verification of STARKs in SNARKs.
13///
14/// SAFETY: There are some bias complications with using this challenger. In particular,
15/// samples are actually random in [0, 2^64) and then reduced to be in F.
16#[derive(Clone, Debug)]
17pub struct MultiField32Challenger<F, PF, P, const WIDTH: usize, const RATE: usize>
18where
19    F: PrimeField32,
20    PF: Field,
21    P: CryptographicPermutation<[PF; WIDTH]>,
22{
23    sponge_state: [PF; WIDTH],
24    input_buffer: Vec<F>,
25    output_buffer: Vec<F>,
26    permutation: P,
27    num_f_elms: usize,
28}
29
30impl<F, PF, P, const WIDTH: usize, const RATE: usize> MultiField32Challenger<F, PF, P, WIDTH, RATE>
31where
32    F: PrimeField32,
33    PF: Field,
34    P: CryptographicPermutation<[PF; WIDTH]>,
35{
36    pub fn new(permutation: P) -> Result<Self, String> {
37        if F::order() >= PF::order() {
38            return Err(String::from("F::order() must be less than PF::order()"));
39        }
40        let num_f_elms = PF::bits() / 64;
41        Ok(Self {
42            sponge_state: [PF::default(); WIDTH],
43            input_buffer: vec![],
44            output_buffer: vec![],
45            permutation,
46            num_f_elms,
47        })
48    }
49}
50
51impl<F, PF, P, const WIDTH: usize, const RATE: usize> MultiField32Challenger<F, PF, P, WIDTH, RATE>
52where
53    F: PrimeField32,
54    PF: PrimeField,
55    P: CryptographicPermutation<[PF; WIDTH]>,
56{
57    fn duplexing(&mut self) {
58        assert!(self.input_buffer.len() <= self.num_f_elms * RATE);
59
60        for (i, f_chunk) in self.input_buffer.chunks(self.num_f_elms).enumerate() {
61            self.sponge_state[i] = reduce_32(f_chunk);
62        }
63        self.input_buffer.clear();
64
65        // Apply the permutation.
66        self.permutation.permute_mut(&mut self.sponge_state);
67
68        self.output_buffer.clear();
69        for &pf_val in &self.sponge_state[..RATE] {
70            self.output_buffer
71                .extend(split_32::<PF, F>(pf_val, self.num_f_elms));
72        }
73    }
74}
75
76impl<F, PF, P, const WIDTH: usize, const RATE: usize> FieldChallenger<F>
77    for MultiField32Challenger<F, PF, P, WIDTH, RATE>
78where
79    F: PrimeField32,
80    PF: PrimeField,
81    P: CryptographicPermutation<[PF; WIDTH]>,
82{
83}
84
85impl<F, PF, P, const WIDTH: usize, const RATE: usize> CanObserve<F>
86    for MultiField32Challenger<F, PF, P, WIDTH, RATE>
87where
88    F: PrimeField32,
89    PF: PrimeField,
90    P: CryptographicPermutation<[PF; WIDTH]>,
91{
92    fn observe(&mut self, value: F) {
93        // Any buffered output is now invalid.
94        self.output_buffer.clear();
95
96        self.input_buffer.push(value);
97
98        if self.input_buffer.len() == self.num_f_elms * RATE {
99            self.duplexing();
100        }
101    }
102}
103
104impl<F, PF, const N: usize, P, const WIDTH: usize, const RATE: usize> CanObserve<[F; N]>
105    for MultiField32Challenger<F, PF, P, WIDTH, RATE>
106where
107    F: PrimeField32,
108    PF: PrimeField,
109    P: CryptographicPermutation<[PF; WIDTH]>,
110{
111    fn observe(&mut self, values: [F; N]) {
112        for value in values {
113            self.observe(value);
114        }
115    }
116}
117
118impl<F, PF, const N: usize, P, const WIDTH: usize, const RATE: usize> CanObserve<Hash<F, PF, N>>
119    for MultiField32Challenger<F, PF, P, WIDTH, RATE>
120where
121    F: PrimeField32,
122    PF: PrimeField,
123    P: CryptographicPermutation<[PF; WIDTH]>,
124{
125    fn observe(&mut self, values: Hash<F, PF, N>) {
126        for pf_val in values {
127            let f_vals: Vec<F> = split_32(pf_val, self.num_f_elms);
128            for f_val in f_vals {
129                self.observe(f_val);
130            }
131        }
132    }
133}
134
135impl<F, PF, const N: usize, P, const WIDTH: usize, const RATE: usize>
136    CanObserve<&MerkleCap<F, [PF; N]>> for MultiField32Challenger<F, PF, P, WIDTH, RATE>
137where
138    F: PrimeField32,
139    PF: PrimeField,
140    P: CryptographicPermutation<[PF; WIDTH]>,
141{
142    fn observe(&mut self, cap: &MerkleCap<F, [PF; N]>) {
143        for digest in cap.roots() {
144            for pf_val in digest {
145                let f_vals: Vec<F> = split_32(*pf_val, self.num_f_elms);
146                for f_val in f_vals {
147                    self.observe(f_val);
148                }
149            }
150        }
151    }
152}
153
154impl<F, PF, const N: usize, P, const WIDTH: usize, const RATE: usize>
155    CanObserve<MerkleCap<F, [PF; N]>> for MultiField32Challenger<F, PF, P, WIDTH, RATE>
156where
157    F: PrimeField32,
158    PF: PrimeField,
159    P: CryptographicPermutation<[PF; WIDTH]>,
160{
161    fn observe(&mut self, cap: MerkleCap<F, [PF; N]>) {
162        self.observe(&cap);
163    }
164}
165
166// for TrivialPcs
167impl<F, PF, P, const WIDTH: usize, const RATE: usize> CanObserve<Vec<Vec<F>>>
168    for MultiField32Challenger<F, PF, P, WIDTH, RATE>
169where
170    F: PrimeField32,
171    PF: PrimeField,
172    P: CryptographicPermutation<[PF; WIDTH]>,
173{
174    fn observe(&mut self, valuess: Vec<Vec<F>>) {
175        for values in valuess {
176            for value in values {
177                self.observe(value);
178            }
179        }
180    }
181}
182
183impl<F, EF, PF, P, const WIDTH: usize, const RATE: usize> CanSample<EF>
184    for MultiField32Challenger<F, PF, P, WIDTH, RATE>
185where
186    F: PrimeField32,
187    EF: BasedVectorSpace<F>,
188    PF: PrimeField,
189    P: CryptographicPermutation<[PF; WIDTH]>,
190{
191    fn sample(&mut self) -> EF {
192        EF::from_basis_coefficients_fn(|_| {
193            // If we have buffered inputs, we must perform a duplexing so that the challenge will
194            // reflect them. Or if we've run out of outputs, we must perform a duplexing to get more.
195            if !self.input_buffer.is_empty() || self.output_buffer.is_empty() {
196                self.duplexing();
197            }
198
199            self.output_buffer
200                .pop()
201                .expect("Output buffer should be non-empty")
202        })
203    }
204}
205
206impl<F, PF, P, const WIDTH: usize, const RATE: usize> CanSampleBits<usize>
207    for MultiField32Challenger<F, PF, P, WIDTH, RATE>
208where
209    F: PrimeField32,
210    PF: PrimeField,
211    P: CryptographicPermutation<[PF; WIDTH]>,
212{
213    /// The sampled bits are not perfectly uniform, but we can bound the error: every sequence
214    /// appears with probability 1/p-close to uniform (1/2^b).
215    ///
216    /// Proof:
217    /// We denote p = F::ORDER_U32, and b = `bits`.
218    /// If X follows a uniform distribution over F, if we consider the first b bits of X, each
219    /// sequence appears either with probability P1 = ⌊p / 2^b⌋ / p or P2 = (1 + ⌊p / 2^b⌋) / p.
220    /// We have 1/2^b - 1/p ≤ P1, P2 ≤ 1/2^b + 1/p
221    fn sample_bits(&mut self, bits: usize) -> usize {
222        assert!(bits < (usize::BITS as usize));
223        assert!((1 << bits) < F::ORDER_U32);
224        let rand_f: F = self.sample();
225        let rand_usize = rand_f.as_canonical_u32() as usize;
226        rand_usize & ((1 << bits) - 1)
227    }
228}
229
230impl<F, PF, P, const WIDTH: usize, const RATE: usize> CanFinalizeDigest
231    for MultiField32Challenger<F, PF, P, WIDTH, RATE>
232where
233    F: PrimeField32,
234    PF: PrimeField,
235    P: CryptographicPermutation<[PF; WIDTH]>,
236{
237    type Digest = [PF; RATE];
238
239    fn finalize(mut self) -> [PF; RATE] {
240        // Unconditionally duplex: absorb any pending input and permute.
241        //
242        // Note: sampling only pops from the output buffer without modifying
243        // sponge state, so it does not necessarily affect the digest (e.g.
244        // when the last observe already triggered auto-duplexing).
245        self.duplexing();
246        self.sponge_state[..RATE].try_into().unwrap()
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use p3_baby_bear::BabyBear;
253    use p3_field::PrimeCharacteristicRing;
254    use p3_goldilocks::Goldilocks;
255    use p3_symmetric::Permutation;
256
257    use super::*;
258
259    const WIDTH: usize = 8;
260    const RATE: usize = 4;
261
262    type F = BabyBear;
263    type PF = Goldilocks;
264
265    #[derive(Clone)]
266    struct TestPermutation;
267
268    impl Permutation<[PF; WIDTH]> for TestPermutation {
269        fn permute_mut(&self, input: &mut [PF; WIDTH]) {
270            for (i, val) in input.iter_mut().enumerate() {
271                *val = PF::from_u8((i + 1) as u8);
272            }
273        }
274    }
275
276    impl CryptographicPermutation<[PF; WIDTH]> for TestPermutation {}
277
278    /// A permutation where each output depends on all inputs, suitable for
279    /// tests that need to detect state changes (e.g. finalize).
280    #[derive(Clone)]
281    struct MixingPermutation;
282
283    impl Permutation<[PF; WIDTH]> for MixingPermutation {
284        fn permute_mut(&self, input: &mut [PF; WIDTH]) {
285            let sum: PF = input.iter().copied().sum();
286            for (i, val) in input.iter_mut().enumerate() {
287                *val = sum + PF::from_u8((i + 1) as u8);
288            }
289        }
290    }
291
292    impl CryptographicPermutation<[PF; WIDTH]> for MixingPermutation {}
293
294    #[test]
295    fn test_output_buffer_excludes_capacity() {
296        let permutation = TestPermutation;
297        let mut challenger =
298            MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(permutation).unwrap();
299
300        // num_f_elms = PF::bits() / 64 = 64 / 64 = 1 for Goldilocks
301        let num_f_elms = challenger.num_f_elms;
302
303        // Trigger duplexing by sampling
304        let _: F = challenger.sample();
305
306        // Output buffer should contain RATE * num_f_elms elements, NOT WIDTH * num_f_elms
307        // This verifies we only output the rate portion, not the capacity
308        let expected_output_size = RATE * num_f_elms;
309        let incorrect_output_size = WIDTH * num_f_elms;
310
311        assert_eq!(
312            challenger.output_buffer.len(),
313            expected_output_size - 1, // -1 because we sampled one element
314            "Output buffer should be based on RATE ({}), not WIDTH ({})",
315            expected_output_size,
316            incorrect_output_size
317        );
318    }
319
320    #[test]
321    fn test_finalize() {
322        let new_chal =
323            || MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
324
325        // Deterministic: same observations produce same digest.
326        let mut c1 = new_chal();
327        let mut c2 = new_chal();
328        for i in 0..5u8 {
329            c1.observe(F::from_u8(i));
330            c2.observe(F::from_u8(i));
331        }
332        assert_eq!(c1.finalize(), c2.finalize());
333
334        // Different observations produce different digests.
335        let mut c1 = new_chal();
336        let mut c2 = new_chal();
337        for i in 0..5u8 {
338            c1.observe(F::from_u8(i));
339            c2.observe(F::from_u8(i + 1));
340        }
341        assert_ne!(c1.finalize(), c2.finalize());
342    }
343
344    /// Document how sampling interacts with finalize.
345    ///
346    /// Same principle as DuplexChallenger: sampling only pops from the
347    /// output buffer without modifying sponge state. The digest changes
348    /// when a sample triggers a new duplexing. Each duplexing produces
349    /// `num_f_elms * RATE` output elements (here 1 * 4 = 4 BabyBear
350    /// elements for Goldilocks/BabyBear), so the digest is stable within
351    /// each batch of that many samples.
352    #[test]
353    fn test_finalize_sample_interaction() {
354        let batch_size = {
355            let c =
356                MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
357            c.num_f_elms * RATE
358        };
359
360        let digest = |n_samples: usize| {
361            let mut c =
362                MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
363            for i in 0..3u8 {
364                c.observe(F::from_u8(i));
365            }
366            for _ in 0..n_samples {
367                let _: F = c.sample();
368            }
369            c.finalize()
370        };
371
372        // The first sample triggers duplexing (absorbs pending input),
373        // so finalize's duplexing is an extra permutation — different digest.
374        assert_ne!(digest(0), digest(1));
375
376        // Samples within the same batch don't trigger another duplexing.
377        assert_eq!(digest(1), digest(2));
378        assert_eq!(digest(1), digest(batch_size));
379
380        // Exhausting the output buffer triggers a fresh duplexing.
381        assert_ne!(digest(batch_size), digest(batch_size + 1));
382
383        // Stable within the second batch.
384        assert_eq!(digest(batch_size + 1), digest(batch_size + 2));
385    }
386
387    #[test]
388    fn test_duplexing_respects_rate() {
389        let permutation = TestPermutation;
390        let mut challenger =
391            MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(permutation).unwrap();
392
393        let num_f_elms = challenger.num_f_elms;
394
395        // Fill input buffer to trigger duplexing
396        for i in 0..(num_f_elms * RATE) {
397            challenger.observe(F::from_u8(i as u8));
398        }
399
400        // After observing exactly num_f_elms * RATE elements, duplexing occurs
401        // Output buffer should have exactly RATE * num_f_elms elements
402        assert_eq!(
403            challenger.output_buffer.len(),
404            RATE * num_f_elms,
405            "After duplexing, output buffer should contain RATE * num_f_elms elements"
406        );
407    }
408}