p3_poseidon2/
external.rs

1//! External layers for the Poseidon2 permutation.
2//!
3//! Poseidon2 applies *external layers* at both the beginning and end of the permutation.
4//! These layers are critical for ensuring proper diffusion and enhancing security,
5//! particularly against structural and algebraic attacks.
6//!
7//! An external round consists of:
8//! 1. Addition of round constants,
9//! 2. Application of a nonlinear S-box,
10//! 3. A lightweight matrix multiplication (external linear layer).
11//!
12//! The constants and linear transformations used in these rounds are designed
13//! to complement the internal structure of Poseidon2.
14//!
15//! Main purpose of these constants:
16//! - Inject randomness between rounds.
17
18use alloc::vec::Vec;
19
20use p3_field::{Field, PrimeCharacteristicRing};
21use p3_mds::MdsPermutation;
22use p3_symmetric::Permutation;
23use rand::Rng;
24use rand::distr::{Distribution, StandardUniform};
25
26/// Multiply a 4-element vector x by
27/// [ 5 7 1 3 ]
28/// [ 4 6 1 1 ]
29/// [ 1 3 5 7 ]
30/// [ 1 1 4 6 ].
31/// This uses the formula from the start of Appendix B in the Poseidon2 paper, with multiplications unrolled into additions.
32/// It is also the matrix used by the Horizon Labs implementation.
33#[inline(always)]
34fn apply_hl_mat4<R>(x: &mut [R; 4])
35where
36    R: PrimeCharacteristicRing,
37{
38    let t0 = x[0].clone() + x[1].clone();
39    let t1 = x[2].clone() + x[3].clone();
40    let t2 = x[1].clone() + x[1].clone() + t1.clone();
41    let t3 = x[3].clone() + x[3].clone() + t0.clone();
42    let t4 = t1.double().double() + t3.clone();
43    let t5 = t0.double().double() + t2.clone();
44    let t6 = t3 + t5.clone();
45    let t7 = t2 + t4.clone();
46    x[0] = t6;
47    x[1] = t5;
48    x[2] = t7;
49    x[3] = t4;
50}
51
52// It turns out we can find a 4x4 matrix which is more efficient than the above.
53
54/// Multiply a 4-element vector x by:
55/// [ 2 3 1 1 ]
56/// [ 1 2 3 1 ]
57/// [ 1 1 2 3 ]
58/// [ 3 1 1 2 ].
59#[inline(always)]
60fn apply_mat4<R>(x: &mut [R; 4])
61where
62    R: PrimeCharacteristicRing,
63{
64    let t01 = x[0].clone() + x[1].clone();
65    let t23 = x[2].clone() + x[3].clone();
66    let t0123 = t01.clone() + t23.clone();
67    let t01123 = t0123.clone() + x[1].clone();
68    let t01233 = t0123 + x[3].clone();
69    // The order here is important. Need to overwrite x[0] and x[2] after x[1] and x[3].
70    x[3] = t01233.clone() + x[0].double(); // 3*x[0] + x[1] + x[2] + 2*x[3]
71    x[1] = t01123.clone() + x[2].double(); // x[0] + 2*x[1] + 3*x[2] + x[3]
72    x[0] = t01123 + t01; // 2*x[0] + 3*x[1] + x[2] + x[3]
73    x[2] = t01233 + t23; // x[0] + x[1] + 2*x[2] + 3*x[3]
74}
75
76/// The 4x4 MDS matrix used by the Horizon Labs implementation of Poseidon2.
77///
78/// This requires 10 additions and 4 doubles to compute.
79#[derive(Clone, Default)]
80pub struct HLMDSMat4;
81
82impl<R: PrimeCharacteristicRing> Permutation<[R; 4]> for HLMDSMat4 {
83    #[inline(always)]
84    fn permute_mut(&self, input: &mut [R; 4]) {
85        apply_hl_mat4(input);
86    }
87}
88impl<R: PrimeCharacteristicRing> MdsPermutation<R, 4> for HLMDSMat4 {}
89
90/// The fastest 4x4 MDS matrix.
91///
92/// This requires 7 additions and 2 doubles to compute.
93#[derive(Clone, Default)]
94pub struct MDSMat4;
95
96impl<R: PrimeCharacteristicRing> Permutation<[R; 4]> for MDSMat4 {
97    #[inline(always)]
98    fn permute_mut(&self, input: &mut [R; 4]) {
99        apply_mat4(input);
100    }
101}
102impl<R: PrimeCharacteristicRing> MdsPermutation<R, 4> for MDSMat4 {}
103
104/// Implement the matrix multiplication used by the external layer.
105///
106/// Given a 4x4 MDS matrix M, we multiply by the `4N x 4N` matrix
107/// `[[2M M  ... M], [M  2M ... M], ..., [M  M ... 2M]]`.
108///
109/// # Panics
110/// This will panic if `WIDTH` is not supported. Currently, the
111/// supported `WIDTH` values are 2, 3, 4, 8, 12, 16, 20, 24.`
112#[inline(always)]
113pub fn mds_light_permutation<
114    R: PrimeCharacteristicRing,
115    MdsPerm4: MdsPermutation<R, 4>,
116    const WIDTH: usize,
117>(
118    state: &mut [R; WIDTH],
119    mdsmat: &MdsPerm4,
120) {
121    match WIDTH {
122        2 => {
123            let sum = state[0].clone() + state[1].clone();
124            state[0] += sum.clone();
125            state[1] += sum;
126        }
127
128        3 => {
129            let sum = state[0].clone() + state[1].clone() + state[2].clone();
130            state[0] += sum.clone();
131            state[1] += sum.clone();
132            state[2] += sum;
133        }
134
135        4 | 8 | 12 | 16 | 20 | 24 => {
136            // First, we apply M_4 to each consecutive four elements of the state.
137            // In Appendix B's terminology, this replaces each x_i with x_i'.
138            for chunk in state.chunks_exact_mut(4) {
139                mdsmat.permute_mut(chunk.try_into().unwrap());
140            }
141            // Now, we apply the outer circulant matrix (to compute the y_i values).
142
143            // We first precompute the four sums of every four elements.
144            let sums: [R; 4] =
145                core::array::from_fn(|k| (0..WIDTH).step_by(4).map(|j| state[j + k].clone()).sum());
146
147            // The formula for each y_i involves 2x_i' term and x_j' terms for each j that equals i mod 4.
148            // In other words, we can add a single copy of x_i' to the appropriate one of our precomputed sums
149            state
150                .iter_mut()
151                .enumerate()
152                .for_each(|(i, elem)| *elem += sums[i % 4].clone());
153        }
154
155        _ => {
156            panic!("Unsupported width");
157        }
158    }
159}
160
161/// A struct which stores round-specific constants for both initial and terminal external layers.
162#[derive(Debug, Clone)]
163pub struct ExternalLayerConstants<T, const WIDTH: usize> {
164    /// Constants applied during each initial external round.
165    ///
166    /// Used in `permute_state_initial`. Each `[T; WIDTH]` is a full-width vector of constants.
167    initial: Vec<[T; WIDTH]>,
168
169    /// Constants applied during each terminal external round.
170    ///
171    /// Used in `permute_state_terminal`. The term "terminal" avoids using Rust’s reserved word `final`.
172    terminal: Vec<[T; WIDTH]>,
173}
174
175impl<T, const WIDTH: usize> ExternalLayerConstants<T, WIDTH> {
176    /// Create a new instance of external layer constants.
177    ///
178    /// # Panics
179    /// Panics if `initial.len() != terminal.len()` since the Poseidon2 spec requires
180    /// the same number of initial and terminal rounds to maintain symmetry.
181    pub const fn new(initial: Vec<[T; WIDTH]>, terminal: Vec<[T; WIDTH]>) -> Self {
182        assert!(
183            initial.len() == terminal.len(),
184            "The number of initial and terminal external rounds should be equal."
185        );
186        Self { initial, terminal }
187    }
188
189    /// Randomly generate a new set of external constants using a provided RNG.
190    ///
191    /// # Arguments
192    /// - `external_round_number`: Total number of external rounds (must be even).
193    /// - `rng`: A random number generator that supports uniform sampling.
194    ///
195    /// The constants are split equally between the initial and terminal rounds.
196    ///
197    /// # Panics
198    /// Panics if `external_round_number` is not even.
199    pub fn new_from_rng<R: Rng>(external_round_number: usize, rng: &mut R) -> Self
200    where
201        StandardUniform: Distribution<[T; WIDTH]>,
202    {
203        let half_f = external_round_number / 2;
204        assert_eq!(
205            2 * half_f,
206            external_round_number,
207            "The total number of external rounds should be even"
208        );
209        let initial_constants = rng.sample_iter(StandardUniform).take(half_f).collect();
210        let terminal_constants = rng.sample_iter(StandardUniform).take(half_f).collect();
211
212        Self::new(initial_constants, terminal_constants)
213    }
214
215    /// Construct constants from statically stored arrays, using a conversion function.
216    ///
217    /// This is useful when deserializing precomputed constants or embedding
218    /// them directly in the codebase (e.g., from `[[[u32; WIDTH]; N]; 2]` arrays).
219    ///
220    /// # Arguments
221    /// - `initial`, `terminal`: Two fixed-size arrays of size `N` containing round constants.
222    /// - `conversion_fn`: A function to convert from the source type `U` to `T`.
223    pub fn new_from_saved_array<U, const N: usize>(
224        [initial, terminal]: [[[U; WIDTH]; N]; 2],
225        conversion_fn: fn([U; WIDTH]) -> [T; WIDTH],
226    ) -> Self
227    where
228        T: Clone,
229    {
230        let initial_consts = initial.map(conversion_fn).to_vec();
231        let terminal_consts = terminal.map(conversion_fn).to_vec();
232        Self::new(initial_consts, terminal_consts)
233    }
234
235    /// Get a reference to the list of initial round constants.
236    ///
237    /// These are used in the first half of the external rounds.
238    pub const fn get_initial_constants(&self) -> &Vec<[T; WIDTH]> {
239        &self.initial
240    }
241
242    /// Get a reference to the list of terminal round constants.
243    ///
244    /// These are used in the second half (terminal rounds) of the external layer.
245    pub const fn get_terminal_constants(&self) -> &Vec<[T; WIDTH]> {
246        &self.terminal
247    }
248}
249
250/// Initialize an external layer from a set of constants.
251pub trait ExternalLayerConstructor<F, const WIDTH: usize>
252where
253    F: Field,
254{
255    /// A constructor which internally will convert the supplied
256    /// constants into the appropriate form for the implementation.
257    fn new_from_constants(external_constants: ExternalLayerConstants<F, WIDTH>) -> Self;
258}
259
260/// A trait containing all data needed to implement the external layers of Poseidon2.
261pub trait ExternalLayer<R, const WIDTH: usize, const D: u64>: Sync + Clone
262where
263    R: PrimeCharacteristicRing,
264{
265    // permute_state_initial, permute_state_terminal are split as the Poseidon2 specifications are slightly different
266    // with the initial rounds involving an extra matrix multiplication.
267
268    /// Perform the initial external layers of the Poseidon2 permutation on the given state.
269    fn permute_state_initial(&self, state: &mut [R; WIDTH]);
270
271    /// Perform the terminal external layers of the Poseidon2 permutation on the given state.
272    fn permute_state_terminal(&self, state: &mut [R; WIDTH]);
273}
274
275/// Applies the terminal external rounds of the Poseidon2 permutation.
276///
277/// Each external round consists of three steps:
278/// 1. Adding round constants to each element of the state.
279/// 2. Apply the S-box to each element of the state.
280/// 3. Applying an external linear layer (based on a `4x4` MDS matrix).
281///
282/// # Parameters
283/// - `state`: The current state of the permutation (size `WIDTH`).
284/// - `terminal_external_constants`: Per-round constants which are added to each state element.
285/// - `add_rc_and_sbox`: A function that adds the round constant and applies the S-box to a given element.
286/// - `mat4`: The 4x4 MDS matrix used in the external linear layer.
287#[inline]
288pub fn external_terminal_permute_state<
289    R: PrimeCharacteristicRing,
290    CT: Copy, // Whatever type the constants are stored as.
291    MdsPerm4: MdsPermutation<R, 4>,
292    const WIDTH: usize,
293>(
294    state: &mut [R; WIDTH],
295    terminal_external_constants: &[[CT; WIDTH]],
296    add_rc_and_sbox: fn(&mut R, CT),
297    mat4: &MdsPerm4,
298) {
299    for elem in terminal_external_constants {
300        state
301            .iter_mut()
302            .zip(elem.iter())
303            .for_each(|(s, &rc)| add_rc_and_sbox(s, rc));
304        mds_light_permutation(state, mat4);
305    }
306}
307
308/// Applies the initial external rounds of the Poseidon2 permutation.
309///
310/// Apply the external linear layer and run a sequence of standard external rounds consisting of
311/// 1. Adding round constants to each element of the state.
312/// 2. Apply the S-box to each element of the state.
313/// 3. Applying an external linear layer (based on a `4x4` MDS matrix).
314///
315/// # Parameters
316/// - `state`: The state array at the start of the permutation.
317/// - `initial_external_constants`: Per-round constants which are added to each state element.
318/// - `add_rc_and_sbox`: A function that adds the round constant and applies the S-box to a given element.
319/// - `mat4`: The 4x4 MDS matrix used in the external linear layer.
320#[inline]
321pub fn external_initial_permute_state<
322    R: PrimeCharacteristicRing,
323    CT: Copy, // Whatever type the constants are stored as.
324    MdsPerm4: MdsPermutation<R, 4>,
325    const WIDTH: usize,
326>(
327    state: &mut [R; WIDTH],
328    initial_external_constants: &[[CT; WIDTH]],
329    add_rc_and_sbox: fn(&mut R, CT),
330    mat4: &MdsPerm4,
331) {
332    mds_light_permutation(state, mat4);
333    // After the initial mds_light_permutation, the remaining layers are identical
334    // to the terminal permutation simply with different constants.
335    external_terminal_permute_state(state, initial_external_constants, add_rc_and_sbox, mat4);
336}
337
338#[cfg(test)]
339mod tests {
340    use p3_baby_bear::BabyBear;
341    use rand::SeedableRng;
342    use rand::rngs::SmallRng;
343
344    use super::*;
345
346    type F = BabyBear;
347
348    #[test]
349    fn test_apply_mat4() {
350        // Use a seeded RNG
351        let mut rng = SmallRng::seed_from_u64(12345678);
352
353        // Define a test input: x = [x0, x1, x2, x3]
354        let x0: F = rng.random();
355        let x1: F = rng.random();
356        let x2: F = rng.random();
357        let x3: F = rng.random();
358        let mut x = [x0, x1, x2, x3];
359
360        // Apply the matrix transformation in place
361        apply_mat4(&mut x);
362
363        // We compute the expected values according to the matrix multiplication formula.
364        // [ 2 3 1 1 ]
365        // [ 1 2 3 1 ]
366        // [ 1 1 2 3 ]
367        // [ 3 1 1 2 ]
368        let expected = [
369            F::TWO * x0 + F::from_u8(3) * x1 + x2 + x3,
370            x0 + F::TWO * x1 + F::from_u8(3) * x2 + x3,
371            x0 + x1 + F::TWO * x2 + F::from_u8(3) * x3,
372            F::from_u8(3) * x0 + x1 + x2 + F::TWO * x3,
373        ];
374
375        assert_eq!(x, expected, "apply_mat4 did not produce expected output");
376    }
377
378    #[test]
379    fn test_apply_hl_mat4_with_manual_verification() {
380        // Use a seeded RNG
381        let mut rng = SmallRng::seed_from_u64(87654321);
382
383        // Generate random values for the input vector
384        let x0: F = rng.random();
385        let x1: F = rng.random();
386        let x2: F = rng.random();
387        let x3: F = rng.random();
388        let mut x = [x0, x1, x2, x3];
389
390        // Apply the hl matrix in-place
391        apply_hl_mat4(&mut x);
392
393        // Manually compute the result of multiplying by:
394        // [ 5 7 1 3 ]
395        // [ 4 6 1 1 ]
396        // [ 1 3 5 7 ]
397        // [ 1 1 4 6 ]
398        let expected = [
399            F::from_u8(5) * x0 + F::from_u8(7) * x1 + x2 + F::from_u8(3) * x3,
400            F::from_u8(4) * x0 + F::from_u8(6) * x1 + x2 + x3,
401            x0 + F::from_u8(3) * x1 + F::from_u8(5) * x2 + F::from_u8(7) * x3,
402            x0 + x1 + F::from_u8(4) * x2 + F::from_u8(6) * x3,
403        ];
404
405        assert_eq!(x, expected, "apply_hl_mat4 did not produce expected output");
406    }
407}