1use alloc::collections::BTreeMap;
2use alloc::slice;
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use core::mem::{MaybeUninit, transmute};
6
7use itertools::{Itertools, izip};
8use p3_field::integers::QuotientMap;
9use p3_field::{Field, Powers, TwoAdicField};
10use p3_matrix::Matrix;
11use p3_matrix::bitrev::{BitReversalPerm, BitReversedMatrixView, BitReversibleMatrix};
12use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView, RowMajorMatrixViewMut};
13use p3_matrix::util::reverse_matrix_index_bits;
14use p3_maybe_rayon::prelude::*;
15use p3_util::{log2_strict_usize, reverse_bits_len, reverse_slice_index_bits};
16use spin::RwLock;
17use tracing::{debug_span, instrument};
18
19use crate::TwoAdicSubgroupDft;
20use crate::butterflies::{Butterfly, DitButterfly};
21
22#[derive(Default, Clone, Debug)]
30pub struct Radix2DitParallel<F> {
31 twiddles: Arc<RwLock<BTreeMap<usize, Arc<VectorPair<F>>>>>,
33
34 #[allow(clippy::type_complexity)]
36 coset_twiddles: Arc<RwLock<BTreeMap<(usize, F), Arc<[Vec<F>]>>>>,
37
38 inverse_twiddles: Arc<RwLock<BTreeMap<usize, Arc<VectorPair<F>>>>>,
40}
41
42#[derive(Default, Clone, Debug)]
44struct VectorPair<F> {
45 twiddles: Vec<F>,
46 bitrev_twiddles: Vec<F>,
47}
48
49impl<F> Radix2DitParallel<F>
50where
51 F: TwoAdicField + Ord,
52{
53 fn get_or_compute_twiddles(&self, log_h: usize) -> Arc<VectorPair<F>> {
54 if let Some(pair) = self.twiddles.read().get(&log_h) {
56 return pair.clone();
57 }
58
59 let mut w_lock = self.twiddles.write();
61
62 w_lock
64 .entry(log_h)
65 .or_insert_with(|| {
66 let half_h = (1 << log_h) >> 1;
67 let root = F::two_adic_generator(log_h);
68 let twiddles = root.powers().collect_n(half_h);
69 let mut bitrev_twiddles = twiddles.clone();
70 reverse_slice_index_bits(&mut bitrev_twiddles);
71
72 Arc::new(VectorPair {
73 twiddles,
74 bitrev_twiddles,
75 })
76 })
77 .clone()
78 }
79
80 fn get_or_compute_coset_twiddles(&self, (log_h, shift): (usize, F)) -> Arc<[Vec<F>]> {
81 let key = (log_h, shift);
82 if let Some(twiddles) = self.coset_twiddles.read().get(&key) {
84 return twiddles.clone();
85 }
86 let mut w_lock = self.coset_twiddles.write();
89 w_lock
92 .entry(key)
93 .or_insert_with(|| {
94 let mid = log_h.div_ceil(2);
95 let h = 1 << log_h;
96 let root = F::two_adic_generator(log_h);
97 (0..log_h)
98 .map(|layer| {
99 let shift_power = shift.exp_power_of_2(layer);
100 let powers = Powers {
101 base: root.exp_power_of_2(layer),
102 current: shift_power,
103 };
104 let mut twiddles = powers.collect_n(h >> (layer + 1));
105 let layer_rev = log_h - 1 - layer;
106 if layer_rev >= mid {
107 reverse_slice_index_bits(&mut twiddles);
108 }
109 twiddles
110 })
111 .collect::<Vec<_>>()
112 .into()
113 })
114 .clone()
115 }
116
117 fn get_or_compute_inverse_twiddles(&self, log_h: usize) -> Arc<VectorPair<F>> {
118 if let Some(pair) = self.inverse_twiddles.read().get(&log_h) {
120 return pair.clone();
121 }
122 let mut w_lock = self.inverse_twiddles.write();
124 w_lock
127 .entry(log_h)
128 .or_insert_with(|| {
129 let half_h = (1 << log_h) >> 1;
131 let root_inv = F::two_adic_generator(log_h).inverse();
132 let twiddles = root_inv.powers().collect_n(half_h);
133 let mut bitrev_twiddles = twiddles.clone();
134 reverse_slice_index_bits(&mut bitrev_twiddles);
135
136 Arc::new(VectorPair {
137 twiddles,
138 bitrev_twiddles,
139 })
140 })
141 .clone()
142 }
143}
144
145impl<F: TwoAdicField + Ord> TwoAdicSubgroupDft<F> for Radix2DitParallel<F> {
146 type Evaluations = BitReversedMatrixView<RowMajorMatrix<F>>;
147
148 fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> Self::Evaluations {
149 let h = mat.height();
150 let log_h = log2_strict_usize(h);
151
152 let twiddles = self.get_or_compute_twiddles(log_h);
154
155 let mid = log_h.div_ceil(2);
156
157 reverse_matrix_index_bits(&mut mat);
159 first_half(&mut mat, mid, &twiddles.twiddles);
160
161 reverse_matrix_index_bits(&mut mat);
163 second_half(&mut mat, mid, &twiddles.bitrev_twiddles, None);
164
165 mat.bit_reverse_rows()
166 }
167
168 #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits = added_bits))]
169 fn coset_lde_batch(
170 &self,
171 mut mat: RowMajorMatrix<F>,
172 added_bits: usize,
173 shift: F,
174 ) -> Self::Evaluations {
175 let w = mat.width;
176 let h = mat.height();
177 let log_h = log2_strict_usize(h);
178 let mid = log_h.div_ceil(2);
179
180 let inverse_twiddles = self.get_or_compute_inverse_twiddles(log_h);
181
182 reverse_matrix_index_bits(&mut mat);
184 first_half(&mut mat, mid, &inverse_twiddles.twiddles);
185
186 reverse_matrix_index_bits(&mut mat);
188 let h_inv_subfield = F::PrimeSubfield::from_int(h).try_inverse();
192 let scale = h_inv_subfield.map(F::from_prime_subfield);
193 second_half(&mut mat, mid, &inverse_twiddles.bitrev_twiddles, scale);
194 let lde_elems = w * (h << added_bits);
197 let elems_to_add = lde_elems - w * h;
198 debug_span!("reserve_exact").in_scope(|| mat.values.reserve_exact(elems_to_add));
199
200 let g_big = F::two_adic_generator(log_h + added_bits);
201
202 let mat_ptr = mat.values.as_mut_ptr();
203 let rest_ptr = unsafe { (mat_ptr as *mut MaybeUninit<F>).add(w * h) };
204 let first_slice: &mut [F] = unsafe { slice::from_raw_parts_mut(mat_ptr, w * h) };
205 let rest_slice: &mut [MaybeUninit<F>] =
206 unsafe { slice::from_raw_parts_mut(rest_ptr, lde_elems - w * h) };
207 let mut first_coset_mat = RowMajorMatrixViewMut::new(first_slice, w);
208 let mut rest_cosets_mat = rest_slice
209 .chunks_exact_mut(w * h)
210 .map(|slice| RowMajorMatrixViewMut::new(slice, w))
211 .collect_vec();
212
213 for coset_idx in 1..(1 << added_bits) {
214 let total_shift = g_big.exp_u64(coset_idx as u64) * shift;
215 let coset_idx = reverse_bits_len(coset_idx, added_bits);
216 let dest = &mut rest_cosets_mat[coset_idx - 1]; coset_dft_oop(self, &first_coset_mat.as_view(), dest, total_shift);
218 }
219
220 coset_dft(self, &mut first_coset_mat.as_view_mut(), shift);
222
223 unsafe {
225 mat.values.set_len(lde_elems);
226 }
227 BitReversalPerm::new_view(mat)
228 }
229}
230
231#[instrument(level = "debug", skip_all)]
232fn coset_dft<F: TwoAdicField + Ord>(
233 dft: &Radix2DitParallel<F>,
234 mat: &mut RowMajorMatrixViewMut<'_, F>,
235 shift: F,
236) {
237 let log_h = log2_strict_usize(mat.height());
238 let mid = log_h.div_ceil(2);
239
240 let twiddles = dft.get_or_compute_coset_twiddles((log_h, shift));
242
243 first_half_general(mat, mid, &twiddles);
245
246 reverse_matrix_index_bits(mat);
248
249 second_half_general(mat, mid, &twiddles);
250}
251
252#[instrument(level = "debug", skip_all)]
254fn coset_dft_oop<F: TwoAdicField + Ord>(
255 dft: &Radix2DitParallel<F>,
256 src: &RowMajorMatrixView<'_, F>,
257 dst_maybe: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
258 shift: F,
259) {
260 assert_eq!(src.dimensions(), dst_maybe.dimensions());
261
262 let log_h = log2_strict_usize(dst_maybe.height());
263
264 if log_h == 0 {
265 let src_maybe = unsafe {
268 transmute::<&RowMajorMatrixView<'_, F>, &RowMajorMatrixView<'_, MaybeUninit<F>>>(src)
269 };
270 dst_maybe.copy_from(src_maybe);
271 return;
272 }
273
274 let mid = log_h.div_ceil(2);
275
276 let twiddles = dft.get_or_compute_coset_twiddles((log_h, shift));
277
278 first_half_general_oop(src, dst_maybe, mid, &twiddles);
280
281 let dst = unsafe {
283 transmute::<&mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>, &mut RowMajorMatrixViewMut<'_, F>>(
284 dst_maybe,
285 )
286 };
287
288 reverse_matrix_index_bits(dst);
290
291 second_half_general(dst, mid, &twiddles);
292}
293
294#[instrument(level = "debug", skip_all)]
296fn first_half<F: Field>(mat: &mut RowMajorMatrix<F>, mid: usize, twiddles: &[F]) {
297 let log_h = log2_strict_usize(mat.height());
298
299 mat.par_row_chunks_exact_mut(1 << mid)
301 .for_each(|mut submat| {
302 let mut backwards = false;
303 for layer in 0..mid {
304 let layer_rev = log_h - 1 - layer;
305 let layer_pow = 1 << layer_rev;
306 dit_layer(
307 &mut submat,
308 layer,
309 twiddles.iter().step_by(layer_pow),
310 backwards,
311 );
312 backwards = !backwards;
313 }
314 });
315}
316
317#[instrument(level = "debug", skip_all)]
320fn first_half_general<F: Field>(
321 mat: &mut RowMajorMatrixViewMut<'_, F>,
322 mid: usize,
323 twiddles: &[Vec<F>],
324) {
325 let log_h = log2_strict_usize(mat.height());
326 mat.par_row_chunks_exact_mut(1 << mid)
327 .for_each(|mut submat| {
328 let mut backwards = false;
329 for layer in 0..mid {
330 let layer_rev = log_h - 1 - layer;
331 dit_layer(&mut submat, layer, twiddles[layer_rev].iter(), backwards);
332 backwards = !backwards;
333 }
334 });
335}
336
337#[instrument(level = "debug", skip_all)]
342fn first_half_general_oop<F: Field>(
343 src: &RowMajorMatrixView<'_, F>,
344 dst_maybe: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
345 mid: usize,
346 twiddles: &[Vec<F>],
347) {
348 let log_h = log2_strict_usize(src.height());
349 src.par_row_chunks_exact(1 << mid)
350 .zip(dst_maybe.par_row_chunks_exact_mut(1 << mid))
351 .for_each(|(src_submat, mut dst_submat_maybe)| {
352 debug_assert_eq!(src_submat.dimensions(), dst_submat_maybe.dimensions());
353
354 let layer_rev = log_h - 1;
357 dit_layer_oop(
358 &src_submat,
359 &mut dst_submat_maybe,
360 0,
361 twiddles[layer_rev].iter(),
362 );
363
364 let mut dst_submat = unsafe {
366 transmute::<RowMajorMatrixViewMut<'_, MaybeUninit<F>>, RowMajorMatrixViewMut<'_, F>>(
367 dst_submat_maybe,
368 )
369 };
370
371 let mut backwards = true;
373 for layer in 1..mid {
374 let layer_rev = log_h - 1 - layer;
375 dit_layer(
376 &mut dst_submat,
377 layer,
378 twiddles[layer_rev].iter(),
379 backwards,
380 );
381 backwards = !backwards;
382 }
383 });
384}
385
386#[instrument(level = "debug", skip_all)]
392#[inline(always)] fn second_half<F: Field>(
394 mat: &mut RowMajorMatrix<F>,
395 mid: usize,
396 twiddles_rev: &[F],
397 scale: Option<F>,
398) {
399 let log_h = log2_strict_usize(mat.height());
400
401 mat.par_row_chunks_exact_mut(1 << (log_h - mid))
403 .enumerate()
404 .for_each(|(thread, mut submat)| {
405 let mut backwards = false;
406 if let Some(scale) = scale {
407 submat.scale(scale);
408 }
409 for layer in mid..log_h {
410 let first_block = thread << (layer - mid);
411 dit_layer_rev(
412 &mut submat,
413 log_h,
414 layer,
415 twiddles_rev[first_block..].iter().copied(),
416 backwards,
417 );
418 backwards = !backwards;
419 }
420 });
421}
422
423#[instrument(level = "debug", skip_all)]
426fn second_half_general<F: Field>(
427 mat: &mut RowMajorMatrixViewMut<'_, F>,
428 mid: usize,
429 twiddles_rev: &[Vec<F>],
430) {
431 let log_h = log2_strict_usize(mat.height());
432 mat.par_row_chunks_exact_mut(1 << (log_h - mid))
433 .enumerate()
434 .for_each(|(thread, mut submat)| {
435 let mut backwards = false;
436 for layer in mid..log_h {
437 let layer_rev = log_h - 1 - layer;
438 let first_block = thread << (layer - mid);
439 dit_layer_rev(
440 &mut submat,
441 log_h,
442 layer,
443 twiddles_rev[layer_rev][first_block..].iter().copied(),
444 backwards,
445 );
446 backwards = !backwards;
447 }
448 });
449}
450
451fn dit_layer<'a, F: Field>(
453 submat: &mut RowMajorMatrixViewMut<'_, F>,
454 layer: usize,
455 twiddles: impl Iterator<Item = &'a F> + Clone,
456 backwards: bool,
457) {
458 let half_block_size = 1 << layer;
459 let block_size = half_block_size * 2;
460 let width = submat.width();
461 debug_assert!(submat.height() >= block_size);
462
463 let process_block = move |block: &mut [F]| {
464 let (lows, highs) = block.split_at_mut(half_block_size * width);
465 for (lo, hi, twiddle) in izip!(
466 lows.chunks_mut(width),
467 highs.chunks_mut(width),
468 twiddles.clone()
469 ) {
470 DitButterfly(*twiddle).apply_to_rows(lo, hi);
471 }
472 };
473
474 let blocks = submat.values.chunks_mut(block_size * width);
475 if backwards {
476 for block in blocks.rev() {
477 process_block(block);
478 }
479 } else {
480 for block in blocks {
481 process_block(block);
482 }
483 }
484}
485
486fn dit_layer_oop<'a, F: Field>(
488 src: &RowMajorMatrixView<'_, F>,
489 dst: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
490 layer: usize,
491 twiddles: impl Iterator<Item = &'a F> + Clone,
492) {
493 debug_assert_eq!(src.dimensions(), dst.dimensions());
494 let half_block_size = 1 << layer;
495 let block_size = half_block_size * 2;
496 let width = dst.width();
497 debug_assert!(dst.height() >= block_size);
498
499 let process_blocks = move |src_block: &[F], dst_block: &mut [MaybeUninit<F>]| {
500 let (src_lows, src_highs) = src_block.split_at(half_block_size * width);
501 let (dst_lows, dst_highs) = dst_block.split_at_mut(half_block_size * width);
502
503 for (src_lo, dst_lo, src_hi, dst_hi, twiddle) in izip!(
504 src_lows.chunks(width),
505 dst_lows.chunks_mut(width),
506 src_highs.chunks(width),
507 dst_highs.chunks_mut(width),
508 twiddles.clone()
509 ) {
510 DitButterfly(*twiddle).apply_to_rows_oop(src_lo, dst_lo, src_hi, dst_hi);
511 }
512 };
513
514 let src_chunks = src.values.chunks(block_size * width);
515 let dst_chunks = dst.values.chunks_mut(block_size * width);
516
517 for (src_block, dst_block) in src_chunks.zip(dst_chunks) {
518 process_blocks(src_block, dst_block);
519 }
520}
521
522fn dit_layer_rev<F: Field>(
525 submat: &mut RowMajorMatrixViewMut<'_, F>,
526 log_h: usize,
527 layer: usize,
528 twiddles_rev: impl DoubleEndedIterator<Item = F> + ExactSizeIterator,
529 backwards: bool,
530) {
531 let layer_rev = log_h - 1 - layer;
532
533 let half_block_size = 1 << layer_rev;
534 let block_size = half_block_size * 2;
535 let width = submat.width();
536 debug_assert!(submat.height() >= block_size);
537
538 let blocks_and_twiddles = submat
539 .values
540 .chunks_mut(block_size * width)
541 .zip(twiddles_rev);
542 if backwards {
543 for (block, twiddle) in blocks_and_twiddles.rev() {
544 let (lo, hi) = block.split_at_mut(half_block_size * width);
545 DitButterfly(twiddle).apply_to_rows(lo, hi);
546 }
547 } else {
548 for (block, twiddle) in blocks_and_twiddles {
549 let (lo, hi) = block.split_at_mut(half_block_size * width);
550 DitButterfly(twiddle).apply_to_rows(lo, hi);
551 }
552 }
553}