p3_dft/
radix_2_bowers.rs

1use alloc::vec::Vec;
2
3use p3_field::{Field, PrimeCharacteristicRing, TwoAdicField};
4use p3_matrix::Matrix;
5use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut};
6use p3_matrix::util::reverse_matrix_index_bits;
7use p3_maybe_rayon::prelude::*;
8use p3_util::{flatten_to_base, log2_strict_usize, reverse_slice_index_bits};
9use tracing::instrument;
10
11use crate::TwoAdicSubgroupDft;
12use crate::butterflies::{Butterfly, DifButterfly, DitButterfly, TwiddleFreeButterfly};
13use crate::util::divide_by_height;
14
15/// The Bowers G FFT algorithm.
16/// See: "Improved Twiddle Access for Fast Fourier Transforms"
17#[derive(Default, Clone)]
18pub struct Radix2Bowers;
19
20impl<F: TwoAdicField> TwoAdicSubgroupDft<F> for Radix2Bowers {
21    type Evaluations = RowMajorMatrix<F>;
22
23    fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
24        reverse_matrix_index_bits(&mut mat);
25        bowers_g(&mut mat.as_view_mut());
26        mat
27    }
28
29    /// Compute the inverse DFT of each column in `mat`.
30    fn idft_batch(&self, mut mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
31        bowers_g_t(&mut mat.as_view_mut());
32        divide_by_height(&mut mat);
33        reverse_matrix_index_bits(&mut mat);
34        mat
35    }
36
37    fn lde_batch(&self, mut mat: RowMajorMatrix<F>, added_bits: usize) -> RowMajorMatrix<F> {
38        bowers_g_t(&mut mat.as_view_mut());
39        divide_by_height(&mut mat);
40        mat = mat.bit_reversed_zero_pad(added_bits);
41        bowers_g(&mut mat.as_view_mut());
42        mat
43    }
44
45    #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
46    fn coset_lde_batch(
47        &self,
48        mut mat: RowMajorMatrix<F>,
49        added_bits: usize,
50        shift: F,
51    ) -> RowMajorMatrix<F> {
52        let h = mat.height();
53        let log_h = log2_strict_usize(h);
54        // It's cheaper to use div_2exp_u64 as this usually avoids an inversion.
55        // It's also cheaper to work in the PrimeSubfield whenever possible.
56        let h_inv_subfield = F::PrimeSubfield::ONE.div_2exp_u64(log_h as u64);
57        let h_inv = F::from_prime_subfield(h_inv_subfield);
58
59        bowers_g_t(&mut mat.as_view_mut());
60
61        // Rescale coefficients in two ways:
62        // - divide by height (since we're doing an inverse DFT)
63        // - multiply by powers of the coset shift (see default coset LDE impl for an explanation)
64        let mut weights = shift.shifted_powers(h_inv).collect_n(h);
65        // reverse_bits because mat is encoded in bit-reversed order
66        reverse_slice_index_bits(&mut weights);
67
68        mat.par_rows_mut()
69            .zip(weights.into_par_iter())
70            .for_each(|(row, weight)| row.iter_mut().for_each(|elem| *elem *= weight));
71
72        mat = mat.bit_reversed_zero_pad(added_bits);
73
74        bowers_g(&mut mat.as_view_mut());
75
76        mat
77    }
78}
79
80/// Executes the Bowers G network. This is like a DFT, except it assumes the input is in
81/// bit-reversed order.
82fn bowers_g<F: TwoAdicField>(mat: &mut RowMajorMatrixViewMut<'_, F>) {
83    let h = mat.height();
84    let log_h = log2_strict_usize(h);
85
86    let root = F::two_adic_generator(log_h);
87    let mut twiddles = root.powers().collect_n(h / 2);
88    reverse_slice_index_bits(&mut twiddles);
89    // SAFETY: DifButterfly is `repr(transparent)`
90    let twiddles: Vec<DifButterfly<F>> = unsafe { flatten_to_base(twiddles) };
91
92    for log_half_block_size in 0..log_h {
93        butterfly_layer(mat, 1 << log_half_block_size, &twiddles);
94    }
95}
96
97/// Executes the Bowers G^T network. This is like an inverse DFT, except we skip rescaling by
98/// 1/height, and the output is bit-reversed.
99fn bowers_g_t<F: TwoAdicField>(mat: &mut RowMajorMatrixViewMut<'_, F>) {
100    let h = mat.height();
101    let log_h = log2_strict_usize(h);
102
103    let root_inv = F::two_adic_generator(log_h).inverse();
104    let mut twiddles = root_inv.powers().collect_n(h / 2);
105    reverse_slice_index_bits(&mut twiddles);
106    // SAFETY: DitButterfly is `repr(transparent)`
107    let twiddles: Vec<DitButterfly<F>> = unsafe { flatten_to_base(twiddles) };
108
109    for log_half_block_size in (0..log_h).rev() {
110        butterfly_layer(mat, 1 << log_half_block_size, &twiddles);
111    }
112}
113
114fn butterfly_layer<F: Field, B: Butterfly<F>>(
115    mat: &mut RowMajorMatrixViewMut<'_, F>,
116    half_block_size: usize,
117    twiddles: &[B],
118) {
119    mat.par_row_chunks_exact_mut(2 * half_block_size)
120        .enumerate()
121        .for_each(|(block, mut chunks)| {
122            let (mut hi_chunks, mut lo_chunks) = chunks.split_rows_mut(half_block_size);
123            hi_chunks
124                .par_rows_mut()
125                .zip(lo_chunks.par_rows_mut())
126                .for_each(|(hi_chunk, lo_chunk)| {
127                    if block == 0 {
128                        TwiddleFreeButterfly.apply_to_rows(hi_chunk, lo_chunk);
129                    } else {
130                        twiddles[block].apply_to_rows(hi_chunk, lo_chunk);
131                    }
132                });
133        });
134}