1use alloc::sync::Arc;
4use alloc::vec::Vec;
5use core::iter;
6
7use itertools::Itertools;
8use p3_field::{Field, TwoAdicField, scale_slice_in_place_single_core};
9use p3_matrix::Matrix;
10use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut};
11use p3_matrix::util::reverse_matrix_index_bits;
12use p3_maybe_rayon::prelude::*;
13use p3_util::{as_base_slice, log2_strict_usize, reverse_slice_index_bits};
14use spin::RwLock;
15
16use crate::{
17 Butterfly, DifButterfly, DifButterflyZeros, DitButterfly, TwiddleFreeButterfly,
18 TwoAdicSubgroupDft,
19};
20
21const LAYERS_PER_GROUP: usize = 3;
23
24#[derive(Default, Clone, Debug)]
36pub struct Radix2DFTSmallBatch<F> {
37 #[allow(clippy::type_complexity)]
43 twiddles: Arc<RwLock<Arc<[Vec<F>]>>>,
44
45 #[allow(clippy::type_complexity)]
47 inv_twiddles: Arc<RwLock<Arc<[Vec<F>]>>>,
48}
49
50impl<F: TwoAdicField> Radix2DFTSmallBatch<F> {
51 pub fn new(n: usize) -> Self {
55 let res = Self::default();
56 res.update_twiddles(n);
57 res
58 }
59
60 fn roots_of_unity_table(&self, n: usize) -> Vec<Vec<F>> {
68 let lg_n = log2_strict_usize(n);
69 let generator = F::two_adic_generator(lg_n);
70 let half_n = 1 << (lg_n - 1);
71 let nth_roots = generator.powers().collect_n(half_n);
73
74 (0..lg_n)
75 .map(|i| nth_roots.iter().step_by(1 << i).copied().collect())
76 .collect()
77 }
78
79 fn update_twiddles(&self, fft_len: usize) {
81 let curr_max_fft_len = 1 << self.twiddles.read().len();
86 if fft_len > curr_max_fft_len {
87 let mut new_twiddles = self.roots_of_unity_table(fft_len);
88 let mut new_inv_twiddles: Vec<Vec<F>> = new_twiddles
89 .iter()
90 .map(|ts| {
91 iter::once(F::ONE)
94 .chain(ts[1..].iter().rev().map(|&f| -f))
95 .collect()
96 })
97 .collect();
98
99 new_twiddles.iter_mut().for_each(|ts| {
100 reverse_slice_index_bits(ts);
101 });
102 new_inv_twiddles.iter_mut().for_each(|ts| {
103 reverse_slice_index_bits(ts);
104 });
105
106 {
107 let mut tw_lock = self.twiddles.write();
108 let cur_have = 1usize << tw_lock.len();
109 if fft_len > cur_have {
110 *tw_lock = Arc::from(new_twiddles); }
112 }
113 {
114 let mut inv_tw_lock = self.inv_twiddles.write();
115 let cur_have = 1usize << inv_tw_lock.len();
116 if fft_len > cur_have {
117 *inv_tw_lock = Arc::from(new_inv_twiddles); }
119 }
120 }
121 }
122}
123
124impl<F> TwoAdicSubgroupDft<F> for Radix2DFTSmallBatch<F>
125where
126 F: TwoAdicField,
127{
128 type Evaluations = RowMajorMatrix<F>;
129
130 fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> Self::Evaluations {
131 let h = mat.height();
132 let w = mat.width();
133 let log_h = log2_strict_usize(h);
134
135 self.update_twiddles(h);
136 let g = self.twiddles.read().clone(); let root_table = &g[g.len() - log_h..];
138
139 let num_par_rows = estimate_num_rows_in_l1::<F>(h, w);
144 let log_num_par_rows = log2_strict_usize(num_par_rows);
145 let chunk_size = num_par_rows * w;
146
147 let multi_layer_dit = MultiLayerDitButterfly {};
151
152 for (dit_0, dit_1, dit_2) in root_table[log_num_par_rows..]
155 .iter()
156 .rev()
157 .map(|slice| unsafe { as_base_slice::<DitButterfly<F>, F>(slice) }) .tuples()
159 {
160 dft_layer_par_triple(&mut mat.as_view_mut(), dit_0, dit_1, dit_2, multi_layer_dit);
161 }
162
163 let corr = (log_h - log_num_par_rows) % LAYERS_PER_GROUP;
166 dft_layer_par_extra_layers(
167 &mut mat.as_view_mut(),
168 &root_table[log_num_par_rows..log_num_par_rows + corr],
169 multi_layer_dit,
170 );
171
172 par_remaining_layers(&mut mat.values, chunk_size, &root_table[..log_num_par_rows]);
176
177 reverse_matrix_index_bits(&mut mat);
179 mat
180 }
181
182 fn idft_batch(&self, mut mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
183 let h = mat.height();
184 let w = mat.width();
185 let log_h = log2_strict_usize(h);
186
187 self.update_twiddles(h);
188 let g = self.inv_twiddles.read().clone(); let start = g
190 .len()
191 .checked_sub(log_h)
192 .expect("log_h exceeds inv_twiddles length");
193 let root_table = &g[start..];
194
195 let num_par_rows = estimate_num_rows_in_l1::<F>(h, w);
201 let log_num_par_rows = log2_strict_usize(num_par_rows);
202 let chunk_size = num_par_rows * w;
203
204 reverse_matrix_index_bits(&mut mat);
206
207 par_initial_layers(
213 &mut mat.values,
214 chunk_size,
215 &root_table[..log_num_par_rows],
216 log_h,
217 );
218
219 let multi_layer_dif = MultiLayerDifButterfly {};
223
224 let corr = (log_h - log_num_par_rows) % LAYERS_PER_GROUP;
227 dft_layer_par_extra_layers(
228 &mut mat.as_view_mut(),
229 &root_table[log_num_par_rows..log_num_par_rows + corr],
230 multi_layer_dif,
231 );
232
233 for (dif_0, dif_1, dif_2) in root_table[(log_num_par_rows + corr)..]
236 .iter()
237 .map(|slice| unsafe { as_base_slice::<DifButterfly<F>, F>(slice) }) .tuples()
239 {
240 dft_layer_par_triple(&mut mat.as_view_mut(), dif_2, dif_1, dif_0, multi_layer_dif);
241 }
242
243 mat
244 }
245
246 fn coset_lde_batch(
247 &self,
248 mut mat: RowMajorMatrix<F>,
249 added_bits: usize,
250 shift: F,
251 ) -> Self::Evaluations {
252 let h = mat.height();
253 let w = mat.width();
254 let log_h = log2_strict_usize(h);
255
256 self.update_twiddles(h << added_bits);
257 let g = self.twiddles.read().clone(); let start = g
259 .len()
260 .checked_sub(log_h + added_bits)
261 .expect("log_h exceeds twiddles length");
262 let root_table = &g[start..];
263 let g = self.inv_twiddles.read().clone(); let start = g
265 .len()
266 .checked_sub(log_h)
267 .expect("log_h exceeds inv_twiddles length");
268 let inv_root_table = &g[start..];
269 let output_height = h << added_bits;
270
271 let output_values = F::zero_vec(output_height * w);
273 let mut out = RowMajorMatrix::new(output_values, w);
274
275 let num_par_rows = estimate_num_rows_in_l1::<F>(h, w);
292 let num_inner_dit_layers = log2_strict_usize(num_par_rows);
293 let num_inner_dif_layers = num_inner_dit_layers + added_bits;
294
295 let multi_layer_dit = MultiLayerDitButterfly {};
298 for (dit_0, dit_1, dit_2) in inv_root_table[num_inner_dit_layers..]
299 .iter()
300 .rev()
301 .map(|slice| unsafe { as_base_slice::<DitButterfly<F>, F>(slice) }) .tuples()
303 {
304 dft_layer_par_triple(&mut mat.as_view_mut(), dit_0, dit_1, dit_2, multi_layer_dit);
305 }
306
307 let corr = (log_h - num_inner_dit_layers) % LAYERS_PER_GROUP;
310 dft_layer_par_extra_layers(
311 &mut mat.as_view_mut(),
312 &inv_root_table[num_inner_dit_layers..num_inner_dit_layers + corr],
313 multi_layer_dit,
314 );
315
316 par_middle_layers(
320 &mut mat.as_view_mut(),
321 &mut out.as_view_mut(),
322 num_par_rows,
323 &root_table[..(num_inner_dif_layers)],
324 &inv_root_table[..num_inner_dit_layers],
325 added_bits,
326 shift,
327 );
328
329 let multi_layer_dif = MultiLayerDifButterfly {};
331
332 dft_layer_par_extra_layers(
335 &mut out.as_view_mut(),
336 &root_table[num_inner_dif_layers..num_inner_dif_layers + corr],
337 multi_layer_dif,
338 );
339
340 for (dif_0, dif_1, dif_2) in root_table[(num_inner_dif_layers + corr)..]
343 .iter()
344 .map(|slice| unsafe { as_base_slice::<DifButterfly<F>, F>(slice) }) .tuples()
346 {
347 dft_layer_par_triple(&mut out.as_view_mut(), dif_2, dif_1, dif_0, multi_layer_dif);
348 }
349
350 out
351 }
352}
353
354#[inline]
367fn dft_layer_par<F: Field, B: Butterfly<F>>(
368 mat: &mut RowMajorMatrixViewMut<'_, F>,
369 twiddles: &[B],
370) {
371 debug_assert!(
372 mat.height().is_multiple_of(twiddles.len()),
373 "Matrix height must be divisible by the number of twiddles"
374 );
375 let size = mat.values.len();
376 let num_blocks = twiddles.len();
377
378 let outer_block_size = size / num_blocks;
379 let half_outer_block_size = outer_block_size / 2;
380
381 mat.values
382 .par_chunks_exact_mut(outer_block_size)
383 .enumerate()
384 .for_each(|(ind, block)| {
385 let (hi_chunk, lo_chunk) = block.split_at_mut(half_outer_block_size);
387
388 let num_threads = current_num_threads();
390 let inner_block_size = size / (2 * num_blocks).max(num_threads);
391
392 hi_chunk
393 .par_chunks_mut(inner_block_size)
394 .zip(lo_chunk.par_chunks_mut(inner_block_size))
395 .for_each(|(hi_chunk, lo_chunk)| {
396 if ind == 0 {
397 TwiddleFreeButterfly.apply_to_rows(hi_chunk, lo_chunk);
399 } else {
400 twiddles[ind].apply_to_rows(hi_chunk, lo_chunk);
402 }
403 });
404 });
405}
406
407#[inline]
412fn par_remaining_layers<F: Field>(mat: &mut [F], chunk_size: usize, root_table: &[Vec<F>]) {
413 mat.par_chunks_exact_mut(chunk_size)
414 .enumerate()
415 .for_each(|(index, chunk)| {
416 remaining_layers(chunk, root_table, index);
417 });
418}
419
420fn remaining_layers<F: Field>(chunk: &mut [F], root_table: &[Vec<F>], index: usize) {
422 for (layer, twiddles) in root_table.iter().rev().enumerate() {
423 let num_twiddles_per_block = 1 << layer;
424 let start = index * num_twiddles_per_block;
425 let twiddle_range = start..(start + num_twiddles_per_block);
426 let dit_twiddles: &[DitButterfly<F>] = unsafe { as_base_slice(&twiddles[twiddle_range]) };
428 dft_layer(chunk, dit_twiddles);
429 }
430}
431
432#[inline]
440fn par_initial_layers<F: Field>(
441 mat: &mut [F],
442 chunk_size: usize,
443 root_table: &[Vec<F>],
444 log_height: usize,
445) {
446 let inv_height = F::ONE.div_2exp_u64(log_height as u64);
447 mat.par_chunks_exact_mut(chunk_size)
448 .enumerate()
449 .for_each(|(index, chunk)| {
450 scale_slice_in_place_single_core(chunk, inv_height);
452 initial_layers(chunk, root_table, index);
453 });
454}
455
456#[inline]
458fn initial_layers<F: Field>(chunk: &mut [F], root_table: &[Vec<F>], index: usize) {
459 let num_rounds = root_table.len();
460
461 for (layer, twiddles) in root_table.iter().enumerate() {
462 let num_twiddles_per_block = 1 << (num_rounds - layer - 1);
463 let start = index * num_twiddles_per_block;
464 let twiddle_range = start..(start + num_twiddles_per_block);
465 let dif_twiddles: &[DifButterfly<F>] = unsafe { as_base_slice(&twiddles[twiddle_range]) };
467 dft_layer(chunk, dif_twiddles);
468 }
469}
470
471fn par_middle_layers<F: Field>(
477 in_mat: &mut RowMajorMatrixViewMut<'_, F>,
478 out_mat: &mut RowMajorMatrixViewMut<'_, F>,
479 num_par_rows: usize,
480 root_table: &[Vec<F>],
481 inv_root_table: &[Vec<F>],
482 added_bits: usize,
483 shift: F,
484) {
485 debug_assert_eq!(in_mat.width(), out_mat.width());
486 debug_assert_eq!(in_mat.height() << added_bits, out_mat.height());
487
488 let width = in_mat.width();
489 let height = in_mat.height();
490 let num_rounds = root_table.len();
491 let in_chunk_size = num_par_rows * width;
492 let out_chunk_size = in_chunk_size << added_bits;
493
494 let log_height = log2_strict_usize(height);
495 let inv_height = F::ONE.div_2exp_u64(log_height as u64);
496
497 let mut scaling = shift.shifted_powers(inv_height).collect_n(height);
498 reverse_slice_index_bits(&mut scaling);
499
500 in_mat
501 .values
502 .par_chunks_exact_mut(in_chunk_size)
503 .zip(out_mat.values.par_chunks_exact_mut(out_chunk_size))
504 .zip(scaling.par_chunks_exact_mut(num_par_rows))
505 .enumerate()
506 .for_each(|(index, ((in_chunk, out_chunk), scaling))| {
507 remaining_layers(in_chunk, inv_root_table, index);
508
509 in_chunk
511 .chunks_exact(width)
512 .zip(scaling)
513 .zip(out_chunk.chunks_exact_mut(width << added_bits))
514 .for_each(|((in_row, scale), out_row)| {
515 out_row
516 .iter_mut()
517 .zip(in_row.iter())
518 .for_each(|(out_val, in_val)| {
519 *out_val = *in_val * *scale;
520 });
521 });
522
523 for (layer, twiddles) in root_table[..added_bits].iter().enumerate() {
526 let num_twiddles_per_block = 1 << (num_rounds - layer - 1);
527 let start = index * num_twiddles_per_block;
528 let twiddle_range = start..(start + num_twiddles_per_block);
529
530 let dif_twiddles_zeros: &[DifButterflyZeros<F>] =
532 unsafe { as_base_slice(&twiddles[twiddle_range]) };
533 dft_layer_zeros(out_chunk, dif_twiddles_zeros, added_bits - layer - 1);
534 }
535
536 initial_layers(out_chunk, &root_table[added_bits..], index);
537 });
538}
539
540#[inline]
549fn dft_layer<F: Field, B: Butterfly<F>>(vec: &mut [F], twiddles: &[B]) {
550 debug_assert_eq!(
551 vec.len() % twiddles.len(),
552 0,
553 "Vector length must be divisible by the number of twiddles"
554 );
555 let size = vec.len();
556 let num_blocks = twiddles.len();
557
558 let block_size = size / num_blocks;
559 let half_block_size = block_size / 2;
560
561 vec.chunks_exact_mut(block_size)
562 .zip(twiddles)
563 .for_each(|(block, &twiddle)| {
564 let (hi_chunk, lo_chunk) = block.split_at_mut(half_block_size);
566
567 twiddle.apply_to_rows(hi_chunk, lo_chunk);
569 });
570}
571
572#[inline]
584fn dft_layer_par_double<F: Field, B: Butterfly<F>, M: MultiLayerButterfly<F, B>>(
585 mat: &mut RowMajorMatrixViewMut<'_, F>,
586 twiddles_small: &[B],
587 twiddles_large: &[B],
588 multi_butterfly: M,
589) {
590 debug_assert!(
591 mat.height().is_multiple_of(twiddles_small.len()),
592 "Matrix height must be divisible by the number of twiddles"
593 );
594 let size = mat.values.len();
595 let num_blocks = twiddles_small.len();
596
597 let outer_block_size = size / num_blocks;
598 let quarter_outer_block_size = outer_block_size / 4;
599
600 let inner_chunk_size =
603 (workload_size::<F>().next_power_of_two() / 4).min(quarter_outer_block_size);
604
605 mat.values
606 .par_chunks_exact_mut(outer_block_size)
607 .enumerate()
608 .for_each(|(ind, block)| {
609 let chunk_par_iters_0 = block
612 .chunks_exact_mut(quarter_outer_block_size)
613 .map(|chunk| chunk.par_chunks_mut(inner_chunk_size))
614 .collect::<Vec<_>>();
615 let chunk_par_iters_1 = zip_par_iter_vec(chunk_par_iters_0);
616 chunk_par_iters_1.into_iter().tuples().for_each(|(hi, lo)| {
617 hi.zip(lo).for_each(|chunks| {
618 multi_butterfly.apply_2_layers(chunks, ind, twiddles_small, twiddles_large);
619 });
620 });
621 });
622}
623
624#[inline]
637fn dft_layer_par_triple<F: Field, B: Butterfly<F>, M: MultiLayerButterfly<F, B>>(
638 mat: &mut RowMajorMatrixViewMut<'_, F>,
639 twiddles_small: &[B],
640 twiddles_med: &[B],
641 twiddles_large: &[B],
642 multi_butterfly: M,
643) {
644 debug_assert!(
645 mat.height().is_multiple_of(twiddles_small.len()),
646 "Matrix height must be divisible by the number of twiddles"
647 );
648 let size = mat.values.len();
649 let num_blocks = twiddles_small.len();
650
651 let outer_block_size = size / num_blocks;
652 let eighth_outer_block_size = outer_block_size / 8;
653
654 let inner_chunk_size =
657 (workload_size::<F>().next_power_of_two() / 8).min(eighth_outer_block_size);
658
659 mat.values
660 .par_chunks_exact_mut(outer_block_size)
661 .enumerate()
662 .for_each(|(ind, block)| {
663 let chunk_par_iters_0 = block
666 .chunks_exact_mut(eighth_outer_block_size)
667 .map(|chunk| chunk.par_chunks_mut(inner_chunk_size))
668 .collect::<Vec<_>>();
669 let chunk_par_iters_1 = zip_par_iter_vec(chunk_par_iters_0);
670 let chunk_par_iters_2 = zip_par_iter_vec(chunk_par_iters_1);
671 chunk_par_iters_2.into_iter().tuples().for_each(|(hi, lo)| {
672 hi.zip(lo).for_each(|chunks| {
673 multi_butterfly.apply_3_layers(
674 chunks,
675 ind,
676 twiddles_small,
677 twiddles_med,
678 twiddles_large,
679 );
680 });
681 });
682 });
683}
684
685fn dft_layer_par_extra_layers<F: Field, B: Butterfly<F>, M: MultiLayerButterfly<F, B>>(
690 mat: &mut RowMajorMatrixViewMut<'_, F>,
691 root_table: &[Vec<F>],
692 multi_layer: M,
693) {
694 match root_table.len() {
695 1 => {
696 let fft_layer: &[B] = unsafe { as_base_slice(&root_table[0]) };
698 dft_layer_par(&mut mat.as_view_mut(), fft_layer);
699 }
700 2 => {
701 let fft_layer_0: &[B] = unsafe { as_base_slice(&root_table[0]) };
702 let fft_layer_1: &[B] = unsafe { as_base_slice(&root_table[1]) };
703 dft_layer_par_double(
704 &mut mat.as_view_mut(),
705 fft_layer_1,
706 fft_layer_0,
707 multi_layer,
708 );
709 }
710 0 => {}
711 _ => unreachable!("The number of layers must be 0, 1 or 2"),
712 }
713}
714
715#[inline]
738fn dft_layer_zeros<F: Field, B: Butterfly<F>>(vec: &mut [F], twiddles: &[B], skip: usize) {
739 debug_assert_eq!(
740 vec.len() % twiddles.len(),
741 0,
742 "Vector length must be divisible by the number of twiddles"
743 );
744 let size = vec.len();
745 let num_blocks = twiddles.len();
746
747 let block_size = size / num_blocks;
748 let half_block_size = block_size / 2;
749
750 vec.chunks_exact_mut(block_size)
751 .zip(twiddles)
752 .step_by(1 << skip) .for_each(|(block, &twiddle)| {
754 let (hi_chunk, lo_chunk) = block.split_at_mut(half_block_size);
756
757 twiddle.apply_to_rows(hi_chunk, lo_chunk);
759 });
760}
761
762type DoubleLayerBlockDecomposition<'a, F> =
764 ((&'a mut [F], &'a mut [F]), (&'a mut [F], &'a mut [F]));
765
766#[inline]
768fn fft_double_layer_single_twiddle<F: Field, Fly: Butterfly<F>>(
769 block: &mut DoubleLayerBlockDecomposition<'_, F>,
770 butterfly: Fly,
771) {
772 butterfly.apply_to_rows(block.0.0, block.1.0);
773 butterfly.apply_to_rows(block.0.1, block.1.1);
774}
775
776#[inline]
781fn fft_double_layer_double_twiddle<F: Field, Fly0: Butterfly<F>, Fly1: Butterfly<F>>(
782 block: &mut DoubleLayerBlockDecomposition<'_, F>,
783 fly0: Fly0,
784 fly1: Fly1,
785) {
786 fly0.apply_to_rows(block.0.0, block.0.1);
787 fly1.apply_to_rows(block.1.0, block.1.1);
788}
789
790type TripleLayerBlockDecomposition<'a, F> = (
792 ((&'a mut [F], &'a mut [F]), (&'a mut [F], &'a mut [F])),
793 ((&'a mut [F], &'a mut [F]), (&'a mut [F], &'a mut [F])),
794);
795
796#[inline]
798fn fft_triple_layer_single_twiddle<F: Field, Fly: Butterfly<F>>(
799 block: &mut TripleLayerBlockDecomposition<'_, F>,
800 butterfly: Fly,
801) {
802 butterfly.apply_to_rows(block.0.0.0, block.1.0.0);
803 butterfly.apply_to_rows(block.0.0.1, block.1.0.1);
804 butterfly.apply_to_rows(block.0.1.0, block.1.1.0);
805 butterfly.apply_to_rows(block.0.1.1, block.1.1.1);
806}
807
808#[inline]
813fn fft_triple_layer_double_twiddle<F: Field, Fly0: Butterfly<F>, Fly1: Butterfly<F>>(
814 block: &mut TripleLayerBlockDecomposition<'_, F>,
815 fly0: Fly0,
816 fly1: Fly1,
817) {
818 fly0.apply_to_rows(block.0.0.0, block.0.1.0);
819 fly0.apply_to_rows(block.0.0.1, block.0.1.1);
820 fly1.apply_to_rows(block.1.0.0, block.1.1.0);
821 fly1.apply_to_rows(block.1.0.1, block.1.1.1);
822}
823
824#[inline]
829fn fft_triple_layer_quad_twiddle<F: Field, Fly0: Butterfly<F>, Flies: Butterfly<F>>(
830 block: &mut TripleLayerBlockDecomposition<'_, F>,
831 fly0: Fly0,
832 butterflies: &[Flies],
833) {
834 debug_assert!(butterflies.len() == 3);
835 fly0.apply_to_rows(block.0.0.0, block.0.0.1);
836 butterflies[0].apply_to_rows(block.0.1.0, block.0.1.1);
837 butterflies[1].apply_to_rows(block.1.0.0, block.1.0.1);
838 butterflies[2].apply_to_rows(block.1.1.0, block.1.1.1);
839}
840
841#[must_use]
846const fn workload_size<T: Sized>() -> usize {
847 const L1_CACHE_SIZE: usize = 1 << 15; L1_CACHE_SIZE / size_of::<T>()
849}
850
851#[must_use]
857fn estimate_num_rows_in_l1<T: Sized>(height: usize, width: usize) -> usize {
858 (workload_size::<T>() / width)
859 .next_power_of_two()
860 .min(height) }
862
863#[inline]
870fn zip_par_iter_vec<I: IndexedParallelIterator>(
871 in_vec: Vec<I>,
872) -> Vec<impl IndexedParallelIterator<Item = (I::Item, I::Item)>> {
873 in_vec
874 .into_iter()
875 .tuples()
876 .map(|(hi, lo)| hi.zip(lo))
877 .collect::<Vec<_>>()
878}
879
880trait MultiLayerButterfly<F: Field, B: Butterfly<F>>: Copy + Send + Sync {
881 fn apply_2_layers(
882 &self,
883 chunk_decomposition: DoubleLayerBlockDecomposition<'_, F>,
884 ind: usize,
885 twiddles_small: &[B],
886 twiddles_large: &[B],
887 );
888
889 fn apply_3_layers(
890 &self,
891 chunk_decomposition: TripleLayerBlockDecomposition<'_, F>,
892 ind: usize,
893 twiddles_small: &[B],
894 twiddles_med: &[B],
895 twiddles_large: &[B],
896 );
897}
898
899#[derive(Debug, Clone, Copy)]
900struct MultiLayerDitButterfly;
901
902impl<F: Field> MultiLayerButterfly<F, DitButterfly<F>> for MultiLayerDitButterfly {
903 #[inline]
904 fn apply_2_layers(
905 &self,
906 mut blk_decomp: DoubleLayerBlockDecomposition<'_, F>,
907 ind: usize,
908 twiddles_small: &[DitButterfly<F>],
909 twiddles_large: &[DitButterfly<F>],
910 ) {
911 if ind == 0 {
912 fft_double_layer_single_twiddle(&mut blk_decomp, TwiddleFreeButterfly);
913 fft_double_layer_double_twiddle(
914 &mut blk_decomp,
915 TwiddleFreeButterfly,
916 twiddles_large[1],
917 );
918 } else {
919 fft_double_layer_single_twiddle(&mut blk_decomp, twiddles_small[ind]);
920 fft_double_layer_double_twiddle(
921 &mut blk_decomp,
922 twiddles_large[2 * ind],
923 twiddles_large[2 * ind + 1],
924 );
925 }
926 }
927
928 #[inline]
929 fn apply_3_layers(
930 &self,
931 mut blk_decomp: TripleLayerBlockDecomposition<'_, F>,
932 ind: usize,
933 twiddles_small: &[DitButterfly<F>],
934 twiddles_med: &[DitButterfly<F>],
935 twiddles_large: &[DitButterfly<F>],
936 ) {
937 if ind == 0 {
938 fft_triple_layer_single_twiddle(&mut blk_decomp, TwiddleFreeButterfly);
939 fft_triple_layer_double_twiddle(&mut blk_decomp, TwiddleFreeButterfly, twiddles_med[1]);
940 fft_triple_layer_quad_twiddle(
941 &mut blk_decomp,
942 TwiddleFreeButterfly,
943 &twiddles_large[1..4],
944 );
945 } else {
946 fft_triple_layer_single_twiddle(&mut blk_decomp, twiddles_small[ind]);
947 fft_triple_layer_double_twiddle(
948 &mut blk_decomp,
949 twiddles_med[2 * ind],
950 twiddles_med[2 * ind + 1],
951 );
952 fft_triple_layer_quad_twiddle(
953 &mut blk_decomp,
954 twiddles_large[4 * ind],
955 &twiddles_large[4 * ind + 1..4 * (ind + 1)],
956 );
957 }
958 }
959}
960
961#[derive(Debug, Clone, Copy)]
962struct MultiLayerDifButterfly;
963
964impl<F: Field> MultiLayerButterfly<F, DifButterfly<F>> for MultiLayerDifButterfly {
965 #[inline]
966 fn apply_2_layers(
967 &self,
968 mut blk_decomp: DoubleLayerBlockDecomposition<'_, F>,
969 ind: usize,
970 twiddles_small: &[DifButterfly<F>],
971 twiddles_large: &[DifButterfly<F>],
972 ) {
973 if ind == 0 {
974 fft_double_layer_double_twiddle(
975 &mut blk_decomp,
976 TwiddleFreeButterfly,
977 twiddles_large[1],
978 );
979 fft_double_layer_single_twiddle(&mut blk_decomp, TwiddleFreeButterfly);
980 } else {
981 fft_double_layer_double_twiddle(
982 &mut blk_decomp,
983 twiddles_large[2 * ind],
984 twiddles_large[2 * ind + 1],
985 );
986 fft_double_layer_single_twiddle(&mut blk_decomp, twiddles_small[ind]);
987 }
988 }
989
990 #[inline]
991 fn apply_3_layers(
992 &self,
993 mut blk_decomp: TripleLayerBlockDecomposition<'_, F>,
994 ind: usize,
995 twiddles_small: &[DifButterfly<F>],
996 twiddles_med: &[DifButterfly<F>],
997 twiddles_large: &[DifButterfly<F>],
998 ) {
999 if ind == 0 {
1000 fft_triple_layer_quad_twiddle(
1001 &mut blk_decomp,
1002 TwiddleFreeButterfly,
1003 &twiddles_large[1..4],
1004 );
1005 fft_triple_layer_double_twiddle(&mut blk_decomp, TwiddleFreeButterfly, twiddles_med[1]);
1006 fft_triple_layer_single_twiddle(&mut blk_decomp, TwiddleFreeButterfly);
1007 } else {
1008 fft_triple_layer_quad_twiddle(
1009 &mut blk_decomp,
1010 twiddles_large[4 * ind],
1011 &twiddles_large[4 * ind + 1..4 * (ind + 1)],
1012 );
1013 fft_triple_layer_double_twiddle(
1014 &mut blk_decomp,
1015 twiddles_med[2 * ind],
1016 twiddles_med[2 * ind + 1],
1017 );
1018 fft_triple_layer_single_twiddle(&mut blk_decomp, twiddles_small[ind]);
1019 }
1020 }
1021}