1#![allow(clippy::use_self)]
2
3extern crate alloc;
10use alloc::vec::Vec;
11
12use itertools::izip;
13use p3_field::{Field, PackedFieldPow2, PackedValue, PrimeCharacteristicRing};
14use p3_util::log2_strict_usize;
15
16use crate::utils::monty_reduce;
17use crate::{FieldParameters, MontyField31, TwoAdicData};
18
19#[inline(always)]
20fn backward_butterfly<T: PrimeCharacteristicRing + Copy>(x: T, y: T, roots: T) -> (T, T) {
21 let t = y * roots;
22 (x + t, x - t)
23}
24
25#[inline(always)]
26fn backward_butterfly_interleaved<const HALF_RADIX: usize, T: PackedFieldPow2>(
27 x: T,
28 y: T,
29 roots: T,
30) -> (T, T) {
31 let (x, y) = x.interleave(y, HALF_RADIX);
32 let (x, y) = backward_butterfly(x, y, roots);
33 x.interleave(y, HALF_RADIX)
34}
35
36#[inline]
37fn backward_pass_packed<T: PackedFieldPow2>(input: &mut [T], roots: &[T::Scalar]) {
38 let packed_roots = T::pack_slice(roots);
39 let n = input.len();
40 let (xs, ys) = unsafe { input.split_at_mut_unchecked(n / 2) };
41
42 izip!(xs, ys, packed_roots)
43 .for_each(|(x, y, &roots)| (*x, *y) = backward_butterfly(*x, *y, roots));
44}
45
46#[inline]
47fn backward_iterative_layer_1<T: PackedFieldPow2>(input: &mut [T], roots: &[T::Scalar]) {
48 let packed_roots = T::pack_slice(roots);
49 let n = input.len();
50 let (top_half, bottom_half) = unsafe { input.split_at_mut_unchecked(n / 2) };
51 let (xs, ys) = unsafe { top_half.split_at_mut_unchecked(n / 4) };
52 let (zs, ws) = unsafe { bottom_half.split_at_mut_unchecked(n / 4) };
53
54 izip!(xs, ys, zs, ws, packed_roots).for_each(|(x, y, z, w, &root)| {
55 (*x, *y) = backward_butterfly(*x, *y, root);
56 (*z, *w) = backward_butterfly(*z, *w, root);
57 });
58}
59
60#[inline]
61fn backward_iterative_packed<const HALF_RADIX: usize, T: PackedFieldPow2>(
62 input: &mut [T],
63 roots: &[T::Scalar],
64) {
65 let roots = T::from_fn(|i| roots[i % HALF_RADIX]);
68
69 input.chunks_exact_mut(2).for_each(|pair| {
70 let (x, y) = backward_butterfly_interleaved::<HALF_RADIX, _>(pair[0], pair[1], roots);
71 pair[0] = x;
72 pair[1] = y;
73 });
74}
75
76#[inline]
77fn backward_iterative_packed_radix_2<T: PackedFieldPow2>(input: &mut [T]) {
78 input.chunks_exact_mut(2).for_each(|pair| {
79 let x = pair[0];
80 let y = pair[1];
81 let (mut x, y) = x.interleave(y, 1);
82 let t = x - y; x += y;
84 let (x, y) = x.interleave(t, 1);
85 pair[0] = x;
86 pair[1] = y;
87 });
88}
89
90impl<MP: FieldParameters + TwoAdicData> MontyField31<MP> {
91 #[inline]
93 fn backward_iterative_layer(
94 packed_input: &mut [<Self as Field>::Packing],
95 roots: &[Self],
96 m: usize,
97 ) {
98 debug_assert_eq!(roots.len(), m);
99 let packed_roots = <Self as Field>::Packing::pack_slice(roots);
100
101 let packed_m = m / <Self as Field>::Packing::WIDTH;
103 packed_input
104 .chunks_exact_mut(2 * packed_m)
105 .for_each(|layer_chunk| {
106 let (xs, ys) = unsafe { layer_chunk.split_at_mut_unchecked(packed_m) };
107
108 izip!(xs, ys, packed_roots)
109 .for_each(|(x, y, &root)| (*x, *y) = backward_butterfly(*x, *y, root));
110 });
111 }
112
113 #[inline]
114 fn backward_iterative_packed_radix_16(input: &mut [<Self as Field>::Packing]) {
115 backward_iterative_packed_radix_2(input);
122
123 let roots4 = [MP::INV_ROOTS_8.as_ref()[0], MP::INV_ROOTS_8.as_ref()[2]];
125 if <Self as Field>::Packing::WIDTH >= 4 {
126 backward_iterative_packed::<2, _>(input, &roots4);
127 } else {
128 Self::backward_iterative_layer(input, &roots4, 2);
129 }
130
131 if <Self as Field>::Packing::WIDTH >= 8 {
133 backward_iterative_packed::<4, _>(input, MP::INV_ROOTS_8.as_ref());
134 } else {
135 Self::backward_iterative_layer(input, MP::INV_ROOTS_8.as_ref(), 4);
136 }
137
138 if <Self as Field>::Packing::WIDTH >= 16 {
140 backward_iterative_packed::<8, _>(input, MP::INV_ROOTS_16.as_ref());
141 } else {
142 Self::backward_iterative_layer(input, MP::INV_ROOTS_16.as_ref(), 8);
143 }
144 }
145
146 fn backward_iterative(packed_input: &mut [<Self as Field>::Packing], root_table: &[Vec<Self>]) {
147 assert!(packed_input.len() >= 2);
148 let packing_width = <Self as Field>::Packing::WIDTH;
149 let n = packed_input.len() * packing_width;
150 let lg_n = log2_strict_usize(n);
151 debug_assert_eq!(root_table.len(), lg_n - 1);
152
153 const FIRST_LOOP_LAYER: usize = 4;
157
158 const NUM_SPECIALISATIONS: usize = 2;
160
161 assert!(lg_n >= FIRST_LOOP_LAYER + NUM_SPECIALISATIONS);
164
165 Self::backward_iterative_packed_radix_16(packed_input);
166
167 for lg_m in FIRST_LOOP_LAYER..(lg_n - NUM_SPECIALISATIONS) {
168 let m = 1 << lg_m;
169
170 let roots = &root_table[lg_m - 1];
171 debug_assert_eq!(roots.len(), m);
172
173 Self::backward_iterative_layer(packed_input, roots, m);
174 }
175 backward_iterative_layer_1(packed_input, &root_table[lg_n - 3]); backward_pass_packed(packed_input, &root_table[lg_n - 2]); }
179
180 #[inline]
181 fn backward_pass(input: &mut [Self], roots: &[Self]) {
182 let half_n = input.len() / 2;
183 assert_eq!(roots.len(), half_n);
184
185 let (xs, ys) = unsafe { input.split_at_mut_unchecked(half_n) };
187
188 let s = xs[0] + ys[0];
189 let t = xs[0] - ys[0];
190 xs[0] = s;
191 ys[0] = t;
192
193 izip!(&mut xs[1..], &mut ys[1..], &roots[1..]).for_each(|(x, y, &root)| {
194 (*x, *y) = backward_butterfly(*x, *y, root);
195 });
196 }
197
198 #[inline(always)]
199 fn backward_2(a: &mut [Self]) {
200 assert_eq!(a.len(), 2);
201
202 let s = a[0] + a[1];
203 let t = a[0] - a[1];
204 a[0] = s;
205 a[1] = t;
206 }
207
208 #[inline(always)]
209 fn backward_4(a: &mut [Self]) {
210 assert_eq!(a.len(), 4);
211
212 let a0 = a[0];
214 let a2 = a[1];
215 let a1 = a[2];
216 let a3 = a[3];
217
218 let t1 = MP::PRIME + a1.value - a3.value;
220 let t3 = Self::new_monty(monty_reduce::<MP>(
221 t1 as u64 * MP::INV_ROOTS_8.as_ref()[2].value as u64,
222 ));
223 let t5 = a1 + a3;
224 let t4 = a0 + a2;
225 let t2 = a0 - a2;
226
227 a[0] = t4 + t5;
228 a[1] = t2 + t3;
229 a[2] = t4 - t5;
230 a[3] = t2 - t3;
231 }
232
233 #[inline(always)]
234 fn backward_8(a: &mut [Self]) {
235 assert_eq!(a.len(), 8);
236
237 let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) };
239 Self::backward_4(a0);
240 Self::backward_4(a1);
241
242 Self::backward_pass(a, MP::INV_ROOTS_8.as_ref());
243 }
244
245 #[inline(always)]
246 fn backward_16(a: &mut [Self]) {
247 assert_eq!(a.len(), 16);
248
249 let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) };
251 Self::backward_8(a0);
252 Self::backward_8(a1);
253
254 Self::backward_pass(a, MP::INV_ROOTS_16.as_ref());
255 }
256
257 #[inline(always)]
258 fn backward_32(a: &mut [Self], root_table: &[Vec<Self>]) {
259 assert_eq!(a.len(), 32);
260
261 let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) };
263 Self::backward_16(a0);
264 Self::backward_16(a1);
265
266 Self::backward_pass(a, &root_table[root_table.len() - 1]);
267 }
268
269 #[inline]
272 fn backward_fft_recur(input: &mut [<Self as Field>::Packing], root_table: &[Vec<Self>]) {
273 const ITERATIVE_FFT_THRESHOLD: usize = 1024;
274
275 let n = input.len() * <Self as Field>::Packing::WIDTH;
276 if n <= ITERATIVE_FFT_THRESHOLD {
277 Self::backward_iterative(input, root_table);
278 } else {
279 assert_eq!(n, 1 << (root_table.len() + 1));
280
281 let (a0, a1) = unsafe { input.split_at_mut_unchecked(input.len() / 2) };
283 Self::backward_fft_recur(a0, &root_table[..root_table.len() - 1]);
284 Self::backward_fft_recur(a1, &root_table[..root_table.len() - 1]);
285
286 backward_pass_packed(input, &root_table[root_table.len() - 1]);
287 }
288 }
289
290 #[inline]
291 pub fn backward_fft(input: &mut [Self], root_table: &[Vec<Self>]) {
292 let n = input.len();
293 if n == 1 {
294 return;
295 }
296
297 assert_eq!(n, 1 << (root_table.len() + 1));
298 match n {
299 32 => Self::backward_32(input, root_table),
300 16 => Self::backward_16(input),
301 8 => Self::backward_8(input),
302 4 => Self::backward_4(input),
303 2 => Self::backward_2(input),
304 _ => {
305 let packed_input = <Self as Field>::Packing::pack_slice_mut(input);
306 Self::backward_fft_recur(packed_input, root_table);
307 }
308 }
309 }
310}