Skip to main content

p3_monty_31/dft/
mod.rs

1//! An implementation of the FFT for `MontyField31`
2extern crate alloc;
3
4use alloc::sync::Arc;
5use alloc::vec::Vec;
6
7use itertools::izip;
8use p3_dft::TwoAdicSubgroupDft;
9use p3_field::{Field, PrimeCharacteristicRing};
10use p3_matrix::Matrix;
11use p3_matrix::bitrev::{BitReversedMatrixView, BitReversibleMatrix};
12use p3_matrix::dense::RowMajorMatrix;
13use p3_maybe_rayon::prelude::*;
14use p3_util::{log2_ceil_usize, log2_strict_usize};
15use spin::RwLock;
16use tracing::{debug_span, instrument};
17
18mod backward;
19mod forward;
20
21use crate::{FieldParameters, MontyField31, MontyParameters, TwoAdicData};
22
23/// Multiply each element of column `j` of `mat` by `shift**j`.
24#[instrument(level = "debug", skip_all)]
25fn coset_shift_and_scale_rows<F: Field>(
26    out: &mut [F],
27    out_ncols: usize,
28    mat: &[F],
29    ncols: usize,
30    shift: F,
31    scale: F,
32) {
33    let powers = shift.shifted_powers(scale).collect_n(ncols);
34    out.par_chunks_exact_mut(out_ncols)
35        .zip(mat.par_chunks_exact(ncols))
36        .for_each(|(out_row, in_row)| {
37            izip!(out_row.iter_mut(), in_row, &powers).for_each(|(out, &coeff, &weight)| {
38                *out = coeff * weight;
39            });
40        });
41}
42
43/// Recursive DFT, decimation-in-frequency in the forward direction,
44/// decimation-in-time in the backward (inverse) direction.
45#[derive(Clone, Debug, Default)]
46pub struct RecursiveDft<F> {
47    /// Forward twiddle tables
48    #[allow(clippy::type_complexity)]
49    twiddles: Arc<RwLock<Arc<[Vec<F>]>>>,
50    /// Inverse twiddle tables
51    #[allow(clippy::type_complexity)]
52    inv_twiddles: Arc<RwLock<Arc<[Vec<F>]>>>,
53}
54
55impl<MP: FieldParameters + TwoAdicData> RecursiveDft<MontyField31<MP>> {
56    pub fn new(n: usize) -> Self {
57        let res = Self {
58            twiddles: Arc::default(),
59            inv_twiddles: Arc::default(),
60        };
61        res.update_twiddles(n);
62        res
63    }
64
65    #[inline]
66    fn decimation_in_freq_dft(
67        mat: &mut [MontyField31<MP>],
68        ncols: usize,
69        twiddles: &[Vec<MontyField31<MP>>],
70    ) {
71        if ncols > 1 {
72            let lg_fft_len = log2_strict_usize(ncols);
73            let twiddles = &twiddles[..(lg_fft_len - 1)];
74
75            mat.par_chunks_exact_mut(ncols)
76                .for_each(|v| MontyField31::forward_fft(v, twiddles));
77        }
78    }
79
80    #[inline]
81    fn decimation_in_time_dft(
82        mat: &mut [MontyField31<MP>],
83        ncols: usize,
84        twiddles: &[Vec<MontyField31<MP>>],
85    ) {
86        if ncols > 1 {
87            let lg_fft_len = p3_util::log2_strict_usize(ncols);
88            let twiddles = &twiddles[..(lg_fft_len - 1)];
89
90            mat.par_chunks_exact_mut(ncols)
91                .for_each(|v| MontyField31::backward_fft(v, twiddles));
92        }
93    }
94
95    /// Compute twiddle factors, or take memoized ones if already available.
96    #[instrument(skip_all)]
97    fn update_twiddles(&self, fft_len: usize) {
98        // As we don't save the twiddles for the final layer where
99        // the only twiddle is 1, roots_of_unity_table(fft_len)
100        // returns a vector of twiddles of length log_2(fft_len) - 1.
101        // let curr_max_fft_len = 2 << self.twiddles.read().len();
102        let need = log2_strict_usize(fft_len);
103        let snapshot = self.twiddles.read().clone();
104        let have = snapshot.len() + 1;
105        if have >= need {
106            return;
107        }
108
109        let missing_twiddles = MontyField31::get_missing_twiddles(need, have);
110
111        let missing_inv_twiddles = missing_twiddles
112            .iter()
113            .map(|ts| {
114                core::iter::once(MontyField31::ONE)
115                    .chain(
116                        ts[1..]
117                            .iter()
118                            .rev()
119                            .map(|&t| MontyField31::new_monty(MP::PRIME - t.value)),
120                    )
121                    .collect()
122            })
123            .collect::<Vec<_>>();
124        // Helper closure to extend a table under its lock.
125        let have_minus_one = have - 1;
126        let extend_table = |lock: &RwLock<Arc<[Vec<_>]>>, missing: &[Vec<_>]| {
127            let mut w = lock.write();
128            let current_len = w.len();
129            // Double-check if an update is still needed after acquiring the write lock.
130            if (current_len + 1) < need {
131                let mut v = w.to_vec();
132                // Append only the portion needed in case another thread did a partial update.
133                let extend_from = current_len.saturating_sub(have_minus_one);
134                v.extend_from_slice(&missing[extend_from..]);
135                *w = v.into();
136            }
137        };
138        // Atomically update each table. This two-step process is the source of the race condition.
139        extend_table(&self.twiddles, &missing_twiddles);
140        extend_table(&self.inv_twiddles, &missing_inv_twiddles);
141    }
142
143    fn get_twiddles(&self) -> Arc<[Vec<MontyField31<MP>>]> {
144        self.twiddles.read().clone()
145    }
146
147    fn get_inv_twiddles(&self) -> Arc<[Vec<MontyField31<MP>>]> {
148        self.inv_twiddles.read().clone()
149    }
150}
151
152/// DFT implementation that uses DIT for the inverse "backward"
153/// direction and DIF for the "forward" direction.
154///
155/// The API mandates that the LDE is applied column-wise on the
156/// _row-major_ input. This is awkward for memory coherence, so the
157/// algorithm here transposes the input and operates on the rows in
158/// the typical way, then transposes back again for the output. Even
159/// for modestly large inputs, the cost of the two transposes
160/// outweighed by the improved performance from operating row-wise.
161///
162/// The choice of DIT for inverse and DIF for "forward" transform mean
163/// that a (coset) LDE
164///
165/// - IDFT / zero extend / DFT
166///
167/// expands to
168///
169///   - bit-reverse input
170///   - invDFT DIT
171///     - result is in "correct" order
172///   - coset shift and zero extend result
173///   - DFT DIF on result
174///     - output is bit-reversed, as required for FRI.
175///
176/// Hence the only bit-reversal that needs to take place is on the input.
177///
178impl<MP: MontyParameters + FieldParameters + TwoAdicData> TwoAdicSubgroupDft<MontyField31<MP>>
179    for RecursiveDft<MontyField31<MP>>
180{
181    type Evaluations = BitReversedMatrixView<RowMajorMatrix<MontyField31<MP>>>;
182
183    #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
184    fn dft_batch(&self, mut mat: RowMajorMatrix<MontyField31<MP>>) -> Self::Evaluations
185    where
186        MP: MontyParameters + FieldParameters + TwoAdicData,
187    {
188        let nrows = mat.height();
189        let ncols = mat.width();
190
191        if nrows <= 1 {
192            return mat.bit_reverse_rows();
193        }
194
195        let mut scratch = debug_span!("allocate scratch space")
196            .in_scope(|| RowMajorMatrix::default(nrows, ncols));
197
198        self.update_twiddles(nrows);
199        let twiddles = self.get_twiddles();
200
201        // transpose input
202        debug_span!("pre-transpose", nrows, ncols).in_scope(|| {
203            p3_util::transpose::transpose(&mat.values, &mut scratch.values, ncols, nrows);
204        });
205
206        debug_span!("dft batch", n_dfts = ncols, fft_len = nrows)
207            .in_scope(|| Self::decimation_in_freq_dft(&mut scratch.values, nrows, &twiddles));
208
209        // transpose output
210        debug_span!("post-transpose", nrows = ncols, ncols = nrows).in_scope(|| {
211            p3_util::transpose::transpose(&scratch.values, &mut mat.values, nrows, ncols);
212        });
213
214        mat.bit_reverse_rows()
215    }
216
217    #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
218    fn idft_batch(&self, mat: RowMajorMatrix<MontyField31<MP>>) -> RowMajorMatrix<MontyField31<MP>>
219    where
220        MP: MontyParameters + FieldParameters + TwoAdicData,
221    {
222        let nrows = mat.height();
223        let ncols = mat.width();
224        if nrows <= 1 {
225            return mat;
226        }
227
228        let mut scratch = debug_span!("allocate scratch space")
229            .in_scope(|| RowMajorMatrix::default(nrows, ncols));
230
231        let mut mat =
232            debug_span!("initial bitrev").in_scope(|| mat.bit_reverse_rows().to_row_major_matrix());
233
234        self.update_twiddles(nrows);
235        let inv_twiddles = self.get_inv_twiddles();
236
237        // transpose input
238        debug_span!("pre-transpose", nrows, ncols).in_scope(|| {
239            p3_util::transpose::transpose(&mat.values, &mut scratch.values, ncols, nrows);
240        });
241
242        debug_span!("idft", n_dfts = ncols, fft_len = nrows)
243            .in_scope(|| Self::decimation_in_time_dft(&mut scratch.values, nrows, &inv_twiddles));
244
245        // transpose output
246        debug_span!("post-transpose", nrows = ncols, ncols = nrows).in_scope(|| {
247            p3_util::transpose::transpose(&scratch.values, &mut mat.values, nrows, ncols);
248        });
249
250        let log_rows = log2_ceil_usize(nrows);
251        let inv_len = MontyField31::ONE.div_2exp_u64(log_rows as u64);
252        debug_span!("scale").in_scope(|| mat.scale(inv_len));
253        mat
254    }
255
256    #[instrument(skip_all, level = "debug", fields(dims = %mat.dimensions(), added_bits))]
257    fn coset_lde_batch(
258        &self,
259        mat: RowMajorMatrix<MontyField31<MP>>,
260        added_bits: usize,
261        shift: MontyField31<MP>,
262    ) -> Self::Evaluations {
263        let nrows = mat.height();
264        let ncols = mat.width();
265        let result_nrows = nrows << added_bits;
266
267        if nrows == 1 {
268            let dupd_rows = core::iter::repeat_n(mat.values, result_nrows)
269                .flatten()
270                .collect();
271            return RowMajorMatrix::new(dupd_rows, ncols).bit_reverse_rows();
272        }
273
274        let input_size = nrows * ncols;
275        let output_size = result_nrows * ncols;
276
277        let mat = mat.bit_reverse_rows().to_row_major_matrix();
278
279        // Allocate space for the output and the intermediate state.
280        let (mut output, mut padded) = debug_span!("allocate scratch space").in_scope(|| {
281            // Safety: These are pretty dodgy, but work because MontyField31 is #[repr(transparent)]
282            let output = MontyField31::<MP>::zero_vec(output_size);
283            let padded = MontyField31::<MP>::zero_vec(output_size);
284            (output, padded)
285        });
286
287        // `coeffs` will hold the result of the inverse FFT; use the
288        // output storage as scratch space.
289        let coeffs = &mut output[..input_size];
290
291        debug_span!("pre-transpose", nrows, ncols)
292            .in_scope(|| p3_util::transpose::transpose(&mat.values, coeffs, ncols, nrows));
293
294        // Apply inverse DFT; result is not yet normalised.
295        self.update_twiddles(result_nrows);
296        let inv_twiddles = self.get_inv_twiddles();
297        debug_span!("inverse dft batch", n_dfts = ncols, fft_len = nrows)
298            .in_scope(|| Self::decimation_in_time_dft(coeffs, nrows, &inv_twiddles));
299
300        // At this point the inverse FFT of each column of `mat` appears
301        // as a row in `coeffs`.
302
303        // Normalise inverse DFT and coset shift in one go.
304        let log_rows = log2_ceil_usize(nrows);
305        let inv_len = MontyField31::ONE.div_2exp_u64(log_rows as u64);
306        coset_shift_and_scale_rows(&mut padded, result_nrows, coeffs, nrows, shift, inv_len);
307
308        // `padded` is implicitly zero padded since it was initialised
309        // to zeros when declared above.
310
311        let twiddles = self.get_twiddles();
312
313        // Apply DFT
314        debug_span!("dft batch", n_dfts = ncols, fft_len = result_nrows)
315            .in_scope(|| Self::decimation_in_freq_dft(&mut padded, result_nrows, &twiddles));
316
317        // transpose output
318        debug_span!("post-transpose", nrows = ncols, ncols = result_nrows)
319            .in_scope(|| p3_util::transpose::transpose(&padded, &mut output, result_nrows, ncols));
320
321        RowMajorMatrix::new(output, ncols).bit_reverse_rows()
322    }
323}