p3_mersenne_31/
radix_2_dit.rs1use p3_dft::TwoAdicSubgroupDft;
2use p3_field::extension::Complex;
3use p3_field::{PrimeCharacteristicRing, PrimeField64, TwoAdicField};
4use p3_matrix::Matrix;
5use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut};
6use p3_matrix::util::reverse_matrix_index_bits;
7use p3_util::log2_strict_usize;
8
9use crate::Mersenne31;
10
11type F = Mersenne31;
12type C = Complex<F>;
13
14#[derive(Debug, Default, Clone)]
15pub struct Mersenne31ComplexRadix2Dit;
16
17impl TwoAdicSubgroupDft<C> for Mersenne31ComplexRadix2Dit {
18 type Evaluations = RowMajorMatrix<C>;
19 fn dft_batch(&self, mut mat: RowMajorMatrix<C>) -> RowMajorMatrix<C> {
20 let h = mat.height();
21 let log_h = log2_strict_usize(h);
22
23 let root = C::two_adic_generator(log_h);
24 let twiddles = root.powers().collect_n(h / 2);
25
26 reverse_matrix_index_bits(&mut mat);
28 for layer in 0..log_h {
29 dit_layer(&mut mat.as_view_mut(), layer, &twiddles);
30 }
31 mat
32 }
33}
34
35fn dit_layer(mat: &mut RowMajorMatrixViewMut<'_, C>, layer: usize, twiddles: &[C]) {
42 let h = mat.height();
43 let log_h = log2_strict_usize(h);
44 let layer_rev = log_h - 1 - layer;
45
46 let half_block_size = 1 << layer;
47 let block_size = half_block_size * 2;
48
49 for j in (0..h).step_by(block_size) {
50 let butterfly_hi = j;
52 let butterfly_lo = butterfly_hi + half_block_size;
53 twiddle_free_butterfly(mat, butterfly_hi, butterfly_lo);
54
55 for i in 1..half_block_size {
56 let butterfly_hi = j + i;
57 let butterfly_lo = butterfly_hi + half_block_size;
58 let twiddle = twiddles[i << layer_rev];
59 dit_butterfly(mat, butterfly_hi, butterfly_lo, twiddle);
60 }
61 }
62}
63
64#[inline]
65fn twiddle_free_butterfly(mat: &mut RowMajorMatrixViewMut<'_, C>, row_1: usize, row_2: usize) {
66 let ((shorts_1, suffix_1), (shorts_2, suffix_2)) = mat.packed_row_pair_mut(row_1, row_2);
67
68 let row_1 = shorts_1.iter_mut().chain(suffix_1);
72 let row_2 = shorts_2.iter_mut().chain(suffix_2);
73
74 for (x, y) in row_1.zip(row_2) {
75 let sum = *x + *y;
76 let diff = *x - *y;
77 *x = sum;
78 *y = diff;
79 }
80}
81
82#[inline]
83fn dit_butterfly(mat: &mut RowMajorMatrixViewMut<'_, C>, row_1: usize, row_2: usize, twiddle: C) {
84 let ((shorts_1, suffix_1), (shorts_2, suffix_2)) = mat.packed_row_pair_mut(row_1, row_2);
85
86 let row_1 = shorts_1.iter_mut().chain(suffix_1);
90 let row_2 = shorts_2.iter_mut().chain(suffix_2);
91
92 for (x, y) in row_1.zip(row_2) {
93 dit_butterfly_inner(x, y, twiddle);
94 }
95}
96
97#[inline]
111fn dit_butterfly_inner(x: &mut C, y: &mut C, twiddle: C) {
112 const P_SQR: i64 = (F::ORDER_U64 * F::ORDER_U64) as i64;
116 const TWO_P_SQR: i64 = 2 * P_SQR;
117
118 let unpack = |x: C| (x.to_array()[0].value as i64, x.to_array()[1].value as i64);
123 let (x1, x2) = unpack(*x);
124 let (y1, y2) = unpack(*y);
125 let (w1, w2) = unpack(twiddle);
126
127 let z1 = y1 * w1 - y2 * w2; let a1 = F::from_u64((P_SQR + x1 + z1) as u64);
141 let b1 = F::from_u64((P_SQR + x1 - z1) as u64);
143
144 let z2 = y2 * w1 + y1 * w2; let a2 = F::from_u64((x2 + z2) as u64);
150 let b2 = F::from_u64((TWO_P_SQR + x2 - z2) as u64);
152
153 *x = C::new_complex(a1, a2);
154 *y = C::new_complex(b1, b2);
155}