Skip to main content

ark_poly/domain/radix2/
mod.rs

1//! This module defines `Radix2EvaluationDomain`, an `EvaluationDomain`
2//! for performing various kinds of polynomial arithmetic on top of
3//! fields that are FFT-friendly.
4//!
5//! `Radix2EvaluationDomain` supports FFTs of size at most `2^F::TWO_ADICITY`.
6
7pub use crate::domain::utils::{bitreverse_permutation_in_place, Elements};
8use crate::domain::{DomainCoeff, EvaluationDomain};
9use ark_ff::{FftField, Field};
10use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
11use ark_std::{fmt, vec::*};
12
13mod fft;
14
15/// Factor that determines if a the degree aware FFT should be called.
16const DEGREE_AWARE_FFT_THRESHOLD_FACTOR: usize = 1 << 2;
17
18/// Defines a domain over which finite field (I)FFTs can be performed. Works
19/// only for fields that have a large multiplicative subgroup of size that is
20/// a power-of-2.
21#[derive(Copy, Clone, Hash, Eq, PartialEq, CanonicalSerialize, CanonicalDeserialize)]
22pub struct Radix2EvaluationDomain<F: Field> {
23    /// The size of the domain.
24    pub size: u64,
25    /// `log_2(self.size)`.
26    pub log_size_of_group: u32,
27    /// Size of the domain as a field element.
28    pub size_as_field_element: F,
29    /// Inverse of the size in the field.
30    pub size_inv: F,
31    /// A generator of the subgroup.
32    pub group_gen: F,
33    /// Inverse of the generator of the subgroup.
34    pub group_gen_inv: F,
35    /// Offset that specifies the coset.
36    pub offset: F,
37    /// Inverse of the offset that specifies the coset.
38    pub offset_inv: F,
39    /// Constant coefficient for the vanishing polynomial.
40    /// Equals `self.offset^self.size`.
41    pub offset_pow_size: F,
42}
43
44impl<F: Field> fmt::Debug for Radix2EvaluationDomain<F> {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        write!(f, "Radix-2 multiplicative subgroup of size {}", self.size)
47    }
48}
49
50impl<F: FftField> EvaluationDomain<F> for Radix2EvaluationDomain<F> {
51    type Elements = Elements<F>;
52
53    /// Construct a domain that is large enough for evaluations of a polynomial
54    /// having `num_coeffs` coefficients.
55    fn new(num_coeffs: usize) -> Option<Self> {
56        let size = num_coeffs.next_power_of_two() as u64;
57
58        let log_size_of_group = size.trailing_zeros();
59
60        // libfqfft uses > https://github.com/scipr-lab/libfqfft/blob/e0183b2cef7d4c5deb21a6eaf3fe3b586d738fe0/libfqfft/evaluation_domain/domains/basic_radix2_domain.tcc#L33
61        if log_size_of_group > F::TWO_ADICITY {
62            return None;
63        }
64
65        // Compute the generator for the multiplicative subgroup.
66        // It should be the 2^(log_size_of_group) root of unity.
67        let group_gen = F::get_root_of_unity(size)?;
68        // Check that it is indeed the 2^(log_size_of_group) root of unity.
69        debug_assert_eq!(group_gen.pow([size]), F::one());
70        let size_as_field_element = F::from(size);
71
72        Some(Self {
73            size,
74            log_size_of_group,
75            size_as_field_element,
76            size_inv: size_as_field_element.inverse()?,
77            group_gen,
78            group_gen_inv: group_gen.inverse()?,
79            offset: F::ONE,
80            offset_inv: F::ONE,
81            offset_pow_size: F::ONE,
82        })
83    }
84
85    fn get_coset(&self, offset: F) -> Option<Self> {
86        Some(Self {
87            offset,
88            offset_inv: offset.inverse()?,
89            offset_pow_size: offset.pow([self.size]),
90            ..*self
91        })
92    }
93
94    fn compute_size_of_domain(num_coeffs: usize) -> Option<usize> {
95        let size = num_coeffs.checked_next_power_of_two()?;
96        (size.trailing_zeros() <= F::TWO_ADICITY).then_some(size)
97    }
98
99    #[inline]
100    fn size(&self) -> usize {
101        self.size.try_into().unwrap()
102    }
103
104    #[inline]
105    fn log_size_of_group(&self) -> u64 {
106        self.log_size_of_group as u64
107    }
108
109    #[inline]
110    fn size_inv(&self) -> F {
111        self.size_inv
112    }
113
114    #[inline]
115    fn group_gen(&self) -> F {
116        self.group_gen
117    }
118
119    #[inline]
120    fn group_gen_inv(&self) -> F {
121        self.group_gen_inv
122    }
123
124    #[inline]
125    fn coset_offset(&self) -> F {
126        self.offset
127    }
128
129    #[inline]
130    fn coset_offset_inv(&self) -> F {
131        self.offset_inv
132    }
133
134    #[inline]
135    fn coset_offset_pow_size(&self) -> F {
136        self.offset_pow_size
137    }
138
139    #[inline]
140    fn fft_in_place<T: DomainCoeff<F>>(&self, coeffs: &mut Vec<T>) {
141        if coeffs.len() * DEGREE_AWARE_FFT_THRESHOLD_FACTOR <= self.size() {
142            self.degree_aware_fft_in_place(coeffs);
143        } else {
144            coeffs.resize(self.size(), T::zero());
145            self.in_order_fft_in_place(coeffs);
146        }
147    }
148
149    #[inline]
150    fn ifft_in_place<T: DomainCoeff<F>>(&self, evals: &mut Vec<T>) {
151        evals.resize(self.size(), T::zero());
152        self.in_order_ifft_in_place(&mut *evals);
153    }
154
155    /// Return an iterator over the elements of the domain.
156    fn elements(&self) -> Elements<F> {
157        Elements {
158            cur_elem: self.offset,
159            cur_pow: 0,
160            size: self.size,
161            group_gen: self.group_gen,
162        }
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::DEGREE_AWARE_FFT_THRESHOLD_FACTOR;
169    use crate::{
170        polynomial::{univariate::*, DenseUVPolynomial, Polynomial},
171        EvaluationDomain, Radix2EvaluationDomain,
172    };
173    use ark_ff::{FftField, Field, One, UniformRand, Zero};
174    use ark_std::{collections::BTreeSet, rand::Rng, test_rng, vec};
175    use ark_test_curves::bls12_381::Fr;
176
177    #[test]
178    fn vanishing_polynomial_evaluation() {
179        let rng = &mut test_rng();
180        for coeffs in 0..10 {
181            let domain = Radix2EvaluationDomain::<Fr>::new(coeffs).unwrap();
182            let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
183            let z = domain.vanishing_polynomial();
184            let z_coset = coset_domain.vanishing_polynomial();
185            for _ in 0..100 {
186                let point: Fr = rng.gen();
187                assert_eq!(
188                    z.evaluate(&point),
189                    domain.evaluate_vanishing_polynomial(point)
190                );
191                assert_eq!(
192                    z_coset.evaluate(&point),
193                    coset_domain.evaluate_vanishing_polynomial(point)
194                );
195            }
196        }
197    }
198
199    #[test]
200    fn vanishing_polynomial_vanishes_on_domain() {
201        for coeffs in 0..1000 {
202            let domain = Radix2EvaluationDomain::<Fr>::new(coeffs).unwrap();
203            let z = domain.vanishing_polynomial();
204            for point in domain.elements() {
205                assert!(z.evaluate(&point).is_zero())
206            }
207
208            let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
209            let z = coset_domain.vanishing_polynomial();
210            for point in coset_domain.elements() {
211                assert!(z.evaluate(&point).is_zero())
212            }
213        }
214    }
215
216    #[test]
217    fn filter_polynomial_test() {
218        for log_domain_size in 1..=4 {
219            let domain_size = 1 << log_domain_size;
220            let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
221            for log_subdomain_size in 1..=log_domain_size {
222                let subdomain_size = 1 << log_subdomain_size;
223                let subdomain = Radix2EvaluationDomain::<Fr>::new(subdomain_size).unwrap();
224
225                // Obtain all possible offsets of `subdomain` within `domain`.
226                let mut possible_offsets = vec![Fr::one()];
227                let domain_generator = domain.group_gen();
228
229                let mut offset = domain_generator;
230                let subdomain_generator = subdomain.group_gen();
231                while offset != subdomain_generator {
232                    possible_offsets.push(offset);
233                    offset *= domain_generator;
234                }
235
236                assert_eq!(possible_offsets.len(), domain_size / subdomain_size);
237
238                // Get all possible cosets of `subdomain` within `domain`.
239                let cosets = possible_offsets
240                    .iter()
241                    .map(|offset| subdomain.get_coset(*offset).unwrap());
242
243                for coset in cosets {
244                    let coset_elements = coset.elements().collect::<BTreeSet<_>>();
245                    let filter_poly = domain.filter_polynomial(&coset);
246                    assert_eq!(filter_poly.degree(), domain_size - subdomain_size);
247                    for element in domain.elements() {
248                        let evaluation = domain.evaluate_filter_polynomial(&coset, element);
249                        assert_eq!(evaluation, filter_poly.evaluate(&element));
250                        if coset_elements.contains(&element) {
251                            assert_eq!(evaluation, Fr::one())
252                        } else {
253                            assert_eq!(evaluation, Fr::zero())
254                        }
255                    }
256                }
257            }
258        }
259    }
260
261    #[test]
262    fn size_of_elements() {
263        for coeffs in 1..10 {
264            let size = 1 << coeffs;
265            let domain = Radix2EvaluationDomain::<Fr>::new(size).unwrap();
266            let domain_size = domain.size();
267            assert_eq!(domain_size, domain.elements().count());
268        }
269    }
270
271    #[test]
272    fn elements_contents() {
273        for coeffs in 1..10 {
274            let size = 1 << coeffs;
275            let domain = Radix2EvaluationDomain::<Fr>::new(size).unwrap();
276            let offset = Fr::GENERATOR;
277            let coset_domain = domain.get_coset(offset).unwrap();
278            for (i, (x, coset_x)) in domain.elements().zip(coset_domain.elements()).enumerate() {
279                assert_eq!(x, domain.group_gen.pow([i as u64]));
280                assert_eq!(x, domain.element(i));
281                assert_eq!(coset_x, offset * coset_domain.group_gen.pow([i as u64]));
282                assert_eq!(coset_x, coset_domain.element(i));
283            }
284        }
285    }
286
287    /// Test that lagrange interpolation for a random polynomial at a random
288    /// point works.
289    #[test]
290    fn non_systematic_lagrange_coefficients_test() {
291        for domain_dim in 1..10 {
292            let domain_size = 1 << domain_dim;
293            let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
294            let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
295            // Get random pt + lagrange coefficients
296            let rand_pt = Fr::rand(&mut test_rng());
297            let lagrange_coeffs = domain.evaluate_all_lagrange_coefficients(rand_pt);
298            let coset_lagrange_coeffs = coset_domain.evaluate_all_lagrange_coefficients(rand_pt);
299
300            // Sample the random polynomial, evaluate it over the domain and the random
301            // point.
302            let rand_poly = DensePolynomial::<Fr>::rand(domain_size - 1, &mut test_rng());
303            let poly_evals = domain.fft(rand_poly.coeffs());
304            let coset_poly_evals = coset_domain.fft(rand_poly.coeffs());
305            let actual_eval = rand_poly.evaluate(&rand_pt);
306
307            // Do lagrange interpolation, and compare against the actual evaluation
308            let mut interpolated_eval = Fr::zero();
309            let mut coset_interpolated_eval = Fr::zero();
310            for i in 0..domain_size {
311                interpolated_eval += lagrange_coeffs[i] * poly_evals[i];
312                coset_interpolated_eval += coset_lagrange_coeffs[i] * coset_poly_evals[i];
313            }
314            assert_eq!(actual_eval, interpolated_eval);
315            assert_eq!(actual_eval, coset_interpolated_eval);
316        }
317    }
318
319    /// Test that lagrange coefficients for a point in the domain is correct
320    #[test]
321    fn systematic_lagrange_coefficients_test() {
322        // This runs in time O(N^2) in the domain size, so keep the domain dimension
323        // low. We generate lagrange coefficients for each element in the domain.
324        for domain_dim in 1..5 {
325            let domain_size = 1 << domain_dim;
326            let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
327            let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
328            for (i, (x, coset_x)) in domain.elements().zip(coset_domain.elements()).enumerate() {
329                let lagrange_coeffs = domain.evaluate_all_lagrange_coefficients(x);
330                let coset_lagrange_coeffs =
331                    coset_domain.evaluate_all_lagrange_coefficients(coset_x);
332                for (j, (y, coset_y)) in lagrange_coeffs
333                    .into_iter()
334                    .zip(coset_lagrange_coeffs)
335                    .enumerate()
336                {
337                    // Lagrange coefficient for the evaluation point, which should be 1
338                    if i == j {
339                        assert_eq!(y, Fr::one());
340                        assert_eq!(coset_y, Fr::one());
341                    } else {
342                        assert_eq!(y, Fr::zero());
343                        assert_eq!(coset_y, Fr::zero());
344                    }
345                }
346            }
347        }
348    }
349
350    #[test]
351    fn test_fft_correctness() {
352        // Tests that the ffts output the correct result.
353        // This assumes a correct polynomial evaluation at point procedure.
354        // It tests consistency of FFT/IFFT, and coset_fft/coset_ifft,
355        // along with testing that each individual evaluation is correct.
356
357        // Runs in time O(degree^2)
358        let log_degree = 5;
359        let degree = 1 << log_degree;
360        let rand_poly = DensePolynomial::<Fr>::rand(degree - 1, &mut test_rng());
361
362        for log_domain_size in log_degree..(log_degree + 2) {
363            let domain_size = 1 << log_domain_size;
364            let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
365            let coset_domain =
366                Radix2EvaluationDomain::<Fr>::new_coset(domain_size, Fr::GENERATOR).unwrap();
367            let poly_evals = domain.fft(&rand_poly.coeffs);
368            let poly_coset_evals = coset_domain.fft(&rand_poly.coeffs);
369
370            for (i, (x, coset_x)) in domain.elements().zip(coset_domain.elements()).enumerate() {
371                assert_eq!(poly_evals[i], rand_poly.evaluate(&x));
372                assert_eq!(poly_coset_evals[i], rand_poly.evaluate(&coset_x));
373            }
374
375            let rand_poly_from_subgroup =
376                DensePolynomial::from_coefficients_vec(domain.ifft(&poly_evals));
377            let rand_poly_from_coset =
378                DensePolynomial::from_coefficients_vec(coset_domain.ifft(&poly_coset_evals));
379
380            assert_eq!(
381                rand_poly, rand_poly_from_subgroup,
382                "degree = {}, domain size = {}",
383                degree, domain_size
384            );
385            assert_eq!(
386                rand_poly, rand_poly_from_coset,
387                "degree = {}, domain size = {}",
388                degree, domain_size
389            );
390        }
391    }
392
393    #[test]
394    fn degree_aware_fft_correctness() {
395        // Test that the degree aware FFT (O(n log d)) matches the regular FFT (O(n log n)).
396        let num_coeffs = 1 << 5;
397        let rand_poly = DensePolynomial::<Fr>::rand(num_coeffs - 1, &mut test_rng());
398        let domain_size = num_coeffs * DEGREE_AWARE_FFT_THRESHOLD_FACTOR;
399        let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
400        let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
401
402        let deg_aware_fft_evals = domain.fft(&rand_poly);
403        let coset_deg_aware_fft_evals = coset_domain.fft(&rand_poly);
404
405        for (i, (x, coset_x)) in domain.elements().zip(coset_domain.elements()).enumerate() {
406            assert_eq!(deg_aware_fft_evals[i], rand_poly.evaluate(&x));
407            assert_eq!(coset_deg_aware_fft_evals[i], rand_poly.evaluate(&coset_x));
408        }
409    }
410
411    #[test]
412    fn test_roots_of_unity() {
413        // Tests that the roots of unity result is the same as domain.elements()
414        let max_degree = 10;
415        for log_domain_size in 0..max_degree {
416            let domain_size = 1 << log_domain_size;
417            let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
418            let actual_roots = domain.roots_of_unity(domain.group_gen);
419            for &value in &actual_roots {
420                assert!(domain.evaluate_vanishing_polynomial(value).is_zero());
421            }
422            let expected_roots_elements = domain.elements();
423            for (expected, &actual) in expected_roots_elements.zip(&actual_roots) {
424                assert_eq!(expected, actual);
425            }
426            assert_eq!(actual_roots.len(), domain_size / 2);
427        }
428    }
429
430    #[test]
431    #[cfg(feature = "parallel")]
432    fn parallel_fft_consistency() {
433        use ark_std::{test_rng, vec::*};
434        use ark_test_curves::bls12_381::Fr;
435
436        // This implements the Cooley-Turkey FFT, derived from libfqfft
437        // The libfqfft implementation uses pseudocode from [CLRS 2n Ed, pp. 864].
438        fn serial_radix2_fft(a: &mut [Fr], omega: Fr, log_n: u32) {
439            use ark_std::convert::TryFrom;
440            let n = u32::try_from(a.len())
441                .expect("cannot perform FFTs larger on vectors of len > (1 << 32)");
442            assert_eq!(n, 1 << log_n);
443
444            // swap coefficients in place
445            crate::domain::utils::bitreverse_permutation_in_place(a, log_n);
446
447            let mut m = 1;
448            for _i in 1..=log_n {
449                // w_m is 2^i-th root of unity
450                let w_m = omega.pow([(n / (2 * m)) as u64]);
451
452                let mut k = 0;
453                while k < n {
454                    // w = w_m^j at the start of every loop iteration
455                    let mut w = Fr::one();
456                    for j in 0..m {
457                        let mut t = a[(k + j + m) as usize];
458                        t *= w;
459                        let mut tmp = a[(k + j) as usize];
460                        tmp -= t;
461                        a[(k + j + m) as usize] = tmp;
462                        a[(k + j) as usize] += t;
463                        w *= &w_m;
464                    }
465
466                    k += 2 * m;
467                }
468
469                m *= 2;
470            }
471        }
472
473        fn serial_radix2_ifft(a: &mut [Fr], omega: Fr, log_n: u32) {
474            serial_radix2_fft(a, omega.inverse().unwrap(), log_n);
475            let domain_size_inv = Fr::from(a.len() as u64).inverse().unwrap();
476            for coeff in a.iter_mut() {
477                *coeff *= Fr::from(domain_size_inv);
478            }
479        }
480
481        fn serial_radix2_coset_fft(a: &mut [Fr], omega: Fr, log_n: u32) {
482            let coset_shift = Fr::GENERATOR;
483            let mut cur_pow = Fr::one();
484            for coeff in a.iter_mut() {
485                *coeff *= cur_pow;
486                cur_pow *= coset_shift;
487            }
488            serial_radix2_fft(a, omega, log_n);
489        }
490
491        fn serial_radix2_coset_ifft(a: &mut [Fr], omega: Fr, log_n: u32) {
492            serial_radix2_ifft(a, omega, log_n);
493            let coset_shift = Fr::GENERATOR.inverse().unwrap();
494            let mut cur_pow = Fr::one();
495            for coeff in a.iter_mut() {
496                *coeff *= cur_pow;
497                cur_pow *= coset_shift;
498            }
499        }
500
501        fn test_consistency<R: Rng>(rng: &mut R, max_coeffs: u32) {
502            for _ in 0..5 {
503                for log_d in 0..max_coeffs {
504                    let d = 1 << log_d;
505
506                    let expected_poly = (0..d).map(|_| Fr::rand(rng)).collect::<Vec<_>>();
507                    let mut expected_vec = expected_poly.clone();
508                    let mut actual_vec = expected_vec.clone();
509
510                    let domain = Radix2EvaluationDomain::new(d).unwrap();
511                    let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
512
513                    serial_radix2_fft(&mut expected_vec, domain.group_gen, log_d);
514                    domain.fft_in_place(&mut actual_vec);
515                    assert_eq!(expected_vec, actual_vec);
516
517                    serial_radix2_ifft(&mut expected_vec, domain.group_gen, log_d);
518                    domain.ifft_in_place(&mut actual_vec);
519                    assert_eq!(expected_vec, actual_vec);
520                    assert_eq!(expected_vec, expected_poly);
521
522                    serial_radix2_coset_fft(&mut expected_vec, domain.group_gen, log_d);
523                    coset_domain.fft_in_place(&mut actual_vec);
524                    assert_eq!(expected_vec, actual_vec);
525
526                    serial_radix2_coset_ifft(&mut expected_vec, domain.group_gen, log_d);
527                    coset_domain.ifft_in_place(&mut actual_vec);
528                    assert_eq!(expected_vec, actual_vec);
529                }
530            }
531        }
532
533        let rng = &mut test_rng();
534
535        test_consistency(rng, 10);
536    }
537
538    #[test]
539    fn test_domain_creation() {
540        // Test powers of 2
541        let domain_8 = Radix2EvaluationDomain::<Fr>::new(8).unwrap();
542        assert_eq!(domain_8.size(), 8);
543        assert_eq!(domain_8.log_size_of_group(), 3);
544
545        // Test rounding to next power of 2
546        let domain_7 = Radix2EvaluationDomain::<Fr>::new(7).unwrap();
547        assert_eq!(domain_7.size(), 8);
548    }
549
550    #[test]
551    fn test_root_of_unity() {
552        let domain = Radix2EvaluationDomain::<Fr>::new(8).unwrap();
553        let root = domain.group_gen();
554
555        // Check order: root^8 should be 1
556        let expected = root.pow([8]);
557        assert_eq!(expected, Fr::one());
558    }
559
560    #[test]
561    fn test_inverse_root_of_unity() {
562        let domain = Radix2EvaluationDomain::<Fr>::new(8).unwrap();
563        let root = domain.group_gen();
564        let root_inv = domain.group_gen_inv();
565
566        // root * root_inv should be 1
567        let expected = root * root_inv;
568        assert_eq!(expected, Fr::one());
569    }
570
571    #[test]
572    fn test_size_inverse() {
573        let domain = Radix2EvaluationDomain::<Fr>::new(8).unwrap();
574        let size_inv = domain.size_inv();
575        let expected = Fr::from(8).inverse().unwrap();
576
577        assert_eq!(size_inv, expected);
578    }
579
580    #[test]
581    fn test_fft_ifft_identity() {
582        let domain = Radix2EvaluationDomain::<Fr>::new(8).unwrap();
583        let mut coeffs = vec![
584            Fr::from(1),
585            Fr::from(2),
586            Fr::from(3),
587            Fr::from(4),
588            Fr::from(5),
589            Fr::from(6),
590            Fr::from(7),
591            Fr::from(8),
592        ];
593
594        let original = coeffs.clone();
595        domain.fft_in_place(&mut coeffs);
596        domain.ifft_in_place(&mut coeffs);
597
598        // Check if IFFT(FFT(coeffs)) == coeffs
599        assert_eq!(coeffs, original);
600    }
601
602    #[test]
603    fn test_vanishing_polynomial() {
604        let domain = Radix2EvaluationDomain::<Fr>::new(4).unwrap();
605        let z = domain.vanishing_polynomial();
606
607        for elem in domain.elements() {
608            assert_eq!(z.evaluate(&elem), Fr::zero());
609        }
610    }
611
612    #[test]
613    fn test_compute_size_of_domain() {
614        assert_eq!(
615            Radix2EvaluationDomain::<Fr>::compute_size_of_domain(7),
616            Some(8)
617        );
618        assert_eq!(
619            Radix2EvaluationDomain::<Fr>::compute_size_of_domain(8),
620            Some(8)
621        );
622        assert_eq!(
623            Radix2EvaluationDomain::<Fr>::compute_size_of_domain(15),
624            Some(16)
625        );
626    }
627}