risc0_zkp/core/hash/poseidon2/
mod.rs

1// Copyright 2024 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! An implementation of Poseidon2 targeting the Baby Bear.
16
17// Thank you to https://github.com/nhukc for the initial implementation of this code
18
19pub(crate) mod consts;
20mod rng;
21
22use alloc::{boxed::Box, rc::Rc, vec::Vec};
23
24use risc0_core::field::{
25    baby_bear::{BabyBear, BabyBearElem, BabyBearExtElem},
26    Elem, ExtElem,
27};
28
29use super::{HashFn, HashSuite, Rng, RngFactory};
30use crate::core::digest::{Digest, DIGEST_WORDS};
31
32pub use self::{
33    consts::{CELLS, M_INT_DIAG_HZN, ROUNDS_HALF_FULL, ROUNDS_PARTIAL, ROUND_CONSTANTS},
34    rng::Poseidon2Rng,
35};
36
37/// The 'rate' of the sponge, i.e. how much we can safely add/remove per mixing.
38pub const CELLS_RATE: usize = 16;
39
40/// The size of the hash output in cells (~ 248 bits)
41pub const CELLS_OUT: usize = 8;
42
43/// A hash implementation for Poseidon2
44struct Poseidon2HashFn;
45
46impl HashFn<BabyBear> for Poseidon2HashFn {
47    fn hash_pair(&self, a: &Digest, b: &Digest) -> Box<Digest> {
48        let both: Vec<BabyBearElem> = a
49            .as_words()
50            .iter()
51            .chain(b.as_words())
52            .map(|w| BabyBearElem::new_raw(*w))
53            .collect();
54        assert!(both.len() == DIGEST_WORDS * 2);
55        for elem in &both {
56            assert!(elem.is_reduced());
57        }
58        to_digest(unpadded_hash(both.iter()))
59    }
60
61    fn hash_elem_slice(&self, slice: &[BabyBearElem]) -> Box<Digest> {
62        to_digest(unpadded_hash(slice.iter()))
63    }
64
65    fn hash_ext_elem_slice(&self, slice: &[BabyBearExtElem]) -> Box<Digest> {
66        to_digest(unpadded_hash(
67            slice.iter().flat_map(|ee| ee.subelems().iter()),
68        ))
69    }
70}
71
72struct Poseidon2RngFactory;
73
74impl RngFactory<BabyBear> for Poseidon2RngFactory {
75    fn new_rng(&self) -> Box<dyn Rng<BabyBear>> {
76        Box::new(Poseidon2Rng::new())
77    }
78}
79
80/// A hash suite using Poseidon2 for both MT hashes and RNG
81pub struct Poseidon2HashSuite;
82
83impl Poseidon2HashSuite {
84    /// Construct a new Poseidon2HashSuite
85    pub fn new_suite() -> HashSuite<BabyBear> {
86        HashSuite {
87            name: "poseidon2".into(),
88            hashfn: Rc::new(Poseidon2HashFn {}),
89            rng: Rc::new(Poseidon2RngFactory {}),
90        }
91    }
92}
93
94fn to_digest(elems: [BabyBearElem; CELLS_OUT]) -> Box<Digest> {
95    let mut state: [u32; DIGEST_WORDS] = [0; DIGEST_WORDS];
96    for i in 0..DIGEST_WORDS {
97        state[i] = elems[i].as_u32_montgomery();
98    }
99    Box::new(Digest::from(state))
100}
101
102fn add_round_constants_full(cells: &mut [BabyBearElem; CELLS], round: usize) {
103    for i in 0..CELLS {
104        cells[i] += ROUND_CONSTANTS[round * CELLS + i];
105    }
106}
107
108fn add_round_constants_partial(cells: &mut [BabyBearElem; CELLS], round: usize) {
109    cells[0] += ROUND_CONSTANTS[round * CELLS];
110}
111
112fn sbox(x: BabyBearElem) -> BabyBearElem {
113    let x2 = x * x;
114    let x4 = x2 * x2;
115    let x6 = x4 * x2;
116    x6 * x
117}
118
119fn do_full_sboxes(cells: &mut [BabyBearElem; CELLS]) {
120    for cell in cells.iter_mut() {
121        *cell = sbox(*cell);
122    }
123}
124
125fn do_partial_sboxes(cells: &mut [BabyBearElem; CELLS]) {
126    cells[0] = sbox(cells[0]);
127}
128
129fn multiply_by_m_int(cells: &mut [BabyBearElem; CELLS]) {
130    // Exploit the fact that off-diagonal entries of M_INT are all 1.
131    let sum: BabyBearElem = cells.iter().fold(BabyBearElem::ZERO, |acc, x| acc + *x);
132    for i in 0..CELLS {
133        cells[i] = sum + M_INT_DIAG_HZN[i] * cells[i];
134    }
135}
136
137fn multiply_by_4x4_circulant(x: &[BabyBearElem; 4]) -> [BabyBearElem; 4] {
138    // See appendix B of Poseidon2 paper.
139    let t0 = x[0] + x[1];
140    let t1 = x[2] + x[3];
141    let t2 = BabyBearElem::new(2) * x[1] + t1;
142    let t3 = BabyBearElem::new(2) * x[3] + t0;
143    let t4 = BabyBearElem::new(4) * t1 + t3;
144    let t5 = BabyBearElem::new(4) * t0 + t2;
145    let t6 = t3 + t5;
146    let t7 = t2 + t4;
147    [t6, t5, t7, t4]
148}
149
150fn multiply_by_m_ext(cells: &mut [BabyBearElem; CELLS]) {
151    // Optimized method for multiplication by M_EXT.
152    // See appendix B of Poseidon2 paper for additional details.
153    let old_cells = *cells;
154    cells.fill(BabyBearElem::ZERO);
155    let mut tmp_sums = [BabyBearElem::ZERO; 4];
156
157    for i in 0..CELLS / 4 {
158        let chunk_array: [BabyBearElem; 4] = [
159            old_cells[i * 4],
160            old_cells[i * 4 + 1],
161            old_cells[i * 4 + 2],
162            old_cells[i * 4 + 3],
163        ];
164        let out = multiply_by_4x4_circulant(&chunk_array);
165        for j in 0..4 {
166            tmp_sums[j] += out[j];
167            cells[i * 4 + j] += out[j];
168        }
169    }
170    for i in 0..CELLS {
171        cells[i] += tmp_sums[i % 4];
172    }
173}
174
175fn full_round(cells: &mut [BabyBearElem; CELLS], round: usize) {
176    add_round_constants_full(cells, round);
177    // if round == 0 {
178    //     tracing::trace!("After constants in full round 0: {cells:?}");
179    // }
180
181    do_full_sboxes(cells);
182    multiply_by_m_ext(cells);
183    // tracing::trace!("After mExt in full round {round}: {cells:?}");
184}
185
186fn partial_round(cells: &mut [BabyBearElem; CELLS], round: usize) {
187    add_round_constants_partial(cells, round);
188    do_partial_sboxes(cells);
189    multiply_by_m_int(cells);
190}
191
192/// The raw sponge mixing function
193pub fn poseidon2_mix(cells: &mut [BabyBearElem; CELLS]) {
194    let mut round = 0;
195
196    // First linear layer.
197    multiply_by_m_ext(cells);
198    // tracing::trace!("After initial mExt: {cells:?}");
199
200    // Do initial full rounds
201    for _i in 0..ROUNDS_HALF_FULL {
202        full_round(cells, round);
203        round += 1;
204    }
205    // Do partial rounds
206    for _i in 0..ROUNDS_PARTIAL {
207        partial_round(cells, round);
208        round += 1;
209    }
210    // tracing::trace!("After partial rounds: {cells:?}");
211    // Do remaining full rounds
212    for _i in 0..ROUNDS_HALF_FULL {
213        full_round(cells, round);
214        round += 1;
215    }
216}
217
218/// Perform an unpadded hash of a vector of elements.  Because this is unpadded
219/// collision resistance is only true for vectors of the same size.  If the size
220/// is variable, this is subject to length extension attacks.
221pub fn unpadded_hash<'a, I>(iter: I) -> [BabyBearElem; CELLS_OUT]
222where
223    I: Iterator<Item = &'a BabyBearElem>,
224{
225    let mut state = [BabyBearElem::ZERO; CELLS];
226    let mut count = 0;
227    let mut unmixed = 0;
228    for val in iter {
229        state[unmixed] = *val;
230        count += 1;
231        unmixed += 1;
232        if unmixed == CELLS_RATE {
233            poseidon2_mix(&mut state);
234            unmixed = 0;
235        }
236    }
237    if unmixed != 0 || count == 0 {
238        // Zero pad to get a CELLS_RATE-aligned number of inputs
239        for elem in state.iter_mut().take(CELLS_RATE).skip(unmixed) {
240            *elem = BabyBearElem::ZERO;
241        }
242        poseidon2_mix(&mut state);
243    }
244    state.as_slice()[0..CELLS_OUT].try_into().unwrap()
245}
246
247#[cfg(test)]
248mod tests {
249    use test_log::test;
250
251    use super::*;
252    use crate::core::hash::poseidon2::consts::_M_EXT;
253
254    fn do_partial_sboxes(cells: &mut [BabyBearElem; CELLS]) {
255        cells[0] = sbox(cells[0]);
256    }
257
258    fn partial_round_naive(cells: &mut [BabyBearElem; CELLS], round: usize) {
259        add_round_constants_partial(cells, round);
260        do_partial_sboxes(cells);
261        multiply_by_m_int_naive(cells);
262    }
263
264    fn multiply_by_m_ext_naive(cells: &mut [BabyBearElem; CELLS]) {
265        let old_cells = *cells;
266        for i in 0..CELLS {
267            let mut tot = BabyBearElem::ZERO;
268            for j in 0..CELLS {
269                tot += _M_EXT[i * CELLS + j] * old_cells[j];
270            }
271            cells[i] = tot;
272        }
273    }
274
275    fn multiply_by_m_int_naive(cells: &mut [BabyBearElem; CELLS]) {
276        let old_cells = *cells;
277        for i in 0..CELLS {
278            let mut tot = BabyBearElem::ZERO;
279            for (j, old_cell) in old_cells.iter().enumerate().take(CELLS) {
280                if i == j {
281                    tot += (M_INT_DIAG_HZN[i] + BabyBearElem::ONE) * *old_cell;
282                } else {
283                    tot += *old_cell;
284                }
285            }
286            cells[i] = tot;
287        }
288    }
289
290    // Naive version of poseidon2
291    fn poseidon2_mix_naive(cells: &mut [BabyBearElem; CELLS]) {
292        let mut round = 0;
293        multiply_by_m_ext_naive(cells);
294        for _i in 0..ROUNDS_HALF_FULL {
295            full_round(cells, round);
296            round += 1;
297        }
298        for _i in 0..ROUNDS_PARTIAL {
299            partial_round_naive(cells, round);
300            round += 1;
301        }
302        for _i in 0..ROUNDS_HALF_FULL {
303            full_round(cells, round);
304            round += 1;
305        }
306    }
307
308    #[test]
309    fn compare_naive() {
310        // Make a fixed input
311        let mut test_in_1 = [BabyBearElem::ONE; CELLS];
312        // Copy it
313        let mut test_in_2 = test_in_1;
314        // Try two versions
315        poseidon2_mix_naive(&mut test_in_1);
316        poseidon2_mix(&mut test_in_2);
317        tracing::debug!("test_in_1: {:?}", test_in_1);
318        tracing::debug!("test_in_2: {:?}", test_in_2);
319        // Verify they are the same
320        assert_eq!(test_in_1, test_in_2);
321    }
322
323    macro_rules! baby_bear_array {
324        [$($x:literal),* $(,)?] => {
325            [$(BabyBearElem::new($x)),* ]
326        }
327    }
328
329    #[test]
330    fn poseidon2_test_vectors() {
331        let buf: &mut [BabyBearElem; CELLS] = &mut baby_bear_array![
332            0x00000000, 0x00000001, 0x00000002, 0x00000003, 0x00000004, 0x00000005, 0x00000006,
333            0x00000007, 0x00000008, 0x00000009, 0x0000000A, 0x0000000B, 0x0000000C, 0x0000000D,
334            0x0000000E, 0x0000000F, 0x00000010, 0x00000011, 0x00000012, 0x00000013, 0x00000014,
335            0x00000015, 0x00000016, 0x00000017,
336        ];
337        tracing::debug!("input: {:?}", buf);
338        poseidon2_mix(buf);
339        let goal: [u32; CELLS] = [
340            0x2ed3e23d, 0x12921fb0, 0x0e659e79, 0x61d81dc9, 0x32bae33b, 0x62486ae3, 0x1e681b60,
341            0x24b91325, 0x2a2ef5b9, 0x50e8593e, 0x5bc818ec, 0x10691997, 0x35a14520, 0x2ba6a3c5,
342            0x279d47ec, 0x55014e81, 0x5953a67f, 0x2f403111, 0x6b8828ff, 0x1801301f, 0x2749207a,
343            0x3dc9cf21, 0x3c985ba2, 0x57a99864,
344        ];
345        for i in 0..CELLS {
346            assert_eq!(buf[i].as_u32(), goal[i]);
347        }
348
349        tracing::debug!("output: {:?}", buf);
350    }
351
352    // Test against golden values from an independent interpreter version of Poseidon2
353    #[test]
354    fn hash_elem_slice_compare_golden() {
355        let buf: [BabyBearElem; 32] = baby_bear_array![
356            943718400, 1887436800, 2013125296, 1761607679, 692060158, 1761607634, 566231037,
357            1509949437, 440401916, 1384120316, 314572795, 1258291195, 188743674, 1132462074,
358            62914553, 1006632953, 1950351353, 880803832, 1824522232, 754974711, 1698693111,
359            629145590, 1572863990, 503316469, 1447034869, 377487348, 1321205748, 251658227,
360            1195376627, 125829106, 1069547506, 2013265906,
361        ];
362        let suite = Poseidon2HashSuite::new_suite();
363        let result = suite.hashfn.hash_elem_slice(&buf);
364        let goal: [u32; DIGEST_WORDS] = [
365            (BabyBearElem::from(0x722baada_u32)).as_u32_montgomery(),
366            (BabyBearElem::from(0x5b352fed_u32)).as_u32_montgomery(),
367            (BabyBearElem::from(0x3684017b_u32)).as_u32_montgomery(),
368            (BabyBearElem::from(0x540d4a7b_u32)).as_u32_montgomery(),
369            (BabyBearElem::from(0x44ffd422_u32)).as_u32_montgomery(),
370            (BabyBearElem::from(0x48615f97_u32)).as_u32_montgomery(),
371            (BabyBearElem::from(0x1a496f45_u32)).as_u32_montgomery(),
372            (BabyBearElem::from(0x203ca999_u32)).as_u32_montgomery(),
373        ];
374        for (i, word) in goal.iter().enumerate() {
375            assert_eq!(result.as_words()[i], *word, "At entry {i}");
376        }
377    }
378
379    #[test]
380    fn hash_elem_slice_compare_golden_unaligned() {
381        let buf: [BabyBearElem; 17] = baby_bear_array![
382            943718400, 1887436800, 2013125296, 1761607679, 692060158, 1635778558, 566231037,
383            1509949437, 440401916, 1384120316, 314572795, 1258291195, 188743674, 1132462074,
384            62914553, 1006632953, 1950351353,
385        ];
386        let suite = Poseidon2HashSuite::new_suite();
387        let result = suite.hashfn.hash_elem_slice(&buf);
388        let goal: [u32; DIGEST_WORDS] = [
389            (BabyBearElem::from(0x622615d7_u32)).as_u32_montgomery(),
390            (BabyBearElem::from(0x1cfe9764_u32)).as_u32_montgomery(),
391            (BabyBearElem::from(0x166cb1c9_u32)).as_u32_montgomery(),
392            (BabyBearElem::from(0x76febcde_u32)).as_u32_montgomery(),
393            (BabyBearElem::from(0x6056219f_u32)).as_u32_montgomery(),
394            (BabyBearElem::from(0x326359cf_u32)).as_u32_montgomery(),
395            (BabyBearElem::from(0x5c2cca75_u32)).as_u32_montgomery(),
396            (BabyBearElem::from(0x233dc3ff_u32)).as_u32_montgomery(),
397        ];
398        for (i, word) in goal.iter().enumerate() {
399            assert_eq!(result.as_words()[i], *word, "At entry {i}");
400        }
401    }
402}