p3_symmetric/
compression.rs

1use crate::hasher::CryptographicHasher;
2use crate::permutation::CryptographicPermutation;
3
4/// An `N`-to-1 compression function collision-resistant in a hash tree setting.
5///
6/// Unlike `CompressionFunction`, it may not be collision-resistant in general.
7/// Instead it is only collision-resistant in hash-tree like settings where
8/// the preimage of a non-leaf node must consist of compression outputs.
9pub trait PseudoCompressionFunction<T, const N: usize>: Clone {
10    fn compress(&self, input: [T; N]) -> T;
11}
12
13/// An `N`-to-1 compression function.
14pub trait CompressionFunction<T, const N: usize>: PseudoCompressionFunction<T, N> {}
15
16#[derive(Clone, Debug)]
17pub struct TruncatedPermutation<InnerP, const N: usize, const CHUNK: usize, const WIDTH: usize> {
18    inner_permutation: InnerP,
19}
20
21impl<InnerP, const N: usize, const CHUNK: usize, const WIDTH: usize>
22    TruncatedPermutation<InnerP, N, CHUNK, WIDTH>
23{
24    pub const fn new(inner_permutation: InnerP) -> Self {
25        Self { inner_permutation }
26    }
27}
28
29impl<T, InnerP, const N: usize, const CHUNK: usize, const WIDTH: usize>
30    PseudoCompressionFunction<[T; CHUNK], N> for TruncatedPermutation<InnerP, N, CHUNK, WIDTH>
31where
32    T: Copy + Default,
33    InnerP: CryptographicPermutation<[T; WIDTH]>,
34{
35    fn compress(&self, input: [[T; CHUNK]; N]) -> [T; CHUNK] {
36        debug_assert!(CHUNK * N <= WIDTH);
37        let mut pre = [T::default(); WIDTH];
38        for i in 0..N {
39            pre[i * CHUNK..(i + 1) * CHUNK].copy_from_slice(&input[i]);
40        }
41        let post = self.inner_permutation.permute(pre);
42        post[..CHUNK].try_into().unwrap()
43    }
44}
45
46#[derive(Clone, Debug)]
47pub struct CompressionFunctionFromHasher<H, const N: usize, const CHUNK: usize> {
48    hasher: H,
49}
50
51impl<H, const N: usize, const CHUNK: usize> CompressionFunctionFromHasher<H, N, CHUNK> {
52    pub const fn new(hasher: H) -> Self {
53        Self { hasher }
54    }
55}
56
57impl<T, H, const N: usize, const CHUNK: usize> PseudoCompressionFunction<[T; CHUNK], N>
58    for CompressionFunctionFromHasher<H, N, CHUNK>
59where
60    T: Clone,
61    H: CryptographicHasher<T, [T; CHUNK]>,
62{
63    fn compress(&self, input: [[T; CHUNK]; N]) -> [T; CHUNK] {
64        self.hasher.hash_iter(input.into_iter().flatten())
65    }
66}
67
68impl<T, H, const N: usize, const CHUNK: usize> CompressionFunction<[T; CHUNK], N>
69    for CompressionFunctionFromHasher<H, N, CHUNK>
70where
71    T: Clone,
72    H: CryptographicHasher<T, [T; CHUNK]>,
73{
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79    use crate::Permutation;
80
81    #[derive(Clone)]
82    struct MockPermutation;
83
84    impl<T, const WIDTH: usize> Permutation<[T; WIDTH]> for MockPermutation
85    where
86        T: Copy + core::ops::Add<Output = T> + Default,
87    {
88        fn permute_mut(&self, input: &mut [T; WIDTH]) {
89            let sum: T = input.iter().copied().fold(T::default(), |acc, x| acc + x);
90            // Simplest impl: set every element to the sum
91            *input = [sum; WIDTH];
92        }
93    }
94
95    impl<T, const WIDTH: usize> CryptographicPermutation<[T; WIDTH]> for MockPermutation where
96        T: Copy + core::ops::Add<Output = T> + Default
97    {
98    }
99
100    #[derive(Clone)]
101    struct MockHasher;
102
103    impl<const CHUNK: usize> CryptographicHasher<u64, [u64; CHUNK]> for MockHasher {
104        fn hash_iter<I: IntoIterator<Item = u64>>(&self, iter: I) -> [u64; CHUNK] {
105            let sum: u64 = iter.into_iter().sum();
106            // Simplest impl: set every element to the sum
107            [sum; CHUNK]
108        }
109    }
110
111    #[test]
112    fn test_truncated_permutation_compress() {
113        const N: usize = 2;
114        const CHUNK: usize = 4;
115        const WIDTH: usize = 8;
116
117        let permutation = MockPermutation;
118        let compressor = TruncatedPermutation::<MockPermutation, N, CHUNK, WIDTH>::new(permutation);
119
120        let input: [[u64; CHUNK]; N] = [[1, 2, 3, 4], [5, 6, 7, 8]];
121        let output = compressor.compress(input);
122        let expected_sum = 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8;
123
124        assert_eq!(output, [expected_sum; CHUNK]);
125    }
126
127    #[test]
128    fn test_compression_function_from_hasher_compress() {
129        const N: usize = 2;
130        const CHUNK: usize = 4;
131
132        let hasher = MockHasher;
133        let compressor = CompressionFunctionFromHasher::<MockHasher, N, CHUNK>::new(hasher);
134
135        let input = [[10, 20, 30, 40], [50, 60, 70, 80]];
136        let output = compressor.compress(input);
137        let expected_sum = 10 + 20 + 30 + 40 + 50 + 60 + 70 + 80;
138
139        assert_eq!(output, [expected_sum; CHUNK]);
140    }
141
142    #[test]
143    fn test_truncated_permutation_with_zeros() {
144        const N: usize = 2;
145        const CHUNK: usize = 4;
146        const WIDTH: usize = 8;
147
148        let permutation = MockPermutation;
149        let compressor = TruncatedPermutation::<MockPermutation, N, CHUNK, WIDTH>::new(permutation);
150
151        let input: [[u64; CHUNK]; N] = [[0, 0, 0, 0], [0, 0, 0, 0]];
152        let output = compressor.compress(input);
153
154        assert_eq!(output, [0; CHUNK]);
155    }
156
157    #[test]
158    fn test_truncated_permutation_with_extra_width() {
159        const N: usize = 2;
160        const CHUNK: usize = 3;
161        const WIDTH: usize = 10; // More than `CHUNK * N` (6 < 10)
162
163        let permutation = MockPermutation;
164        let compressor = TruncatedPermutation::<MockPermutation, N, CHUNK, WIDTH>::new(permutation);
165
166        let input: [[u64; CHUNK]; N] = [[1, 2, 3], [4, 5, 6]];
167        let output = compressor.compress(input);
168
169        let expected_sum = 1 + 2 + 3 + 4 + 5 + 6;
170
171        assert_eq!(
172            output, [expected_sum; CHUNK],
173            "Compression should correctly handle extra WIDTH space."
174        );
175    }
176}