Skip to main content

p3_mds/
integrated_coset_mds.rs

1use alloc::vec::Vec;
2
3use p3_field::{Algebra, Field, Powers, TwoAdicField};
4use p3_symmetric::Permutation;
5use p3_util::{log2_strict_usize, reverse_slice_index_bits};
6
7use crate::MdsPermutation;
8use crate::butterflies::{bowers_g_layer, bowers_g_t_layer_integrated};
9
10/// Optimized Reed-Solomon MDS permutation with integrated coset shifts.
11///
12/// Compared to the standard coset-based approach:
13/// - Uses DIF + DIT (both bit-reversed) instead of DIT + DIF.
14/// - Skips bit-reversals of both input and output.
15/// - Omits the 1/N rescaling (does not affect the MDS property).
16/// - Folds the coset shift powers into the forward DFT twiddle factors,
17///   eliminating the separate weighting step.
18#[derive(Clone, Debug)]
19pub struct IntegratedCosetMds<F, const N: usize> {
20    /// Twiddle factors for the inverse DFT, bit-reversed.
21    ifft_twiddles: Vec<F>,
22    /// Per-layer twiddle factors for the forward DFT.
23    /// Each inner vector combines the standard root-of-unity powers
24    /// with the corresponding coset shift power for that layer.
25    fft_twiddles: Vec<Vec<F>>,
26}
27
28impl<F: TwoAdicField, const N: usize> Default for IntegratedCosetMds<F, N> {
29    fn default() -> Self {
30        let log_n = log2_strict_usize(N);
31
32        // Primitive N-th root of unity and its inverse.
33        let root = F::two_adic_generator(log_n);
34        let root_inv = root.inverse();
35        let coset_shift = F::GENERATOR;
36
37        // Inverse-DFT twiddles: powers of root^{-1}, bit-reversed.
38        let mut ifft_twiddles = root_inv.powers().collect_n(N / 2);
39        reverse_slice_index_bits(&mut ifft_twiddles);
40
41        // Forward-DFT twiddles: for each layer, combine the root power
42        // with the coset shift raised to the same power-of-2 exponent.
43        // This folds the separate weighting step into the DFT itself.
44        let fft_twiddles = (0..log_n)
45            .map(|layer| {
46                let powers = Powers {
47                    base: root.exp_power_of_2(layer),
48                    current: coset_shift.exp_power_of_2(layer),
49                };
50                let mut twiddles = powers.collect_n(N >> (layer + 1));
51                reverse_slice_index_bits(&mut twiddles);
52                twiddles
53            })
54            .collect();
55
56        Self {
57            ifft_twiddles,
58            fft_twiddles,
59        }
60    }
61}
62
63impl<F: Field, A: Algebra<F>, const N: usize> Permutation<[A; N]> for IntegratedCosetMds<F, N> {
64    fn permute_mut(&self, values: &mut [A; N]) {
65        let log_n = log2_strict_usize(N);
66
67        // Step 1: bit-reversed DIF (Bowers G) — acts as an inverse DFT.
68        for layer in 0..log_n {
69            bowers_g_layer(values, layer, &self.ifft_twiddles);
70        }
71
72        // Step 2: bit-reversed DIT (Bowers G^T) with integrated coset shifts.
73        //
74        // Each layer uses its own twiddle table that already includes the shift.
75        for layer in (0..log_n).rev() {
76            bowers_g_t_layer_integrated(values, layer, &self.fft_twiddles[layer]);
77        }
78    }
79}
80
81impl<F: Field, A: Algebra<F>, const N: usize> MdsPermutation<A, N> for IntegratedCosetMds<F, N> {}
82
83#[cfg(test)]
84mod tests {
85    use core::array;
86
87    use p3_baby_bear::BabyBear;
88    use p3_dft::{NaiveDft, TwoAdicSubgroupDft};
89    use p3_field::{Field, PrimeCharacteristicRing, TwoAdicField};
90    use p3_goldilocks::Goldilocks;
91    use p3_symmetric::Permutation;
92    use p3_util::reverse_slice_index_bits;
93    use proptest::prelude::*;
94    use rand::distr::{Distribution, StandardUniform};
95    use rand::rngs::SmallRng;
96    use rand::{RngExt, SeedableRng};
97
98    use crate::integrated_coset_mds::IntegratedCosetMds;
99
100    fn matches_naive_for<F, const N: usize>()
101    where
102        F: TwoAdicField,
103        StandardUniform: Distribution<F>,
104    {
105        // Generate a random input with a fixed seed.
106        let mut rng = SmallRng::seed_from_u64(1);
107        let mut arr: [F; N] = array::from_fn(|_| rng.random());
108
109        // The integrated variant works on bit-reversed data,
110        // so bit-reverse the input before feeding it to the naive reference.
111        let mut arr_rev = arr.to_vec();
112        reverse_slice_index_bits(&mut arr_rev);
113
114        // Compute the reference via naive coset LDE, then bit-reverse the output
115        // to match the integrated variant's convention.
116        let shift = F::GENERATOR;
117        let mut coset_lde_naive = NaiveDft.coset_lde(arr_rev, 0, shift);
118        reverse_slice_index_bits(&mut coset_lde_naive);
119
120        // Compensate for the omitted 1/N rescaling.
121        let scale = F::from_usize(N);
122        coset_lde_naive.iter_mut().for_each(|x| *x *= scale);
123
124        // Apply the permutation under test and compare.
125        IntegratedCosetMds::<F, N>::default().permute_mut(&mut arr);
126        assert_eq!(coset_lde_naive, arr);
127    }
128
129    macro_rules! matches_naive_test {
130        ($name:ident, $field:ty, $n:expr) => {
131            #[test]
132            fn $name() {
133                matches_naive_for::<$field, $n>();
134            }
135        };
136    }
137
138    matches_naive_test!(matches_naive_baby_bear_1, BabyBear, 1);
139    matches_naive_test!(matches_naive_baby_bear_2, BabyBear, 2);
140    matches_naive_test!(matches_naive_baby_bear_4, BabyBear, 4);
141    matches_naive_test!(matches_naive_baby_bear_8, BabyBear, 8);
142    matches_naive_test!(matches_naive_baby_bear_16, BabyBear, 16);
143    matches_naive_test!(matches_naive_baby_bear_32, BabyBear, 32);
144
145    matches_naive_test!(matches_naive_goldilocks_1, Goldilocks, 1);
146    matches_naive_test!(matches_naive_goldilocks_2, Goldilocks, 2);
147    matches_naive_test!(matches_naive_goldilocks_4, Goldilocks, 4);
148    matches_naive_test!(matches_naive_goldilocks_8, Goldilocks, 8);
149    matches_naive_test!(matches_naive_goldilocks_16, Goldilocks, 16);
150    matches_naive_test!(matches_naive_goldilocks_32, Goldilocks, 32);
151
152    #[test]
153    fn all_zeros_baby_bear() {
154        // All-zeros must map to all-zeros (the permutation is linear).
155        let mds = IntegratedCosetMds::<BabyBear, 8>::default();
156        let mut zeros = [BabyBear::ZERO; 8];
157        mds.permute_mut(&mut zeros);
158        assert_eq!(zeros, [BabyBear::ZERO; 8]);
159    }
160
161    #[test]
162    fn all_zeros_goldilocks() {
163        // Same zero-preservation check on a different field.
164        let mds = IntegratedCosetMds::<Goldilocks, 8>::default();
165        let mut zeros = [Goldilocks::ZERO; 8];
166        mds.permute_mut(&mut zeros);
167        assert_eq!(zeros, [Goldilocks::ZERO; 8]);
168    }
169
170    fn check_linearity<F, const N: usize>(a: [F; N], b: [F; N])
171    where
172        F: TwoAdicField,
173    {
174        let mds = IntegratedCosetMds::<F, N>::default();
175
176        // Apply the permutation to the element-wise sum.
177        let mut sum: [F; N] = core::array::from_fn(|i| a[i] + b[i]);
178        mds.permute_mut(&mut sum);
179
180        // Apply to each vector individually.
181        let mut ra = a;
182        mds.permute_mut(&mut ra);
183        let mut rb = b;
184        mds.permute_mut(&mut rb);
185
186        // Linearity: MDS(a + b) must equal MDS(a) + MDS(b).
187        let expected: [F; N] = core::array::from_fn(|i| ra[i] + rb[i]);
188        assert_eq!(sum, expected);
189    }
190
191    fn arb_babybear() -> impl Strategy<Value = BabyBear> {
192        prop::num::u32::ANY.prop_map(BabyBear::from_u32)
193    }
194
195    proptest! {
196        #[test]
197        fn integrated_coset_mds_linear_bb8(
198            a in prop::array::uniform8(arb_babybear()),
199            b in prop::array::uniform8(arb_babybear()),
200        ) {
201            check_linearity::<BabyBear, 8>(a, b);
202        }
203
204        #[test]
205        fn integrated_coset_mds_linear_bb16(
206            a in prop::array::uniform16(arb_babybear()),
207            b in prop::array::uniform16(arb_babybear()),
208        ) {
209            check_linearity::<BabyBear, 16>(a, b);
210        }
211
212        #[test]
213        fn integrated_coset_mds_matches_naive_random_bb8(
214            input in prop::array::uniform8(arb_babybear()),
215        ) {
216            // Bit-reverse the input and compute the naive coset LDE as reference.
217            let mut arr_rev = input.to_vec();
218            reverse_slice_index_bits(&mut arr_rev);
219
220            let shift = BabyBear::GENERATOR;
221            let mut coset_lde_naive = NaiveDft.coset_lde(arr_rev, 0, shift);
222            reverse_slice_index_bits(&mut coset_lde_naive);
223
224            // Compensate for the omitted 1/N rescaling.
225            let scale = BabyBear::from_usize(8);
226            coset_lde_naive.iter_mut().for_each(|x| *x *= scale);
227
228            // Apply the permutation and compare.
229            let mut result = input;
230            IntegratedCosetMds::<BabyBear, 8>::default().permute_mut(&mut result);
231            prop_assert_eq!(coset_lde_naive, result);
232        }
233    }
234}