p3_symmetric/
serializing_hasher.rs

1use p3_field::Field;
2
3use crate::CryptographicHasher;
4
5/// Converts a hasher which can hash bytes, u32's or u64's into a hasher which can hash field elements.
6///
7/// Supports two types of hashing.
8/// - Hashing a sequence of field elements.
9/// - Hashing a sequence of arrays of `N` field elements as if we are hashing `N` sequences of field elements in parallel.
10///   This is useful when the inner hash is able to use vectorized instructions to compute multiple hashes at once.
11#[derive(Copy, Clone, Debug)]
12pub struct SerializingHasher<Inner> {
13    inner: Inner,
14}
15
16impl<Inner> SerializingHasher<Inner> {
17    pub const fn new(inner: Inner) -> Self {
18        Self { inner }
19    }
20}
21
22impl<F, Inner, const N: usize> CryptographicHasher<F, [u8; N]> for SerializingHasher<Inner>
23where
24    F: Field,
25    Inner: CryptographicHasher<u8, [u8; N]>,
26{
27    fn hash_iter<I>(&self, input: I) -> [u8; N]
28    where
29        I: IntoIterator<Item = F>,
30    {
31        self.inner.hash_iter(F::into_byte_stream(input))
32    }
33}
34
35impl<F, Inner, const N: usize> CryptographicHasher<F, [u32; N]> for SerializingHasher<Inner>
36where
37    F: Field,
38    Inner: CryptographicHasher<u32, [u32; N]>,
39{
40    fn hash_iter<I>(&self, input: I) -> [u32; N]
41    where
42        I: IntoIterator<Item = F>,
43    {
44        self.inner.hash_iter(F::into_u32_stream(input))
45    }
46}
47
48impl<F, Inner, const N: usize> CryptographicHasher<F, [u64; N]> for SerializingHasher<Inner>
49where
50    F: Field,
51    Inner: CryptographicHasher<u64, [u64; N]>,
52{
53    fn hash_iter<I>(&self, input: I) -> [u64; N]
54    where
55        I: IntoIterator<Item = F>,
56    {
57        self.inner.hash_iter(F::into_u64_stream(input))
58    }
59}
60
61impl<F, Inner, const N: usize, const M: usize> CryptographicHasher<[F; M], [[u8; M]; N]>
62    for SerializingHasher<Inner>
63where
64    F: Field,
65    Inner: CryptographicHasher<[u8; M], [[u8; M]; N]>,
66{
67    fn hash_iter<I>(&self, input: I) -> [[u8; M]; N]
68    where
69        I: IntoIterator<Item = [F; M]>,
70    {
71        self.inner.hash_iter(F::into_parallel_byte_streams(input))
72    }
73}
74
75impl<F, Inner, const N: usize, const M: usize> CryptographicHasher<[F; M], [[u32; M]; N]>
76    for SerializingHasher<Inner>
77where
78    F: Field,
79    Inner: CryptographicHasher<[u32; M], [[u32; M]; N]>,
80{
81    fn hash_iter<I>(&self, input: I) -> [[u32; M]; N]
82    where
83        I: IntoIterator<Item = [F; M]>,
84    {
85        self.inner.hash_iter(F::into_parallel_u32_streams(input))
86    }
87}
88
89impl<F, Inner, const N: usize, const M: usize> CryptographicHasher<[F; M], [[u64; M]; N]>
90    for SerializingHasher<Inner>
91where
92    F: Field,
93    Inner: CryptographicHasher<[u64; M], [[u64; M]; N]>,
94{
95    fn hash_iter<I>(&self, input: I) -> [[u64; M]; N]
96    where
97        I: IntoIterator<Item = [F; M]>,
98    {
99        self.inner.hash_iter(F::into_parallel_u64_streams(input))
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use core::array;
106
107    use p3_koala_bear::KoalaBear;
108
109    use crate::{CryptographicHasher, SerializingHasher};
110
111    #[derive(Clone)]
112    struct MockHasher;
113
114    impl CryptographicHasher<u8, [u8; 4]> for MockHasher {
115        fn hash_iter<I: IntoIterator<Item = u8>>(&self, iter: I) -> [u8; 4] {
116            let sum: u8 = iter.into_iter().fold(0, |acc, x| acc.wrapping_add(x));
117            // Simplest impl: set every element to the sum
118            [sum; 4]
119        }
120    }
121
122    impl CryptographicHasher<[u8; 4], [[u8; 4]; 4]> for MockHasher {
123        fn hash_iter<I: IntoIterator<Item = [u8; 4]>>(&self, iter: I) -> [[u8; 4]; 4] {
124            let sum: [u8; 4] = iter.into_iter().fold([0, 0, 0, 0], |acc, x| {
125                [
126                    acc[0].wrapping_add(x[0]),
127                    acc[1].wrapping_add(x[1]),
128                    acc[2].wrapping_add(x[2]),
129                    acc[3].wrapping_add(x[3]),
130                ]
131            });
132            // Simplest impl: set every element to the sum
133            [sum; 4]
134        }
135    }
136
137    impl CryptographicHasher<u32, [u32; 4]> for MockHasher {
138        fn hash_iter<I: IntoIterator<Item = u32>>(&self, iter: I) -> [u32; 4] {
139            let sum: u32 = iter.into_iter().fold(0, |acc, x| acc.wrapping_add(x));
140            // Simplest impl: set every element to the sum
141            [sum; 4]
142        }
143    }
144
145    impl CryptographicHasher<[u32; 4], [[u32; 4]; 4]> for MockHasher {
146        fn hash_iter<I: IntoIterator<Item = [u32; 4]>>(&self, iter: I) -> [[u32; 4]; 4] {
147            let sum: [u32; 4] = iter.into_iter().fold([0, 0, 0, 0], |acc, x| {
148                [
149                    acc[0].wrapping_add(x[0]),
150                    acc[1].wrapping_add(x[1]),
151                    acc[2].wrapping_add(x[2]),
152                    acc[3].wrapping_add(x[3]),
153                ]
154            });
155            // Simplest impl: set every element to the sum
156            [sum; 4]
157        }
158    }
159
160    impl CryptographicHasher<u64, [u64; 4]> for MockHasher {
161        fn hash_iter<I: IntoIterator<Item = u64>>(&self, iter: I) -> [u64; 4] {
162            let sum: u64 = iter.into_iter().fold(0, |acc, x| acc.wrapping_add(x));
163            // Simplest impl: set every element to the sum
164            [sum; 4]
165        }
166    }
167
168    impl CryptographicHasher<[u64; 4], [[u64; 4]; 4]> for MockHasher {
169        fn hash_iter<I: IntoIterator<Item = [u64; 4]>>(&self, iter: I) -> [[u64; 4]; 4] {
170            let sum: [u64; 4] = iter.into_iter().fold([0, 0, 0, 0], |acc, x| {
171                [
172                    acc[0].wrapping_add(x[0]),
173                    acc[1].wrapping_add(x[1]),
174                    acc[2].wrapping_add(x[2]),
175                    acc[3].wrapping_add(x[3]),
176                ]
177            });
178            // Simplest impl: set every element to the sum
179            [sum; 4]
180        }
181    }
182
183    #[test]
184    fn test_parallel_hashers() {
185        let mock_hash = MockHasher {};
186        let hasher = SerializingHasher::new(mock_hash);
187        let input: [KoalaBear; 256] = KoalaBear::new_array(array::from_fn(|x| x as u32));
188
189        let parallel_input: [[KoalaBear; 4]; 64] = unsafe { core::mem::transmute(input) };
190        let unzipped_input: [[KoalaBear; 64]; 4] = array::from_fn(|i| parallel_input.map(|x| x[i]));
191
192        let u8_output_parallel: [[u8; 4]; 4] = hasher.hash_iter(parallel_input);
193        let u8_output_individual: [[u8; 4]; 4] = unzipped_input.map(|x| hasher.hash_iter(x));
194        let u8_output_individual_transposed =
195            array::from_fn(|i| u8_output_individual.map(|x| x[i]));
196
197        let u32_output_parallel: [[u32; 4]; 4] = hasher.hash_iter(parallel_input);
198        let u32_output_individual: [[u32; 4]; 4] = unzipped_input.map(|x| hasher.hash_iter(x));
199        let u32_output_individual_transposed =
200            array::from_fn(|i| u32_output_individual.map(|x| x[i]));
201
202        let u64_output_parallel: [[u64; 4]; 4] = hasher.hash_iter(parallel_input);
203        let u64_output_individual: [[u64; 4]; 4] = unzipped_input.map(|x| hasher.hash_iter(x));
204        let u64_output_individual_transposed =
205            array::from_fn(|i| u64_output_individual.map(|x| x[i]));
206
207        assert_eq!(u8_output_parallel, u8_output_individual_transposed);
208        assert_eq!(u32_output_parallel, u32_output_individual_transposed);
209        assert_eq!(u64_output_parallel, u64_output_individual_transposed);
210    }
211}