1use alloc::vec::Vec;
2
3use p3_field::{Field, PrimeCharacteristicRing, TwoAdicField};
4use p3_matrix::Matrix;
5use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut};
6use p3_matrix::util::reverse_matrix_index_bits;
7use p3_maybe_rayon::prelude::*;
8use p3_util::{flatten_to_base, log2_strict_usize, reverse_slice_index_bits};
9use tracing::instrument;
10
11use crate::butterflies::{Butterfly, DifButterfly, DitButterfly, TwiddleFreeButterfly};
12use crate::util::divide_by_height;
13use crate::{Layout, TwoAdicSubgroupDft};
14
15#[derive(Default, Clone)]
18pub struct Radix2Bowers;
19
20impl<F: TwoAdicField> TwoAdicSubgroupDft<F> for Radix2Bowers {
21 type Evaluations = RowMajorMatrix<F>;
22
23 fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
24 reverse_matrix_index_bits(&mut mat);
25 bowers_g(&mut mat.as_view_mut());
26 mat
27 }
28
29 fn idft_batch(&self, mut mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
31 bowers_g_t(&mut mat.as_view_mut());
32 divide_by_height(&mut mat);
33 reverse_matrix_index_bits(&mut mat);
34 mat
35 }
36
37 fn lde_batch(&self, mut mat: RowMajorMatrix<F>, added_bits: usize) -> RowMajorMatrix<F> {
38 bowers_g_t(&mut mat.as_view_mut());
39 divide_by_height(&mut mat);
40 mat = mat.bit_reversed_zero_pad(added_bits);
41 bowers_g(&mut mat.as_view_mut());
42 mat
43 }
44
45 #[instrument(skip_all, level = "debug", fields(dims = %mat.dimensions(), added_bits))]
46 fn coset_lde_batch(
47 &self,
48 mut mat: RowMajorMatrix<F>,
49 added_bits: usize,
50 shift: F,
51 ) -> RowMajorMatrix<F> {
52 let h = mat.height();
53 let log_h = log2_strict_usize(h);
54 let h_inv_subfield = F::PrimeSubfield::ONE.div_2exp_u64(log_h as u64);
57 let h_inv = F::from_prime_subfield(h_inv_subfield);
58
59 bowers_g_t(&mut mat.as_view_mut());
60
61 let mut weights = shift.shifted_powers(h_inv).collect_n(h);
65 reverse_slice_index_bits(&mut weights);
67
68 mat.par_rows_mut()
69 .zip(weights.into_par_iter())
70 .for_each(|(row, weight)| row.iter_mut().for_each(|elem| *elem *= weight));
71
72 mat = mat.bit_reversed_zero_pad(added_bits);
73
74 bowers_g(&mut mat.as_view_mut());
75
76 mat
77 }
78
79 #[instrument(skip_all, level = "debug", fields(dims = %mat.dimensions(), added_bits))]
80 fn coset_lde_batch_with_transform<T>(
81 &self,
82 mut mat: RowMajorMatrix<F>,
83 added_bits: usize,
84 shift: F,
85 transform: T,
86 ) -> RowMajorMatrix<F>
87 where
88 T: FnOnce(&mut RowMajorMatrixViewMut<'_, F>, Layout),
89 {
90 let h = mat.height();
91 let log_h = log2_strict_usize(h);
92 let h_inv_subfield = F::PrimeSubfield::ONE.div_2exp_u64(log_h as u64);
93 let h_inv = F::from_prime_subfield(h_inv_subfield);
94
95 bowers_g_t(&mut mat.as_view_mut());
97
98 mat.values.par_iter_mut().for_each(|v| *v *= h_inv);
102
103 transform(&mut mat.as_view_mut(), Layout::BitReversed);
104
105 let mut shift_powers = shift.powers().collect_n(h);
107 reverse_slice_index_bits(&mut shift_powers);
108 mat.par_rows_mut()
109 .zip(shift_powers.into_par_iter())
110 .for_each(|(row, sp)| row.iter_mut().for_each(|e| *e *= sp));
111
112 mat = mat.bit_reversed_zero_pad(added_bits);
113 bowers_g(&mut mat.as_view_mut());
114 mat
115 }
116}
117
118fn bowers_g<F: TwoAdicField>(mat: &mut RowMajorMatrixViewMut<'_, F>) {
121 let h = mat.height();
122 let log_h = log2_strict_usize(h);
123
124 let root = F::two_adic_generator(log_h);
125 let mut twiddles = root.powers().collect_n(h / 2);
126 reverse_slice_index_bits(&mut twiddles);
127 let twiddles: Vec<DifButterfly<F>> = unsafe { flatten_to_base(twiddles) };
129
130 for log_half_block_size in 0..log_h {
131 butterfly_layer(mat, 1 << log_half_block_size, &twiddles);
132 }
133}
134
135fn bowers_g_t<F: TwoAdicField>(mat: &mut RowMajorMatrixViewMut<'_, F>) {
138 let h = mat.height();
139 let log_h = log2_strict_usize(h);
140
141 let root_inv = F::two_adic_generator(log_h).inverse();
142 let mut twiddles = root_inv.powers().collect_n(h / 2);
143 reverse_slice_index_bits(&mut twiddles);
144 let twiddles: Vec<DitButterfly<F>> = unsafe { flatten_to_base(twiddles) };
146
147 for log_half_block_size in (0..log_h).rev() {
148 butterfly_layer(mat, 1 << log_half_block_size, &twiddles);
149 }
150}
151
152fn butterfly_layer<F: Field, B: Butterfly<F>>(
153 mat: &mut RowMajorMatrixViewMut<'_, F>,
154 half_block_size: usize,
155 twiddles: &[B],
156) {
157 mat.par_row_chunks_exact_mut(2 * half_block_size)
158 .enumerate()
159 .for_each(|(block, mut chunks)| {
160 let (mut hi_chunks, mut lo_chunks) = chunks.split_rows_mut(half_block_size);
161 hi_chunks
162 .par_rows_mut()
163 .zip(lo_chunks.par_rows_mut())
164 .for_each(|(hi_chunk, lo_chunk)| {
165 if block == 0 {
166 TwiddleFreeButterfly.apply_to_rows(hi_chunk, lo_chunk);
167 } else {
168 twiddles[block].apply_to_rows(hi_chunk, lo_chunk);
169 }
170 });
171 });
172}