p3_poseidon2/external.rs
1//! External layers for the Poseidon2 permutation.
2//!
3//! Poseidon2 applies *external layers* at both the beginning and end of the permutation.
4//! These layers are critical for ensuring proper diffusion and enhancing security,
5//! particularly against structural and algebraic attacks.
6//!
7//! An external round consists of:
8//! 1. Addition of round constants,
9//! 2. Application of a nonlinear S-box,
10//! 3. A lightweight matrix multiplication (external linear layer).
11//!
12//! The constants and linear transformations used in these rounds are designed
13//! to complement the internal structure of Poseidon2.
14//!
15//! Main purpose of these constants:
16//! - Inject randomness between rounds.
17
18use alloc::vec::Vec;
19
20use p3_field::{Field, PrimeCharacteristicRing};
21use p3_mds::MdsPermutation;
22use p3_symmetric::Permutation;
23use rand::Rng;
24use rand::distr::{Distribution, StandardUniform};
25
26/// Multiply a 4-element vector x by
27/// [ 5 7 1 3 ]
28/// [ 4 6 1 1 ]
29/// [ 1 3 5 7 ]
30/// [ 1 1 4 6 ].
31/// This uses the formula from the start of Appendix B in the Poseidon2 paper, with multiplications unrolled into additions.
32/// It is also the matrix used by the Horizon Labs implementation.
33#[inline(always)]
34fn apply_hl_mat4<R>(x: &mut [R; 4])
35where
36 R: PrimeCharacteristicRing,
37{
38 let t0 = x[0].clone() + x[1].clone();
39 let t1 = x[2].clone() + x[3].clone();
40 let t2 = x[1].clone() + x[1].clone() + t1.clone();
41 let t3 = x[3].clone() + x[3].clone() + t0.clone();
42 let t4 = t1.double().double() + t3.clone();
43 let t5 = t0.double().double() + t2.clone();
44 let t6 = t3 + t5.clone();
45 let t7 = t2 + t4.clone();
46 x[0] = t6;
47 x[1] = t5;
48 x[2] = t7;
49 x[3] = t4;
50}
51
52// It turns out we can find a 4x4 matrix which is more efficient than the above.
53
54/// Multiply a 4-element vector x by:
55/// [ 2 3 1 1 ]
56/// [ 1 2 3 1 ]
57/// [ 1 1 2 3 ]
58/// [ 3 1 1 2 ].
59#[inline(always)]
60fn apply_mat4<R>(x: &mut [R; 4])
61where
62 R: PrimeCharacteristicRing,
63{
64 let t01 = x[0].clone() + x[1].clone();
65 let t23 = x[2].clone() + x[3].clone();
66 let t0123 = t01.clone() + t23.clone();
67 let t01123 = t0123.clone() + x[1].clone();
68 let t01233 = t0123 + x[3].clone();
69 // The order here is important. Need to overwrite x[0] and x[2] after x[1] and x[3].
70 x[3] = t01233.clone() + x[0].double(); // 3*x[0] + x[1] + x[2] + 2*x[3]
71 x[1] = t01123.clone() + x[2].double(); // x[0] + 2*x[1] + 3*x[2] + x[3]
72 x[0] = t01123 + t01; // 2*x[0] + 3*x[1] + x[2] + x[3]
73 x[2] = t01233 + t23; // x[0] + x[1] + 2*x[2] + 3*x[3]
74}
75
76/// The 4x4 MDS matrix used by the Horizon Labs implementation of Poseidon2.
77///
78/// This requires 10 additions and 4 doubles to compute.
79#[derive(Clone, Default)]
80pub struct HLMDSMat4;
81
82impl<R: PrimeCharacteristicRing> Permutation<[R; 4]> for HLMDSMat4 {
83 #[inline(always)]
84 fn permute_mut(&self, input: &mut [R; 4]) {
85 apply_hl_mat4(input);
86 }
87}
88impl<R: PrimeCharacteristicRing> MdsPermutation<R, 4> for HLMDSMat4 {}
89
90/// The fastest 4x4 MDS matrix.
91///
92/// This requires 7 additions and 2 doubles to compute.
93#[derive(Clone, Default)]
94pub struct MDSMat4;
95
96impl<R: PrimeCharacteristicRing> Permutation<[R; 4]> for MDSMat4 {
97 #[inline(always)]
98 fn permute_mut(&self, input: &mut [R; 4]) {
99 apply_mat4(input);
100 }
101}
102impl<R: PrimeCharacteristicRing> MdsPermutation<R, 4> for MDSMat4 {}
103
104/// Implement the matrix multiplication used by the external layer.
105///
106/// Given a 4x4 MDS matrix M, we multiply by the `4N x 4N` matrix
107/// `[[2M M ... M], [M 2M ... M], ..., [M M ... 2M]]`.
108///
109/// # Panics
110/// This will panic if `WIDTH` is not supported. Currently, the
111/// supported `WIDTH` values are 2, 3, 4, 8, 12, 16, 20, 24.`
112#[inline(always)]
113pub fn mds_light_permutation<
114 R: PrimeCharacteristicRing,
115 MdsPerm4: MdsPermutation<R, 4>,
116 const WIDTH: usize,
117>(
118 state: &mut [R; WIDTH],
119 mdsmat: &MdsPerm4,
120) {
121 match WIDTH {
122 2 => {
123 let sum = state[0].clone() + state[1].clone();
124 state[0] += sum.clone();
125 state[1] += sum;
126 }
127
128 3 => {
129 let sum = state[0].clone() + state[1].clone() + state[2].clone();
130 state[0] += sum.clone();
131 state[1] += sum.clone();
132 state[2] += sum;
133 }
134
135 4 | 8 | 12 | 16 | 20 | 24 => {
136 // First, we apply M_4 to each consecutive four elements of the state.
137 // In Appendix B's terminology, this replaces each x_i with x_i'.
138 for chunk in state.chunks_exact_mut(4) {
139 mdsmat.permute_mut(chunk.try_into().unwrap());
140 }
141 // Now, we apply the outer circulant matrix (to compute the y_i values).
142
143 // We first precompute the four sums of every four elements.
144 let sums: [R; 4] =
145 core::array::from_fn(|k| (0..WIDTH).step_by(4).map(|j| state[j + k].clone()).sum());
146
147 // The formula for each y_i involves 2x_i' term and x_j' terms for each j that equals i mod 4.
148 // In other words, we can add a single copy of x_i' to the appropriate one of our precomputed sums
149 state
150 .iter_mut()
151 .enumerate()
152 .for_each(|(i, elem)| *elem += sums[i % 4].clone());
153 }
154
155 _ => {
156 panic!("Unsupported width");
157 }
158 }
159}
160
161/// A struct which stores round-specific constants for both initial and terminal external layers.
162#[derive(Debug, Clone)]
163pub struct ExternalLayerConstants<T, const WIDTH: usize> {
164 /// Constants applied during each initial external round.
165 ///
166 /// Used in `permute_state_initial`. Each `[T; WIDTH]` is a full-width vector of constants.
167 initial: Vec<[T; WIDTH]>,
168
169 /// Constants applied during each terminal external round.
170 ///
171 /// Used in `permute_state_terminal`. The term "terminal" avoids using Rust’s reserved word `final`.
172 terminal: Vec<[T; WIDTH]>,
173}
174
175impl<T, const WIDTH: usize> ExternalLayerConstants<T, WIDTH> {
176 /// Create a new instance of external layer constants.
177 ///
178 /// # Panics
179 /// Panics if `initial.len() != terminal.len()` since the Poseidon2 spec requires
180 /// the same number of initial and terminal rounds to maintain symmetry.
181 pub const fn new(initial: Vec<[T; WIDTH]>, terminal: Vec<[T; WIDTH]>) -> Self {
182 assert!(
183 initial.len() == terminal.len(),
184 "The number of initial and terminal external rounds should be equal."
185 );
186 Self { initial, terminal }
187 }
188
189 /// Randomly generate a new set of external constants using a provided RNG.
190 ///
191 /// # Arguments
192 /// - `external_round_number`: Total number of external rounds (must be even).
193 /// - `rng`: A random number generator that supports uniform sampling.
194 ///
195 /// The constants are split equally between the initial and terminal rounds.
196 ///
197 /// # Panics
198 /// Panics if `external_round_number` is not even.
199 pub fn new_from_rng<R: Rng>(external_round_number: usize, rng: &mut R) -> Self
200 where
201 StandardUniform: Distribution<[T; WIDTH]>,
202 {
203 let half_f = external_round_number / 2;
204 assert_eq!(
205 2 * half_f,
206 external_round_number,
207 "The total number of external rounds should be even"
208 );
209 let initial_constants = rng.sample_iter(StandardUniform).take(half_f).collect();
210 let terminal_constants = rng.sample_iter(StandardUniform).take(half_f).collect();
211
212 Self::new(initial_constants, terminal_constants)
213 }
214
215 /// Construct constants from statically stored arrays, using a conversion function.
216 ///
217 /// This is useful when deserializing precomputed constants or embedding
218 /// them directly in the codebase (e.g., from `[[[u32; WIDTH]; N]; 2]` arrays).
219 ///
220 /// # Arguments
221 /// - `initial`, `terminal`: Two fixed-size arrays of size `N` containing round constants.
222 /// - `conversion_fn`: A function to convert from the source type `U` to `T`.
223 pub fn new_from_saved_array<U, const N: usize>(
224 [initial, terminal]: [[[U; WIDTH]; N]; 2],
225 conversion_fn: fn([U; WIDTH]) -> [T; WIDTH],
226 ) -> Self
227 where
228 T: Clone,
229 {
230 let initial_consts = initial.map(conversion_fn).to_vec();
231 let terminal_consts = terminal.map(conversion_fn).to_vec();
232 Self::new(initial_consts, terminal_consts)
233 }
234
235 /// Get a reference to the list of initial round constants.
236 ///
237 /// These are used in the first half of the external rounds.
238 pub const fn get_initial_constants(&self) -> &Vec<[T; WIDTH]> {
239 &self.initial
240 }
241
242 /// Get a reference to the list of terminal round constants.
243 ///
244 /// These are used in the second half (terminal rounds) of the external layer.
245 pub const fn get_terminal_constants(&self) -> &Vec<[T; WIDTH]> {
246 &self.terminal
247 }
248}
249
250/// Initialize an external layer from a set of constants.
251pub trait ExternalLayerConstructor<F, const WIDTH: usize>
252where
253 F: Field,
254{
255 /// A constructor which internally will convert the supplied
256 /// constants into the appropriate form for the implementation.
257 fn new_from_constants(external_constants: ExternalLayerConstants<F, WIDTH>) -> Self;
258}
259
260/// A trait containing all data needed to implement the external layers of Poseidon2.
261pub trait ExternalLayer<R, const WIDTH: usize, const D: u64>: Sync + Clone
262where
263 R: PrimeCharacteristicRing,
264{
265 // permute_state_initial, permute_state_terminal are split as the Poseidon2 specifications are slightly different
266 // with the initial rounds involving an extra matrix multiplication.
267
268 /// Perform the initial external layers of the Poseidon2 permutation on the given state.
269 fn permute_state_initial(&self, state: &mut [R; WIDTH]);
270
271 /// Perform the terminal external layers of the Poseidon2 permutation on the given state.
272 fn permute_state_terminal(&self, state: &mut [R; WIDTH]);
273}
274
275/// Applies the terminal external rounds of the Poseidon2 permutation.
276///
277/// Each external round consists of three steps:
278/// 1. Adding round constants to each element of the state.
279/// 2. Apply the S-box to each element of the state.
280/// 3. Applying an external linear layer (based on a `4x4` MDS matrix).
281///
282/// # Parameters
283/// - `state`: The current state of the permutation (size `WIDTH`).
284/// - `terminal_external_constants`: Per-round constants which are added to each state element.
285/// - `add_rc_and_sbox`: A function that adds the round constant and applies the S-box to a given element.
286/// - `mat4`: The 4x4 MDS matrix used in the external linear layer.
287#[inline]
288pub fn external_terminal_permute_state<
289 R: PrimeCharacteristicRing,
290 CT: Copy, // Whatever type the constants are stored as.
291 MdsPerm4: MdsPermutation<R, 4>,
292 const WIDTH: usize,
293>(
294 state: &mut [R; WIDTH],
295 terminal_external_constants: &[[CT; WIDTH]],
296 add_rc_and_sbox: fn(&mut R, CT),
297 mat4: &MdsPerm4,
298) {
299 for elem in terminal_external_constants {
300 state
301 .iter_mut()
302 .zip(elem.iter())
303 .for_each(|(s, &rc)| add_rc_and_sbox(s, rc));
304 mds_light_permutation(state, mat4);
305 }
306}
307
308/// Applies the initial external rounds of the Poseidon2 permutation.
309///
310/// Apply the external linear layer and run a sequence of standard external rounds consisting of
311/// 1. Adding round constants to each element of the state.
312/// 2. Apply the S-box to each element of the state.
313/// 3. Applying an external linear layer (based on a `4x4` MDS matrix).
314///
315/// # Parameters
316/// - `state`: The state array at the start of the permutation.
317/// - `initial_external_constants`: Per-round constants which are added to each state element.
318/// - `add_rc_and_sbox`: A function that adds the round constant and applies the S-box to a given element.
319/// - `mat4`: The 4x4 MDS matrix used in the external linear layer.
320#[inline]
321pub fn external_initial_permute_state<
322 R: PrimeCharacteristicRing,
323 CT: Copy, // Whatever type the constants are stored as.
324 MdsPerm4: MdsPermutation<R, 4>,
325 const WIDTH: usize,
326>(
327 state: &mut [R; WIDTH],
328 initial_external_constants: &[[CT; WIDTH]],
329 add_rc_and_sbox: fn(&mut R, CT),
330 mat4: &MdsPerm4,
331) {
332 mds_light_permutation(state, mat4);
333 // After the initial mds_light_permutation, the remaining layers are identical
334 // to the terminal permutation simply with different constants.
335 external_terminal_permute_state(state, initial_external_constants, add_rc_and_sbox, mat4);
336}
337
338#[cfg(test)]
339mod tests {
340 use p3_baby_bear::BabyBear;
341 use rand::SeedableRng;
342 use rand::rngs::SmallRng;
343
344 use super::*;
345
346 type F = BabyBear;
347
348 #[test]
349 fn test_apply_mat4() {
350 // Use a seeded RNG
351 let mut rng = SmallRng::seed_from_u64(12345678);
352
353 // Define a test input: x = [x0, x1, x2, x3]
354 let x0: F = rng.random();
355 let x1: F = rng.random();
356 let x2: F = rng.random();
357 let x3: F = rng.random();
358 let mut x = [x0, x1, x2, x3];
359
360 // Apply the matrix transformation in place
361 apply_mat4(&mut x);
362
363 // We compute the expected values according to the matrix multiplication formula.
364 // [ 2 3 1 1 ]
365 // [ 1 2 3 1 ]
366 // [ 1 1 2 3 ]
367 // [ 3 1 1 2 ]
368 let expected = [
369 F::TWO * x0 + F::from_u8(3) * x1 + x2 + x3,
370 x0 + F::TWO * x1 + F::from_u8(3) * x2 + x3,
371 x0 + x1 + F::TWO * x2 + F::from_u8(3) * x3,
372 F::from_u8(3) * x0 + x1 + x2 + F::TWO * x3,
373 ];
374
375 assert_eq!(x, expected, "apply_mat4 did not produce expected output");
376 }
377
378 #[test]
379 fn test_apply_hl_mat4_with_manual_verification() {
380 // Use a seeded RNG
381 let mut rng = SmallRng::seed_from_u64(87654321);
382
383 // Generate random values for the input vector
384 let x0: F = rng.random();
385 let x1: F = rng.random();
386 let x2: F = rng.random();
387 let x3: F = rng.random();
388 let mut x = [x0, x1, x2, x3];
389
390 // Apply the hl matrix in-place
391 apply_hl_mat4(&mut x);
392
393 // Manually compute the result of multiplying by:
394 // [ 5 7 1 3 ]
395 // [ 4 6 1 1 ]
396 // [ 1 3 5 7 ]
397 // [ 1 1 4 6 ]
398 let expected = [
399 F::from_u8(5) * x0 + F::from_u8(7) * x1 + x2 + F::from_u8(3) * x3,
400 F::from_u8(4) * x0 + F::from_u8(6) * x1 + x2 + x3,
401 x0 + F::from_u8(3) * x1 + F::from_u8(5) * x2 + F::from_u8(7) * x3,
402 x0 + x1 + F::from_u8(4) * x2 + F::from_u8(6) * x3,
403 ];
404
405 assert_eq!(x, expected, "apply_hl_mat4 did not produce expected output");
406 }
407}