p3_dft/
radix_2_dit_parallel.rs

1use alloc::collections::BTreeMap;
2use alloc::slice;
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use core::mem::{MaybeUninit, transmute};
6
7use itertools::{Itertools, izip};
8use p3_field::integers::QuotientMap;
9use p3_field::{Field, Powers, TwoAdicField};
10use p3_matrix::Matrix;
11use p3_matrix::bitrev::{BitReversalPerm, BitReversedMatrixView, BitReversibleMatrix};
12use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView, RowMajorMatrixViewMut};
13use p3_matrix::util::reverse_matrix_index_bits;
14use p3_maybe_rayon::prelude::*;
15use p3_util::{log2_strict_usize, reverse_bits_len, reverse_slice_index_bits};
16use spin::RwLock;
17use tracing::{debug_span, instrument};
18
19use crate::TwoAdicSubgroupDft;
20use crate::butterflies::{Butterfly, DitButterfly};
21
22/// A parallel FFT algorithm which divides a butterfly network's layers into two halves.
23///
24/// For the first half, we apply a butterfly network with smaller blocks in earlier layers,
25/// i.e. either DIT or Bowers G. Then we bit-reverse, and for the second half, we continue executing
26/// the same network but in bit-reversed order. This way we're always working with small blocks,
27/// so within each half, we can have a certain amount of parallelism with no cross-thread
28/// communication.
29#[derive(Default, Clone, Debug)]
30pub struct Radix2DitParallel<F> {
31    /// Twiddles based on roots of unity, used in the forward DFT.
32    twiddles: Arc<RwLock<BTreeMap<usize, Arc<VectorPair<F>>>>>,
33
34    /// A map from `(log_h, shift)` to forward DFT twiddles with that coset shift baked in.
35    #[allow(clippy::type_complexity)]
36    coset_twiddles: Arc<RwLock<BTreeMap<(usize, F), Arc<[Vec<F>]>>>>,
37
38    /// Twiddles based on inverse roots of unity, used in the inverse DFT.
39    inverse_twiddles: Arc<RwLock<BTreeMap<usize, Arc<VectorPair<F>>>>>,
40}
41
42/// A pair of vectors, one with twiddle factors in their natural order, the other bit-reversed.
43#[derive(Default, Clone, Debug)]
44struct VectorPair<F> {
45    twiddles: Vec<F>,
46    bitrev_twiddles: Vec<F>,
47}
48
49impl<F> Radix2DitParallel<F>
50where
51    F: TwoAdicField + Ord,
52{
53    fn get_or_compute_twiddles(&self, log_h: usize) -> Arc<VectorPair<F>> {
54        // Fast path: Check for the value with a cheap read lock.
55        if let Some(pair) = self.twiddles.read().get(&log_h) {
56            return pair.clone();
57        }
58
59        // Slow path: The value doesn't exist. Acquire a write lock.
60        let mut w_lock = self.twiddles.write();
61
62        // Double-check and compute if necessary.
63        w_lock
64            .entry(log_h)
65            .or_insert_with(|| {
66                let half_h = (1 << log_h) >> 1;
67                let root = F::two_adic_generator(log_h);
68                let twiddles = root.powers().collect_n(half_h);
69                let mut bitrev_twiddles = twiddles.clone();
70                reverse_slice_index_bits(&mut bitrev_twiddles);
71
72                Arc::new(VectorPair {
73                    twiddles,
74                    bitrev_twiddles,
75                })
76            })
77            .clone()
78    }
79
80    fn get_or_compute_coset_twiddles(&self, (log_h, shift): (usize, F)) -> Arc<[Vec<F>]> {
81        let key = (log_h, shift);
82        // Fast path: Try to get the value with a cheap read lock first.
83        if let Some(twiddles) = self.coset_twiddles.read().get(&key) {
84            return twiddles.clone();
85        }
86        // Slow path: The value isn't there, so we need to compute it.
87        // Acquire a write lock to ensure only one thread does the computation.
88        let mut w_lock = self.coset_twiddles.write();
89        // Double-check: Another thread might have inserted it while we waited for the lock.
90        // The `entry` API handles this check and insertion atomically.
91        w_lock
92            .entry(key)
93            .or_insert_with(|| {
94                let mid = log_h.div_ceil(2);
95                let h = 1 << log_h;
96                let root = F::two_adic_generator(log_h);
97                (0..log_h)
98                    .map(|layer| {
99                        let shift_power = shift.exp_power_of_2(layer);
100                        let powers = Powers {
101                            base: root.exp_power_of_2(layer),
102                            current: shift_power,
103                        };
104                        let mut twiddles = powers.collect_n(h >> (layer + 1));
105                        let layer_rev = log_h - 1 - layer;
106                        if layer_rev >= mid {
107                            reverse_slice_index_bits(&mut twiddles);
108                        }
109                        twiddles
110                    })
111                    .collect::<Vec<_>>()
112                    .into()
113            })
114            .clone()
115    }
116
117    fn get_or_compute_inverse_twiddles(&self, log_h: usize) -> Arc<VectorPair<F>> {
118        // Fast path: First, check for the value using a cheap read lock.
119        if let Some(pair) = self.inverse_twiddles.read().get(&log_h) {
120            return pair.clone();
121        }
122        // Slow path: The value doesn't exist. Acquire a write lock.
123        let mut w_lock = self.inverse_twiddles.write();
124        // Double-check: Another thread might have created the entry while we waited.
125        // The `entry` API handles this check and the insertion atomically.
126        w_lock
127            .entry(log_h)
128            .or_insert_with(|| {
129                // This computation only runs if the entry is truly vacant.
130                let half_h = (1 << log_h) >> 1;
131                let root_inv = F::two_adic_generator(log_h).inverse();
132                let twiddles = root_inv.powers().collect_n(half_h);
133                let mut bitrev_twiddles = twiddles.clone();
134                reverse_slice_index_bits(&mut bitrev_twiddles);
135
136                Arc::new(VectorPair {
137                    twiddles,
138                    bitrev_twiddles,
139                })
140            })
141            .clone()
142    }
143}
144
145impl<F: TwoAdicField + Ord> TwoAdicSubgroupDft<F> for Radix2DitParallel<F> {
146    type Evaluations = BitReversedMatrixView<RowMajorMatrix<F>>;
147
148    fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> Self::Evaluations {
149        let h = mat.height();
150        let log_h = log2_strict_usize(h);
151
152        // Compute twiddle factors, or take memoized ones if already available.
153        let twiddles = self.get_or_compute_twiddles(log_h);
154
155        let mid = log_h.div_ceil(2);
156
157        // The first half looks like a normal DIT.
158        reverse_matrix_index_bits(&mut mat);
159        first_half(&mut mat, mid, &twiddles.twiddles);
160
161        // For the second half, we flip the DIT, working in bit-reversed order.
162        reverse_matrix_index_bits(&mut mat);
163        second_half(&mut mat, mid, &twiddles.bitrev_twiddles, None);
164
165        mat.bit_reverse_rows()
166    }
167
168    #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits = added_bits))]
169    fn coset_lde_batch(
170        &self,
171        mut mat: RowMajorMatrix<F>,
172        added_bits: usize,
173        shift: F,
174    ) -> Self::Evaluations {
175        let w = mat.width;
176        let h = mat.height();
177        let log_h = log2_strict_usize(h);
178        let mid = log_h.div_ceil(2);
179
180        let inverse_twiddles = self.get_or_compute_inverse_twiddles(log_h);
181
182        // The first half looks like a normal DIT.
183        reverse_matrix_index_bits(&mut mat);
184        first_half(&mut mat, mid, &inverse_twiddles.twiddles);
185
186        // For the second half, we flip the DIT, working in bit-reversed order.
187        reverse_matrix_index_bits(&mut mat);
188        // We'll also scale by 1/h, as per the usual inverse DFT algorithm.
189        // If F isn't a PrimeField, (and is thus an extension field) it's much cheaper to
190        // invert in F::PrimeSubfield.
191        let h_inv_subfield = F::PrimeSubfield::from_int(h).try_inverse();
192        let scale = h_inv_subfield.map(F::from_prime_subfield);
193        second_half(&mut mat, mid, &inverse_twiddles.bitrev_twiddles, scale);
194        // We skip the final bit-reversal, since the next FFT expects bit-reversed input.
195
196        let lde_elems = w * (h << added_bits);
197        let elems_to_add = lde_elems - w * h;
198        debug_span!("reserve_exact").in_scope(|| mat.values.reserve_exact(elems_to_add));
199
200        let g_big = F::two_adic_generator(log_h + added_bits);
201
202        let mat_ptr = mat.values.as_mut_ptr();
203        let rest_ptr = unsafe { (mat_ptr as *mut MaybeUninit<F>).add(w * h) };
204        let first_slice: &mut [F] = unsafe { slice::from_raw_parts_mut(mat_ptr, w * h) };
205        let rest_slice: &mut [MaybeUninit<F>] =
206            unsafe { slice::from_raw_parts_mut(rest_ptr, lde_elems - w * h) };
207        let mut first_coset_mat = RowMajorMatrixViewMut::new(first_slice, w);
208        let mut rest_cosets_mat = rest_slice
209            .chunks_exact_mut(w * h)
210            .map(|slice| RowMajorMatrixViewMut::new(slice, w))
211            .collect_vec();
212
213        for coset_idx in 1..(1 << added_bits) {
214            let total_shift = g_big.exp_u64(coset_idx as u64) * shift;
215            let coset_idx = reverse_bits_len(coset_idx, added_bits);
216            let dest = &mut rest_cosets_mat[coset_idx - 1]; // - 1 because we removed the first matrix.
217            coset_dft_oop(self, &first_coset_mat.as_view(), dest, total_shift);
218        }
219
220        // Now run a forward DFT on the very first coset, this time in-place.
221        coset_dft(self, &mut first_coset_mat.as_view_mut(), shift);
222
223        // SAFETY: We wrote all values above.
224        unsafe {
225            mat.values.set_len(lde_elems);
226        }
227        BitReversalPerm::new_view(mat)
228    }
229}
230
231#[instrument(level = "debug", skip_all)]
232fn coset_dft<F: TwoAdicField + Ord>(
233    dft: &Radix2DitParallel<F>,
234    mat: &mut RowMajorMatrixViewMut<'_, F>,
235    shift: F,
236) {
237    let log_h = log2_strict_usize(mat.height());
238    let mid = log_h.div_ceil(2);
239
240    // let twiddles = compute_factors((log_h, shift), &dft.coset_twiddles, compute_coset_twiddles);
241    let twiddles = dft.get_or_compute_coset_twiddles((log_h, shift));
242
243    // The first half looks like a normal DIT.
244    first_half_general(mat, mid, &twiddles);
245
246    // For the second half, we flip the DIT, working in bit-reversed order.
247    reverse_matrix_index_bits(mat);
248
249    second_half_general(mat, mid, &twiddles);
250}
251
252/// Like `coset_dft`, except out-of-place.
253#[instrument(level = "debug", skip_all)]
254fn coset_dft_oop<F: TwoAdicField + Ord>(
255    dft: &Radix2DitParallel<F>,
256    src: &RowMajorMatrixView<'_, F>,
257    dst_maybe: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
258    shift: F,
259) {
260    assert_eq!(src.dimensions(), dst_maybe.dimensions());
261
262    let log_h = log2_strict_usize(dst_maybe.height());
263
264    if log_h == 0 {
265        // This is an edge case where first_half_general_oop doesn't work, as it expects there to be
266        // at least one layer in the network, so we just copy instead.
267        let src_maybe = unsafe {
268            transmute::<&RowMajorMatrixView<'_, F>, &RowMajorMatrixView<'_, MaybeUninit<F>>>(src)
269        };
270        dst_maybe.copy_from(src_maybe);
271        return;
272    }
273
274    let mid = log_h.div_ceil(2);
275
276    let twiddles = dft.get_or_compute_coset_twiddles((log_h, shift));
277
278    // The first half looks like a normal DIT.
279    first_half_general_oop(src, dst_maybe, mid, &twiddles);
280
281    // dst is now initialized.
282    let dst = unsafe {
283        transmute::<&mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>, &mut RowMajorMatrixViewMut<'_, F>>(
284            dst_maybe,
285        )
286    };
287
288    // For the second half, we flip the DIT, working in bit-reversed order.
289    reverse_matrix_index_bits(dst);
290
291    second_half_general(dst, mid, &twiddles);
292}
293
294/// This can be used as the first half of a DIT butterfly network.
295#[instrument(level = "debug", skip_all)]
296fn first_half<F: Field>(mat: &mut RowMajorMatrix<F>, mid: usize, twiddles: &[F]) {
297    let log_h = log2_strict_usize(mat.height());
298
299    // max block size: 2^mid
300    mat.par_row_chunks_exact_mut(1 << mid)
301        .for_each(|mut submat| {
302            let mut backwards = false;
303            for layer in 0..mid {
304                let layer_rev = log_h - 1 - layer;
305                let layer_pow = 1 << layer_rev;
306                dit_layer(
307                    &mut submat,
308                    layer,
309                    twiddles.iter().step_by(layer_pow),
310                    backwards,
311                );
312                backwards = !backwards;
313            }
314        });
315}
316
317/// Like `first_half`, except supporting different twiddle factors per layer, enabling coset shifts
318/// to be baked into them.
319#[instrument(level = "debug", skip_all)]
320fn first_half_general<F: Field>(
321    mat: &mut RowMajorMatrixViewMut<'_, F>,
322    mid: usize,
323    twiddles: &[Vec<F>],
324) {
325    let log_h = log2_strict_usize(mat.height());
326    mat.par_row_chunks_exact_mut(1 << mid)
327        .for_each(|mut submat| {
328            let mut backwards = false;
329            for layer in 0..mid {
330                let layer_rev = log_h - 1 - layer;
331                dit_layer(&mut submat, layer, twiddles[layer_rev].iter(), backwards);
332                backwards = !backwards;
333            }
334        });
335}
336
337/// Like `first_half_general`, except out-of-place.
338///
339/// Assumes there's at least one layer in the network, i.e. `src.height() > 1`.
340/// Undefined behavior otherwise.
341#[instrument(level = "debug", skip_all)]
342fn first_half_general_oop<F: Field>(
343    src: &RowMajorMatrixView<'_, F>,
344    dst_maybe: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
345    mid: usize,
346    twiddles: &[Vec<F>],
347) {
348    let log_h = log2_strict_usize(src.height());
349    src.par_row_chunks_exact(1 << mid)
350        .zip(dst_maybe.par_row_chunks_exact_mut(1 << mid))
351        .for_each(|(src_submat, mut dst_submat_maybe)| {
352            debug_assert_eq!(src_submat.dimensions(), dst_submat_maybe.dimensions());
353
354            // The first layer is special, done out-of-place.
355            // (Recall from the mid definition that there must be at least one layer here.)
356            let layer_rev = log_h - 1;
357            dit_layer_oop(
358                &src_submat,
359                &mut dst_submat_maybe,
360                0,
361                twiddles[layer_rev].iter(),
362            );
363
364            // submat is now initialized.
365            let mut dst_submat = unsafe {
366                transmute::<RowMajorMatrixViewMut<'_, MaybeUninit<F>>, RowMajorMatrixViewMut<'_, F>>(
367                    dst_submat_maybe,
368                )
369            };
370
371            // Subsequent layers.
372            let mut backwards = true;
373            for layer in 1..mid {
374                let layer_rev = log_h - 1 - layer;
375                dit_layer(
376                    &mut dst_submat,
377                    layer,
378                    twiddles[layer_rev].iter(),
379                    backwards,
380                );
381                backwards = !backwards;
382            }
383        });
384}
385
386/// This can be used as the second half of a DIT butterfly network. It works in bit-reversed order.
387///
388/// The optional `scale` parameter is used to scale the matrix by a constant factor. Normally that
389/// would be a separate step, but it's best to merge it into a butterfly network to avoid a
390/// separate pass through main memory.
391#[instrument(level = "debug", skip_all)]
392#[inline(always)] // To avoid branch on scale
393fn second_half<F: Field>(
394    mat: &mut RowMajorMatrix<F>,
395    mid: usize,
396    twiddles_rev: &[F],
397    scale: Option<F>,
398) {
399    let log_h = log2_strict_usize(mat.height());
400
401    // max block size: 2^(log_h - mid)
402    mat.par_row_chunks_exact_mut(1 << (log_h - mid))
403        .enumerate()
404        .for_each(|(thread, mut submat)| {
405            let mut backwards = false;
406            if let Some(scale) = scale {
407                submat.scale(scale);
408            }
409            for layer in mid..log_h {
410                let first_block = thread << (layer - mid);
411                dit_layer_rev(
412                    &mut submat,
413                    log_h,
414                    layer,
415                    twiddles_rev[first_block..].iter().copied(),
416                    backwards,
417                );
418                backwards = !backwards;
419            }
420        });
421}
422
423/// Like `second_half`, except supporting different twiddle factors per layer, enabling coset shifts
424/// to be baked into them.
425#[instrument(level = "debug", skip_all)]
426fn second_half_general<F: Field>(
427    mat: &mut RowMajorMatrixViewMut<'_, F>,
428    mid: usize,
429    twiddles_rev: &[Vec<F>],
430) {
431    let log_h = log2_strict_usize(mat.height());
432    mat.par_row_chunks_exact_mut(1 << (log_h - mid))
433        .enumerate()
434        .for_each(|(thread, mut submat)| {
435            let mut backwards = false;
436            for layer in mid..log_h {
437                let layer_rev = log_h - 1 - layer;
438                let first_block = thread << (layer - mid);
439                dit_layer_rev(
440                    &mut submat,
441                    log_h,
442                    layer,
443                    twiddles_rev[layer_rev][first_block..].iter().copied(),
444                    backwards,
445                );
446                backwards = !backwards;
447            }
448        });
449}
450
451/// One layer of a DIT butterfly network.
452fn dit_layer<'a, F: Field>(
453    submat: &mut RowMajorMatrixViewMut<'_, F>,
454    layer: usize,
455    twiddles: impl Iterator<Item = &'a F> + Clone,
456    backwards: bool,
457) {
458    let half_block_size = 1 << layer;
459    let block_size = half_block_size * 2;
460    let width = submat.width();
461    debug_assert!(submat.height() >= block_size);
462
463    let process_block = move |block: &mut [F]| {
464        let (lows, highs) = block.split_at_mut(half_block_size * width);
465        for (lo, hi, twiddle) in izip!(
466            lows.chunks_mut(width),
467            highs.chunks_mut(width),
468            twiddles.clone()
469        ) {
470            DitButterfly(*twiddle).apply_to_rows(lo, hi);
471        }
472    };
473
474    let blocks = submat.values.chunks_mut(block_size * width);
475    if backwards {
476        for block in blocks.rev() {
477            process_block(block);
478        }
479    } else {
480        for block in blocks {
481            process_block(block);
482        }
483    }
484}
485
486/// One layer of a DIT butterfly network.
487fn dit_layer_oop<'a, F: Field>(
488    src: &RowMajorMatrixView<'_, F>,
489    dst: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
490    layer: usize,
491    twiddles: impl Iterator<Item = &'a F> + Clone,
492) {
493    debug_assert_eq!(src.dimensions(), dst.dimensions());
494    let half_block_size = 1 << layer;
495    let block_size = half_block_size * 2;
496    let width = dst.width();
497    debug_assert!(dst.height() >= block_size);
498
499    let process_blocks = move |src_block: &[F], dst_block: &mut [MaybeUninit<F>]| {
500        let (src_lows, src_highs) = src_block.split_at(half_block_size * width);
501        let (dst_lows, dst_highs) = dst_block.split_at_mut(half_block_size * width);
502
503        for (src_lo, dst_lo, src_hi, dst_hi, twiddle) in izip!(
504            src_lows.chunks(width),
505            dst_lows.chunks_mut(width),
506            src_highs.chunks(width),
507            dst_highs.chunks_mut(width),
508            twiddles.clone()
509        ) {
510            DitButterfly(*twiddle).apply_to_rows_oop(src_lo, dst_lo, src_hi, dst_hi);
511        }
512    };
513
514    let src_chunks = src.values.chunks(block_size * width);
515    let dst_chunks = dst.values.chunks_mut(block_size * width);
516
517    for (src_block, dst_block) in src_chunks.zip(dst_chunks) {
518        process_blocks(src_block, dst_block);
519    }
520}
521
522/// Like `dit_layer`, except the matrix and twiddles are encoded in bit-reversed order.
523/// This can also be viewed as a layer of the Bowers G^T network.
524fn dit_layer_rev<F: Field>(
525    submat: &mut RowMajorMatrixViewMut<'_, F>,
526    log_h: usize,
527    layer: usize,
528    twiddles_rev: impl DoubleEndedIterator<Item = F> + ExactSizeIterator,
529    backwards: bool,
530) {
531    let layer_rev = log_h - 1 - layer;
532
533    let half_block_size = 1 << layer_rev;
534    let block_size = half_block_size * 2;
535    let width = submat.width();
536    debug_assert!(submat.height() >= block_size);
537
538    let blocks_and_twiddles = submat
539        .values
540        .chunks_mut(block_size * width)
541        .zip(twiddles_rev);
542    if backwards {
543        for (block, twiddle) in blocks_and_twiddles.rev() {
544            let (lo, hi) = block.split_at_mut(half_block_size * width);
545            DitButterfly(twiddle).apply_to_rows(lo, hi);
546        }
547    } else {
548        for (block, twiddle) in blocks_and_twiddles {
549            let (lo, hi) = block.split_at_mut(half_block_size * width);
550            DitButterfly(twiddle).apply_to_rows(lo, hi);
551        }
552    }
553}