Skip to main content

p3_dft/
radix_2_dit_parallel.rs

1use alloc::collections::BTreeMap;
2use alloc::slice;
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use core::mem::{MaybeUninit, transmute};
6
7use itertools::{Itertools, izip};
8use p3_field::integers::QuotientMap;
9use p3_field::{Field, Powers, TwoAdicField};
10use p3_matrix::Matrix;
11use p3_matrix::bitrev::{BitReversalPerm, BitReversedMatrixView, BitReversibleMatrix};
12use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView, RowMajorMatrixViewMut};
13use p3_matrix::util::reverse_matrix_index_bits;
14use p3_maybe_rayon::prelude::*;
15use p3_util::{log2_strict_usize, reverse_bits_len, reverse_slice_index_bits};
16use spin::RwLock;
17use tracing::{debug_span, instrument};
18
19use crate::butterflies::{Butterfly, DitButterfly, ScaledDitButterfly, TwiddleFreeButterfly};
20use crate::{Layout, TwoAdicSubgroupDft};
21
22/// A parallel FFT algorithm which divides a butterfly network's layers into two halves.
23///
24/// For the first half, we apply a butterfly network with smaller blocks in earlier layers,
25/// i.e. either DIT or Bowers G. Then we bit-reverse, and for the second half, we continue executing
26/// the same network but in bit-reversed order. This way we're always working with small blocks,
27/// so within each half, we can have a certain amount of parallelism with no cross-thread
28/// communication.
29#[derive(Default, Clone, Debug)]
30pub struct Radix2DitParallel<F> {
31    /// Twiddles based on roots of unity, used in the forward DFT.
32    twiddles: Arc<RwLock<BTreeMap<usize, Arc<VectorPair<F>>>>>,
33
34    /// A map from `(log_h, shift)` to forward DFT twiddles with that coset shift baked in.
35    #[allow(clippy::type_complexity)]
36    coset_twiddles: Arc<RwLock<BTreeMap<(usize, F), Arc<[Vec<F>]>>>>,
37
38    /// Twiddles based on inverse roots of unity, used in the inverse DFT.
39    inverse_twiddles: Arc<RwLock<BTreeMap<usize, Arc<VectorPair<F>>>>>,
40}
41
42/// A pair of vectors, one with twiddle factors in their natural order, the other bit-reversed.
43#[derive(Default, Clone, Debug)]
44struct VectorPair<F> {
45    twiddles: Vec<F>,
46    bitrev_twiddles: Vec<F>,
47}
48
49impl<F> Radix2DitParallel<F>
50where
51    F: TwoAdicField + Ord,
52{
53    fn get_or_compute_twiddles(&self, log_h: usize) -> Arc<VectorPair<F>> {
54        // Fast path: Check for the value with a cheap read lock.
55        if let Some(pair) = self.twiddles.read().get(&log_h) {
56            return pair.clone();
57        }
58
59        // Slow path: The value doesn't exist. Acquire a write lock.
60        let mut w_lock = self.twiddles.write();
61
62        // Double-check and compute if necessary.
63        w_lock
64            .entry(log_h)
65            .or_insert_with(|| {
66                let half_h = (1 << log_h) >> 1;
67                let root = F::two_adic_generator(log_h);
68                let twiddles = root.powers().collect_n(half_h);
69                let mut bitrev_twiddles = twiddles.clone();
70                reverse_slice_index_bits(&mut bitrev_twiddles);
71
72                Arc::new(VectorPair {
73                    twiddles,
74                    bitrev_twiddles,
75                })
76            })
77            .clone()
78    }
79
80    fn get_or_compute_coset_twiddles(&self, (log_h, shift): (usize, F)) -> Arc<[Vec<F>]> {
81        let key = (log_h, shift);
82        // Fast path: Try to get the value with a cheap read lock first.
83        if let Some(twiddles) = self.coset_twiddles.read().get(&key) {
84            return twiddles.clone();
85        }
86        // Slow path: The value isn't there, so we need to compute it.
87        // Acquire a write lock to ensure only one thread does the computation.
88        let mut w_lock = self.coset_twiddles.write();
89        // Double-check: Another thread might have inserted it while we waited for the lock.
90        // The `entry` API handles this check and insertion atomically.
91        w_lock
92            .entry(key)
93            .or_insert_with(|| {
94                let mid = log_h.div_ceil(2);
95                let h = 1 << log_h;
96                let root = F::two_adic_generator(log_h);
97                (0..log_h)
98                    .map(|layer| {
99                        let shift_power = shift.exp_power_of_2(layer);
100                        let powers = Powers {
101                            base: root.exp_power_of_2(layer),
102                            current: shift_power,
103                        };
104                        let mut twiddles = powers.collect_n(h >> (layer + 1));
105                        let layer_rev = log_h - 1 - layer;
106                        if layer_rev >= mid {
107                            reverse_slice_index_bits(&mut twiddles);
108                        }
109                        twiddles
110                    })
111                    .collect::<Vec<_>>()
112                    .into()
113            })
114            .clone()
115    }
116
117    fn get_or_compute_inverse_twiddles(&self, log_h: usize) -> Arc<VectorPair<F>> {
118        // Fast path: First, check for the value using a cheap read lock.
119        if let Some(pair) = self.inverse_twiddles.read().get(&log_h) {
120            return pair.clone();
121        }
122        // Slow path: The value doesn't exist. Acquire a write lock.
123        let mut w_lock = self.inverse_twiddles.write();
124        // Double-check: Another thread might have created the entry while we waited.
125        // The `entry` API handles this check and the insertion atomically.
126        w_lock
127            .entry(log_h)
128            .or_insert_with(|| {
129                // This computation only runs if the entry is truly vacant.
130                let half_h = (1 << log_h) >> 1;
131                let root_inv = F::two_adic_generator(log_h).inverse();
132                let twiddles = root_inv.powers().collect_n(half_h);
133                let mut bitrev_twiddles = twiddles.clone();
134                reverse_slice_index_bits(&mut bitrev_twiddles);
135
136                Arc::new(VectorPair {
137                    twiddles,
138                    bitrev_twiddles,
139                })
140            })
141            .clone()
142    }
143}
144
145impl<F: TwoAdicField + Ord> TwoAdicSubgroupDft<F> for Radix2DitParallel<F> {
146    type Evaluations = BitReversedMatrixView<RowMajorMatrix<F>>;
147
148    fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> Self::Evaluations {
149        let h = mat.height();
150        let log_h = log2_strict_usize(h);
151
152        // Compute twiddle factors, or take memoized ones if already available.
153        let twiddles = self.get_or_compute_twiddles(log_h);
154
155        let mid = log_h.div_ceil(2);
156
157        // The first half looks like a normal DIT.
158        reverse_matrix_index_bits(&mut mat);
159        first_half(&mut mat, mid, &twiddles.twiddles);
160
161        // For the second half, we flip the DIT, working in bit-reversed order.
162        reverse_matrix_index_bits(&mut mat);
163        second_half(&mut mat, mid, &twiddles.bitrev_twiddles, None);
164
165        mat.bit_reverse_rows()
166    }
167
168    fn coset_dft_batch(&self, mut mat: RowMajorMatrix<F>, shift: F) -> Self::Evaluations {
169        reverse_matrix_index_bits(&mut mat);
170        coset_dft(self, &mut mat.as_view_mut(), shift);
171        BitReversalPerm::new_view(mat)
172    }
173
174    fn coset_idft_batch(&self, mat: RowMajorMatrix<F>, shift: F) -> RowMajorMatrix<F> {
175        let mut coeffs = self.idft_batch(mat);
176        crate::util::coset_shift_cols(&mut coeffs, shift.inverse());
177        coeffs
178    }
179
180    #[instrument(skip_all, level = "debug", fields(dims = %mat.dimensions(), added_bits = added_bits))]
181    fn coset_lde_batch_with_transform<T>(
182        &self,
183        mut mat: RowMajorMatrix<F>,
184        added_bits: usize,
185        shift: F,
186        transform: T,
187    ) -> Self::Evaluations
188    where
189        T: FnOnce(&mut RowMajorMatrixViewMut<'_, F>, Layout),
190    {
191        let w = mat.width;
192        let h = mat.height();
193        let log_h = log2_strict_usize(h);
194        let mid = log_h.div_ceil(2);
195
196        let inverse_twiddles = self.get_or_compute_inverse_twiddles(log_h);
197
198        // The first half looks like a normal DIT.
199        reverse_matrix_index_bits(&mut mat);
200        first_half(&mut mat, mid, &inverse_twiddles.twiddles);
201
202        // For the second half, we flip the DIT, working in bit-reversed order.
203        reverse_matrix_index_bits(&mut mat);
204        // We'll also scale by 1/h, as per the usual inverse DFT algorithm.
205        // If F isn't a PrimeField, (and is thus an extension field) it's much cheaper to
206        // invert in F::PrimeSubfield.
207        let h_inv_subfield = F::PrimeSubfield::from_int(h).try_inverse();
208        let scale = h_inv_subfield.map(F::from_prime_subfield);
209        second_half(&mut mat, mid, &inverse_twiddles.bitrev_twiddles, scale);
210        // We skip the final bit-reversal, since the next FFT expects bit-reversed input.
211
212        transform(&mut mat.as_view_mut(), Layout::BitReversed);
213
214        let lde_elems = w * (h << added_bits);
215        let elems_to_add = lde_elems - w * h;
216        debug_span!("reserve_exact").in_scope(|| mat.values.reserve_exact(elems_to_add));
217
218        let g_big = F::two_adic_generator(log_h + added_bits);
219
220        let mat_ptr = mat.values.as_mut_ptr();
221        let rest_ptr = unsafe { (mat_ptr as *mut MaybeUninit<F>).add(w * h) };
222        let first_slice: &mut [F] = unsafe { slice::from_raw_parts_mut(mat_ptr, w * h) };
223        let rest_slice: &mut [MaybeUninit<F>] =
224            unsafe { slice::from_raw_parts_mut(rest_ptr, lde_elems - w * h) };
225        let mut first_coset_mat = RowMajorMatrixViewMut::new(first_slice, w);
226        let mut rest_cosets_mat = rest_slice
227            .chunks_exact_mut(w * h)
228            .map(|slice| RowMajorMatrixViewMut::new(slice, w))
229            .collect_vec();
230
231        for coset_idx in 1..(1 << added_bits) {
232            let total_shift = g_big.exp_u64(coset_idx as u64) * shift;
233            let coset_idx = reverse_bits_len(coset_idx, added_bits);
234            let dest = &mut rest_cosets_mat[coset_idx - 1]; // - 1 because we removed the first matrix.
235            coset_dft_oop(self, &first_coset_mat.as_view(), dest, total_shift);
236        }
237
238        // Now run a forward DFT on the very first coset, this time in-place.
239        coset_dft(self, &mut first_coset_mat.as_view_mut(), shift);
240
241        // SAFETY: We wrote all values above.
242        unsafe {
243            mat.values.set_len(lde_elems);
244        }
245        BitReversalPerm::new_view(mat)
246    }
247}
248
249#[instrument(level = "debug", skip_all)]
250fn coset_dft<F: TwoAdicField + Ord>(
251    dft: &Radix2DitParallel<F>,
252    mat: &mut RowMajorMatrixViewMut<'_, F>,
253    shift: F,
254) {
255    let log_h = log2_strict_usize(mat.height());
256    let mid = log_h.div_ceil(2);
257
258    let twiddles = dft.get_or_compute_coset_twiddles((log_h, shift));
259
260    // The first half looks like a normal DIT.
261    first_half_general(mat, mid, &twiddles);
262
263    // For the second half, we flip the DIT, working in bit-reversed order.
264    reverse_matrix_index_bits(mat);
265
266    second_half_general(mat, mid, &twiddles);
267}
268
269/// Like `coset_dft`, except out-of-place.
270#[instrument(level = "debug", skip_all)]
271fn coset_dft_oop<F: TwoAdicField + Ord>(
272    dft: &Radix2DitParallel<F>,
273    src: &RowMajorMatrixView<'_, F>,
274    dst_maybe: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
275    shift: F,
276) {
277    assert_eq!(src.dimensions(), dst_maybe.dimensions());
278
279    let log_h = log2_strict_usize(dst_maybe.height());
280
281    if log_h == 0 {
282        // This is an edge case where first_half_general_oop doesn't work, as it expects there to be
283        // at least one layer in the network, so we just copy instead.
284        let src_maybe = unsafe {
285            transmute::<&RowMajorMatrixView<'_, F>, &RowMajorMatrixView<'_, MaybeUninit<F>>>(src)
286        };
287        dst_maybe.copy_from(src_maybe);
288        return;
289    }
290
291    let mid = log_h.div_ceil(2);
292
293    let twiddles = dft.get_or_compute_coset_twiddles((log_h, shift));
294
295    // The first half looks like a normal DIT.
296    first_half_general_oop(src, dst_maybe, mid, &twiddles);
297
298    // dst is now initialized.
299    let dst = unsafe {
300        transmute::<&mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>, &mut RowMajorMatrixViewMut<'_, F>>(
301            dst_maybe,
302        )
303    };
304
305    // For the second half, we flip the DIT, working in bit-reversed order.
306    reverse_matrix_index_bits(dst);
307
308    second_half_general(dst, mid, &twiddles);
309}
310
311/// This can be used as the first half of a DIT butterfly network.
312///
313/// For layer 0, all twiddle factors are 1 (root^0 = 1), so we use `TwiddleFreeButterfly`
314/// to avoid a Montgomery multiply by 1 across the entire matrix.
315///
316/// For layers 1 to mid-1 included, the first twiddle in each block is also always 1 (`twiddles[0] = 1`),
317/// so we special-case the first row-pair of each block to use `TwiddleFreeButterfly` as well.
318#[instrument(level = "debug", skip_all)]
319fn first_half<F: Field>(mat: &mut RowMajorMatrix<F>, mid: usize, twiddles: &[F]) {
320    let log_h = log2_strict_usize(mat.height());
321
322    // max block size: 2^mid
323    mat.par_row_chunks_exact_mut(1 << mid)
324        .for_each(|mut submat| {
325            let mut backwards = false;
326            for layer in 0..mid {
327                if layer == 0 {
328                    // For layer 0, half_block_size=1 and each block clones the twiddle
329                    // iterator from the start, consuming only twiddles[0] = root^0 = 1.
330                    // Use TwiddleFreeButterfly to skip the multiply entirely.
331                    dit_layer_twiddle_free(&mut submat, backwards);
332                } else {
333                    let layer_rev = log_h - 1 - layer;
334                    let layer_pow = 1 << layer_rev;
335                    // For layers 1..mid-1, twiddles[0] = root^0 = 1 is always the first
336                    // twiddle consumed per block. Use the optimized version that applies
337                    // TwiddleFreeButterfly for the first row-pair of each block.
338                    dit_layer_first_one(
339                        &mut submat,
340                        layer,
341                        twiddles.iter().step_by(layer_pow),
342                        backwards,
343                    );
344                }
345                backwards = !backwards;
346            }
347        });
348}
349
350/// Like `first_half`, except supporting different twiddle factors per layer, enabling coset shifts
351/// to be baked into them.
352#[instrument(level = "debug", skip_all)]
353fn first_half_general<F: Field>(
354    mat: &mut RowMajorMatrixViewMut<'_, F>,
355    mid: usize,
356    twiddles: &[Vec<F>],
357) {
358    let log_h = log2_strict_usize(mat.height());
359    mat.par_row_chunks_exact_mut(1 << mid)
360        .for_each(|mut submat| {
361            let mut backwards = false;
362            for layer in 0..mid {
363                let layer_rev = log_h - 1 - layer;
364                dit_layer(&mut submat, layer, twiddles[layer_rev].iter(), backwards);
365                backwards = !backwards;
366            }
367        });
368}
369
370/// Like `first_half_general`, except out-of-place.
371///
372/// Assumes there's at least one layer in the network, i.e. `src.height() > 1`.
373///
374/// # Panics
375/// Panics (via `log2_strict_usize` and arithmetic underflow) if `src.height() < 2`.
376#[instrument(level = "debug", skip_all)]
377fn first_half_general_oop<F: Field>(
378    src: &RowMajorMatrixView<'_, F>,
379    dst_maybe: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
380    mid: usize,
381    twiddles: &[Vec<F>],
382) {
383    let log_h = log2_strict_usize(src.height());
384    src.par_row_chunks_exact(1 << mid)
385        .zip(dst_maybe.par_row_chunks_exact_mut(1 << mid))
386        .for_each(|(src_submat, mut dst_submat_maybe)| {
387            debug_assert_eq!(src_submat.dimensions(), dst_submat_maybe.dimensions());
388
389            // The first layer is special, done out-of-place.
390            // (Recall from the mid definition that there must be at least one layer here.)
391            let layer_rev = log_h - 1;
392            dit_layer_oop(
393                &src_submat,
394                &mut dst_submat_maybe,
395                0,
396                twiddles[layer_rev].iter(),
397            );
398
399            // submat is now initialized.
400            let mut dst_submat = unsafe {
401                transmute::<RowMajorMatrixViewMut<'_, MaybeUninit<F>>, RowMajorMatrixViewMut<'_, F>>(
402                    dst_submat_maybe,
403                )
404            };
405
406            // Subsequent layers.
407            let mut backwards = true;
408            for layer in 1..mid {
409                let layer_rev = log_h - 1 - layer;
410                dit_layer(
411                    &mut dst_submat,
412                    layer,
413                    twiddles[layer_rev].iter(),
414                    backwards,
415                );
416                backwards = !backwards;
417            }
418        });
419}
420
421/// This can be used as the second half of a DIT butterfly network. It works in bit-reversed order.
422///
423/// The optional `scale` parameter is used to scale the matrix by a constant factor. Rather than
424/// doing a separate pass over memory, we fold the scaling into the first butterfly layer to
425/// eliminate an extra memory pass.
426#[instrument(level = "debug", skip_all)]
427#[inline(always)] // To avoid branch on scale
428fn second_half<F: Field>(
429    mat: &mut RowMajorMatrix<F>,
430    mid: usize,
431    twiddles_rev: &[F],
432    scale: Option<F>,
433) {
434    let log_h = log2_strict_usize(mat.height());
435
436    // max block size: 2^(log_h - mid)
437    mat.par_row_chunks_exact_mut(1 << (log_h - mid))
438        .enumerate()
439        .for_each(|(thread, mut submat)| {
440            let mut backwards = false;
441            if let Some(scale) = scale {
442                // Fold the scale into the first butterfly layer to avoid a separate
443                // memory pass. This merges the O(N) scaling step into the first O(N)
444                // butterfly pass.
445                let mut scale_applied = false;
446                for layer in mid..log_h {
447                    let first_block = thread << (layer - mid);
448                    if !scale_applied {
449                        scale_applied = true;
450                        dit_layer_rev_scaled(
451                            &mut submat,
452                            log_h,
453                            layer,
454                            twiddles_rev[first_block..].iter().copied(),
455                            backwards,
456                            Some(scale),
457                        );
458                    } else {
459                        dit_layer_rev(
460                            &mut submat,
461                            log_h,
462                            layer,
463                            twiddles_rev[first_block..].iter().copied(),
464                            backwards,
465                        );
466                    }
467                    backwards = !backwards;
468                }
469                // Handle case where there are no layers in the second half (mid == log_h).
470                if !scale_applied {
471                    submat.scale(scale);
472                }
473            } else {
474                for layer in mid..log_h {
475                    let first_block = thread << (layer - mid);
476                    dit_layer_rev(
477                        &mut submat,
478                        log_h,
479                        layer,
480                        twiddles_rev[first_block..].iter().copied(),
481                        backwards,
482                    );
483                    backwards = !backwards;
484                }
485            }
486        });
487}
488
489/// Like `second_half`, except supporting different twiddle factors per layer, enabling coset shifts
490/// to be baked into them.
491#[instrument(level = "debug", skip_all)]
492fn second_half_general<F: Field>(
493    mat: &mut RowMajorMatrixViewMut<'_, F>,
494    mid: usize,
495    twiddles_rev: &[Vec<F>],
496) {
497    let log_h = log2_strict_usize(mat.height());
498    mat.par_row_chunks_exact_mut(1 << (log_h - mid))
499        .enumerate()
500        .for_each(|(thread, mut submat)| {
501            let mut backwards = false;
502            for layer in mid..log_h {
503                let layer_rev = log_h - 1 - layer;
504                let first_block = thread << (layer - mid);
505                dit_layer_rev(
506                    &mut submat,
507                    log_h,
508                    layer,
509                    twiddles_rev[layer_rev][first_block..].iter().copied(),
510                    backwards,
511                );
512                backwards = !backwards;
513            }
514        });
515}
516
517/// One layer of a DIT butterfly network where all twiddle factors are 1 (i.e., layer 0).
518///
519/// This is equivalent to `dit_layer` with `layer=0` and `twiddles[0]=1`, but uses
520/// `TwiddleFreeButterfly` to avoid a Montgomery multiplication by 1 in the hot loop.
521///
522/// Correctness: For layer=0, `half_block_size=1` and each block clones the twiddle
523/// iterator from position 0, consuming only `twiddles[0] = generator^0 = 1`.
524/// Since multiplying by 1 is a no-op, `TwiddleFreeButterfly` gives identical results.
525fn dit_layer_twiddle_free<F: Field>(submat: &mut RowMajorMatrixViewMut<'_, F>, backwards: bool) {
526    // layer=0 means half_block_size=1, block_size=2.
527    let width = submat.width();
528    debug_assert!(submat.height() >= 2);
529
530    let process_block = move |block: &mut [F]| {
531        // Each block is exactly 2 rows: lo = block[0..width], hi = block[width..2*width]
532        let (lo, hi) = block.split_at_mut(width);
533        TwiddleFreeButterfly.apply_to_rows(lo, hi);
534    };
535
536    let blocks = submat.values.chunks_mut(2 * width);
537    if backwards {
538        for block in blocks.rev() {
539            process_block(block);
540        }
541    } else {
542        for block in blocks {
543            process_block(block);
544        }
545    }
546}
547
548/// One layer of a DIT butterfly network where the first twiddle factor per block is always 1.
549///
550/// This is used in `first_half` for layers 1..mid-1 of the standard (non-coset) DFT/inverse DFT,
551/// where `twiddles[0] = root^0 = 1`. The first row-pair of each block uses `TwiddleFreeButterfly`
552/// to avoid one Montgomery multiplication per block, while subsequent row-pairs use `DitButterfly`.
553///
554/// Correctness: The twiddle iterator yields `twiddles[0], twiddles[step], twiddles[2*step], ...`
555/// where `twiddles[0] = root^0 = 1`. Only used when this property holds.
556fn dit_layer_first_one<'a, F: Field>(
557    submat: &mut RowMajorMatrixViewMut<'_, F>,
558    layer: usize,
559    twiddles: impl Iterator<Item = &'a F> + Clone,
560    backwards: bool,
561) {
562    let half_block_size = 1 << layer;
563    let block_size = half_block_size * 2;
564    let width = submat.width();
565    debug_assert!(submat.height() >= block_size);
566    debug_assert!(
567        half_block_size >= 2,
568        "layer must be >= 1 for dit_layer_first_one"
569    );
570
571    let process_block = move |block: &mut [F]| {
572        let (lows, highs) = block.split_at_mut(half_block_size * width);
573        let mut tw_iter = twiddles.clone();
574        // First row-pair: twiddle is always 1, use TwiddleFreeButterfly to skip the multiply.
575        let _ = tw_iter.next(); // consume twiddles[0] = 1
576        let (lo0, lo_rest) = lows.split_at_mut(width);
577        let (hi0, hi_rest) = highs.split_at_mut(width);
578        TwiddleFreeButterfly.apply_to_rows(lo0, hi0);
579        // Remaining row-pairs use DitButterfly with their respective twiddle factors.
580        for (lo, hi, twiddle) in izip!(
581            lo_rest.chunks_mut(width),
582            hi_rest.chunks_mut(width),
583            tw_iter
584        ) {
585            DitButterfly(*twiddle).apply_to_rows(lo, hi);
586        }
587    };
588
589    let blocks = submat.values.chunks_mut(block_size * width);
590    if backwards {
591        for block in blocks.rev() {
592            process_block(block);
593        }
594    } else {
595        for block in blocks {
596            process_block(block);
597        }
598    }
599}
600
601/// One layer of a DIT butterfly network.
602fn dit_layer<'a, F: Field>(
603    submat: &mut RowMajorMatrixViewMut<'_, F>,
604    layer: usize,
605    twiddles: impl Iterator<Item = &'a F> + Clone,
606    backwards: bool,
607) {
608    let half_block_size = 1 << layer;
609    let block_size = half_block_size * 2;
610    let width = submat.width();
611    debug_assert!(submat.height() >= block_size);
612
613    let process_block = move |block: &mut [F]| {
614        let (lows, highs) = block.split_at_mut(half_block_size * width);
615        for (lo, hi, twiddle) in izip!(
616            lows.chunks_mut(width),
617            highs.chunks_mut(width),
618            twiddles.clone()
619        ) {
620            DitButterfly(*twiddle).apply_to_rows(lo, hi);
621        }
622    };
623
624    let blocks = submat.values.chunks_mut(block_size * width);
625    if backwards {
626        for block in blocks.rev() {
627            process_block(block);
628        }
629    } else {
630        for block in blocks {
631            process_block(block);
632        }
633    }
634}
635
636/// One layer of a DIT butterfly network, out-of-place.
637fn dit_layer_oop<'a, F: Field>(
638    src: &RowMajorMatrixView<'_, F>,
639    dst: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
640    layer: usize,
641    twiddles: impl Iterator<Item = &'a F> + Clone,
642) {
643    debug_assert_eq!(src.dimensions(), dst.dimensions());
644    let half_block_size = 1 << layer;
645    let block_size = half_block_size * 2;
646    let width = dst.width();
647    debug_assert!(dst.height() >= block_size);
648
649    let process_blocks = move |src_block: &[F], dst_block: &mut [MaybeUninit<F>]| {
650        let (src_lows, src_highs) = src_block.split_at(half_block_size * width);
651        let (dst_lows, dst_highs) = dst_block.split_at_mut(half_block_size * width);
652
653        for (src_lo, dst_lo, src_hi, dst_hi, twiddle) in izip!(
654            src_lows.chunks(width),
655            dst_lows.chunks_mut(width),
656            src_highs.chunks(width),
657            dst_highs.chunks_mut(width),
658            twiddles.clone()
659        ) {
660            DitButterfly(*twiddle).apply_to_rows_oop(src_lo, dst_lo, src_hi, dst_hi);
661        }
662    };
663
664    let src_chunks = src.values.chunks(block_size * width);
665    let dst_chunks = dst.values.chunks_mut(block_size * width);
666
667    for (src_block, dst_block) in src_chunks.zip(dst_chunks) {
668        process_blocks(src_block, dst_block);
669    }
670}
671
672/// Like `dit_layer_rev`, except with an optional scale factor folded into the butterfly.
673///
674/// This avoids an extra memory pass when scaling is required (e.g., 1/N in inverse DFT).
675/// When `scale` is `None`, this is identical to `dit_layer_rev`.
676///
677/// When `scale` is `Some(s)`, uses `ScaledDitButterfly::new(twiddle, s)` which precomputes
678/// `twiddle * scale` once per block, reducing multiplications in the hot loop from 3 to 2.
679fn dit_layer_rev_scaled<F: Field>(
680    submat: &mut RowMajorMatrixViewMut<'_, F>,
681    log_h: usize,
682    layer: usize,
683    twiddles_rev: impl DoubleEndedIterator<Item = F> + ExactSizeIterator,
684    backwards: bool,
685    scale: Option<F>,
686) {
687    let layer_rev = log_h - 1 - layer;
688
689    let half_block_size = 1 << layer_rev;
690    let block_size = half_block_size * 2;
691    let width = submat.width();
692    debug_assert!(submat.height() >= block_size);
693
694    match scale {
695        None => {
696            // No scaling: same as regular dit_layer_rev
697            let blocks_and_twiddles = submat
698                .values
699                .chunks_mut(block_size * width)
700                .zip(twiddles_rev);
701            if backwards {
702                for (block, twiddle) in blocks_and_twiddles.rev() {
703                    let (lo, hi) = block.split_at_mut(half_block_size * width);
704                    DitButterfly(twiddle).apply_to_rows(lo, hi);
705                }
706            } else {
707                for (block, twiddle) in blocks_and_twiddles {
708                    let (lo, hi) = block.split_at_mut(half_block_size * width);
709                    DitButterfly(twiddle).apply_to_rows(lo, hi);
710                }
711            }
712        }
713        Some(s) => {
714            // Fold scaling into the butterfly to avoid a separate memory pass.
715            // ScaledDitButterfly::new precomputes twiddle * scale once per block,
716            // so the hot loop only needs 2 multiplications instead of 3.
717            let blocks_and_twiddles = submat
718                .values
719                .chunks_mut(block_size * width)
720                .zip(twiddles_rev);
721            if backwards {
722                for (block, twiddle) in blocks_and_twiddles.rev() {
723                    let (lo, hi) = block.split_at_mut(half_block_size * width);
724                    ScaledDitButterfly::new(twiddle, s).apply_to_rows(lo, hi);
725                }
726            } else {
727                for (block, twiddle) in blocks_and_twiddles {
728                    let (lo, hi) = block.split_at_mut(half_block_size * width);
729                    ScaledDitButterfly::new(twiddle, s).apply_to_rows(lo, hi);
730                }
731            }
732        }
733    }
734}
735
736/// Like `dit_layer`, except the matrix and twiddles are encoded in bit-reversed order.
737/// This can also be viewed as a layer of the Bowers G^T network.
738fn dit_layer_rev<F: Field>(
739    submat: &mut RowMajorMatrixViewMut<'_, F>,
740    log_h: usize,
741    layer: usize,
742    twiddles_rev: impl DoubleEndedIterator<Item = F> + ExactSizeIterator,
743    backwards: bool,
744) {
745    let layer_rev = log_h - 1 - layer;
746
747    let half_block_size = 1 << layer_rev;
748    let block_size = half_block_size * 2;
749    let width = submat.width();
750    debug_assert!(submat.height() >= block_size);
751
752    let blocks_and_twiddles = submat
753        .values
754        .chunks_mut(block_size * width)
755        .zip(twiddles_rev);
756    if backwards {
757        for (block, twiddle) in blocks_and_twiddles.rev() {
758            let (lo, hi) = block.split_at_mut(half_block_size * width);
759            DitButterfly(twiddle).apply_to_rows(lo, hi);
760        }
761    } else {
762        for (block, twiddle) in blocks_and_twiddles {
763            let (lo, hi) = block.split_at_mut(half_block_size * width);
764            DitButterfly(twiddle).apply_to_rows(lo, hi);
765        }
766    }
767}
768
769#[cfg(test)]
770mod tests {
771    use p3_baby_bear::BabyBear;
772    use p3_field::TwoAdicField;
773    use p3_matrix::Matrix;
774    use p3_matrix::dense::RowMajorMatrix;
775    use rand::SeedableRng;
776    use rand::rngs::SmallRng;
777
778    use super::*;
779
780    type F = BabyBear;
781
782    #[test]
783    fn coset_dft_idft_roundtrip() {
784        let dft = Radix2DitParallel::<F>::default();
785        let shift = F::GENERATOR;
786        let mut rng = SmallRng::seed_from_u64(42);
787        let original = RowMajorMatrix::<F>::rand(&mut rng, 16, 3);
788
789        let evals = dft.coset_dft_batch(original.clone(), shift);
790        let recovered = dft.coset_idft_batch(evals.to_row_major_matrix(), shift);
791
792        assert_eq!(original, recovered);
793    }
794
795    #[test]
796    fn coset_dft_matches_default_trait() {
797        let dft = Radix2DitParallel::<F>::default();
798        let shift = F::two_adic_generator(4) * F::GENERATOR;
799        let mut rng = SmallRng::seed_from_u64(7);
800        let mat = RowMajorMatrix::<F>::rand(&mut rng, 16, 4);
801
802        let override_result = dft
803            .coset_dft_batch(mat.clone(), shift)
804            .to_row_major_matrix();
805
806        let mut shifted = mat;
807        crate::util::coset_shift_cols(&mut shifted, shift);
808        let default_result = dft.dft_batch(shifted).to_row_major_matrix();
809
810        assert_eq!(override_result, default_result);
811    }
812}