1use p3_field::PrimeCharacteristicRing;
17use p3_poseidon2::{
18 ExternalLayer, GenericPoseidon2LinearLayers, InternalLayer, MDSMat4, Poseidon2,
19 add_rc_and_sbox_generic, external_initial_permute_state, external_terminal_permute_state,
20 internal_permute_state,
21};
22
23use crate::{
24 Mersenne31, Poseidon2ExternalLayerMersenne31, Poseidon2InternalLayerMersenne31, from_u62,
25};
26
27pub(crate) const MERSENNE31_S_BOX_DEGREE: u64 = 5;
33
34pub type Poseidon2Mersenne31<const WIDTH: usize> = Poseidon2<
39 Mersenne31,
40 Poseidon2ExternalLayerMersenne31<WIDTH>,
41 Poseidon2InternalLayerMersenne31,
42 WIDTH,
43 MERSENNE31_S_BOX_DEGREE,
44>;
45
46pub struct GenericPoseidon2LinearLayersMersenne31 {}
52
53const POSEIDON2_INTERNAL_MATRIX_DIAG_16_SHIFTS: [u8; 15] =
54 [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14, 15, 16];
55
56const POSEIDON2_INTERNAL_MATRIX_DIAG_24_SHIFTS: [u8; 23] = [
57 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
58];
59
60fn permute_mut<const N: usize>(state: &mut [Mersenne31; N], shifts: &[u8]) {
64 debug_assert_eq!(shifts.len() + 1, N);
65 let part_sum: u64 = state[1..].iter().map(|x| x.value as u64).sum();
66 let full_sum = part_sum + (state[0].value as u64);
67 let s0 = part_sum + (-state[0]).value as u64;
68 state[0] = from_u62(s0);
69 for i in 1..N {
70 let si = full_sum + ((state[i].value as u64) << shifts[i - 1]);
71 state[i] = from_u62(si);
72 }
73}
74
75impl InternalLayer<Mersenne31, 16, MERSENNE31_S_BOX_DEGREE> for Poseidon2InternalLayerMersenne31 {
76 fn permute_state(&self, state: &mut [Mersenne31; 16]) {
78 internal_permute_state(
79 state,
80 |x| permute_mut(x, &POSEIDON2_INTERNAL_MATRIX_DIAG_16_SHIFTS),
81 &self.internal_constants,
82 );
83 }
84}
85
86impl InternalLayer<Mersenne31, 24, MERSENNE31_S_BOX_DEGREE> for Poseidon2InternalLayerMersenne31 {
87 fn permute_state(&self, state: &mut [Mersenne31; 24]) {
89 internal_permute_state(
90 state,
91 |x| permute_mut(x, &POSEIDON2_INTERNAL_MATRIX_DIAG_24_SHIFTS),
92 &self.internal_constants,
93 );
94 }
95}
96
97impl<const WIDTH: usize> ExternalLayer<Mersenne31, WIDTH, MERSENNE31_S_BOX_DEGREE>
98 for Poseidon2ExternalLayerMersenne31<WIDTH>
99{
100 fn permute_state_initial(&self, state: &mut [Mersenne31; WIDTH]) {
102 external_initial_permute_state(
103 state,
104 self.external_constants.get_initial_constants(),
105 add_rc_and_sbox_generic,
106 &MDSMat4,
107 );
108 }
109
110 fn permute_state_terminal(&self, state: &mut [Mersenne31; WIDTH]) {
112 external_terminal_permute_state(
113 state,
114 self.external_constants.get_terminal_constants(),
115 add_rc_and_sbox_generic,
116 &MDSMat4,
117 );
118 }
119}
120
121impl GenericPoseidon2LinearLayers<16> for GenericPoseidon2LinearLayersMersenne31 {
122 fn internal_linear_layer<R: PrimeCharacteristicRing>(state: &mut [R; 16]) {
123 let part_sum: R = state[1..].iter().cloned().sum();
124 let full_sum = part_sum.clone() + state[0].clone();
125
126 state[0] = part_sum - state[0].clone();
128 state[1] = full_sum.clone() + state[1].clone();
129 state[2] = full_sum.clone() + state[2].double();
130
131 state[1..]
135 .iter_mut()
136 .zip(POSEIDON2_INTERNAL_MATRIX_DIAG_16_SHIFTS)
137 .skip(2)
138 .for_each(|(val, diag_shift)| {
139 *val = full_sum.clone() + val.clone().mul_2exp_u64(diag_shift as u64);
140 });
141 }
142}
143
144impl GenericPoseidon2LinearLayers<24> for GenericPoseidon2LinearLayersMersenne31 {
145 fn internal_linear_layer<R: PrimeCharacteristicRing>(state: &mut [R; 24]) {
146 let part_sum: R = state[1..].iter().cloned().sum();
147 let full_sum = part_sum.clone() + state[0].clone();
148
149 state[0] = part_sum - state[0].clone();
151 state[1] = full_sum.clone() + state[1].clone();
152 state[2] = full_sum.clone() + state[2].double();
153
154 state[1..]
158 .iter_mut()
159 .zip(POSEIDON2_INTERNAL_MATRIX_DIAG_24_SHIFTS)
160 .skip(2)
161 .for_each(|(val, diag_shift)| {
162 *val = full_sum.clone() + val.clone().mul_2exp_u64(diag_shift as u64);
163 });
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use p3_symmetric::Permutation;
170 use rand::SeedableRng;
171 use rand_xoshiro::Xoroshiro128Plus;
172
173 use super::*;
174
175 type F = Mersenne31;
176
177 #[test]
185 fn test_poseidon2_width_16_random() {
186 let mut input: [F; 16] = Mersenne31::new_array([
187 894848333, 1437655012, 1200606629, 1690012884, 71131202, 1749206695, 1717947831,
188 120589055, 19776022, 42382981, 1831865506, 724844064, 171220207, 1299207443, 227047920,
189 1783754913,
190 ]);
191
192 let expected: [F; 16] = Mersenne31::new_array([
193 1124552602, 2127602268, 1834113265, 1207687593, 1891161485, 245915620, 981277919,
194 627265710, 1534924153, 1580826924, 887997842, 1526280482, 547791593, 1028672510,
195 1803086471, 323071277,
196 ]);
197
198 let mut rng = Xoroshiro128Plus::seed_from_u64(1);
199 let perm = Poseidon2Mersenne31::new_from_rng_128(&mut rng);
200
201 perm.permute_mut(&mut input);
202 assert_eq!(input, expected);
203 }
204
205 #[test]
210 fn test_poseidon2_width_24_random() {
211 let mut input: [F; 24] = Mersenne31::new_array([
212 886409618, 1327899896, 1902407911, 591953491, 648428576, 1844789031, 1198336108,
213 355597330, 1799586834, 59617783, 790334801, 1968791836, 559272107, 31054313,
214 1042221543, 474748436, 135686258, 263665994, 1962340735, 1741539604, 2026927696,
215 449439011, 1131357108, 50869465,
216 ]);
217
218 let expected: [F; 24] = Mersenne31::new_array([
219 87189408, 212775836, 954807335, 1424761838, 1222521810, 1264950009, 1891204592,
220 710452896, 957091834, 1776630156, 1091081383, 786687731, 1101902149, 1281649821,
221 436070674, 313565599, 1961711763, 2002894460, 2040173120, 854107426, 25198245,
222 1967213543, 604802266, 2086190331,
223 ]);
224
225 let mut rng = Xoroshiro128Plus::seed_from_u64(1);
226 let perm = Poseidon2Mersenne31::new_from_rng_128(&mut rng);
227
228 perm.permute_mut(&mut input);
229 assert_eq!(input, expected);
230 }
231}