ark_poly/domain/radix2/
mod.rs

1//! This module defines `Radix2EvaluationDomain`, an `EvaluationDomain`
2//! for performing various kinds of polynomial arithmetic on top of
3//! fields that are FFT-friendly. `Radix2EvaluationDomain` supports
4//! FFTs of size at most `2^F::TWO_ADICITY`.
5
6pub use crate::domain::utils::Elements;
7use crate::domain::{DomainCoeff, EvaluationDomain};
8use ark_ff::FftField;
9use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
10use ark_std::{fmt, vec::*};
11
12mod fft;
13
14/// Factor that determines if a the degree aware FFT should be called.
15const DEGREE_AWARE_FFT_THRESHOLD_FACTOR: usize = 1 << 2;
16
17/// Defines a domain over which finite field (I)FFTs can be performed. Works
18/// only for fields that have a large multiplicative subgroup of size that is
19/// a power-of-2.
20#[derive(Copy, Clone, Hash, Eq, PartialEq, CanonicalSerialize, CanonicalDeserialize)]
21pub struct Radix2EvaluationDomain<F: FftField> {
22    /// The size of the domain.
23    pub size: u64,
24    /// `log_2(self.size)`.
25    pub log_size_of_group: u32,
26    /// Size of the domain as a field element.
27    pub size_as_field_element: F,
28    /// Inverse of the size in the field.
29    pub size_inv: F,
30    /// A generator of the subgroup.
31    pub group_gen: F,
32    /// Inverse of the generator of the subgroup.
33    pub group_gen_inv: F,
34    /// Offset that specifies the coset.
35    pub offset: F,
36    /// Inverse of the offset that specifies the coset.
37    pub offset_inv: F,
38    /// Constant coefficient for the vanishing polynomial.
39    /// Equals `self.offset^self.size`.
40    pub offset_pow_size: F,
41}
42
43impl<F: FftField> fmt::Debug for Radix2EvaluationDomain<F> {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        write!(f, "Radix-2 multiplicative subgroup of size {}", self.size)
46    }
47}
48
49impl<F: FftField> EvaluationDomain<F> for Radix2EvaluationDomain<F> {
50    type Elements = Elements<F>;
51
52    /// Construct a domain that is large enough for evaluations of a polynomial
53    /// having `num_coeffs` coefficients.
54    fn new(num_coeffs: usize) -> Option<Self> {
55        let size = if num_coeffs.is_power_of_two() {
56            num_coeffs
57        } else {
58            num_coeffs.checked_next_power_of_two()?
59        } as u64;
60        let log_size_of_group = size.trailing_zeros();
61
62        // libfqfft uses > https://github.com/scipr-lab/libfqfft/blob/e0183b2cef7d4c5deb21a6eaf3fe3b586d738fe0/libfqfft/evaluation_domain/domains/basic_radix2_domain.tcc#L33
63        if log_size_of_group > F::TWO_ADICITY {
64            return None;
65        }
66
67        // Compute the generator for the multiplicative subgroup.
68        // It should be the 2^(log_size_of_group) root of unity.
69        let group_gen = F::get_root_of_unity(size)?;
70        // Check that it is indeed the 2^(log_size_of_group) root of unity.
71        debug_assert_eq!(group_gen.pow([size]), F::one());
72        let size_as_field_element = F::from(size);
73        let size_inv = size_as_field_element.inverse()?;
74
75        Some(Radix2EvaluationDomain {
76            size,
77            log_size_of_group,
78            size_as_field_element,
79            size_inv,
80            group_gen,
81            group_gen_inv: group_gen.inverse()?,
82            offset: F::one(),
83            offset_inv: F::one(),
84            offset_pow_size: F::one(),
85        })
86    }
87
88    fn get_coset(&self, offset: F) -> Option<Self> {
89        Some(Radix2EvaluationDomain {
90            offset,
91            offset_inv: offset.inverse()?,
92            offset_pow_size: offset.pow([self.size]),
93            ..*self
94        })
95    }
96
97    fn compute_size_of_domain(num_coeffs: usize) -> Option<usize> {
98        let size = num_coeffs.checked_next_power_of_two()?;
99        if size.trailing_zeros() > F::TWO_ADICITY {
100            None
101        } else {
102            Some(size)
103        }
104    }
105
106    #[inline]
107    fn size(&self) -> usize {
108        usize::try_from(self.size).unwrap()
109    }
110
111    #[inline]
112    fn log_size_of_group(&self) -> u64 {
113        self.log_size_of_group as u64
114    }
115
116    #[inline]
117    fn size_inv(&self) -> F {
118        self.size_inv
119    }
120
121    #[inline]
122    fn group_gen(&self) -> F {
123        self.group_gen
124    }
125
126    #[inline]
127    fn group_gen_inv(&self) -> F {
128        self.group_gen_inv
129    }
130
131    #[inline]
132    fn coset_offset(&self) -> F {
133        self.offset
134    }
135
136    #[inline]
137    fn coset_offset_inv(&self) -> F {
138        self.offset_inv
139    }
140
141    #[inline]
142    fn coset_offset_pow_size(&self) -> F {
143        self.offset_pow_size
144    }
145
146    #[inline]
147    fn fft_in_place<T: DomainCoeff<F>>(&self, coeffs: &mut Vec<T>) {
148        if coeffs.len() * DEGREE_AWARE_FFT_THRESHOLD_FACTOR <= self.size() {
149            self.degree_aware_fft_in_place(coeffs);
150        } else {
151            coeffs.resize(self.size(), T::zero());
152            self.in_order_fft_in_place(coeffs);
153        }
154    }
155
156    #[inline]
157    fn ifft_in_place<T: DomainCoeff<F>>(&self, evals: &mut Vec<T>) {
158        evals.resize(self.size(), T::zero());
159        self.in_order_ifft_in_place(&mut *evals);
160    }
161
162    /// Return an iterator over the elements of the domain.
163    fn elements(&self) -> Elements<F> {
164        Elements {
165            cur_elem: self.offset,
166            cur_pow: 0,
167            size: self.size,
168            group_gen: self.group_gen,
169        }
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::DEGREE_AWARE_FFT_THRESHOLD_FACTOR;
176    use crate::{
177        polynomial::{univariate::*, DenseUVPolynomial, Polynomial},
178        EvaluationDomain, Radix2EvaluationDomain,
179    };
180    use ark_ff::{FftField, Field, One, UniformRand, Zero};
181    use ark_std::{collections::BTreeSet, rand::Rng, test_rng};
182    use ark_test_curves::bls12_381::Fr;
183
184    #[test]
185    fn vanishing_polynomial_evaluation() {
186        let rng = &mut test_rng();
187        for coeffs in 0..10 {
188            let domain = Radix2EvaluationDomain::<Fr>::new(coeffs).unwrap();
189            let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
190            let z = domain.vanishing_polynomial();
191            let z_coset = coset_domain.vanishing_polynomial();
192            for _ in 0..100 {
193                let point: Fr = rng.gen();
194                assert_eq!(
195                    z.evaluate(&point),
196                    domain.evaluate_vanishing_polynomial(point)
197                );
198                assert_eq!(
199                    z_coset.evaluate(&point),
200                    coset_domain.evaluate_vanishing_polynomial(point)
201                );
202            }
203        }
204    }
205
206    #[test]
207    fn vanishing_polynomial_vanishes_on_domain() {
208        for coeffs in 0..1000 {
209            let domain = Radix2EvaluationDomain::<Fr>::new(coeffs).unwrap();
210            let z = domain.vanishing_polynomial();
211            for point in domain.elements() {
212                assert!(z.evaluate(&point).is_zero())
213            }
214
215            let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
216            let z = coset_domain.vanishing_polynomial();
217            for point in coset_domain.elements() {
218                assert!(z.evaluate(&point).is_zero())
219            }
220        }
221    }
222
223    #[test]
224    fn filter_polynomial_test() {
225        for log_domain_size in 1..=4 {
226            let domain_size = 1 << log_domain_size;
227            let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
228            for log_subdomain_size in 1..=log_domain_size {
229                let subdomain_size = 1 << log_subdomain_size;
230                let subdomain = Radix2EvaluationDomain::<Fr>::new(subdomain_size).unwrap();
231
232                // Obtain all possible offsets of `subdomain` within `domain`.
233                let mut possible_offsets = vec![Fr::one()];
234                let domain_generator = domain.group_gen();
235
236                let mut offset = domain_generator;
237                let subdomain_generator = subdomain.group_gen();
238                while offset != subdomain_generator {
239                    possible_offsets.push(offset);
240                    offset *= domain_generator;
241                }
242
243                assert_eq!(possible_offsets.len(), domain_size / subdomain_size);
244
245                // Get all possible cosets of `subdomain` within `domain`.
246                let cosets = possible_offsets
247                    .iter()
248                    .map(|offset| subdomain.get_coset(*offset).unwrap());
249
250                for coset in cosets {
251                    let coset_elements = coset.elements().collect::<BTreeSet<_>>();
252                    let filter_poly = domain.filter_polynomial(&coset);
253                    assert_eq!(filter_poly.degree(), domain_size - subdomain_size);
254                    for element in domain.elements() {
255                        let evaluation = domain.evaluate_filter_polynomial(&coset, element);
256                        assert_eq!(evaluation, filter_poly.evaluate(&element));
257                        if coset_elements.contains(&element) {
258                            assert_eq!(evaluation, Fr::one())
259                        } else {
260                            assert_eq!(evaluation, Fr::zero())
261                        }
262                    }
263                }
264            }
265        }
266    }
267
268    #[test]
269    fn size_of_elements() {
270        for coeffs in 1..10 {
271            let size = 1 << coeffs;
272            let domain = Radix2EvaluationDomain::<Fr>::new(size).unwrap();
273            let domain_size = domain.size();
274            assert_eq!(domain_size, domain.elements().count());
275        }
276    }
277
278    #[test]
279    fn elements_contents() {
280        for coeffs in 1..10 {
281            let size = 1 << coeffs;
282            let domain = Radix2EvaluationDomain::<Fr>::new(size).unwrap();
283            let offset = Fr::GENERATOR;
284            let coset_domain = domain.get_coset(offset).unwrap();
285            for (i, (x, coset_x)) in domain.elements().zip(coset_domain.elements()).enumerate() {
286                assert_eq!(x, domain.group_gen.pow([i as u64]));
287                assert_eq!(x, domain.element(i));
288                assert_eq!(coset_x, offset * coset_domain.group_gen.pow([i as u64]));
289                assert_eq!(coset_x, coset_domain.element(i));
290            }
291        }
292    }
293
294    /// Test that lagrange interpolation for a random polynomial at a random
295    /// point works.
296    #[test]
297    fn non_systematic_lagrange_coefficients_test() {
298        for domain_dim in 1..10 {
299            let domain_size = 1 << domain_dim;
300            let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
301            let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
302            // Get random pt + lagrange coefficients
303            let rand_pt = Fr::rand(&mut test_rng());
304            let lagrange_coeffs = domain.evaluate_all_lagrange_coefficients(rand_pt);
305            let coset_lagrange_coeffs = coset_domain.evaluate_all_lagrange_coefficients(rand_pt);
306
307            // Sample the random polynomial, evaluate it over the domain and the random
308            // point.
309            let rand_poly = DensePolynomial::<Fr>::rand(domain_size - 1, &mut test_rng());
310            let poly_evals = domain.fft(rand_poly.coeffs());
311            let coset_poly_evals = coset_domain.fft(rand_poly.coeffs());
312            let actual_eval = rand_poly.evaluate(&rand_pt);
313
314            // Do lagrange interpolation, and compare against the actual evaluation
315            let mut interpolated_eval = Fr::zero();
316            let mut coset_interpolated_eval = Fr::zero();
317            for i in 0..domain_size {
318                interpolated_eval += lagrange_coeffs[i] * poly_evals[i];
319                coset_interpolated_eval += coset_lagrange_coeffs[i] * coset_poly_evals[i];
320            }
321            assert_eq!(actual_eval, interpolated_eval);
322            assert_eq!(actual_eval, coset_interpolated_eval);
323        }
324    }
325
326    /// Test that lagrange coefficients for a point in the domain is correct
327    #[test]
328    fn systematic_lagrange_coefficients_test() {
329        // This runs in time O(N^2) in the domain size, so keep the domain dimension
330        // low. We generate lagrange coefficients for each element in the domain.
331        for domain_dim in 1..5 {
332            let domain_size = 1 << domain_dim;
333            let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
334            let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
335            for (i, (x, coset_x)) in domain.elements().zip(coset_domain.elements()).enumerate() {
336                let lagrange_coeffs = domain.evaluate_all_lagrange_coefficients(x);
337                let coset_lagrange_coeffs =
338                    coset_domain.evaluate_all_lagrange_coefficients(coset_x);
339                for (j, (y, coset_y)) in lagrange_coeffs
340                    .into_iter()
341                    .zip(coset_lagrange_coeffs)
342                    .enumerate()
343                {
344                    // Lagrange coefficient for the evaluation point, which should be 1
345                    if i == j {
346                        assert_eq!(y, Fr::one());
347                        assert_eq!(coset_y, Fr::one());
348                    } else {
349                        assert_eq!(y, Fr::zero());
350                        assert_eq!(coset_y, Fr::zero());
351                    }
352                }
353            }
354        }
355    }
356
357    #[test]
358    fn test_fft_correctness() {
359        // Tests that the ffts output the correct result.
360        // This assumes a correct polynomial evaluation at point procedure.
361        // It tests consistency of FFT/IFFT, and coset_fft/coset_ifft,
362        // along with testing that each individual evaluation is correct.
363
364        // Runs in time O(degree^2)
365        let log_degree = 5;
366        let degree = 1 << log_degree;
367        let rand_poly = DensePolynomial::<Fr>::rand(degree - 1, &mut test_rng());
368
369        for log_domain_size in log_degree..(log_degree + 2) {
370            let domain_size = 1 << log_domain_size;
371            let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
372            let coset_domain =
373                Radix2EvaluationDomain::<Fr>::new_coset(domain_size, Fr::GENERATOR).unwrap();
374            let poly_evals = domain.fft(&rand_poly.coeffs);
375            let poly_coset_evals = coset_domain.fft(&rand_poly.coeffs);
376
377            for (i, (x, coset_x)) in domain.elements().zip(coset_domain.elements()).enumerate() {
378                assert_eq!(poly_evals[i], rand_poly.evaluate(&x));
379                assert_eq!(poly_coset_evals[i], rand_poly.evaluate(&coset_x));
380            }
381
382            let rand_poly_from_subgroup =
383                DensePolynomial::from_coefficients_vec(domain.ifft(&poly_evals));
384            let rand_poly_from_coset =
385                DensePolynomial::from_coefficients_vec(coset_domain.ifft(&poly_coset_evals));
386
387            assert_eq!(
388                rand_poly, rand_poly_from_subgroup,
389                "degree = {}, domain size = {}",
390                degree, domain_size
391            );
392            assert_eq!(
393                rand_poly, rand_poly_from_coset,
394                "degree = {}, domain size = {}",
395                degree, domain_size
396            );
397        }
398    }
399
400    #[test]
401    fn degree_aware_fft_correctness() {
402        // Test that the degree aware FFT (O(n log d)) matches the regular FFT (O(n log n)).
403        let num_coeffs = 1 << 5;
404        let rand_poly = DensePolynomial::<Fr>::rand(num_coeffs - 1, &mut test_rng());
405        let domain_size = num_coeffs * DEGREE_AWARE_FFT_THRESHOLD_FACTOR;
406        let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
407        let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
408
409        let deg_aware_fft_evals = domain.fft(&rand_poly);
410        let coset_deg_aware_fft_evals = coset_domain.fft(&rand_poly);
411
412        for (i, (x, coset_x)) in domain.elements().zip(coset_domain.elements()).enumerate() {
413            assert_eq!(deg_aware_fft_evals[i], rand_poly.evaluate(&x));
414            assert_eq!(coset_deg_aware_fft_evals[i], rand_poly.evaluate(&coset_x));
415        }
416    }
417
418    #[test]
419    fn test_roots_of_unity() {
420        // Tests that the roots of unity result is the same as domain.elements()
421        let max_degree = 10;
422        for log_domain_size in 0..max_degree {
423            let domain_size = 1 << log_domain_size;
424            let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
425            let actual_roots = domain.roots_of_unity(domain.group_gen);
426            for &value in &actual_roots {
427                assert!(domain.evaluate_vanishing_polynomial(value).is_zero());
428            }
429            let expected_roots_elements = domain.elements();
430            for (expected, &actual) in expected_roots_elements.zip(&actual_roots) {
431                assert_eq!(expected, actual);
432            }
433            assert_eq!(actual_roots.len(), domain_size / 2);
434        }
435    }
436
437    #[test]
438    #[cfg(feature = "parallel")]
439    fn parallel_fft_consistency() {
440        use ark_std::{test_rng, vec::*};
441        use ark_test_curves::bls12_381::Fr;
442
443        // This implements the Cooley-Turkey FFT, derived from libfqfft
444        // The libfqfft implementation uses pseudocode from [CLRS 2n Ed, pp. 864].
445        fn serial_radix2_fft(a: &mut [Fr], omega: Fr, log_n: u32) {
446            use ark_std::convert::TryFrom;
447            let n = u32::try_from(a.len())
448                .expect("cannot perform FFTs larger on vectors of len > (1 << 32)");
449            assert_eq!(n, 1 << log_n);
450
451            // swap coefficients in place
452            for k in 0..n {
453                let rk = crate::domain::utils::bitreverse(k, log_n);
454                if k < rk {
455                    a.swap(rk as usize, k as usize);
456                }
457            }
458
459            let mut m = 1;
460            for _i in 1..=log_n {
461                // w_m is 2^i-th root of unity
462                let w_m = omega.pow([(n / (2 * m)) as u64]);
463
464                let mut k = 0;
465                while k < n {
466                    // w = w_m^j at the start of every loop iteration
467                    let mut w = Fr::one();
468                    for j in 0..m {
469                        let mut t = a[(k + j + m) as usize];
470                        t *= w;
471                        let mut tmp = a[(k + j) as usize];
472                        tmp -= t;
473                        a[(k + j + m) as usize] = tmp;
474                        a[(k + j) as usize] += t;
475                        w *= &w_m;
476                    }
477
478                    k += 2 * m;
479                }
480
481                m *= 2;
482            }
483        }
484
485        fn serial_radix2_ifft(a: &mut [Fr], omega: Fr, log_n: u32) {
486            serial_radix2_fft(a, omega.inverse().unwrap(), log_n);
487            let domain_size_inv = Fr::from(a.len() as u64).inverse().unwrap();
488            for coeff in a.iter_mut() {
489                *coeff *= Fr::from(domain_size_inv);
490            }
491        }
492
493        fn serial_radix2_coset_fft(a: &mut [Fr], omega: Fr, log_n: u32) {
494            let coset_shift = Fr::GENERATOR;
495            let mut cur_pow = Fr::one();
496            for coeff in a.iter_mut() {
497                *coeff *= cur_pow;
498                cur_pow *= coset_shift;
499            }
500            serial_radix2_fft(a, omega, log_n);
501        }
502
503        fn serial_radix2_coset_ifft(a: &mut [Fr], omega: Fr, log_n: u32) {
504            serial_radix2_ifft(a, omega, log_n);
505            let coset_shift = Fr::GENERATOR.inverse().unwrap();
506            let mut cur_pow = Fr::one();
507            for coeff in a.iter_mut() {
508                *coeff *= cur_pow;
509                cur_pow *= coset_shift;
510            }
511        }
512
513        fn test_consistency<R: Rng>(rng: &mut R, max_coeffs: u32) {
514            for _ in 0..5 {
515                for log_d in 0..max_coeffs {
516                    let d = 1 << log_d;
517
518                    let expected_poly = (0..d).map(|_| Fr::rand(rng)).collect::<Vec<_>>();
519                    let mut expected_vec = expected_poly.clone();
520                    let mut actual_vec = expected_vec.clone();
521
522                    let domain = Radix2EvaluationDomain::new(d).unwrap();
523                    let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
524
525                    serial_radix2_fft(&mut expected_vec, domain.group_gen, log_d);
526                    domain.fft_in_place(&mut actual_vec);
527                    assert_eq!(expected_vec, actual_vec);
528
529                    serial_radix2_ifft(&mut expected_vec, domain.group_gen, log_d);
530                    domain.ifft_in_place(&mut actual_vec);
531                    assert_eq!(expected_vec, actual_vec);
532                    assert_eq!(expected_vec, expected_poly);
533
534                    serial_radix2_coset_fft(&mut expected_vec, domain.group_gen, log_d);
535                    coset_domain.fft_in_place(&mut actual_vec);
536                    assert_eq!(expected_vec, actual_vec);
537
538                    serial_radix2_coset_ifft(&mut expected_vec, domain.group_gen, log_d);
539                    coset_domain.ifft_in_place(&mut actual_vec);
540                    assert_eq!(expected_vec, actual_vec);
541                }
542            }
543        }
544
545        let rng = &mut test_rng();
546
547        test_consistency(rng, 10);
548    }
549}