1use alloc::vec::Vec;
2
3use p3_field::{Algebra, Field, TwoAdicField};
4use p3_symmetric::Permutation;
5use p3_util::{log2_strict_usize, reverse_slice_index_bits};
6
7use crate::MdsPermutation;
8use crate::butterflies::{bowers_g_layer, bowers_g_t_layer};
9
10#[derive(Clone, Debug)]
17pub struct CosetMds<F, const N: usize> {
18 fft_twiddles: Vec<F>,
19 ifft_twiddles: Vec<F>,
20 weights: [F; N],
21}
22
23impl<F, const N: usize> Default for CosetMds<F, N>
24where
25 F: TwoAdicField,
26{
27 fn default() -> Self {
28 let log_n = log2_strict_usize(N);
29
30 let root = F::two_adic_generator(log_n);
31 let root_inv = root.inverse();
32 let mut fft_twiddles: Vec<F> = root.powers().collect_n(N / 2);
33 let mut ifft_twiddles: Vec<F> = root_inv.powers().collect_n(N / 2);
34 reverse_slice_index_bits(&mut fft_twiddles);
35 reverse_slice_index_bits(&mut ifft_twiddles);
36
37 let shift = F::GENERATOR;
38 let mut weights: [F; N] = shift.powers().collect_n(N).try_into().unwrap();
39 reverse_slice_index_bits(&mut weights);
40 Self {
41 fft_twiddles,
42 ifft_twiddles,
43 weights,
44 }
45 }
46}
47
48impl<F: TwoAdicField, A: Algebra<F>, const N: usize> Permutation<[A; N]> for CosetMds<F, N> {
49 fn permute(&self, mut input: [A; N]) -> [A; N] {
50 self.permute_mut(&mut input);
51 input
52 }
53
54 fn permute_mut(&self, values: &mut [A; N]) {
55 bowers_g_t(values, &self.ifft_twiddles);
57
58 for (value, weight) in values.iter_mut().zip(self.weights) {
60 *value = value.clone() * weight;
61 }
62
63 bowers_g(values, &self.fft_twiddles);
65 }
66}
67
68impl<F: TwoAdicField, A: Algebra<F>, const N: usize> MdsPermutation<A, N> for CosetMds<F, N> {}
69
70#[inline]
73fn bowers_g<F: Field, A: Algebra<F>, const N: usize>(values: &mut [A; N], twiddles: &[F]) {
74 let log_n = log2_strict_usize(N);
75 for log_half_block_size in 0..log_n {
76 bowers_g_layer(values, log_half_block_size, twiddles);
77 }
78}
79
80#[inline]
83fn bowers_g_t<F: Field, A: Algebra<F>, const N: usize>(values: &mut [A; N], twiddles: &[F]) {
84 let log_n = log2_strict_usize(N);
85 for log_half_block_size in (0..log_n).rev() {
86 bowers_g_t_layer(values, log_half_block_size, twiddles);
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use p3_baby_bear::BabyBear;
93 use p3_dft::{NaiveDft, TwoAdicSubgroupDft};
94 use p3_field::{Field, PrimeCharacteristicRing};
95 use p3_symmetric::Permutation;
96 use rand::rngs::SmallRng;
97 use rand::{Rng, SeedableRng};
98
99 use crate::coset_mds::CosetMds;
100
101 #[test]
102 fn matches_naive() {
103 type F = BabyBear;
104 const N: usize = 8;
105
106 let mut rng = SmallRng::seed_from_u64(1);
107 let mut arr: [F; N] = rng.random();
108
109 let shift = F::GENERATOR;
110 let mut coset_lde_naive = NaiveDft.coset_lde(arr.to_vec(), 0, shift);
111 coset_lde_naive
112 .iter_mut()
113 .for_each(|x| *x *= F::from_u8(N as u8));
114 CosetMds::default().permute_mut(&mut arr);
115 assert_eq!(coset_lde_naive, arr);
116 }
117}