1use crate::domain::{
5 radix2::{fft, EvaluationDomain, Radix2EvaluationDomain},
6 utils::compute_powers_serial,
7 DomainCoeff,
8};
9use ark_ff::FftField;
10use ark_std::{cfg_chunks_mut, vec::*};
11#[cfg(feature = "parallel")]
12use rayon::prelude::*;
13
14#[derive(PartialEq, Eq, Debug)]
15enum FFTOrder {
16 II,
18 IO,
21 OI,
24}
25
26impl<F: FftField> Radix2EvaluationDomain<F> {
27 pub(crate) fn degree_aware_fft_in_place<T: DomainCoeff<F>>(&self, coeffs: &mut Vec<T>) {
30 if !self.offset.is_one() {
31 Self::distribute_powers(&mut *coeffs, self.offset);
32 }
33 let n = self.size();
34 let log_n = self.log_size_of_group;
35 let num_coeffs = if coeffs.len().is_power_of_two() {
36 coeffs.len()
37 } else {
38 coeffs.len().checked_next_power_of_two().unwrap()
39 };
40 let log_d = ark_std::log2(num_coeffs);
41 let duplicity_of_initials = 1 << log_n.checked_sub(log_d).expect("domain is too small");
50
51 coeffs.resize(n, T::zero());
52
53 for i in 0..num_coeffs as u64 {
55 let ri = fft::bitrev(i, log_n);
56 if i < ri {
57 coeffs.swap(i as usize, ri as usize);
58 }
59 }
60
61 if duplicity_of_initials > 1 {
63 ark_std::cfg_chunks_mut!(coeffs, duplicity_of_initials).for_each(|chunk| {
64 let v = chunk[0];
65 chunk[1..].fill(v);
66 });
67 }
68
69 let start_gap = duplicity_of_initials;
70 self.oi_helper(&mut *coeffs, self.group_gen, start_gap);
71 }
72
73 #[allow(unused)]
74 pub(crate) fn in_order_fft_in_place<T: DomainCoeff<F>>(&self, x_s: &mut [T]) {
75 if !self.offset.is_one() {
76 Self::distribute_powers(x_s, self.offset);
77 }
78 self.fft_helper_in_place(x_s, FFTOrder::II);
79 }
80
81 pub(crate) fn in_order_ifft_in_place<T: DomainCoeff<F>>(&self, x_s: &mut [T]) {
82 self.ifft_helper_in_place(x_s, FFTOrder::II);
83 if self.offset.is_one() {
84 ark_std::cfg_iter_mut!(x_s).for_each(|val| *val *= self.size_inv);
85 } else {
86 Self::distribute_powers_and_mul_by_const(x_s, self.offset_inv, self.size_inv);
87 }
88 }
89
90 fn fft_helper_in_place<T: DomainCoeff<F>>(&self, x_s: &mut [T], ord: FFTOrder) {
91 use FFTOrder::*;
92
93 let log_len = ark_std::log2(x_s.len());
94
95 if ord == OI {
96 self.oi_helper(x_s, self.group_gen, 1);
97 } else {
98 self.io_helper(x_s, self.group_gen);
99 }
100
101 if ord == II {
102 derange(x_s, log_len);
103 }
104 }
105
106 fn ifft_helper_in_place<T: DomainCoeff<F>>(&self, x_s: &mut [T], ord: FFTOrder) {
110 use FFTOrder::*;
111
112 let log_len = ark_std::log2(x_s.len());
113
114 if ord == II {
115 derange(x_s, log_len);
116 }
117
118 if ord == IO {
119 self.io_helper(x_s, self.group_gen_inv);
120 } else {
121 self.oi_helper(x_s, self.group_gen_inv, 1);
122 }
123 }
124
125 #[cfg(not(feature = "parallel"))]
129 pub(super) fn roots_of_unity(&self, root: F) -> Vec<F> {
130 compute_powers_serial((self.size as usize) / 2, root)
131 }
132
133 #[cfg(feature = "parallel")]
135 pub(super) fn roots_of_unity(&self, root: F) -> Vec<F> {
136 let log_size = ark_std::log2(self.size as usize);
138 if log_size <= LOG_ROOTS_OF_UNITY_PARALLEL_SIZE {
140 compute_powers_serial((self.size as usize) / 2, root)
141 } else {
142 let mut temp = root;
143 let log_powers: Vec<F> = (0..(log_size - 1))
145 .map(|_| {
146 let old_value = temp;
147 temp.square_in_place();
148 old_value
149 })
150 .collect();
151
152 let mut powers = vec![F::zero(); 1 << (log_size - 1)];
154 Self::roots_of_unity_recursive(&mut powers, &log_powers);
155 powers
156 }
157 }
158
159 #[cfg(feature = "parallel")]
160 fn roots_of_unity_recursive(out: &mut [F], log_powers: &[F]) {
161 assert_eq!(out.len(), 1 << log_powers.len());
162 if log_powers.len() <= LOG_ROOTS_OF_UNITY_PARALLEL_SIZE as usize {
165 out[0] = F::one();
166 for idx in 1..out.len() {
167 out[idx] = out[idx - 1] * log_powers[0];
168 }
169 return;
170 }
171
172 let (lr_lo, lr_hi) = log_powers.split_at((1 + log_powers.len()) / 2);
175 let mut scr_lo = vec![F::default(); 1 << lr_lo.len()];
176 let mut scr_hi = vec![F::default(); 1 << lr_hi.len()];
177 rayon::join(
179 || Self::roots_of_unity_recursive(&mut scr_lo, lr_lo),
180 || Self::roots_of_unity_recursive(&mut scr_hi, lr_hi),
181 );
182 out.par_chunks_mut(scr_lo.len())
185 .zip(&scr_hi)
186 .for_each(|(out_chunk, scr_hi)| {
187 for (out_elem, scr_lo) in out_chunk.iter_mut().zip(&scr_lo) {
188 *out_elem = *scr_hi * scr_lo;
189 }
190 });
191 }
192
193 #[inline(always)]
194 fn butterfly_fn_io<T: DomainCoeff<F>>(((lo, hi), root): ((&mut T, &mut T), &F)) {
195 let mut neg = *lo;
196 neg -= *hi;
197
198 *lo += *hi;
199
200 *hi = neg;
201 *hi *= *root;
202 }
203
204 #[inline(always)]
205 fn butterfly_fn_oi<T: DomainCoeff<F>>(((lo, hi), root): ((&mut T, &mut T), &F)) {
206 *hi *= *root;
207
208 let mut neg = *lo;
209 neg -= *hi;
210
211 *lo += *hi;
212
213 *hi = neg;
214 }
215
216 #[allow(clippy::too_many_arguments)]
217 fn apply_butterfly<T: DomainCoeff<F>, G: Fn(((&mut T, &mut T), &F)) + Copy + Sync + Send>(
218 g: G,
219 xi: &mut [T],
220 roots: &[F],
221 step: usize,
222 chunk_size: usize,
223 num_chunks: usize,
224 max_threads: usize,
225 gap: usize,
226 ) {
227 if xi.len() <= MIN_INPUT_SIZE_FOR_PARALLELIZATION {
228 xi.chunks_mut(chunk_size).for_each(|cxi| {
229 let (lo, hi) = cxi.split_at_mut(gap);
230 lo.iter_mut()
231 .zip(hi)
232 .zip(roots.iter().step_by(step))
233 .for_each(g);
234 });
235 } else {
236 cfg_chunks_mut!(xi, chunk_size).for_each(|cxi| {
237 let (lo, hi) = cxi.split_at_mut(gap);
238 if gap > MIN_GAP_SIZE_FOR_PARALLELIZATION && num_chunks < max_threads {
242 cfg_iter_mut!(lo)
243 .zip(hi)
244 .zip(cfg_iter!(roots).step_by(step))
245 .for_each(g);
246 } else {
247 lo.iter_mut()
248 .zip(hi)
249 .zip(roots.iter().step_by(step))
250 .for_each(g);
251 }
252 });
253 }
254 }
255
256 fn io_helper<T: DomainCoeff<F>>(&self, xi: &mut [T], root: F) {
257 let mut roots = self.roots_of_unity(root);
258 let mut step = 1;
259 let mut first = true;
260
261 #[cfg(feature = "parallel")]
262 let max_threads = rayon::current_num_threads();
263 #[cfg(not(feature = "parallel"))]
264 let max_threads = 1;
265
266 let mut gap = xi.len() / 2;
267 while gap > 0 {
268 let chunk_size = 2 * gap;
270 let num_chunks = xi.len() / chunk_size;
271
272 if num_chunks >= MIN_NUM_CHUNKS_FOR_COMPACTION {
276 if !first {
277 roots = cfg_into_iter!(roots).step_by(step * 2).collect();
278 }
279 step = 1;
280 roots.shrink_to_fit();
281 } else {
282 step = num_chunks;
283 }
284 first = false;
285
286 Self::apply_butterfly(
287 Self::butterfly_fn_io,
288 xi,
289 &roots,
290 step,
291 chunk_size,
292 num_chunks,
293 max_threads,
294 gap,
295 );
296
297 gap /= 2;
298 }
299 }
300
301 fn oi_helper<T: DomainCoeff<F>>(&self, xi: &mut [T], root: F, start_gap: usize) {
302 let roots_cache = self.roots_of_unity(root);
303
304 let compaction_max_size = core::cmp::min(
309 roots_cache.len() / 2,
310 roots_cache.len() / MIN_NUM_CHUNKS_FOR_COMPACTION,
311 );
312 let mut compacted_roots = vec![F::default(); compaction_max_size];
313
314 #[cfg(feature = "parallel")]
315 let max_threads = rayon::current_num_threads();
316 #[cfg(not(feature = "parallel"))]
317 let max_threads = 1;
318
319 let mut gap = start_gap;
320 while gap < xi.len() {
321 let chunk_size = 2 * gap;
323 let num_chunks = xi.len() / chunk_size;
324
325 let (roots, step) = if num_chunks >= MIN_NUM_CHUNKS_FOR_COMPACTION && gap < xi.len() / 2
329 {
330 cfg_iter!(roots_cache)
331 .step_by(num_chunks)
332 .zip(&mut compacted_roots[..gap])
333 .for_each(|(b, a)| *a = *b);
334
335 (&compacted_roots[..gap], 1)
336 } else {
337 (&roots_cache[..], num_chunks)
338 };
339
340 Self::apply_butterfly(
341 Self::butterfly_fn_oi,
342 xi,
343 roots,
344 step,
345 chunk_size,
346 num_chunks,
347 max_threads,
348 gap,
349 );
350
351 gap *= 2;
352 }
353 }
354}
355
356const MIN_NUM_CHUNKS_FOR_COMPACTION: usize = 1 << 7;
359
360const MIN_GAP_SIZE_FOR_PARALLELIZATION: usize = 1 << 10;
363
364const MIN_INPUT_SIZE_FOR_PARALLELIZATION: usize = 1 << 10;
367
368#[cfg(feature = "parallel")]
370const LOG_ROOTS_OF_UNITY_PARALLEL_SIZE: u32 = 7;
371
372#[inline]
373fn bitrev(a: u64, log_len: u32) -> u64 {
374 a.reverse_bits().wrapping_shr(64 - log_len)
375}
376
377fn derange<T>(xi: &mut [T], log_len: u32) {
378 for idx in 1..(xi.len() as u64 - 1) {
379 let ridx = bitrev(idx, log_len);
380 if idx < ridx {
381 xi.swap(idx as usize, ridx as usize);
382 }
383 }
384}