p3_mersenne_31/
dft.rs

1//! Implementation of DFT for `Mersenne31`.
2//!
3//! Strategy follows: `<https://www.robinscheibler.org/2013/02/13/real-fft.html>`
4//! In short, fold a Mersenne31 DFT of length n into a Mersenne31Complex DFT
5//! of length n/2. Some pre/post-processing is necessary so that the result
6//! of the transform behaves as expected wrt the convolution theorem etc.
7//!
8//! Note that we don't return the final n/2 - 1 elements since we know that
9//! the "complex conjugate" of the (n-k)th element equals the kth element.
10//! The convolution theorem maintains this relationship and so these final
11//! n/2 - 1 elements are essentially redundant.
12
13use alloc::vec::Vec;
14
15use itertools::{Itertools, izip};
16use p3_dft::TwoAdicSubgroupDft;
17use p3_field::extension::Complex;
18use p3_field::{Field, PrimeCharacteristicRing, TwoAdicField};
19use p3_matrix::Matrix;
20use p3_matrix::dense::RowMajorMatrix;
21use p3_util::log2_strict_usize;
22
23use crate::Mersenne31;
24
25type F = Mersenne31;
26type C = Complex<Mersenne31>;
27
28/// Given an hxw matrix M = (m_{ij}) where h is even, return an
29/// (h/2)xw matrix N whose (k,l) entry is
30///
31///    Mersenne31Complex(m_{2k,l}, m_{2k+1,l})
32///
33/// i.e. the even rows become the real parts and the odd rows become
34/// the imaginary parts.
35///
36/// This packing is suitable as input to a Fourier Transform over the
37/// domain Mersenne31Complex; it is inverse to `idft_postprocess()`
38/// below.
39fn dft_preprocess(input: &RowMajorMatrix<F>) -> RowMajorMatrix<C> {
40    assert!(
41        input.height().is_multiple_of(2),
42        "input height must be even"
43    );
44    RowMajorMatrix::new(
45        input
46            .rows()
47            .tuples()
48            .flat_map(|(row_0, row_1)| {
49                // For each pair of rows in input, convert each
50                // two-element column into a Mersenne31Complex
51                // treating the first row as the real part and the
52                // second row as the imaginary part.
53                row_0.zip(row_1).map(|(x, y)| C::new_complex(x, y))
54            })
55            .collect(),
56        input.width(),
57    )
58}
59
60/// Transform the result of applying the DFT to the packed
61/// `Mersenne31` values so that the convolution theorem holds.
62///
63/// Source: https://www.robinscheibler.org/2013/02/13/real-fft.html
64///
65/// NB: This function and `idft_preprocess()` are inverses.
66fn dft_postprocess(input: &RowMajorMatrix<C>) -> RowMajorMatrix<C> {
67    let h = input.height();
68    let log2_h = log2_strict_usize(h); // checks that h is a power of two
69
70    // NB: The original real matrix had height 2h, hence log2(2h) = log2(h) + 1.
71    // omega is a 2h-th root of unity
72    let omega = C::two_adic_generator(log2_h + 1);
73    let mut omega_j = omega;
74
75    let mut output = Vec::with_capacity((h + 1) * input.width());
76    output.extend(
77        input
78            .first_row()
79            .unwrap() // The matrix is non-empty so this unwrap should never panic.
80            .into_iter()
81            .map(|x| C::new_real(x.real() + x.imag())),
82    );
83
84    for j in 1..h {
85        let row_iter = unsafe {
86            // Safety: We know that 0 < j < h = input.height()
87            izip!(input.row_unchecked(j), input.row_unchecked(h - j))
88        };
89        let row = row_iter.map(|(x, y)| {
90            let even = x + y.conjugate();
91            // odd = (x - y.conjugate()) * -i
92            let odd = C::new_complex(x.imag() + y.imag(), y.real() - x.real());
93            (even + odd * omega_j).halve()
94        });
95        output.extend(row);
96        omega_j *= omega;
97    }
98
99    output.extend(
100        input
101            .first_row()
102            .unwrap() // The matrix is non-empty so this unwrap should never panic.
103            .into_iter()
104            .map(|x| C::new_real(x.real() - x.imag())),
105    );
106    debug_assert_eq!(output.len(), (h + 1) * input.width());
107    RowMajorMatrix::new(output, input.width())
108}
109
110/// Undo the transform of the DFT matrix in `dft_postprocess()` so
111/// that the inverse DFT can be applied.
112///
113/// Source: https://www.robinscheibler.org/2013/02/13/real-fft.html
114///
115/// NB: This function and `dft_postprocess()` are inverses.
116fn idft_preprocess(input: &RowMajorMatrix<C>) -> RowMajorMatrix<C> {
117    let h = input.height() - 1;
118    let log2_h = log2_strict_usize(h); // checks that h is a power of two
119
120    // NB: The original real matrix had length 2h, hence log2(2h) = log2(h) + 1.
121    // omega is a 2n-th root of unity
122    let omega = C::two_adic_generator(log2_h + 1).inverse();
123    let mut omega_j = C::ONE;
124
125    let mut output = Vec::with_capacity(h * input.width());
126    // TODO: Specialise j = 0 and j = n (which we know must be real)?
127    for j in 0..h {
128        let row_iter = unsafe {
129            // Safety: We know that 0 = j < h < input.height()
130            izip!(input.row_unchecked(j), input.row_unchecked(h - j))
131        };
132        let row = row_iter.map(|(x, y)| {
133            let even = x + y.conjugate();
134            // odd = (x - y.conjugate()) * -i
135            let odd = C::new_complex(x.imag() + y.imag(), y.real() - x.real());
136            (even - odd * omega_j).halve()
137        });
138        output.extend(row);
139        omega_j *= omega;
140    }
141    RowMajorMatrix::new(output, input.width())
142}
143
144/// Given an (h/2)xw matrix M = (m_{kl}) = (a_{kl} + I*b_{kl}) (where
145/// I is the imaginary unit), return the hxw matrix N whose (i,j)
146/// entry is a_{i/2,j} if i is even and b_{(i-1)/2,j} if i is odd.
147///
148/// This function is inverse to `dft_preprocess()` above.
149fn idft_postprocess(input: &RowMajorMatrix<C>) -> RowMajorMatrix<F> {
150    // Allocate necessary `Vec`s upfront:
151    //   1) The actual output,
152    //   2) A temporary buf to store the imaginary parts.
153    //      This buf is filled and flushed per row
154    //      throughout postprocessing to save on allocations.
155    let mut output = Vec::with_capacity(input.width() * input.height() * 2);
156    let mut buf = Vec::with_capacity(input.width());
157
158    // Convert each row of input into two rows, the first row
159    // having the real parts of the input, the second row
160    // having the imaginary parts.
161    for row in input.rows() {
162        for ext in row {
163            output.push(ext.real());
164            buf.push(ext.imag());
165        }
166        output.append(&mut buf);
167    }
168
169    RowMajorMatrix::new(output, input.width())
170}
171
172/// The DFT for Mersenne31
173#[derive(Debug, Default, Clone)]
174pub struct Mersenne31Dft;
175
176impl Mersenne31Dft {
177    /// Compute the DFT of each column of `mat`.
178    ///
179    /// NB: The DFT works by packing pairs of `Mersenne31` values into
180    /// a `Mersenne31Complex` and doing a (half-length) DFT on the
181    /// result. In particular, the type of the result elements are in
182    /// the extension field, not the domain field.
183    pub fn dft_batch<Dft: TwoAdicSubgroupDft<C>>(mat: &RowMajorMatrix<F>) -> RowMajorMatrix<C> {
184        let dft = Dft::default();
185        dft_postprocess(&dft.dft_batch(dft_preprocess(mat)).to_row_major_matrix())
186    }
187
188    /// Compute the inverse DFT of each column of `mat`.
189    ///
190    /// NB: See comment on `dft_batch()` for information on packing.
191    pub fn idft_batch<Dft: TwoAdicSubgroupDft<C>>(mat: &RowMajorMatrix<C>) -> RowMajorMatrix<F> {
192        let dft = Dft::default();
193        idft_postprocess(&dft.idft_batch(idft_preprocess(mat)))
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use rand::distr::{Distribution, StandardUniform};
200    use rand::rngs::SmallRng;
201    use rand::{Rng, SeedableRng};
202
203    use super::*;
204    use crate::Mersenne31ComplexRadix2Dit;
205
206    type Base = Mersenne31;
207    type Dft = Mersenne31ComplexRadix2Dit;
208
209    #[test]
210    fn consistency()
211    where
212        StandardUniform: Distribution<Base>,
213    {
214        const N: usize = 1 << 12;
215        let rng = SmallRng::seed_from_u64(1);
216        let input = rng
217            .sample_iter(StandardUniform)
218            .take(N)
219            .collect::<Vec<Base>>();
220        let input = RowMajorMatrix::new_col(input);
221        let fft_input = Mersenne31Dft::dft_batch::<Dft>(&input);
222        let output = Mersenne31Dft::idft_batch::<Dft>(&fft_input);
223        assert_eq!(input, output);
224    }
225
226    #[test]
227    fn convolution()
228    where
229        StandardUniform: Distribution<Base>,
230    {
231        const N: usize = 1 << 6;
232        let rng = SmallRng::seed_from_u64(1);
233        let v = rng
234            .sample_iter(StandardUniform)
235            .take(2 * N)
236            .collect::<Vec<Base>>();
237        let a = RowMajorMatrix::new_col(v[..N].to_vec());
238        let b = RowMajorMatrix::new_col(v[N..].to_vec());
239
240        let fft_a = Mersenne31Dft::dft_batch::<Dft>(&a);
241        let fft_b = Mersenne31Dft::dft_batch::<Dft>(&b);
242
243        let fft_c = fft_a
244            .values
245            .iter()
246            .zip(fft_b.values.iter())
247            .map(|(&xi, &yi)| xi * yi)
248            .collect();
249        let fft_c = RowMajorMatrix::new_col(fft_c);
250
251        let c = Mersenne31Dft::idft_batch::<Dft>(&fft_c);
252
253        let mut conv = Vec::with_capacity(N);
254        for i in 0..N {
255            let mut t = Base::ZERO;
256            for j in 0..N {
257                t += a.values[j] * b.values[(N + i - j) % N];
258            }
259            conv.push(t);
260        }
261
262        assert_eq!(c.values, conv);
263    }
264}