Skip to main content

p3_dft/
radix_2_bowers.rs

1use alloc::vec::Vec;
2
3use p3_field::{Field, PrimeCharacteristicRing, TwoAdicField};
4use p3_matrix::Matrix;
5use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut};
6use p3_matrix::util::reverse_matrix_index_bits;
7use p3_maybe_rayon::prelude::*;
8use p3_util::{flatten_to_base, log2_strict_usize, reverse_slice_index_bits};
9use tracing::instrument;
10
11use crate::butterflies::{Butterfly, DifButterfly, DitButterfly, TwiddleFreeButterfly};
12use crate::util::divide_by_height;
13use crate::{Layout, TwoAdicSubgroupDft};
14
15/// The Bowers G FFT algorithm.
16/// See: "Improved Twiddle Access for Fast Fourier Transforms"
17#[derive(Default, Clone)]
18pub struct Radix2Bowers;
19
20impl<F: TwoAdicField> TwoAdicSubgroupDft<F> for Radix2Bowers {
21    type Evaluations = RowMajorMatrix<F>;
22
23    fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
24        reverse_matrix_index_bits(&mut mat);
25        bowers_g(&mut mat.as_view_mut());
26        mat
27    }
28
29    /// Compute the inverse DFT of each column in `mat`.
30    fn idft_batch(&self, mut mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
31        bowers_g_t(&mut mat.as_view_mut());
32        divide_by_height(&mut mat);
33        reverse_matrix_index_bits(&mut mat);
34        mat
35    }
36
37    fn lde_batch(&self, mut mat: RowMajorMatrix<F>, added_bits: usize) -> RowMajorMatrix<F> {
38        bowers_g_t(&mut mat.as_view_mut());
39        divide_by_height(&mut mat);
40        mat = mat.bit_reversed_zero_pad(added_bits);
41        bowers_g(&mut mat.as_view_mut());
42        mat
43    }
44
45    #[instrument(skip_all, level = "debug", fields(dims = %mat.dimensions(), added_bits))]
46    fn coset_lde_batch(
47        &self,
48        mut mat: RowMajorMatrix<F>,
49        added_bits: usize,
50        shift: F,
51    ) -> RowMajorMatrix<F> {
52        let h = mat.height();
53        let log_h = log2_strict_usize(h);
54        // It's cheaper to use div_2exp_u64 as this usually avoids an inversion.
55        // It's also cheaper to work in the PrimeSubfield whenever possible.
56        let h_inv_subfield = F::PrimeSubfield::ONE.div_2exp_u64(log_h as u64);
57        let h_inv = F::from_prime_subfield(h_inv_subfield);
58
59        bowers_g_t(&mut mat.as_view_mut());
60
61        // Rescale coefficients in two ways:
62        // - divide by height (since we're doing an inverse DFT)
63        // - multiply by powers of the coset shift (see default coset LDE impl for an explanation)
64        let mut weights = shift.shifted_powers(h_inv).collect_n(h);
65        // reverse_bits because mat is encoded in bit-reversed order
66        reverse_slice_index_bits(&mut weights);
67
68        mat.par_rows_mut()
69            .zip(weights.into_par_iter())
70            .for_each(|(row, weight)| row.iter_mut().for_each(|elem| *elem *= weight));
71
72        mat = mat.bit_reversed_zero_pad(added_bits);
73
74        bowers_g(&mut mat.as_view_mut());
75
76        mat
77    }
78
79    #[instrument(skip_all, level = "debug", fields(dims = %mat.dimensions(), added_bits))]
80    fn coset_lde_batch_with_transform<T>(
81        &self,
82        mut mat: RowMajorMatrix<F>,
83        added_bits: usize,
84        shift: F,
85        transform: T,
86    ) -> RowMajorMatrix<F>
87    where
88        T: FnOnce(&mut RowMajorMatrixViewMut<'_, F>, Layout),
89    {
90        let h = mat.height();
91        let log_h = log2_strict_usize(h);
92        let h_inv_subfield = F::PrimeSubfield::ONE.div_2exp_u64(log_h as u64);
93        let h_inv = F::from_prime_subfield(h_inv_subfield);
94
95        // Inverse butterfly leaves coefficients in bit-reversed memory.
96        bowers_g_t(&mut mat.as_view_mut());
97
98        // Normalise to true polynomial coefficients before invoking the closure.
99        // (Costs one extra O(h) pass relative to the no-op `coset_lde_batch` path,
100        // which fuses normalisation with the coset-shift weights.)
101        mat.values.par_iter_mut().for_each(|v| *v *= h_inv);
102
103        transform(&mut mat.as_view_mut(), Layout::BitReversed);
104
105        // Apply coset-shift weights in bit-reversed order.
106        let mut shift_powers = shift.powers().collect_n(h);
107        reverse_slice_index_bits(&mut shift_powers);
108        mat.par_rows_mut()
109            .zip(shift_powers.into_par_iter())
110            .for_each(|(row, sp)| row.iter_mut().for_each(|e| *e *= sp));
111
112        mat = mat.bit_reversed_zero_pad(added_bits);
113        bowers_g(&mut mat.as_view_mut());
114        mat
115    }
116}
117
118/// Executes the Bowers G network. This is like a DFT, except it assumes the input is in
119/// bit-reversed order.
120fn bowers_g<F: TwoAdicField>(mat: &mut RowMajorMatrixViewMut<'_, F>) {
121    let h = mat.height();
122    let log_h = log2_strict_usize(h);
123
124    let root = F::two_adic_generator(log_h);
125    let mut twiddles = root.powers().collect_n(h / 2);
126    reverse_slice_index_bits(&mut twiddles);
127    // SAFETY: DifButterfly is `repr(transparent)`
128    let twiddles: Vec<DifButterfly<F>> = unsafe { flatten_to_base(twiddles) };
129
130    for log_half_block_size in 0..log_h {
131        butterfly_layer(mat, 1 << log_half_block_size, &twiddles);
132    }
133}
134
135/// Executes the Bowers G^T network. This is like an inverse DFT, except we skip rescaling by
136/// 1/height, and the output is bit-reversed.
137fn bowers_g_t<F: TwoAdicField>(mat: &mut RowMajorMatrixViewMut<'_, F>) {
138    let h = mat.height();
139    let log_h = log2_strict_usize(h);
140
141    let root_inv = F::two_adic_generator(log_h).inverse();
142    let mut twiddles = root_inv.powers().collect_n(h / 2);
143    reverse_slice_index_bits(&mut twiddles);
144    // SAFETY: DitButterfly is `repr(transparent)`
145    let twiddles: Vec<DitButterfly<F>> = unsafe { flatten_to_base(twiddles) };
146
147    for log_half_block_size in (0..log_h).rev() {
148        butterfly_layer(mat, 1 << log_half_block_size, &twiddles);
149    }
150}
151
152fn butterfly_layer<F: Field, B: Butterfly<F>>(
153    mat: &mut RowMajorMatrixViewMut<'_, F>,
154    half_block_size: usize,
155    twiddles: &[B],
156) {
157    mat.par_row_chunks_exact_mut(2 * half_block_size)
158        .enumerate()
159        .for_each(|(block, mut chunks)| {
160            let (mut hi_chunks, mut lo_chunks) = chunks.split_rows_mut(half_block_size);
161            hi_chunks
162                .par_rows_mut()
163                .zip(lo_chunks.par_rows_mut())
164                .for_each(|(hi_chunk, lo_chunk)| {
165                    if block == 0 {
166                        TwiddleFreeButterfly.apply_to_rows(hi_chunk, lo_chunk);
167                    } else {
168                        twiddles[block].apply_to_rows(hi_chunk, lo_chunk);
169                    }
170                });
171        });
172}