p3_mersenne_31/
radix_2_dit.rs

1use p3_dft::TwoAdicSubgroupDft;
2use p3_field::extension::Complex;
3use p3_field::{PrimeCharacteristicRing, PrimeField64, TwoAdicField};
4use p3_matrix::Matrix;
5use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut};
6use p3_matrix::util::reverse_matrix_index_bits;
7use p3_util::log2_strict_usize;
8
9use crate::Mersenne31;
10
11type F = Mersenne31;
12type C = Complex<F>;
13
14#[derive(Debug, Default, Clone)]
15pub struct Mersenne31ComplexRadix2Dit;
16
17impl TwoAdicSubgroupDft<C> for Mersenne31ComplexRadix2Dit {
18    type Evaluations = RowMajorMatrix<C>;
19    fn dft_batch(&self, mut mat: RowMajorMatrix<C>) -> RowMajorMatrix<C> {
20        let h = mat.height();
21        let log_h = log2_strict_usize(h);
22
23        let root = C::two_adic_generator(log_h);
24        let twiddles = root.powers().collect_n(h / 2);
25
26        // DIT butterfly
27        reverse_matrix_index_bits(&mut mat);
28        for layer in 0..log_h {
29            dit_layer(&mut mat.as_view_mut(), layer, &twiddles);
30        }
31        mat
32    }
33}
34
35// NB: Most of what follows is copypasta from `dft/src/radix_2_dit.rs`.
36// This is ugly, but the alternative is finding another way to "inject"
37// the specialisation of the butterfly evaluation to Mersenne31Complex
38// (in `dit_butterfly_inner()` below) into the existing structure.
39
40/// One layer of a DIT butterfly network.
41fn dit_layer(mat: &mut RowMajorMatrixViewMut<'_, C>, layer: usize, twiddles: &[C]) {
42    let h = mat.height();
43    let log_h = log2_strict_usize(h);
44    let layer_rev = log_h - 1 - layer;
45
46    let half_block_size = 1 << layer;
47    let block_size = half_block_size * 2;
48
49    for j in (0..h).step_by(block_size) {
50        // Unroll i=0 case
51        let butterfly_hi = j;
52        let butterfly_lo = butterfly_hi + half_block_size;
53        twiddle_free_butterfly(mat, butterfly_hi, butterfly_lo);
54
55        for i in 1..half_block_size {
56            let butterfly_hi = j + i;
57            let butterfly_lo = butterfly_hi + half_block_size;
58            let twiddle = twiddles[i << layer_rev];
59            dit_butterfly(mat, butterfly_hi, butterfly_lo, twiddle);
60        }
61    }
62}
63
64#[inline]
65fn twiddle_free_butterfly(mat: &mut RowMajorMatrixViewMut<'_, C>, row_1: usize, row_2: usize) {
66    let ((shorts_1, suffix_1), (shorts_2, suffix_2)) = mat.packed_row_pair_mut(row_1, row_2);
67
68    // TODO: There's no special packing for Mersenne31Complex at the
69    // time of writing; when there is we'll want to expand this out
70    // into three separate loops.
71    let row_1 = shorts_1.iter_mut().chain(suffix_1);
72    let row_2 = shorts_2.iter_mut().chain(suffix_2);
73
74    for (x, y) in row_1.zip(row_2) {
75        let sum = *x + *y;
76        let diff = *x - *y;
77        *x = sum;
78        *y = diff;
79    }
80}
81
82#[inline]
83fn dit_butterfly(mat: &mut RowMajorMatrixViewMut<'_, C>, row_1: usize, row_2: usize, twiddle: C) {
84    let ((shorts_1, suffix_1), (shorts_2, suffix_2)) = mat.packed_row_pair_mut(row_1, row_2);
85
86    // TODO: There's no special packing for Mersenne31Complex at the
87    // time of writing; when there is we'll want to expand this out
88    // into three separate loops.
89    let row_1 = shorts_1.iter_mut().chain(suffix_1);
90    let row_2 = shorts_2.iter_mut().chain(suffix_2);
91
92    for (x, y) in row_1.zip(row_2) {
93        dit_butterfly_inner(x, y, twiddle);
94    }
95}
96
97/// Given x, y, and twiddle, return the "butterfly values"
98/// x' = x + y*twiddle and y' = x - y*twiddle.
99///
100/// NB: At the time of writing, replacing the straight-forward
101/// implementation
102///
103///    let sum = *x + *y * twiddle;
104///    let diff = *x - *y * twiddle;
105///    *x = sum;
106///    *y = diff;
107///
108/// with the one below approximately halved the runtime of a DFT over
109/// `Mersenne31Complex`.
110#[inline]
111fn dit_butterfly_inner(x: &mut C, y: &mut C, twiddle: C) {
112    // Adding any multiple of P doesn't change the result modulo P;
113    // we use this to ensure that the inputs to `from_u64`
114    // below are non-negative.
115    const P_SQR: i64 = (F::ORDER_U64 * F::ORDER_U64) as i64;
116    const TWO_P_SQR: i64 = 2 * P_SQR;
117
118    // Unpack the inputs;
119    //   x = x1 + i*x2
120    //   y = y1 + i*y2
121    //   twiddle = w1 + i*w2
122    let unpack = |x: C| (x.to_array()[0].value as i64, x.to_array()[1].value as i64);
123    let (x1, x2) = unpack(*x);
124    let (y1, y2) = unpack(*y);
125    let (w1, w2) = unpack(twiddle);
126
127    // x ± y*twiddle
128    // = (x1 + i*x2) ± (y1 + i*y2)*(w1 + i*w2)
129    // = (x1 ± (y1*w1 - y2*w2)) + i*(x2 ± (y2*w1 + y1*w2))
130    // = (x1 ± z1) + i*(x2 ± z2)
131    // where z1 + i*z2 = y*twiddle
132
133    // SAFE: multiplying `u64` values within the range of `Mersennes31` doesn't overflow:
134    // (2^31 - 1) * (2^31 - 1) = 2^62 - 2^32 + 1 < 2^64 - 1
135    let z1 = y1 * w1 - y2 * w2; // -P^2 <= z1 <= P^2
136
137    // NB: 2*P^2 + P < 2^63
138
139    // -P^2 <= x1 + z1 <= P^2 + P
140    let a1 = F::from_u64((P_SQR + x1 + z1) as u64);
141    // -P^2 <= x1 - z1 <= P^2 + P
142    let b1 = F::from_u64((P_SQR + x1 - z1) as u64);
143
144    // SAFE: multiplying `u64` values within the range of `Mersennes31` doesn't overflow:
145    // 2 * (2^31 - 1) * (2^31 - 1) = 2 * (2^62 - 2^32 + 1) < 2^64 - 1
146    let z2 = y2 * w1 + y1 * w2; // 0 <= z2 <= 2*P^2
147
148    // 0 <= x2 + z2 <= 2*P^2 + P
149    let a2 = F::from_u64((x2 + z2) as u64);
150    // -2*P^2 <= x2 - z2 <= P
151    let b2 = F::from_u64((TWO_P_SQR + x2 - z2) as u64);
152
153    *x = C::new_complex(a1, a2);
154    *y = C::new_complex(b1, b2);
155}