p3_dft/radix_2_dit.rs
1use alloc::collections::BTreeMap;
2use alloc::sync::Arc;
3
4use p3_field::{Field, TwoAdicField};
5use p3_matrix::Matrix;
6use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut};
7use p3_matrix::util::reverse_matrix_index_bits;
8use p3_maybe_rayon::prelude::*;
9use p3_util::log2_strict_usize;
10use spin::RwLock;
11
12use crate::TwoAdicSubgroupDft;
13use crate::butterflies::{Butterfly, DitButterfly, TwiddleFreeButterfly};
14
15/// Radix-2 Decimation-in-Time FFT over a two-adic subgroup.
16///
17/// This struct implements a fast Fourier transform (FFT) using the Radix-2
18/// Decimation-in-Time (DIT) algorithm over a two-adic multiplicative subgroup of a finite field.
19/// It is optimized for a batch setting where multiple FFT's are being computed simultaneously.
20///
21/// Internally, the implementation memoizes twiddle factors (powers of the root of unity)
22/// for reuse across multiple transforms. This avoids redundant computation
23/// when performing FFTs of the same size.
24#[derive(Default, Clone, Debug)]
25pub struct Radix2Dit<F: TwoAdicField> {
26 /// Memoized twiddle factors indexed by `log2(n)`, where `n` is the DFT length.
27 ///
28 /// This allows fast lookup and reuse of previously computed twiddle values
29 /// (powers of a two-adic generator), which are expensive to recompute.
30 ///
31 /// `RwLock` is used to enable interior mutability for caching purposes along with thread
32 /// safety.
33 twiddles: Arc<RwLock<BTreeMap<usize, Arc<[F]>>>>,
34}
35
36impl<F: TwoAdicField> Radix2Dit<F> {
37 /// Returns the twiddle factors for a DFT of size `2^log_h`.
38 /// If they haven't been computed yet, this function computes and caches them.
39 fn get_or_compute_twiddles(&self, log_h: usize) -> Arc<[F]> {
40 // Fast path: Check if the twiddles already exist with a read lock.
41 if let Some(twiddles) = self.twiddles.read().get(&log_h) {
42 return twiddles.clone();
43 }
44 // Slow path: The twiddles were not found. We need to compute them.
45 // Acquire a write lock to ensure only one thread computes and inserts the values.
46 let mut w_lock = self.twiddles.write();
47 // Double-check: Another thread might have computed and inserted the twiddles
48 // while we were waiting for the write lock. The `entry` API handles this
49 // check and insertion atomically.
50 w_lock
51 .entry(log_h)
52 .or_insert_with(|| {
53 let n = 1 << log_h;
54 let root = F::two_adic_generator(log_h);
55 Arc::from(root.powers().take(n).collect())
56 })
57 .clone()
58 }
59}
60
61impl<F: TwoAdicField> TwoAdicSubgroupDft<F> for Radix2Dit<F> {
62 type Evaluations = RowMajorMatrix<F>;
63
64 fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
65 let h = mat.height();
66 let log_h = log2_strict_usize(h);
67
68 // Compute twiddle factors, or take memoized ones if already available.
69 let twiddles = self.get_or_compute_twiddles(log_h);
70
71 // DIT butterfly
72 reverse_matrix_index_bits(&mut mat);
73 for layer in 0..log_h {
74 dit_layer(&mut mat.as_view_mut(), layer, &twiddles);
75 }
76 mat
77 }
78}
79
80/// Applies one layer of the Radix-2 DIT FFT butterfly network.
81///
82/// Splits the matrix into blocks of rows and performs in-place butterfly operations
83/// on each block. Uses a `TwiddleFreeButterfly` for the first pair and `DitButterfly`
84/// with precomputed twiddles for the rest.
85///
86/// # Arguments
87/// - `mat`: Mutable matrix view with height as a power of two.
88/// - `layer`: Index of the current FFT layer (starting at 0).
89/// - `twiddles`: Precomputed twiddle factors for this layer.
90fn dit_layer<F: Field>(mat: &mut RowMajorMatrixViewMut<'_, F>, layer: usize, twiddles: &[F]) {
91 // Get the number of rows in the matrix (must be a power of two)
92 let h = mat.height();
93 // Compute reversed layer index to access twiddle indices correctly
94 let log_h = log2_strict_usize(h);
95 let layer_rev = log_h - 1 - layer;
96
97 // Each butterfly operates on 2 rows; this is the number of rows in half a block
98 let half_block_size = 1 << layer;
99 // Each block contains 2^layer * 2 rows; full size of the butterfly block
100 let block_size = half_block_size * 2;
101
102 // Process the matrix in blocks of rows of size `block_size`
103 mat.par_row_chunks_exact_mut(block_size)
104 .for_each(|mut block_chunks| {
105 // Split each block vertically into top (hi) and bottom (lo) halves
106 let (mut hi_chunks, mut lo_chunks) = block_chunks.split_rows_mut(half_block_size);
107 // For each pair of rows (hi, lo), apply a butterfly
108 hi_chunks
109 .par_rows_mut()
110 .zip(lo_chunks.par_rows_mut())
111 .enumerate()
112 .for_each(|(ind, (hi_chunk, lo_chunk))| {
113 if ind == 0 {
114 // The first pair doesn't require a twiddle factor
115 TwiddleFreeButterfly.apply_to_rows(hi_chunk, lo_chunk);
116 } else {
117 // Apply DIT butterfly using the twiddle factor at index `ind << layer_rev`
118 DitButterfly(twiddles[ind << layer_rev]).apply_to_rows(hi_chunk, lo_chunk);
119 }
120 });
121 });
122}