1pub(crate) mod consts;
20mod rng;
21
22use alloc::{boxed::Box, rc::Rc, vec::Vec};
23
24use risc0_core::field::{
25 baby_bear::{BabyBear, BabyBearElem, BabyBearExtElem},
26 Elem, ExtElem,
27};
28
29use super::{HashFn, HashSuite, Rng, RngFactory};
30use crate::core::digest::{Digest, DIGEST_WORDS};
31
32pub use self::{
33 consts::{CELLS, M_INT_DIAG_HZN, ROUNDS_HALF_FULL, ROUNDS_PARTIAL, ROUND_CONSTANTS},
34 rng::Poseidon2Rng,
35};
36
37pub const CELLS_RATE: usize = 16;
39
40pub const CELLS_OUT: usize = 8;
42
43struct Poseidon2HashFn;
45
46impl HashFn<BabyBear> for Poseidon2HashFn {
47 fn hash_pair(&self, a: &Digest, b: &Digest) -> Box<Digest> {
48 let both: Vec<BabyBearElem> = a
49 .as_words()
50 .iter()
51 .chain(b.as_words())
52 .map(|w| BabyBearElem::new_raw(*w))
53 .collect();
54 assert!(both.len() == DIGEST_WORDS * 2);
55 for elem in &both {
56 assert!(elem.is_reduced());
57 }
58 to_digest(unpadded_hash(both.iter()))
59 }
60
61 fn hash_elem_slice(&self, slice: &[BabyBearElem]) -> Box<Digest> {
62 to_digest(unpadded_hash(slice.iter()))
63 }
64
65 fn hash_ext_elem_slice(&self, slice: &[BabyBearExtElem]) -> Box<Digest> {
66 to_digest(unpadded_hash(
67 slice.iter().flat_map(|ee| ee.subelems().iter()),
68 ))
69 }
70}
71
72struct Poseidon2RngFactory;
73
74impl RngFactory<BabyBear> for Poseidon2RngFactory {
75 fn new_rng(&self) -> Box<dyn Rng<BabyBear>> {
76 Box::new(Poseidon2Rng::new())
77 }
78}
79
80pub struct Poseidon2HashSuite;
82
83impl Poseidon2HashSuite {
84 pub fn new_suite() -> HashSuite<BabyBear> {
86 HashSuite {
87 name: "poseidon2".into(),
88 hashfn: Rc::new(Poseidon2HashFn {}),
89 rng: Rc::new(Poseidon2RngFactory {}),
90 }
91 }
92}
93
94fn to_digest(elems: [BabyBearElem; CELLS_OUT]) -> Box<Digest> {
95 let mut state: [u32; DIGEST_WORDS] = [0; DIGEST_WORDS];
96 for i in 0..DIGEST_WORDS {
97 state[i] = elems[i].as_u32_montgomery();
98 }
99 Box::new(Digest::from(state))
100}
101
102fn add_round_constants_full(cells: &mut [BabyBearElem; CELLS], round: usize) {
103 for i in 0..CELLS {
104 cells[i] += ROUND_CONSTANTS[round * CELLS + i];
105 }
106}
107
108fn add_round_constants_partial(cells: &mut [BabyBearElem; CELLS], round: usize) {
109 cells[0] += ROUND_CONSTANTS[round * CELLS];
110}
111
112fn sbox(x: BabyBearElem) -> BabyBearElem {
113 let x2 = x * x;
114 let x4 = x2 * x2;
115 let x6 = x4 * x2;
116 x6 * x
117}
118
119fn do_full_sboxes(cells: &mut [BabyBearElem; CELLS]) {
120 for cell in cells.iter_mut() {
121 *cell = sbox(*cell);
122 }
123}
124
125fn do_partial_sboxes(cells: &mut [BabyBearElem; CELLS]) {
126 cells[0] = sbox(cells[0]);
127}
128
129fn multiply_by_m_int(cells: &mut [BabyBearElem; CELLS]) {
130 let sum: BabyBearElem = cells.iter().fold(BabyBearElem::ZERO, |acc, x| acc + *x);
132 for i in 0..CELLS {
133 cells[i] = sum + M_INT_DIAG_HZN[i] * cells[i];
134 }
135}
136
137fn multiply_by_4x4_circulant(x: &[BabyBearElem; 4]) -> [BabyBearElem; 4] {
138 let t0 = x[0] + x[1];
140 let t1 = x[2] + x[3];
141 let t2 = BabyBearElem::new(2) * x[1] + t1;
142 let t3 = BabyBearElem::new(2) * x[3] + t0;
143 let t4 = BabyBearElem::new(4) * t1 + t3;
144 let t5 = BabyBearElem::new(4) * t0 + t2;
145 let t6 = t3 + t5;
146 let t7 = t2 + t4;
147 [t6, t5, t7, t4]
148}
149
150fn multiply_by_m_ext(cells: &mut [BabyBearElem; CELLS]) {
151 let old_cells = *cells;
154 cells.fill(BabyBearElem::ZERO);
155 let mut tmp_sums = [BabyBearElem::ZERO; 4];
156
157 for i in 0..CELLS / 4 {
158 let chunk_array: [BabyBearElem; 4] = [
159 old_cells[i * 4],
160 old_cells[i * 4 + 1],
161 old_cells[i * 4 + 2],
162 old_cells[i * 4 + 3],
163 ];
164 let out = multiply_by_4x4_circulant(&chunk_array);
165 for j in 0..4 {
166 tmp_sums[j] += out[j];
167 cells[i * 4 + j] += out[j];
168 }
169 }
170 for i in 0..CELLS {
171 cells[i] += tmp_sums[i % 4];
172 }
173}
174
175fn full_round(cells: &mut [BabyBearElem; CELLS], round: usize) {
176 add_round_constants_full(cells, round);
177 do_full_sboxes(cells);
182 multiply_by_m_ext(cells);
183 }
185
186fn partial_round(cells: &mut [BabyBearElem; CELLS], round: usize) {
187 add_round_constants_partial(cells, round);
188 do_partial_sboxes(cells);
189 multiply_by_m_int(cells);
190}
191
192pub fn poseidon2_mix(cells: &mut [BabyBearElem; CELLS]) {
194 let mut round = 0;
195
196 multiply_by_m_ext(cells);
198 for _i in 0..ROUNDS_HALF_FULL {
202 full_round(cells, round);
203 round += 1;
204 }
205 for _i in 0..ROUNDS_PARTIAL {
207 partial_round(cells, round);
208 round += 1;
209 }
210 for _i in 0..ROUNDS_HALF_FULL {
213 full_round(cells, round);
214 round += 1;
215 }
216}
217
218pub fn unpadded_hash<'a, I>(iter: I) -> [BabyBearElem; CELLS_OUT]
222where
223 I: Iterator<Item = &'a BabyBearElem>,
224{
225 let mut state = [BabyBearElem::ZERO; CELLS];
226 let mut count = 0;
227 let mut unmixed = 0;
228 for val in iter {
229 state[unmixed] = *val;
230 count += 1;
231 unmixed += 1;
232 if unmixed == CELLS_RATE {
233 poseidon2_mix(&mut state);
234 unmixed = 0;
235 }
236 }
237 if unmixed != 0 || count == 0 {
238 for elem in state.iter_mut().take(CELLS_RATE).skip(unmixed) {
240 *elem = BabyBearElem::ZERO;
241 }
242 poseidon2_mix(&mut state);
243 }
244 state.as_slice()[0..CELLS_OUT].try_into().unwrap()
245}
246
247#[cfg(test)]
248mod tests {
249 use test_log::test;
250
251 use super::*;
252 use crate::core::hash::poseidon2::consts::_M_EXT;
253
254 fn do_partial_sboxes(cells: &mut [BabyBearElem; CELLS]) {
255 cells[0] = sbox(cells[0]);
256 }
257
258 fn partial_round_naive(cells: &mut [BabyBearElem; CELLS], round: usize) {
259 add_round_constants_partial(cells, round);
260 do_partial_sboxes(cells);
261 multiply_by_m_int_naive(cells);
262 }
263
264 fn multiply_by_m_ext_naive(cells: &mut [BabyBearElem; CELLS]) {
265 let old_cells = *cells;
266 for i in 0..CELLS {
267 let mut tot = BabyBearElem::ZERO;
268 for j in 0..CELLS {
269 tot += _M_EXT[i * CELLS + j] * old_cells[j];
270 }
271 cells[i] = tot;
272 }
273 }
274
275 fn multiply_by_m_int_naive(cells: &mut [BabyBearElem; CELLS]) {
276 let old_cells = *cells;
277 for i in 0..CELLS {
278 let mut tot = BabyBearElem::ZERO;
279 for (j, old_cell) in old_cells.iter().enumerate().take(CELLS) {
280 if i == j {
281 tot += (M_INT_DIAG_HZN[i] + BabyBearElem::ONE) * *old_cell;
282 } else {
283 tot += *old_cell;
284 }
285 }
286 cells[i] = tot;
287 }
288 }
289
290 fn poseidon2_mix_naive(cells: &mut [BabyBearElem; CELLS]) {
292 let mut round = 0;
293 multiply_by_m_ext_naive(cells);
294 for _i in 0..ROUNDS_HALF_FULL {
295 full_round(cells, round);
296 round += 1;
297 }
298 for _i in 0..ROUNDS_PARTIAL {
299 partial_round_naive(cells, round);
300 round += 1;
301 }
302 for _i in 0..ROUNDS_HALF_FULL {
303 full_round(cells, round);
304 round += 1;
305 }
306 }
307
308 #[test]
309 fn compare_naive() {
310 let mut test_in_1 = [BabyBearElem::ONE; CELLS];
312 let mut test_in_2 = test_in_1;
314 poseidon2_mix_naive(&mut test_in_1);
316 poseidon2_mix(&mut test_in_2);
317 tracing::debug!("test_in_1: {:?}", test_in_1);
318 tracing::debug!("test_in_2: {:?}", test_in_2);
319 assert_eq!(test_in_1, test_in_2);
321 }
322
323 macro_rules! baby_bear_array {
324 [$($x:literal),* $(,)?] => {
325 [$(BabyBearElem::new($x)),* ]
326 }
327 }
328
329 #[test]
330 fn poseidon2_test_vectors() {
331 let buf: &mut [BabyBearElem; CELLS] = &mut baby_bear_array![
332 0x00000000, 0x00000001, 0x00000002, 0x00000003, 0x00000004, 0x00000005, 0x00000006,
333 0x00000007, 0x00000008, 0x00000009, 0x0000000A, 0x0000000B, 0x0000000C, 0x0000000D,
334 0x0000000E, 0x0000000F, 0x00000010, 0x00000011, 0x00000012, 0x00000013, 0x00000014,
335 0x00000015, 0x00000016, 0x00000017,
336 ];
337 tracing::debug!("input: {:?}", buf);
338 poseidon2_mix(buf);
339 let goal: [u32; CELLS] = [
340 0x2ed3e23d, 0x12921fb0, 0x0e659e79, 0x61d81dc9, 0x32bae33b, 0x62486ae3, 0x1e681b60,
341 0x24b91325, 0x2a2ef5b9, 0x50e8593e, 0x5bc818ec, 0x10691997, 0x35a14520, 0x2ba6a3c5,
342 0x279d47ec, 0x55014e81, 0x5953a67f, 0x2f403111, 0x6b8828ff, 0x1801301f, 0x2749207a,
343 0x3dc9cf21, 0x3c985ba2, 0x57a99864,
344 ];
345 for i in 0..CELLS {
346 assert_eq!(buf[i].as_u32(), goal[i]);
347 }
348
349 tracing::debug!("output: {:?}", buf);
350 }
351
352 #[test]
354 fn hash_elem_slice_compare_golden() {
355 let buf: [BabyBearElem; 32] = baby_bear_array![
356 943718400, 1887436800, 2013125296, 1761607679, 692060158, 1761607634, 566231037,
357 1509949437, 440401916, 1384120316, 314572795, 1258291195, 188743674, 1132462074,
358 62914553, 1006632953, 1950351353, 880803832, 1824522232, 754974711, 1698693111,
359 629145590, 1572863990, 503316469, 1447034869, 377487348, 1321205748, 251658227,
360 1195376627, 125829106, 1069547506, 2013265906,
361 ];
362 let suite = Poseidon2HashSuite::new_suite();
363 let result = suite.hashfn.hash_elem_slice(&buf);
364 let goal: [u32; DIGEST_WORDS] = [
365 (BabyBearElem::from(0x722baada_u32)).as_u32_montgomery(),
366 (BabyBearElem::from(0x5b352fed_u32)).as_u32_montgomery(),
367 (BabyBearElem::from(0x3684017b_u32)).as_u32_montgomery(),
368 (BabyBearElem::from(0x540d4a7b_u32)).as_u32_montgomery(),
369 (BabyBearElem::from(0x44ffd422_u32)).as_u32_montgomery(),
370 (BabyBearElem::from(0x48615f97_u32)).as_u32_montgomery(),
371 (BabyBearElem::from(0x1a496f45_u32)).as_u32_montgomery(),
372 (BabyBearElem::from(0x203ca999_u32)).as_u32_montgomery(),
373 ];
374 for (i, word) in goal.iter().enumerate() {
375 assert_eq!(result.as_words()[i], *word, "At entry {i}");
376 }
377 }
378
379 #[test]
380 fn hash_elem_slice_compare_golden_unaligned() {
381 let buf: [BabyBearElem; 17] = baby_bear_array![
382 943718400, 1887436800, 2013125296, 1761607679, 692060158, 1635778558, 566231037,
383 1509949437, 440401916, 1384120316, 314572795, 1258291195, 188743674, 1132462074,
384 62914553, 1006632953, 1950351353,
385 ];
386 let suite = Poseidon2HashSuite::new_suite();
387 let result = suite.hashfn.hash_elem_slice(&buf);
388 let goal: [u32; DIGEST_WORDS] = [
389 (BabyBearElem::from(0x622615d7_u32)).as_u32_montgomery(),
390 (BabyBearElem::from(0x1cfe9764_u32)).as_u32_montgomery(),
391 (BabyBearElem::from(0x166cb1c9_u32)).as_u32_montgomery(),
392 (BabyBearElem::from(0x76febcde_u32)).as_u32_montgomery(),
393 (BabyBearElem::from(0x6056219f_u32)).as_u32_montgomery(),
394 (BabyBearElem::from(0x326359cf_u32)).as_u32_montgomery(),
395 (BabyBearElem::from(0x5c2cca75_u32)).as_u32_montgomery(),
396 (BabyBearElem::from(0x233dc3ff_u32)).as_u32_montgomery(),
397 ];
398 for (i, word) in goal.iter().enumerate() {
399 assert_eq!(result.as_words()[i], *word, "At entry {i}");
400 }
401 }
402}