Skip to main content

ark_poly/domain/
mixed_radix.rs

1//! This module provides a `MixedRadixEvaluationDomain` for performing
2//! various types of polynomial arithmetic on fields that are FFT-friendly.
3//!
4//! However, these fields do not have a high enough two-adicity to perform
5//! the FFT efficiently. Specifically, the multiplicative subgroup `G`
6//! generated by `F::TWO_ADIC_ROOT_OF_UNITY` is not large enough for efficient FFT.
7//!
8//! The `MixedRadixEvaluationDomain` resolves this issue by using a larger
9//! subgroup. This larger subgroup is obtained by combining `G` with another
10//! subgroup of size `F::SMALL_SUBGROUP_BASE^(F::SMALL_SUBGROUP_BASE_ADICITY)`.
11//! Together, these form a subgroup generated by `F::LARGE_SUBGROUP_ROOT_OF_UNITY`.
12
13pub use crate::domain::utils::Elements;
14use crate::domain::{
15    utils::{best_fft, bitreverse_permutation_in_place},
16    DomainCoeff, EvaluationDomain,
17};
18use ark_ff::{fields::utils::k_adicity, FftField, Field};
19use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
20use ark_std::{cmp::min, fmt, vec, vec::*};
21#[cfg(feature = "parallel")]
22use rayon::prelude::*;
23
24/// Defines a domain over which finite field (I)FFTs can be performed.
25///
26/// Works only for fields that have a multiplicative subgroup of size that is
27/// a power-of-2 and another small subgroup over a different base defined.
28#[derive(Copy, Clone, Hash, Eq, PartialEq, CanonicalSerialize, CanonicalDeserialize)]
29pub struct MixedRadixEvaluationDomain<F: Field> {
30    /// The size of the domain.
31    pub size: u64,
32    /// `log_2(self.size)`.
33    pub log_size_of_group: u32,
34    /// Size of the domain as a field element.
35    pub size_as_field_element: F,
36    /// Inverse of the size in the field.
37    pub size_inv: F,
38    /// A generator of the subgroup.
39    pub group_gen: F,
40    /// Inverse of the generator of the subgroup.
41    pub group_gen_inv: F,
42    /// Offset that specifies the coset.
43    pub offset: F,
44    /// Inverse of the offset that specifies the coset.
45    pub offset_inv: F,
46    /// Constant coefficient for the vanishing polynomial.
47    /// Equals `self.offset^self.size`.
48    pub offset_pow_size: F,
49}
50
51impl<F: Field> fmt::Debug for MixedRadixEvaluationDomain<F> {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        write!(
54            f,
55            "Mixed-radix multiplicative subgroup of size {}",
56            self.size
57        )
58    }
59}
60
61impl<F: FftField> EvaluationDomain<F> for MixedRadixEvaluationDomain<F> {
62    type Elements = Elements<F>;
63
64    /// Construct a domain that is large enough for evaluations of a polynomial
65    /// having `num_coeffs` coefficients.
66    fn new(num_coeffs: usize) -> Option<Self> {
67        // Compute the best size of our evaluation domain.
68        let size = best_mixed_domain_size::<F>(num_coeffs) as u64;
69        let small_subgroup_base = F::SMALL_SUBGROUP_BASE?;
70
71        // Compute the size of our evaluation domain
72        let q = u64::from(small_subgroup_base);
73        let q_adicity = k_adicity(q, size);
74        let q_part = q.checked_pow(q_adicity)?;
75
76        let two_adicity = k_adicity(2, size);
77        let log_size_of_group = two_adicity;
78        let two_part = 2u64.checked_pow(two_adicity)?;
79
80        if size != q_part * two_part {
81            return None;
82        }
83
84        // Compute the generator for the multiplicative subgroup.
85        // It should be the num_coeffs root of unity.
86        let group_gen = F::get_root_of_unity(size)?;
87        // Check that it is indeed the requested root of unity.
88        debug_assert_eq!(group_gen.pow([size]), F::one());
89        let size_as_field_element = F::from(size);
90        let size_inv = size_as_field_element.inverse()?;
91
92        Some(Self {
93            size,
94            log_size_of_group,
95            size_as_field_element,
96            size_inv,
97            group_gen,
98            group_gen_inv: group_gen.inverse()?,
99            offset: F::one(),
100            offset_inv: F::one(),
101            offset_pow_size: F::one(),
102        })
103    }
104
105    fn get_coset(&self, offset: F) -> Option<Self> {
106        Some(Self {
107            offset,
108            offset_inv: offset.inverse()?,
109            offset_pow_size: offset.pow([self.size]),
110            ..*self
111        })
112    }
113
114    fn compute_size_of_domain(num_coeffs: usize) -> Option<usize> {
115        let small_subgroup_base = F::SMALL_SUBGROUP_BASE?;
116
117        // Compute the best size of our evaluation domain.
118        let num_coeffs = best_mixed_domain_size::<F>(num_coeffs) as u64;
119
120        let q = u64::from(small_subgroup_base);
121        let q_adicity = k_adicity(q, num_coeffs);
122        let q_part = q.checked_pow(q_adicity)?;
123
124        let two_adicity = k_adicity(2, num_coeffs);
125        let two_part = 2u64.checked_pow(two_adicity)?;
126
127        (num_coeffs == q_part * two_part).then_some(num_coeffs as usize)
128    }
129
130    #[inline]
131    fn size(&self) -> usize {
132        self.size.try_into().unwrap()
133    }
134
135    #[inline]
136    fn log_size_of_group(&self) -> u64 {
137        self.log_size_of_group as u64
138    }
139
140    #[inline]
141    fn size_inv(&self) -> F {
142        self.size_inv
143    }
144
145    #[inline]
146    fn group_gen(&self) -> F {
147        self.group_gen
148    }
149
150    #[inline]
151    fn group_gen_inv(&self) -> F {
152        self.group_gen_inv
153    }
154
155    #[inline]
156    fn coset_offset(&self) -> F {
157        self.offset
158    }
159
160    #[inline]
161    fn coset_offset_inv(&self) -> F {
162        self.offset_inv
163    }
164
165    #[inline]
166    fn coset_offset_pow_size(&self) -> F {
167        self.offset_pow_size
168    }
169
170    #[inline]
171    fn fft_in_place<T: DomainCoeff<F>>(&self, coeffs: &mut Vec<T>) {
172        if !self.offset.is_one() {
173            Self::distribute_powers(coeffs, self.offset);
174        }
175        coeffs.resize(self.size(), T::zero());
176        best_fft(
177            coeffs,
178            self.group_gen,
179            self.log_size_of_group,
180            serial_mixed_radix_fft::<T, F>,
181        )
182    }
183
184    #[inline]
185    fn ifft_in_place<T: DomainCoeff<F>>(&self, evals: &mut Vec<T>) {
186        evals.resize(self.size(), T::zero());
187        best_fft(
188            evals,
189            self.group_gen_inv,
190            self.log_size_of_group,
191            serial_mixed_radix_fft::<T, F>,
192        );
193        if self.offset.is_one() {
194            ark_std::cfg_iter_mut!(evals).for_each(|val| *val *= self.size_inv);
195        } else {
196            Self::distribute_powers_and_mul_by_const(evals, self.offset_inv, self.size_inv);
197        }
198    }
199
200    /// Return an iterator over the elements of the domain.
201    fn elements(&self) -> Elements<F> {
202        Elements {
203            cur_elem: self.offset,
204            cur_pow: 0,
205            size: self.size,
206            group_gen: self.group_gen,
207        }
208    }
209}
210
211fn mixed_radix_fft_permute(
212    two_adicity: u32,
213    q_adicity: u32,
214    q: usize,
215    n: usize,
216    mut i: usize,
217) -> usize {
218    // This is the permutation obtained by splitting into 2 groups two_adicity times
219    // and then q groups q_adicity many times. It can be efficiently described
220    // as follows i = 2^0 b_0 + 2^1 b_1 + ... + 2^{two_adicity - 1}
221    // b_{two_adicity - 1} + 2^two_adicity ( x_0 + q^1 x_1 + .. +
222    // q^{q_adicity-1} x_{q_adicity-1}) We want to return
223    // j = b_0 (n/2) + b_1 (n/ 2^2) + ... + b_{two_adicity-1} (n/ 2^two_adicity)
224    // + x_0 (n / 2^two_adicity / q) + .. + x_{q_adicity-1} (n / 2^two_adicity /
225    // q^q_adicity)
226    let mut res = 0;
227    let mut shift = n;
228
229    for _ in 0..two_adicity {
230        shift /= 2;
231        res += (i % 2) * shift;
232        i /= 2;
233    }
234
235    for _ in 0..q_adicity {
236        shift /= q;
237        res += (i % q) * shift;
238        i /= q;
239    }
240
241    res
242}
243
244fn best_mixed_domain_size<F: FftField>(min_size: usize) -> usize {
245    let mut best = usize::MAX;
246    let small_subgroup_base_adicity = F::SMALL_SUBGROUP_BASE_ADICITY.unwrap();
247    let small_subgroup_base = usize::try_from(F::SMALL_SUBGROUP_BASE.unwrap()).unwrap();
248
249    for b in 0..=small_subgroup_base_adicity {
250        let mut r = small_subgroup_base.pow(b);
251
252        let mut two_adicity = 0;
253        while r < min_size {
254            r *= 2;
255            two_adicity += 1;
256        }
257
258        if two_adicity <= F::TWO_ADICITY {
259            best = min(best, r);
260        }
261    }
262
263    best
264}
265
266pub(crate) fn serial_mixed_radix_fft<T: DomainCoeff<F>, F: FftField>(
267    a: &mut [T],
268    omega: F,
269    two_adicity: u32,
270) {
271    // Conceptually, this FFT first splits into 2 sub-arrays two_adicity many times,
272    // and then splits into q sub-arrays q_adicity many times.
273
274    let n = a.len();
275    let q = usize::try_from(F::SMALL_SUBGROUP_BASE.unwrap()).unwrap();
276    let q_u64 = u64::from(F::SMALL_SUBGROUP_BASE.unwrap());
277    let n_u64 = n as u64;
278
279    let q_adicity = k_adicity(q_u64, n_u64);
280    let q_part = q_u64.checked_pow(q_adicity).unwrap();
281    let two_part = 2u64.checked_pow(two_adicity).unwrap();
282
283    assert_eq!(n_u64, q_part * two_part);
284
285    let mut m = 1; // invariant: m = 2^{s-1}
286
287    if q_adicity > 0 {
288        // If we're using the other radix, we have to do two things differently than in
289        // the radix 2 case. 1. Applying the index permutation is a bit more
290        // complicated. It isn't an involution (like it is in the radix 2 case)
291        // so we need to remember which elements we've moved as we go along
292        // and can't use the trick of just swapping when processing the first element of
293        // a 2-cycle.
294        //
295        // 2. We need to do q_adicity many merge passes, each of which is a bit more
296        // complicated than the specialized q=2 case.
297
298        // Applying the permutation
299        let mut seen = vec![false; n];
300        for k in 0..n {
301            let mut i = k;
302            let mut a_i = a[i];
303            while !seen[i] {
304                let dest = mixed_radix_fft_permute(two_adicity, q_adicity, q, n, i);
305
306                let a_dest = a[dest];
307                a[dest] = a_i;
308
309                seen[i] = true;
310
311                a_i = a_dest;
312                i = dest;
313            }
314        }
315
316        let omega_q = omega.pow([(n / q) as u64]);
317        let mut qth_roots = Vec::with_capacity(q);
318        qth_roots.push(F::one());
319        for i in 1..q {
320            qth_roots.push(qth_roots[i - 1] * omega_q);
321        }
322
323        let mut terms = vec![T::zero(); q - 1];
324
325        // Doing the q_adicity passes.
326        for _ in 0..q_adicity {
327            let w_m = omega.pow([(n / (q * m)) as u64]);
328            let mut k = 0;
329            while k < n {
330                let mut w_j = F::one(); // w_j is omega_m ^ j
331                for j in 0..m {
332                    let base_term = a[k + j];
333                    let mut w_j_i = w_j;
334                    for i in 1..q {
335                        terms[i - 1] = a[k + j + i * m];
336                        terms[i - 1] *= w_j_i;
337                        w_j_i *= w_j;
338                    }
339
340                    for i in 0..q {
341                        a[k + j + i * m] = base_term;
342                        for l in 1..q {
343                            let mut tmp = terms[l - 1];
344                            tmp *= qth_roots[(i * l) % q];
345                            a[k + j + i * m] += tmp;
346                        }
347                    }
348
349                    w_j *= w_m;
350                }
351
352                k += q * m;
353            }
354            m *= q;
355        }
356    } else {
357        bitreverse_permutation_in_place(a, two_adicity);
358    }
359
360    for _ in 0..two_adicity {
361        // w_m is 2^s-th root of unity now
362        let w_m = omega.pow([(n / (2 * m)) as u64]);
363
364        let mut k = 0;
365        while k < n {
366            let mut w = F::one();
367            for j in 0..m {
368                let mut t = a[(k + m) + j];
369                t *= w;
370                a[(k + m) + j] = a[k + j];
371                a[(k + m) + j] -= t;
372                a[k + j] += t;
373                w *= w_m;
374            }
375            k += 2 * m;
376        }
377        m *= 2;
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use crate::{
384        polynomial::{univariate::DensePolynomial, DenseUVPolynomial, Polynomial},
385        EvaluationDomain, MixedRadixEvaluationDomain,
386    };
387    use ark_ff::{FftField, Field, One, UniformRand, Zero};
388    use ark_std::{rand::Rng, test_rng};
389    use ark_test_curves::bn384_small_two_adicity::Fq as Fr;
390
391    #[test]
392    fn vanishing_polynomial_evaluation() {
393        let rng = &mut test_rng();
394        for coeffs in 0..12 {
395            let domain = MixedRadixEvaluationDomain::<Fr>::new(coeffs).unwrap();
396            let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
397            let z = domain.vanishing_polynomial();
398            let coset_z = coset_domain.vanishing_polynomial();
399            for _ in 0..100 {
400                let point: Fr = rng.gen();
401                assert_eq!(
402                    z.evaluate(&point),
403                    domain.evaluate_vanishing_polynomial(point)
404                );
405                assert_eq!(
406                    coset_z.evaluate(&point),
407                    coset_domain.evaluate_vanishing_polynomial(point)
408                );
409            }
410        }
411    }
412
413    #[test]
414    fn vanishing_polynomial_vanishes_on_domain() {
415        for coeffs in 0..1000 {
416            let domain = MixedRadixEvaluationDomain::<Fr>::new(coeffs).unwrap();
417            let z = domain.vanishing_polynomial();
418            for point in domain.elements() {
419                assert!(z.evaluate(&point).is_zero())
420            }
421
422            let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
423            let z = coset_domain.vanishing_polynomial();
424            for point in coset_domain.elements() {
425                assert!(z.evaluate(&point).is_zero())
426            }
427        }
428    }
429
430    /// Test that lagrange interpolation for a random polynomial at a random
431    /// point works.
432    #[test]
433    fn non_systematic_lagrange_coefficients_test() {
434        for domain_dim in 1..10 {
435            let domain_size = 1 << domain_dim;
436            let domain = MixedRadixEvaluationDomain::<Fr>::new(domain_size).unwrap();
437            let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
438            // Get random pt + lagrange coefficients
439            let rand_pt = Fr::rand(&mut test_rng());
440            let lagrange_coeffs = domain.evaluate_all_lagrange_coefficients(rand_pt);
441            let coset_lagrange_coeffs = coset_domain.evaluate_all_lagrange_coefficients(rand_pt);
442
443            // Sample the random polynomial, evaluate it over the domain and the random
444            // point.
445            let rand_poly = DensePolynomial::<Fr>::rand(domain_size - 1, &mut test_rng());
446            let poly_evals = domain.fft(rand_poly.coeffs());
447            let coset_poly_evals = coset_domain.fft(rand_poly.coeffs());
448            let actual_eval = rand_poly.evaluate(&rand_pt);
449
450            // Do lagrange interpolation, and compare against the actual evaluation
451            let mut interpolated_eval = Fr::zero();
452            let mut coset_interpolated_eval = Fr::zero();
453            for i in 0..domain_size {
454                interpolated_eval += lagrange_coeffs[i] * poly_evals[i];
455                coset_interpolated_eval += coset_lagrange_coeffs[i] * coset_poly_evals[i];
456            }
457            assert_eq!(actual_eval, interpolated_eval);
458            assert_eq!(actual_eval, coset_interpolated_eval);
459        }
460    }
461
462    /// Test that lagrange coefficients for a point in the domain is correct
463    #[test]
464    fn systematic_lagrange_coefficients_test() {
465        // This runs in time O(N^2) in the domain size, so keep the domain dimension
466        // low. We generate lagrange coefficients for each element in the domain.
467        for domain_dim in 1..5 {
468            let domain_size = 1 << domain_dim;
469            let domain = MixedRadixEvaluationDomain::<Fr>::new(domain_size).unwrap();
470            let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
471            for (i, (x, coset_x)) in domain.elements().zip(coset_domain.elements()).enumerate() {
472                let lagrange_coeffs = domain.evaluate_all_lagrange_coefficients(x);
473                let coset_lagrange_coeffs =
474                    coset_domain.evaluate_all_lagrange_coefficients(coset_x);
475                for (j, (y, coset_y)) in lagrange_coeffs
476                    .into_iter()
477                    .zip(coset_lagrange_coeffs)
478                    .enumerate()
479                {
480                    // Lagrange coefficient for the evaluation point, which should be 1
481                    if i == j {
482                        assert_eq!(y, Fr::one());
483                        assert_eq!(coset_y, Fr::one());
484                    } else {
485                        assert_eq!(y, Fr::zero());
486                        assert_eq!(coset_y, Fr::zero());
487                    }
488                }
489            }
490        }
491    }
492
493    #[test]
494    fn size_of_elements() {
495        for coeffs in 1..12 {
496            let size = 1 << coeffs;
497            let domain = MixedRadixEvaluationDomain::<Fr>::new(size).unwrap();
498            let domain_size = domain.size();
499            assert_eq!(domain_size, domain.elements().count());
500        }
501    }
502
503    #[test]
504    fn elements_contents() {
505        for coeffs in 1..12 {
506            let size = 1 << coeffs;
507            let domain = MixedRadixEvaluationDomain::<Fr>::new(size).unwrap();
508            let offset = Fr::GENERATOR;
509            let coset_domain = domain.get_coset(offset).unwrap();
510            for (i, (x, coset_x)) in domain.elements().zip(coset_domain.elements()).enumerate() {
511                assert_eq!(x, domain.group_gen.pow([i as u64]));
512                assert_eq!(x, domain.element(i));
513                assert_eq!(coset_x, offset * coset_domain.group_gen.pow([i as u64]));
514                assert_eq!(coset_x, coset_domain.element(i));
515            }
516        }
517    }
518
519    #[test]
520    #[cfg(feature = "parallel")]
521    fn parallel_fft_consistency() {
522        use super::serial_mixed_radix_fft;
523        use crate::domain::utils::parallel_fft;
524        use ark_ff::PrimeField;
525        use ark_std::{test_rng, vec::Vec};
526        use ark_test_curves::bn384_small_two_adicity::Fq as Fr;
527        use core::cmp::min;
528
529        fn test_consistency<F: PrimeField, R: Rng>(rng: &mut R, max_coeffs: u32) {
530            for _ in 0..5 {
531                for log_d in 0..max_coeffs {
532                    let d = 1 << log_d;
533
534                    let mut v1 = (0..d).map(|_| F::rand(rng)).collect::<Vec<_>>();
535                    let mut v2 = v1.clone();
536
537                    let domain = MixedRadixEvaluationDomain::new(v1.len()).unwrap();
538
539                    for log_cpus in log_d..min(log_d + 1, 3) {
540                        parallel_fft::<F, F>(
541                            &mut v1,
542                            domain.group_gen,
543                            log_d,
544                            log_cpus,
545                            serial_mixed_radix_fft::<F, F>,
546                        );
547                        serial_mixed_radix_fft::<F, F>(&mut v2, domain.group_gen, log_d);
548
549                        assert_eq!(v1, v2);
550                    }
551                }
552            }
553        }
554
555        let rng = &mut test_rng();
556
557        test_consistency::<Fr, _>(rng, 16);
558    }
559
560    #[test]
561    fn test_root_of_unity() {
562        let domain = MixedRadixEvaluationDomain::<Fr>::new(8).unwrap();
563        let root = domain.group_gen();
564
565        let expected = root.pow([domain.size() as u64]);
566        assert_eq!(expected, Fr::one());
567    }
568
569    #[test]
570    fn test_inverse_root_of_unity() {
571        let domain = MixedRadixEvaluationDomain::<Fr>::new(8).unwrap();
572        let root = domain.group_gen();
573        let root_inv = domain.group_gen_inv();
574
575        assert_eq!(root * root_inv, Fr::one());
576    }
577
578    #[test]
579    fn test_size_inverse() {
580        let domain = MixedRadixEvaluationDomain::<Fr>::new(8).unwrap();
581        let size_inv = domain.size_inv();
582        let expected = Fr::from(domain.size() as u64).inverse().unwrap();
583
584        assert_eq!(size_inv, expected);
585    }
586
587    #[test]
588    fn test_fft_ifft_identity() {
589        let domain = MixedRadixEvaluationDomain::<Fr>::new(8).unwrap();
590        let mut coeffs = ark_std::vec![
591            Fr::from(1),
592            Fr::from(2),
593            Fr::from(3),
594            Fr::from(4),
595            Fr::from(5),
596            Fr::from(6),
597            Fr::from(7),
598            Fr::from(8),
599        ];
600
601        let original = coeffs.clone();
602        domain.fft_in_place(&mut coeffs);
603        domain.ifft_in_place(&mut coeffs);
604
605        assert_eq!(coeffs, original);
606    }
607
608    #[test]
609    fn test_vanishing_polynomial() {
610        let domain = MixedRadixEvaluationDomain::<Fr>::new(4).unwrap();
611        let z = domain.vanishing_polynomial();
612
613        for elem in domain.elements() {
614            assert_eq!(z.evaluate(&elem), Fr::zero());
615        }
616    }
617
618    #[test]
619    fn test_get_coset() {
620        let domain = MixedRadixEvaluationDomain::<Fr>::new(4).unwrap();
621        let offset = Fr::from(3);
622        let coset_domain = domain.get_coset(offset).unwrap();
623
624        assert_eq!(coset_domain.coset_offset(), Fr::from(3));
625    }
626
627    #[test]
628    fn test_compute_size_of_domain() {
629        assert!(MixedRadixEvaluationDomain::<Fr>::compute_size_of_domain(7).is_some());
630        assert!(MixedRadixEvaluationDomain::<Fr>::compute_size_of_domain(16).is_some());
631    }
632}