1pub(crate) mod consts;
20mod rng;
21
22use alloc::{boxed::Box, rc::Rc, vec::Vec};
23
24use bytemuck::CheckedBitPattern;
25use risc0_core::field::{
26 baby_bear::{BabyBear, BabyBearElem, BabyBearExtElem},
27 Elem, ExtElem,
28};
29
30use super::{HashFn, HashSuite, Rng, RngFactory};
31use crate::core::digest::{Digest, DIGEST_WORDS};
32
33pub use self::{
34 consts::{CELLS, M_INT_DIAG_HZN, ROUNDS_HALF_FULL, ROUNDS_PARTIAL, ROUND_CONSTANTS},
35 rng::Poseidon2Rng,
36};
37
38pub const CELLS_RATE: usize = 16;
40
41pub const CELLS_OUT: usize = 8;
43
44struct Poseidon2HashFn;
46
47impl HashFn<BabyBear> for Poseidon2HashFn {
48 fn hash_pair(&self, a: &Digest, b: &Digest) -> Box<Digest> {
49 let both: Vec<BabyBearElem> = a
50 .as_words()
51 .iter()
52 .chain(b.as_words())
53 .map(|w| BabyBearElem::new_raw(*w))
54 .collect();
55 assert!(both.len() == DIGEST_WORDS * 2);
56 for elem in &both {
57 assert!(elem.is_reduced());
58 }
59 to_digest(unpadded_hash(both.iter()))
60 }
61
62 fn hash_elem_slice(&self, slice: &[BabyBearElem]) -> Box<Digest> {
63 to_digest(unpadded_hash(slice.iter()))
64 }
65
66 fn hash_ext_elem_slice(&self, slice: &[BabyBearExtElem]) -> Box<Digest> {
67 to_digest(unpadded_hash(
68 slice.iter().flat_map(|ee| ee.subelems().iter()),
69 ))
70 }
71
72 fn is_digest_valid(&self, digest: &Digest) -> bool {
74 digest
75 .as_words()
76 .iter()
77 .all(<BabyBearElem as CheckedBitPattern>::is_valid_bit_pattern)
78 }
79}
80
81struct Poseidon2RngFactory;
82
83impl RngFactory<BabyBear> for Poseidon2RngFactory {
84 fn new_rng(&self) -> Box<dyn Rng<BabyBear>> {
85 Box::new(Poseidon2Rng::new())
86 }
87}
88
89pub struct Poseidon2HashSuite;
91
92impl Poseidon2HashSuite {
93 pub fn new_suite() -> HashSuite<BabyBear> {
95 HashSuite {
96 name: "poseidon2".into(),
97 hashfn: Rc::new(Poseidon2HashFn {}),
98 rng: Rc::new(Poseidon2RngFactory {}),
99 }
100 }
101}
102
103fn to_digest(elems: [BabyBearElem; CELLS_OUT]) -> Box<Digest> {
104 let mut state: [u32; DIGEST_WORDS] = [0; DIGEST_WORDS];
105 for i in 0..DIGEST_WORDS {
106 state[i] = elems[i].as_u32_montgomery();
107 }
108 Box::new(Digest::from(state))
109}
110
111fn add_round_constants_full(cells: &mut [BabyBearElem; CELLS], round: usize) {
112 for i in 0..CELLS {
113 cells[i] += ROUND_CONSTANTS[round * CELLS + i];
114 }
115}
116
117fn add_round_constants_partial(cells: &mut [BabyBearElem; CELLS], round: usize) {
118 cells[0] += ROUND_CONSTANTS[round * CELLS];
119}
120
121fn sbox(x: BabyBearElem) -> BabyBearElem {
122 let x2 = x * x;
123 let x4 = x2 * x2;
124 let x6 = x4 * x2;
125 x6 * x
126}
127
128fn do_full_sboxes(cells: &mut [BabyBearElem; CELLS]) {
129 for cell in cells.iter_mut() {
130 *cell = sbox(*cell);
131 }
132}
133
134fn do_partial_sboxes(cells: &mut [BabyBearElem; CELLS]) {
135 cells[0] = sbox(cells[0]);
136}
137
138fn multiply_by_m_int(cells: &mut [BabyBearElem; CELLS]) {
139 let sum: BabyBearElem = cells.iter().fold(BabyBearElem::ZERO, |acc, x| acc + *x);
141 for i in 0..CELLS {
142 cells[i] = sum + M_INT_DIAG_HZN[i] * cells[i];
143 }
144}
145
146fn multiply_by_4x4_circulant(x: &[BabyBearElem; 4]) -> [BabyBearElem; 4] {
147 let t0 = x[0] + x[1];
149 let t1 = x[2] + x[3];
150 let t2 = BabyBearElem::new(2) * x[1] + t1;
151 let t3 = BabyBearElem::new(2) * x[3] + t0;
152 let t4 = BabyBearElem::new(4) * t1 + t3;
153 let t5 = BabyBearElem::new(4) * t0 + t2;
154 let t6 = t3 + t5;
155 let t7 = t2 + t4;
156 [t6, t5, t7, t4]
157}
158
159fn multiply_by_m_ext(cells: &mut [BabyBearElem; CELLS]) {
160 let old_cells = *cells;
163 cells.fill(BabyBearElem::ZERO);
164 let mut tmp_sums = [BabyBearElem::ZERO; 4];
165
166 for i in 0..CELLS / 4 {
167 let chunk_array: [BabyBearElem; 4] = [
168 old_cells[i * 4],
169 old_cells[i * 4 + 1],
170 old_cells[i * 4 + 2],
171 old_cells[i * 4 + 3],
172 ];
173 let out = multiply_by_4x4_circulant(&chunk_array);
174 for j in 0..4 {
175 tmp_sums[j] += out[j];
176 cells[i * 4 + j] += out[j];
177 }
178 }
179 for i in 0..CELLS {
180 cells[i] += tmp_sums[i % 4];
181 }
182}
183
184fn full_round(cells: &mut [BabyBearElem; CELLS], round: usize) {
185 add_round_constants_full(cells, round);
186 do_full_sboxes(cells);
191 multiply_by_m_ext(cells);
192 }
194
195fn partial_round(cells: &mut [BabyBearElem; CELLS], round: usize) {
196 add_round_constants_partial(cells, round);
197 do_partial_sboxes(cells);
198 multiply_by_m_int(cells);
199}
200
201pub fn poseidon2_mix(cells: &mut [BabyBearElem; CELLS]) {
203 let mut round = 0;
204
205 multiply_by_m_ext(cells);
207 for _i in 0..ROUNDS_HALF_FULL {
211 full_round(cells, round);
212 round += 1;
213 }
214 for _i in 0..ROUNDS_PARTIAL {
216 partial_round(cells, round);
217 round += 1;
218 }
219 for _i in 0..ROUNDS_HALF_FULL {
222 full_round(cells, round);
223 round += 1;
224 }
225}
226
227pub fn unpadded_hash<'a, I>(iter: I) -> [BabyBearElem; CELLS_OUT]
231where
232 I: Iterator<Item = &'a BabyBearElem>,
233{
234 let mut state = [BabyBearElem::ZERO; CELLS];
235 let mut count = 0;
236 let mut unmixed = 0;
237 for val in iter {
238 state[unmixed] = *val;
239 count += 1;
240 unmixed += 1;
241 if unmixed == CELLS_RATE {
242 poseidon2_mix(&mut state);
243 unmixed = 0;
244 }
245 }
246 if unmixed != 0 || count == 0 {
247 for elem in state.iter_mut().take(CELLS_RATE).skip(unmixed) {
249 *elem = BabyBearElem::ZERO;
250 }
251 poseidon2_mix(&mut state);
252 }
253 state.as_slice()[0..CELLS_OUT].try_into().unwrap()
254}
255
256#[cfg(test)]
257mod tests {
258 use test_log::test;
259
260 use super::*;
261 use crate::core::hash::poseidon2::consts::_M_EXT;
262
263 fn do_partial_sboxes(cells: &mut [BabyBearElem; CELLS]) {
264 cells[0] = sbox(cells[0]);
265 }
266
267 fn partial_round_naive(cells: &mut [BabyBearElem; CELLS], round: usize) {
268 add_round_constants_partial(cells, round);
269 do_partial_sboxes(cells);
270 multiply_by_m_int_naive(cells);
271 }
272
273 fn multiply_by_m_ext_naive(cells: &mut [BabyBearElem; CELLS]) {
274 let old_cells = *cells;
275 for i in 0..CELLS {
276 let mut tot = BabyBearElem::ZERO;
277 for j in 0..CELLS {
278 tot += _M_EXT[i * CELLS + j] * old_cells[j];
279 }
280 cells[i] = tot;
281 }
282 }
283
284 fn multiply_by_m_int_naive(cells: &mut [BabyBearElem; CELLS]) {
285 let old_cells = *cells;
286 for i in 0..CELLS {
287 let mut tot = BabyBearElem::ZERO;
288 for (j, old_cell) in old_cells.iter().enumerate().take(CELLS) {
289 if i == j {
290 tot += (M_INT_DIAG_HZN[i] + BabyBearElem::ONE) * *old_cell;
291 } else {
292 tot += *old_cell;
293 }
294 }
295 cells[i] = tot;
296 }
297 }
298
299 fn poseidon2_mix_naive(cells: &mut [BabyBearElem; CELLS]) {
301 let mut round = 0;
302 multiply_by_m_ext_naive(cells);
303 for _i in 0..ROUNDS_HALF_FULL {
304 full_round(cells, round);
305 round += 1;
306 }
307 for _i in 0..ROUNDS_PARTIAL {
308 partial_round_naive(cells, round);
309 round += 1;
310 }
311 for _i in 0..ROUNDS_HALF_FULL {
312 full_round(cells, round);
313 round += 1;
314 }
315 }
316
317 #[test]
318 fn compare_naive() {
319 let mut test_in_1 = [BabyBearElem::ONE; CELLS];
321 let mut test_in_2 = test_in_1;
323 poseidon2_mix_naive(&mut test_in_1);
325 poseidon2_mix(&mut test_in_2);
326 tracing::debug!("test_in_1: {:?}", test_in_1);
327 tracing::debug!("test_in_2: {:?}", test_in_2);
328 assert_eq!(test_in_1, test_in_2);
330 }
331
332 macro_rules! baby_bear_array {
333 [$($x:literal),* $(,)?] => {
334 [$(BabyBearElem::new($x)),* ]
335 }
336 }
337
338 #[test]
339 fn poseidon2_test_vectors() {
340 let buf: &mut [BabyBearElem; CELLS] = &mut baby_bear_array![
341 0x00000000, 0x00000001, 0x00000002, 0x00000003, 0x00000004, 0x00000005, 0x00000006,
342 0x00000007, 0x00000008, 0x00000009, 0x0000000A, 0x0000000B, 0x0000000C, 0x0000000D,
343 0x0000000E, 0x0000000F, 0x00000010, 0x00000011, 0x00000012, 0x00000013, 0x00000014,
344 0x00000015, 0x00000016, 0x00000017,
345 ];
346 tracing::debug!("input: {:?}", buf);
347 poseidon2_mix(buf);
348 let goal: [u32; CELLS] = [
349 0x2ed3e23d, 0x12921fb0, 0x0e659e79, 0x61d81dc9, 0x32bae33b, 0x62486ae3, 0x1e681b60,
350 0x24b91325, 0x2a2ef5b9, 0x50e8593e, 0x5bc818ec, 0x10691997, 0x35a14520, 0x2ba6a3c5,
351 0x279d47ec, 0x55014e81, 0x5953a67f, 0x2f403111, 0x6b8828ff, 0x1801301f, 0x2749207a,
352 0x3dc9cf21, 0x3c985ba2, 0x57a99864,
353 ];
354 for i in 0..CELLS {
355 assert_eq!(buf[i].as_u32(), goal[i]);
356 }
357
358 tracing::debug!("output: {:?}", buf);
359 }
360
361 #[test]
363 fn hash_elem_slice_compare_golden() {
364 let buf: [BabyBearElem; 32] = baby_bear_array![
365 943718400, 1887436800, 2013125296, 1761607679, 692060158, 1761607634, 566231037,
366 1509949437, 440401916, 1384120316, 314572795, 1258291195, 188743674, 1132462074,
367 62914553, 1006632953, 1950351353, 880803832, 1824522232, 754974711, 1698693111,
368 629145590, 1572863990, 503316469, 1447034869, 377487348, 1321205748, 251658227,
369 1195376627, 125829106, 1069547506, 2013265906,
370 ];
371 let suite = Poseidon2HashSuite::new_suite();
372 let result = suite.hashfn.hash_elem_slice(&buf);
373 let goal: [u32; DIGEST_WORDS] = [
374 (BabyBearElem::from(0x722baada_u32)).as_u32_montgomery(),
375 (BabyBearElem::from(0x5b352fed_u32)).as_u32_montgomery(),
376 (BabyBearElem::from(0x3684017b_u32)).as_u32_montgomery(),
377 (BabyBearElem::from(0x540d4a7b_u32)).as_u32_montgomery(),
378 (BabyBearElem::from(0x44ffd422_u32)).as_u32_montgomery(),
379 (BabyBearElem::from(0x48615f97_u32)).as_u32_montgomery(),
380 (BabyBearElem::from(0x1a496f45_u32)).as_u32_montgomery(),
381 (BabyBearElem::from(0x203ca999_u32)).as_u32_montgomery(),
382 ];
383 for (i, word) in goal.iter().enumerate() {
384 assert_eq!(result.as_words()[i], *word, "At entry {i}");
385 }
386 }
387
388 #[test]
389 fn hash_elem_slice_compare_golden_unaligned() {
390 let buf: [BabyBearElem; 17] = baby_bear_array![
391 943718400, 1887436800, 2013125296, 1761607679, 692060158, 1635778558, 566231037,
392 1509949437, 440401916, 1384120316, 314572795, 1258291195, 188743674, 1132462074,
393 62914553, 1006632953, 1950351353,
394 ];
395 let suite = Poseidon2HashSuite::new_suite();
396 let result = suite.hashfn.hash_elem_slice(&buf);
397 let goal: [u32; DIGEST_WORDS] = [
398 (BabyBearElem::from(0x622615d7_u32)).as_u32_montgomery(),
399 (BabyBearElem::from(0x1cfe9764_u32)).as_u32_montgomery(),
400 (BabyBearElem::from(0x166cb1c9_u32)).as_u32_montgomery(),
401 (BabyBearElem::from(0x76febcde_u32)).as_u32_montgomery(),
402 (BabyBearElem::from(0x6056219f_u32)).as_u32_montgomery(),
403 (BabyBearElem::from(0x326359cf_u32)).as_u32_montgomery(),
404 (BabyBearElem::from(0x5c2cca75_u32)).as_u32_montgomery(),
405 (BabyBearElem::from(0x233dc3ff_u32)).as_u32_montgomery(),
406 ];
407 for (i, word) in goal.iter().enumerate() {
408 assert_eq!(result.as_words()[i], *word, "At entry {i}");
409 }
410 }
411}