Skip to main content

p3_poseidon1/
utils.rs

1//! Equivalent matrix decomposition for efficient partial rounds.
2//!
3//! # Overview
4//!
5//! This module implements the sparse matrix optimization described in **Appendix B** of the Poseidon1
6//! paper (Grassi et al., USENIX Security 2021). It transforms the RP partial rounds
7//! from their textbook form (dense MDS multiply per round, O(t^2) each) into an
8//! equivalent form using sparse matrices (O(t) each).
9//!
10//! # Background: Textbook Partial Rounds
11//!
12//! In the unoptimized Poseidon1, each partial round applies:
13//!
14//! ```text
15//!   state <- M * SBox(state + rc)
16//! ```
17//!
18//! where M is the dense txt MDS matrix, SBox applies x^D only to `state[0]`, and
19//! rc is a full t-vector of round constants. The cost per round is dominated by the
20//! dense matrix multiply: O(t^2).
21//!
22//! # The Sparse Factorization (Poseidon1 Paper, Appendix B, Eq. 5)
23//!
24//! The key insight is that M can be factored as:
25//!
26//! ```text
27//!   M = M' * M''
28//! ```
29//!
30//! where:
31//!
32//! ```text
33//!   M' = ┌───┬───┐       M'' = ┌─────────┬─────┐
34//!        │ 1 │ 0 │             │ M[0][0] │  v  │
35//!        ├───┼───┤             ├─────────┼─────┤
36//!        │ 0 │ M̂ │             │   ŵ     │  I  │
37//!        └───┴───┘             └─────────┴─────┘
38//! ```
39//!
40//! Here M̂ is the (t-1)x(t-1) submatrix `M[1..t, 1..t]`, v is the first row of M
41//! (excluding `M[0][0]`), ŵ = M̂^{-1} * w where w is the first column of M
42//! (excluding `M[0][0]`), and I is the (t-1)x(t-1) identity.
43//!
44//! Since the partial S-box only touches `state[0]`, the dense M' factor can be
45//! "absorbed" into the next round's M, leaving only the sparse M'' to be applied
46//! per round via an O(t) matrix-vector product.
47//!
48//! After iterating this factorization across all RP rounds, we obtain:
49//! - One dense **transition matrix** m_i (applied once before the loop)
50//! - RP **sparse matrices** S_r, each parameterized by vectors v_r and ŵ_r of length t-1
51//!
52//! # Round Constant Compression
53//!
54//! In parallel, the round constants are compressed via backward substitution.
55//! Since M^{-1} is linear, we can "push" round constants backward through the
56//! inverse matrix, accumulating them into the first partial round. After this
57//! transformation:
58//! - The first partial round uses a full t-vector of constants.
59//! - Each subsequent partial round uses a single scalar constant (for `state[0]`).
60//! - The last partial round has no additive constant at all.
61//!
62//! # Implementation Note
63//!
64//! This implementation follows the HorizenLabs reference
65//! (`plain_implementations/src/poseidon/poseidon_params.rs`) which works on the
66//! **transposed** MDS matrix internally and reverses the sparse matrix ordering
67//! before returning.
68
69use alloc::vec;
70use alloc::vec::Vec;
71
72use p3_field::{Field, dot_product};
73
74/// Expand a circulant matrix from its first column into a dense NxN matrix.
75///
76/// Each column is a cyclic downward-shift of the previous one.
77pub fn circulant_to_dense<F: Field, const N: usize>(first_col: &[i64; N]) -> [[F; N]; N] {
78    let col: [F; N] = first_col.map(F::from_i64);
79    core::array::from_fn(|i| core::array::from_fn(|j| col[(N + i - j) % N]))
80}
81
82/// Dense NxN matrix multiplication: `C = A * B`.
83fn matrix_mul<F: Field, const N: usize>(a: &[[F; N]; N], b: &[[F; N]; N]) -> [[F; N]; N] {
84    core::array::from_fn(|i| {
85        core::array::from_fn(|j| dot_product(a[i].iter().copied(), (0..N).map(|k| b[k][j])))
86    })
87}
88
89/// Matrix-vector multiplication: `result = M * v`.
90fn matrix_vec_mul<F: Field, const N: usize>(m: &[[F; N]; N], v: &[F; N]) -> [F; N] {
91    core::array::from_fn(|i| F::dot_product(&m[i], v))
92}
93
94/// Matrix transpose: `result[i][j] = m[j][i]`.
95fn matrix_transpose<F: Field, const N: usize>(m: &[[F; N]; N]) -> [[F; N]; N] {
96    let mut result = [[F::ZERO; N]; N];
97    for i in 0..N {
98        for j in 0..N {
99            result[i][j] = m[j][i];
100        }
101    }
102    result
103}
104
105/// NxN matrix inverse via Gauss-Jordan elimination.
106///
107/// # Panics
108///
109/// Panics if the matrix is singular (i.e., not invertible over the field).
110fn matrix_inverse<F: Field, const N: usize>(m: &[[F; N]; N]) -> [[F; N]; N] {
111    // We work on [M | I] and reduce M to I, yielding [I | M^{-1}].
112    let mut aug = vec![[F::ZERO; N]; N];
113    let mut inv = [[F::ZERO; N]; N];
114
115    // Initialize: aug = M, inv = I.
116    for i in 0..N {
117        aug[i] = m[i];
118        inv[i][i] = F::ONE;
119    }
120
121    for col in 0..N {
122        // Partial pivoting: find a row with a nonzero entry in this column.
123        let pivot_row = (col..N)
124            .find(|&r| aug[r][col] != F::ZERO)
125            .expect("Matrix is singular");
126
127        // Swap the pivot row into position.
128        if pivot_row != col {
129            aug.swap(col, pivot_row);
130            inv.swap(col, pivot_row);
131        }
132
133        // Scale the pivot row so that aug[col][col] = 1.
134        let pivot_inv = aug[col][col].inverse();
135        for j in 0..N {
136            aug[col][j] *= pivot_inv;
137            inv[col][j] *= pivot_inv;
138        }
139
140        // Eliminate this column in all other rows.
141        for i in 0..N {
142            if i == col {
143                continue;
144            }
145            let factor = aug[i][col];
146            if factor == F::ZERO {
147                continue;
148            }
149            // Snapshot the pivot row to avoid aliasing.
150            let aug_col_row: [F; N] = aug[col];
151            let inv_col_row: [F; N] = inv[col];
152            for j in 0..N {
153                aug[i][j] -= factor * aug_col_row[j];
154                inv[i][j] -= factor * inv_col_row[j];
155            }
156        }
157    }
158
159    inv
160}
161
162/// Inverse of the (N-1)x(N-1) bottom-right submatrix: `m[1..N, 1..N]`.
163///
164/// This is M̂^{-1} from the sparse matrix factorization (Appendix B, Eq. 5 in the paper).
165///
166/// # Panics
167///
168/// Panics if the submatrix is singular. For an MDS matrix, every submatrix is
169/// non-singular by definition, so this should never happen with valid parameters.
170fn submatrix_inverse<F: Field, const N: usize>(m: &[[F; N]; N]) -> Vec<Vec<F>> {
171    let n = N - 1;
172
173    // Extract the (N-1)x(N-1) bottom-right submatrix.
174    let mut sub: Vec<Vec<F>> = (0..n).map(|_| F::zero_vec(n)).collect();
175    for i in 0..n {
176        for j in 0..n {
177            sub[i][j] = m[i + 1][j + 1];
178        }
179    }
180
181    // Standard Gauss-Jordan on the submatrix.
182    let mut inv: Vec<Vec<F>> = (0..n).map(|_| F::zero_vec(n)).collect();
183    for (i, row) in inv.iter_mut().enumerate() {
184        row[i] = F::ONE;
185    }
186
187    for col in 0..n {
188        let pivot_row = (col..n)
189            .find(|&r| sub[r][col] != F::ZERO)
190            .expect("Submatrix is singular");
191
192        if pivot_row != col {
193            sub.swap(col, pivot_row);
194            inv.swap(col, pivot_row);
195        }
196
197        let pivot_inv = sub[col][col].inverse();
198        for j in 0..n {
199            sub[col][j] *= pivot_inv;
200            inv[col][j] *= pivot_inv;
201        }
202
203        for i in 0..n {
204            if i == col {
205                continue;
206            }
207            let factor = sub[i][col];
208            if factor == F::ZERO {
209                continue;
210            }
211            let sub_col_row: Vec<F> = sub[col].clone();
212            let inv_col_row: Vec<F> = inv[col].clone();
213            for j in 0..n {
214                sub[i][j] -= factor * sub_col_row[j];
215                inv[i][j] -= factor * inv_col_row[j];
216            }
217        }
218    }
219
220    inv
221}
222
223/// Factor the dense MDS matrix into RP sparse matrices.
224///
225/// # Algorithm (following HorizenLabs)
226///
227/// The algorithm works on M^T (the transposed MDS matrix) and iterates RP times.
228/// In each iteration, it extracts the sparse components (v, ŵ) from the current
229/// accumulated matrix, then "peels off" one sparse factor by multiplying M^T back in.
230///
231/// After all RP iterations:
232/// - The accumulated remainder becomes the dense transition matrix m_i.
233/// - The sparse components (v_r, ŵ_r) are reversed to match forward application order.
234///
235/// # Returns
236///
237/// A tuple (m_i, v_collection, ŵ_collection) where:
238/// - m_i is a dense WIDTHxWIDTH transition matrix, applied once before the partial round loop.
239/// - `v_collection[r]` has WIDTH-1 elements: the first column of sparse factor S_r.
240/// - `ŵ_collection[r]` has WIDTH-1 elements: the first row of sparse factor S_r.
241#[allow(clippy::type_complexity)]
242fn compute_equivalent_matrices<F: Field, const N: usize>(
243    mds: &[[F; N]; N],
244    rounds_p: usize,
245) -> ([[F; N]; N], Vec<[F; N]>, Vec<[F; N]>) {
246    let mut w_hat_collection: Vec<[F; N]> = Vec::with_capacity(rounds_p);
247    let mut v_collection: Vec<[F; N]> = Vec::with_capacity(rounds_p);
248
249    // Work on M^T.
250    let mds_t = matrix_transpose(mds);
251    let mut m_mul = mds_t;
252    let mut m_i = [[F::ZERO; N]; N];
253
254    for _ in 0..rounds_p {
255        // Extract v = first row of m_mul (excluding [0,0]).
256        // In the transposed domain, this corresponds to the first column of M''.
257        // Stored in a flat [F; N] array, padded with zero at index N-1.
258        let v_arr: [F; N] =
259            core::array::from_fn(|j| if j < N - 1 { m_mul[0][j + 1] } else { F::ZERO });
260
261        // Extract w = first column of m_mul (excluding [0,0]).
262        let w: Vec<F> = (1..N).map(|i| m_mul[i][0]).collect();
263
264        // Compute M̂^{-1} (inverse of the bottom-right submatrix).
265        let m_hat_inv = submatrix_inverse::<F, N>(&m_mul);
266
267        // Compute ŵ = M̂^{-1} * w (Eq. 5 in the paper).
268        // Stored in a flat [F; N] array, padded with zero at index N-1.
269        let w_hat_arr: [F; N] = core::array::from_fn(|i| {
270            if i < N - 1 {
271                dot_product(m_hat_inv[i].iter().copied(), w.iter().copied())
272            } else {
273                F::ZERO
274            }
275        });
276
277        v_collection.push(v_arr);
278        w_hat_collection.push(w_hat_arr);
279
280        // Build m_i: identity-like matrix (zero out first row/column, set [0,0] = 1).
281        // This is the M' factor that gets absorbed into the next iteration.
282        m_i = m_mul;
283        m_i[0][0] = F::ONE;
284        for row in m_i.iter_mut().skip(1) {
285            row[0] = F::ZERO;
286        }
287        for elem in m_i[0].iter_mut().skip(1) {
288            *elem = F::ZERO;
289        }
290
291        // Accumulate: m_mul = M^T * m_i for the next iteration.
292        m_mul = matrix_mul(&mds_t, &m_i);
293    }
294
295    // Transpose m_i back (HorizenLabs works in the transposed domain).
296    let m_i_returned = matrix_transpose(&m_i);
297
298    // Reverse the collections: HorizenLabs computes them in reverse order
299    // (index RP-1 first, RP-2 second, ..., 0 last). After reversal, index 0
300    // corresponds to the first partial round applied.
301    v_collection.reverse();
302    w_hat_collection.reverse();
303
304    (m_i_returned, v_collection, w_hat_collection)
305}
306
307/// Compress round constants via backward substitution through M^{-1}.
308///
309/// # Algorithm
310///
311/// Starting from the last partial round's constants and working backward:
312///
313/// 1. Push the accumulated constant vector through M^{-1}.
314/// 2. Extract the first element as the scalar constant for that round.
315/// 3. Add the remaining elements to the previous round's constants.
316///
317/// After processing all rounds, the accumulated vector becomes the first partial
318/// round's full WIDTH-vector of constants.
319///
320/// # Returns
321///
322/// A tuple of (full_vector, scalar_constants) where:
323/// - The full vector has WIDTH elements, used for the first partial round.
324/// - The scalar constants have RP-1 entries, one per remaining partial round.
325fn equivalent_round_constants<F: Field, const N: usize>(
326    partial_rc: &[[F; N]],
327    mds_inv: &[[F; N]; N],
328) -> ([F; N], Vec<F>) {
329    let rounds_p = partial_rc.len();
330    let mut opt_partial_rc = F::zero_vec(rounds_p);
331
332    // Start with the last partial round's full constant vector.
333    let mut tmp = partial_rc[rounds_p - 1];
334
335    // Process rounds in reverse: from second-to-last down to first.
336    for i in (0..rounds_p - 1).rev() {
337        // Push the accumulated constants backward through M^{-1}.
338        let inv_cip = matrix_vec_mul(mds_inv, &tmp);
339
340        // The first element becomes the scalar constant for round i+1.
341        opt_partial_rc[i + 1] = inv_cip[0];
342
343        // Load round i's constants and add the remaining backward-substituted values.
344        tmp = partial_rc[i];
345        for j in 1..N {
346            tmp[j] += inv_cip[j];
347        }
348    }
349
350    // The accumulated vector is the first partial round's full constant vector.
351    let first_round_constants = tmp;
352
353    // Discard index 0 (round 0 uses the full vector, not a scalar).
354    let opt_partial_rc = opt_partial_rc[1..].to_vec();
355
356    (first_round_constants, opt_partial_rc)
357}
358
359/// Forward constant substitution for the textbook partial round path.
360///
361/// In a partial round, only `state[0]` goes through the S-box. The constants for
362/// `state[1..WIDTH]` can be folded forward through the MDS matrix, reducing each
363/// partial round to a single scalar addition to `state[0]` plus one MDS multiply.
364///
365/// # Algorithm
366///
367/// Starting from round 0, for each partial round:
368/// 1. The scalar constant for that round is `rc[0] + acc[0]` (original constant
369///    plus accumulated offset from previous rounds).
370/// 2. The remaining offsets `[0, rc[1]+acc[1], ..., rc[W-1]+acc[W-1]]` are
371///    propagated through the MDS matrix to produce the accumulator for the next round.
372///
373/// After all rounds, the final accumulator is a residual vector that must be added
374/// to the state.
375///
376/// # Returns
377///
378/// A tuple of (scalar_constants, residual) where:
379/// - `scalar_constants` has RP entries, one per partial round (added to `state[0]`
380///   before the S-box).
381/// - `residual` is a WIDTH-vector added to the state after all partial rounds complete.
382pub fn forward_constant_substitution<F: Field, const N: usize>(
383    mds: &[[F; N]; N],
384    partial_rc: &[[F; N]],
385) -> (Vec<F>, [F; N]) {
386    let rounds_p = partial_rc.len();
387    let mut acc = [F::ZERO; N];
388    let mut scalar_constants = Vec::with_capacity(rounds_p);
389
390    for rc in partial_rc {
391        // Scalar constant = rc[0] + accumulated offset.
392        scalar_constants.push(rc[0] + acc[0]);
393
394        // Build remainder: [0, rc[1]+acc[1], ..., rc[W-1]+acc[W-1]].
395        let remainder: [F; N] =
396            core::array::from_fn(|i| if i == 0 { F::ZERO } else { rc[i] + acc[i] });
397
398        // Propagate through MDS.
399        acc = matrix_vec_mul(mds, &remainder);
400    }
401
402    (scalar_constants, acc)
403}
404
405/// Compute all optimized partial round constants from raw parameters.
406///
407/// Combines the round constant compression and sparse matrix factorization
408/// into a single entry point, keeping the individual helpers private.
409///
410/// # Returns
411///
412/// A tuple of:
413/// - The compressed first-round constant vector (WIDTH elements).
414/// - The optimized scalar round constants (RP-1 entries).
415/// - The dense transition matrix m_i.
416/// - The per-round sparse v vectors.
417/// - The per-round sparse ŵ vectors.
418#[allow(clippy::type_complexity)]
419pub(crate) fn compute_optimized_constants<F: Field, const N: usize>(
420    mds: &[[F; N]; N],
421    rounds_p: usize,
422    partial_rc: &[[F; N]],
423) -> ([F; N], Vec<F>, [[F; N]; N], Vec<[F; N]>, Vec<[F; N]>) {
424    let mds_inv = matrix_inverse(mds);
425    let (first_round_constants, opt_partial_rc) = equivalent_round_constants(partial_rc, &mds_inv);
426    let (m_i, sparse_v, sparse_w_hat) = compute_equivalent_matrices(mds, rounds_p);
427
428    // Pre-assemble full first rows: [mds_0_0, ŵ[0], ŵ[1], ..., ŵ[N-2]].
429    // This enables branch-free dot product computation in cheap_matmul.
430    let mds_0_0 = mds[0][0];
431    let sparse_first_row = sparse_w_hat
432        .iter()
433        .map(|w| core::array::from_fn(|i| if i == 0 { mds_0_0 } else { w[i - 1] }))
434        .collect();
435
436    (
437        first_round_constants,
438        opt_partial_rc,
439        m_i,
440        sparse_v,
441        sparse_first_row,
442    )
443}
444
445#[cfg(test)]
446mod tests {
447    use p3_baby_bear::BabyBear;
448    use p3_field::{InjectiveMonomial, PrimeCharacteristicRing};
449    use rand::rngs::SmallRng;
450    use rand::{RngExt, SeedableRng};
451
452    use super::*;
453    use crate::cheap_matmul;
454
455    type F = BabyBear;
456
457    /// Verify that the matrix inverse produces a correct inverse: M * M^{-1} = I.
458    #[test]
459    fn test_matrix_inverse_roundtrip() {
460        let mut rng = SmallRng::seed_from_u64(42);
461        let m: [[F; 4]; 4] = rng.random();
462
463        let m_inv = matrix_inverse(&m);
464        let product = matrix_mul(&m, &m_inv);
465
466        for (i, row) in product.iter().enumerate() {
467            for (j, &val) in row.iter().enumerate() {
468                if i == j {
469                    assert_eq!(val, F::ONE, "Diagonal [{i}][{j}] should be 1");
470                } else {
471                    assert_eq!(val, F::ZERO, "Off-diagonal [{i}][{j}] should be 0");
472                }
473            }
474        }
475    }
476
477    /// Verify equivalence between textbook and optimized partial rounds on a 4x4 state.
478    ///
479    /// Textbook form: each round adds the full constant vector, applies the S-box to
480    /// the first element, and multiplies by the dense MDS matrix.
481    ///
482    /// Optimized form: adds the full constant vector once, applies the dense transition
483    /// matrix once, then loops over rounds applying the S-box to the first element,
484    /// a scalar constant, and the sparse matrix multiply.
485    #[test]
486    fn test_partial_rounds_equivalence_4x4() {
487        let mut rng = SmallRng::seed_from_u64(42);
488        let mds: [[F; 4]; 4] = rng.random();
489        let rounds_p = 3;
490
491        let partial_rc: Vec<[F; 4]> = (0..rounds_p).map(|_| rng.random()).collect();
492
493        let mds_inv = matrix_inverse(&mds);
494
495        let (first_rc, opt_rc) = equivalent_round_constants::<F, 4>(&partial_rc, &mds_inv);
496        let (m_i, v_coll, w_hat_coll) = compute_equivalent_matrices::<F, 4>(&mds, rounds_p);
497
498        let input: [F; 4] = rng.random();
499
500        // Textbook partial rounds: add full constant vector, S-box on first element,
501        // then dense MDS multiply.
502        let mut textbook_state = input;
503        for rc in partial_rc.iter().take(rounds_p) {
504            for (s, &c) in textbook_state.iter_mut().zip(rc.iter()) {
505                *s += c;
506            }
507            textbook_state[0] = InjectiveMonomial::<7>::injective_exp_n(&textbook_state[0]);
508            textbook_state = matrix_vec_mul(&mds, &textbook_state);
509        }
510
511        // Optimized partial rounds: add full constant vector once, apply dense
512        // transition matrix once, then loop with S-box + scalar constant + sparse multiply.
513        let mut opt_state = input;
514        for i in 0..4 {
515            opt_state[i] += first_rc[i];
516        }
517        opt_state = matrix_vec_mul(&m_i, &opt_state);
518
519        // Pre-assemble full first rows (as done in compute_optimized_constants).
520        let mds_0_0 = mds[0][0];
521        let sparse_first_rows: Vec<[F; 4]> = w_hat_coll
522            .iter()
523            .map(|w| core::array::from_fn(|i| if i == 0 { mds_0_0 } else { w[i - 1] }))
524            .collect();
525
526        for r in 0..rounds_p {
527            opt_state[0] = InjectiveMonomial::<7>::injective_exp_n(&opt_state[0]);
528            if r < rounds_p - 1 {
529                opt_state[0] += opt_rc[r];
530            }
531            cheap_matmul(&mut opt_state, &sparse_first_rows[r], &v_coll[r]);
532        }
533
534        assert_eq!(
535            textbook_state, opt_state,
536            "Textbook and optimized partial rounds should match"
537        );
538    }
539
540    /// Verify equivalence between textbook (full MDS per round) and textbook+scalar
541    /// (forward constant substitution) partial rounds on a 4x4 state.
542    #[test]
543    fn test_forward_constant_substitution_equivalence_4x4() {
544        let mut rng = SmallRng::seed_from_u64(123);
545        let mds: [[F; 4]; 4] = rng.random();
546        let rounds_p = 5;
547
548        let partial_rc: Vec<[F; 4]> = (0..rounds_p).map(|_| rng.random()).collect();
549
550        let (scalar_constants, residual) = forward_constant_substitution::<F, 4>(&mds, &partial_rc);
551
552        let input: [F; 4] = rng.random();
553
554        // Textbook: add full constant vector, S-box on first element, dense MDS.
555        let mut textbook_state = input;
556        for rc in &partial_rc {
557            for (s, &c) in textbook_state.iter_mut().zip(rc.iter()) {
558                *s += c;
559            }
560            textbook_state[0] = InjectiveMonomial::<7>::injective_exp_n(&textbook_state[0]);
561            textbook_state = matrix_vec_mul(&mds, &textbook_state);
562        }
563
564        // Textbook+scalar: only scalar constant to state[0], same S-box + MDS,
565        // then add residual at the end.
566        let mut scalar_state = input;
567        for &c in &scalar_constants {
568            scalar_state[0] += c;
569            scalar_state[0] = InjectiveMonomial::<7>::injective_exp_n(&scalar_state[0]);
570            scalar_state = matrix_vec_mul(&mds, &scalar_state);
571        }
572        for (s, &r) in scalar_state.iter_mut().zip(residual.iter()) {
573            *s += r;
574        }
575
576        assert_eq!(
577            textbook_state, scalar_state,
578            "Textbook and textbook+scalar partial rounds should match"
579        );
580    }
581}