1use alloc::vec;
2use alloc::vec::Vec;
3use core::error::Error;
4use core::fmt::{Display, Formatter};
5
6use p3_field::{BasedVectorSpace, Field, PrimeField64};
7use p3_monty_31::{MontyField31, MontyParameters};
8use p3_symmetric::{CryptographicPermutation, Hash};
9
10use crate::{CanObserve, CanSample, CanSampleBits, CanSampleUniformBits, FieldChallenger};
11
12#[derive(Clone, Debug)]
29pub struct DuplexChallenger<F, P, const WIDTH: usize, const RATE: usize>
30where
31 F: Clone,
32 P: CryptographicPermutation<[F; WIDTH]>,
33{
34 pub sponge_state: [F; WIDTH],
41
42 pub input_buffer: Vec<F>,
48
49 pub output_buffer: Vec<F>,
55
56 pub permutation: P,
62}
63
64impl<F, P, const WIDTH: usize, const RATE: usize> DuplexChallenger<F, P, WIDTH, RATE>
65where
66 F: Copy,
67 P: CryptographicPermutation<[F; WIDTH]>,
68{
69 pub fn new(permutation: P) -> Self
70 where
71 F: Default,
72 {
73 Self {
74 sponge_state: [F::default(); WIDTH],
75 input_buffer: vec![],
76 output_buffer: vec![],
77 permutation,
78 }
79 }
80
81 fn duplexing(&mut self) {
82 assert!(self.input_buffer.len() <= RATE);
83
84 for (i, val) in self.input_buffer.drain(..).enumerate() {
86 self.sponge_state[i] = val;
87 }
88
89 self.permutation.permute_mut(&mut self.sponge_state);
91
92 self.output_buffer.clear();
93 self.output_buffer.extend(&self.sponge_state[..RATE]);
94 }
95}
96
97impl<F, P, const WIDTH: usize, const RATE: usize> FieldChallenger<F>
98 for DuplexChallenger<F, P, WIDTH, RATE>
99where
100 F: PrimeField64,
101 P: CryptographicPermutation<[F; WIDTH]>,
102{
103}
104
105impl<F, P, const WIDTH: usize, const RATE: usize> CanObserve<F>
106 for DuplexChallenger<F, P, WIDTH, RATE>
107where
108 F: Copy,
109 P: CryptographicPermutation<[F; WIDTH]>,
110{
111 fn observe(&mut self, value: F) {
112 self.output_buffer.clear();
114
115 self.input_buffer.push(value);
116
117 if self.input_buffer.len() == RATE {
118 self.duplexing();
119 }
120 }
121}
122
123impl<F, P, const N: usize, const WIDTH: usize, const RATE: usize> CanObserve<[F; N]>
124 for DuplexChallenger<F, P, WIDTH, RATE>
125where
126 F: Copy,
127 P: CryptographicPermutation<[F; WIDTH]>,
128{
129 fn observe(&mut self, values: [F; N]) {
130 for value in values {
131 self.observe(value);
132 }
133 }
134}
135
136impl<F, P, const N: usize, const WIDTH: usize, const RATE: usize> CanObserve<Hash<F, F, N>>
137 for DuplexChallenger<F, P, WIDTH, RATE>
138where
139 F: Copy,
140 P: CryptographicPermutation<[F; WIDTH]>,
141{
142 fn observe(&mut self, values: Hash<F, F, N>) {
143 for value in values {
144 self.observe(value);
145 }
146 }
147}
148
149impl<F, P, const WIDTH: usize, const RATE: usize> CanObserve<Vec<Vec<F>>>
151 for DuplexChallenger<F, P, WIDTH, RATE>
152where
153 F: Copy,
154 P: CryptographicPermutation<[F; WIDTH]>,
155{
156 fn observe(&mut self, valuess: Vec<Vec<F>>) {
157 for values in valuess {
158 for value in values {
159 self.observe(value);
160 }
161 }
162 }
163}
164
165impl<F, EF, P, const WIDTH: usize, const RATE: usize> CanSample<EF>
166 for DuplexChallenger<F, P, WIDTH, RATE>
167where
168 F: Field,
169 EF: BasedVectorSpace<F>,
170 P: CryptographicPermutation<[F; WIDTH]>,
171{
172 fn sample(&mut self) -> EF {
173 EF::from_basis_coefficients_fn(|_| {
174 if !self.input_buffer.is_empty() || self.output_buffer.is_empty() {
177 self.duplexing();
178 }
179
180 self.output_buffer
181 .pop()
182 .expect("Output buffer should be non-empty")
183 })
184 }
185}
186
187impl<F, P, const WIDTH: usize, const RATE: usize> CanSampleBits<usize>
188 for DuplexChallenger<F, P, WIDTH, RATE>
189where
190 F: PrimeField64,
191 P: CryptographicPermutation<[F; WIDTH]>,
192{
193 fn sample_bits(&mut self, bits: usize) -> usize {
202 assert!(bits < (usize::BITS as usize));
203 assert!((1 << bits) < F::ORDER_U64);
204 let rand_f: F = self.sample();
205 let rand_usize = rand_f.as_canonical_u64() as usize;
206 rand_usize & ((1 << bits) - 1)
207 }
208}
209
210pub trait UniformSamplingField {
212 const MAX_SINGLE_SAMPLE_BITS: usize;
215 const SAMPLING_BITS_M: [u64; 64];
226}
227
228impl<MP> UniformSamplingField for MontyField31<MP>
232where
233 MP: UniformSamplingField + MontyParameters,
234{
235 const MAX_SINGLE_SAMPLE_BITS: usize = MP::MAX_SINGLE_SAMPLE_BITS;
236 const SAMPLING_BITS_M: [u64; 64] = MP::SAMPLING_BITS_M;
237}
238
239pub(super) struct ResampleOnRejection;
243pub(super) struct ErrorOnRejection;
245
246#[derive(Debug)]
249pub struct ResamplingError {
250 value: u64,
252 m: u64,
254}
255
256impl Display for ResamplingError {
257 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
258 write!(
259 f,
260 "Encountered value {0}, which requires resampling for uniform bits as it not smaller than {1}. But resampling is not enabled.",
261 self.value, self.m
262 )
263 }
264}
265
266impl Error for ResamplingError {}
267
268pub(super) trait BitSamplingStrategy<F, P, const W: usize, const R: usize>
270where
271 F: PrimeField64,
272 P: CryptographicPermutation<[F; W]>,
273{
274 const ERROR_ON_REJECTION: bool;
276
277 #[inline]
278 fn sample_value(
279 challenger: &mut DuplexChallenger<F, P, W, R>,
280 m: u64,
281 ) -> Result<F, ResamplingError> {
282 let mut result: F = challenger.sample();
283 if Self::ERROR_ON_REJECTION {
284 if result.as_canonical_u64() >= m {
285 return Err(ResamplingError {
286 value: result.as_canonical_u64(),
287 m,
288 });
289 }
290 } else {
291 while result.as_canonical_u64() >= m {
292 result = challenger.sample();
293 }
294 }
295 Ok(result)
296 }
297}
298
299impl<F, P, const W: usize, const R: usize> BitSamplingStrategy<F, P, W, R> for ResampleOnRejection
301where
302 F: PrimeField64,
303 P: CryptographicPermutation<[F; W]>,
304{
305 const ERROR_ON_REJECTION: bool = false;
306}
307
308impl<F, P, const W: usize, const R: usize> BitSamplingStrategy<F, P, W, R> for ErrorOnRejection
310where
311 F: PrimeField64,
312 P: CryptographicPermutation<[F; W]>,
313{
314 const ERROR_ON_REJECTION: bool = true;
315}
316
317impl<F, P, const WIDTH: usize, const RATE: usize> DuplexChallenger<F, P, WIDTH, RATE>
318where
319 F: UniformSamplingField + PrimeField64,
320 P: CryptographicPermutation<[F; WIDTH]>,
321{
322 #[inline]
324 fn sample_uniform_bits_with_strategy<S>(
325 &mut self,
326 bits: usize,
327 ) -> Result<usize, ResamplingError>
328 where
329 S: BitSamplingStrategy<F, P, WIDTH, RATE>,
330 {
331 if bits == 0 {
332 return Ok(0);
333 };
334 assert!(bits < usize::BITS as usize, "bit count must be valid");
335 assert!(
336 (1u64 << bits) < F::ORDER_U64,
337 "bit count exceeds field order"
338 );
339 let m = F::SAMPLING_BITS_M[bits];
340 if bits <= F::MAX_SINGLE_SAMPLE_BITS {
341 let rand_f = S::sample_value(self, m);
343 Ok(rand_f?.as_canonical_u64() as usize & ((1 << bits) - 1))
344 } else {
345 let half_bits1 = bits / 2;
348 let half_bits2 = bits - half_bits1;
349 let rand1 = S::sample_value(self, F::SAMPLING_BITS_M[half_bits1]);
351 let chunk1 = rand1?.as_canonical_u64() as usize & ((1 << half_bits1) - 1);
352 let rand2 = S::sample_value(self, F::SAMPLING_BITS_M[half_bits2]);
354 let chunk2 = rand2?.as_canonical_u64() as usize & ((1 << half_bits2) - 1);
355
356 Ok(chunk1 | (chunk2 << half_bits1))
358 }
359 }
360}
361
362impl<F, P, const WIDTH: usize, const RATE: usize> CanSampleUniformBits<F>
363 for DuplexChallenger<F, P, WIDTH, RATE>
364where
365 F: UniformSamplingField + PrimeField64,
366 P: CryptographicPermutation<[F; WIDTH]>,
367{
368 fn sample_uniform_bits<const RESAMPLE: bool>(
369 &mut self,
370 bits: usize,
371 ) -> Result<usize, ResamplingError> {
372 if RESAMPLE {
373 self.sample_uniform_bits_with_strategy::<ResampleOnRejection>(bits)
374 } else {
375 self.sample_uniform_bits_with_strategy::<ErrorOnRejection>(bits)
376 }
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use core::iter;
383
384 use p3_baby_bear::BabyBear;
385 use p3_field::PrimeCharacteristicRing;
386 use p3_field::extension::BinomialExtensionField;
387 use p3_goldilocks::Goldilocks;
388 use p3_symmetric::Permutation;
389
390 use super::*;
391 use crate::grinding_challenger::GrindingChallenger;
392
393 const WIDTH: usize = 24;
394 const RATE: usize = 16;
395
396 type G = Goldilocks;
397 type EF2G = BinomialExtensionField<G, 2>;
398
399 type BB = BabyBear;
400
401 #[derive(Clone)]
402 struct TestPermutation {}
403
404 impl<F: Clone> Permutation<[F; WIDTH]> for TestPermutation {
405 fn permute_mut(&self, input: &mut [F; WIDTH]) {
406 input.reverse();
407 }
408 }
409
410 impl<F: Clone> CryptographicPermutation<[F; WIDTH]> for TestPermutation {}
411
412 #[test]
413 fn test_duplex_challenger() {
414 type Chal = DuplexChallenger<G, TestPermutation, WIDTH, RATE>;
415 let permutation = TestPermutation {};
416 let mut duplex_challenger = DuplexChallenger::new(permutation);
417
418 (0..12).for_each(|element| duplex_challenger.observe(G::from_u8(element as u8)));
420
421 let state_after_duplexing: Vec<_> = iter::repeat_n(G::ZERO, 12)
422 .chain((0..12).map(G::from_u8).rev())
423 .collect();
424
425 let expected_samples: Vec<G> = state_after_duplexing[..16].iter().copied().rev().collect();
426 let samples = <Chal as CanSample<G>>::sample_vec(&mut duplex_challenger, 16);
427 assert_eq!(samples, expected_samples);
428 }
429
430 #[test]
431 #[should_panic]
432 fn test_duplex_challenger_sample_bits_security() {
433 type GoldilocksChal = DuplexChallenger<G, TestPermutation, WIDTH, RATE>;
434 let permutation = TestPermutation {};
435 let mut duplex_challenger = GoldilocksChal::new(permutation);
436
437 for _ in 0..100 {
438 assert!(duplex_challenger.sample_bits(129) < 4);
439 }
440 }
441
442 #[test]
443 #[should_panic]
444 fn test_duplex_challenger_sample_bits_security_small_field() {
445 type BabyBearChal = DuplexChallenger<BB, TestPermutation, WIDTH, RATE>;
446 let permutation = TestPermutation {};
447 let mut duplex_challenger = BabyBearChal::new(permutation);
448
449 for _ in 0..100 {
450 assert!(duplex_challenger.sample_bits(40) < 1 << 31);
451 }
452 }
453
454 #[test]
455 #[should_panic]
456 fn test_duplex_challenger_grind_security() {
457 type GoldilocksChal = DuplexChallenger<G, TestPermutation, WIDTH, RATE>;
458 let permutation = TestPermutation {};
459 let mut duplex_challenger = GoldilocksChal::new(permutation);
460
461 let too_many_bits = usize::BITS as usize;
466
467 let witness = duplex_challenger.grind(too_many_bits);
468 assert!(duplex_challenger.check_witness(too_many_bits, witness));
469 }
470
471 #[test]
472 fn test_observe_single_value() {
473 let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
474 chal.observe(G::from_u8(42));
475 assert_eq!(chal.input_buffer, vec![G::from_u8(42)]);
476 assert!(chal.output_buffer.is_empty());
477 }
478
479 #[test]
480 fn test_observe_array_of_values() {
481 let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
482 chal.observe([G::from_u8(1), G::from_u8(2), G::from_u8(3)]);
483 assert_eq!(
484 chal.input_buffer,
485 vec![G::from_u8(1), G::from_u8(2), G::from_u8(3)]
486 );
487 assert!(chal.output_buffer.is_empty());
488 }
489
490 #[test]
491 fn test_observe_hash_array() {
492 let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
493 let hash = Hash::<G, G, 4>::from([G::from_u8(10); 4]);
494 chal.observe(hash);
495 assert_eq!(chal.input_buffer, vec![G::from_u8(10); 4]);
496 }
497
498 #[test]
499 fn test_observe_nested_vecs() {
500 let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
501 chal.observe(vec![
502 vec![G::from_u8(1), G::from_u8(2)],
503 vec![G::from_u8(3)],
504 ]);
505 assert_eq!(
506 chal.input_buffer,
507 vec![G::from_u8(1), G::from_u8(2), G::from_u8(3)]
508 );
509 }
510
511 #[test]
512 fn test_sample_triggers_duplex() {
513 let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
514 chal.observe(G::from_u8(5));
515 assert!(chal.output_buffer.is_empty());
516 let _sample: G = chal.sample();
517 assert!(!chal.output_buffer.is_empty());
518 }
519
520 #[test]
521 fn test_sample_multiple_extension_field() {
522 use p3_field::extension::BinomialExtensionField;
523 type EF = BinomialExtensionField<G, 2>;
524 let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
525
526 chal.observe(G::from_u8(1));
527 chal.observe(G::from_u8(2));
528 let _: EF = chal.sample();
529 let _: EF = chal.sample();
530 }
531
532 #[test]
533 fn test_sample_bits_within_bounds() {
534 let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
535 for i in 0..RATE {
536 chal.observe(G::from_u8(i as u8));
537 }
538
539 let bits = 3;
544 let value = chal.sample_bits(bits);
545 let expected = G::ZERO.as_canonical_u64() as usize & ((1 << bits) - 1);
546 assert_eq!(value, expected);
547 }
548
549 #[test]
550 fn test_sample_bits_trigger_duplex_when_empty() {
551 let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
552 assert_eq!(chal.input_buffer.len(), 0);
554 assert_eq!(chal.output_buffer.len(), 0);
555
556 let bits = 2;
558 let sample = chal.sample_bits(bits);
559 let expected = G::ZERO.as_canonical_u64() as usize & ((1 << bits) - 1);
560 assert_eq!(sample, expected);
561 }
562
563 #[test]
564 fn test_output_buffer_pops_correctly() {
565 let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
566
567 for i in 0..RATE {
569 chal.observe(G::from_u8(i as u8));
570 }
571
572 let expected = [
574 G::from_u8(0),
575 G::from_u8(0),
576 G::from_u8(0),
577 G::from_u8(0),
578 G::from_u8(0),
579 G::from_u8(0),
580 G::from_u8(0),
581 G::from_u8(0),
582 G::from_u8(15),
583 G::from_u8(14),
584 G::from_u8(13),
585 G::from_u8(12),
586 G::from_u8(11),
587 G::from_u8(10),
588 G::from_u8(9),
589 G::from_u8(8),
590 ]
591 .to_vec();
592
593 assert_eq!(chal.output_buffer, expected);
594
595 let first: G = chal.sample();
596 let second: G = chal.sample();
597
598 assert_eq!(first, G::from_u8(8));
600 assert_eq!(second, G::from_u8(9));
601 }
602
603 #[test]
604 fn test_duplexing_only_when_needed() {
605 let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
606 chal.output_buffer = vec![G::from_u8(10), G::from_u8(20)];
607
608 let sample: G = chal.sample();
610 assert_eq!(sample, G::from_u8(20));
611 assert_eq!(chal.output_buffer, vec![G::from_u8(10)]);
612 }
613
614 #[test]
615 fn test_flush_when_input_full() {
616 let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
617
618 for i in 0..RATE {
620 chal.observe(G::from_u8(i as u8));
621 }
622
623 let expected_output = [
625 G::from_u8(0),
626 G::from_u8(0),
627 G::from_u8(0),
628 G::from_u8(0),
629 G::from_u8(0),
630 G::from_u8(0),
631 G::from_u8(0),
632 G::from_u8(0),
633 G::from_u8(15),
634 G::from_u8(14),
635 G::from_u8(13),
636 G::from_u8(12),
637 G::from_u8(11),
638 G::from_u8(10),
639 G::from_u8(9),
640 G::from_u8(8),
641 ]
642 .to_vec();
643
644 assert!(chal.input_buffer.is_empty());
646
647 assert_eq!(chal.output_buffer, expected_output);
649 }
650
651 #[test]
652 fn test_observe_base_as_algebra_element_consistency_with_direct_observe() {
653 let mut chal1 =
655 DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
656 let mut chal2 =
657 DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
658
659 let base_val = G::from_u8(99);
660
661 chal1.observe_base_as_algebra_element::<EF2G>(base_val);
663
664 let ext_val = EF2G::from(base_val);
666 chal2.observe_algebra_element(ext_val);
667
668 assert_eq!(chal1.input_buffer, chal2.input_buffer);
670 assert_eq!(chal1.output_buffer, chal2.output_buffer);
671 assert_eq!(chal1.sponge_state, chal2.sponge_state);
672 }
673
674 #[test]
675 fn test_observe_base_as_algebra_element_stream_consistency() {
676 let mut chal1 =
678 DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
679 let mut chal2 =
680 DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
681
682 let base_values: Vec<_> = (0u8..25).map(G::from_u8).collect();
684
685 for &val in &base_values {
687 chal1.observe_base_as_algebra_element::<EF2G>(val);
688 }
689
690 for &val in &base_values {
692 let ext_val = EF2G::from(val);
693 chal2.observe_algebra_element(ext_val);
694 }
695
696 assert_eq!(chal1.input_buffer, chal2.input_buffer);
698 assert_eq!(chal1.output_buffer, chal2.output_buffer);
699 assert_eq!(chal1.sponge_state, chal2.sponge_state);
700
701 let sample1: EF2G = chal1.sample_algebra_element();
703 let sample2: EF2G = chal2.sample_algebra_element();
704 assert_eq!(sample1, sample2);
705
706 assert_eq!(chal1.input_buffer, chal2.input_buffer);
708 assert_eq!(chal1.output_buffer, chal2.output_buffer);
709 assert_eq!(chal1.sponge_state, chal2.sponge_state);
710 }
711
712 #[test]
713 fn test_observe_algebra_elements_equivalence() {
714 let mut chal1 =
718 DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
719 let mut chal2 =
720 DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
721
722 let ext_values: Vec<EF2G> = (0u8..10).map(|i| EF2G::from(G::from_u8(i))).collect();
724
725 chal1.observe_algebra_slice(&ext_values);
727
728 for ext_val in &ext_values {
730 chal2.observe_algebra_element(*ext_val);
731 }
732
733 assert_eq!(chal1.input_buffer, chal2.input_buffer);
735 assert_eq!(chal1.output_buffer, chal2.output_buffer);
736 assert_eq!(chal1.sponge_state, chal2.sponge_state);
737
738 let sample1: EF2G = chal1.sample_algebra_element();
740 let sample2: EF2G = chal2.sample_algebra_element();
741 assert_eq!(sample1, sample2);
742 }
743
744 #[test]
745 fn test_observe_algebra_elements_empty_slice() {
746 let mut chal1 =
748 DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
749 let mut chal2 =
750 DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
751
752 chal1.observe(G::from_u8(42));
754 chal2.observe(G::from_u8(42));
755
756 let empty: Vec<EF2G> = vec![];
758 chal1.observe_algebra_slice(&empty);
759
760 assert_eq!(chal1.input_buffer, chal2.input_buffer);
762 assert_eq!(chal1.output_buffer, chal2.output_buffer);
763 assert_eq!(chal1.sponge_state, chal2.sponge_state);
764 }
765
766 #[test]
767 fn test_observe_algebra_elements_triggers_duplexing() {
768 let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
770
771 let ext_values: Vec<EF2G> = (0u8..8).map(|i| EF2G::from(G::from_u8(i))).collect();
775
776 assert!(chal.input_buffer.is_empty());
777 assert!(chal.output_buffer.is_empty());
778
779 chal.observe_algebra_slice(&ext_values);
780
781 assert!(chal.input_buffer.is_empty());
783 assert!(!chal.output_buffer.is_empty());
784 }
785}