use ark_ff::Field;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use unroll::unroll_for_loops;
pub trait Sponge<F: Field> {
type Digest;
fn hash(bytes: &[u8]) -> Self::Digest;
fn hash_field(elems: &[F]) -> Self::Digest;
fn merge(digests: &[Self::Digest; 2]) -> Self::Digest;
}
pub trait Jive<F: Field> {
fn compress(elems: &[F]) -> Vec<F>;
fn compress_k(elems: &[F], k: usize) -> Vec<F>;
}
pub trait Anemoi<'a, F: Field> {
const NUM_COLUMNS: usize;
const NUM_ROUNDS: usize;
const WIDTH: usize;
const RATE: usize;
const OUTPUT_SIZE: usize;
const MDS: Option<&'a [F]> = None;
const ARK_C: &'a [F];
const ARK_D: &'a [F];
const GROUP_GENERATOR: u32;
const ALPHA: u32;
const INV_ALPHA: F;
const BETA: u32;
const DELTA: F;
const QUAD: u32 = 2;
fn mul_by_generator(x: &F) -> F {
match Self::GROUP_GENERATOR {
2 => x.double(),
3 => x.double() + x,
5 => x.double().double() + x,
7 => (x.double() + x).double() + x,
9 => x.double().double().double() + x,
11 => (x.double().double() + x).double() + x,
13 => ((x.double() + x).double() + x).double() + x,
15 => x.double().double().double().double() - x,
17 => x.double().double().double().double() + x,
_ => F::from(Self::GROUP_GENERATOR as u64) * x,
}
}
fn exp_by_alpha(x: F) -> F {
match Self::ALPHA {
3 => x.square() * x,
5 => x.square().square() * x,
7 => (x.square() * x).square() * x,
11 => (x.square().square() * x).square() * x,
13 => ((x.square() * x).square() * x).square() * x,
17 => x.square().square().square().square() * x,
_ => x.pow([Self::ALPHA as u64]),
}
}
fn exp_by_inv_alpha(x: F) -> F;
#[inline(always)]
#[unroll_for_loops]
fn ark_layer(state: &mut [F], round_ctr: usize) {
debug_assert!(state.len() == Self::WIDTH);
assert!(round_ctr < Self::NUM_ROUNDS);
let range = round_ctr * Self::NUM_COLUMNS..(round_ctr + 1) * Self::NUM_COLUMNS;
let c = &Self::ARK_C[range.clone()];
let d = &Self::ARK_D[range];
for i in 0..Self::NUM_COLUMNS {
state[i] += c[i];
state[Self::NUM_COLUMNS + i] += d[i];
}
}
#[inline(always)]
fn mds_layer(state: &mut [F]) {
debug_assert!(state.len() == Self::WIDTH);
match Self::NUM_COLUMNS {
1 => {
state[1] += state[0];
state[0] += state[1];
}
2 => {
state[0] += Self::mul_by_generator(&state[1]);
state[1] += Self::mul_by_generator(&state[0]);
state[3] += Self::mul_by_generator(&state[2]);
state[2] += Self::mul_by_generator(&state[3]);
state.swap(2, 3);
state[2] += state[0];
state[3] += state[1];
state[0] += state[2];
state[1] += state[3];
}
3 => {
Self::mds_internal(&mut state[..Self::NUM_COLUMNS]);
state[Self::NUM_COLUMNS..].rotate_left(1);
Self::mds_internal(&mut state[Self::NUM_COLUMNS..]);
state[3] += state[0];
state[4] += state[1];
state[5] += state[2];
state[0] += state[3];
state[1] += state[4];
state[2] += state[5];
}
4 => {
Self::mds_internal(&mut state[..Self::NUM_COLUMNS]);
state[Self::NUM_COLUMNS..].rotate_left(1);
Self::mds_internal(&mut state[Self::NUM_COLUMNS..]);
state[4] += state[0];
state[5] += state[1];
state[6] += state[2];
state[7] += state[3];
state[0] += state[4];
state[1] += state[5];
state[2] += state[6];
state[3] += state[7];
}
5 => {
let x = state[..Self::NUM_COLUMNS].to_vec();
let mut y = state[Self::NUM_COLUMNS..].to_vec();
y.rotate_left(1);
let sum_coeffs = x[0] + x[1] + x[2] + x[3] + x[4];
state[0] = sum_coeffs + x[3] + (x[2] + x[3] + x[4].double()).double();
state[1] = sum_coeffs + x[4] + (x[3] + x[4] + x[0].double()).double();
state[2] = sum_coeffs + x[0] + (x[4] + x[0] + x[1].double()).double();
state[3] = sum_coeffs + x[1] + (x[0] + x[1] + x[2].double()).double();
state[4] = sum_coeffs + x[2] + (x[1] + x[2] + x[3].double()).double();
let sum_coeffs = y[0] + y[1] + y[2] + y[3] + y[4];
state[5] = sum_coeffs + y[3] + (y[2] + y[3] + y[4].double()).double();
state[6] = sum_coeffs + y[4] + (y[3] + y[4] + y[0].double()).double();
state[7] = sum_coeffs + y[0] + (y[4] + y[0] + y[1].double()).double();
state[8] = sum_coeffs + y[1] + (y[0] + y[1] + y[2].double()).double();
state[9] = sum_coeffs + y[2] + (y[1] + y[2] + y[3].double()).double();
state[5] += state[0];
state[6] += state[1];
state[7] += state[2];
state[8] += state[3];
state[9] += state[4];
state[0] += state[5];
state[1] += state[6];
state[2] += state[7];
state[3] += state[8];
state[4] += state[9];
}
6 => {
let x = state[..Self::NUM_COLUMNS].to_vec();
let mut y = state[Self::NUM_COLUMNS..].to_vec();
y.rotate_left(1);
let sum_coeffs = x[0] + x[1] + x[2] + x[3] + x[4] + x[5];
state[0] =
sum_coeffs + x[3] + x[5] + (x[2] + x[3] + (x[4] + x[5]).double()).double();
state[1] =
sum_coeffs + x[4] + x[0] + (x[3] + x[4] + (x[5] + x[0]).double()).double();
state[2] =
sum_coeffs + x[5] + x[1] + (x[4] + x[5] + (x[0] + x[1]).double()).double();
state[3] =
sum_coeffs + x[0] + x[2] + (x[5] + x[0] + (x[1] + x[2]).double()).double();
state[4] =
sum_coeffs + x[1] + x[3] + (x[0] + x[1] + (x[2] + x[3]).double()).double();
state[5] =
sum_coeffs + x[2] + x[4] + (x[1] + x[2] + (x[3] + x[4]).double()).double();
let sum_coeffs = y[0] + y[1] + y[2] + y[3] + y[4] + y[5];
state[6] =
sum_coeffs + y[3] + y[5] + (y[2] + y[3] + (y[4] + y[5]).double()).double();
state[7] =
sum_coeffs + y[4] + y[0] + (y[3] + y[4] + (y[5] + y[0]).double()).double();
state[8] =
sum_coeffs + y[5] + y[1] + (y[4] + y[5] + (y[0] + y[1]).double()).double();
state[9] =
sum_coeffs + y[0] + y[2] + (y[5] + y[0] + (y[1] + y[2]).double()).double();
state[10] =
sum_coeffs + y[1] + y[3] + (y[0] + y[1] + (y[2] + y[3]).double()).double();
state[11] =
sum_coeffs + y[2] + y[4] + (y[1] + y[2] + (y[3] + y[4]).double()).double();
state[6] += state[0];
state[7] += state[1];
state[8] += state[2];
state[9] += state[3];
state[10] += state[4];
state[11] += state[5];
state[0] += state[6];
state[1] += state[7];
state[2] += state[8];
state[3] += state[9];
state[4] += state[10];
state[5] += state[11];
}
_ => {
let mds = Self::MDS.expect("NO MDS matrix specified for this instance.");
let mut result = vec![F::zero(); Self::WIDTH];
for (index, r) in result.iter_mut().enumerate().take(Self::NUM_COLUMNS) {
for j in 0..Self::NUM_COLUMNS {
*r += mds[index * Self::NUM_COLUMNS + j] * state[j];
}
}
state[Self::NUM_COLUMNS..].rotate_left(1);
for (index, r) in result.iter_mut().skip(Self::NUM_COLUMNS).enumerate() {
for j in 0..Self::NUM_COLUMNS {
*r += mds[index * Self::NUM_COLUMNS + j] * state[Self::NUM_COLUMNS + j];
}
}
for i in 0..Self::NUM_COLUMNS {
state[Self::NUM_COLUMNS + i] = result[i] + result[Self::NUM_COLUMNS + i];
}
for i in 0..Self::NUM_COLUMNS {
state[i] = result[i] + state[Self::NUM_COLUMNS + i];
}
}
}
}
#[inline(always)]
fn mds_internal(state: &mut [F]) {
debug_assert!(state.len() == Self::WIDTH);
match Self::NUM_COLUMNS {
3 => {
let tmp = state[0] + Self::mul_by_generator(&state[2]);
state[2] += state[1];
state[2] += Self::mul_by_generator(&state[0]);
state[0] = tmp + state[2];
state[1] += tmp;
}
4 => {
state[0] += state[1];
state[2] += state[3];
state[3] += Self::mul_by_generator(&state[0]);
state[1] = Self::mul_by_generator(&(state[1] + state[2]));
state[0] += state[1];
state[2] += Self::mul_by_generator(&state[3]);
state[1] += state[2];
state[3] += state[0];
}
_ => (),
}
}
#[inline(always)]
#[unroll_for_loops]
fn sbox_layer(state: &mut [F]) {
debug_assert!(state.len() == Self::WIDTH);
let mut x = state[..Self::NUM_COLUMNS].to_vec();
let mut y = state[Self::NUM_COLUMNS..].to_vec();
x.iter_mut().enumerate().for_each(|(i, t)| {
let y2 = y[i].square();
*t -= Self::mul_by_generator(&y2);
});
let mut x_alpha_inv = x.clone();
x_alpha_inv
.iter_mut()
.for_each(|t| *t = Self::exp_by_inv_alpha(*t));
y.iter_mut()
.enumerate()
.for_each(|(i, t)| *t -= x_alpha_inv[i]);
state
.iter_mut()
.enumerate()
.take(Self::NUM_COLUMNS)
.for_each(|(i, t)| {
let y2 = y[i].square();
*t = x[i] + Self::mul_by_generator(&y2) + Self::DELTA;
});
state[Self::NUM_COLUMNS..].copy_from_slice(&y);
}
fn round(state: &mut [F], round_ctr: usize) {
debug_assert!(state.len() == Self::WIDTH);
Self::ark_layer(state, round_ctr);
Self::mds_layer(state);
Self::sbox_layer(state);
}
fn permutation(state: &mut [F]) {
debug_assert!(state.len() == Self::WIDTH);
for i in 0..Self::NUM_ROUNDS {
Self::round(state, i);
}
Self::mds_layer(state)
}
}