spongefish_poseidon/
lib.rs

1//! This code has been blatantly stolen from `ark-crypto-primitive::sponge`
2//! from William Lin, with contributions from Pratyush Mishra, Weikeng Chen, Yuwen Zhang, Kristian Sosnin, Merlyn, Wilson Nguyen, Hossein Moghaddas, and others.
3use std::fmt::Debug;
4
5use ark_ff::PrimeField;
6use spongefish::duplex_sponge::{DuplexSponge, Permutation, Unit};
7
8/// Poseidon Sponge.
9///
10/// The `NAME` const is to distinbuish between different bitsizes of the same Field.
11/// For instance Bls12_381 and Bn254 both have field type Fp<MontBackend<FrConfig, 4>, 4> but are different fields.
12#[derive(Clone)]
13pub struct PoseidonPermutation<const NAME: u32, F: PrimeField, const R: usize, const N: usize> {
14    /// Number of rounds in a full-round operation.
15    pub full_rounds: usize,
16    /// Number of rounds in a partial-round operation.
17    pub partial_rounds: usize,
18    /// Exponent used in S-boxes.
19    pub alpha: u64,
20    /// Additive Round keys. These are added before each MDS matrix application to make it an affine shift.
21    /// They are indexed by `ark[round_num][state_element_index]`
22    pub ark: &'static [[F; N]],
23    /// Maximally Distance Separating (MDS) Matrix.
24    pub mds: &'static [[F; N]],
25
26    /// Permutation state
27    pub state: [F; N],
28}
29
30pub type PoseidonHash<const NAME: u32, F, const R: usize, const N: usize> =
31    DuplexSponge<PoseidonPermutation<NAME, F, R, N>>;
32
33impl<const NAME: u32, F: PrimeField, const R: usize, const N: usize> AsRef<[F]>
34    for PoseidonPermutation<NAME, F, R, N>
35{
36    fn as_ref(&self) -> &[F] {
37        &self.state
38    }
39}
40
41impl<const NAME: u32, F: PrimeField, const R: usize, const N: usize> AsMut<[F]>
42    for PoseidonPermutation<NAME, F, R, N>
43{
44    fn as_mut(&mut self) -> &mut [F] {
45        &mut self.state
46    }
47}
48
49impl<const NAME: u32, F: PrimeField, const R: usize, const N: usize>
50    PoseidonPermutation<NAME, F, R, N>
51{
52    fn apply_s_box(&self, state: &mut [F], is_full_round: bool) {
53        // Full rounds apply the S Box (x^alpha) to every element of state
54        if is_full_round {
55            for elem in state {
56                *elem = elem.pow([self.alpha]);
57            }
58        }
59        // Partial rounds apply the S Box (x^alpha) to just the first element of state
60        else {
61            state[0] = state[0].pow([self.alpha]);
62        }
63    }
64
65    #[inline]
66    fn apply_ark(&self, state: &mut [F], round_number: usize) {
67        state.iter_mut().enumerate().for_each(|(i, state_elem)| {
68            *state_elem += self.ark[round_number][i];
69        });
70    }
71
72    #[allow(clippy::needless_range_loop)]
73    fn apply_mds(&self, state: &mut [F]) {
74        let mut new_state = [F::ZERO; N];
75        for i in 0..N {
76            let mut cur = F::zero();
77            for j in 0..N {
78                cur += state[j] * self.mds[i][j];
79            }
80            new_state[i] = cur;
81        }
82        state.clone_from_slice(&new_state);
83    }
84}
85
86impl<const NAME: u32, F: PrimeField, const R: usize, const N: usize> zeroize::Zeroize
87    for PoseidonPermutation<NAME, F, R, N>
88{
89    fn zeroize(&mut self) {
90        self.state.zeroize();
91    }
92}
93
94impl<const NAME: u32, F, const R: usize, const N: usize> Permutation
95    for PoseidonPermutation<NAME, F, R, N>
96where
97    Self: Default,
98    F: PrimeField + Unit,
99{
100    type U = F;
101    const N: usize = N;
102    const R: usize = R;
103
104    fn new(iv: [u8; 32]) -> Self {
105        assert!(N >= 1);
106        let mut sponge = Self::default();
107        sponge.state[R] = F::from_be_bytes_mod_order(&iv);
108        sponge
109    }
110
111    fn permute(&mut self) {
112        let full_rounds_over_2 = self.full_rounds / 2;
113        let mut state = self.state;
114        for i in 0..full_rounds_over_2 {
115            self.apply_ark(&mut state, i);
116            self.apply_s_box(&mut state, true);
117            self.apply_mds(&mut state);
118        }
119
120        for i in 0..self.partial_rounds {
121            self.apply_ark(&mut state, full_rounds_over_2 + i);
122            self.apply_s_box(&mut state, false);
123            self.apply_mds(&mut state);
124        }
125
126        for i in 0..full_rounds_over_2 {
127            self.apply_ark(&mut state, full_rounds_over_2 + self.partial_rounds + i);
128            self.apply_s_box(&mut state, true);
129            self.apply_mds(&mut state);
130        }
131        self.state = state;
132    }
133}
134
135impl<const NAME: u32, F: PrimeField, const R: usize, const N: usize> Debug
136    for PoseidonPermutation<NAME, F, R, N>
137{
138    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139        self.state.fmt(f)
140    }
141}
142
143/// Initialization of constants.
144#[allow(unused)]
145macro_rules! poseidon_permutation {
146    ($bits: expr, $name: ident, $path: tt) => {
147        pub type $name =
148            crate::PoseidonPermutation<$bits, $path::Field, { $path::R }, { $path::N }>;
149
150        impl Default for $name {
151            fn default() -> Self {
152                let alpha = $path::ALPHA;
153                Self {
154                    full_rounds: $path::R_F,
155                    partial_rounds: $path::R_P,
156                    alpha,
157                    ark: $path::ARK,
158                    mds: $path::MDS,
159                    state: [ark_ff::Zero::zero(); $path::N],
160                }
161            }
162        }
163    };
164}
165
166#[cfg(feature = "bls12-381")]
167pub mod bls12_381;
168
169#[cfg(feature = "bn254")]
170pub mod bn254;
171
172#[cfg(feature = "solinas")]
173pub mod f64;
174
175/// Unit-tests.
176#[cfg(test)]
177mod tests;