1use alloc::vec::Vec;
2
3use p3_field::{Algebra, Field, Powers, TwoAdicField};
4use p3_symmetric::Permutation;
5use p3_util::{log2_strict_usize, reverse_slice_index_bits};
6
7use crate::MdsPermutation;
8use crate::butterflies::{bowers_g_layer, bowers_g_t_layer_integrated};
9
10#[derive(Clone, Debug)]
19pub struct IntegratedCosetMds<F, const N: usize> {
20 ifft_twiddles: Vec<F>,
22 fft_twiddles: Vec<Vec<F>>,
26}
27
28impl<F: TwoAdicField, const N: usize> Default for IntegratedCosetMds<F, N> {
29 fn default() -> Self {
30 let log_n = log2_strict_usize(N);
31
32 let root = F::two_adic_generator(log_n);
34 let root_inv = root.inverse();
35 let coset_shift = F::GENERATOR;
36
37 let mut ifft_twiddles = root_inv.powers().collect_n(N / 2);
39 reverse_slice_index_bits(&mut ifft_twiddles);
40
41 let fft_twiddles = (0..log_n)
45 .map(|layer| {
46 let powers = Powers {
47 base: root.exp_power_of_2(layer),
48 current: coset_shift.exp_power_of_2(layer),
49 };
50 let mut twiddles = powers.collect_n(N >> (layer + 1));
51 reverse_slice_index_bits(&mut twiddles);
52 twiddles
53 })
54 .collect();
55
56 Self {
57 ifft_twiddles,
58 fft_twiddles,
59 }
60 }
61}
62
63impl<F: Field, A: Algebra<F>, const N: usize> Permutation<[A; N]> for IntegratedCosetMds<F, N> {
64 fn permute_mut(&self, values: &mut [A; N]) {
65 let log_n = log2_strict_usize(N);
66
67 for layer in 0..log_n {
69 bowers_g_layer(values, layer, &self.ifft_twiddles);
70 }
71
72 for layer in (0..log_n).rev() {
76 bowers_g_t_layer_integrated(values, layer, &self.fft_twiddles[layer]);
77 }
78 }
79}
80
81impl<F: Field, A: Algebra<F>, const N: usize> MdsPermutation<A, N> for IntegratedCosetMds<F, N> {}
82
83#[cfg(test)]
84mod tests {
85 use core::array;
86
87 use p3_baby_bear::BabyBear;
88 use p3_dft::{NaiveDft, TwoAdicSubgroupDft};
89 use p3_field::{Field, PrimeCharacteristicRing, TwoAdicField};
90 use p3_goldilocks::Goldilocks;
91 use p3_symmetric::Permutation;
92 use p3_util::reverse_slice_index_bits;
93 use proptest::prelude::*;
94 use rand::distr::{Distribution, StandardUniform};
95 use rand::rngs::SmallRng;
96 use rand::{RngExt, SeedableRng};
97
98 use crate::integrated_coset_mds::IntegratedCosetMds;
99
100 fn matches_naive_for<F, const N: usize>()
101 where
102 F: TwoAdicField,
103 StandardUniform: Distribution<F>,
104 {
105 let mut rng = SmallRng::seed_from_u64(1);
107 let mut arr: [F; N] = array::from_fn(|_| rng.random());
108
109 let mut arr_rev = arr.to_vec();
112 reverse_slice_index_bits(&mut arr_rev);
113
114 let shift = F::GENERATOR;
117 let mut coset_lde_naive = NaiveDft.coset_lde(arr_rev, 0, shift);
118 reverse_slice_index_bits(&mut coset_lde_naive);
119
120 let scale = F::from_usize(N);
122 coset_lde_naive.iter_mut().for_each(|x| *x *= scale);
123
124 IntegratedCosetMds::<F, N>::default().permute_mut(&mut arr);
126 assert_eq!(coset_lde_naive, arr);
127 }
128
129 macro_rules! matches_naive_test {
130 ($name:ident, $field:ty, $n:expr) => {
131 #[test]
132 fn $name() {
133 matches_naive_for::<$field, $n>();
134 }
135 };
136 }
137
138 matches_naive_test!(matches_naive_baby_bear_1, BabyBear, 1);
139 matches_naive_test!(matches_naive_baby_bear_2, BabyBear, 2);
140 matches_naive_test!(matches_naive_baby_bear_4, BabyBear, 4);
141 matches_naive_test!(matches_naive_baby_bear_8, BabyBear, 8);
142 matches_naive_test!(matches_naive_baby_bear_16, BabyBear, 16);
143 matches_naive_test!(matches_naive_baby_bear_32, BabyBear, 32);
144
145 matches_naive_test!(matches_naive_goldilocks_1, Goldilocks, 1);
146 matches_naive_test!(matches_naive_goldilocks_2, Goldilocks, 2);
147 matches_naive_test!(matches_naive_goldilocks_4, Goldilocks, 4);
148 matches_naive_test!(matches_naive_goldilocks_8, Goldilocks, 8);
149 matches_naive_test!(matches_naive_goldilocks_16, Goldilocks, 16);
150 matches_naive_test!(matches_naive_goldilocks_32, Goldilocks, 32);
151
152 #[test]
153 fn all_zeros_baby_bear() {
154 let mds = IntegratedCosetMds::<BabyBear, 8>::default();
156 let mut zeros = [BabyBear::ZERO; 8];
157 mds.permute_mut(&mut zeros);
158 assert_eq!(zeros, [BabyBear::ZERO; 8]);
159 }
160
161 #[test]
162 fn all_zeros_goldilocks() {
163 let mds = IntegratedCosetMds::<Goldilocks, 8>::default();
165 let mut zeros = [Goldilocks::ZERO; 8];
166 mds.permute_mut(&mut zeros);
167 assert_eq!(zeros, [Goldilocks::ZERO; 8]);
168 }
169
170 fn check_linearity<F, const N: usize>(a: [F; N], b: [F; N])
171 where
172 F: TwoAdicField,
173 {
174 let mds = IntegratedCosetMds::<F, N>::default();
175
176 let mut sum: [F; N] = core::array::from_fn(|i| a[i] + b[i]);
178 mds.permute_mut(&mut sum);
179
180 let mut ra = a;
182 mds.permute_mut(&mut ra);
183 let mut rb = b;
184 mds.permute_mut(&mut rb);
185
186 let expected: [F; N] = core::array::from_fn(|i| ra[i] + rb[i]);
188 assert_eq!(sum, expected);
189 }
190
191 fn arb_babybear() -> impl Strategy<Value = BabyBear> {
192 prop::num::u32::ANY.prop_map(BabyBear::from_u32)
193 }
194
195 proptest! {
196 #[test]
197 fn integrated_coset_mds_linear_bb8(
198 a in prop::array::uniform8(arb_babybear()),
199 b in prop::array::uniform8(arb_babybear()),
200 ) {
201 check_linearity::<BabyBear, 8>(a, b);
202 }
203
204 #[test]
205 fn integrated_coset_mds_linear_bb16(
206 a in prop::array::uniform16(arb_babybear()),
207 b in prop::array::uniform16(arb_babybear()),
208 ) {
209 check_linearity::<BabyBear, 16>(a, b);
210 }
211
212 #[test]
213 fn integrated_coset_mds_matches_naive_random_bb8(
214 input in prop::array::uniform8(arb_babybear()),
215 ) {
216 let mut arr_rev = input.to_vec();
218 reverse_slice_index_bits(&mut arr_rev);
219
220 let shift = BabyBear::GENERATOR;
221 let mut coset_lde_naive = NaiveDft.coset_lde(arr_rev, 0, shift);
222 reverse_slice_index_bits(&mut coset_lde_naive);
223
224 let scale = BabyBear::from_usize(8);
226 coset_lde_naive.iter_mut().for_each(|x| *x *= scale);
227
228 let mut result = input;
230 IntegratedCosetMds::<BabyBear, 8>::default().permute_mut(&mut result);
231 prop_assert_eq!(coset_lde_naive, result);
232 }
233 }
234}