ark_poly/domain/
mixed_radix.rs

1//! This module contains a `MixedRadixEvaluationDomain` for
2//! performing various kinds of polynomial arithmetic on top of
3//! fields that are FFT-friendly but do not have high-enough
4//! two-adicity to perform the FFT efficiently, i.e. the multiplicative
5//! subgroup `G` generated by `F::TWO_ADIC_ROOT_OF_UNITY` is not large enough.
6//! `MixedRadixEvaluationDomain` resolves
7//! this issue by using a larger subgroup obtained by combining
8//! `G` with another subgroup of size
9//! `F::SMALL_SUBGROUP_BASE^(F::SMALL_SUBGROUP_BASE_ADICITY)`,
10//! to obtain a subgroup generated by `F::LARGE_SUBGROUP_ROOT_OF_UNITY`.
11
12pub use crate::domain::utils::Elements;
13use crate::domain::{
14    utils::{best_fft, bitreverse},
15    DomainCoeff, EvaluationDomain,
16};
17use ark_ff::{fields::utils::k_adicity, FftField};
18use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
19use ark_std::{cmp::min, fmt, vec::*};
20#[cfg(feature = "parallel")]
21use rayon::prelude::*;
22
23/// Defines a domain over which finite field (I)FFTs can be performed. Works
24/// only for fields that have a multiplicative subgroup of size that is
25/// a power-of-2 and another small subgroup over a different base defined.
26#[derive(Copy, Clone, Hash, Eq, PartialEq, CanonicalSerialize, CanonicalDeserialize)]
27pub struct MixedRadixEvaluationDomain<F: FftField> {
28    /// The size of the domain.
29    pub size: u64,
30    /// `log_2(self.size)`.
31    pub log_size_of_group: u32,
32    /// Size of the domain as a field element.
33    pub size_as_field_element: F,
34    /// Inverse of the size in the field.
35    pub size_inv: F,
36    /// A generator of the subgroup.
37    pub group_gen: F,
38    /// Inverse of the generator of the subgroup.
39    pub group_gen_inv: F,
40    /// Offset that specifies the coset.
41    pub offset: F,
42    /// Inverse of the offset that specifies the coset.
43    pub offset_inv: F,
44    /// Constant coefficient for the vanishing polynomial.
45    /// Equals `self.offset^self.size`.
46    pub offset_pow_size: F,
47}
48
49impl<F: FftField> fmt::Debug for MixedRadixEvaluationDomain<F> {
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        write!(
52            f,
53            "Mixed-radix multiplicative subgroup of size {}",
54            self.size
55        )
56    }
57}
58
59impl<F: FftField> EvaluationDomain<F> for MixedRadixEvaluationDomain<F> {
60    type Elements = Elements<F>;
61
62    /// Construct a domain that is large enough for evaluations of a polynomial
63    /// having `num_coeffs` coefficients.
64    fn new(num_coeffs: usize) -> Option<Self> {
65        // Compute the best size of our evaluation domain.
66        let size = best_mixed_domain_size::<F>(num_coeffs) as u64;
67        let small_subgroup_base = F::SMALL_SUBGROUP_BASE?;
68
69        // Compute the size of our evaluation domain
70        let q = u64::from(small_subgroup_base);
71        let q_adicity = k_adicity(q, size);
72        let q_part = q.checked_pow(q_adicity)?;
73
74        let two_adicity = k_adicity(2, size);
75        let log_size_of_group = two_adicity;
76        let two_part = 2u64.checked_pow(two_adicity)?;
77
78        if size != q_part * two_part {
79            return None;
80        }
81
82        // Compute the generator for the multiplicative subgroup.
83        // It should be the num_coeffs root of unity.
84        let group_gen = F::get_root_of_unity(size)?;
85        // Check that it is indeed the requested root of unity.
86        debug_assert_eq!(group_gen.pow([size]), F::one());
87        let size_as_field_element = F::from(size);
88        let size_inv = size_as_field_element.inverse()?;
89
90        Some(MixedRadixEvaluationDomain {
91            size,
92            log_size_of_group,
93            size_as_field_element,
94            size_inv,
95            group_gen,
96            group_gen_inv: group_gen.inverse()?,
97            offset: F::one(),
98            offset_inv: F::one(),
99            offset_pow_size: F::one(),
100        })
101    }
102
103    fn get_coset(&self, offset: F) -> Option<Self> {
104        Some(MixedRadixEvaluationDomain {
105            offset,
106            offset_inv: offset.inverse()?,
107            offset_pow_size: offset.pow([self.size]),
108            ..*self
109        })
110    }
111
112    fn compute_size_of_domain(num_coeffs: usize) -> Option<usize> {
113        let small_subgroup_base = F::SMALL_SUBGROUP_BASE?;
114
115        // Compute the best size of our evaluation domain.
116        let num_coeffs = best_mixed_domain_size::<F>(num_coeffs) as u64;
117
118        let q = u64::from(small_subgroup_base);
119        let q_adicity = k_adicity(q, num_coeffs);
120        let q_part = q.checked_pow(q_adicity)?;
121
122        let two_adicity = k_adicity(2, num_coeffs);
123        let two_part = 2u64.checked_pow(two_adicity)?;
124
125        if num_coeffs == q_part * two_part {
126            Some(num_coeffs as usize)
127        } else {
128            None
129        }
130    }
131
132    #[inline]
133    fn size(&self) -> usize {
134        usize::try_from(self.size).unwrap()
135    }
136
137    #[inline]
138    fn log_size_of_group(&self) -> u64 {
139        self.log_size_of_group as u64
140    }
141
142    #[inline]
143    fn size_inv(&self) -> F {
144        self.size_inv
145    }
146
147    #[inline]
148    fn group_gen(&self) -> F {
149        self.group_gen
150    }
151
152    #[inline]
153    fn group_gen_inv(&self) -> F {
154        self.group_gen_inv
155    }
156
157    #[inline]
158    fn coset_offset(&self) -> F {
159        self.offset
160    }
161
162    #[inline]
163    fn coset_offset_inv(&self) -> F {
164        self.offset_inv
165    }
166
167    #[inline]
168    fn coset_offset_pow_size(&self) -> F {
169        self.offset_pow_size
170    }
171
172    #[inline]
173    fn fft_in_place<T: DomainCoeff<F>>(&self, coeffs: &mut Vec<T>) {
174        if !self.offset.is_one() {
175            Self::distribute_powers(coeffs, self.offset);
176        }
177        coeffs.resize(self.size(), T::zero());
178        best_fft(
179            coeffs,
180            self.group_gen,
181            self.log_size_of_group,
182            serial_mixed_radix_fft::<T, F>,
183        )
184    }
185
186    #[inline]
187    fn ifft_in_place<T: DomainCoeff<F>>(&self, evals: &mut Vec<T>) {
188        evals.resize(self.size(), T::zero());
189        best_fft(
190            evals,
191            self.group_gen_inv,
192            self.log_size_of_group,
193            serial_mixed_radix_fft::<T, F>,
194        );
195        if self.offset.is_one() {
196            ark_std::cfg_iter_mut!(evals).for_each(|val| *val *= self.size_inv);
197        } else {
198            Self::distribute_powers_and_mul_by_const(evals, self.offset_inv, self.size_inv);
199        }
200    }
201
202    /// Return an iterator over the elements of the domain.
203    fn elements(&self) -> Elements<F> {
204        Elements {
205            cur_elem: self.offset,
206            cur_pow: 0,
207            size: self.size,
208            group_gen: self.group_gen,
209        }
210    }
211}
212
213fn mixed_radix_fft_permute(
214    two_adicity: u32,
215    q_adicity: u32,
216    q: usize,
217    n: usize,
218    mut i: usize,
219) -> usize {
220    // This is the permutation obtained by splitting into 2 groups two_adicity times
221    // and then q groups q_adicity many times. It can be efficiently described
222    // as follows i = 2^0 b_0 + 2^1 b_1 + ... + 2^{two_adicity - 1}
223    // b_{two_adicity - 1} + 2^two_adicity ( x_0 + q^1 x_1 + .. +
224    // q^{q_adicity-1} x_{q_adicity-1}) We want to return
225    // j = b_0 (n/2) + b_1 (n/ 2^2) + ... + b_{two_adicity-1} (n/ 2^two_adicity)
226    // + x_0 (n / 2^two_adicity / q) + .. + x_{q_adicity-1} (n / 2^two_adicity /
227    // q^q_adicity)
228    let mut res = 0;
229    let mut shift = n;
230
231    for _ in 0..two_adicity {
232        shift /= 2;
233        res += (i % 2) * shift;
234        i /= 2;
235    }
236
237    for _ in 0..q_adicity {
238        shift /= q;
239        res += (i % q) * shift;
240        i /= q;
241    }
242
243    res
244}
245
246fn best_mixed_domain_size<F: FftField>(min_size: usize) -> usize {
247    let mut best = usize::max_value();
248    let small_subgroup_base_adicity = F::SMALL_SUBGROUP_BASE_ADICITY.unwrap();
249    let small_subgroup_base = usize::try_from(F::SMALL_SUBGROUP_BASE.unwrap()).unwrap();
250
251    for b in 0..=small_subgroup_base_adicity {
252        let mut r = small_subgroup_base.pow(b);
253
254        let mut two_adicity = 0;
255        while r < min_size {
256            r *= 2;
257            two_adicity += 1;
258        }
259
260        if two_adicity <= F::TWO_ADICITY {
261            best = min(best, r);
262        }
263    }
264
265    best
266}
267
268pub(crate) fn serial_mixed_radix_fft<T: DomainCoeff<F>, F: FftField>(
269    a: &mut [T],
270    omega: F,
271    two_adicity: u32,
272) {
273    // Conceptually, this FFT first splits into 2 sub-arrays two_adicity many times,
274    // and then splits into q sub-arrays q_adicity many times.
275
276    let n = a.len();
277    let q = usize::try_from(F::SMALL_SUBGROUP_BASE.unwrap()).unwrap();
278    let q_u64 = u64::from(F::SMALL_SUBGROUP_BASE.unwrap());
279    let n_u64 = n as u64;
280
281    let q_adicity = k_adicity(q_u64, n_u64);
282    let q_part = q_u64.checked_pow(q_adicity).unwrap();
283    let two_part = 2u64.checked_pow(two_adicity).unwrap();
284
285    assert_eq!(n_u64, q_part * two_part);
286
287    let mut m = 1; // invariant: m = 2^{s-1}
288
289    if q_adicity > 0 {
290        // If we're using the other radix, we have to do two things differently than in
291        // the radix 2 case. 1. Applying the index permutation is a bit more
292        // complicated. It isn't an involution (like it is in the radix 2 case)
293        // so we need to remember which elements we've moved as we go along
294        // and can't use the trick of just swapping when processing the first element of
295        // a 2-cycle.
296        //
297        // 2. We need to do q_adicity many merge passes, each of which is a bit more
298        // complicated than the specialized q=2 case.
299
300        // Applying the permutation
301        let mut seen = vec![false; n];
302        for k in 0..n {
303            let mut i = k;
304            let mut a_i = a[i];
305            while !seen[i] {
306                let dest = mixed_radix_fft_permute(two_adicity, q_adicity, q, n, i);
307
308                let a_dest = a[dest];
309                a[dest] = a_i;
310
311                seen[i] = true;
312
313                a_i = a_dest;
314                i = dest;
315            }
316        }
317
318        let omega_q = omega.pow([(n / q) as u64]);
319        let mut qth_roots = Vec::with_capacity(q);
320        qth_roots.push(F::one());
321        for i in 1..q {
322            qth_roots.push(qth_roots[i - 1] * omega_q);
323        }
324
325        let mut terms = vec![T::zero(); q - 1];
326
327        // Doing the q_adicity passes.
328        for _ in 0..q_adicity {
329            let w_m = omega.pow([(n / (q * m)) as u64]);
330            let mut k = 0;
331            while k < n {
332                let mut w_j = F::one(); // w_j is omega_m ^ j
333                for j in 0..m {
334                    let base_term = a[k + j];
335                    let mut w_j_i = w_j;
336                    for i in 1..q {
337                        terms[i - 1] = a[k + j + i * m];
338                        terms[i - 1] *= w_j_i;
339                        w_j_i *= w_j;
340                    }
341
342                    for i in 0..q {
343                        a[k + j + i * m] = base_term;
344                        for l in 1..q {
345                            let mut tmp = terms[l - 1];
346                            tmp *= qth_roots[(i * l) % q];
347                            a[k + j + i * m] += tmp;
348                        }
349                    }
350
351                    w_j *= w_m;
352                }
353
354                k += q * m;
355            }
356            m *= q;
357        }
358    } else {
359        // swapping in place (from Storer's book)
360        for k in 0..n {
361            let rk = bitreverse(k as u32, two_adicity) as usize;
362            if k < rk {
363                a.swap(k, rk);
364            }
365        }
366    }
367
368    for _ in 0..two_adicity {
369        // w_m is 2^s-th root of unity now
370        let w_m = omega.pow([(n / (2 * m)) as u64]);
371
372        let mut k = 0;
373        while k < n {
374            let mut w = F::one();
375            for j in 0..m {
376                let mut t = a[(k + m) + j];
377                t *= w;
378                a[(k + m) + j] = a[k + j];
379                a[(k + m) + j] -= t;
380                a[k + j] += t;
381                w *= w_m;
382            }
383            k += 2 * m;
384        }
385        m *= 2;
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use crate::{
392        polynomial::{univariate::DensePolynomial, DenseUVPolynomial, Polynomial},
393        EvaluationDomain, MixedRadixEvaluationDomain,
394    };
395    use ark_ff::{FftField, Field, One, UniformRand, Zero};
396    use ark_std::{rand::Rng, test_rng};
397    use ark_test_curves::bn384_small_two_adicity::Fq as Fr;
398
399    #[test]
400    fn vanishing_polynomial_evaluation() {
401        let rng = &mut test_rng();
402        for coeffs in 0..12 {
403            let domain = MixedRadixEvaluationDomain::<Fr>::new(coeffs).unwrap();
404            let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
405            let z = domain.vanishing_polynomial();
406            let coset_z = coset_domain.vanishing_polynomial();
407            for _ in 0..100 {
408                let point: Fr = rng.gen();
409                assert_eq!(
410                    z.evaluate(&point),
411                    domain.evaluate_vanishing_polynomial(point)
412                );
413                assert_eq!(
414                    coset_z.evaluate(&point),
415                    coset_domain.evaluate_vanishing_polynomial(point)
416                );
417            }
418        }
419    }
420
421    #[test]
422    fn vanishing_polynomial_vanishes_on_domain() {
423        for coeffs in 0..1000 {
424            let domain = MixedRadixEvaluationDomain::<Fr>::new(coeffs).unwrap();
425            let z = domain.vanishing_polynomial();
426            for point in domain.elements() {
427                assert!(z.evaluate(&point).is_zero())
428            }
429
430            let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
431            let z = coset_domain.vanishing_polynomial();
432            for point in coset_domain.elements() {
433                assert!(z.evaluate(&point).is_zero())
434            }
435        }
436    }
437
438    /// Test that lagrange interpolation for a random polynomial at a random
439    /// point works.
440    #[test]
441    fn non_systematic_lagrange_coefficients_test() {
442        for domain_dim in 1..10 {
443            let domain_size = 1 << domain_dim;
444            let domain = MixedRadixEvaluationDomain::<Fr>::new(domain_size).unwrap();
445            let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
446            // Get random pt + lagrange coefficients
447            let rand_pt = Fr::rand(&mut test_rng());
448            let lagrange_coeffs = domain.evaluate_all_lagrange_coefficients(rand_pt);
449            let coset_lagrange_coeffs = coset_domain.evaluate_all_lagrange_coefficients(rand_pt);
450
451            // Sample the random polynomial, evaluate it over the domain and the random
452            // point.
453            let rand_poly = DensePolynomial::<Fr>::rand(domain_size - 1, &mut test_rng());
454            let poly_evals = domain.fft(rand_poly.coeffs());
455            let coset_poly_evals = coset_domain.fft(rand_poly.coeffs());
456            let actual_eval = rand_poly.evaluate(&rand_pt);
457
458            // Do lagrange interpolation, and compare against the actual evaluation
459            let mut interpolated_eval = Fr::zero();
460            let mut coset_interpolated_eval = Fr::zero();
461            for i in 0..domain_size {
462                interpolated_eval += lagrange_coeffs[i] * poly_evals[i];
463                coset_interpolated_eval += coset_lagrange_coeffs[i] * coset_poly_evals[i];
464            }
465            assert_eq!(actual_eval, interpolated_eval);
466            assert_eq!(actual_eval, coset_interpolated_eval);
467        }
468    }
469
470    /// Test that lagrange coefficients for a point in the domain is correct
471    #[test]
472    fn systematic_lagrange_coefficients_test() {
473        // This runs in time O(N^2) in the domain size, so keep the domain dimension
474        // low. We generate lagrange coefficients for each element in the domain.
475        for domain_dim in 1..5 {
476            let domain_size = 1 << domain_dim;
477            let domain = MixedRadixEvaluationDomain::<Fr>::new(domain_size).unwrap();
478            let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
479            for (i, (x, coset_x)) in domain.elements().zip(coset_domain.elements()).enumerate() {
480                let lagrange_coeffs = domain.evaluate_all_lagrange_coefficients(x);
481                let coset_lagrange_coeffs =
482                    coset_domain.evaluate_all_lagrange_coefficients(coset_x);
483                for (j, (y, coset_y)) in lagrange_coeffs
484                    .into_iter()
485                    .zip(coset_lagrange_coeffs)
486                    .enumerate()
487                {
488                    // Lagrange coefficient for the evaluation point, which should be 1
489                    if i == j {
490                        assert_eq!(y, Fr::one());
491                        assert_eq!(coset_y, Fr::one());
492                    } else {
493                        assert_eq!(y, Fr::zero());
494                        assert_eq!(coset_y, Fr::zero());
495                    }
496                }
497            }
498        }
499    }
500
501    #[test]
502    fn size_of_elements() {
503        for coeffs in 1..12 {
504            let size = 1 << coeffs;
505            let domain = MixedRadixEvaluationDomain::<Fr>::new(size).unwrap();
506            let domain_size = domain.size();
507            assert_eq!(domain_size, domain.elements().count());
508        }
509    }
510
511    #[test]
512    fn elements_contents() {
513        for coeffs in 1..12 {
514            let size = 1 << coeffs;
515            let domain = MixedRadixEvaluationDomain::<Fr>::new(size).unwrap();
516            let offset = Fr::GENERATOR;
517            let coset_domain = domain.get_coset(offset).unwrap();
518            for (i, (x, coset_x)) in domain.elements().zip(coset_domain.elements()).enumerate() {
519                assert_eq!(x, domain.group_gen.pow([i as u64]));
520                assert_eq!(x, domain.element(i));
521                assert_eq!(coset_x, offset * coset_domain.group_gen.pow([i as u64]));
522                assert_eq!(coset_x, coset_domain.element(i));
523            }
524        }
525    }
526
527    #[test]
528    #[cfg(feature = "parallel")]
529    fn parallel_fft_consistency() {
530        use super::serial_mixed_radix_fft;
531        use crate::domain::utils::parallel_fft;
532        use ark_ff::PrimeField;
533        use ark_std::{test_rng, vec::*};
534        use ark_test_curves::bn384_small_two_adicity::Fq as Fr;
535        use core::cmp::min;
536
537        fn test_consistency<F: PrimeField, R: Rng>(rng: &mut R, max_coeffs: u32) {
538            for _ in 0..5 {
539                for log_d in 0..max_coeffs {
540                    let d = 1 << log_d;
541
542                    let mut v1 = (0..d).map(|_| F::rand(rng)).collect::<Vec<_>>();
543                    let mut v2 = v1.clone();
544
545                    let domain = MixedRadixEvaluationDomain::new(v1.len()).unwrap();
546
547                    for log_cpus in log_d..min(log_d + 1, 3) {
548                        parallel_fft::<F, F>(
549                            &mut v1,
550                            domain.group_gen,
551                            log_d,
552                            log_cpus,
553                            serial_mixed_radix_fft::<F, F>,
554                        );
555                        serial_mixed_radix_fft::<F, F>(&mut v2, domain.group_gen, log_d);
556
557                        assert_eq!(v1, v2);
558                    }
559                }
560            }
561        }
562
563        let rng = &mut test_rng();
564
565        test_consistency::<Fr, _>(rng, 16);
566    }
567}