1use alloc::vec;
2
3use p3_field::TwoAdicField;
4use p3_matrix::Matrix;
5use p3_matrix::dense::RowMajorMatrix;
6use p3_util::log2_strict_usize;
7
8use crate::TwoAdicSubgroupDft;
9
10#[derive(Default, Clone, Debug)]
11pub struct NaiveDft;
12
13impl<F: TwoAdicField> TwoAdicSubgroupDft<F> for NaiveDft {
14 type Evaluations = RowMajorMatrix<F>;
15 fn dft_batch(&self, mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
16 let w = mat.width();
17 let h = mat.height();
18 let log_h = log2_strict_usize(h);
19 let g = F::two_adic_generator(log_h);
20
21 let mut res = RowMajorMatrix::new(vec![F::ZERO; w * h], w);
22 for (res_r, point) in g.powers().take(h).enumerate() {
23 for (src_r, point_power) in point.powers().take(h).enumerate() {
24 for c in 0..w {
25 res.values[res_r * w + c] += point_power * mat.values[src_r * w + c];
26 }
27 }
28 }
29
30 res
31 }
32}
33
34#[cfg(test)]
35mod tests {
36 use alloc::vec;
37
38 use p3_baby_bear::BabyBear;
39 use p3_field::{Field, PrimeCharacteristicRing};
40 use p3_goldilocks::Goldilocks;
41 use p3_matrix::dense::RowMajorMatrix;
42 use rand::SeedableRng;
43 use rand::rngs::SmallRng;
44
45 use crate::{NaiveDft, TwoAdicSubgroupDft};
46
47 #[test]
48 fn basic() {
49 type F = BabyBear;
50
51 let mat = RowMajorMatrix::new(
56 vec![
57 F::from_u8(5),
58 F::from_u8(2),
59 F::ZERO,
60 F::from_u8(4),
61 F::from_u8(3),
62 F::ZERO,
63 ],
64 3,
65 );
66
67 let dft = NaiveDft.dft_batch(mat);
68 assert_eq!(
73 dft,
74 RowMajorMatrix::new(
75 vec![
76 F::from_u8(9),
77 F::from_u8(5),
78 F::ZERO,
79 F::ONE,
80 F::NEG_ONE,
81 F::ZERO,
82 ],
83 3,
84 )
85 );
86 }
87
88 #[test]
89 fn dft_idft_consistency() {
90 type F = Goldilocks;
91 let mut rng = SmallRng::seed_from_u64(1);
92 let original = RowMajorMatrix::<F>::rand(&mut rng, 8, 3);
93 let dft = NaiveDft.dft_batch(original.clone());
94 let idft = NaiveDft.idft_batch(dft);
95 assert_eq!(original, idft);
96 }
97
98 #[test]
99 fn coset_dft_idft_consistency() {
100 type F = Goldilocks;
101 let generator = F::GENERATOR;
102 let mut rng = SmallRng::seed_from_u64(1);
103 let original = RowMajorMatrix::<F>::rand(&mut rng, 8, 3);
104 let dft = NaiveDft.coset_dft_batch(original.clone(), generator);
105 let idft = NaiveDft.coset_idft_batch(dft, generator);
106 assert_eq!(original, idft);
107 }
108}