p3_poseidon1/utils.rs
1//! Equivalent matrix decomposition for efficient partial rounds.
2//!
3//! # Overview
4//!
5//! This module implements the sparse matrix optimization described in **Appendix B** of the Poseidon1
6//! paper (Grassi et al., USENIX Security 2021). It transforms the RP partial rounds
7//! from their textbook form (dense MDS multiply per round, O(t^2) each) into an
8//! equivalent form using sparse matrices (O(t) each).
9//!
10//! # Background: Textbook Partial Rounds
11//!
12//! In the unoptimized Poseidon1, each partial round applies:
13//!
14//! ```text
15//! state <- M * SBox(state + rc)
16//! ```
17//!
18//! where M is the dense txt MDS matrix, SBox applies x^D only to `state[0]`, and
19//! rc is a full t-vector of round constants. The cost per round is dominated by the
20//! dense matrix multiply: O(t^2).
21//!
22//! # The Sparse Factorization (Poseidon1 Paper, Appendix B, Eq. 5)
23//!
24//! The key insight is that M can be factored as:
25//!
26//! ```text
27//! M = M' * M''
28//! ```
29//!
30//! where:
31//!
32//! ```text
33//! M' = ┌───┬───┐ M'' = ┌─────────┬─────┐
34//! │ 1 │ 0 │ │ M[0][0] │ v │
35//! ├───┼───┤ ├─────────┼─────┤
36//! │ 0 │ M̂ │ │ ŵ │ I │
37//! └───┴───┘ └─────────┴─────┘
38//! ```
39//!
40//! Here M̂ is the (t-1)x(t-1) submatrix `M[1..t, 1..t]`, v is the first row of M
41//! (excluding `M[0][0]`), ŵ = M̂^{-1} * w where w is the first column of M
42//! (excluding `M[0][0]`), and I is the (t-1)x(t-1) identity.
43//!
44//! Since the partial S-box only touches `state[0]`, the dense M' factor can be
45//! "absorbed" into the next round's M, leaving only the sparse M'' to be applied
46//! per round via an O(t) matrix-vector product.
47//!
48//! After iterating this factorization across all RP rounds, we obtain:
49//! - One dense **transition matrix** m_i (applied once before the loop)
50//! - RP **sparse matrices** S_r, each parameterized by vectors v_r and ŵ_r of length t-1
51//!
52//! # Round Constant Compression
53//!
54//! In parallel, the round constants are compressed via backward substitution.
55//! Since M^{-1} is linear, we can "push" round constants backward through the
56//! inverse matrix, accumulating them into the first partial round. After this
57//! transformation:
58//! - The first partial round uses a full t-vector of constants.
59//! - Each subsequent partial round uses a single scalar constant (for `state[0]`).
60//! - The last partial round has no additive constant at all.
61//!
62//! # Implementation Note
63//!
64//! This implementation follows the HorizenLabs reference
65//! (`plain_implementations/src/poseidon/poseidon_params.rs`) which works on the
66//! **transposed** MDS matrix internally and reverses the sparse matrix ordering
67//! before returning.
68
69use alloc::vec;
70use alloc::vec::Vec;
71
72use p3_field::{Field, dot_product};
73
74/// Expand a circulant matrix from its first column into a dense NxN matrix.
75///
76/// Each column is a cyclic downward-shift of the previous one.
77pub fn circulant_to_dense<F: Field, const N: usize>(first_col: &[i64; N]) -> [[F; N]; N] {
78 let col: [F; N] = first_col.map(F::from_i64);
79 core::array::from_fn(|i| core::array::from_fn(|j| col[(N + i - j) % N]))
80}
81
82/// Dense NxN matrix multiplication: `C = A * B`.
83fn matrix_mul<F: Field, const N: usize>(a: &[[F; N]; N], b: &[[F; N]; N]) -> [[F; N]; N] {
84 core::array::from_fn(|i| {
85 core::array::from_fn(|j| dot_product(a[i].iter().copied(), (0..N).map(|k| b[k][j])))
86 })
87}
88
89/// Matrix-vector multiplication: `result = M * v`.
90fn matrix_vec_mul<F: Field, const N: usize>(m: &[[F; N]; N], v: &[F; N]) -> [F; N] {
91 core::array::from_fn(|i| F::dot_product(&m[i], v))
92}
93
94/// Matrix transpose: `result[i][j] = m[j][i]`.
95fn matrix_transpose<F: Field, const N: usize>(m: &[[F; N]; N]) -> [[F; N]; N] {
96 let mut result = [[F::ZERO; N]; N];
97 for i in 0..N {
98 for j in 0..N {
99 result[i][j] = m[j][i];
100 }
101 }
102 result
103}
104
105/// NxN matrix inverse via Gauss-Jordan elimination.
106///
107/// # Panics
108///
109/// Panics if the matrix is singular (i.e., not invertible over the field).
110fn matrix_inverse<F: Field, const N: usize>(m: &[[F; N]; N]) -> [[F; N]; N] {
111 // We work on [M | I] and reduce M to I, yielding [I | M^{-1}].
112 let mut aug = vec![[F::ZERO; N]; N];
113 let mut inv = [[F::ZERO; N]; N];
114
115 // Initialize: aug = M, inv = I.
116 for i in 0..N {
117 aug[i] = m[i];
118 inv[i][i] = F::ONE;
119 }
120
121 for col in 0..N {
122 // Partial pivoting: find a row with a nonzero entry in this column.
123 let pivot_row = (col..N)
124 .find(|&r| aug[r][col] != F::ZERO)
125 .expect("Matrix is singular");
126
127 // Swap the pivot row into position.
128 if pivot_row != col {
129 aug.swap(col, pivot_row);
130 inv.swap(col, pivot_row);
131 }
132
133 // Scale the pivot row so that aug[col][col] = 1.
134 let pivot_inv = aug[col][col].inverse();
135 for j in 0..N {
136 aug[col][j] *= pivot_inv;
137 inv[col][j] *= pivot_inv;
138 }
139
140 // Eliminate this column in all other rows.
141 for i in 0..N {
142 if i == col {
143 continue;
144 }
145 let factor = aug[i][col];
146 if factor == F::ZERO {
147 continue;
148 }
149 // Snapshot the pivot row to avoid aliasing.
150 let aug_col_row: [F; N] = aug[col];
151 let inv_col_row: [F; N] = inv[col];
152 for j in 0..N {
153 aug[i][j] -= factor * aug_col_row[j];
154 inv[i][j] -= factor * inv_col_row[j];
155 }
156 }
157 }
158
159 inv
160}
161
162/// Inverse of the (N-1)x(N-1) bottom-right submatrix: `m[1..N, 1..N]`.
163///
164/// This is M̂^{-1} from the sparse matrix factorization (Appendix B, Eq. 5 in the paper).
165///
166/// # Panics
167///
168/// Panics if the submatrix is singular. For an MDS matrix, every submatrix is
169/// non-singular by definition, so this should never happen with valid parameters.
170fn submatrix_inverse<F: Field, const N: usize>(m: &[[F; N]; N]) -> Vec<Vec<F>> {
171 let n = N - 1;
172
173 // Extract the (N-1)x(N-1) bottom-right submatrix.
174 let mut sub: Vec<Vec<F>> = (0..n).map(|_| F::zero_vec(n)).collect();
175 for i in 0..n {
176 for j in 0..n {
177 sub[i][j] = m[i + 1][j + 1];
178 }
179 }
180
181 // Standard Gauss-Jordan on the submatrix.
182 let mut inv: Vec<Vec<F>> = (0..n).map(|_| F::zero_vec(n)).collect();
183 for (i, row) in inv.iter_mut().enumerate() {
184 row[i] = F::ONE;
185 }
186
187 for col in 0..n {
188 let pivot_row = (col..n)
189 .find(|&r| sub[r][col] != F::ZERO)
190 .expect("Submatrix is singular");
191
192 if pivot_row != col {
193 sub.swap(col, pivot_row);
194 inv.swap(col, pivot_row);
195 }
196
197 let pivot_inv = sub[col][col].inverse();
198 for j in 0..n {
199 sub[col][j] *= pivot_inv;
200 inv[col][j] *= pivot_inv;
201 }
202
203 for i in 0..n {
204 if i == col {
205 continue;
206 }
207 let factor = sub[i][col];
208 if factor == F::ZERO {
209 continue;
210 }
211 let sub_col_row: Vec<F> = sub[col].clone();
212 let inv_col_row: Vec<F> = inv[col].clone();
213 for j in 0..n {
214 sub[i][j] -= factor * sub_col_row[j];
215 inv[i][j] -= factor * inv_col_row[j];
216 }
217 }
218 }
219
220 inv
221}
222
223/// Factor the dense MDS matrix into RP sparse matrices.
224///
225/// # Algorithm (following HorizenLabs)
226///
227/// The algorithm works on M^T (the transposed MDS matrix) and iterates RP times.
228/// In each iteration, it extracts the sparse components (v, ŵ) from the current
229/// accumulated matrix, then "peels off" one sparse factor by multiplying M^T back in.
230///
231/// After all RP iterations:
232/// - The accumulated remainder becomes the dense transition matrix m_i.
233/// - The sparse components (v_r, ŵ_r) are reversed to match forward application order.
234///
235/// # Returns
236///
237/// A tuple (m_i, v_collection, ŵ_collection) where:
238/// - m_i is a dense WIDTHxWIDTH transition matrix, applied once before the partial round loop.
239/// - `v_collection[r]` has WIDTH-1 elements: the first column of sparse factor S_r.
240/// - `ŵ_collection[r]` has WIDTH-1 elements: the first row of sparse factor S_r.
241#[allow(clippy::type_complexity)]
242fn compute_equivalent_matrices<F: Field, const N: usize>(
243 mds: &[[F; N]; N],
244 rounds_p: usize,
245) -> ([[F; N]; N], Vec<[F; N]>, Vec<[F; N]>) {
246 let mut w_hat_collection: Vec<[F; N]> = Vec::with_capacity(rounds_p);
247 let mut v_collection: Vec<[F; N]> = Vec::with_capacity(rounds_p);
248
249 // Work on M^T.
250 let mds_t = matrix_transpose(mds);
251 let mut m_mul = mds_t;
252 let mut m_i = [[F::ZERO; N]; N];
253
254 for _ in 0..rounds_p {
255 // Extract v = first row of m_mul (excluding [0,0]).
256 // In the transposed domain, this corresponds to the first column of M''.
257 // Stored in a flat [F; N] array, padded with zero at index N-1.
258 let v_arr: [F; N] =
259 core::array::from_fn(|j| if j < N - 1 { m_mul[0][j + 1] } else { F::ZERO });
260
261 // Extract w = first column of m_mul (excluding [0,0]).
262 let w: Vec<F> = (1..N).map(|i| m_mul[i][0]).collect();
263
264 // Compute M̂^{-1} (inverse of the bottom-right submatrix).
265 let m_hat_inv = submatrix_inverse::<F, N>(&m_mul);
266
267 // Compute ŵ = M̂^{-1} * w (Eq. 5 in the paper).
268 // Stored in a flat [F; N] array, padded with zero at index N-1.
269 let w_hat_arr: [F; N] = core::array::from_fn(|i| {
270 if i < N - 1 {
271 dot_product(m_hat_inv[i].iter().copied(), w.iter().copied())
272 } else {
273 F::ZERO
274 }
275 });
276
277 v_collection.push(v_arr);
278 w_hat_collection.push(w_hat_arr);
279
280 // Build m_i: identity-like matrix (zero out first row/column, set [0,0] = 1).
281 // This is the M' factor that gets absorbed into the next iteration.
282 m_i = m_mul;
283 m_i[0][0] = F::ONE;
284 for row in m_i.iter_mut().skip(1) {
285 row[0] = F::ZERO;
286 }
287 for elem in m_i[0].iter_mut().skip(1) {
288 *elem = F::ZERO;
289 }
290
291 // Accumulate: m_mul = M^T * m_i for the next iteration.
292 m_mul = matrix_mul(&mds_t, &m_i);
293 }
294
295 // Transpose m_i back (HorizenLabs works in the transposed domain).
296 let m_i_returned = matrix_transpose(&m_i);
297
298 // Reverse the collections: HorizenLabs computes them in reverse order
299 // (index RP-1 first, RP-2 second, ..., 0 last). After reversal, index 0
300 // corresponds to the first partial round applied.
301 v_collection.reverse();
302 w_hat_collection.reverse();
303
304 (m_i_returned, v_collection, w_hat_collection)
305}
306
307/// Compress round constants via backward substitution through M^{-1}.
308///
309/// # Algorithm
310///
311/// Starting from the last partial round's constants and working backward:
312///
313/// 1. Push the accumulated constant vector through M^{-1}.
314/// 2. Extract the first element as the scalar constant for that round.
315/// 3. Add the remaining elements to the previous round's constants.
316///
317/// After processing all rounds, the accumulated vector becomes the first partial
318/// round's full WIDTH-vector of constants.
319///
320/// # Returns
321///
322/// A tuple of (full_vector, scalar_constants) where:
323/// - The full vector has WIDTH elements, used for the first partial round.
324/// - The scalar constants have RP-1 entries, one per remaining partial round.
325fn equivalent_round_constants<F: Field, const N: usize>(
326 partial_rc: &[[F; N]],
327 mds_inv: &[[F; N]; N],
328) -> ([F; N], Vec<F>) {
329 let rounds_p = partial_rc.len();
330 let mut opt_partial_rc = F::zero_vec(rounds_p);
331
332 // Start with the last partial round's full constant vector.
333 let mut tmp = partial_rc[rounds_p - 1];
334
335 // Process rounds in reverse: from second-to-last down to first.
336 for i in (0..rounds_p - 1).rev() {
337 // Push the accumulated constants backward through M^{-1}.
338 let inv_cip = matrix_vec_mul(mds_inv, &tmp);
339
340 // The first element becomes the scalar constant for round i+1.
341 opt_partial_rc[i + 1] = inv_cip[0];
342
343 // Load round i's constants and add the remaining backward-substituted values.
344 tmp = partial_rc[i];
345 for j in 1..N {
346 tmp[j] += inv_cip[j];
347 }
348 }
349
350 // The accumulated vector is the first partial round's full constant vector.
351 let first_round_constants = tmp;
352
353 // Discard index 0 (round 0 uses the full vector, not a scalar).
354 let opt_partial_rc = opt_partial_rc[1..].to_vec();
355
356 (first_round_constants, opt_partial_rc)
357}
358
359/// Forward constant substitution for the textbook partial round path.
360///
361/// In a partial round, only `state[0]` goes through the S-box. The constants for
362/// `state[1..WIDTH]` can be folded forward through the MDS matrix, reducing each
363/// partial round to a single scalar addition to `state[0]` plus one MDS multiply.
364///
365/// # Algorithm
366///
367/// Starting from round 0, for each partial round:
368/// 1. The scalar constant for that round is `rc[0] + acc[0]` (original constant
369/// plus accumulated offset from previous rounds).
370/// 2. The remaining offsets `[0, rc[1]+acc[1], ..., rc[W-1]+acc[W-1]]` are
371/// propagated through the MDS matrix to produce the accumulator for the next round.
372///
373/// After all rounds, the final accumulator is a residual vector that must be added
374/// to the state.
375///
376/// # Returns
377///
378/// A tuple of (scalar_constants, residual) where:
379/// - `scalar_constants` has RP entries, one per partial round (added to `state[0]`
380/// before the S-box).
381/// - `residual` is a WIDTH-vector added to the state after all partial rounds complete.
382pub fn forward_constant_substitution<F: Field, const N: usize>(
383 mds: &[[F; N]; N],
384 partial_rc: &[[F; N]],
385) -> (Vec<F>, [F; N]) {
386 let rounds_p = partial_rc.len();
387 let mut acc = [F::ZERO; N];
388 let mut scalar_constants = Vec::with_capacity(rounds_p);
389
390 for rc in partial_rc {
391 // Scalar constant = rc[0] + accumulated offset.
392 scalar_constants.push(rc[0] + acc[0]);
393
394 // Build remainder: [0, rc[1]+acc[1], ..., rc[W-1]+acc[W-1]].
395 let remainder: [F; N] =
396 core::array::from_fn(|i| if i == 0 { F::ZERO } else { rc[i] + acc[i] });
397
398 // Propagate through MDS.
399 acc = matrix_vec_mul(mds, &remainder);
400 }
401
402 (scalar_constants, acc)
403}
404
405/// Compute all optimized partial round constants from raw parameters.
406///
407/// Combines the round constant compression and sparse matrix factorization
408/// into a single entry point, keeping the individual helpers private.
409///
410/// # Returns
411///
412/// A tuple of:
413/// - The compressed first-round constant vector (WIDTH elements).
414/// - The optimized scalar round constants (RP-1 entries).
415/// - The dense transition matrix m_i.
416/// - The per-round sparse v vectors.
417/// - The per-round sparse ŵ vectors.
418#[allow(clippy::type_complexity)]
419pub(crate) fn compute_optimized_constants<F: Field, const N: usize>(
420 mds: &[[F; N]; N],
421 rounds_p: usize,
422 partial_rc: &[[F; N]],
423) -> ([F; N], Vec<F>, [[F; N]; N], Vec<[F; N]>, Vec<[F; N]>) {
424 let mds_inv = matrix_inverse(mds);
425 let (first_round_constants, opt_partial_rc) = equivalent_round_constants(partial_rc, &mds_inv);
426 let (m_i, sparse_v, sparse_w_hat) = compute_equivalent_matrices(mds, rounds_p);
427
428 // Pre-assemble full first rows: [mds_0_0, ŵ[0], ŵ[1], ..., ŵ[N-2]].
429 // This enables branch-free dot product computation in cheap_matmul.
430 let mds_0_0 = mds[0][0];
431 let sparse_first_row = sparse_w_hat
432 .iter()
433 .map(|w| core::array::from_fn(|i| if i == 0 { mds_0_0 } else { w[i - 1] }))
434 .collect();
435
436 (
437 first_round_constants,
438 opt_partial_rc,
439 m_i,
440 sparse_v,
441 sparse_first_row,
442 )
443}
444
445#[cfg(test)]
446mod tests {
447 use p3_baby_bear::BabyBear;
448 use p3_field::{InjectiveMonomial, PrimeCharacteristicRing};
449 use rand::rngs::SmallRng;
450 use rand::{RngExt, SeedableRng};
451
452 use super::*;
453 use crate::cheap_matmul;
454
455 type F = BabyBear;
456
457 /// Verify that the matrix inverse produces a correct inverse: M * M^{-1} = I.
458 #[test]
459 fn test_matrix_inverse_roundtrip() {
460 let mut rng = SmallRng::seed_from_u64(42);
461 let m: [[F; 4]; 4] = rng.random();
462
463 let m_inv = matrix_inverse(&m);
464 let product = matrix_mul(&m, &m_inv);
465
466 for (i, row) in product.iter().enumerate() {
467 for (j, &val) in row.iter().enumerate() {
468 if i == j {
469 assert_eq!(val, F::ONE, "Diagonal [{i}][{j}] should be 1");
470 } else {
471 assert_eq!(val, F::ZERO, "Off-diagonal [{i}][{j}] should be 0");
472 }
473 }
474 }
475 }
476
477 /// Verify equivalence between textbook and optimized partial rounds on a 4x4 state.
478 ///
479 /// Textbook form: each round adds the full constant vector, applies the S-box to
480 /// the first element, and multiplies by the dense MDS matrix.
481 ///
482 /// Optimized form: adds the full constant vector once, applies the dense transition
483 /// matrix once, then loops over rounds applying the S-box to the first element,
484 /// a scalar constant, and the sparse matrix multiply.
485 #[test]
486 fn test_partial_rounds_equivalence_4x4() {
487 let mut rng = SmallRng::seed_from_u64(42);
488 let mds: [[F; 4]; 4] = rng.random();
489 let rounds_p = 3;
490
491 let partial_rc: Vec<[F; 4]> = (0..rounds_p).map(|_| rng.random()).collect();
492
493 let mds_inv = matrix_inverse(&mds);
494
495 let (first_rc, opt_rc) = equivalent_round_constants::<F, 4>(&partial_rc, &mds_inv);
496 let (m_i, v_coll, w_hat_coll) = compute_equivalent_matrices::<F, 4>(&mds, rounds_p);
497
498 let input: [F; 4] = rng.random();
499
500 // Textbook partial rounds: add full constant vector, S-box on first element,
501 // then dense MDS multiply.
502 let mut textbook_state = input;
503 for rc in partial_rc.iter().take(rounds_p) {
504 for (s, &c) in textbook_state.iter_mut().zip(rc.iter()) {
505 *s += c;
506 }
507 textbook_state[0] = InjectiveMonomial::<7>::injective_exp_n(&textbook_state[0]);
508 textbook_state = matrix_vec_mul(&mds, &textbook_state);
509 }
510
511 // Optimized partial rounds: add full constant vector once, apply dense
512 // transition matrix once, then loop with S-box + scalar constant + sparse multiply.
513 let mut opt_state = input;
514 for i in 0..4 {
515 opt_state[i] += first_rc[i];
516 }
517 opt_state = matrix_vec_mul(&m_i, &opt_state);
518
519 // Pre-assemble full first rows (as done in compute_optimized_constants).
520 let mds_0_0 = mds[0][0];
521 let sparse_first_rows: Vec<[F; 4]> = w_hat_coll
522 .iter()
523 .map(|w| core::array::from_fn(|i| if i == 0 { mds_0_0 } else { w[i - 1] }))
524 .collect();
525
526 for r in 0..rounds_p {
527 opt_state[0] = InjectiveMonomial::<7>::injective_exp_n(&opt_state[0]);
528 if r < rounds_p - 1 {
529 opt_state[0] += opt_rc[r];
530 }
531 cheap_matmul(&mut opt_state, &sparse_first_rows[r], &v_coll[r]);
532 }
533
534 assert_eq!(
535 textbook_state, opt_state,
536 "Textbook and optimized partial rounds should match"
537 );
538 }
539
540 /// Verify equivalence between textbook (full MDS per round) and textbook+scalar
541 /// (forward constant substitution) partial rounds on a 4x4 state.
542 #[test]
543 fn test_forward_constant_substitution_equivalence_4x4() {
544 let mut rng = SmallRng::seed_from_u64(123);
545 let mds: [[F; 4]; 4] = rng.random();
546 let rounds_p = 5;
547
548 let partial_rc: Vec<[F; 4]> = (0..rounds_p).map(|_| rng.random()).collect();
549
550 let (scalar_constants, residual) = forward_constant_substitution::<F, 4>(&mds, &partial_rc);
551
552 let input: [F; 4] = rng.random();
553
554 // Textbook: add full constant vector, S-box on first element, dense MDS.
555 let mut textbook_state = input;
556 for rc in &partial_rc {
557 for (s, &c) in textbook_state.iter_mut().zip(rc.iter()) {
558 *s += c;
559 }
560 textbook_state[0] = InjectiveMonomial::<7>::injective_exp_n(&textbook_state[0]);
561 textbook_state = matrix_vec_mul(&mds, &textbook_state);
562 }
563
564 // Textbook+scalar: only scalar constant to state[0], same S-box + MDS,
565 // then add residual at the end.
566 let mut scalar_state = input;
567 for &c in &scalar_constants {
568 scalar_state[0] += c;
569 scalar_state[0] = InjectiveMonomial::<7>::injective_exp_n(&scalar_state[0]);
570 scalar_state = matrix_vec_mul(&mds, &scalar_state);
571 }
572 for (s, &r) in scalar_state.iter_mut().zip(residual.iter()) {
573 *s += r;
574 }
575
576 assert_eq!(
577 textbook_state, scalar_state,
578 "Textbook and textbook+scalar partial rounds should match"
579 );
580 }
581}