1extern crate alloc;
3
4use alloc::sync::Arc;
5use alloc::vec::Vec;
6
7use itertools::izip;
8use p3_dft::TwoAdicSubgroupDft;
9use p3_field::{Field, PrimeCharacteristicRing};
10use p3_matrix::Matrix;
11use p3_matrix::bitrev::{BitReversedMatrixView, BitReversibleMatrix};
12use p3_matrix::dense::RowMajorMatrix;
13use p3_maybe_rayon::prelude::*;
14use p3_util::{log2_ceil_usize, log2_strict_usize};
15use spin::RwLock;
16use tracing::{debug_span, instrument};
17
18mod backward;
19mod forward;
20
21use crate::{FieldParameters, MontyField31, MontyParameters, TwoAdicData};
22
23#[instrument(level = "debug", skip_all)]
25fn coset_shift_and_scale_rows<F: Field>(
26 out: &mut [F],
27 out_ncols: usize,
28 mat: &[F],
29 ncols: usize,
30 shift: F,
31 scale: F,
32) {
33 debug_assert!(out.len().is_multiple_of(out_ncols));
34 debug_assert!(mat.len().is_multiple_of(ncols));
35 debug_assert!(out_ncols >= ncols);
36 debug_assert_eq!(out.len() / out_ncols, mat.len() / ncols);
37 let powers = shift.shifted_powers(scale).collect_n(ncols);
38 out.par_chunks_exact_mut(out_ncols)
39 .zip(mat.par_chunks_exact(ncols))
40 .for_each(|(out_row, in_row)| {
41 izip!(out_row.iter_mut(), in_row, &powers).for_each(|(out, &coeff, &weight)| {
42 *out = coeff * weight;
43 });
44 });
45}
46
47#[derive(Clone, Debug)]
51struct TwiddlePair<F> {
52 twiddles: Arc<[Vec<F>]>,
53 inv_twiddles: Arc<[Vec<F>]>,
54}
55
56impl<F> Default for TwiddlePair<F> {
57 fn default() -> Self {
58 Self {
59 twiddles: Arc::from(Vec::new()),
60 inv_twiddles: Arc::from(Vec::new()),
61 }
62 }
63}
64
65#[derive(Clone, Debug, Default)]
68pub struct RecursiveDft<F> {
69 cache: Arc<RwLock<TwiddlePair<F>>>,
74}
75
76impl<MP: FieldParameters + TwoAdicData> RecursiveDft<MontyField31<MP>> {
77 pub fn new(n: usize) -> Self {
78 let res = Self::default();
79 res.update_twiddles(n);
80 res
81 }
82
83 #[inline]
84 fn decimation_in_freq_dft(
85 mat: &mut [MontyField31<MP>],
86 ncols: usize,
87 twiddles: &[Vec<MontyField31<MP>>],
88 ) {
89 if ncols > 1 {
90 let lg_fft_len = log2_strict_usize(ncols);
91 let twiddles = &twiddles[..(lg_fft_len - 1)];
92
93 mat.par_chunks_exact_mut(ncols)
94 .for_each(|v| MontyField31::forward_fft(v, twiddles));
95 }
96 }
97
98 #[inline]
99 fn decimation_in_time_dft(
100 mat: &mut [MontyField31<MP>],
101 ncols: usize,
102 twiddles: &[Vec<MontyField31<MP>>],
103 ) {
104 if ncols > 1 {
105 let lg_fft_len = p3_util::log2_strict_usize(ncols);
106 let twiddles = &twiddles[..(lg_fft_len - 1)];
107
108 mat.par_chunks_exact_mut(ncols)
109 .for_each(|v| MontyField31::backward_fft(v, twiddles));
110 }
111 }
112
113 #[instrument(skip_all)]
115 fn update_twiddles(&self, fft_len: usize) {
116 let need = log2_strict_usize(fft_len);
120
121 let have = self.cache.read().twiddles.len() + 1;
123 if have >= need {
124 return;
125 }
126
127 let missing_twiddles = MontyField31::get_missing_twiddles(need, have);
128
129 let missing_inv_twiddles = missing_twiddles
130 .iter()
131 .map(|ts| {
132 core::iter::once(MontyField31::ONE)
133 .chain(
134 ts[1..]
135 .iter()
136 .rev()
137 .map(|&t| MontyField31::new_monty(MP::PRIME - t.value)),
138 )
139 .collect()
140 })
141 .collect::<Vec<_>>();
142
143 let have_minus_one = have - 1;
145 let mut cache = self.cache.write();
146 let current_len = cache.twiddles.len();
147 if (current_len + 1) < need {
149 let extend_from = current_len.saturating_sub(have_minus_one);
150
151 let mut tw = cache.twiddles.to_vec();
152 tw.extend_from_slice(&missing_twiddles[extend_from..]);
153
154 let mut inv_tw = cache.inv_twiddles.to_vec();
155 inv_tw.extend_from_slice(&missing_inv_twiddles[extend_from..]);
156
157 cache.twiddles = tw.into();
158 cache.inv_twiddles = inv_tw.into();
159 }
160 }
161
162 fn get_twiddles(&self) -> Arc<[Vec<MontyField31<MP>>]> {
163 self.cache.read().twiddles.clone()
164 }
165
166 fn get_inv_twiddles(&self) -> Arc<[Vec<MontyField31<MP>>]> {
167 self.cache.read().inv_twiddles.clone()
168 }
169}
170
171impl<MP: MontyParameters + FieldParameters + TwoAdicData> TwoAdicSubgroupDft<MontyField31<MP>>
198 for RecursiveDft<MontyField31<MP>>
199{
200 type Evaluations = BitReversedMatrixView<RowMajorMatrix<MontyField31<MP>>>;
201
202 #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
203 fn dft_batch(&self, mut mat: RowMajorMatrix<MontyField31<MP>>) -> Self::Evaluations
204 where
205 MP: MontyParameters + FieldParameters + TwoAdicData,
206 {
207 let nrows = mat.height();
208 let ncols = mat.width();
209
210 if nrows <= 1 {
211 return mat.bit_reverse_rows();
212 }
213
214 let mut scratch = debug_span!("allocate scratch space")
215 .in_scope(|| RowMajorMatrix::default(nrows, ncols));
216
217 self.update_twiddles(nrows);
218 let twiddles = self.get_twiddles();
219
220 debug_span!("pre-transpose", nrows, ncols).in_scope(|| {
222 p3_util::transpose::transpose(&mat.values, &mut scratch.values, ncols, nrows);
223 });
224
225 debug_span!("dft batch", n_dfts = ncols, fft_len = nrows)
226 .in_scope(|| Self::decimation_in_freq_dft(&mut scratch.values, nrows, &twiddles));
227
228 debug_span!("post-transpose", nrows = ncols, ncols = nrows).in_scope(|| {
230 p3_util::transpose::transpose(&scratch.values, &mut mat.values, nrows, ncols);
231 });
232
233 mat.bit_reverse_rows()
234 }
235
236 #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
237 fn idft_batch(&self, mat: RowMajorMatrix<MontyField31<MP>>) -> RowMajorMatrix<MontyField31<MP>>
238 where
239 MP: MontyParameters + FieldParameters + TwoAdicData,
240 {
241 let nrows = mat.height();
242 let ncols = mat.width();
243 if nrows <= 1 {
244 return mat;
245 }
246
247 let mut scratch = debug_span!("allocate scratch space")
248 .in_scope(|| RowMajorMatrix::default(nrows, ncols));
249
250 let mut mat =
251 debug_span!("initial bitrev").in_scope(|| mat.bit_reverse_rows().to_row_major_matrix());
252
253 self.update_twiddles(nrows);
254 let inv_twiddles = self.get_inv_twiddles();
255
256 debug_span!("pre-transpose", nrows, ncols).in_scope(|| {
258 p3_util::transpose::transpose(&mat.values, &mut scratch.values, ncols, nrows);
259 });
260
261 debug_span!("idft", n_dfts = ncols, fft_len = nrows)
262 .in_scope(|| Self::decimation_in_time_dft(&mut scratch.values, nrows, &inv_twiddles));
263
264 debug_span!("post-transpose", nrows = ncols, ncols = nrows).in_scope(|| {
266 p3_util::transpose::transpose(&scratch.values, &mut mat.values, nrows, ncols);
267 });
268
269 let log_rows = log2_ceil_usize(nrows);
270 let inv_len = MontyField31::ONE.div_2exp_u64(log_rows as u64);
271 debug_span!("scale").in_scope(|| mat.scale(inv_len));
272 mat
273 }
274
275 #[instrument(skip_all, level = "debug", fields(dims = %mat.dimensions(), added_bits))]
276 fn coset_lde_batch(
277 &self,
278 mat: RowMajorMatrix<MontyField31<MP>>,
279 added_bits: usize,
280 shift: MontyField31<MP>,
281 ) -> Self::Evaluations {
282 let nrows = mat.height();
283 let ncols = mat.width();
284 let result_nrows = nrows << added_bits;
285
286 if nrows == 1 {
287 let dupd_rows = core::iter::repeat_n(mat.values, result_nrows)
288 .flatten()
289 .collect();
290 return RowMajorMatrix::new(dupd_rows, ncols).bit_reverse_rows();
291 }
292
293 let input_size = nrows * ncols;
294 let output_size = result_nrows * ncols;
295
296 let mat = mat.bit_reverse_rows().to_row_major_matrix();
297
298 let (mut output, mut padded) = debug_span!("allocate scratch space").in_scope(|| {
300 let output = MontyField31::<MP>::zero_vec(output_size);
302 let padded = MontyField31::<MP>::zero_vec(output_size);
303 (output, padded)
304 });
305
306 let coeffs = &mut output[..input_size];
309
310 debug_span!("pre-transpose", nrows, ncols)
311 .in_scope(|| p3_util::transpose::transpose(&mat.values, coeffs, ncols, nrows));
312
313 self.update_twiddles(result_nrows);
315 let inv_twiddles = self.get_inv_twiddles();
316 debug_span!("inverse dft batch", n_dfts = ncols, fft_len = nrows)
317 .in_scope(|| Self::decimation_in_time_dft(coeffs, nrows, &inv_twiddles));
318
319 let log_rows = log2_ceil_usize(nrows);
324 let inv_len = MontyField31::ONE.div_2exp_u64(log_rows as u64);
325 coset_shift_and_scale_rows(&mut padded, result_nrows, coeffs, nrows, shift, inv_len);
326
327 let twiddles = self.get_twiddles();
331
332 debug_span!("dft batch", n_dfts = ncols, fft_len = result_nrows)
334 .in_scope(|| Self::decimation_in_freq_dft(&mut padded, result_nrows, &twiddles));
335
336 debug_span!("post-transpose", nrows = ncols, ncols = result_nrows)
338 .in_scope(|| p3_util::transpose::transpose(&padded, &mut output, result_nrows, ncols));
339
340 RowMajorMatrix::new(output, ncols).bit_reverse_rows()
341 }
342}