Skip to main content

p3_symmetric/
sponge.rs

1use alloc::string::String;
2use core::marker::PhantomData;
3
4use itertools::Itertools;
5use p3_field::{
6    PrimeField, PrimeField32, SpongePaddingValue, absorb_radix_bits,
7    max_shifted_absorb_injective_limbs, reduce_packed_shifted,
8};
9
10use crate::hasher::CryptographicHasher;
11use crate::permutation::CryptographicPermutation;
12
13/// A padded, overwrite-mode sponge function.
14///
15/// `WIDTH` is the sponge's rate plus the sponge's capacity.
16#[derive(Copy, Clone, Debug)]
17pub struct PaddingFreeSponge<P, const WIDTH: usize, const RATE: usize, const OUT: usize> {
18    permutation: P,
19}
20
21impl<P, const WIDTH: usize, const RATE: usize, const OUT: usize>
22    PaddingFreeSponge<P, WIDTH, RATE, OUT>
23{
24    pub const fn new(permutation: P) -> Self {
25        const {
26            assert!(RATE > 0);
27            assert!(RATE < WIDTH);
28            assert!(OUT <= WIDTH);
29        }
30
31        Self { permutation }
32    }
33}
34
35impl<T, P, const WIDTH: usize, const RATE: usize, const OUT: usize> CryptographicHasher<T, [T; OUT]>
36    for PaddingFreeSponge<P, WIDTH, RATE, OUT>
37where
38    T: Default + SpongePaddingValue,
39    P: CryptographicPermutation<[T; WIDTH]>,
40{
41    fn hash_iter<I>(&self, input: I) -> [T; OUT]
42    where
43        I: IntoIterator<Item = T>,
44    {
45        const {
46            assert!(RATE > 0);
47            assert!(RATE < WIDTH);
48            assert!(OUT <= WIDTH);
49        }
50        // Start from the all-zero state.
51        let mut state = [T::default(); WIDTH];
52        let mut input = input.into_iter();
53
54        'outer: loop {
55            // Absorb one block: overwrite state[0..RATE] with input elements one at a time.
56            for i in 0..RATE {
57                if let Some(x) = input.next() {
58                    // Overwrite the i-th rate position.
59                    state[i] = x;
60                } else {
61                    // Input exhausted mid-block. Permute only if at least
62                    // one element was absorbed in this block (i > 0).
63                    // If i == 0 the state already reflects the previous
64                    // permutation output and needs no extra call.
65                    if i != 0 {
66                        self.permutation.permute_mut(&mut state);
67                    }
68                    break 'outer;
69                }
70            }
71
72            // Full block absorbed. Permute before the next block.
73            self.permutation.permute_mut(&mut state);
74        }
75
76        // Squeeze: return the first OUT elements of the final state.
77        state[..OUT].try_into().unwrap()
78    }
79}
80
81/// Padding-free sponge over a large prime field, accepting 32-bit field elements as input.
82///
83/// # Security
84///
85/// **Not** collision-resistant for variable-length inputs.
86#[derive(Clone, Debug)]
87pub struct MultiField32PaddingFreeSponge<
88    F,
89    PF,
90    P,
91    const WIDTH: usize,
92    const RATE: usize,
93    const OUT: usize,
94> {
95    /// The cryptographic permutation applied after each absorbed block.
96    permutation: P,
97    /// How many small-field elements fit inside one large-field element.
98    num_f_elms: usize,
99    /// Radix used for shifted packing into the large field.
100    radix_bits: u32,
101    _phantom: PhantomData<(F, PF)>,
102}
103
104impl<F, PF, P, const WIDTH: usize, const RATE: usize, const OUT: usize>
105    MultiField32PaddingFreeSponge<F, PF, P, WIDTH, RATE, OUT>
106where
107    F: PrimeField32,
108    PF: PrimeField,
109{
110    pub fn new(permutation: P) -> Result<Self, String> {
111        const {
112            assert!(RATE > 0);
113            assert!(RATE < WIDTH);
114            assert!(OUT <= WIDTH);
115        }
116
117        if F::order() >= PF::order() {
118            return Err(String::from("F::order() must be less than PF::order()"));
119        }
120
121        // Use shifted-radix injective packing for robust absorb encoding.
122        let num_f_elms = max_shifted_absorb_injective_limbs::<F, PF>();
123        let radix_bits = absorb_radix_bits::<F>();
124        Ok(Self {
125            permutation,
126            num_f_elms,
127            radix_bits,
128            _phantom: PhantomData,
129        })
130    }
131}
132
133impl<F, PF, P, const WIDTH: usize, const RATE: usize, const OUT: usize>
134    CryptographicHasher<F, [PF; OUT]> for MultiField32PaddingFreeSponge<F, PF, P, WIDTH, RATE, OUT>
135where
136    F: PrimeField32,
137    PF: PrimeField + Default + Copy,
138    P: CryptographicPermutation<[PF; WIDTH]>,
139{
140    fn hash_iter<I>(&self, input: I) -> [PF; OUT]
141    where
142        I: IntoIterator<Item = F>,
143    {
144        const {
145            assert!(RATE > 0);
146            assert!(RATE < WIDTH);
147            assert!(OUT <= WIDTH);
148        }
149        let mut state = [PF::default(); WIDTH];
150
151        // Example: RATE = 3, num_f_elms = 2, input = [f0..f7]
152        //
153        //   block_chunk = [f0, f1, f2, f3, f4, f5]  (RATE * 2 = 6 small elems)
154        //     chunk 0: [f0, f1] -> pack into PF -> state[0]
155        //     chunk 1: [f2, f3] -> pack into PF -> state[1]
156        //     chunk 2: [f4, f5] -> pack into PF -> state[2]
157        //   -> permute
158        //
159        //   block_chunk = [f6, f7]  (partial)
160        //     chunk 0: [f6, f7] -> pack into PF -> state[0]
161        //   -> permute
162        for block_chunk in &input.into_iter().chunks(RATE * self.num_f_elms) {
163            for (chunk_id, chunk) in (&block_chunk.chunks(self.num_f_elms))
164                .into_iter()
165                .enumerate()
166            {
167                // Pack num_f_elms small-field elements into one large-field
168                // element via shifted-radix reduction.
169                state[chunk_id] = reduce_packed_shifted(&chunk.collect_vec(), self.radix_bits);
170            }
171            state = self.permutation.permute(state);
172        }
173
174        state[..OUT].try_into().unwrap()
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use p3_field::{PrimeCharacteristicRing, absorb_radix_bits, reduce_32, reduce_packed_shifted};
181    use p3_goldilocks::Goldilocks;
182    use p3_koala_bear::KoalaBear;
183
184    use super::*;
185    use crate::Permutation;
186
187    #[derive(Clone)]
188    struct MockPermutation;
189
190    impl<T, const WIDTH: usize> Permutation<[T; WIDTH]> for MockPermutation
191    where
192        T: Copy + core::ops::Add<Output = T> + Default,
193    {
194        fn permute_mut(&self, input: &mut [T; WIDTH]) {
195            let sum: T = input.iter().copied().fold(T::default(), |acc, x| acc + x);
196            // Set every element to the sum
197            *input = [sum; WIDTH];
198        }
199    }
200
201    impl<T, const WIDTH: usize> CryptographicPermutation<[T; WIDTH]> for MockPermutation where
202        T: Copy + core::ops::Add<Output = T> + Default
203    {
204    }
205
206    #[derive(Clone)]
207    struct IdentityPermutation;
208
209    impl<T: Clone, const WIDTH: usize> Permutation<[T; WIDTH]> for IdentityPermutation {
210        fn permute_mut(&self, _input: &mut [T; WIDTH]) {}
211    }
212
213    impl<T: Clone, const WIDTH: usize> CryptographicPermutation<[T; WIDTH]> for IdentityPermutation {}
214
215    #[test]
216    fn test_padding_free_sponge_basic() {
217        const WIDTH: usize = 4;
218        const RATE: usize = 2;
219        const OUT: usize = 2;
220
221        let permutation = MockPermutation;
222        let sponge = PaddingFreeSponge::<MockPermutation, WIDTH, RATE, OUT>::new(permutation);
223
224        let input = [1, 2, 3, 4, 5];
225        let output = sponge.hash_iter(input);
226
227        // Explanation of why the final state results in [44, 44, 44, 44]:
228        // Initial state: [0, 0, 0, 0]
229        // First input chunk [1, 2] overwrites first two positions: [1, 2, 0, 0]
230        // Apply permutation (sum all elements and overwrite): [3, 3, 3, 3]
231        // Second input chunk [3, 4] overwrites first two positions: [3, 4, 3, 3]
232        // Apply permutation: [13, 13, 13, 13] (3 + 4 + 3 + 3 = 13)
233        // Third input chunk [5] overwrites first position: [5, 13, 13, 13]
234        // Apply permutation: [44, 44, 44, 44] (5 + 13 + 13 + 13 = 44)
235
236        assert_eq!(output, [44; OUT]);
237    }
238
239    #[test]
240    fn test_padding_free_sponge_empty_input() {
241        // Empty input: no elements absorbed, no permutation called.
242        //
243        // The initial all-zero state is returned directly.
244        const WIDTH: usize = 4;
245        const RATE: usize = 2;
246        const OUT: usize = 2;
247
248        let sponge = PaddingFreeSponge::<MockPermutation, WIDTH, RATE, OUT>::new(MockPermutation);
249
250        let input: [u64; 0] = [];
251        let output = sponge.hash_iter(input);
252
253        // Squeeze from the untouched zero state.
254        assert_eq!(output, [0; OUT]);
255    }
256
257    #[test]
258    fn test_padding_free_sponge_exact_block_size() {
259        const WIDTH: usize = 6;
260        const RATE: usize = 3;
261        const OUT: usize = 2;
262
263        let permutation = MockPermutation;
264        let sponge = PaddingFreeSponge::<MockPermutation, WIDTH, RATE, OUT>::new(permutation);
265
266        let input = [10, 20, 30];
267        let output = sponge.hash_iter(input);
268
269        assert_eq!(output, [60; OUT]);
270    }
271
272    #[test]
273    fn test_multi_field32_padding_free_sponge_uses_absorb_radix() {
274        const WIDTH: usize = 5;
275        const RATE: usize = 4;
276        const OUT: usize = 1;
277
278        type F = KoalaBear;
279        type PF = Goldilocks;
280
281        let sponge =
282            MultiField32PaddingFreeSponge::<F, PF, _, WIDTH, RATE, OUT>::new(IdentityPermutation)
283                .unwrap();
284
285        let input = [F::from_u32(1 << 30), F::ONE];
286        let output = sponge.hash_iter(input);
287        let expected = [reduce_packed_shifted::<F, PF>(
288            &input,
289            absorb_radix_bits::<F>(),
290        )];
291
292        assert_eq!(output, expected);
293        assert_ne!(output[0], reduce_32::<F, PF>(&input));
294    }
295
296    #[test]
297    fn test_multi_field32_padding_free_sponge_fills_full_pf_rate_rows() {
298        const WIDTH: usize = 6;
299        const RATE: usize = 5;
300        const OUT: usize = 4;
301
302        type F = KoalaBear;
303        type PF = Goldilocks;
304
305        let sponge =
306            MultiField32PaddingFreeSponge::<F, PF, _, WIDTH, RATE, OUT>::new(IdentityPermutation)
307                .unwrap();
308
309        let input = core::array::from_fn::<_, 8, _>(|i| F::from_u32((i + 1) as u32));
310        let radix_bits = absorb_radix_bits::<F>();
311        let packed = [
312            reduce_packed_shifted::<F, PF>(&input[0..2], radix_bits),
313            reduce_packed_shifted::<F, PF>(&input[2..4], radix_bits),
314            reduce_packed_shifted::<F, PF>(&input[4..6], radix_bits),
315            reduce_packed_shifted::<F, PF>(&input[6..8], radix_bits),
316        ];
317
318        assert_eq!(sponge.num_f_elms, 2);
319        assert_eq!(sponge.hash_iter(input), packed);
320    }
321
322    #[test]
323    fn test_multi_field32_padding_free_sponge_distinguishes_trailing_zero_in_slot() {
324        const WIDTH: usize = 2;
325        const RATE: usize = 1;
326        const OUT: usize = 1;
327
328        type F = KoalaBear;
329        type PF = Goldilocks;
330
331        let sponge =
332            MultiField32PaddingFreeSponge::<F, PF, _, WIDTH, RATE, OUT>::new(MockPermutation)
333                .unwrap();
334
335        assert_eq!(sponge.num_f_elms, 2);
336        assert_ne!(
337            sponge.hash_iter([F::ONE]),
338            sponge.hash_iter([F::ONE, F::ZERO])
339        );
340    }
341}