Skip to main content

p3_monty_31/dft/
forward.rs

1#![allow(clippy::use_self)]
2
3//! Discrete Fourier Transform, in-place, decimation-in-frequency
4//!
5//! Straightforward recursive algorithm, "unrolled" up to size 256.
6//!
7//! Inspired by Bernstein's djbfft: https://cr.yp.to/djbfft.html
8
9extern crate alloc;
10
11use alloc::vec::Vec;
12
13use itertools::izip;
14use p3_field::{Field, PackedFieldPow2, PackedValue, PrimeCharacteristicRing, TwoAdicField};
15use p3_util::log2_strict_usize;
16
17use crate::utils::monty_reduce;
18use crate::{FieldParameters, MontyField31, TwoAdicData};
19
20impl<MP: FieldParameters + TwoAdicData> MontyField31<MP> {
21    /// Given a field element `gen` of order n where `n = 2^lg_n`,
22    /// return a vector of vectors `table` where table[i] is the
23    /// vector of twiddle factors for an fft of length n/2^i. The
24    /// values g_i^k for k >= i/2 are skipped as these are just the
25    /// negatives of the other roots (using g_i^{i/2} = -1).  The
26    /// value gen^0 = 1 is included to aid consistency between the
27    /// packed and non-packed variants.
28    pub fn roots_of_unity_table(n: usize) -> Vec<Vec<Self>> {
29        let lg_n = log2_strict_usize(n);
30        let generator = Self::two_adic_generator(lg_n);
31        let half_n = 1 << (lg_n - 1);
32        // nth_roots = [1, g, g^2, g^3, ..., g^{n/2 - 1}]
33        let nth_roots = generator.powers().collect_n(half_n);
34
35        (0..(lg_n - 1))
36            .map(|i| nth_roots.iter().step_by(1 << i).copied().collect())
37            .rev()
38            .collect()
39    }
40
41    pub fn get_missing_twiddles(req_lg_n: usize, cur_lg_n: usize) -> Vec<Vec<Self>> {
42        // Get the main generator for the largest required FFT size.
43        let main_generator = Self::two_adic_generator(req_lg_n);
44
45        (cur_lg_n..req_lg_n)
46            .map(|level| {
47                // For a given 'level', we're generating twiddles for a DIF pass
48                // where the number of butterflies is m = 2^level.
49                let count = 1 << level;
50
51                // The generator for this smaller FFT size is a power of the main generator.
52                //
53                // The exponent is 2^(req_lg_n - (level + 1)).
54                let sub_generator_exp = 1 << (req_lg_n - level - 1);
55                let sub_generator = main_generator.exp_u64(sub_generator_exp as u64);
56
57                // Now, we can collect the 'count' powers of this specific sub-generator.
58                sub_generator.powers().collect_n(count)
59            })
60            .collect()
61    }
62}
63
64#[inline(always)]
65fn forward_butterfly<T: PrimeCharacteristicRing + Copy>(x: T, y: T, roots: T) -> (T, T) {
66    let t = x - y;
67    (x + y, t * roots)
68}
69
70/// Architecture-dispatched DIF butterfly for packed `MontyField31` vectors.
71///
72/// The DIF butterfly computes `(x + y, (x - y) · ω)` where `ω` is a twiddle factor.
73///
74/// On aarch64, this delegates to `PackedMontyField31Neon::forward_butterfly`,
75/// which fuses the subtraction and multiplication to skip the modular reduction
76/// on `x - y`. See that method's documentation for the full rationale.
77///
78/// On other architectures, this falls back to the generic `forward_butterfly`.
79///
80/// TODO: apply the same fused sub+mul optimization for AVX2/AVX-512 backends.
81#[inline(always)]
82fn monty_forward_butterfly<MP: FieldParameters + TwoAdicData>(
83    x: <MontyField31<MP> as Field>::Packing,
84    y: <MontyField31<MP> as Field>::Packing,
85    roots: <MontyField31<MP> as Field>::Packing,
86) -> (
87    <MontyField31<MP> as Field>::Packing,
88    <MontyField31<MP> as Field>::Packing,
89) {
90    #[cfg(target_arch = "aarch64")]
91    {
92        x.forward_butterfly(y, roots)
93    }
94    #[cfg(not(target_arch = "aarch64"))]
95    {
96        forward_butterfly(x, y, roots)
97    }
98}
99
100#[inline(always)]
101fn forward_butterfly_interleaved<const HALF_RADIX: usize, T: PackedFieldPow2>(
102    x: T,
103    y: T,
104    roots: T,
105) -> (T, T) {
106    let (x, y) = x.interleave(y, HALF_RADIX);
107    let (x, y) = forward_butterfly(x, y, roots);
108    x.interleave(y, HALF_RADIX)
109}
110
111#[inline]
112fn forward_iterative_packed<const HALF_RADIX: usize, T: PackedFieldPow2>(
113    input: &mut [T],
114    roots: &[T::Scalar],
115) {
116    // roots[0] == 1
117    // roots <-- [1, roots[1], ..., roots[HALF_RADIX-1], 1, roots[1], ...]
118    let roots = T::from_fn(|i| roots[i % HALF_RADIX]);
119
120    input.chunks_exact_mut(2).for_each(|pair| {
121        let (x, y) = forward_butterfly_interleaved::<HALF_RADIX, _>(pair[0], pair[1], roots);
122        pair[0] = x;
123        pair[1] = y;
124    });
125}
126
127#[inline]
128fn forward_iterative_packed_radix_2<T: PackedFieldPow2>(input: &mut [T]) {
129    input.chunks_exact_mut(2).for_each(|pair| {
130        let x = pair[0];
131        let y = pair[1];
132        let (mut x, y) = x.interleave(y, 1);
133        let t = x - y; // roots[0] == 1
134        x += y;
135        let (x, y) = x.interleave(t, 1);
136        pair[0] = x;
137        pair[1] = y;
138    });
139}
140
141impl<MP: FieldParameters + TwoAdicData> MontyField31<MP> {
142    /// Apply one DIF layer of butterflies across the packed input.
143    ///
144    /// At FFT layer `lg_m`, the input is viewed as groups of `2m` elements.
145    /// Each group is split into a top half (`xs`) and bottom half (`ys`),
146    /// and we apply the DIF butterfly pairwise:
147    ///
148    /// ```text
149    ///     xs[i]  ──┬──(+)──->  xs[i]         (= x + y)
150    ///              │
151    ///     ys[i]  ──┴──(−)──->  ys[i] · ω[i]  (= (x − y) · ω)
152    /// ```
153    ///
154    /// Uses `monty_forward_butterfly` for the fused sub+mul optimization.
155    #[inline]
156    fn forward_iterative_layer(
157        packed_input: &mut [<Self as Field>::Packing],
158        roots: &[Self],
159        m: usize,
160    ) {
161        debug_assert_eq!(roots.len(), m);
162        let packed_roots = <Self as Field>::Packing::pack_slice(roots);
163
164        // lg_m >= 4, so m = 2^lg_m >= 2^4, hence packing_width divides m
165        let packed_m = m / <Self as Field>::Packing::WIDTH;
166        packed_input
167            .chunks_exact_mut(2 * packed_m)
168            .for_each(|layer_chunk| {
169                let (xs, ys) = unsafe { layer_chunk.split_at_mut_unchecked(packed_m) };
170
171                izip!(xs, ys, packed_roots)
172                    .for_each(|(x, y, &root)| (*x, *y) = monty_forward_butterfly(*x, *y, root));
173            });
174    }
175
176    /// First DIF pass: split the entire array in half and butterfly.
177    ///
178    /// This is a specialization of `forward_iterative_layer` for the very
179    /// first layer (`lg_m = lg_n - 1`), where `m = n/2`. The array is split
180    /// into exactly two halves and each pair of elements is butterflied with
181    /// the corresponding twiddle factor.
182    ///
183    /// Specializing this avoids the `chunks_exact_mut` overhead for the
184    /// common case of a single split.
185    #[inline]
186    fn monty_forward_pass_packed(input: &mut [<Self as Field>::Packing], roots: &[Self]) {
187        let packed_roots = <Self as Field>::Packing::pack_slice(roots);
188        let n = input.len();
189        let (xs, ys) = unsafe { input.split_at_mut_unchecked(n / 2) };
190
191        izip!(xs, ys, packed_roots)
192            .for_each(|(x, y, &roots)| (*x, *y) = monty_forward_butterfly(*x, *y, roots));
193    }
194
195    /// Second DIF pass: split into quarters and butterfly with shared roots.
196    ///
197    /// This is a specialization of `forward_iterative_layer` for the second
198    /// layer (`lg_m = lg_n - 2`), where `m = n/4`. The array has four
199    /// quarters, and the top-left/top-right pair shares the same twiddle
200    /// factors as the bottom-left/bottom-right pair:
201    ///
202    /// ```text
203    ///     ┌────────────────────────────┐
204    ///     │  xs  │  ys  │  zs  │  ws   │
205    ///     └────────────────────────────┘
206    ///       ↕ ω     ↕ ω    ↕ ω    ↕ ω
207    ///
208    ///     butterfly(xs[i], ys[i], ω[i])
209    ///     butterfly(zs[i], ws[i], ω[i])   ← same ω
210    /// ```
211    ///
212    /// Processing two butterfly pairs per loop iteration improves
213    /// instruction-level parallelism.
214    #[inline]
215    fn monty_forward_iterative_layer_1(input: &mut [<Self as Field>::Packing], roots: &[Self]) {
216        let packed_roots = <Self as Field>::Packing::pack_slice(roots);
217        let n = input.len();
218        let (top_half, bottom_half) = unsafe { input.split_at_mut_unchecked(n / 2) };
219        let (xs, ys) = unsafe { top_half.split_at_mut_unchecked(n / 4) };
220        let (zs, ws) = unsafe { bottom_half.split_at_mut_unchecked(n / 4) };
221
222        izip!(xs, ys, zs, ws, packed_roots).for_each(|(x, y, z, w, &root)| {
223            (*x, *y) = monty_forward_butterfly(*x, *y, root);
224            (*z, *w) = monty_forward_butterfly(*z, *w, root);
225        });
226    }
227
228    #[inline]
229    fn forward_iterative_packed_radix_16(input: &mut [<Self as Field>::Packing]) {
230        // Rather surprisingly, a version similar where the separate
231        // loops in each call to forward_iterative_packed() are
232        // combined into one, was not only not faster, but was
233        // actually a bit slower.
234
235        // Radix 16
236        if <Self as Field>::Packing::WIDTH >= 16 {
237            forward_iterative_packed::<8, _>(input, MP::ROOTS_16.as_ref());
238        } else {
239            Self::forward_iterative_layer(input, MP::ROOTS_16.as_ref(), 8);
240        }
241
242        // Radix 8
243        if <Self as Field>::Packing::WIDTH >= 8 {
244            forward_iterative_packed::<4, _>(input, MP::ROOTS_8.as_ref());
245        } else {
246            Self::forward_iterative_layer(input, MP::ROOTS_8.as_ref(), 4);
247        }
248
249        // Radix 4
250        let roots4 = [MP::ROOTS_8.as_ref()[0], MP::ROOTS_8.as_ref()[2]];
251        if <Self as Field>::Packing::WIDTH >= 4 {
252            forward_iterative_packed::<2, _>(input, &roots4);
253        } else {
254            Self::forward_iterative_layer(input, &roots4, 2);
255        }
256
257        // Radix 2
258        forward_iterative_packed_radix_2(input);
259    }
260
261    /// Breadth-first DIF FFT for smallish vectors (must be >= 64)
262    #[inline]
263    fn forward_iterative(packed_input: &mut [<Self as Field>::Packing], root_table: &[Vec<Self>]) {
264        assert!(packed_input.len() >= 2);
265        let packing_width = <Self as Field>::Packing::WIDTH;
266        let n = packed_input.len() * packing_width;
267        let lg_n = log2_strict_usize(n);
268        debug_assert_eq!(root_table.len(), lg_n - 1);
269
270        // Stop loop early to do radix 16 separately. This value is determined by the largest
271        // packing width we will encounter, which is 16 at the moment for AVX512. Specifically
272        // it is log_2(max{possible packing widths}) = lg(16) = 4.
273        const LAST_LOOP_LAYER: usize = 4;
274
275        // How many layers have we specialised before the main loop
276        const NUM_SPECIALISATIONS: usize = 2;
277
278        // Needed to avoid overlap of the 2 specialisations at the start
279        // with the radix-16 specialisation at the end of the loop
280        assert!(lg_n >= LAST_LOOP_LAYER + NUM_SPECIALISATIONS);
281
282        // Specialise the first NUM_SPECIALISATIONS iterations; improves performance a little.
283        Self::monty_forward_pass_packed(packed_input, &root_table[lg_n - 2]); // lg_m == lg_n - 1, s == 0
284        Self::monty_forward_iterative_layer_1(packed_input, &root_table[lg_n - 3]); // lg_m == lg_n - 2, s == 1
285
286        // loop from lg_n-2 down to 4.
287        for lg_m in (LAST_LOOP_LAYER..(lg_n - NUM_SPECIALISATIONS)).rev() {
288            let m = 1 << lg_m;
289
290            let roots = &root_table[lg_m - 1];
291            debug_assert_eq!(roots.len(), m);
292
293            Self::forward_iterative_layer(packed_input, roots, m);
294        }
295
296        // Last 4 layers
297        Self::forward_iterative_packed_radix_16(packed_input);
298    }
299
300    #[inline(always)]
301    fn forward_butterfly(x: Self, y: Self, w: Self) -> (Self, Self) {
302        let t = MP::PRIME + x.value - y.value;
303        (
304            x + y,
305            Self::new_monty(monty_reduce::<MP>(t as u64 * w.value as u64)),
306        )
307    }
308
309    #[inline]
310    fn forward_pass(input: &mut [Self], roots: &[Self]) {
311        let half_n = input.len() / 2;
312        assert_eq!(roots.len(), half_n);
313
314        // Safe because 0 <= half_n < a.len()
315        let (xs, ys) = unsafe { input.split_at_mut_unchecked(half_n) };
316
317        let s = xs[0] + ys[0];
318        let t = xs[0] - ys[0];
319        xs[0] = s;
320        ys[0] = t;
321
322        izip!(&mut xs[1..], &mut ys[1..], &roots[1..]).for_each(|(x, y, &root)| {
323            (*x, *y) = Self::forward_butterfly(*x, *y, root);
324        });
325    }
326
327    #[inline(always)]
328    fn forward_2(a: &mut [Self]) {
329        assert_eq!(a.len(), 2);
330
331        let s = a[0] + a[1];
332        let t = a[0] - a[1];
333        a[0] = s;
334        a[1] = t;
335    }
336
337    #[inline(always)]
338    fn forward_4(a: &mut [Self]) {
339        assert_eq!(a.len(), 4);
340
341        // Expanding the calculation of t3 saves one instruction
342        let t1 = MP::PRIME + a[1].value - a[3].value;
343        let t3 = Self::new_monty(monty_reduce::<MP>(
344            t1 as u64 * MP::ROOTS_8.as_ref()[2].value as u64,
345        ));
346        let t5 = a[1] + a[3];
347        let t4 = a[0] + a[2];
348        let t2 = a[0] - a[2];
349
350        // Return in bit-reversed order
351        a[0] = t4 + t5;
352        a[1] = t4 - t5;
353        a[2] = t2 + t3;
354        a[3] = t2 - t3;
355    }
356
357    #[inline(always)]
358    fn forward_8(a: &mut [Self]) {
359        assert_eq!(a.len(), 8);
360
361        Self::forward_pass(a, MP::ROOTS_8.as_ref());
362
363        // Safe because a.len() == 8
364        let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) };
365        Self::forward_4(a0);
366        Self::forward_4(a1);
367    }
368
369    #[inline(always)]
370    fn forward_16(a: &mut [Self]) {
371        assert_eq!(a.len(), 16);
372
373        Self::forward_pass(a, MP::ROOTS_16.as_ref());
374
375        // Safe because a.len() == 16
376        let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) };
377        Self::forward_8(a0);
378        Self::forward_8(a1);
379    }
380
381    #[inline(always)]
382    fn forward_32(a: &mut [Self], root_table: &[Vec<Self>]) {
383        assert_eq!(a.len(), 32);
384
385        Self::forward_pass(a, &root_table[root_table.len() - 1]);
386
387        // Safe because a.len() == 32
388        let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) };
389        Self::forward_16(a0);
390        Self::forward_16(a1);
391    }
392
393    /// Assumes `input.len() >= 64`.
394    #[inline]
395    fn forward_fft_recur(input: &mut [<Self as Field>::Packing], root_table: &[Vec<Self>]) {
396        const ITERATIVE_FFT_THRESHOLD: usize = 1024;
397
398        let n = input.len() * <Self as Field>::Packing::WIDTH;
399        if n <= ITERATIVE_FFT_THRESHOLD {
400            Self::forward_iterative(input, root_table);
401        } else {
402            assert_eq!(n, 1 << (root_table.len() + 1));
403            Self::monty_forward_pass_packed(input, &root_table[root_table.len() - 1]);
404
405            // Safe because input.len() > ITERATIVE_FFT_THRESHOLD
406            let (a0, a1) = unsafe { input.split_at_mut_unchecked(input.len() / 2) };
407
408            Self::forward_fft_recur(a0, &root_table[..root_table.len() - 1]);
409            Self::forward_fft_recur(a1, &root_table[..root_table.len() - 1]);
410        }
411    }
412
413    #[inline]
414    pub fn forward_fft(input: &mut [Self], root_table: &[Vec<Self>]) {
415        let n = input.len();
416        if n == 1 {
417            return;
418        }
419        assert_eq!(n, 1 << (root_table.len() + 1));
420        match n {
421            32 => Self::forward_32(input, root_table),
422            16 => Self::forward_16(input),
423            8 => Self::forward_8(input),
424            4 => Self::forward_4(input),
425            2 => Self::forward_2(input),
426            _ => {
427                let packed_input = <Self as Field>::Packing::pack_slice_mut(input);
428                Self::forward_fft_recur(packed_input, root_table);
429            }
430        }
431    }
432}