Skip to main content

p3_monty_31/dft/
mod.rs

1//! An implementation of the FFT for `MontyField31`
2extern crate alloc;
3
4use alloc::sync::Arc;
5use alloc::vec::Vec;
6
7use itertools::izip;
8use p3_dft::TwoAdicSubgroupDft;
9use p3_field::{Field, PrimeCharacteristicRing};
10use p3_matrix::Matrix;
11use p3_matrix::bitrev::{BitReversedMatrixView, BitReversibleMatrix};
12use p3_matrix::dense::RowMajorMatrix;
13use p3_maybe_rayon::prelude::*;
14use p3_util::{log2_ceil_usize, log2_strict_usize};
15use spin::RwLock;
16use tracing::{debug_span, instrument};
17
18mod backward;
19mod forward;
20
21use crate::{FieldParameters, MontyField31, MontyParameters, TwoAdicData};
22
23/// Multiply each element of column `j` of `mat` by `shift**j`.
24#[instrument(level = "debug", skip_all)]
25fn coset_shift_and_scale_rows<F: Field>(
26    out: &mut [F],
27    out_ncols: usize,
28    mat: &[F],
29    ncols: usize,
30    shift: F,
31    scale: F,
32) {
33    debug_assert!(out.len().is_multiple_of(out_ncols));
34    debug_assert!(mat.len().is_multiple_of(ncols));
35    debug_assert!(out_ncols >= ncols);
36    debug_assert_eq!(out.len() / out_ncols, mat.len() / ncols);
37    let powers = shift.shifted_powers(scale).collect_n(ncols);
38    out.par_chunks_exact_mut(out_ncols)
39        .zip(mat.par_chunks_exact(ncols))
40        .for_each(|(out_row, in_row)| {
41            izip!(out_row.iter_mut(), in_row, &powers).for_each(|(out, &coeff, &weight)| {
42                *out = coeff * weight;
43            });
44        });
45}
46
47/// Paired twiddle and inverse-twiddle tables, always updated atomically
48/// under a single lock to prevent concurrent observers from seeing a
49/// half-updated state.
50#[derive(Clone, Debug)]
51struct TwiddlePair<F> {
52    twiddles: Arc<[Vec<F>]>,
53    inv_twiddles: Arc<[Vec<F>]>,
54}
55
56impl<F> Default for TwiddlePair<F> {
57    fn default() -> Self {
58        Self {
59            twiddles: Arc::from(Vec::new()),
60            inv_twiddles: Arc::from(Vec::new()),
61        }
62    }
63}
64
65/// Recursive DFT, decimation-in-frequency in the forward direction,
66/// decimation-in-time in the backward (inverse) direction.
67#[derive(Clone, Debug, Default)]
68pub struct RecursiveDft<F> {
69    /// Memoized twiddle factors, paired with their inverses.
70    ///
71    /// Both tables are stored behind a single lock so they are always
72    /// updated atomically.
73    cache: Arc<RwLock<TwiddlePair<F>>>,
74}
75
76impl<MP: FieldParameters + TwoAdicData> RecursiveDft<MontyField31<MP>> {
77    pub fn new(n: usize) -> Self {
78        let res = Self::default();
79        res.update_twiddles(n);
80        res
81    }
82
83    #[inline]
84    fn decimation_in_freq_dft(
85        mat: &mut [MontyField31<MP>],
86        ncols: usize,
87        twiddles: &[Vec<MontyField31<MP>>],
88    ) {
89        if ncols > 1 {
90            let lg_fft_len = log2_strict_usize(ncols);
91            let twiddles = &twiddles[..(lg_fft_len - 1)];
92
93            mat.par_chunks_exact_mut(ncols)
94                .for_each(|v| MontyField31::forward_fft(v, twiddles));
95        }
96    }
97
98    #[inline]
99    fn decimation_in_time_dft(
100        mat: &mut [MontyField31<MP>],
101        ncols: usize,
102        twiddles: &[Vec<MontyField31<MP>>],
103    ) {
104        if ncols > 1 {
105            let lg_fft_len = p3_util::log2_strict_usize(ncols);
106            let twiddles = &twiddles[..(lg_fft_len - 1)];
107
108            mat.par_chunks_exact_mut(ncols)
109                .for_each(|v| MontyField31::backward_fft(v, twiddles));
110        }
111    }
112
113    /// Compute twiddle factors, or take memoized ones if already available.
114    #[instrument(skip_all)]
115    fn update_twiddles(&self, fft_len: usize) {
116        // As we don't save the twiddles for the final layer where
117        // the only twiddle is 1, roots_of_unity_table(fft_len)
118        // returns a vector of twiddles of length log_2(fft_len) - 1.
119        let need = log2_strict_usize(fft_len);
120
121        // Fast path: read lock to check if we already have enough.
122        let have = self.cache.read().twiddles.len() + 1;
123        if have >= need {
124            return;
125        }
126
127        let missing_twiddles = MontyField31::get_missing_twiddles(need, have);
128
129        let missing_inv_twiddles = missing_twiddles
130            .iter()
131            .map(|ts| {
132                core::iter::once(MontyField31::ONE)
133                    .chain(
134                        ts[1..]
135                            .iter()
136                            .rev()
137                            .map(|&t| MontyField31::new_monty(MP::PRIME - t.value)),
138                    )
139                    .collect()
140            })
141            .collect::<Vec<_>>();
142
143        // Slow path: acquire write lock and update both tables atomically.
144        let have_minus_one = have - 1;
145        let mut cache = self.cache.write();
146        let current_len = cache.twiddles.len();
147        // Double-check if an update is still needed after acquiring the write lock.
148        if (current_len + 1) < need {
149            let extend_from = current_len.saturating_sub(have_minus_one);
150
151            let mut tw = cache.twiddles.to_vec();
152            tw.extend_from_slice(&missing_twiddles[extend_from..]);
153
154            let mut inv_tw = cache.inv_twiddles.to_vec();
155            inv_tw.extend_from_slice(&missing_inv_twiddles[extend_from..]);
156
157            cache.twiddles = tw.into();
158            cache.inv_twiddles = inv_tw.into();
159        }
160    }
161
162    fn get_twiddles(&self) -> Arc<[Vec<MontyField31<MP>>]> {
163        self.cache.read().twiddles.clone()
164    }
165
166    fn get_inv_twiddles(&self) -> Arc<[Vec<MontyField31<MP>>]> {
167        self.cache.read().inv_twiddles.clone()
168    }
169}
170
171/// DFT implementation that uses DIT for the inverse "backward"
172/// direction and DIF for the "forward" direction.
173///
174/// The API mandates that the LDE is applied column-wise on the
175/// _row-major_ input. This is awkward for memory coherence, so the
176/// algorithm here transposes the input and operates on the rows in
177/// the typical way, then transposes back again for the output. Even
178/// for modestly large inputs, the cost of the two transposes
179/// outweighed by the improved performance from operating row-wise.
180///
181/// The choice of DIT for inverse and DIF for "forward" transform mean
182/// that a (coset) LDE
183///
184/// - IDFT / zero extend / DFT
185///
186/// expands to
187///
188///   - bit-reverse input
189///   - invDFT DIT
190///     - result is in "correct" order
191///   - coset shift and zero extend result
192///   - DFT DIF on result
193///     - output is bit-reversed, as required for FRI.
194///
195/// Hence the only bit-reversal that needs to take place is on the input.
196///
197impl<MP: MontyParameters + FieldParameters + TwoAdicData> TwoAdicSubgroupDft<MontyField31<MP>>
198    for RecursiveDft<MontyField31<MP>>
199{
200    type Evaluations = BitReversedMatrixView<RowMajorMatrix<MontyField31<MP>>>;
201
202    #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
203    fn dft_batch(&self, mut mat: RowMajorMatrix<MontyField31<MP>>) -> Self::Evaluations
204    where
205        MP: MontyParameters + FieldParameters + TwoAdicData,
206    {
207        let nrows = mat.height();
208        let ncols = mat.width();
209
210        if nrows <= 1 {
211            return mat.bit_reverse_rows();
212        }
213
214        let mut scratch = debug_span!("allocate scratch space")
215            .in_scope(|| RowMajorMatrix::default(nrows, ncols));
216
217        self.update_twiddles(nrows);
218        let twiddles = self.get_twiddles();
219
220        // transpose input
221        debug_span!("pre-transpose", nrows, ncols).in_scope(|| {
222            p3_util::transpose::transpose(&mat.values, &mut scratch.values, ncols, nrows);
223        });
224
225        debug_span!("dft batch", n_dfts = ncols, fft_len = nrows)
226            .in_scope(|| Self::decimation_in_freq_dft(&mut scratch.values, nrows, &twiddles));
227
228        // transpose output
229        debug_span!("post-transpose", nrows = ncols, ncols = nrows).in_scope(|| {
230            p3_util::transpose::transpose(&scratch.values, &mut mat.values, nrows, ncols);
231        });
232
233        mat.bit_reverse_rows()
234    }
235
236    #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
237    fn idft_batch(&self, mat: RowMajorMatrix<MontyField31<MP>>) -> RowMajorMatrix<MontyField31<MP>>
238    where
239        MP: MontyParameters + FieldParameters + TwoAdicData,
240    {
241        let nrows = mat.height();
242        let ncols = mat.width();
243        if nrows <= 1 {
244            return mat;
245        }
246
247        let mut scratch = debug_span!("allocate scratch space")
248            .in_scope(|| RowMajorMatrix::default(nrows, ncols));
249
250        let mut mat =
251            debug_span!("initial bitrev").in_scope(|| mat.bit_reverse_rows().to_row_major_matrix());
252
253        self.update_twiddles(nrows);
254        let inv_twiddles = self.get_inv_twiddles();
255
256        // transpose input
257        debug_span!("pre-transpose", nrows, ncols).in_scope(|| {
258            p3_util::transpose::transpose(&mat.values, &mut scratch.values, ncols, nrows);
259        });
260
261        debug_span!("idft", n_dfts = ncols, fft_len = nrows)
262            .in_scope(|| Self::decimation_in_time_dft(&mut scratch.values, nrows, &inv_twiddles));
263
264        // transpose output
265        debug_span!("post-transpose", nrows = ncols, ncols = nrows).in_scope(|| {
266            p3_util::transpose::transpose(&scratch.values, &mut mat.values, nrows, ncols);
267        });
268
269        let log_rows = log2_ceil_usize(nrows);
270        let inv_len = MontyField31::ONE.div_2exp_u64(log_rows as u64);
271        debug_span!("scale").in_scope(|| mat.scale(inv_len));
272        mat
273    }
274
275    #[instrument(skip_all, level = "debug", fields(dims = %mat.dimensions(), added_bits))]
276    fn coset_lde_batch(
277        &self,
278        mat: RowMajorMatrix<MontyField31<MP>>,
279        added_bits: usize,
280        shift: MontyField31<MP>,
281    ) -> Self::Evaluations {
282        let nrows = mat.height();
283        let ncols = mat.width();
284        let result_nrows = nrows << added_bits;
285
286        if nrows == 1 {
287            let dupd_rows = core::iter::repeat_n(mat.values, result_nrows)
288                .flatten()
289                .collect();
290            return RowMajorMatrix::new(dupd_rows, ncols).bit_reverse_rows();
291        }
292
293        let input_size = nrows * ncols;
294        let output_size = result_nrows * ncols;
295
296        let mat = mat.bit_reverse_rows().to_row_major_matrix();
297
298        // Allocate space for the output and the intermediate state.
299        let (mut output, mut padded) = debug_span!("allocate scratch space").in_scope(|| {
300            // Safety: These are pretty dodgy, but work because MontyField31 is #[repr(transparent)]
301            let output = MontyField31::<MP>::zero_vec(output_size);
302            let padded = MontyField31::<MP>::zero_vec(output_size);
303            (output, padded)
304        });
305
306        // `coeffs` will hold the result of the inverse FFT; use the
307        // output storage as scratch space.
308        let coeffs = &mut output[..input_size];
309
310        debug_span!("pre-transpose", nrows, ncols)
311            .in_scope(|| p3_util::transpose::transpose(&mat.values, coeffs, ncols, nrows));
312
313        // Apply inverse DFT; result is not yet normalised.
314        self.update_twiddles(result_nrows);
315        let inv_twiddles = self.get_inv_twiddles();
316        debug_span!("inverse dft batch", n_dfts = ncols, fft_len = nrows)
317            .in_scope(|| Self::decimation_in_time_dft(coeffs, nrows, &inv_twiddles));
318
319        // At this point the inverse FFT of each column of `mat` appears
320        // as a row in `coeffs`.
321
322        // Normalise inverse DFT and coset shift in one go.
323        let log_rows = log2_ceil_usize(nrows);
324        let inv_len = MontyField31::ONE.div_2exp_u64(log_rows as u64);
325        coset_shift_and_scale_rows(&mut padded, result_nrows, coeffs, nrows, shift, inv_len);
326
327        // `padded` is implicitly zero padded since it was initialised
328        // to zeros when declared above.
329
330        let twiddles = self.get_twiddles();
331
332        // Apply DFT
333        debug_span!("dft batch", n_dfts = ncols, fft_len = result_nrows)
334            .in_scope(|| Self::decimation_in_freq_dft(&mut padded, result_nrows, &twiddles));
335
336        // transpose output
337        debug_span!("post-transpose", nrows = ncols, ncols = result_nrows)
338            .in_scope(|| p3_util::transpose::transpose(&padded, &mut output, result_nrows, ncols));
339
340        RowMajorMatrix::new(output, ncols).bit_reverse_rows()
341    }
342}