ark_poly/domain/radix2/
fft.rs

1// The code below is a port of the excellent library of https://github.com/kwantam/fffft by Riad S. Wahby
2// to the arkworks APIs
3
4use crate::domain::{
5    radix2::{fft, EvaluationDomain, Radix2EvaluationDomain},
6    utils::compute_powers_serial,
7    DomainCoeff,
8};
9use ark_ff::FftField;
10use ark_std::{cfg_chunks_mut, vec::*};
11#[cfg(feature = "parallel")]
12use rayon::prelude::*;
13
14#[derive(PartialEq, Eq, Debug)]
15enum FFTOrder {
16    /// Both the input and the output of the FFT must be in-order.
17    II,
18    /// The input of the FFT must be in-order, but the output does not have to
19    /// be.
20    IO,
21    /// The input of the FFT can be out of order, but the output must be
22    /// in-order.
23    OI,
24}
25
26impl<F: FftField> Radix2EvaluationDomain<F> {
27    /// Degree aware FFT that runs in O(n log d) instead of O(n log n)
28    /// Implementation copied from libiop.
29    pub(crate) fn degree_aware_fft_in_place<T: DomainCoeff<F>>(&self, coeffs: &mut Vec<T>) {
30        if !self.offset.is_one() {
31            Self::distribute_powers(&mut *coeffs, self.offset);
32        }
33        let n = self.size();
34        let log_n = self.log_size_of_group;
35        let num_coeffs = if coeffs.len().is_power_of_two() {
36            coeffs.len()
37        } else {
38            coeffs.len().checked_next_power_of_two().unwrap()
39        };
40        let log_d = ark_std::log2(num_coeffs);
41        // When the polynomial is of size k*|coset|, for k < 2^i,
42        // the first i iterations of Cooley Tukey are easily predictable.
43        // This is because they will be combining g(w^2) + wh(w^2), but g or h will always refer
44        // to a coefficient that is 0.
45        // Therefore those first i rounds have the effect of copying the evaluations into more locations,
46        // so we handle this in initialization, and reduce the number of loops that are performing arithmetic.
47        // The number of times we copy each initial non-zero element is as below:
48
49        let duplicity_of_initials = 1 << log_n.checked_sub(log_d).expect("domain is too small");
50
51        coeffs.resize(n, T::zero());
52
53        // swap coefficients in place
54        for i in 0..num_coeffs as u64 {
55            let ri = fft::bitrev(i, log_n);
56            if i < ri {
57                coeffs.swap(i as usize, ri as usize);
58            }
59        }
60
61        // duplicate initial values
62        if duplicity_of_initials > 1 {
63            ark_std::cfg_chunks_mut!(coeffs, duplicity_of_initials).for_each(|chunk| {
64                let v = chunk[0];
65                chunk[1..].fill(v);
66            });
67        }
68
69        let start_gap = duplicity_of_initials;
70        self.oi_helper(&mut *coeffs, self.group_gen, start_gap);
71    }
72
73    #[allow(unused)]
74    pub(crate) fn in_order_fft_in_place<T: DomainCoeff<F>>(&self, x_s: &mut [T]) {
75        if !self.offset.is_one() {
76            Self::distribute_powers(x_s, self.offset);
77        }
78        self.fft_helper_in_place(x_s, FFTOrder::II);
79    }
80
81    pub(crate) fn in_order_ifft_in_place<T: DomainCoeff<F>>(&self, x_s: &mut [T]) {
82        self.ifft_helper_in_place(x_s, FFTOrder::II);
83        if self.offset.is_one() {
84            ark_std::cfg_iter_mut!(x_s).for_each(|val| *val *= self.size_inv);
85        } else {
86            Self::distribute_powers_and_mul_by_const(x_s, self.offset_inv, self.size_inv);
87        }
88    }
89
90    fn fft_helper_in_place<T: DomainCoeff<F>>(&self, x_s: &mut [T], ord: FFTOrder) {
91        use FFTOrder::*;
92
93        let log_len = ark_std::log2(x_s.len());
94
95        if ord == OI {
96            self.oi_helper(x_s, self.group_gen, 1);
97        } else {
98            self.io_helper(x_s, self.group_gen);
99        }
100
101        if ord == II {
102            derange(x_s, log_len);
103        }
104    }
105
106    // Handles doing an IFFT with handling of being in order and out of order.
107    // The results here must all be divided by |x_s|,
108    // which is left up to the caller to do.
109    fn ifft_helper_in_place<T: DomainCoeff<F>>(&self, x_s: &mut [T], ord: FFTOrder) {
110        use FFTOrder::*;
111
112        let log_len = ark_std::log2(x_s.len());
113
114        if ord == II {
115            derange(x_s, log_len);
116        }
117
118        if ord == IO {
119            self.io_helper(x_s, self.group_gen_inv);
120        } else {
121            self.oi_helper(x_s, self.group_gen_inv, 1);
122        }
123    }
124
125    /// Computes the first `self.size / 2` roots of unity for the entire domain.
126    /// e.g. for the domain [1, g, g^2, ..., g^{n - 1}], it computes
127    // [1, g, g^2, ..., g^{(n/2) - 1}]
128    #[cfg(not(feature = "parallel"))]
129    pub(super) fn roots_of_unity(&self, root: F) -> Vec<F> {
130        compute_powers_serial((self.size as usize) / 2, root)
131    }
132
133    /// Computes the first `self.size / 2` roots of unity.
134    #[cfg(feature = "parallel")]
135    pub(super) fn roots_of_unity(&self, root: F) -> Vec<F> {
136        // TODO: check if this method can replace parallel compute powers.
137        let log_size = ark_std::log2(self.size as usize);
138        // early exit for short inputs
139        if log_size <= LOG_ROOTS_OF_UNITY_PARALLEL_SIZE {
140            compute_powers_serial((self.size as usize) / 2, root)
141        } else {
142            let mut temp = root;
143            // w, w^2, w^4, w^8, ..., w^(2^(log_size - 1))
144            let log_powers: Vec<F> = (0..(log_size - 1))
145                .map(|_| {
146                    let old_value = temp;
147                    temp.square_in_place();
148                    old_value
149                })
150                .collect();
151
152            // allocate the return array and start the recursion
153            let mut powers = vec![F::zero(); 1 << (log_size - 1)];
154            Self::roots_of_unity_recursive(&mut powers, &log_powers);
155            powers
156        }
157    }
158
159    #[cfg(feature = "parallel")]
160    fn roots_of_unity_recursive(out: &mut [F], log_powers: &[F]) {
161        assert_eq!(out.len(), 1 << log_powers.len());
162        // base case: just compute the powers sequentially,
163        // g = log_powers[0], out = [1, g, g^2, ...]
164        if log_powers.len() <= LOG_ROOTS_OF_UNITY_PARALLEL_SIZE as usize {
165            out[0] = F::one();
166            for idx in 1..out.len() {
167                out[idx] = out[idx - 1] * log_powers[0];
168            }
169            return;
170        }
171
172        // recursive case:
173        // 1. split log_powers in half
174        let (lr_lo, lr_hi) = log_powers.split_at((1 + log_powers.len()) / 2);
175        let mut scr_lo = vec![F::default(); 1 << lr_lo.len()];
176        let mut scr_hi = vec![F::default(); 1 << lr_hi.len()];
177        // 2. compute each half individually
178        rayon::join(
179            || Self::roots_of_unity_recursive(&mut scr_lo, lr_lo),
180            || Self::roots_of_unity_recursive(&mut scr_hi, lr_hi),
181        );
182        // 3. recombine halves
183        // At this point, out is a blank slice.
184        out.par_chunks_mut(scr_lo.len())
185            .zip(&scr_hi)
186            .for_each(|(out_chunk, scr_hi)| {
187                for (out_elem, scr_lo) in out_chunk.iter_mut().zip(&scr_lo) {
188                    *out_elem = *scr_hi * scr_lo;
189                }
190            });
191    }
192
193    #[inline(always)]
194    fn butterfly_fn_io<T: DomainCoeff<F>>(((lo, hi), root): ((&mut T, &mut T), &F)) {
195        let mut neg = *lo;
196        neg -= *hi;
197
198        *lo += *hi;
199
200        *hi = neg;
201        *hi *= *root;
202    }
203
204    #[inline(always)]
205    fn butterfly_fn_oi<T: DomainCoeff<F>>(((lo, hi), root): ((&mut T, &mut T), &F)) {
206        *hi *= *root;
207
208        let mut neg = *lo;
209        neg -= *hi;
210
211        *lo += *hi;
212
213        *hi = neg;
214    }
215
216    #[allow(clippy::too_many_arguments)]
217    fn apply_butterfly<T: DomainCoeff<F>, G: Fn(((&mut T, &mut T), &F)) + Copy + Sync + Send>(
218        g: G,
219        xi: &mut [T],
220        roots: &[F],
221        step: usize,
222        chunk_size: usize,
223        num_chunks: usize,
224        max_threads: usize,
225        gap: usize,
226    ) {
227        if xi.len() <= MIN_INPUT_SIZE_FOR_PARALLELIZATION {
228            xi.chunks_mut(chunk_size).for_each(|cxi| {
229                let (lo, hi) = cxi.split_at_mut(gap);
230                lo.iter_mut()
231                    .zip(hi)
232                    .zip(roots.iter().step_by(step))
233                    .for_each(g);
234            });
235        } else {
236            cfg_chunks_mut!(xi, chunk_size).for_each(|cxi| {
237                let (lo, hi) = cxi.split_at_mut(gap);
238                // If the chunk is sufficiently big that parallelism helps,
239                // we parallelize the butterfly operation within the chunk.
240
241                if gap > MIN_GAP_SIZE_FOR_PARALLELIZATION && num_chunks < max_threads {
242                    cfg_iter_mut!(lo)
243                        .zip(hi)
244                        .zip(cfg_iter!(roots).step_by(step))
245                        .for_each(g);
246                } else {
247                    lo.iter_mut()
248                        .zip(hi)
249                        .zip(roots.iter().step_by(step))
250                        .for_each(g);
251                }
252            });
253        }
254    }
255
256    fn io_helper<T: DomainCoeff<F>>(&self, xi: &mut [T], root: F) {
257        let mut roots = self.roots_of_unity(root);
258        let mut step = 1;
259        let mut first = true;
260
261        #[cfg(feature = "parallel")]
262        let max_threads = rayon::current_num_threads();
263        #[cfg(not(feature = "parallel"))]
264        let max_threads = 1;
265
266        let mut gap = xi.len() / 2;
267        while gap > 0 {
268            // each butterfly cluster uses 2*gap positions
269            let chunk_size = 2 * gap;
270            let num_chunks = xi.len() / chunk_size;
271
272            // Only compact roots to achieve cache locality/compactness if
273            // the roots lookup is done a significant amount of times
274            // Which also implies a large lookup stride.
275            if num_chunks >= MIN_NUM_CHUNKS_FOR_COMPACTION {
276                if !first {
277                    roots = cfg_into_iter!(roots).step_by(step * 2).collect();
278                }
279                step = 1;
280                roots.shrink_to_fit();
281            } else {
282                step = num_chunks;
283            }
284            first = false;
285
286            Self::apply_butterfly(
287                Self::butterfly_fn_io,
288                xi,
289                &roots,
290                step,
291                chunk_size,
292                num_chunks,
293                max_threads,
294                gap,
295            );
296
297            gap /= 2;
298        }
299    }
300
301    fn oi_helper<T: DomainCoeff<F>>(&self, xi: &mut [T], root: F, start_gap: usize) {
302        let roots_cache = self.roots_of_unity(root);
303
304        // The `cmp::min` is only necessary for the case where
305        // `MIN_NUM_CHUNKS_FOR_COMPACTION = 1`. Else, notice that we compact
306        // the roots cache by a stride of at least `MIN_NUM_CHUNKS_FOR_COMPACTION`.
307
308        let compaction_max_size = core::cmp::min(
309            roots_cache.len() / 2,
310            roots_cache.len() / MIN_NUM_CHUNKS_FOR_COMPACTION,
311        );
312        let mut compacted_roots = vec![F::default(); compaction_max_size];
313
314        #[cfg(feature = "parallel")]
315        let max_threads = rayon::current_num_threads();
316        #[cfg(not(feature = "parallel"))]
317        let max_threads = 1;
318
319        let mut gap = start_gap;
320        while gap < xi.len() {
321            // each butterfly cluster uses 2*gap positions
322            let chunk_size = 2 * gap;
323            let num_chunks = xi.len() / chunk_size;
324
325            // Only compact roots to achieve cache locality/compactness if
326            // the roots lookup is done a significant amount of times
327            // Which also implies a large lookup stride.
328            let (roots, step) = if num_chunks >= MIN_NUM_CHUNKS_FOR_COMPACTION && gap < xi.len() / 2
329            {
330                cfg_iter!(roots_cache)
331                    .step_by(num_chunks)
332                    .zip(&mut compacted_roots[..gap])
333                    .for_each(|(b, a)| *a = *b);
334
335                (&compacted_roots[..gap], 1)
336            } else {
337                (&roots_cache[..], num_chunks)
338            };
339
340            Self::apply_butterfly(
341                Self::butterfly_fn_oi,
342                xi,
343                roots,
344                step,
345                chunk_size,
346                num_chunks,
347                max_threads,
348                gap,
349            );
350
351            gap *= 2;
352        }
353    }
354}
355
356/// The minimum number of chunks at which root compaction
357/// is beneficial.
358const MIN_NUM_CHUNKS_FOR_COMPACTION: usize = 1 << 7;
359
360/// The minimum size of a chunk at which parallelization of `butterfly`s is
361/// beneficial. This value was chosen empirically.
362const MIN_GAP_SIZE_FOR_PARALLELIZATION: usize = 1 << 10;
363
364/// The minimum size of a chunk at which parallelization of `butterfly`s is
365/// beneficial. This value was chosen empirically.
366const MIN_INPUT_SIZE_FOR_PARALLELIZATION: usize = 1 << 10;
367
368// minimum size at which to parallelize.
369#[cfg(feature = "parallel")]
370const LOG_ROOTS_OF_UNITY_PARALLEL_SIZE: u32 = 7;
371
372#[inline]
373fn bitrev(a: u64, log_len: u32) -> u64 {
374    a.reverse_bits().wrapping_shr(64 - log_len)
375}
376
377fn derange<T>(xi: &mut [T], log_len: u32) {
378    for idx in 1..(xi.len() as u64 - 1) {
379        let ridx = bitrev(idx, log_len);
380        if idx < ridx {
381            xi.swap(idx as usize, ridx as usize);
382        }
383    }
384}