1#![allow(clippy::use_self)]
2
3extern crate alloc;
10
11use alloc::vec::Vec;
12
13use itertools::izip;
14use p3_field::{Field, PackedFieldPow2, PackedValue, PrimeCharacteristicRing, TwoAdicField};
15use p3_util::log2_strict_usize;
16
17use crate::utils::monty_reduce;
18use crate::{FieldParameters, MontyField31, TwoAdicData};
19
20impl<MP: FieldParameters + TwoAdicData> MontyField31<MP> {
21 pub fn roots_of_unity_table(n: usize) -> Vec<Vec<Self>> {
29 let lg_n = log2_strict_usize(n);
30 let generator = Self::two_adic_generator(lg_n);
31 let half_n = 1 << (lg_n - 1);
32 let nth_roots = generator.powers().collect_n(half_n);
34
35 (0..(lg_n - 1))
36 .map(|i| nth_roots.iter().step_by(1 << i).copied().collect())
37 .rev()
38 .collect()
39 }
40
41 pub fn get_missing_twiddles(req_lg_n: usize, cur_lg_n: usize) -> Vec<Vec<Self>> {
42 let main_generator = Self::two_adic_generator(req_lg_n);
44
45 (cur_lg_n..req_lg_n)
46 .map(|level| {
47 let count = 1 << level;
50
51 let sub_generator_exp = 1 << (req_lg_n - level - 1);
55 let sub_generator = main_generator.exp_u64(sub_generator_exp as u64);
56
57 sub_generator.powers().collect_n(count)
59 })
60 .collect()
61 }
62}
63
64#[inline(always)]
65fn forward_butterfly<T: PrimeCharacteristicRing + Copy>(x: T, y: T, roots: T) -> (T, T) {
66 let t = x - y;
67 (x + y, t * roots)
68}
69
70#[inline(always)]
82fn monty_forward_butterfly<MP: FieldParameters + TwoAdicData>(
83 x: <MontyField31<MP> as Field>::Packing,
84 y: <MontyField31<MP> as Field>::Packing,
85 roots: <MontyField31<MP> as Field>::Packing,
86) -> (
87 <MontyField31<MP> as Field>::Packing,
88 <MontyField31<MP> as Field>::Packing,
89) {
90 #[cfg(target_arch = "aarch64")]
91 {
92 x.forward_butterfly(y, roots)
93 }
94 #[cfg(not(target_arch = "aarch64"))]
95 {
96 forward_butterfly(x, y, roots)
97 }
98}
99
100#[inline(always)]
101fn forward_butterfly_interleaved<const HALF_RADIX: usize, T: PackedFieldPow2>(
102 x: T,
103 y: T,
104 roots: T,
105) -> (T, T) {
106 let (x, y) = x.interleave(y, HALF_RADIX);
107 let (x, y) = forward_butterfly(x, y, roots);
108 x.interleave(y, HALF_RADIX)
109}
110
111#[inline]
112fn forward_iterative_packed<const HALF_RADIX: usize, T: PackedFieldPow2>(
113 input: &mut [T],
114 roots: &[T::Scalar],
115) {
116 let roots = T::from_fn(|i| roots[i % HALF_RADIX]);
119
120 input.chunks_exact_mut(2).for_each(|pair| {
121 let (x, y) = forward_butterfly_interleaved::<HALF_RADIX, _>(pair[0], pair[1], roots);
122 pair[0] = x;
123 pair[1] = y;
124 });
125}
126
127#[inline]
128fn forward_iterative_packed_radix_2<T: PackedFieldPow2>(input: &mut [T]) {
129 input.chunks_exact_mut(2).for_each(|pair| {
130 let x = pair[0];
131 let y = pair[1];
132 let (mut x, y) = x.interleave(y, 1);
133 let t = x - y; x += y;
135 let (x, y) = x.interleave(t, 1);
136 pair[0] = x;
137 pair[1] = y;
138 });
139}
140
141impl<MP: FieldParameters + TwoAdicData> MontyField31<MP> {
142 #[inline]
156 fn forward_iterative_layer(
157 packed_input: &mut [<Self as Field>::Packing],
158 roots: &[Self],
159 m: usize,
160 ) {
161 debug_assert_eq!(roots.len(), m);
162 let packed_roots = <Self as Field>::Packing::pack_slice(roots);
163
164 let packed_m = m / <Self as Field>::Packing::WIDTH;
166 packed_input
167 .chunks_exact_mut(2 * packed_m)
168 .for_each(|layer_chunk| {
169 let (xs, ys) = unsafe { layer_chunk.split_at_mut_unchecked(packed_m) };
170
171 izip!(xs, ys, packed_roots)
172 .for_each(|(x, y, &root)| (*x, *y) = monty_forward_butterfly(*x, *y, root));
173 });
174 }
175
176 #[inline]
186 fn monty_forward_pass_packed(input: &mut [<Self as Field>::Packing], roots: &[Self]) {
187 let packed_roots = <Self as Field>::Packing::pack_slice(roots);
188 let n = input.len();
189 let (xs, ys) = unsafe { input.split_at_mut_unchecked(n / 2) };
190
191 izip!(xs, ys, packed_roots)
192 .for_each(|(x, y, &roots)| (*x, *y) = monty_forward_butterfly(*x, *y, roots));
193 }
194
195 #[inline]
215 fn monty_forward_iterative_layer_1(input: &mut [<Self as Field>::Packing], roots: &[Self]) {
216 let packed_roots = <Self as Field>::Packing::pack_slice(roots);
217 let n = input.len();
218 let (top_half, bottom_half) = unsafe { input.split_at_mut_unchecked(n / 2) };
219 let (xs, ys) = unsafe { top_half.split_at_mut_unchecked(n / 4) };
220 let (zs, ws) = unsafe { bottom_half.split_at_mut_unchecked(n / 4) };
221
222 izip!(xs, ys, zs, ws, packed_roots).for_each(|(x, y, z, w, &root)| {
223 (*x, *y) = monty_forward_butterfly(*x, *y, root);
224 (*z, *w) = monty_forward_butterfly(*z, *w, root);
225 });
226 }
227
228 #[inline]
229 fn forward_iterative_packed_radix_16(input: &mut [<Self as Field>::Packing]) {
230 if <Self as Field>::Packing::WIDTH >= 16 {
237 forward_iterative_packed::<8, _>(input, MP::ROOTS_16.as_ref());
238 } else {
239 Self::forward_iterative_layer(input, MP::ROOTS_16.as_ref(), 8);
240 }
241
242 if <Self as Field>::Packing::WIDTH >= 8 {
244 forward_iterative_packed::<4, _>(input, MP::ROOTS_8.as_ref());
245 } else {
246 Self::forward_iterative_layer(input, MP::ROOTS_8.as_ref(), 4);
247 }
248
249 let roots4 = [MP::ROOTS_8.as_ref()[0], MP::ROOTS_8.as_ref()[2]];
251 if <Self as Field>::Packing::WIDTH >= 4 {
252 forward_iterative_packed::<2, _>(input, &roots4);
253 } else {
254 Self::forward_iterative_layer(input, &roots4, 2);
255 }
256
257 forward_iterative_packed_radix_2(input);
259 }
260
261 #[inline]
263 fn forward_iterative(packed_input: &mut [<Self as Field>::Packing], root_table: &[Vec<Self>]) {
264 assert!(packed_input.len() >= 2);
265 let packing_width = <Self as Field>::Packing::WIDTH;
266 let n = packed_input.len() * packing_width;
267 let lg_n = log2_strict_usize(n);
268 debug_assert_eq!(root_table.len(), lg_n - 1);
269
270 const LAST_LOOP_LAYER: usize = 4;
274
275 const NUM_SPECIALISATIONS: usize = 2;
277
278 assert!(lg_n >= LAST_LOOP_LAYER + NUM_SPECIALISATIONS);
281
282 Self::monty_forward_pass_packed(packed_input, &root_table[lg_n - 2]); Self::monty_forward_iterative_layer_1(packed_input, &root_table[lg_n - 3]); for lg_m in (LAST_LOOP_LAYER..(lg_n - NUM_SPECIALISATIONS)).rev() {
288 let m = 1 << lg_m;
289
290 let roots = &root_table[lg_m - 1];
291 debug_assert_eq!(roots.len(), m);
292
293 Self::forward_iterative_layer(packed_input, roots, m);
294 }
295
296 Self::forward_iterative_packed_radix_16(packed_input);
298 }
299
300 #[inline(always)]
301 fn forward_butterfly(x: Self, y: Self, w: Self) -> (Self, Self) {
302 let t = MP::PRIME + x.value - y.value;
303 (
304 x + y,
305 Self::new_monty(monty_reduce::<MP>(t as u64 * w.value as u64)),
306 )
307 }
308
309 #[inline]
310 fn forward_pass(input: &mut [Self], roots: &[Self]) {
311 let half_n = input.len() / 2;
312 assert_eq!(roots.len(), half_n);
313
314 let (xs, ys) = unsafe { input.split_at_mut_unchecked(half_n) };
316
317 let s = xs[0] + ys[0];
318 let t = xs[0] - ys[0];
319 xs[0] = s;
320 ys[0] = t;
321
322 izip!(&mut xs[1..], &mut ys[1..], &roots[1..]).for_each(|(x, y, &root)| {
323 (*x, *y) = Self::forward_butterfly(*x, *y, root);
324 });
325 }
326
327 #[inline(always)]
328 fn forward_2(a: &mut [Self]) {
329 assert_eq!(a.len(), 2);
330
331 let s = a[0] + a[1];
332 let t = a[0] - a[1];
333 a[0] = s;
334 a[1] = t;
335 }
336
337 #[inline(always)]
338 fn forward_4(a: &mut [Self]) {
339 assert_eq!(a.len(), 4);
340
341 let t1 = MP::PRIME + a[1].value - a[3].value;
343 let t3 = Self::new_monty(monty_reduce::<MP>(
344 t1 as u64 * MP::ROOTS_8.as_ref()[2].value as u64,
345 ));
346 let t5 = a[1] + a[3];
347 let t4 = a[0] + a[2];
348 let t2 = a[0] - a[2];
349
350 a[0] = t4 + t5;
352 a[1] = t4 - t5;
353 a[2] = t2 + t3;
354 a[3] = t2 - t3;
355 }
356
357 #[inline(always)]
358 fn forward_8(a: &mut [Self]) {
359 assert_eq!(a.len(), 8);
360
361 Self::forward_pass(a, MP::ROOTS_8.as_ref());
362
363 let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) };
365 Self::forward_4(a0);
366 Self::forward_4(a1);
367 }
368
369 #[inline(always)]
370 fn forward_16(a: &mut [Self]) {
371 assert_eq!(a.len(), 16);
372
373 Self::forward_pass(a, MP::ROOTS_16.as_ref());
374
375 let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) };
377 Self::forward_8(a0);
378 Self::forward_8(a1);
379 }
380
381 #[inline(always)]
382 fn forward_32(a: &mut [Self], root_table: &[Vec<Self>]) {
383 assert_eq!(a.len(), 32);
384
385 Self::forward_pass(a, &root_table[root_table.len() - 1]);
386
387 let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) };
389 Self::forward_16(a0);
390 Self::forward_16(a1);
391 }
392
393 #[inline]
395 fn forward_fft_recur(input: &mut [<Self as Field>::Packing], root_table: &[Vec<Self>]) {
396 const ITERATIVE_FFT_THRESHOLD: usize = 1024;
397
398 let n = input.len() * <Self as Field>::Packing::WIDTH;
399 if n <= ITERATIVE_FFT_THRESHOLD {
400 Self::forward_iterative(input, root_table);
401 } else {
402 assert_eq!(n, 1 << (root_table.len() + 1));
403 Self::monty_forward_pass_packed(input, &root_table[root_table.len() - 1]);
404
405 let (a0, a1) = unsafe { input.split_at_mut_unchecked(input.len() / 2) };
407
408 Self::forward_fft_recur(a0, &root_table[..root_table.len() - 1]);
409 Self::forward_fft_recur(a1, &root_table[..root_table.len() - 1]);
410 }
411 }
412
413 #[inline]
414 pub fn forward_fft(input: &mut [Self], root_table: &[Vec<Self>]) {
415 let n = input.len();
416 if n == 1 {
417 return;
418 }
419 assert_eq!(n, 1 << (root_table.len() + 1));
420 match n {
421 32 => Self::forward_32(input, root_table),
422 16 => Self::forward_16(input),
423 8 => Self::forward_8(input),
424 4 => Self::forward_4(input),
425 2 => Self::forward_2(input),
426 _ => {
427 let packed_input = <Self as Field>::Packing::pack_slice_mut(input);
428 Self::forward_fft_recur(packed_input, root_table);
429 }
430 }
431 }
432}