p3_dft/
radix_2_dit.rs

1use alloc::collections::BTreeMap;
2use alloc::sync::Arc;
3
4use p3_field::{Field, TwoAdicField};
5use p3_matrix::Matrix;
6use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut};
7use p3_matrix::util::reverse_matrix_index_bits;
8use p3_maybe_rayon::prelude::*;
9use p3_util::log2_strict_usize;
10use spin::RwLock;
11
12use crate::TwoAdicSubgroupDft;
13use crate::butterflies::{Butterfly, DitButterfly, TwiddleFreeButterfly};
14
15/// Radix-2 Decimation-in-Time FFT over a two-adic subgroup.
16///
17/// This struct implements a fast Fourier transform (FFT) using the Radix-2
18/// Decimation-in-Time (DIT) algorithm over a two-adic multiplicative subgroup of a finite field.
19/// It is optimized for a batch setting where multiple FFT's are being computed simultaneously.
20///
21/// Internally, the implementation memoizes twiddle factors (powers of the root of unity)
22/// for reuse across multiple transforms. This avoids redundant computation
23/// when performing FFTs of the same size.
24#[derive(Default, Clone, Debug)]
25pub struct Radix2Dit<F: TwoAdicField> {
26    /// Memoized twiddle factors indexed by `log2(n)`, where `n` is the DFT length.
27    ///
28    /// This allows fast lookup and reuse of previously computed twiddle values
29    /// (powers of a two-adic generator), which are expensive to recompute.
30    ///
31    /// `RwLock` is used to enable interior mutability for caching purposes along with thread
32    /// safety.
33    twiddles: Arc<RwLock<BTreeMap<usize, Arc<[F]>>>>,
34}
35
36impl<F: TwoAdicField> Radix2Dit<F> {
37    /// Returns the twiddle factors for a DFT of size `2^log_h`.
38    /// If they haven't been computed yet, this function computes and caches them.
39    fn get_or_compute_twiddles(&self, log_h: usize) -> Arc<[F]> {
40        // Fast path: Check if the twiddles already exist with a read lock.
41        if let Some(twiddles) = self.twiddles.read().get(&log_h) {
42            return twiddles.clone();
43        }
44        // Slow path: The twiddles were not found. We need to compute them.
45        // Acquire a write lock to ensure only one thread computes and inserts the values.
46        let mut w_lock = self.twiddles.write();
47        // Double-check: Another thread might have computed and inserted the twiddles
48        // while we were waiting for the write lock. The `entry` API handles this
49        // check and insertion atomically.
50        w_lock
51            .entry(log_h)
52            .or_insert_with(|| {
53                let n = 1 << log_h;
54                let root = F::two_adic_generator(log_h);
55                Arc::from(root.powers().take(n).collect())
56            })
57            .clone()
58    }
59}
60
61impl<F: TwoAdicField> TwoAdicSubgroupDft<F> for Radix2Dit<F> {
62    type Evaluations = RowMajorMatrix<F>;
63
64    fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
65        let h = mat.height();
66        let log_h = log2_strict_usize(h);
67
68        // Compute twiddle factors, or take memoized ones if already available.
69        let twiddles = self.get_or_compute_twiddles(log_h);
70
71        // DIT butterfly
72        reverse_matrix_index_bits(&mut mat);
73        for layer in 0..log_h {
74            dit_layer(&mut mat.as_view_mut(), layer, &twiddles);
75        }
76        mat
77    }
78}
79
80/// Applies one layer of the Radix-2 DIT FFT butterfly network.
81///
82/// Splits the matrix into blocks of rows and performs in-place butterfly operations
83/// on each block. Uses a `TwiddleFreeButterfly` for the first pair and `DitButterfly`
84/// with precomputed twiddles for the rest.
85///
86/// # Arguments
87/// - `mat`: Mutable matrix view with height as a power of two.
88/// - `layer`: Index of the current FFT layer (starting at 0).
89/// - `twiddles`: Precomputed twiddle factors for this layer.
90fn dit_layer<F: Field>(mat: &mut RowMajorMatrixViewMut<'_, F>, layer: usize, twiddles: &[F]) {
91    // Get the number of rows in the matrix (must be a power of two)
92    let h = mat.height();
93    // Compute reversed layer index to access twiddle indices correctly
94    let log_h = log2_strict_usize(h);
95    let layer_rev = log_h - 1 - layer;
96
97    // Each butterfly operates on 2 rows; this is the number of rows in half a block
98    let half_block_size = 1 << layer;
99    // Each block contains 2^layer * 2 rows; full size of the butterfly block
100    let block_size = half_block_size * 2;
101
102    // Process the matrix in blocks of rows of size `block_size`
103    mat.par_row_chunks_exact_mut(block_size)
104        .for_each(|mut block_chunks| {
105            // Split each block vertically into top (hi) and bottom (lo) halves
106            let (mut hi_chunks, mut lo_chunks) = block_chunks.split_rows_mut(half_block_size);
107            // For each pair of rows (hi, lo), apply a butterfly
108            hi_chunks
109                .par_rows_mut()
110                .zip(lo_chunks.par_rows_mut())
111                .enumerate()
112                .for_each(|(ind, (hi_chunk, lo_chunk))| {
113                    if ind == 0 {
114                        // The first pair doesn't require a twiddle factor
115                        TwiddleFreeButterfly.apply_to_rows(hi_chunk, lo_chunk);
116                    } else {
117                        // Apply DIT butterfly using the twiddle factor at index `ind << layer_rev`
118                        DitButterfly(twiddles[ind << layer_rev]).apply_to_rows(hi_chunk, lo_chunk);
119                    }
120                });
121        });
122}