1extern crate alloc;
3
4use alloc::sync::Arc;
5use alloc::vec::Vec;
6
7use itertools::izip;
8use p3_dft::TwoAdicSubgroupDft;
9use p3_field::{Field, PrimeCharacteristicRing};
10use p3_matrix::Matrix;
11use p3_matrix::bitrev::{BitReversedMatrixView, BitReversibleMatrix};
12use p3_matrix::dense::RowMajorMatrix;
13use p3_maybe_rayon::prelude::*;
14use p3_util::{log2_ceil_usize, log2_strict_usize};
15use spin::RwLock;
16use tracing::{debug_span, instrument};
17
18mod backward;
19mod forward;
20
21use crate::{FieldParameters, MontyField31, MontyParameters, TwoAdicData};
22
23#[instrument(level = "debug", skip_all)]
25fn coset_shift_and_scale_rows<F: Field>(
26 out: &mut [F],
27 out_ncols: usize,
28 mat: &[F],
29 ncols: usize,
30 shift: F,
31 scale: F,
32) {
33 let powers = shift.shifted_powers(scale).collect_n(ncols);
34 out.par_chunks_exact_mut(out_ncols)
35 .zip(mat.par_chunks_exact(ncols))
36 .for_each(|(out_row, in_row)| {
37 izip!(out_row.iter_mut(), in_row, &powers).for_each(|(out, &coeff, &weight)| {
38 *out = coeff * weight;
39 });
40 });
41}
42
43#[derive(Clone, Debug, Default)]
46pub struct RecursiveDft<F> {
47 #[allow(clippy::type_complexity)]
49 twiddles: Arc<RwLock<Arc<[Vec<F>]>>>,
50 #[allow(clippy::type_complexity)]
52 inv_twiddles: Arc<RwLock<Arc<[Vec<F>]>>>,
53}
54
55impl<MP: FieldParameters + TwoAdicData> RecursiveDft<MontyField31<MP>> {
56 pub fn new(n: usize) -> Self {
57 let res = Self {
58 twiddles: Arc::default(),
59 inv_twiddles: Arc::default(),
60 };
61 res.update_twiddles(n);
62 res
63 }
64
65 #[inline]
66 fn decimation_in_freq_dft(
67 mat: &mut [MontyField31<MP>],
68 ncols: usize,
69 twiddles: &[Vec<MontyField31<MP>>],
70 ) {
71 if ncols > 1 {
72 let lg_fft_len = log2_strict_usize(ncols);
73 let twiddles = &twiddles[..(lg_fft_len - 1)];
74
75 mat.par_chunks_exact_mut(ncols)
76 .for_each(|v| MontyField31::forward_fft(v, twiddles));
77 }
78 }
79
80 #[inline]
81 fn decimation_in_time_dft(
82 mat: &mut [MontyField31<MP>],
83 ncols: usize,
84 twiddles: &[Vec<MontyField31<MP>>],
85 ) {
86 if ncols > 1 {
87 let lg_fft_len = p3_util::log2_strict_usize(ncols);
88 let twiddles = &twiddles[..(lg_fft_len - 1)];
89
90 mat.par_chunks_exact_mut(ncols)
91 .for_each(|v| MontyField31::backward_fft(v, twiddles));
92 }
93 }
94
95 #[instrument(skip_all)]
97 fn update_twiddles(&self, fft_len: usize) {
98 let need = log2_strict_usize(fft_len);
103 let snapshot = self.twiddles.read().clone();
104 let have = snapshot.len() + 1;
105 if have >= need {
106 return;
107 }
108
109 let missing_twiddles = MontyField31::get_missing_twiddles(need, have);
110
111 let missing_inv_twiddles = missing_twiddles
112 .iter()
113 .map(|ts| {
114 core::iter::once(MontyField31::ONE)
115 .chain(
116 ts[1..]
117 .iter()
118 .rev()
119 .map(|&t| MontyField31::new_monty(MP::PRIME - t.value)),
120 )
121 .collect()
122 })
123 .collect::<Vec<_>>();
124 let extend_table = |lock: &RwLock<Arc<[Vec<_>]>>, missing: &[Vec<_>]| {
126 let mut w = lock.write();
127 let current_len = w.len();
128 if (current_len + 1) < need {
130 let mut v = w.to_vec();
131 let extend_from = current_len.saturating_sub(current_len);
133 v.extend_from_slice(&missing[extend_from..]);
134 *w = v.into();
135 }
136 };
137 extend_table(&self.twiddles, &missing_twiddles);
139 extend_table(&self.inv_twiddles, &missing_inv_twiddles);
140 }
141
142 fn get_twiddles(&self) -> Arc<[Vec<MontyField31<MP>>]> {
143 self.twiddles.read().clone()
144 }
145
146 fn get_inv_twiddles(&self) -> Arc<[Vec<MontyField31<MP>>]> {
147 self.inv_twiddles.read().clone()
148 }
149}
150
151impl<MP: MontyParameters + FieldParameters + TwoAdicData> TwoAdicSubgroupDft<MontyField31<MP>>
178 for RecursiveDft<MontyField31<MP>>
179{
180 type Evaluations = BitReversedMatrixView<RowMajorMatrix<MontyField31<MP>>>;
181
182 #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
183 fn dft_batch(&self, mut mat: RowMajorMatrix<MontyField31<MP>>) -> Self::Evaluations
184 where
185 MP: MontyParameters + FieldParameters + TwoAdicData,
186 {
187 let nrows = mat.height();
188 let ncols = mat.width();
189
190 if nrows <= 1 {
191 return mat.bit_reverse_rows();
192 }
193
194 let mut scratch = debug_span!("allocate scratch space")
195 .in_scope(|| RowMajorMatrix::default(nrows, ncols));
196
197 self.update_twiddles(nrows);
198 let twiddles = self.get_twiddles();
199
200 debug_span!("pre-transpose", nrows, ncols)
202 .in_scope(|| transpose::transpose(&mat.values, &mut scratch.values, ncols, nrows));
203
204 debug_span!("dft batch", n_dfts = ncols, fft_len = nrows)
205 .in_scope(|| Self::decimation_in_freq_dft(&mut scratch.values, nrows, &twiddles));
206
207 debug_span!("post-transpose", nrows = ncols, ncols = nrows)
209 .in_scope(|| transpose::transpose(&scratch.values, &mut mat.values, nrows, ncols));
210
211 mat.bit_reverse_rows()
212 }
213
214 #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
215 fn idft_batch(&self, mat: RowMajorMatrix<MontyField31<MP>>) -> RowMajorMatrix<MontyField31<MP>>
216 where
217 MP: MontyParameters + FieldParameters + TwoAdicData,
218 {
219 let nrows = mat.height();
220 let ncols = mat.width();
221 if nrows <= 1 {
222 return mat;
223 }
224
225 let mut scratch = debug_span!("allocate scratch space")
226 .in_scope(|| RowMajorMatrix::default(nrows, ncols));
227
228 let mut mat =
229 debug_span!("initial bitrev").in_scope(|| mat.bit_reverse_rows().to_row_major_matrix());
230
231 self.update_twiddles(nrows);
232 let inv_twiddles = self.get_inv_twiddles();
233
234 debug_span!("pre-transpose", nrows, ncols)
236 .in_scope(|| transpose::transpose(&mat.values, &mut scratch.values, ncols, nrows));
237
238 debug_span!("idft", n_dfts = ncols, fft_len = nrows)
239 .in_scope(|| Self::decimation_in_time_dft(&mut scratch.values, nrows, &inv_twiddles));
240
241 debug_span!("post-transpose", nrows = ncols, ncols = nrows)
243 .in_scope(|| transpose::transpose(&scratch.values, &mut mat.values, nrows, ncols));
244
245 let log_rows = log2_ceil_usize(nrows);
246 let inv_len = MontyField31::ONE.div_2exp_u64(log_rows as u64);
247 debug_span!("scale").in_scope(|| mat.scale(inv_len));
248 mat
249 }
250
251 #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
252 fn coset_lde_batch(
253 &self,
254 mat: RowMajorMatrix<MontyField31<MP>>,
255 added_bits: usize,
256 shift: MontyField31<MP>,
257 ) -> Self::Evaluations {
258 let nrows = mat.height();
259 let ncols = mat.width();
260 let result_nrows = nrows << added_bits;
261
262 if nrows == 1 {
263 let dupd_rows = core::iter::repeat_n(mat.values, result_nrows)
264 .flatten()
265 .collect();
266 return RowMajorMatrix::new(dupd_rows, ncols).bit_reverse_rows();
267 }
268
269 let input_size = nrows * ncols;
270 let output_size = result_nrows * ncols;
271
272 let mat = mat.bit_reverse_rows().to_row_major_matrix();
273
274 let (mut output, mut padded) = debug_span!("allocate scratch space").in_scope(|| {
276 let output = MontyField31::<MP>::zero_vec(output_size);
278 let padded = MontyField31::<MP>::zero_vec(output_size);
279 (output, padded)
280 });
281
282 let coeffs = &mut output[..input_size];
285
286 debug_span!("pre-transpose", nrows, ncols)
287 .in_scope(|| transpose::transpose(&mat.values, coeffs, ncols, nrows));
288
289 self.update_twiddles(result_nrows);
291 let inv_twiddles = self.get_inv_twiddles();
292 debug_span!("inverse dft batch", n_dfts = ncols, fft_len = nrows)
293 .in_scope(|| Self::decimation_in_time_dft(coeffs, nrows, &inv_twiddles));
294
295 let log_rows = log2_ceil_usize(nrows);
300 let inv_len = MontyField31::ONE.div_2exp_u64(log_rows as u64);
301 coset_shift_and_scale_rows(&mut padded, result_nrows, coeffs, nrows, shift, inv_len);
302
303 let twiddles = self.get_twiddles();
307
308 debug_span!("dft batch", n_dfts = ncols, fft_len = result_nrows)
310 .in_scope(|| Self::decimation_in_freq_dft(&mut padded, result_nrows, &twiddles));
311
312 debug_span!("post-transpose", nrows = ncols, ncols = result_nrows)
314 .in_scope(|| transpose::transpose(&padded, &mut output, result_nrows, ncols));
315
316 RowMajorMatrix::new(output, ncols).bit_reverse_rows()
317 }
318}