Skip to main content

p3_poseidon1/
external.rs

1//! Full (external) round layers for the Poseidon1 permutation.
2//!
3//! # Overview
4//!
5//! Full rounds apply the S-box to **every** state element, providing strong resistance
6//! against statistical attacks (differential, linear, truncated differential, rebound).
7//! The Poseidon1 paper requires at least RF = 6 full rounds for 128-bit security against
8//! these attacks (see Section 5 and Appendix C of the paper).
9//!
10//! # Round Structure
11//!
12//! Each full round applies three operations in sequence:
13//!
14//! ```text
15//!   state → AddRoundConstants → S-box(all elements) → MDS multiply → state'
16//! ```
17//!
18//! The MDS multiply is dispatched via the [`Permutation`] trait, allowing concrete fields
19//! to use fast convolution (e.g., Karatsuba) while generic `Algebra<F>` types fall back
20//! to O(t^2) dense multiplication.
21//!
22//! # Cost
23//!
24//! Each full round costs t S-box evaluations + O(t^2) for the dense MDS multiply,
25//! giving a total full-round cost of O(RF * t^2). Since RF is small (typically 8),
26//! this is acceptable even for large t.
27
28use alloc::vec::Vec;
29
30use p3_field::{Algebra, Field, InjectiveMonomial, PrimeCharacteristicRing};
31use p3_symmetric::Permutation;
32
33/// Pre-computed constants for the full (external) rounds.
34///
35/// The full rounds are split equally: half before the partial rounds (initial),
36/// and half after (terminal).
37///
38/// The MDS matrix is **not** stored here. It is dispatched through a permutation
39/// trait at the call site. This allows concrete fields to use optimized
40/// implementations (e.g., Karatsuba convolution) while generic algebra types
41/// fall back to dense O(t^2) multiplication.
42#[derive(Debug, Clone)]
43pub struct FullRoundConstants<F, const WIDTH: usize> {
44    /// Round constants for the initial full rounds.
45    pub initial: Vec<[F; WIDTH]>,
46
47    /// Round constants for the terminal full rounds.
48    pub terminal: Vec<[F; WIDTH]>,
49
50    /// Dense N x N MDS matrix expanded from the circulant first column.
51    ///
52    /// The scalar MDS path uses a Karatsuba convolution with `i64` intermediates.
53    /// That approach relies on bit-shifts for halving.
54    ///
55    /// Packed SIMD types cannot perform bit-shifts. They only support field
56    /// arithmetic.
57    ///
58    /// Storing the fully expanded matrix lets SIMD implementations either:
59    ///
60    /// - Fall back to dense O(t^2) multiplication over `Algebra<F>`.
61    /// - Extract the first column for a field-level Karatsuba that uses
62    ///   `halve()` instead of bit-shifts.
63    pub dense_mds: [[F; WIDTH]; WIDTH],
64}
65
66/// Construct a full round layer from pre-computed constants.
67pub trait FullRoundLayerConstructor<F: Field, const WIDTH: usize> {
68    /// Build the layer from the full-round constants.
69    fn new_from_constants(constants: FullRoundConstants<F, WIDTH>) -> Self;
70}
71
72/// The full (external) round layer of the Poseidon1 permutation.
73///
74/// Implementors apply the RF/2 initial or terminal full rounds to the state.
75/// Field-specific implementations (e.g., NEON, AVX2) can override the generic
76/// behavior for better performance.
77pub trait FullRoundLayer<R, const WIDTH: usize, const D: u64>: Sync + Clone
78where
79    R: PrimeCharacteristicRing,
80{
81    /// Apply the RF/2 initial full rounds.
82    fn permute_state_initial(&self, state: &mut [R; WIDTH]);
83
84    /// Apply the RF/2 terminal full rounds.
85    fn permute_state_terminal(&self, state: &mut [R; WIDTH]);
86}
87
88/// Dense matrix-vector multiplication in O(t^2).
89///
90/// Only used for the non-circulant transition matrix in partial rounds, which
91/// is applied once per permutation call. The circulant MDS multiply in full
92/// rounds uses the MDS crate via the permutation trait instead.
93#[inline]
94pub fn mds_multiply<F: PrimeCharacteristicRing, A: Algebra<F>, const WIDTH: usize>(
95    state: &mut [A; WIDTH],
96    mds: &[[F; WIDTH]; WIDTH],
97) {
98    // Snapshot the current state before overwriting.
99    let input = state.clone();
100
101    // Compute each output element as a dot product of one MDS row with the input.
102    for (out, row) in state.iter_mut().zip(mds.iter()) {
103        *out = A::mixed_dot_product(&input, row);
104    }
105}
106
107/// Apply the initial full rounds (generic implementation).
108///
109/// Each round: add round constants, S-box on all elements, MDS multiply.
110/// The MDS multiply is dispatched via the permutation trait parameter.
111#[inline]
112pub fn full_round_initial_permute_state<
113    F: Field,
114    A: Algebra<F> + InjectiveMonomial<D>,
115    Mds: Permutation<[A; WIDTH]>,
116    const WIDTH: usize,
117    const D: u64,
118>(
119    state: &mut [A; WIDTH],
120    constants: &FullRoundConstants<F, WIDTH>,
121    mds: &Mds,
122) {
123    for round_constants in &constants.initial {
124        // AddRoundConstants: state[i] += rc[i].
125        for (s, &rc) in state.iter_mut().zip(round_constants.iter()) {
126            *s += rc;
127        }
128        // S-box: state[i] = state[i]^D for all i.
129        for s in state.iter_mut() {
130            *s = s.injective_exp_n();
131        }
132        // MixLayer: dispatched via Permutation trait.
133        mds.permute_mut(state);
134    }
135}
136
137/// Apply the terminal full rounds (generic implementation).
138///
139/// Same structure as the initial full rounds, but uses the terminal constants.
140#[inline]
141pub fn full_round_terminal_permute_state<
142    F: Field,
143    A: Algebra<F> + InjectiveMonomial<D>,
144    Mds: Permutation<[A; WIDTH]>,
145    const WIDTH: usize,
146    const D: u64,
147>(
148    state: &mut [A; WIDTH],
149    constants: &FullRoundConstants<F, WIDTH>,
150    mds: &Mds,
151) {
152    for round_constants in &constants.terminal {
153        // AddRoundConstants: state[i] += rc[i].
154        for (s, &rc) in state.iter_mut().zip(round_constants.iter()) {
155            *s += rc;
156        }
157        // S-box: state[i] = state[i]^D for all i.
158        for s in state.iter_mut() {
159            *s = s.injective_exp_n();
160        }
161        // MixLayer: dispatched via Permutation trait.
162        mds.permute_mut(state);
163    }
164}