Skip to main content

p3_symmetric/
compression.rs

1use crate::hasher::CryptographicHasher;
2use crate::permutation::CryptographicPermutation;
3
4/// An `N`-to-1 compression function collision-resistant in a hash tree setting.
5///
6/// Unlike `CompressionFunction`, it may not be collision-resistant in general.
7/// Instead it is only collision-resistant in hash-tree like settings where
8/// the preimage of a non-leaf node must consist of compression outputs.
9pub trait PseudoCompressionFunction<T, const N: usize>: Clone {
10    fn compress(&self, input: [T; N]) -> T;
11}
12
13/// An `N`-to-1 compression function.
14pub trait CompressionFunction<T, const N: usize>: PseudoCompressionFunction<T, N> {}
15
16#[derive(Clone, Debug)]
17pub struct TruncatedPermutation<InnerP, const N: usize, const CHUNK: usize, const WIDTH: usize> {
18    inner_permutation: InnerP,
19}
20
21impl<InnerP, const N: usize, const CHUNK: usize, const WIDTH: usize>
22    TruncatedPermutation<InnerP, N, CHUNK, WIDTH>
23{
24    pub const fn new(inner_permutation: InnerP) -> Self {
25        Self { inner_permutation }
26    }
27}
28
29impl<T, InnerP, const N: usize, const CHUNK: usize, const WIDTH: usize>
30    PseudoCompressionFunction<[T; CHUNK], N> for TruncatedPermutation<InnerP, N, CHUNK, WIDTH>
31where
32    T: Copy + Default,
33    InnerP: CryptographicPermutation<[T; WIDTH]>,
34{
35    fn compress(&self, input: [[T; CHUNK]; N]) -> [T; CHUNK] {
36        const {
37            assert!(N > 0);
38            assert!(CHUNK > 0);
39            assert!(CHUNK * N <= WIDTH);
40        }
41        let mut pre = [T::default(); WIDTH];
42        for i in 0..N {
43            pre[i * CHUNK..(i + 1) * CHUNK].copy_from_slice(&input[i]);
44        }
45        let post = self.inner_permutation.permute(pre);
46        post[..CHUNK].try_into().unwrap()
47    }
48}
49
50#[derive(Clone, Debug)]
51pub struct CompressionFunctionFromHasher<H, const N: usize, const CHUNK: usize> {
52    hasher: H,
53}
54
55impl<H, const N: usize, const CHUNK: usize> CompressionFunctionFromHasher<H, N, CHUNK> {
56    pub const fn new(hasher: H) -> Self {
57        Self { hasher }
58    }
59}
60
61impl<T, H, const N: usize, const CHUNK: usize> PseudoCompressionFunction<[T; CHUNK], N>
62    for CompressionFunctionFromHasher<H, N, CHUNK>
63where
64    T: Clone,
65    H: CryptographicHasher<T, [T; CHUNK]>,
66{
67    fn compress(&self, input: [[T; CHUNK]; N]) -> [T; CHUNK] {
68        self.hasher.hash_iter(input.into_iter().flatten())
69    }
70}
71
72impl<T, H, const N: usize, const CHUNK: usize> CompressionFunction<[T; CHUNK], N>
73    for CompressionFunctionFromHasher<H, N, CHUNK>
74where
75    T: Clone,
76    H: CryptographicHasher<T, [T; CHUNK]>,
77{
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83    use crate::Permutation;
84
85    #[derive(Clone)]
86    struct MockPermutation;
87
88    impl<T, const WIDTH: usize> Permutation<[T; WIDTH]> for MockPermutation
89    where
90        T: Copy + core::ops::Add<Output = T> + Default,
91    {
92        fn permute_mut(&self, input: &mut [T; WIDTH]) {
93            let sum: T = input.iter().copied().fold(T::default(), |acc, x| acc + x);
94            // Simplest impl: set every element to the sum
95            *input = [sum; WIDTH];
96        }
97    }
98
99    impl<T, const WIDTH: usize> CryptographicPermutation<[T; WIDTH]> for MockPermutation where
100        T: Copy + core::ops::Add<Output = T> + Default
101    {
102    }
103
104    #[derive(Clone)]
105    struct MockHasher;
106
107    impl<const CHUNK: usize> CryptographicHasher<u64, [u64; CHUNK]> for MockHasher {
108        fn hash_iter<I: IntoIterator<Item = u64>>(&self, iter: I) -> [u64; CHUNK] {
109            let sum: u64 = iter.into_iter().sum();
110            // Simplest impl: set every element to the sum
111            [sum; CHUNK]
112        }
113    }
114
115    #[test]
116    fn test_truncated_permutation_compress() {
117        const N: usize = 2;
118        const CHUNK: usize = 4;
119        const WIDTH: usize = 8;
120
121        let permutation = MockPermutation;
122        let compressor = TruncatedPermutation::<MockPermutation, N, CHUNK, WIDTH>::new(permutation);
123
124        let input: [[u64; CHUNK]; N] = [[1, 2, 3, 4], [5, 6, 7, 8]];
125        let output = compressor.compress(input);
126        let expected_sum = 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8;
127
128        assert_eq!(output, [expected_sum; CHUNK]);
129    }
130
131    #[test]
132    fn test_compression_function_from_hasher_compress() {
133        const N: usize = 2;
134        const CHUNK: usize = 4;
135
136        let hasher = MockHasher;
137        let compressor = CompressionFunctionFromHasher::<MockHasher, N, CHUNK>::new(hasher);
138
139        let input = [[10, 20, 30, 40], [50, 60, 70, 80]];
140        let output = compressor.compress(input);
141        let expected_sum = 10 + 20 + 30 + 40 + 50 + 60 + 70 + 80;
142
143        assert_eq!(output, [expected_sum; CHUNK]);
144    }
145
146    #[test]
147    fn test_truncated_permutation_with_zeros() {
148        const N: usize = 2;
149        const CHUNK: usize = 4;
150        const WIDTH: usize = 8;
151
152        let permutation = MockPermutation;
153        let compressor = TruncatedPermutation::<MockPermutation, N, CHUNK, WIDTH>::new(permutation);
154
155        let input: [[u64; CHUNK]; N] = [[0, 0, 0, 0], [0, 0, 0, 0]];
156        let output = compressor.compress(input);
157
158        assert_eq!(output, [0; CHUNK]);
159    }
160
161    #[test]
162    fn test_truncated_permutation_with_extra_width() {
163        const N: usize = 2;
164        const CHUNK: usize = 3;
165        const WIDTH: usize = 10; // More than `CHUNK * N` (6 < 10)
166
167        let permutation = MockPermutation;
168        let compressor = TruncatedPermutation::<MockPermutation, N, CHUNK, WIDTH>::new(permutation);
169
170        let input: [[u64; CHUNK]; N] = [[1, 2, 3], [4, 5, 6]];
171        let output = compressor.compress(input);
172
173        let expected_sum = 1 + 2 + 3 + 4 + 5 + 6;
174
175        assert_eq!(
176            output, [expected_sum; CHUNK],
177            "Compression should correctly handle extra WIDTH space."
178        );
179    }
180}