Skip to main content

p3_dft/
naive.rs

1use p3_field::TwoAdicField;
2use p3_matrix::Matrix;
3use p3_matrix::dense::RowMajorMatrix;
4use p3_util::log2_strict_usize;
5
6use crate::TwoAdicSubgroupDft;
7
8#[derive(Default, Clone, Debug)]
9pub struct NaiveDft;
10
11impl<F: TwoAdicField> TwoAdicSubgroupDft<F> for NaiveDft {
12    type Evaluations = RowMajorMatrix<F>;
13    fn dft_batch(&self, mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
14        let w = mat.width();
15        let h = mat.height();
16        let log_h = log2_strict_usize(h);
17        let g = F::two_adic_generator(log_h);
18
19        let mut res = RowMajorMatrix::new(F::zero_vec(w * h), w);
20        for (res_r, point) in g.powers().take(h).enumerate() {
21            for (src_r, point_power) in point.powers().take(h).enumerate() {
22                for c in 0..w {
23                    res.values[res_r * w + c] += point_power * mat.values[src_r * w + c];
24                }
25            }
26        }
27
28        res
29    }
30}
31
32#[cfg(test)]
33mod tests {
34    use alloc::vec;
35
36    use p3_baby_bear::BabyBear;
37    use p3_field::{Field, PrimeCharacteristicRing};
38    use p3_goldilocks::Goldilocks;
39    use p3_matrix::dense::RowMajorMatrix;
40    use rand::SeedableRng;
41    use rand::rngs::SmallRng;
42
43    use crate::{NaiveDft, TwoAdicSubgroupDft};
44
45    #[test]
46    fn basic() {
47        type F = BabyBear;
48
49        // A few polynomials:
50        // 5 + 4x
51        // 2 + 3x
52        // 0
53        let mat = RowMajorMatrix::new(
54            vec![
55                F::from_u8(5),
56                F::from_u8(2),
57                F::ZERO,
58                F::from_u8(4),
59                F::from_u8(3),
60                F::ZERO,
61            ],
62            3,
63        );
64
65        let dft = NaiveDft.dft_batch(mat);
66        // Expected evaluations on {1, -1}:
67        // 9, 1
68        // 5, -1
69        // 0, 0
70        assert_eq!(
71            dft,
72            RowMajorMatrix::new(
73                vec![
74                    F::from_u8(9),
75                    F::from_u8(5),
76                    F::ZERO,
77                    F::ONE,
78                    F::NEG_ONE,
79                    F::ZERO,
80                ],
81                3,
82            )
83        );
84    }
85
86    #[test]
87    fn dft_idft_consistency() {
88        type F = Goldilocks;
89        let mut rng = SmallRng::seed_from_u64(1);
90        let original = RowMajorMatrix::<F>::rand(&mut rng, 8, 3);
91        let dft = NaiveDft.dft_batch(original.clone());
92        let idft = NaiveDft.idft_batch(dft);
93        assert_eq!(original, idft);
94    }
95
96    #[test]
97    fn coset_dft_idft_consistency() {
98        type F = Goldilocks;
99        let generator = F::GENERATOR;
100        let mut rng = SmallRng::seed_from_u64(1);
101        let original = RowMajorMatrix::<F>::rand(&mut rng, 8, 3);
102        let dft = NaiveDft.coset_dft_batch(original.clone(), generator);
103        let idft = NaiveDft.coset_idft_batch(dft, generator);
104        assert_eq!(original, idft);
105    }
106}