1use alloc::vec::Vec;
2
3use p3_field::{Field, PrimeCharacteristicRing, TwoAdicField};
4use p3_matrix::Matrix;
5use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut};
6use p3_matrix::util::reverse_matrix_index_bits;
7use p3_maybe_rayon::prelude::*;
8use p3_util::{flatten_to_base, log2_strict_usize, reverse_slice_index_bits};
9use tracing::instrument;
10
11use crate::TwoAdicSubgroupDft;
12use crate::butterflies::{Butterfly, DifButterfly, DitButterfly, TwiddleFreeButterfly};
13use crate::util::divide_by_height;
14
15#[derive(Default, Clone)]
18pub struct Radix2Bowers;
19
20impl<F: TwoAdicField> TwoAdicSubgroupDft<F> for Radix2Bowers {
21 type Evaluations = RowMajorMatrix<F>;
22
23 fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
24 reverse_matrix_index_bits(&mut mat);
25 bowers_g(&mut mat.as_view_mut());
26 mat
27 }
28
29 fn idft_batch(&self, mut mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
31 bowers_g_t(&mut mat.as_view_mut());
32 divide_by_height(&mut mat);
33 reverse_matrix_index_bits(&mut mat);
34 mat
35 }
36
37 fn lde_batch(&self, mut mat: RowMajorMatrix<F>, added_bits: usize) -> RowMajorMatrix<F> {
38 bowers_g_t(&mut mat.as_view_mut());
39 divide_by_height(&mut mat);
40 mat = mat.bit_reversed_zero_pad(added_bits);
41 bowers_g(&mut mat.as_view_mut());
42 mat
43 }
44
45 #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
46 fn coset_lde_batch(
47 &self,
48 mut mat: RowMajorMatrix<F>,
49 added_bits: usize,
50 shift: F,
51 ) -> RowMajorMatrix<F> {
52 let h = mat.height();
53 let log_h = log2_strict_usize(h);
54 let h_inv_subfield = F::PrimeSubfield::ONE.div_2exp_u64(log_h as u64);
57 let h_inv = F::from_prime_subfield(h_inv_subfield);
58
59 bowers_g_t(&mut mat.as_view_mut());
60
61 let mut weights = shift.shifted_powers(h_inv).collect_n(h);
65 reverse_slice_index_bits(&mut weights);
67
68 mat.par_rows_mut()
69 .zip(weights.into_par_iter())
70 .for_each(|(row, weight)| row.iter_mut().for_each(|elem| *elem *= weight));
71
72 mat = mat.bit_reversed_zero_pad(added_bits);
73
74 bowers_g(&mut mat.as_view_mut());
75
76 mat
77 }
78}
79
80fn bowers_g<F: TwoAdicField>(mat: &mut RowMajorMatrixViewMut<'_, F>) {
83 let h = mat.height();
84 let log_h = log2_strict_usize(h);
85
86 let root = F::two_adic_generator(log_h);
87 let mut twiddles = root.powers().collect_n(h / 2);
88 reverse_slice_index_bits(&mut twiddles);
89 let twiddles: Vec<DifButterfly<F>> = unsafe { flatten_to_base(twiddles) };
91
92 for log_half_block_size in 0..log_h {
93 butterfly_layer(mat, 1 << log_half_block_size, &twiddles);
94 }
95}
96
97fn bowers_g_t<F: TwoAdicField>(mat: &mut RowMajorMatrixViewMut<'_, F>) {
100 let h = mat.height();
101 let log_h = log2_strict_usize(h);
102
103 let root_inv = F::two_adic_generator(log_h).inverse();
104 let mut twiddles = root_inv.powers().collect_n(h / 2);
105 reverse_slice_index_bits(&mut twiddles);
106 let twiddles: Vec<DitButterfly<F>> = unsafe { flatten_to_base(twiddles) };
108
109 for log_half_block_size in (0..log_h).rev() {
110 butterfly_layer(mat, 1 << log_half_block_size, &twiddles);
111 }
112}
113
114fn butterfly_layer<F: Field, B: Butterfly<F>>(
115 mat: &mut RowMajorMatrixViewMut<'_, F>,
116 half_block_size: usize,
117 twiddles: &[B],
118) {
119 mat.par_row_chunks_exact_mut(2 * half_block_size)
120 .enumerate()
121 .for_each(|(block, mut chunks)| {
122 let (mut hi_chunks, mut lo_chunks) = chunks.split_rows_mut(half_block_size);
123 hi_chunks
124 .par_rows_mut()
125 .zip(lo_chunks.par_rows_mut())
126 .for_each(|(hi_chunk, lo_chunk)| {
127 if block == 0 {
128 TwiddleFreeButterfly.apply_to_rows(hi_chunk, lo_chunk);
129 } else {
130 twiddles[block].apply_to_rows(hi_chunk, lo_chunk);
131 }
132 });
133 });
134}