Skip to main content

risc0_zkp/core/hash/poseidon2/
mod.rs

1// Copyright 2026 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! An implementation of Poseidon2 targeting the Baby Bear.
16
17// Thank you to https://github.com/nhukc for the initial implementation of this code
18
19pub(crate) mod consts;
20mod rng;
21
22use alloc::{boxed::Box, rc::Rc, vec::Vec};
23
24use bytemuck::CheckedBitPattern;
25use risc0_core::field::{
26    baby_bear::{BabyBear, BabyBearElem, BabyBearExtElem},
27    Elem, ExtElem,
28};
29
30use super::{HashFn, HashSuite, Rng, RngFactory};
31use crate::core::digest::{Digest, DIGEST_WORDS};
32
33pub use self::{
34    consts::{CELLS, M_INT_DIAG_HZN, ROUNDS_HALF_FULL, ROUNDS_PARTIAL, ROUND_CONSTANTS},
35    rng::Poseidon2Rng,
36};
37
38/// The 'rate' of the sponge, i.e. how much we can safely add/remove per mixing.
39pub const CELLS_RATE: usize = 16;
40
41/// The size of the hash output in cells (~ 248 bits)
42pub const CELLS_OUT: usize = 8;
43
44/// A hash implementation for Poseidon2
45struct Poseidon2HashFn;
46
47impl HashFn<BabyBear> for Poseidon2HashFn {
48    fn hash_pair(&self, a: &Digest, b: &Digest) -> Box<Digest> {
49        let both: Vec<BabyBearElem> = a
50            .as_words()
51            .iter()
52            .chain(b.as_words())
53            .map(|w| BabyBearElem::new_raw(*w))
54            .collect();
55        assert!(both.len() == DIGEST_WORDS * 2);
56        for elem in &both {
57            assert!(elem.is_reduced());
58        }
59        to_digest(unpadded_hash(both.iter()))
60    }
61
62    fn hash_elem_slice(&self, slice: &[BabyBearElem]) -> Box<Digest> {
63        to_digest(unpadded_hash(slice.iter()))
64    }
65
66    fn hash_ext_elem_slice(&self, slice: &[BabyBearExtElem]) -> Box<Digest> {
67        to_digest(unpadded_hash(
68            slice.iter().flat_map(|ee| ee.subelems().iter()),
69        ))
70    }
71
72    /// Checks if all words in the digest are less than the Baby Bear modulus.
73    fn is_digest_valid(&self, digest: &Digest) -> bool {
74        digest
75            .as_words()
76            .iter()
77            .all(<BabyBearElem as CheckedBitPattern>::is_valid_bit_pattern)
78    }
79}
80
81struct Poseidon2RngFactory;
82
83impl RngFactory<BabyBear> for Poseidon2RngFactory {
84    fn new_rng(&self) -> Box<dyn Rng<BabyBear>> {
85        Box::new(Poseidon2Rng::new())
86    }
87}
88
89/// A hash suite using Poseidon2 for both MT hashes and RNG
90pub struct Poseidon2HashSuite;
91
92impl Poseidon2HashSuite {
93    /// Construct a new Poseidon2HashSuite
94    pub fn new_suite() -> HashSuite<BabyBear> {
95        HashSuite {
96            name: "poseidon2".into(),
97            hashfn: Rc::new(Poseidon2HashFn {}),
98            rng: Rc::new(Poseidon2RngFactory {}),
99        }
100    }
101}
102
103fn to_digest(elems: [BabyBearElem; CELLS_OUT]) -> Box<Digest> {
104    let mut state: [u32; DIGEST_WORDS] = [0; DIGEST_WORDS];
105    for i in 0..DIGEST_WORDS {
106        state[i] = elems[i].as_u32_montgomery();
107    }
108    Box::new(Digest::from(state))
109}
110
111fn add_round_constants_full(cells: &mut [BabyBearElem; CELLS], round: usize) {
112    for i in 0..CELLS {
113        cells[i] += ROUND_CONSTANTS[round * CELLS + i];
114    }
115}
116
117fn add_round_constants_partial(cells: &mut [BabyBearElem; CELLS], round: usize) {
118    cells[0] += ROUND_CONSTANTS[round * CELLS];
119}
120
121fn sbox(x: BabyBearElem) -> BabyBearElem {
122    let x2 = x * x;
123    let x4 = x2 * x2;
124    let x6 = x4 * x2;
125    x6 * x
126}
127
128fn do_full_sboxes(cells: &mut [BabyBearElem; CELLS]) {
129    for cell in cells.iter_mut() {
130        *cell = sbox(*cell);
131    }
132}
133
134fn do_partial_sboxes(cells: &mut [BabyBearElem; CELLS]) {
135    cells[0] = sbox(cells[0]);
136}
137
138fn multiply_by_m_int(cells: &mut [BabyBearElem; CELLS]) {
139    // Exploit the fact that off-diagonal entries of M_INT are all 1.
140    let sum: BabyBearElem = cells.iter().fold(BabyBearElem::ZERO, |acc, x| acc + *x);
141    for i in 0..CELLS {
142        cells[i] = sum + M_INT_DIAG_HZN[i] * cells[i];
143    }
144}
145
146fn multiply_by_4x4_circulant(x: &[BabyBearElem; 4]) -> [BabyBearElem; 4] {
147    // See appendix B of Poseidon2 paper.
148    let t0 = x[0] + x[1];
149    let t1 = x[2] + x[3];
150    let t2 = BabyBearElem::new(2) * x[1] + t1;
151    let t3 = BabyBearElem::new(2) * x[3] + t0;
152    let t4 = BabyBearElem::new(4) * t1 + t3;
153    let t5 = BabyBearElem::new(4) * t0 + t2;
154    let t6 = t3 + t5;
155    let t7 = t2 + t4;
156    [t6, t5, t7, t4]
157}
158
159fn multiply_by_m_ext(cells: &mut [BabyBearElem; CELLS]) {
160    // Optimized method for multiplication by M_EXT.
161    // See appendix B of Poseidon2 paper for additional details.
162    let old_cells = *cells;
163    cells.fill(BabyBearElem::ZERO);
164    let mut tmp_sums = [BabyBearElem::ZERO; 4];
165
166    for i in 0..CELLS / 4 {
167        let chunk_array: [BabyBearElem; 4] = [
168            old_cells[i * 4],
169            old_cells[i * 4 + 1],
170            old_cells[i * 4 + 2],
171            old_cells[i * 4 + 3],
172        ];
173        let out = multiply_by_4x4_circulant(&chunk_array);
174        for j in 0..4 {
175            tmp_sums[j] += out[j];
176            cells[i * 4 + j] += out[j];
177        }
178    }
179    for i in 0..CELLS {
180        cells[i] += tmp_sums[i % 4];
181    }
182}
183
184fn full_round(cells: &mut [BabyBearElem; CELLS], round: usize) {
185    add_round_constants_full(cells, round);
186    // if round == 0 {
187    //     tracing::trace!("After constants in full round 0: {cells:?}");
188    // }
189
190    do_full_sboxes(cells);
191    multiply_by_m_ext(cells);
192    // tracing::trace!("After mExt in full round {round}: {cells:?}");
193}
194
195fn partial_round(cells: &mut [BabyBearElem; CELLS], round: usize) {
196    add_round_constants_partial(cells, round);
197    do_partial_sboxes(cells);
198    multiply_by_m_int(cells);
199}
200
201/// The raw sponge mixing function
202pub fn poseidon2_mix(cells: &mut [BabyBearElem; CELLS]) {
203    let mut round = 0;
204
205    // First linear layer.
206    multiply_by_m_ext(cells);
207    // tracing::trace!("After initial mExt: {cells:?}");
208
209    // Do initial full rounds
210    for _i in 0..ROUNDS_HALF_FULL {
211        full_round(cells, round);
212        round += 1;
213    }
214    // Do partial rounds
215    for _i in 0..ROUNDS_PARTIAL {
216        partial_round(cells, round);
217        round += 1;
218    }
219    // tracing::trace!("After partial rounds: {cells:?}");
220    // Do remaining full rounds
221    for _i in 0..ROUNDS_HALF_FULL {
222        full_round(cells, round);
223        round += 1;
224    }
225}
226
227/// Perform an unpadded hash of a vector of elements.  Because this is unpadded
228/// collision resistance is only true for vectors of the same size.  If the size
229/// is variable, this is subject to length extension attacks.
230pub fn unpadded_hash<'a, I>(iter: I) -> [BabyBearElem; CELLS_OUT]
231where
232    I: Iterator<Item = &'a BabyBearElem>,
233{
234    let mut state = [BabyBearElem::ZERO; CELLS];
235    let mut count = 0;
236    let mut unmixed = 0;
237    for val in iter {
238        state[unmixed] = *val;
239        count += 1;
240        unmixed += 1;
241        if unmixed == CELLS_RATE {
242            poseidon2_mix(&mut state);
243            unmixed = 0;
244        }
245    }
246    if unmixed != 0 || count == 0 {
247        // Zero pad to get a CELLS_RATE-aligned number of inputs
248        for elem in state.iter_mut().take(CELLS_RATE).skip(unmixed) {
249            *elem = BabyBearElem::ZERO;
250        }
251        poseidon2_mix(&mut state);
252    }
253    state.as_slice()[0..CELLS_OUT].try_into().unwrap()
254}
255
256#[cfg(test)]
257mod tests {
258    use test_log::test;
259
260    use super::*;
261    use crate::core::hash::poseidon2::consts::_M_EXT;
262
263    fn do_partial_sboxes(cells: &mut [BabyBearElem; CELLS]) {
264        cells[0] = sbox(cells[0]);
265    }
266
267    fn partial_round_naive(cells: &mut [BabyBearElem; CELLS], round: usize) {
268        add_round_constants_partial(cells, round);
269        do_partial_sboxes(cells);
270        multiply_by_m_int_naive(cells);
271    }
272
273    fn multiply_by_m_ext_naive(cells: &mut [BabyBearElem; CELLS]) {
274        let old_cells = *cells;
275        for i in 0..CELLS {
276            let mut tot = BabyBearElem::ZERO;
277            for j in 0..CELLS {
278                tot += _M_EXT[i * CELLS + j] * old_cells[j];
279            }
280            cells[i] = tot;
281        }
282    }
283
284    fn multiply_by_m_int_naive(cells: &mut [BabyBearElem; CELLS]) {
285        let old_cells = *cells;
286        for i in 0..CELLS {
287            let mut tot = BabyBearElem::ZERO;
288            for (j, old_cell) in old_cells.iter().enumerate().take(CELLS) {
289                if i == j {
290                    tot += (M_INT_DIAG_HZN[i] + BabyBearElem::ONE) * *old_cell;
291                } else {
292                    tot += *old_cell;
293                }
294            }
295            cells[i] = tot;
296        }
297    }
298
299    // Naive version of poseidon2
300    fn poseidon2_mix_naive(cells: &mut [BabyBearElem; CELLS]) {
301        let mut round = 0;
302        multiply_by_m_ext_naive(cells);
303        for _i in 0..ROUNDS_HALF_FULL {
304            full_round(cells, round);
305            round += 1;
306        }
307        for _i in 0..ROUNDS_PARTIAL {
308            partial_round_naive(cells, round);
309            round += 1;
310        }
311        for _i in 0..ROUNDS_HALF_FULL {
312            full_round(cells, round);
313            round += 1;
314        }
315    }
316
317    #[test]
318    fn compare_naive() {
319        // Make a fixed input
320        let mut test_in_1 = [BabyBearElem::ONE; CELLS];
321        // Copy it
322        let mut test_in_2 = test_in_1;
323        // Try two versions
324        poseidon2_mix_naive(&mut test_in_1);
325        poseidon2_mix(&mut test_in_2);
326        tracing::debug!("test_in_1: {:?}", test_in_1);
327        tracing::debug!("test_in_2: {:?}", test_in_2);
328        // Verify they are the same
329        assert_eq!(test_in_1, test_in_2);
330    }
331
332    macro_rules! baby_bear_array {
333        [$($x:literal),* $(,)?] => {
334            [$(BabyBearElem::new($x)),* ]
335        }
336    }
337
338    #[test]
339    fn poseidon2_test_vectors() {
340        let buf: &mut [BabyBearElem; CELLS] = &mut baby_bear_array![
341            0x00000000, 0x00000001, 0x00000002, 0x00000003, 0x00000004, 0x00000005, 0x00000006,
342            0x00000007, 0x00000008, 0x00000009, 0x0000000A, 0x0000000B, 0x0000000C, 0x0000000D,
343            0x0000000E, 0x0000000F, 0x00000010, 0x00000011, 0x00000012, 0x00000013, 0x00000014,
344            0x00000015, 0x00000016, 0x00000017,
345        ];
346        tracing::debug!("input: {:?}", buf);
347        poseidon2_mix(buf);
348        let goal: [u32; CELLS] = [
349            0x2ed3e23d, 0x12921fb0, 0x0e659e79, 0x61d81dc9, 0x32bae33b, 0x62486ae3, 0x1e681b60,
350            0x24b91325, 0x2a2ef5b9, 0x50e8593e, 0x5bc818ec, 0x10691997, 0x35a14520, 0x2ba6a3c5,
351            0x279d47ec, 0x55014e81, 0x5953a67f, 0x2f403111, 0x6b8828ff, 0x1801301f, 0x2749207a,
352            0x3dc9cf21, 0x3c985ba2, 0x57a99864,
353        ];
354        for i in 0..CELLS {
355            assert_eq!(buf[i].as_u32(), goal[i]);
356        }
357
358        tracing::debug!("output: {:?}", buf);
359    }
360
361    // Test against golden values from an independent interpreter version of Poseidon2
362    #[test]
363    fn hash_elem_slice_compare_golden() {
364        let buf: [BabyBearElem; 32] = baby_bear_array![
365            943718400, 1887436800, 2013125296, 1761607679, 692060158, 1761607634, 566231037,
366            1509949437, 440401916, 1384120316, 314572795, 1258291195, 188743674, 1132462074,
367            62914553, 1006632953, 1950351353, 880803832, 1824522232, 754974711, 1698693111,
368            629145590, 1572863990, 503316469, 1447034869, 377487348, 1321205748, 251658227,
369            1195376627, 125829106, 1069547506, 2013265906,
370        ];
371        let suite = Poseidon2HashSuite::new_suite();
372        let result = suite.hashfn.hash_elem_slice(&buf);
373        let goal: [u32; DIGEST_WORDS] = [
374            (BabyBearElem::from(0x722baada_u32)).as_u32_montgomery(),
375            (BabyBearElem::from(0x5b352fed_u32)).as_u32_montgomery(),
376            (BabyBearElem::from(0x3684017b_u32)).as_u32_montgomery(),
377            (BabyBearElem::from(0x540d4a7b_u32)).as_u32_montgomery(),
378            (BabyBearElem::from(0x44ffd422_u32)).as_u32_montgomery(),
379            (BabyBearElem::from(0x48615f97_u32)).as_u32_montgomery(),
380            (BabyBearElem::from(0x1a496f45_u32)).as_u32_montgomery(),
381            (BabyBearElem::from(0x203ca999_u32)).as_u32_montgomery(),
382        ];
383        for (i, word) in goal.iter().enumerate() {
384            assert_eq!(result.as_words()[i], *word, "At entry {i}");
385        }
386    }
387
388    #[test]
389    fn hash_elem_slice_compare_golden_unaligned() {
390        let buf: [BabyBearElem; 17] = baby_bear_array![
391            943718400, 1887436800, 2013125296, 1761607679, 692060158, 1635778558, 566231037,
392            1509949437, 440401916, 1384120316, 314572795, 1258291195, 188743674, 1132462074,
393            62914553, 1006632953, 1950351353,
394        ];
395        let suite = Poseidon2HashSuite::new_suite();
396        let result = suite.hashfn.hash_elem_slice(&buf);
397        let goal: [u32; DIGEST_WORDS] = [
398            (BabyBearElem::from(0x622615d7_u32)).as_u32_montgomery(),
399            (BabyBearElem::from(0x1cfe9764_u32)).as_u32_montgomery(),
400            (BabyBearElem::from(0x166cb1c9_u32)).as_u32_montgomery(),
401            (BabyBearElem::from(0x76febcde_u32)).as_u32_montgomery(),
402            (BabyBearElem::from(0x6056219f_u32)).as_u32_montgomery(),
403            (BabyBearElem::from(0x326359cf_u32)).as_u32_montgomery(),
404            (BabyBearElem::from(0x5c2cca75_u32)).as_u32_montgomery(),
405            (BabyBearElem::from(0x233dc3ff_u32)).as_u32_montgomery(),
406        ];
407        for (i, word) in goal.iter().enumerate() {
408            assert_eq!(result.as_words()[i], *word, "At entry {i}");
409        }
410    }
411}