Skip to main content

p3_mds/
coset_mds.rs

1use alloc::vec::Vec;
2
3use p3_field::{Algebra, Field, TwoAdicField};
4use p3_symmetric::Permutation;
5use p3_util::{log2_strict_usize, reverse_slice_index_bits};
6
7use crate::MdsPermutation;
8use crate::butterflies::{bowers_g_layer, bowers_g_t_layer};
9
10/// Reed-Solomon based MDS permutation.
11///
12/// Interprets the input as evaluations of a polynomial over a
13/// power-of-two subgroup, then computes evaluations over a coset
14/// of that subgroup.
15/// This is equivalent to returning the parity elements of a
16/// systematic Reed-Solomon code.
17/// Since Reed-Solomon codes are MDS, the resulting map is MDS.
18///
19/// # Algorithm
20///
21/// 1. Inverse DFT via Bowers G^T (skip bit-reversal and 1/N rescaling).
22/// 2. Multiply by powers of the coset shift.
23/// 3. Forward DFT via Bowers G (assumes bit-reversed input).
24#[derive(Clone, Debug)]
25pub struct CosetMds<F, const N: usize> {
26    /// Twiddle factors for the forward DFT, bit-reversed.
27    fft_twiddles: Vec<F>,
28    /// Twiddle factors for the inverse DFT, bit-reversed.
29    ifft_twiddles: Vec<F>,
30    /// Powers of the coset shift generator, bit-reversed.
31    weights: [F; N],
32}
33
34impl<F, const N: usize> Default for CosetMds<F, N>
35where
36    F: TwoAdicField,
37{
38    fn default() -> Self {
39        let log_n = log2_strict_usize(N);
40
41        // Primitive N-th root of unity and its inverse.
42        let root = F::two_adic_generator(log_n);
43        let root_inv = root.inverse();
44
45        // Collect N/2 powers and bit-reverse for the Bowers network layout.
46        let mut fft_twiddles: Vec<F> = root.powers().collect_n(N / 2);
47        let mut ifft_twiddles: Vec<F> = root_inv.powers().collect_n(N / 2);
48        reverse_slice_index_bits(&mut fft_twiddles);
49        reverse_slice_index_bits(&mut ifft_twiddles);
50
51        // Coset shift weights: generator^0, generator^1, ..., generator^{N-1}, bit-reversed.
52        let shift = F::GENERATOR;
53        let mut weights: [F; N] = shift.powers().collect_n(N).try_into().unwrap();
54        reverse_slice_index_bits(&mut weights);
55
56        Self {
57            fft_twiddles,
58            ifft_twiddles,
59            weights,
60        }
61    }
62}
63
64impl<F: Field, A: Algebra<F>, const N: usize> Permutation<[A; N]> for CosetMds<F, N> {
65    fn permute_mut(&self, values: &mut [A; N]) {
66        // Step 1: inverse DFT (skip bit-reversal and 1/N rescaling).
67        bowers_g_t(values, &self.ifft_twiddles);
68
69        // Step 2: multiply each coefficient by the corresponding coset shift power.
70        for (value, weight) in values.iter_mut().zip(self.weights) {
71            *value = value.clone() * weight;
72        }
73
74        // Step 3: forward DFT on the now bit-reversed, shifted coefficients.
75        bowers_g(values, &self.fft_twiddles);
76    }
77}
78
79impl<F: Field, A: Algebra<F>, const N: usize> MdsPermutation<A, N> for CosetMds<F, N> {}
80
81/// Full Bowers G network (forward DFT on bit-reversed input).
82///
83/// Applies layers from smallest to largest block size.
84#[inline]
85fn bowers_g<F: Field, A: Algebra<F>, const N: usize>(values: &mut [A; N], twiddles: &[F]) {
86    let log_n = log2_strict_usize(N);
87    // Sweep from fine blocks (size 2) to the full array.
88    for log_half_block_size in 0..log_n {
89        bowers_g_layer(values, log_half_block_size, twiddles);
90    }
91}
92
93/// Full Bowers G^T network (inverse DFT without 1/N rescaling; output is bit-reversed).
94///
95/// Applies layers from largest to smallest block size.
96#[inline]
97fn bowers_g_t<F: Field, A: Algebra<F>, const N: usize>(values: &mut [A; N], twiddles: &[F]) {
98    let log_n = log2_strict_usize(N);
99    // Sweep from the full array down to fine blocks (size 2).
100    for log_half_block_size in (0..log_n).rev() {
101        bowers_g_t_layer(values, log_half_block_size, twiddles);
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use core::array;
108
109    use p3_baby_bear::BabyBear;
110    use p3_dft::{NaiveDft, TwoAdicSubgroupDft};
111    use p3_field::{Field, PrimeCharacteristicRing, TwoAdicField};
112    use p3_goldilocks::Goldilocks;
113    use p3_symmetric::Permutation;
114    use proptest::prelude::*;
115    use rand::distr::{Distribution, StandardUniform};
116    use rand::rngs::SmallRng;
117    use rand::{RngExt, SeedableRng};
118
119    use crate::coset_mds::CosetMds;
120
121    fn matches_naive_for<F, const N: usize>()
122    where
123        F: TwoAdicField,
124        StandardUniform: Distribution<F>,
125    {
126        // Generate a random input array with a fixed seed for reproducibility.
127        let mut rng = SmallRng::seed_from_u64(1);
128        let mut arr: [F; N] = array::from_fn(|_| rng.random());
129
130        // Compute the reference via a naive coset LDE.
131        let shift = F::GENERATOR;
132        let mut coset_lde_naive = NaiveDft.coset_lde(arr.to_vec(), 0, shift);
133
134        // The Bowers-based implementation skips the 1/N rescaling,
135        // so compensate by multiplying the naive result by N.
136        let scale = F::from_usize(N);
137        coset_lde_naive.iter_mut().for_each(|x| *x *= scale);
138
139        // Apply the permutation under test and compare.
140        CosetMds::<F, N>::default().permute_mut(&mut arr);
141        assert_eq!(coset_lde_naive, arr);
142    }
143
144    macro_rules! matches_naive_test {
145        ($name:ident, $field:ty, $n:expr) => {
146            #[test]
147            fn $name() {
148                matches_naive_for::<$field, $n>();
149            }
150        };
151    }
152
153    matches_naive_test!(matches_naive_baby_bear_1, BabyBear, 1);
154    matches_naive_test!(matches_naive_baby_bear_2, BabyBear, 2);
155    matches_naive_test!(matches_naive_baby_bear_4, BabyBear, 4);
156    matches_naive_test!(matches_naive_baby_bear_8, BabyBear, 8);
157    matches_naive_test!(matches_naive_baby_bear_16, BabyBear, 16);
158    matches_naive_test!(matches_naive_baby_bear_32, BabyBear, 32);
159
160    matches_naive_test!(matches_naive_goldilocks_1, Goldilocks, 1);
161    matches_naive_test!(matches_naive_goldilocks_2, Goldilocks, 2);
162    matches_naive_test!(matches_naive_goldilocks_4, Goldilocks, 4);
163    matches_naive_test!(matches_naive_goldilocks_8, Goldilocks, 8);
164    matches_naive_test!(matches_naive_goldilocks_16, Goldilocks, 16);
165    matches_naive_test!(matches_naive_goldilocks_32, Goldilocks, 32);
166
167    #[test]
168    fn all_zeros_baby_bear() {
169        // All-zeros must map to all-zeros (the permutation is linear).
170        let mds = CosetMds::<BabyBear, 8>::default();
171        let mut zeros = [BabyBear::ZERO; 8];
172        mds.permute_mut(&mut zeros);
173        assert_eq!(zeros, [BabyBear::ZERO; 8]);
174    }
175
176    #[test]
177    fn all_zeros_goldilocks() {
178        // Same zero-preservation check on a different field.
179        let mds = CosetMds::<Goldilocks, 8>::default();
180        let mut zeros = [Goldilocks::ZERO; 8];
181        mds.permute_mut(&mut zeros);
182        assert_eq!(zeros, [Goldilocks::ZERO; 8]);
183    }
184
185    fn check_linearity<F, const N: usize>(a: [F; N], b: [F; N])
186    where
187        F: TwoAdicField,
188    {
189        let mds = CosetMds::<F, N>::default();
190
191        // Apply the permutation to the element-wise sum.
192        let mut sum: [F; N] = core::array::from_fn(|i| a[i] + b[i]);
193        mds.permute_mut(&mut sum);
194
195        // Apply to each vector individually.
196        let mut ra = a;
197        mds.permute_mut(&mut ra);
198        let mut rb = b;
199        mds.permute_mut(&mut rb);
200
201        // Linearity: MDS(a + b) must equal MDS(a) + MDS(b).
202        let expected: [F; N] = core::array::from_fn(|i| ra[i] + rb[i]);
203        assert_eq!(sum, expected);
204    }
205
206    fn arb_babybear() -> impl Strategy<Value = BabyBear> {
207        prop::num::u32::ANY.prop_map(BabyBear::from_u32)
208    }
209
210    proptest! {
211        #[test]
212        fn coset_mds_linear_bb8(
213            a in prop::array::uniform8(arb_babybear()),
214            b in prop::array::uniform8(arb_babybear()),
215        ) {
216            check_linearity::<BabyBear, 8>(a, b);
217        }
218
219        #[test]
220        fn coset_mds_linear_bb16(
221            a in prop::array::uniform16(arb_babybear()),
222            b in prop::array::uniform16(arb_babybear()),
223        ) {
224            check_linearity::<BabyBear, 16>(a, b);
225        }
226
227        #[test]
228        fn coset_mds_matches_naive_random_bb8(input in prop::array::uniform8(arb_babybear())) {
229            // Compute the naive reference, scaled by N to match the un-rescaled Bowers output.
230            let shift = BabyBear::GENERATOR;
231            let mut coset_lde_naive = NaiveDft.coset_lde(input.to_vec(), 0, shift);
232            let scale = BabyBear::from_usize(8);
233            coset_lde_naive.iter_mut().for_each(|x| *x *= scale);
234
235            // Apply the permutation under test and compare.
236            let mut result = input;
237            CosetMds::<BabyBear, 8>::default().permute_mut(&mut result);
238            prop_assert_eq!(coset_lde_naive, result);
239        }
240    }
241}