p3_monty_31/dft/
backward.rs

1#![allow(clippy::use_self)]
2
3//! Discrete Fourier Transform, in-place, decimation-in-time
4//!
5//! Straightforward recursive algorithm, "unrolled" up to size 256.
6//!
7//! Inspired by Bernstein's djbfft: https://cr.yp.to/djbfft.html
8
9extern crate alloc;
10use alloc::vec::Vec;
11
12use itertools::izip;
13use p3_field::{Field, PackedFieldPow2, PackedValue, PrimeCharacteristicRing};
14use p3_util::log2_strict_usize;
15
16use crate::utils::monty_reduce;
17use crate::{FieldParameters, MontyField31, TwoAdicData};
18
19#[inline(always)]
20fn backward_butterfly<T: PrimeCharacteristicRing + Copy>(x: T, y: T, roots: T) -> (T, T) {
21    let t = y * roots;
22    (x + t, x - t)
23}
24
25#[inline(always)]
26fn backward_butterfly_interleaved<const HALF_RADIX: usize, T: PackedFieldPow2>(
27    x: T,
28    y: T,
29    roots: T,
30) -> (T, T) {
31    let (x, y) = x.interleave(y, HALF_RADIX);
32    let (x, y) = backward_butterfly(x, y, roots);
33    x.interleave(y, HALF_RADIX)
34}
35
36#[inline]
37fn backward_pass_packed<T: PackedFieldPow2>(input: &mut [T], roots: &[T::Scalar]) {
38    let packed_roots = T::pack_slice(roots);
39    let n = input.len();
40    let (xs, ys) = unsafe { input.split_at_mut_unchecked(n / 2) };
41
42    izip!(xs, ys, packed_roots)
43        .for_each(|(x, y, &roots)| (*x, *y) = backward_butterfly(*x, *y, roots));
44}
45
46#[inline]
47fn backward_iterative_layer_1<T: PackedFieldPow2>(input: &mut [T], roots: &[T::Scalar]) {
48    let packed_roots = T::pack_slice(roots);
49    let n = input.len();
50    let (top_half, bottom_half) = unsafe { input.split_at_mut_unchecked(n / 2) };
51    let (xs, ys) = unsafe { top_half.split_at_mut_unchecked(n / 4) };
52    let (zs, ws) = unsafe { bottom_half.split_at_mut_unchecked(n / 4) };
53
54    izip!(xs, ys, zs, ws, packed_roots).for_each(|(x, y, z, w, &root)| {
55        (*x, *y) = backward_butterfly(*x, *y, root);
56        (*z, *w) = backward_butterfly(*z, *w, root);
57    });
58}
59
60#[inline]
61fn backward_iterative_packed<const HALF_RADIX: usize, T: PackedFieldPow2>(
62    input: &mut [T],
63    roots: &[T::Scalar],
64) {
65    // roots[0] == 1
66    // roots <-- [1, roots[1], ..., roots[HALF_RADIX-1], 1, roots[1], ...]
67    let roots = T::from_fn(|i| roots[i % HALF_RADIX]);
68
69    input.chunks_exact_mut(2).for_each(|pair| {
70        let (x, y) = backward_butterfly_interleaved::<HALF_RADIX, _>(pair[0], pair[1], roots);
71        pair[0] = x;
72        pair[1] = y;
73    });
74}
75
76#[inline]
77fn backward_iterative_packed_radix_2<T: PackedFieldPow2>(input: &mut [T]) {
78    input.chunks_exact_mut(2).for_each(|pair| {
79        let x = pair[0];
80        let y = pair[1];
81        let (mut x, y) = x.interleave(y, 1);
82        let t = x - y; // roots[0] == 1
83        x += y;
84        let (x, y) = x.interleave(t, 1);
85        pair[0] = x;
86        pair[1] = y;
87    });
88}
89
90impl<MP: FieldParameters + TwoAdicData> MontyField31<MP> {
91    /// Breadth-first DIT FFT for smallish vectors (must be >= 64)
92    #[inline]
93    fn backward_iterative_layer(
94        packed_input: &mut [<Self as Field>::Packing],
95        roots: &[Self],
96        m: usize,
97    ) {
98        debug_assert_eq!(roots.len(), m);
99        let packed_roots = <Self as Field>::Packing::pack_slice(roots);
100
101        // lg_m >= 4, so m = 2^lg_m >= 2^4, hence packing_width divides m
102        let packed_m = m / <Self as Field>::Packing::WIDTH;
103        packed_input
104            .chunks_exact_mut(2 * packed_m)
105            .for_each(|layer_chunk| {
106                let (xs, ys) = unsafe { layer_chunk.split_at_mut_unchecked(packed_m) };
107
108                izip!(xs, ys, packed_roots)
109                    .for_each(|(x, y, &root)| (*x, *y) = backward_butterfly(*x, *y, root));
110            });
111    }
112
113    #[inline]
114    fn backward_iterative_packed_radix_16(input: &mut [<Self as Field>::Packing]) {
115        // Rather surprisingly, a version similar where the separate
116        // loops in each call to backward_iterative_packed() are
117        // combined into one, was not only not faster, but was
118        // actually a bit slower.
119
120        // Radix 2
121        backward_iterative_packed_radix_2(input);
122
123        // Radix 4
124        let roots4 = [MP::INV_ROOTS_8.as_ref()[0], MP::INV_ROOTS_8.as_ref()[2]];
125        if <Self as Field>::Packing::WIDTH >= 4 {
126            backward_iterative_packed::<2, _>(input, &roots4);
127        } else {
128            Self::backward_iterative_layer(input, &roots4, 2);
129        }
130
131        // Radix 8
132        if <Self as Field>::Packing::WIDTH >= 8 {
133            backward_iterative_packed::<4, _>(input, MP::INV_ROOTS_8.as_ref());
134        } else {
135            Self::backward_iterative_layer(input, MP::INV_ROOTS_8.as_ref(), 4);
136        }
137
138        // Radix 16
139        if <Self as Field>::Packing::WIDTH >= 16 {
140            backward_iterative_packed::<8, _>(input, MP::INV_ROOTS_16.as_ref());
141        } else {
142            Self::backward_iterative_layer(input, MP::INV_ROOTS_16.as_ref(), 8);
143        }
144    }
145
146    fn backward_iterative(packed_input: &mut [<Self as Field>::Packing], root_table: &[Vec<Self>]) {
147        assert!(packed_input.len() >= 2);
148        let packing_width = <Self as Field>::Packing::WIDTH;
149        let n = packed_input.len() * packing_width;
150        let lg_n = log2_strict_usize(n);
151        debug_assert_eq!(root_table.len(), lg_n - 1);
152
153        // Start loop after doing radix 16 separately. This value is determined by the largest
154        // packing width we will encounter, which is 16 at the moment for AVX512. Specifically
155        // it is log_2(max{possible packing widths}) = lg(16) = 4.
156        const FIRST_LOOP_LAYER: usize = 4;
157
158        // How many layers have we specialised after the main loop
159        const NUM_SPECIALISATIONS: usize = 2;
160
161        // Needed to avoid overlap of the 2 specialisations at the start
162        // with the radix-16 specialisation at the end of the loop
163        assert!(lg_n >= FIRST_LOOP_LAYER + NUM_SPECIALISATIONS);
164
165        Self::backward_iterative_packed_radix_16(packed_input);
166
167        for lg_m in FIRST_LOOP_LAYER..(lg_n - NUM_SPECIALISATIONS) {
168            let m = 1 << lg_m;
169
170            let roots = &root_table[lg_m - 1];
171            debug_assert_eq!(roots.len(), m);
172
173            Self::backward_iterative_layer(packed_input, roots, m);
174        }
175        // Specialise the last few iterations; improves performance a little.
176        backward_iterative_layer_1(packed_input, &root_table[lg_n - 3]); // lg_m == lg_n - 2, s == 1
177        backward_pass_packed(packed_input, &root_table[lg_n - 2]); // lg_m == lg_n - 1, s == 0
178    }
179
180    #[inline]
181    fn backward_pass(input: &mut [Self], roots: &[Self]) {
182        let half_n = input.len() / 2;
183        assert_eq!(roots.len(), half_n);
184
185        // Safe because 0 <= half_n < input.len()
186        let (xs, ys) = unsafe { input.split_at_mut_unchecked(half_n) };
187
188        let s = xs[0] + ys[0];
189        let t = xs[0] - ys[0];
190        xs[0] = s;
191        ys[0] = t;
192
193        izip!(&mut xs[1..], &mut ys[1..], &roots[1..]).for_each(|(x, y, &root)| {
194            (*x, *y) = backward_butterfly(*x, *y, root);
195        });
196    }
197
198    #[inline(always)]
199    fn backward_2(a: &mut [Self]) {
200        assert_eq!(a.len(), 2);
201
202        let s = a[0] + a[1];
203        let t = a[0] - a[1];
204        a[0] = s;
205        a[1] = t;
206    }
207
208    #[inline(always)]
209    fn backward_4(a: &mut [Self]) {
210        assert_eq!(a.len(), 4);
211
212        // Read in bit-reversed order
213        let a0 = a[0];
214        let a2 = a[1];
215        let a1 = a[2];
216        let a3 = a[3];
217
218        // Expanding the calculation of t3 saves one instruction
219        let t1 = MP::PRIME + a1.value - a3.value;
220        let t3 = Self::new_monty(monty_reduce::<MP>(
221            t1 as u64 * MP::INV_ROOTS_8.as_ref()[2].value as u64,
222        ));
223        let t5 = a1 + a3;
224        let t4 = a0 + a2;
225        let t2 = a0 - a2;
226
227        a[0] = t4 + t5;
228        a[1] = t2 + t3;
229        a[2] = t4 - t5;
230        a[3] = t2 - t3;
231    }
232
233    #[inline(always)]
234    fn backward_8(a: &mut [Self]) {
235        assert_eq!(a.len(), 8);
236
237        // Safe because a.len() == 8
238        let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) };
239        Self::backward_4(a0);
240        Self::backward_4(a1);
241
242        Self::backward_pass(a, MP::INV_ROOTS_8.as_ref());
243    }
244
245    #[inline(always)]
246    fn backward_16(a: &mut [Self]) {
247        assert_eq!(a.len(), 16);
248
249        // Safe because a.len() == 16
250        let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) };
251        Self::backward_8(a0);
252        Self::backward_8(a1);
253
254        Self::backward_pass(a, MP::INV_ROOTS_16.as_ref());
255    }
256
257    #[inline(always)]
258    fn backward_32(a: &mut [Self], root_table: &[Vec<Self>]) {
259        assert_eq!(a.len(), 32);
260
261        // Safe because a.len() == 32
262        let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) };
263        Self::backward_16(a0);
264        Self::backward_16(a1);
265
266        Self::backward_pass(a, &root_table[root_table.len() - 1]);
267    }
268
269    /// Assumes `input.len() >= 64`.
270    /// current packing widths.
271    #[inline]
272    fn backward_fft_recur(input: &mut [<Self as Field>::Packing], root_table: &[Vec<Self>]) {
273        const ITERATIVE_FFT_THRESHOLD: usize = 1024;
274
275        let n = input.len() * <Self as Field>::Packing::WIDTH;
276        if n <= ITERATIVE_FFT_THRESHOLD {
277            Self::backward_iterative(input, root_table);
278        } else {
279            assert_eq!(n, 1 << (root_table.len() + 1));
280
281            // Safe because input.len() > ITERATIVE_FFT_THRESHOLD
282            let (a0, a1) = unsafe { input.split_at_mut_unchecked(input.len() / 2) };
283            Self::backward_fft_recur(a0, &root_table[..root_table.len() - 1]);
284            Self::backward_fft_recur(a1, &root_table[..root_table.len() - 1]);
285
286            backward_pass_packed(input, &root_table[root_table.len() - 1]);
287        }
288    }
289
290    #[inline]
291    pub fn backward_fft(input: &mut [Self], root_table: &[Vec<Self>]) {
292        let n = input.len();
293        if n == 1 {
294            return;
295        }
296
297        assert_eq!(n, 1 << (root_table.len() + 1));
298        match n {
299            32 => Self::backward_32(input, root_table),
300            16 => Self::backward_16(input),
301            8 => Self::backward_8(input),
302            4 => Self::backward_4(input),
303            2 => Self::backward_2(input),
304            _ => {
305                let packed_input = <Self as Field>::Packing::pack_slice_mut(input);
306                Self::backward_fft_recur(packed_input, root_table);
307            }
308        }
309    }
310}