p3_dft/
radix_2_small_batch.rs

1//! An FFT implementation optimized for small batch sizes.
2
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use core::iter;
6
7use itertools::Itertools;
8use p3_field::{Field, TwoAdicField, scale_slice_in_place_single_core};
9use p3_matrix::Matrix;
10use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut};
11use p3_matrix::util::reverse_matrix_index_bits;
12use p3_maybe_rayon::prelude::*;
13use p3_util::{as_base_slice, log2_strict_usize, reverse_slice_index_bits};
14use spin::RwLock;
15
16use crate::{
17    Butterfly, DifButterfly, DifButterflyZeros, DitButterfly, TwiddleFreeButterfly,
18    TwoAdicSubgroupDft,
19};
20
21/// The number of layers to compute in each parallelization.
22const LAYERS_PER_GROUP: usize = 3;
23
24/// An FFT algorithm which divides a butterfly network's layers into two halves.
25///
26/// Unlike other FFT algorithms, this algorithm is optimized for small batch sizes.
27/// It also stores its twiddle factors and only re-computes if it is asked to do a
28/// larger FFT.
29///
30/// Instead of parallelizing across rows, this algorithm parallelizes across groups of rows
31/// with the same twiddle factors. This allows it to make use of field packings far more than
32/// the standard methods even for low width matrices. Once the chunk size is small enough, it
33/// computes a large set of layers fully on a single thread, which avoids the overhead of
34/// passing data between threads.
35#[derive(Default, Clone, Debug)]
36pub struct Radix2DFTSmallBatch<F> {
37    /// Memoized twiddle factors for each length log_n.
38    ///
39    /// For each `i`, `twiddles[i]` contains a list of twiddles stored in
40    /// bit reversed order. The final set of twiddles `twiddles[-1]` is the
41    /// one element vectors `[1]` and more general `twiddles[-i]` has length `2^i`.
42    #[allow(clippy::type_complexity)]
43    twiddles: Arc<RwLock<Arc<[Vec<F>]>>>,
44
45    /// Similar to `twiddles`, but stored the inverses used for the inverse fft.
46    #[allow(clippy::type_complexity)]
47    inv_twiddles: Arc<RwLock<Arc<[Vec<F>]>>>,
48}
49
50impl<F: TwoAdicField> Radix2DFTSmallBatch<F> {
51    /// Create a new `Radix2DFTSmallBatch` instance with precomputed twiddles for the given size.
52    ///
53    /// The input `n` should be a power of two, representing the maximal FFT size you expect to handle.
54    pub fn new(n: usize) -> Self {
55        let res = Self::default();
56        res.update_twiddles(n);
57        res
58    }
59
60    /// Given a field element `gen` of order n where `n = 2^lg_n`,
61    /// return a vector of vectors `table` where table[i] is the
62    /// vector of twiddle factors for an fft of length n/2^i. The
63    /// values g_i^k for k >= i/2 are skipped as these are just the
64    /// negatives of the other roots (using g_i^{i/2} = -1). The
65    /// value gen^0 = 1 is included to aid consistency between the
66    /// packed and non-packed variants.
67    fn roots_of_unity_table(&self, n: usize) -> Vec<Vec<F>> {
68        let lg_n = log2_strict_usize(n);
69        let generator = F::two_adic_generator(lg_n);
70        let half_n = 1 << (lg_n - 1);
71        // nth_roots = [1, g, g^2, g^3, ..., g^{n/2 - 1}]
72        let nth_roots = generator.powers().collect_n(half_n);
73
74        (0..lg_n)
75            .map(|i| nth_roots.iter().step_by(1 << i).copied().collect())
76            .collect()
77    }
78
79    /// Compute twiddle and inv_twiddle factors, or take memoized ones if already available.
80    fn update_twiddles(&self, fft_len: usize) {
81        // TODO: This recomputes the entire table from scratch if we
82        // need it to be larger, which is wasteful.
83
84        // roots_of_unity_table(fft_len) returns a vector of twiddles of length log_2(fft_len).
85        let curr_max_fft_len = 1 << self.twiddles.read().len();
86        if fft_len > curr_max_fft_len {
87            let mut new_twiddles = self.roots_of_unity_table(fft_len);
88            let mut new_inv_twiddles: Vec<Vec<F>> = new_twiddles
89                .iter()
90                .map(|ts| {
91                    // The first twiddle is still one, instead of inverting, we can
92                    // just reverse and negate.
93                    iter::once(F::ONE)
94                        .chain(ts[1..].iter().rev().map(|&f| -f))
95                        .collect()
96                })
97                .collect();
98
99            new_twiddles.iter_mut().for_each(|ts| {
100                reverse_slice_index_bits(ts);
101            });
102            new_inv_twiddles.iter_mut().for_each(|ts| {
103                reverse_slice_index_bits(ts);
104            });
105
106            {
107                let mut tw_lock = self.twiddles.write();
108                let cur_have = 1usize << tw_lock.len();
109                if fft_len > cur_have {
110                    *tw_lock = Arc::from(new_twiddles); // move in the new table
111                }
112            }
113            {
114                let mut inv_tw_lock = self.inv_twiddles.write();
115                let cur_have = 1usize << inv_tw_lock.len();
116                if fft_len > cur_have {
117                    *inv_tw_lock = Arc::from(new_inv_twiddles); // move in the new table
118                }
119            }
120        }
121    }
122}
123
124impl<F> TwoAdicSubgroupDft<F> for Radix2DFTSmallBatch<F>
125where
126    F: TwoAdicField,
127{
128    type Evaluations = RowMajorMatrix<F>;
129
130    fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> Self::Evaluations {
131        let h = mat.height();
132        let w = mat.width();
133        let log_h = log2_strict_usize(h);
134
135        self.update_twiddles(h);
136        let g = self.twiddles.read().clone(); // Lock is dropped immediately
137        let root_table = &g[g.len() - log_h..];
138
139        // The strategy will be to do a standard round-by-round parallelization
140        // until the chunk size is smaller than `num_par_rows * mat.width()` after which we
141        // send `num_par_rows` chunks to each thread and do the remainder of the
142        // fft without transferring any more data between threads.
143        let num_par_rows = estimate_num_rows_in_l1::<F>(h, w);
144        let log_num_par_rows = log2_strict_usize(num_par_rows);
145        let chunk_size = num_par_rows * w;
146
147        // For the layers involving blocks larger than `num_par_rows`, we will
148        // parallelize across the blocks.
149
150        let multi_layer_dit = MultiLayerDitButterfly {};
151
152        // We do `LAYERS_PER_GROUP` layers of the DFT at once, to minimize how much data we need to transfer
153        // between threads.
154        for (dit_0, dit_1, dit_2) in root_table[log_num_par_rows..]
155            .iter()
156            .rev()
157            .map(|slice| unsafe { as_base_slice::<DitButterfly<F>, F>(slice) }) // Safe as DitButterfly is #[repr(transparent)]
158            .tuples()
159        {
160            dft_layer_par_triple(&mut mat.as_view_mut(), dit_0, dit_1, dit_2, multi_layer_dit);
161        }
162
163        // If the total number of layers is not a multiple of `LAYERS_PER_GROUP`,
164        // we need to handle the remaining layers separately.
165        let corr = (log_h - log_num_par_rows) % LAYERS_PER_GROUP;
166        dft_layer_par_extra_layers(
167            &mut mat.as_view_mut(),
168            &root_table[log_num_par_rows..log_num_par_rows + corr],
169            multi_layer_dit,
170        );
171
172        // Once the blocks are small enough, we can split the matrix
173        // into chunks of size `chunk_size` and process them in parallel.
174        // This avoids passing data between threads, which can be expensive.
175        par_remaining_layers(&mut mat.values, chunk_size, &root_table[..log_num_par_rows]);
176
177        // Finally we bit-reverse the matrix to ensure the output is in the correct order.
178        reverse_matrix_index_bits(&mut mat);
179        mat
180    }
181
182    fn idft_batch(&self, mut mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
183        let h = mat.height();
184        let w = mat.width();
185        let log_h = log2_strict_usize(h);
186
187        self.update_twiddles(h);
188        let g = self.inv_twiddles.read().clone(); // Lock is dropped immediately
189        let start = g
190            .len()
191            .checked_sub(log_h)
192            .expect("log_h exceeds inv_twiddles length");
193        let root_table = &g[start..];
194
195        // Find the number of rows which can roughly fit in L1 cache.
196        // The strategy is the same as `dft_batch` but in reverse.
197        // We start by moving `num_par_rows` rows onto each thread and doing
198        // `num_par_rows` layers of the DFT. After this we recombine and do
199        // a standard round-by-round parallelization for the remaining layers.
200        let num_par_rows = estimate_num_rows_in_l1::<F>(h, w);
201        let log_num_par_rows = log2_strict_usize(num_par_rows);
202        let chunk_size = num_par_rows * w;
203
204        // Need to start by bit-reversing the matrix.
205        reverse_matrix_index_bits(&mut mat);
206
207        // For the initial blocks, they are small enough that we can split the matrix
208        // into chunks of size `chunk_size` and process them in parallel.
209        // This avoids passing data between threads, which can be expensive.
210        // We also divide by the height of the matrix while the data is nicely partitioned
211        // on each core.
212        par_initial_layers(
213            &mut mat.values,
214            chunk_size,
215            &root_table[..log_num_par_rows],
216            log_h,
217        );
218
219        // For the layers involving blocks larger than `num_par_rows`, we will
220        // parallelize across the blocks.
221
222        let multi_layer_dif = MultiLayerDifButterfly {};
223
224        // If the total number of layers is not a multiple of `LAYERS_PER_GROUP`,
225        // we need to handle the initial layers separately.
226        let corr = (log_h - log_num_par_rows) % LAYERS_PER_GROUP;
227        dft_layer_par_extra_layers(
228            &mut mat.as_view_mut(),
229            &root_table[log_num_par_rows..log_num_par_rows + corr],
230            multi_layer_dif,
231        );
232
233        // We do `LAYERS_PER_GROUP` layers of the DFT at once, to minimize how much data we need to transfer
234        // between threads.
235        for (dif_0, dif_1, dif_2) in root_table[(log_num_par_rows + corr)..]
236            .iter()
237            .map(|slice| unsafe { as_base_slice::<DifButterfly<F>, F>(slice) }) // Safe as DifButterfly is #[repr(transparent)]
238            .tuples()
239        {
240            dft_layer_par_triple(&mut mat.as_view_mut(), dif_2, dif_1, dif_0, multi_layer_dif);
241        }
242
243        mat
244    }
245
246    fn coset_lde_batch(
247        &self,
248        mut mat: RowMajorMatrix<F>,
249        added_bits: usize,
250        shift: F,
251    ) -> Self::Evaluations {
252        let h = mat.height();
253        let w = mat.width();
254        let log_h = log2_strict_usize(h);
255
256        self.update_twiddles(h << added_bits);
257        let g = self.twiddles.read().clone(); // Lock is dropped immediately
258        let start = g
259            .len()
260            .checked_sub(log_h + added_bits)
261            .expect("log_h exceeds twiddles length");
262        let root_table = &g[start..];
263        let g = self.inv_twiddles.read().clone(); // Lock is dropped immediately
264        let start = g
265            .len()
266            .checked_sub(log_h)
267            .expect("log_h exceeds inv_twiddles length");
268        let inv_root_table = &g[start..];
269        let output_height = h << added_bits;
270
271        // The matrix which we will use to store the output.
272        let output_values = F::zero_vec(output_height * w);
273        let mut out = RowMajorMatrix::new(output_values, w);
274
275        // The strategy is reasonably straightforward.
276        // The rough idea is we want to squash together the dft and idft code.
277
278        // This lets us do all of the inner layers on a single thread reducing the amount
279        // of data we need to transfer.
280
281        // For technical reasons, we need to swap the twiddle factors, using the inverse
282        // twiddles for the initial layers and the normal twiddles for the final layers.
283        // This lets us interpret the initial transformation as the idft giving us coefficients
284        // and the final transformation as the dft giving us evaluations.
285
286        // Find the number of rows which can roughly fit in L1 cache.
287        // The strategy will be to do a standard round-by-round parallelization
288        // until the chunk size is smaller than `num_par_rows * mat.width()` after which we
289        // send `num_par_rows` chunks to each thread and do the remainder of the
290        // fft without transferring any more data between threads.
291        let num_par_rows = estimate_num_rows_in_l1::<F>(h, w);
292        let num_inner_dit_layers = log2_strict_usize(num_par_rows);
293        let num_inner_dif_layers = num_inner_dit_layers + added_bits;
294
295        // We will do large DFT/iDFT layers in batches of `LAYERS_PER_GROUP`. We start with
296        // the dit layers.
297        let multi_layer_dit = MultiLayerDitButterfly {};
298        for (dit_0, dit_1, dit_2) in inv_root_table[num_inner_dit_layers..]
299            .iter()
300            .rev()
301            .map(|slice| unsafe { as_base_slice::<DitButterfly<F>, F>(slice) }) // Safe as DitButterfly is #[repr(transparent)]
302            .tuples()
303        {
304            dft_layer_par_triple(&mut mat.as_view_mut(), dit_0, dit_1, dit_2, multi_layer_dit);
305        }
306
307        // If the total number of layers is not a multiple of `LAYERS_PER_GROUP`,
308        // we need to handle the remaining layers separately.
309        let corr = (log_h - num_inner_dit_layers) % LAYERS_PER_GROUP;
310        dft_layer_par_extra_layers(
311            &mut mat.as_view_mut(),
312            &inv_root_table[num_inner_dit_layers..num_inner_dit_layers + corr],
313            multi_layer_dit,
314        );
315
316        // Now do all the inner layers at once. This does the final `log_num_par_rows` of
317        // the initial transformation, then copies the values of mat to output, scales then
318        // and does the first `log_num_par_rows + added_bits` layers of the final transformation.
319        par_middle_layers(
320            &mut mat.as_view_mut(),
321            &mut out.as_view_mut(),
322            num_par_rows,
323            &root_table[..(num_inner_dif_layers)],
324            &inv_root_table[..num_inner_dit_layers],
325            added_bits,
326            shift,
327        );
328
329        // We are left with the final dif layers.
330        let multi_layer_dif = MultiLayerDifButterfly {};
331
332        // If the total number of layers is not a multiple of `LAYERS_PER_GROUP`,
333        // we need to handle the remaining layers separately.
334        dft_layer_par_extra_layers(
335            &mut out.as_view_mut(),
336            &root_table[num_inner_dif_layers..num_inner_dif_layers + corr],
337            multi_layer_dif,
338        );
339
340        // We do `LAYERS_PER_GROUP` layers of the DFT at once, to minimize how much data we need to transfer
341        // between threads.
342        for (dif_0, dif_1, dif_2) in root_table[(num_inner_dif_layers + corr)..]
343            .iter()
344            .map(|slice| unsafe { as_base_slice::<DifButterfly<F>, F>(slice) }) // Safe as DifButterfly is #[repr(transparent)]
345            .tuples()
346        {
347            dft_layer_par_triple(&mut out.as_view_mut(), dif_2, dif_1, dif_0, multi_layer_dif);
348        }
349
350        out
351    }
352}
353
354/// Applies one layer of the Radix-2 DIF FFT butterfly network making use of parallelization.
355///
356/// Splits the matrix into blocks of rows and performs in-place butterfly operations
357/// on each block. Uses a `TwiddleFreeButterfly` for the first pair and `DifButterfly`
358/// with precomputed twiddles for the rest.
359///
360/// Each block is processed in parallel, if the blocks are large enough they themselves
361/// are split into parallel sub-blocks.
362///
363/// # Arguments
364/// - `mat`: Mutable matrix whose height is a power of two.
365/// - `twiddles`: Precomputed twiddle factors for this layer.
366#[inline]
367fn dft_layer_par<F: Field, B: Butterfly<F>>(
368    mat: &mut RowMajorMatrixViewMut<'_, F>,
369    twiddles: &[B],
370) {
371    debug_assert!(
372        mat.height().is_multiple_of(twiddles.len()),
373        "Matrix height must be divisible by the number of twiddles"
374    );
375    let size = mat.values.len();
376    let num_blocks = twiddles.len();
377
378    let outer_block_size = size / num_blocks;
379    let half_outer_block_size = outer_block_size / 2;
380
381    mat.values
382        .par_chunks_exact_mut(outer_block_size)
383        .enumerate()
384        .for_each(|(ind, block)| {
385            // Split each block vertically into top (hi) and bottom (lo) halves
386            let (hi_chunk, lo_chunk) = block.split_at_mut(half_outer_block_size);
387
388            // If num_blocks is small, we probably are not using all available threads.
389            let num_threads = current_num_threads();
390            let inner_block_size = size / (2 * num_blocks).max(num_threads);
391
392            hi_chunk
393                .par_chunks_mut(inner_block_size)
394                .zip(lo_chunk.par_chunks_mut(inner_block_size))
395                .for_each(|(hi_chunk, lo_chunk)| {
396                    if ind == 0 {
397                        // The first pair doesn't require a twiddle factor
398                        TwiddleFreeButterfly.apply_to_rows(hi_chunk, lo_chunk);
399                    } else {
400                        // Apply DIT butterfly using the twiddle factor at index `ind - 1`
401                        twiddles[ind].apply_to_rows(hi_chunk, lo_chunk);
402                    }
403                });
404        });
405}
406
407/// Splits the matrix into chunks of size `chunk_size` and performs
408/// the remaining layers of the FFT in parallel on each chunk.
409///
410/// This avoids passing data between threads, which can be expensive.
411#[inline]
412fn par_remaining_layers<F: Field>(mat: &mut [F], chunk_size: usize, root_table: &[Vec<F>]) {
413    mat.par_chunks_exact_mut(chunk_size)
414        .enumerate()
415        .for_each(|(index, chunk)| {
416            remaining_layers(chunk, root_table, index);
417        });
418}
419
420/// Performs a collection of DIT layers on a chunk of the matrix.
421fn remaining_layers<F: Field>(chunk: &mut [F], root_table: &[Vec<F>], index: usize) {
422    for (layer, twiddles) in root_table.iter().rev().enumerate() {
423        let num_twiddles_per_block = 1 << layer;
424        let start = index * num_twiddles_per_block;
425        let twiddle_range = start..(start + num_twiddles_per_block);
426        // Safe as DitButterfly is #[repr(transparent)]
427        let dit_twiddles: &[DitButterfly<F>] = unsafe { as_base_slice(&twiddles[twiddle_range]) };
428        dft_layer(chunk, dit_twiddles);
429    }
430}
431
432/// Splits the matrix into chunks of size `chunk_size` and performs
433/// the initial layers of the iFFT in parallel on each chunk.
434///
435/// This avoids passing data between threads, which can be expensive.
436///
437/// Basically identical to [par_remaining_layers] but in reverse and we
438/// also divide by the height.
439#[inline]
440fn par_initial_layers<F: Field>(
441    mat: &mut [F],
442    chunk_size: usize,
443    root_table: &[Vec<F>],
444    log_height: usize,
445) {
446    let inv_height = F::ONE.div_2exp_u64(log_height as u64);
447    mat.par_chunks_exact_mut(chunk_size)
448        .enumerate()
449        .for_each(|(index, chunk)| {
450            // Divide all elements by the height of the matrix.
451            scale_slice_in_place_single_core(chunk, inv_height);
452            initial_layers(chunk, root_table, index);
453        });
454}
455
456/// Performs a collection of DIF layers on a chunk of the matrix.
457#[inline]
458fn initial_layers<F: Field>(chunk: &mut [F], root_table: &[Vec<F>], index: usize) {
459    let num_rounds = root_table.len();
460
461    for (layer, twiddles) in root_table.iter().enumerate() {
462        let num_twiddles_per_block = 1 << (num_rounds - layer - 1);
463        let start = index * num_twiddles_per_block;
464        let twiddle_range = start..(start + num_twiddles_per_block);
465        // Safe as DifButterfly is #[repr(transparent)]
466        let dif_twiddles: &[DifButterfly<F>] = unsafe { as_base_slice(&twiddles[twiddle_range]) };
467        dft_layer(chunk, dif_twiddles);
468    }
469}
470
471/// Splits the matrix into chunks of size `chunk_size` and performs
472/// the middle layers of a coset_lde in parallel on each chunk.
473///
474/// Similar to [par_remaining_layers] followed by [par_initial_layers]
475/// with a scaling and copying operation in between.
476fn par_middle_layers<F: Field>(
477    in_mat: &mut RowMajorMatrixViewMut<'_, F>,
478    out_mat: &mut RowMajorMatrixViewMut<'_, F>,
479    num_par_rows: usize,
480    root_table: &[Vec<F>],
481    inv_root_table: &[Vec<F>],
482    added_bits: usize,
483    shift: F,
484) {
485    debug_assert_eq!(in_mat.width(), out_mat.width());
486    debug_assert_eq!(in_mat.height() << added_bits, out_mat.height());
487
488    let width = in_mat.width();
489    let height = in_mat.height();
490    let num_rounds = root_table.len();
491    let in_chunk_size = num_par_rows * width;
492    let out_chunk_size = in_chunk_size << added_bits;
493
494    let log_height = log2_strict_usize(height);
495    let inv_height = F::ONE.div_2exp_u64(log_height as u64);
496
497    let mut scaling = shift.shifted_powers(inv_height).collect_n(height);
498    reverse_slice_index_bits(&mut scaling);
499
500    in_mat
501        .values
502        .par_chunks_exact_mut(in_chunk_size)
503        .zip(out_mat.values.par_chunks_exact_mut(out_chunk_size))
504        .zip(scaling.par_chunks_exact_mut(num_par_rows))
505        .enumerate()
506        .for_each(|(index, ((in_chunk, out_chunk), scaling))| {
507            remaining_layers(in_chunk, inv_root_table, index);
508
509            // Copy the values to the output matrix and scale appropriately.
510            in_chunk
511                .chunks_exact(width)
512                .zip(scaling)
513                .zip(out_chunk.chunks_exact_mut(width << added_bits))
514                .for_each(|((in_row, scale), out_row)| {
515                    out_row
516                        .iter_mut()
517                        .zip(in_row.iter())
518                        .for_each(|(out_val, in_val)| {
519                            *out_val = *in_val * *scale;
520                        });
521                });
522
523            // We can do something cheaper than standard DFT layers for the first `added_bits` layers.
524            // as there are a lot of zeroes in the out_chunk.
525            for (layer, twiddles) in root_table[..added_bits].iter().enumerate() {
526                let num_twiddles_per_block = 1 << (num_rounds - layer - 1);
527                let start = index * num_twiddles_per_block;
528                let twiddle_range = start..(start + num_twiddles_per_block);
529
530                // Safe as DifButterflyZeros is #[repr(transparent)]
531                let dif_twiddles_zeros: &[DifButterflyZeros<F>] =
532                    unsafe { as_base_slice(&twiddles[twiddle_range]) };
533                dft_layer_zeros(out_chunk, dif_twiddles_zeros, added_bits - layer - 1);
534            }
535
536            initial_layers(out_chunk, &root_table[added_bits..], index);
537        });
538}
539
540/// Applies one layer of the Radix-2 FFT butterfly network on a single core.
541///
542/// Splits the matrix into blocks of rows and performs in-place butterfly operations
543/// on each block.
544///
545/// # Arguments
546/// - `vec`: Mutable vector whose height is a power of two.
547/// - `twiddles`: Precomputed twiddle factors for this layer.
548#[inline]
549fn dft_layer<F: Field, B: Butterfly<F>>(vec: &mut [F], twiddles: &[B]) {
550    debug_assert_eq!(
551        vec.len() % twiddles.len(),
552        0,
553        "Vector length must be divisible by the number of twiddles"
554    );
555    let size = vec.len();
556    let num_blocks = twiddles.len();
557
558    let block_size = size / num_blocks;
559    let half_block_size = block_size / 2;
560
561    vec.chunks_exact_mut(block_size)
562        .zip(twiddles)
563        .for_each(|(block, &twiddle)| {
564            // Split each block vertically into top (hi) and bottom (lo) halves
565            let (hi_chunk, lo_chunk) = block.split_at_mut(half_block_size);
566
567            // Apply DIT butterfly
568            twiddle.apply_to_rows(hi_chunk, lo_chunk);
569        });
570}
571
572/// Applies two layers of the Radix-2 FFT butterfly network making use of parallelization.
573///
574/// Splits the matrix into blocks of rows and performs in-place butterfly operations
575/// on each block. Advantage of doing two layers at once is it reduces the amount of
576/// data transferred between threads.
577///
578/// # Arguments
579/// - `mat`: Mutable matrix whose height is a power of two.
580/// - `twiddles_small`: Precomputed twiddle factors for the layer with the smallest block size.
581/// - `twiddles_large`: Precomputed twiddle factors for the layer with the largest block size.
582/// - `multi_butterfly`: Multi-layer butterfly which applies the two layers in the correct order.
583#[inline]
584fn dft_layer_par_double<F: Field, B: Butterfly<F>, M: MultiLayerButterfly<F, B>>(
585    mat: &mut RowMajorMatrixViewMut<'_, F>,
586    twiddles_small: &[B],
587    twiddles_large: &[B],
588    multi_butterfly: M,
589) {
590    debug_assert!(
591        mat.height().is_multiple_of(twiddles_small.len()),
592        "Matrix height must be divisible by the number of twiddles"
593    );
594    let size = mat.values.len();
595    let num_blocks = twiddles_small.len();
596
597    let outer_block_size = size / num_blocks;
598    let quarter_outer_block_size = outer_block_size / 4;
599
600    // Estimate the optimal size of the inner chunks so that all data fits in L1 cache.
601    // Note that 4 inner chunks are processed in each parallel thread so we divide by 4.
602    let inner_chunk_size =
603        (workload_size::<F>().next_power_of_two() / 4).min(quarter_outer_block_size);
604
605    mat.values
606        .par_chunks_exact_mut(outer_block_size)
607        .enumerate()
608        .for_each(|(ind, block)| {
609            // Split each block into four quarters. Each quarter will be further split into
610            // sub-chunks processed in parallel.
611            let chunk_par_iters_0 = block
612                .chunks_exact_mut(quarter_outer_block_size)
613                .map(|chunk| chunk.par_chunks_mut(inner_chunk_size))
614                .collect::<Vec<_>>();
615            let chunk_par_iters_1 = zip_par_iter_vec(chunk_par_iters_0);
616            chunk_par_iters_1.into_iter().tuples().for_each(|(hi, lo)| {
617                hi.zip(lo).for_each(|chunks| {
618                    multi_butterfly.apply_2_layers(chunks, ind, twiddles_small, twiddles_large);
619                });
620            });
621        });
622}
623
624/// Applies three layers of a Radix-2 FFT butterfly network making use of parallelization.
625///
626/// Splits the matrix into blocks of rows and performs in-place butterfly operations
627/// on each block. Advantage of doing three layers at once is it reduces the amount of
628/// data transferred between threads.
629///
630/// # Arguments
631/// - `mat`: Mutable matrix whose height is a power of two.
632/// - `twiddles_small`: Precomputed twiddle factors for the layer with the smallest block size.
633/// - `twiddles_med`: Precomputed twiddle factors for the middle layer.
634/// - `twiddles_large`: Precomputed twiddle factors for the layer with the largest block size.
635/// - `multi_butterfly`: Multi-layer butterfly which applies the three layers in the correct order.
636#[inline]
637fn dft_layer_par_triple<F: Field, B: Butterfly<F>, M: MultiLayerButterfly<F, B>>(
638    mat: &mut RowMajorMatrixViewMut<'_, F>,
639    twiddles_small: &[B],
640    twiddles_med: &[B],
641    twiddles_large: &[B],
642    multi_butterfly: M,
643) {
644    debug_assert!(
645        mat.height().is_multiple_of(twiddles_small.len()),
646        "Matrix height must be divisible by the number of twiddles"
647    );
648    let size = mat.values.len();
649    let num_blocks = twiddles_small.len();
650
651    let outer_block_size = size / num_blocks;
652    let eighth_outer_block_size = outer_block_size / 8;
653
654    // Estimate the optimal size of the inner chunks so that all data fits in L1 cache.
655    // Note that 8 inner chunks are processed in each parallel thread so we divide by 8.
656    let inner_chunk_size =
657        (workload_size::<F>().next_power_of_two() / 8).min(eighth_outer_block_size);
658
659    mat.values
660        .par_chunks_exact_mut(outer_block_size)
661        .enumerate()
662        .for_each(|(ind, block)| {
663            // Split each block into eight equal parts. Each part will be further split into
664            // sub-chunks processed in parallel.
665            let chunk_par_iters_0 = block
666                .chunks_exact_mut(eighth_outer_block_size)
667                .map(|chunk| chunk.par_chunks_mut(inner_chunk_size))
668                .collect::<Vec<_>>();
669            let chunk_par_iters_1 = zip_par_iter_vec(chunk_par_iters_0);
670            let chunk_par_iters_2 = zip_par_iter_vec(chunk_par_iters_1);
671            chunk_par_iters_2.into_iter().tuples().for_each(|(hi, lo)| {
672                hi.zip(lo).for_each(|chunks| {
673                    multi_butterfly.apply_3_layers(
674                        chunks,
675                        ind,
676                        twiddles_small,
677                        twiddles_med,
678                        twiddles_large,
679                    );
680                });
681            });
682        });
683}
684
685/// Applies the remaining layers of the Radix-2 FFT butterfly network in parallel.
686///
687/// This function is used to correct for the fact that the total number of layers
688/// may not be a multiple of `LAYERS_PER_GROUP`.
689fn dft_layer_par_extra_layers<F: Field, B: Butterfly<F>, M: MultiLayerButterfly<F, B>>(
690    mat: &mut RowMajorMatrixViewMut<'_, F>,
691    root_table: &[Vec<F>],
692    multi_layer: M,
693) {
694    match root_table.len() {
695        1 => {
696            // Safe as DitButterfly is #[repr(transparent)]
697            let fft_layer: &[B] = unsafe { as_base_slice(&root_table[0]) };
698            dft_layer_par(&mut mat.as_view_mut(), fft_layer);
699        }
700        2 => {
701            let fft_layer_0: &[B] = unsafe { as_base_slice(&root_table[0]) };
702            let fft_layer_1: &[B] = unsafe { as_base_slice(&root_table[1]) };
703            dft_layer_par_double(
704                &mut mat.as_view_mut(),
705                fft_layer_1,
706                fft_layer_0,
707                multi_layer,
708            );
709        }
710        0 => {}
711        _ => unreachable!("The number of layers must be 0, 1 or 2"),
712    }
713}
714
715/// Applies one layer of the Radix-2 FFT butterfly network on a single core to
716/// a recently zero-padded matrix.
717///
718/// Splits the matrix into blocks of rows and performs in-place butterfly operations
719/// on each block.
720///
721/// Assume `added_bits = 2` and we are doing a decimation in frequency approach.
722/// Then the rows of our matrix look like:
723/// ```text
724/// [R0, 0, 0, 0, R1, 0, 0, 0, ...]
725/// ```
726/// Thus the first two butterfly layers can be implemented more simply as they map the matrix to:
727/// ```text
728/// After Layer 0: [R0, T00 * R0, 0, 0, R1, T01 * R1, 0, 0, ...]
729/// After Layer 1: [R0, T00 * R0, T10 * R0, T10 * T00 * R0, R1, T01 * R1, T11 * R1, T11 * T01 * R1, ...].
730/// ```
731///
732/// # Arguments
733/// - `vec`: Mutable vector whose height is a power of two.
734/// - `twiddles`: Precomputed twiddle factors for this layer.
735/// - `skip`: `(1 << skip) - 1` is the number of entirely zero
736///   blocks between each non zero block.
737#[inline]
738fn dft_layer_zeros<F: Field, B: Butterfly<F>>(vec: &mut [F], twiddles: &[B], skip: usize) {
739    debug_assert_eq!(
740        vec.len() % twiddles.len(),
741        0,
742        "Vector length must be divisible by the number of twiddles"
743    );
744    let size = vec.len();
745    let num_blocks = twiddles.len();
746
747    let block_size = size / num_blocks;
748    let half_block_size = block_size / 2;
749
750    vec.chunks_exact_mut(block_size)
751        .zip(twiddles)
752        .step_by(1 << skip) // Skip the zero blocks
753        .for_each(|(block, &twiddle)| {
754            // Split each block vertically into top (hi) and bottom (lo) halves
755            let (hi_chunk, lo_chunk) = block.split_at_mut(half_block_size);
756
757            // Apply DIF butterfly making use of the fact that `lo_chunk` is zero.
758            twiddle.apply_to_rows(hi_chunk, lo_chunk);
759        });
760}
761
762/// A type representing a decomposition of an FFT block into four sub-blocks.
763type DoubleLayerBlockDecomposition<'a, F> =
764    ((&'a mut [F], &'a mut [F]), (&'a mut [F], &'a mut [F]));
765
766/// Performs an FFT layer on the sub-blocks using a single twiddle factor.
767#[inline]
768fn fft_double_layer_single_twiddle<F: Field, Fly: Butterfly<F>>(
769    block: &mut DoubleLayerBlockDecomposition<'_, F>,
770    butterfly: Fly,
771) {
772    butterfly.apply_to_rows(block.0.0, block.1.0);
773    butterfly.apply_to_rows(block.0.1, block.1.1);
774}
775
776/// Performs an FFT layer on the sub-blocks using a pair of twiddle factors.
777///
778/// The inputs are differentiated in order to allow the first input to potentially
779/// be a `TwiddleFreeButterfly`, which does not require a twiddle factor.
780#[inline]
781fn fft_double_layer_double_twiddle<F: Field, Fly0: Butterfly<F>, Fly1: Butterfly<F>>(
782    block: &mut DoubleLayerBlockDecomposition<'_, F>,
783    fly0: Fly0,
784    fly1: Fly1,
785) {
786    fly0.apply_to_rows(block.0.0, block.0.1);
787    fly1.apply_to_rows(block.1.0, block.1.1);
788}
789
790/// A type representing a decomposition of an FFT block into eight sub-blocks.
791type TripleLayerBlockDecomposition<'a, F> = (
792    ((&'a mut [F], &'a mut [F]), (&'a mut [F], &'a mut [F])),
793    ((&'a mut [F], &'a mut [F]), (&'a mut [F], &'a mut [F])),
794);
795
796/// Performs an FFT layer on the sub-blocks using a single twiddle factor.
797#[inline]
798fn fft_triple_layer_single_twiddle<F: Field, Fly: Butterfly<F>>(
799    block: &mut TripleLayerBlockDecomposition<'_, F>,
800    butterfly: Fly,
801) {
802    butterfly.apply_to_rows(block.0.0.0, block.1.0.0);
803    butterfly.apply_to_rows(block.0.0.1, block.1.0.1);
804    butterfly.apply_to_rows(block.0.1.0, block.1.1.0);
805    butterfly.apply_to_rows(block.0.1.1, block.1.1.1);
806}
807
808/// Performs an FFT layer on the sub-blocks using a pair of twiddle factors.
809///
810/// The inputs are differentiated in order to allow the first input to potentially
811/// be a `TwiddleFreeButterfly`, which does not require a twiddle factor.
812#[inline]
813fn fft_triple_layer_double_twiddle<F: Field, Fly0: Butterfly<F>, Fly1: Butterfly<F>>(
814    block: &mut TripleLayerBlockDecomposition<'_, F>,
815    fly0: Fly0,
816    fly1: Fly1,
817) {
818    fly0.apply_to_rows(block.0.0.0, block.0.1.0);
819    fly0.apply_to_rows(block.0.0.1, block.0.1.1);
820    fly1.apply_to_rows(block.1.0.0, block.1.1.0);
821    fly1.apply_to_rows(block.1.0.1, block.1.1.1);
822}
823
824/// Performs an FFT layer on the sub-blocks using a four twiddle factors.
825///
826/// The inputs are differentiated in order to allow the first input to potentially
827/// be a `TwiddleFreeButterfly`, which does not require a twiddle factor.
828#[inline]
829fn fft_triple_layer_quad_twiddle<F: Field, Fly0: Butterfly<F>, Flies: Butterfly<F>>(
830    block: &mut TripleLayerBlockDecomposition<'_, F>,
831    fly0: Fly0,
832    butterflies: &[Flies],
833) {
834    debug_assert!(butterflies.len() == 3);
835    fly0.apply_to_rows(block.0.0.0, block.0.0.1);
836    butterflies[0].apply_to_rows(block.0.1.0, block.0.1.1);
837    butterflies[1].apply_to_rows(block.1.0.0, block.1.0.1);
838    butterflies[2].apply_to_rows(block.1.1.0, block.1.1.1);
839}
840
841/// Estimates the optimal workload size for `T` to fit in L1 cache.
842///
843/// Approximates the size of the L1 cache by 32 KB. Used to determine the number of
844/// chunks to process in parallel.
845#[must_use]
846const fn workload_size<T: Sized>() -> usize {
847    const L1_CACHE_SIZE: usize = 1 << 15; // 32 KB
848    L1_CACHE_SIZE / size_of::<T>()
849}
850
851/// Estimates the optimal number of rows of a `RowMajorMatrix<T>` to take in each parallel chunk.
852///
853/// Designed to ensure that `<T> * estimate_num_rows_par() * width` is roughly the size of the L1 cache.
854///
855/// Assumes that height is a power of two and always outputs a power of two.
856#[must_use]
857fn estimate_num_rows_in_l1<T: Sized>(height: usize, width: usize) -> usize {
858    (workload_size::<T>() / width)
859        .next_power_of_two()
860        .min(height) // Ensure we don't exceed the height of the matrix.
861}
862
863/// Given a vector of parallel iterators, zip all pairs together.
864///
865/// This lets us simulate the izip!() macro but for our possibly parallel iterators.
866///
867/// This function assumes that the input vector has an even number of elements. If
868/// it is given an odd number of elements, the last element will be ignored.
869#[inline]
870fn zip_par_iter_vec<I: IndexedParallelIterator>(
871    in_vec: Vec<I>,
872) -> Vec<impl IndexedParallelIterator<Item = (I::Item, I::Item)>> {
873    in_vec
874        .into_iter()
875        .tuples()
876        .map(|(hi, lo)| hi.zip(lo))
877        .collect::<Vec<_>>()
878}
879
880trait MultiLayerButterfly<F: Field, B: Butterfly<F>>: Copy + Send + Sync {
881    fn apply_2_layers(
882        &self,
883        chunk_decomposition: DoubleLayerBlockDecomposition<'_, F>,
884        ind: usize,
885        twiddles_small: &[B],
886        twiddles_large: &[B],
887    );
888
889    fn apply_3_layers(
890        &self,
891        chunk_decomposition: TripleLayerBlockDecomposition<'_, F>,
892        ind: usize,
893        twiddles_small: &[B],
894        twiddles_med: &[B],
895        twiddles_large: &[B],
896    );
897}
898
899#[derive(Debug, Clone, Copy)]
900struct MultiLayerDitButterfly;
901
902impl<F: Field> MultiLayerButterfly<F, DitButterfly<F>> for MultiLayerDitButterfly {
903    #[inline]
904    fn apply_2_layers(
905        &self,
906        mut blk_decomp: DoubleLayerBlockDecomposition<'_, F>,
907        ind: usize,
908        twiddles_small: &[DitButterfly<F>],
909        twiddles_large: &[DitButterfly<F>],
910    ) {
911        if ind == 0 {
912            fft_double_layer_single_twiddle(&mut blk_decomp, TwiddleFreeButterfly);
913            fft_double_layer_double_twiddle(
914                &mut blk_decomp,
915                TwiddleFreeButterfly,
916                twiddles_large[1],
917            );
918        } else {
919            fft_double_layer_single_twiddle(&mut blk_decomp, twiddles_small[ind]);
920            fft_double_layer_double_twiddle(
921                &mut blk_decomp,
922                twiddles_large[2 * ind],
923                twiddles_large[2 * ind + 1],
924            );
925        }
926    }
927
928    #[inline]
929    fn apply_3_layers(
930        &self,
931        mut blk_decomp: TripleLayerBlockDecomposition<'_, F>,
932        ind: usize,
933        twiddles_small: &[DitButterfly<F>],
934        twiddles_med: &[DitButterfly<F>],
935        twiddles_large: &[DitButterfly<F>],
936    ) {
937        if ind == 0 {
938            fft_triple_layer_single_twiddle(&mut blk_decomp, TwiddleFreeButterfly);
939            fft_triple_layer_double_twiddle(&mut blk_decomp, TwiddleFreeButterfly, twiddles_med[1]);
940            fft_triple_layer_quad_twiddle(
941                &mut blk_decomp,
942                TwiddleFreeButterfly,
943                &twiddles_large[1..4],
944            );
945        } else {
946            fft_triple_layer_single_twiddle(&mut blk_decomp, twiddles_small[ind]);
947            fft_triple_layer_double_twiddle(
948                &mut blk_decomp,
949                twiddles_med[2 * ind],
950                twiddles_med[2 * ind + 1],
951            );
952            fft_triple_layer_quad_twiddle(
953                &mut blk_decomp,
954                twiddles_large[4 * ind],
955                &twiddles_large[4 * ind + 1..4 * (ind + 1)],
956            );
957        }
958    }
959}
960
961#[derive(Debug, Clone, Copy)]
962struct MultiLayerDifButterfly;
963
964impl<F: Field> MultiLayerButterfly<F, DifButterfly<F>> for MultiLayerDifButterfly {
965    #[inline]
966    fn apply_2_layers(
967        &self,
968        mut blk_decomp: DoubleLayerBlockDecomposition<'_, F>,
969        ind: usize,
970        twiddles_small: &[DifButterfly<F>],
971        twiddles_large: &[DifButterfly<F>],
972    ) {
973        if ind == 0 {
974            fft_double_layer_double_twiddle(
975                &mut blk_decomp,
976                TwiddleFreeButterfly,
977                twiddles_large[1],
978            );
979            fft_double_layer_single_twiddle(&mut blk_decomp, TwiddleFreeButterfly);
980        } else {
981            fft_double_layer_double_twiddle(
982                &mut blk_decomp,
983                twiddles_large[2 * ind],
984                twiddles_large[2 * ind + 1],
985            );
986            fft_double_layer_single_twiddle(&mut blk_decomp, twiddles_small[ind]);
987        }
988    }
989
990    #[inline]
991    fn apply_3_layers(
992        &self,
993        mut blk_decomp: TripleLayerBlockDecomposition<'_, F>,
994        ind: usize,
995        twiddles_small: &[DifButterfly<F>],
996        twiddles_med: &[DifButterfly<F>],
997        twiddles_large: &[DifButterfly<F>],
998    ) {
999        if ind == 0 {
1000            fft_triple_layer_quad_twiddle(
1001                &mut blk_decomp,
1002                TwiddleFreeButterfly,
1003                &twiddles_large[1..4],
1004            );
1005            fft_triple_layer_double_twiddle(&mut blk_decomp, TwiddleFreeButterfly, twiddles_med[1]);
1006            fft_triple_layer_single_twiddle(&mut blk_decomp, TwiddleFreeButterfly);
1007        } else {
1008            fft_triple_layer_quad_twiddle(
1009                &mut blk_decomp,
1010                twiddles_large[4 * ind],
1011                &twiddles_large[4 * ind + 1..4 * (ind + 1)],
1012            );
1013            fft_triple_layer_double_twiddle(
1014                &mut blk_decomp,
1015                twiddles_med[2 * ind],
1016                twiddles_med[2 * ind + 1],
1017            );
1018            fft_triple_layer_single_twiddle(&mut blk_decomp, twiddles_small[ind]);
1019        }
1020    }
1021}