1use alloc::vec::Vec;
2
3use p3_field::{Algebra, Field, TwoAdicField};
4use p3_symmetric::Permutation;
5use p3_util::{log2_strict_usize, reverse_slice_index_bits};
6
7use crate::MdsPermutation;
8use crate::butterflies::{bowers_g_layer, bowers_g_t_layer};
9
10#[derive(Clone, Debug)]
25pub struct CosetMds<F, const N: usize> {
26 fft_twiddles: Vec<F>,
28 ifft_twiddles: Vec<F>,
30 weights: [F; N],
32}
33
34impl<F, const N: usize> Default for CosetMds<F, N>
35where
36 F: TwoAdicField,
37{
38 fn default() -> Self {
39 let log_n = log2_strict_usize(N);
40
41 let root = F::two_adic_generator(log_n);
43 let root_inv = root.inverse();
44
45 let mut fft_twiddles: Vec<F> = root.powers().collect_n(N / 2);
47 let mut ifft_twiddles: Vec<F> = root_inv.powers().collect_n(N / 2);
48 reverse_slice_index_bits(&mut fft_twiddles);
49 reverse_slice_index_bits(&mut ifft_twiddles);
50
51 let shift = F::GENERATOR;
53 let mut weights: [F; N] = shift.powers().collect_n(N).try_into().unwrap();
54 reverse_slice_index_bits(&mut weights);
55
56 Self {
57 fft_twiddles,
58 ifft_twiddles,
59 weights,
60 }
61 }
62}
63
64impl<F: Field, A: Algebra<F>, const N: usize> Permutation<[A; N]> for CosetMds<F, N> {
65 fn permute_mut(&self, values: &mut [A; N]) {
66 bowers_g_t(values, &self.ifft_twiddles);
68
69 for (value, weight) in values.iter_mut().zip(self.weights) {
71 *value = value.clone() * weight;
72 }
73
74 bowers_g(values, &self.fft_twiddles);
76 }
77}
78
79impl<F: Field, A: Algebra<F>, const N: usize> MdsPermutation<A, N> for CosetMds<F, N> {}
80
81#[inline]
85fn bowers_g<F: Field, A: Algebra<F>, const N: usize>(values: &mut [A; N], twiddles: &[F]) {
86 let log_n = log2_strict_usize(N);
87 for log_half_block_size in 0..log_n {
89 bowers_g_layer(values, log_half_block_size, twiddles);
90 }
91}
92
93#[inline]
97fn bowers_g_t<F: Field, A: Algebra<F>, const N: usize>(values: &mut [A; N], twiddles: &[F]) {
98 let log_n = log2_strict_usize(N);
99 for log_half_block_size in (0..log_n).rev() {
101 bowers_g_t_layer(values, log_half_block_size, twiddles);
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use core::array;
108
109 use p3_baby_bear::BabyBear;
110 use p3_dft::{NaiveDft, TwoAdicSubgroupDft};
111 use p3_field::{Field, PrimeCharacteristicRing, TwoAdicField};
112 use p3_goldilocks::Goldilocks;
113 use p3_symmetric::Permutation;
114 use proptest::prelude::*;
115 use rand::distr::{Distribution, StandardUniform};
116 use rand::rngs::SmallRng;
117 use rand::{RngExt, SeedableRng};
118
119 use crate::coset_mds::CosetMds;
120
121 fn matches_naive_for<F, const N: usize>()
122 where
123 F: TwoAdicField,
124 StandardUniform: Distribution<F>,
125 {
126 let mut rng = SmallRng::seed_from_u64(1);
128 let mut arr: [F; N] = array::from_fn(|_| rng.random());
129
130 let shift = F::GENERATOR;
132 let mut coset_lde_naive = NaiveDft.coset_lde(arr.to_vec(), 0, shift);
133
134 let scale = F::from_usize(N);
137 coset_lde_naive.iter_mut().for_each(|x| *x *= scale);
138
139 CosetMds::<F, N>::default().permute_mut(&mut arr);
141 assert_eq!(coset_lde_naive, arr);
142 }
143
144 macro_rules! matches_naive_test {
145 ($name:ident, $field:ty, $n:expr) => {
146 #[test]
147 fn $name() {
148 matches_naive_for::<$field, $n>();
149 }
150 };
151 }
152
153 matches_naive_test!(matches_naive_baby_bear_1, BabyBear, 1);
154 matches_naive_test!(matches_naive_baby_bear_2, BabyBear, 2);
155 matches_naive_test!(matches_naive_baby_bear_4, BabyBear, 4);
156 matches_naive_test!(matches_naive_baby_bear_8, BabyBear, 8);
157 matches_naive_test!(matches_naive_baby_bear_16, BabyBear, 16);
158 matches_naive_test!(matches_naive_baby_bear_32, BabyBear, 32);
159
160 matches_naive_test!(matches_naive_goldilocks_1, Goldilocks, 1);
161 matches_naive_test!(matches_naive_goldilocks_2, Goldilocks, 2);
162 matches_naive_test!(matches_naive_goldilocks_4, Goldilocks, 4);
163 matches_naive_test!(matches_naive_goldilocks_8, Goldilocks, 8);
164 matches_naive_test!(matches_naive_goldilocks_16, Goldilocks, 16);
165 matches_naive_test!(matches_naive_goldilocks_32, Goldilocks, 32);
166
167 #[test]
168 fn all_zeros_baby_bear() {
169 let mds = CosetMds::<BabyBear, 8>::default();
171 let mut zeros = [BabyBear::ZERO; 8];
172 mds.permute_mut(&mut zeros);
173 assert_eq!(zeros, [BabyBear::ZERO; 8]);
174 }
175
176 #[test]
177 fn all_zeros_goldilocks() {
178 let mds = CosetMds::<Goldilocks, 8>::default();
180 let mut zeros = [Goldilocks::ZERO; 8];
181 mds.permute_mut(&mut zeros);
182 assert_eq!(zeros, [Goldilocks::ZERO; 8]);
183 }
184
185 fn check_linearity<F, const N: usize>(a: [F; N], b: [F; N])
186 where
187 F: TwoAdicField,
188 {
189 let mds = CosetMds::<F, N>::default();
190
191 let mut sum: [F; N] = core::array::from_fn(|i| a[i] + b[i]);
193 mds.permute_mut(&mut sum);
194
195 let mut ra = a;
197 mds.permute_mut(&mut ra);
198 let mut rb = b;
199 mds.permute_mut(&mut rb);
200
201 let expected: [F; N] = core::array::from_fn(|i| ra[i] + rb[i]);
203 assert_eq!(sum, expected);
204 }
205
206 fn arb_babybear() -> impl Strategy<Value = BabyBear> {
207 prop::num::u32::ANY.prop_map(BabyBear::from_u32)
208 }
209
210 proptest! {
211 #[test]
212 fn coset_mds_linear_bb8(
213 a in prop::array::uniform8(arb_babybear()),
214 b in prop::array::uniform8(arb_babybear()),
215 ) {
216 check_linearity::<BabyBear, 8>(a, b);
217 }
218
219 #[test]
220 fn coset_mds_linear_bb16(
221 a in prop::array::uniform16(arb_babybear()),
222 b in prop::array::uniform16(arb_babybear()),
223 ) {
224 check_linearity::<BabyBear, 16>(a, b);
225 }
226
227 #[test]
228 fn coset_mds_matches_naive_random_bb8(input in prop::array::uniform8(arb_babybear())) {
229 let shift = BabyBear::GENERATOR;
231 let mut coset_lde_naive = NaiveDft.coset_lde(input.to_vec(), 0, shift);
232 let scale = BabyBear::from_usize(8);
233 coset_lde_naive.iter_mut().for_each(|x| *x *= scale);
234
235 let mut result = input;
237 CosetMds::<BabyBear, 8>::default().permute_mut(&mut result);
238 prop_assert_eq!(coset_lde_naive, result);
239 }
240 }
241}