p3_monty_31/dft/
mod.rs

1//! An implementation of the FFT for `MontyField31`
2extern crate alloc;
3
4use alloc::sync::Arc;
5use alloc::vec::Vec;
6
7use itertools::izip;
8use p3_dft::TwoAdicSubgroupDft;
9use p3_field::{Field, PrimeCharacteristicRing};
10use p3_matrix::Matrix;
11use p3_matrix::bitrev::{BitReversedMatrixView, BitReversibleMatrix};
12use p3_matrix::dense::RowMajorMatrix;
13use p3_maybe_rayon::prelude::*;
14use p3_util::{log2_ceil_usize, log2_strict_usize};
15use spin::RwLock;
16use tracing::{debug_span, instrument};
17
18mod backward;
19mod forward;
20
21use crate::{FieldParameters, MontyField31, MontyParameters, TwoAdicData};
22
23/// Multiply each element of column `j` of `mat` by `shift**j`.
24#[instrument(level = "debug", skip_all)]
25fn coset_shift_and_scale_rows<F: Field>(
26    out: &mut [F],
27    out_ncols: usize,
28    mat: &[F],
29    ncols: usize,
30    shift: F,
31    scale: F,
32) {
33    let powers = shift.shifted_powers(scale).collect_n(ncols);
34    out.par_chunks_exact_mut(out_ncols)
35        .zip(mat.par_chunks_exact(ncols))
36        .for_each(|(out_row, in_row)| {
37            izip!(out_row.iter_mut(), in_row, &powers).for_each(|(out, &coeff, &weight)| {
38                *out = coeff * weight;
39            });
40        });
41}
42
43/// Recursive DFT, decimation-in-frequency in the forward direction,
44/// decimation-in-time in the backward (inverse) direction.
45#[derive(Clone, Debug, Default)]
46pub struct RecursiveDft<F> {
47    /// Forward twiddle tables
48    #[allow(clippy::type_complexity)]
49    twiddles: Arc<RwLock<Arc<[Vec<F>]>>>,
50    /// Inverse twiddle tables
51    #[allow(clippy::type_complexity)]
52    inv_twiddles: Arc<RwLock<Arc<[Vec<F>]>>>,
53}
54
55impl<MP: FieldParameters + TwoAdicData> RecursiveDft<MontyField31<MP>> {
56    pub fn new(n: usize) -> Self {
57        let res = Self {
58            twiddles: Arc::default(),
59            inv_twiddles: Arc::default(),
60        };
61        res.update_twiddles(n);
62        res
63    }
64
65    #[inline]
66    fn decimation_in_freq_dft(
67        mat: &mut [MontyField31<MP>],
68        ncols: usize,
69        twiddles: &[Vec<MontyField31<MP>>],
70    ) {
71        if ncols > 1 {
72            let lg_fft_len = log2_strict_usize(ncols);
73            let twiddles = &twiddles[..(lg_fft_len - 1)];
74
75            mat.par_chunks_exact_mut(ncols)
76                .for_each(|v| MontyField31::forward_fft(v, twiddles));
77        }
78    }
79
80    #[inline]
81    fn decimation_in_time_dft(
82        mat: &mut [MontyField31<MP>],
83        ncols: usize,
84        twiddles: &[Vec<MontyField31<MP>>],
85    ) {
86        if ncols > 1 {
87            let lg_fft_len = p3_util::log2_strict_usize(ncols);
88            let twiddles = &twiddles[..(lg_fft_len - 1)];
89
90            mat.par_chunks_exact_mut(ncols)
91                .for_each(|v| MontyField31::backward_fft(v, twiddles));
92        }
93    }
94
95    /// Compute twiddle factors, or take memoized ones if already available.
96    #[instrument(skip_all)]
97    fn update_twiddles(&self, fft_len: usize) {
98        // As we don't save the twiddles for the final layer where
99        // the only twiddle is 1, roots_of_unity_table(fft_len)
100        // returns a vector of twiddles of length log_2(fft_len) - 1.
101        // let curr_max_fft_len = 2 << self.twiddles.read().len();
102        let need = log2_strict_usize(fft_len);
103        let snapshot = self.twiddles.read().clone();
104        let have = snapshot.len() + 1;
105        if have >= need {
106            return;
107        }
108
109        let missing_twiddles = MontyField31::get_missing_twiddles(need, have);
110
111        let missing_inv_twiddles = missing_twiddles
112            .iter()
113            .map(|ts| {
114                core::iter::once(MontyField31::ONE)
115                    .chain(
116                        ts[1..]
117                            .iter()
118                            .rev()
119                            .map(|&t| MontyField31::new_monty(MP::PRIME - t.value)),
120                    )
121                    .collect()
122            })
123            .collect::<Vec<_>>();
124        // Helper closure to extend a table under its lock.
125        let extend_table = |lock: &RwLock<Arc<[Vec<_>]>>, missing: &[Vec<_>]| {
126            let mut w = lock.write();
127            let current_len = w.len();
128            // Double-check if an update is still needed after acquiring the write lock.
129            if (current_len + 1) < need {
130                let mut v = w.to_vec();
131                // Append only the portion needed in case another thread did a partial update.
132                let extend_from = current_len.saturating_sub(current_len);
133                v.extend_from_slice(&missing[extend_from..]);
134                *w = v.into();
135            }
136        };
137        // Atomically update each table. This two-step process is the source of the race condition.
138        extend_table(&self.twiddles, &missing_twiddles);
139        extend_table(&self.inv_twiddles, &missing_inv_twiddles);
140    }
141
142    fn get_twiddles(&self) -> Arc<[Vec<MontyField31<MP>>]> {
143        self.twiddles.read().clone()
144    }
145
146    fn get_inv_twiddles(&self) -> Arc<[Vec<MontyField31<MP>>]> {
147        self.inv_twiddles.read().clone()
148    }
149}
150
151/// DFT implementation that uses DIT for the inverse "backward"
152/// direction and DIF for the "forward" direction.
153///
154/// The API mandates that the LDE is applied column-wise on the
155/// _row-major_ input. This is awkward for memory coherence, so the
156/// algorithm here transposes the input and operates on the rows in
157/// the typical way, then transposes back again for the output. Even
158/// for modestly large inputs, the cost of the two transposes
159/// outweighed by the improved performance from operating row-wise.
160///
161/// The choice of DIT for inverse and DIF for "forward" transform mean
162/// that a (coset) LDE
163///
164/// - IDFT / zero extend / DFT
165///
166/// expands to
167///
168///   - bit-reverse input
169///   - invDFT DIT
170///     - result is in "correct" order
171///   - coset shift and zero extend result
172///   - DFT DIF on result
173///     - output is bit-reversed, as required for FRI.
174///
175/// Hence the only bit-reversal that needs to take place is on the input.
176///
177impl<MP: MontyParameters + FieldParameters + TwoAdicData> TwoAdicSubgroupDft<MontyField31<MP>>
178    for RecursiveDft<MontyField31<MP>>
179{
180    type Evaluations = BitReversedMatrixView<RowMajorMatrix<MontyField31<MP>>>;
181
182    #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
183    fn dft_batch(&self, mut mat: RowMajorMatrix<MontyField31<MP>>) -> Self::Evaluations
184    where
185        MP: MontyParameters + FieldParameters + TwoAdicData,
186    {
187        let nrows = mat.height();
188        let ncols = mat.width();
189
190        if nrows <= 1 {
191            return mat.bit_reverse_rows();
192        }
193
194        let mut scratch = debug_span!("allocate scratch space")
195            .in_scope(|| RowMajorMatrix::default(nrows, ncols));
196
197        self.update_twiddles(nrows);
198        let twiddles = self.get_twiddles();
199
200        // transpose input
201        debug_span!("pre-transpose", nrows, ncols)
202            .in_scope(|| transpose::transpose(&mat.values, &mut scratch.values, ncols, nrows));
203
204        debug_span!("dft batch", n_dfts = ncols, fft_len = nrows)
205            .in_scope(|| Self::decimation_in_freq_dft(&mut scratch.values, nrows, &twiddles));
206
207        // transpose output
208        debug_span!("post-transpose", nrows = ncols, ncols = nrows)
209            .in_scope(|| transpose::transpose(&scratch.values, &mut mat.values, nrows, ncols));
210
211        mat.bit_reverse_rows()
212    }
213
214    #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
215    fn idft_batch(&self, mat: RowMajorMatrix<MontyField31<MP>>) -> RowMajorMatrix<MontyField31<MP>>
216    where
217        MP: MontyParameters + FieldParameters + TwoAdicData,
218    {
219        let nrows = mat.height();
220        let ncols = mat.width();
221        if nrows <= 1 {
222            return mat;
223        }
224
225        let mut scratch = debug_span!("allocate scratch space")
226            .in_scope(|| RowMajorMatrix::default(nrows, ncols));
227
228        let mut mat =
229            debug_span!("initial bitrev").in_scope(|| mat.bit_reverse_rows().to_row_major_matrix());
230
231        self.update_twiddles(nrows);
232        let inv_twiddles = self.get_inv_twiddles();
233
234        // transpose input
235        debug_span!("pre-transpose", nrows, ncols)
236            .in_scope(|| transpose::transpose(&mat.values, &mut scratch.values, ncols, nrows));
237
238        debug_span!("idft", n_dfts = ncols, fft_len = nrows)
239            .in_scope(|| Self::decimation_in_time_dft(&mut scratch.values, nrows, &inv_twiddles));
240
241        // transpose output
242        debug_span!("post-transpose", nrows = ncols, ncols = nrows)
243            .in_scope(|| transpose::transpose(&scratch.values, &mut mat.values, nrows, ncols));
244
245        let log_rows = log2_ceil_usize(nrows);
246        let inv_len = MontyField31::ONE.div_2exp_u64(log_rows as u64);
247        debug_span!("scale").in_scope(|| mat.scale(inv_len));
248        mat
249    }
250
251    #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
252    fn coset_lde_batch(
253        &self,
254        mat: RowMajorMatrix<MontyField31<MP>>,
255        added_bits: usize,
256        shift: MontyField31<MP>,
257    ) -> Self::Evaluations {
258        let nrows = mat.height();
259        let ncols = mat.width();
260        let result_nrows = nrows << added_bits;
261
262        if nrows == 1 {
263            let dupd_rows = core::iter::repeat_n(mat.values, result_nrows)
264                .flatten()
265                .collect();
266            return RowMajorMatrix::new(dupd_rows, ncols).bit_reverse_rows();
267        }
268
269        let input_size = nrows * ncols;
270        let output_size = result_nrows * ncols;
271
272        let mat = mat.bit_reverse_rows().to_row_major_matrix();
273
274        // Allocate space for the output and the intermediate state.
275        let (mut output, mut padded) = debug_span!("allocate scratch space").in_scope(|| {
276            // Safety: These are pretty dodgy, but work because MontyField31 is #[repr(transparent)]
277            let output = MontyField31::<MP>::zero_vec(output_size);
278            let padded = MontyField31::<MP>::zero_vec(output_size);
279            (output, padded)
280        });
281
282        // `coeffs` will hold the result of the inverse FFT; use the
283        // output storage as scratch space.
284        let coeffs = &mut output[..input_size];
285
286        debug_span!("pre-transpose", nrows, ncols)
287            .in_scope(|| transpose::transpose(&mat.values, coeffs, ncols, nrows));
288
289        // Apply inverse DFT; result is not yet normalised.
290        self.update_twiddles(result_nrows);
291        let inv_twiddles = self.get_inv_twiddles();
292        debug_span!("inverse dft batch", n_dfts = ncols, fft_len = nrows)
293            .in_scope(|| Self::decimation_in_time_dft(coeffs, nrows, &inv_twiddles));
294
295        // At this point the inverse FFT of each column of `mat` appears
296        // as a row in `coeffs`.
297
298        // Normalise inverse DFT and coset shift in one go.
299        let log_rows = log2_ceil_usize(nrows);
300        let inv_len = MontyField31::ONE.div_2exp_u64(log_rows as u64);
301        coset_shift_and_scale_rows(&mut padded, result_nrows, coeffs, nrows, shift, inv_len);
302
303        // `padded` is implicitly zero padded since it was initialised
304        // to zeros when declared above.
305
306        let twiddles = self.get_twiddles();
307
308        // Apply DFT
309        debug_span!("dft batch", n_dfts = ncols, fft_len = result_nrows)
310            .in_scope(|| Self::decimation_in_freq_dft(&mut padded, result_nrows, &twiddles));
311
312        // transpose output
313        debug_span!("post-transpose", nrows = ncols, ncols = result_nrows)
314            .in_scope(|| transpose::transpose(&padded, &mut output, result_nrows, ncols));
315
316        RowMajorMatrix::new(output, ncols).bit_reverse_rows()
317    }
318}