1extern crate alloc;
3
4use alloc::sync::Arc;
5use alloc::vec::Vec;
6
7use itertools::izip;
8use p3_dft::TwoAdicSubgroupDft;
9use p3_field::{Field, PrimeCharacteristicRing};
10use p3_matrix::Matrix;
11use p3_matrix::bitrev::{BitReversedMatrixView, BitReversibleMatrix};
12use p3_matrix::dense::RowMajorMatrix;
13use p3_maybe_rayon::prelude::*;
14use p3_util::{log2_ceil_usize, log2_strict_usize};
15use spin::RwLock;
16use tracing::{debug_span, instrument};
17
18mod backward;
19mod forward;
20
21use crate::{FieldParameters, MontyField31, MontyParameters, TwoAdicData};
22
23#[instrument(level = "debug", skip_all)]
25fn coset_shift_and_scale_rows<F: Field>(
26 out: &mut [F],
27 out_ncols: usize,
28 mat: &[F],
29 ncols: usize,
30 shift: F,
31 scale: F,
32) {
33 let powers = shift.shifted_powers(scale).collect_n(ncols);
34 out.par_chunks_exact_mut(out_ncols)
35 .zip(mat.par_chunks_exact(ncols))
36 .for_each(|(out_row, in_row)| {
37 izip!(out_row.iter_mut(), in_row, &powers).for_each(|(out, &coeff, &weight)| {
38 *out = coeff * weight;
39 });
40 });
41}
42
43#[derive(Clone, Debug, Default)]
46pub struct RecursiveDft<F> {
47 #[allow(clippy::type_complexity)]
49 twiddles: Arc<RwLock<Arc<[Vec<F>]>>>,
50 #[allow(clippy::type_complexity)]
52 inv_twiddles: Arc<RwLock<Arc<[Vec<F>]>>>,
53}
54
55impl<MP: FieldParameters + TwoAdicData> RecursiveDft<MontyField31<MP>> {
56 pub fn new(n: usize) -> Self {
57 let res = Self {
58 twiddles: Arc::default(),
59 inv_twiddles: Arc::default(),
60 };
61 res.update_twiddles(n);
62 res
63 }
64
65 #[inline]
66 fn decimation_in_freq_dft(
67 mat: &mut [MontyField31<MP>],
68 ncols: usize,
69 twiddles: &[Vec<MontyField31<MP>>],
70 ) {
71 if ncols > 1 {
72 let lg_fft_len = log2_strict_usize(ncols);
73 let twiddles = &twiddles[..(lg_fft_len - 1)];
74
75 mat.par_chunks_exact_mut(ncols)
76 .for_each(|v| MontyField31::forward_fft(v, twiddles));
77 }
78 }
79
80 #[inline]
81 fn decimation_in_time_dft(
82 mat: &mut [MontyField31<MP>],
83 ncols: usize,
84 twiddles: &[Vec<MontyField31<MP>>],
85 ) {
86 if ncols > 1 {
87 let lg_fft_len = p3_util::log2_strict_usize(ncols);
88 let twiddles = &twiddles[..(lg_fft_len - 1)];
89
90 mat.par_chunks_exact_mut(ncols)
91 .for_each(|v| MontyField31::backward_fft(v, twiddles));
92 }
93 }
94
95 #[instrument(skip_all)]
97 fn update_twiddles(&self, fft_len: usize) {
98 let need = log2_strict_usize(fft_len);
103 let snapshot = self.twiddles.read().clone();
104 let have = snapshot.len() + 1;
105 if have >= need {
106 return;
107 }
108
109 let missing_twiddles = MontyField31::get_missing_twiddles(need, have);
110
111 let missing_inv_twiddles = missing_twiddles
112 .iter()
113 .map(|ts| {
114 core::iter::once(MontyField31::ONE)
115 .chain(
116 ts[1..]
117 .iter()
118 .rev()
119 .map(|&t| MontyField31::new_monty(MP::PRIME - t.value)),
120 )
121 .collect()
122 })
123 .collect::<Vec<_>>();
124 let have_minus_one = have - 1;
126 let extend_table = |lock: &RwLock<Arc<[Vec<_>]>>, missing: &[Vec<_>]| {
127 let mut w = lock.write();
128 let current_len = w.len();
129 if (current_len + 1) < need {
131 let mut v = w.to_vec();
132 let extend_from = current_len.saturating_sub(have_minus_one);
134 v.extend_from_slice(&missing[extend_from..]);
135 *w = v.into();
136 }
137 };
138 extend_table(&self.twiddles, &missing_twiddles);
140 extend_table(&self.inv_twiddles, &missing_inv_twiddles);
141 }
142
143 fn get_twiddles(&self) -> Arc<[Vec<MontyField31<MP>>]> {
144 self.twiddles.read().clone()
145 }
146
147 fn get_inv_twiddles(&self) -> Arc<[Vec<MontyField31<MP>>]> {
148 self.inv_twiddles.read().clone()
149 }
150}
151
152impl<MP: MontyParameters + FieldParameters + TwoAdicData> TwoAdicSubgroupDft<MontyField31<MP>>
179 for RecursiveDft<MontyField31<MP>>
180{
181 type Evaluations = BitReversedMatrixView<RowMajorMatrix<MontyField31<MP>>>;
182
183 #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
184 fn dft_batch(&self, mut mat: RowMajorMatrix<MontyField31<MP>>) -> Self::Evaluations
185 where
186 MP: MontyParameters + FieldParameters + TwoAdicData,
187 {
188 let nrows = mat.height();
189 let ncols = mat.width();
190
191 if nrows <= 1 {
192 return mat.bit_reverse_rows();
193 }
194
195 let mut scratch = debug_span!("allocate scratch space")
196 .in_scope(|| RowMajorMatrix::default(nrows, ncols));
197
198 self.update_twiddles(nrows);
199 let twiddles = self.get_twiddles();
200
201 debug_span!("pre-transpose", nrows, ncols).in_scope(|| {
203 p3_util::transpose::transpose(&mat.values, &mut scratch.values, ncols, nrows);
204 });
205
206 debug_span!("dft batch", n_dfts = ncols, fft_len = nrows)
207 .in_scope(|| Self::decimation_in_freq_dft(&mut scratch.values, nrows, &twiddles));
208
209 debug_span!("post-transpose", nrows = ncols, ncols = nrows).in_scope(|| {
211 p3_util::transpose::transpose(&scratch.values, &mut mat.values, nrows, ncols);
212 });
213
214 mat.bit_reverse_rows()
215 }
216
217 #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
218 fn idft_batch(&self, mat: RowMajorMatrix<MontyField31<MP>>) -> RowMajorMatrix<MontyField31<MP>>
219 where
220 MP: MontyParameters + FieldParameters + TwoAdicData,
221 {
222 let nrows = mat.height();
223 let ncols = mat.width();
224 if nrows <= 1 {
225 return mat;
226 }
227
228 let mut scratch = debug_span!("allocate scratch space")
229 .in_scope(|| RowMajorMatrix::default(nrows, ncols));
230
231 let mut mat =
232 debug_span!("initial bitrev").in_scope(|| mat.bit_reverse_rows().to_row_major_matrix());
233
234 self.update_twiddles(nrows);
235 let inv_twiddles = self.get_inv_twiddles();
236
237 debug_span!("pre-transpose", nrows, ncols).in_scope(|| {
239 p3_util::transpose::transpose(&mat.values, &mut scratch.values, ncols, nrows);
240 });
241
242 debug_span!("idft", n_dfts = ncols, fft_len = nrows)
243 .in_scope(|| Self::decimation_in_time_dft(&mut scratch.values, nrows, &inv_twiddles));
244
245 debug_span!("post-transpose", nrows = ncols, ncols = nrows).in_scope(|| {
247 p3_util::transpose::transpose(&scratch.values, &mut mat.values, nrows, ncols);
248 });
249
250 let log_rows = log2_ceil_usize(nrows);
251 let inv_len = MontyField31::ONE.div_2exp_u64(log_rows as u64);
252 debug_span!("scale").in_scope(|| mat.scale(inv_len));
253 mat
254 }
255
256 #[instrument(skip_all, level = "debug", fields(dims = %mat.dimensions(), added_bits))]
257 fn coset_lde_batch(
258 &self,
259 mat: RowMajorMatrix<MontyField31<MP>>,
260 added_bits: usize,
261 shift: MontyField31<MP>,
262 ) -> Self::Evaluations {
263 let nrows = mat.height();
264 let ncols = mat.width();
265 let result_nrows = nrows << added_bits;
266
267 if nrows == 1 {
268 let dupd_rows = core::iter::repeat_n(mat.values, result_nrows)
269 .flatten()
270 .collect();
271 return RowMajorMatrix::new(dupd_rows, ncols).bit_reverse_rows();
272 }
273
274 let input_size = nrows * ncols;
275 let output_size = result_nrows * ncols;
276
277 let mat = mat.bit_reverse_rows().to_row_major_matrix();
278
279 let (mut output, mut padded) = debug_span!("allocate scratch space").in_scope(|| {
281 let output = MontyField31::<MP>::zero_vec(output_size);
283 let padded = MontyField31::<MP>::zero_vec(output_size);
284 (output, padded)
285 });
286
287 let coeffs = &mut output[..input_size];
290
291 debug_span!("pre-transpose", nrows, ncols)
292 .in_scope(|| p3_util::transpose::transpose(&mat.values, coeffs, ncols, nrows));
293
294 self.update_twiddles(result_nrows);
296 let inv_twiddles = self.get_inv_twiddles();
297 debug_span!("inverse dft batch", n_dfts = ncols, fft_len = nrows)
298 .in_scope(|| Self::decimation_in_time_dft(coeffs, nrows, &inv_twiddles));
299
300 let log_rows = log2_ceil_usize(nrows);
305 let inv_len = MontyField31::ONE.div_2exp_u64(log_rows as u64);
306 coset_shift_and_scale_rows(&mut padded, result_nrows, coeffs, nrows, shift, inv_len);
307
308 let twiddles = self.get_twiddles();
312
313 debug_span!("dft batch", n_dfts = ncols, fft_len = result_nrows)
315 .in_scope(|| Self::decimation_in_freq_dft(&mut padded, result_nrows, &twiddles));
316
317 debug_span!("post-transpose", nrows = ncols, ncols = result_nrows)
319 .in_scope(|| p3_util::transpose::transpose(&padded, &mut output, result_nrows, ncols));
320
321 RowMajorMatrix::new(output, ncols).bit_reverse_rows()
322 }
323}