1use alloc::vec::Vec;
14
15use itertools::{Itertools, izip};
16use p3_dft::TwoAdicSubgroupDft;
17use p3_field::extension::Complex;
18use p3_field::{Field, PrimeCharacteristicRing, TwoAdicField};
19use p3_matrix::Matrix;
20use p3_matrix::dense::RowMajorMatrix;
21use p3_util::log2_strict_usize;
22
23use crate::Mersenne31;
24
25type F = Mersenne31;
26type C = Complex<Mersenne31>;
27
28fn dft_preprocess(input: &RowMajorMatrix<F>) -> RowMajorMatrix<C> {
40 assert!(
41 input.height().is_multiple_of(2),
42 "input height must be even"
43 );
44 RowMajorMatrix::new(
45 input
46 .rows()
47 .tuples()
48 .flat_map(|(row_0, row_1)| {
49 row_0.zip(row_1).map(|(x, y)| C::new_complex(x, y))
54 })
55 .collect(),
56 input.width(),
57 )
58}
59
60fn dft_postprocess(input: &RowMajorMatrix<C>) -> RowMajorMatrix<C> {
67 let h = input.height();
68 let log2_h = log2_strict_usize(h); let omega = C::two_adic_generator(log2_h + 1);
73 let mut omega_j = omega;
74
75 let mut output = Vec::with_capacity((h + 1) * input.width());
76 output.extend(
77 input
78 .first_row()
79 .unwrap() .into_iter()
81 .map(|x| C::new_real(x.real() + x.imag())),
82 );
83
84 for j in 1..h {
85 let row_iter = unsafe {
86 izip!(input.row_unchecked(j), input.row_unchecked(h - j))
88 };
89 let row = row_iter.map(|(x, y)| {
90 let even = x + y.conjugate();
91 let odd = C::new_complex(x.imag() + y.imag(), y.real() - x.real());
93 (even + odd * omega_j).halve()
94 });
95 output.extend(row);
96 omega_j *= omega;
97 }
98
99 output.extend(
100 input
101 .first_row()
102 .unwrap() .into_iter()
104 .map(|x| C::new_real(x.real() - x.imag())),
105 );
106 debug_assert_eq!(output.len(), (h + 1) * input.width());
107 RowMajorMatrix::new(output, input.width())
108}
109
110fn idft_preprocess(input: &RowMajorMatrix<C>) -> RowMajorMatrix<C> {
117 let h = input.height() - 1;
118 let log2_h = log2_strict_usize(h); let omega = C::two_adic_generator(log2_h + 1).inverse();
123 let mut omega_j = C::ONE;
124
125 let mut output = Vec::with_capacity(h * input.width());
126 for j in 0..h {
128 let row_iter = unsafe {
129 izip!(input.row_unchecked(j), input.row_unchecked(h - j))
131 };
132 let row = row_iter.map(|(x, y)| {
133 let even = x + y.conjugate();
134 let odd = C::new_complex(x.imag() + y.imag(), y.real() - x.real());
136 (even - odd * omega_j).halve()
137 });
138 output.extend(row);
139 omega_j *= omega;
140 }
141 RowMajorMatrix::new(output, input.width())
142}
143
144fn idft_postprocess(input: &RowMajorMatrix<C>) -> RowMajorMatrix<F> {
150 let mut output = Vec::with_capacity(input.width() * input.height() * 2);
156 let mut buf = Vec::with_capacity(input.width());
157
158 for row in input.rows() {
162 for ext in row {
163 output.push(ext.real());
164 buf.push(ext.imag());
165 }
166 output.append(&mut buf);
167 }
168
169 RowMajorMatrix::new(output, input.width())
170}
171
172#[derive(Debug, Default, Clone)]
174pub struct Mersenne31Dft;
175
176impl Mersenne31Dft {
177 pub fn dft_batch<Dft: TwoAdicSubgroupDft<C>>(mat: &RowMajorMatrix<F>) -> RowMajorMatrix<C> {
184 let dft = Dft::default();
185 dft_postprocess(&dft.dft_batch(dft_preprocess(mat)).to_row_major_matrix())
186 }
187
188 pub fn idft_batch<Dft: TwoAdicSubgroupDft<C>>(mat: &RowMajorMatrix<C>) -> RowMajorMatrix<F> {
192 let dft = Dft::default();
193 idft_postprocess(&dft.idft_batch(idft_preprocess(mat)))
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use rand::distr::{Distribution, StandardUniform};
200 use rand::rngs::SmallRng;
201 use rand::{Rng, SeedableRng};
202
203 use super::*;
204 use crate::Mersenne31ComplexRadix2Dit;
205
206 type Base = Mersenne31;
207 type Dft = Mersenne31ComplexRadix2Dit;
208
209 #[test]
210 fn consistency()
211 where
212 StandardUniform: Distribution<Base>,
213 {
214 const N: usize = 1 << 12;
215 let rng = SmallRng::seed_from_u64(1);
216 let input = rng
217 .sample_iter(StandardUniform)
218 .take(N)
219 .collect::<Vec<Base>>();
220 let input = RowMajorMatrix::new_col(input);
221 let fft_input = Mersenne31Dft::dft_batch::<Dft>(&input);
222 let output = Mersenne31Dft::idft_batch::<Dft>(&fft_input);
223 assert_eq!(input, output);
224 }
225
226 #[test]
227 fn convolution()
228 where
229 StandardUniform: Distribution<Base>,
230 {
231 const N: usize = 1 << 6;
232 let rng = SmallRng::seed_from_u64(1);
233 let v = rng
234 .sample_iter(StandardUniform)
235 .take(2 * N)
236 .collect::<Vec<Base>>();
237 let a = RowMajorMatrix::new_col(v[..N].to_vec());
238 let b = RowMajorMatrix::new_col(v[N..].to_vec());
239
240 let fft_a = Mersenne31Dft::dft_batch::<Dft>(&a);
241 let fft_b = Mersenne31Dft::dft_batch::<Dft>(&b);
242
243 let fft_c = fft_a
244 .values
245 .iter()
246 .zip(fft_b.values.iter())
247 .map(|(&xi, &yi)| xi * yi)
248 .collect();
249 let fft_c = RowMajorMatrix::new_col(fft_c);
250
251 let c = Mersenne31Dft::idft_batch::<Dft>(&fft_c);
252
253 let mut conv = Vec::with_capacity(N);
254 for i in 0..N {
255 let mut t = Base::ZERO;
256 for j in 0..N {
257 t += a.values[j] * b.values[(N + i - j) % N];
258 }
259 conv.push(t);
260 }
261
262 assert_eq!(c.values, conv);
263 }
264}