1pub use crate::domain::utils::{bitreverse_permutation_in_place, Elements};
8use crate::domain::{DomainCoeff, EvaluationDomain};
9use ark_ff::{FftField, Field};
10use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
11use ark_std::{fmt, vec::*};
12
13mod fft;
14
15const DEGREE_AWARE_FFT_THRESHOLD_FACTOR: usize = 1 << 2;
17
18#[derive(Copy, Clone, Hash, Eq, PartialEq, CanonicalSerialize, CanonicalDeserialize)]
22pub struct Radix2EvaluationDomain<F: Field> {
23 pub size: u64,
25 pub log_size_of_group: u32,
27 pub size_as_field_element: F,
29 pub size_inv: F,
31 pub group_gen: F,
33 pub group_gen_inv: F,
35 pub offset: F,
37 pub offset_inv: F,
39 pub offset_pow_size: F,
42}
43
44impl<F: Field> fmt::Debug for Radix2EvaluationDomain<F> {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 write!(f, "Radix-2 multiplicative subgroup of size {}", self.size)
47 }
48}
49
50impl<F: FftField> EvaluationDomain<F> for Radix2EvaluationDomain<F> {
51 type Elements = Elements<F>;
52
53 fn new(num_coeffs: usize) -> Option<Self> {
56 let size = num_coeffs.next_power_of_two() as u64;
57
58 let log_size_of_group = size.trailing_zeros();
59
60 if log_size_of_group > F::TWO_ADICITY {
62 return None;
63 }
64
65 let group_gen = F::get_root_of_unity(size)?;
68 debug_assert_eq!(group_gen.pow([size]), F::one());
70 let size_as_field_element = F::from(size);
71
72 Some(Self {
73 size,
74 log_size_of_group,
75 size_as_field_element,
76 size_inv: size_as_field_element.inverse()?,
77 group_gen,
78 group_gen_inv: group_gen.inverse()?,
79 offset: F::ONE,
80 offset_inv: F::ONE,
81 offset_pow_size: F::ONE,
82 })
83 }
84
85 fn get_coset(&self, offset: F) -> Option<Self> {
86 Some(Self {
87 offset,
88 offset_inv: offset.inverse()?,
89 offset_pow_size: offset.pow([self.size]),
90 ..*self
91 })
92 }
93
94 fn compute_size_of_domain(num_coeffs: usize) -> Option<usize> {
95 let size = num_coeffs.checked_next_power_of_two()?;
96 (size.trailing_zeros() <= F::TWO_ADICITY).then_some(size)
97 }
98
99 #[inline]
100 fn size(&self) -> usize {
101 self.size.try_into().unwrap()
102 }
103
104 #[inline]
105 fn log_size_of_group(&self) -> u64 {
106 self.log_size_of_group as u64
107 }
108
109 #[inline]
110 fn size_inv(&self) -> F {
111 self.size_inv
112 }
113
114 #[inline]
115 fn group_gen(&self) -> F {
116 self.group_gen
117 }
118
119 #[inline]
120 fn group_gen_inv(&self) -> F {
121 self.group_gen_inv
122 }
123
124 #[inline]
125 fn coset_offset(&self) -> F {
126 self.offset
127 }
128
129 #[inline]
130 fn coset_offset_inv(&self) -> F {
131 self.offset_inv
132 }
133
134 #[inline]
135 fn coset_offset_pow_size(&self) -> F {
136 self.offset_pow_size
137 }
138
139 #[inline]
140 fn fft_in_place<T: DomainCoeff<F>>(&self, coeffs: &mut Vec<T>) {
141 if coeffs.len() * DEGREE_AWARE_FFT_THRESHOLD_FACTOR <= self.size() {
142 self.degree_aware_fft_in_place(coeffs);
143 } else {
144 coeffs.resize(self.size(), T::zero());
145 self.in_order_fft_in_place(coeffs);
146 }
147 }
148
149 #[inline]
150 fn ifft_in_place<T: DomainCoeff<F>>(&self, evals: &mut Vec<T>) {
151 evals.resize(self.size(), T::zero());
152 self.in_order_ifft_in_place(&mut *evals);
153 }
154
155 fn elements(&self) -> Elements<F> {
157 Elements {
158 cur_elem: self.offset,
159 cur_pow: 0,
160 size: self.size,
161 group_gen: self.group_gen,
162 }
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::DEGREE_AWARE_FFT_THRESHOLD_FACTOR;
169 use crate::{
170 polynomial::{univariate::*, DenseUVPolynomial, Polynomial},
171 EvaluationDomain, Radix2EvaluationDomain,
172 };
173 use ark_ff::{FftField, Field, One, UniformRand, Zero};
174 use ark_std::{collections::BTreeSet, rand::Rng, test_rng, vec};
175 use ark_test_curves::bls12_381::Fr;
176
177 #[test]
178 fn vanishing_polynomial_evaluation() {
179 let rng = &mut test_rng();
180 for coeffs in 0..10 {
181 let domain = Radix2EvaluationDomain::<Fr>::new(coeffs).unwrap();
182 let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
183 let z = domain.vanishing_polynomial();
184 let z_coset = coset_domain.vanishing_polynomial();
185 for _ in 0..100 {
186 let point: Fr = rng.gen();
187 assert_eq!(
188 z.evaluate(&point),
189 domain.evaluate_vanishing_polynomial(point)
190 );
191 assert_eq!(
192 z_coset.evaluate(&point),
193 coset_domain.evaluate_vanishing_polynomial(point)
194 );
195 }
196 }
197 }
198
199 #[test]
200 fn vanishing_polynomial_vanishes_on_domain() {
201 for coeffs in 0..1000 {
202 let domain = Radix2EvaluationDomain::<Fr>::new(coeffs).unwrap();
203 let z = domain.vanishing_polynomial();
204 for point in domain.elements() {
205 assert!(z.evaluate(&point).is_zero())
206 }
207
208 let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
209 let z = coset_domain.vanishing_polynomial();
210 for point in coset_domain.elements() {
211 assert!(z.evaluate(&point).is_zero())
212 }
213 }
214 }
215
216 #[test]
217 fn filter_polynomial_test() {
218 for log_domain_size in 1..=4 {
219 let domain_size = 1 << log_domain_size;
220 let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
221 for log_subdomain_size in 1..=log_domain_size {
222 let subdomain_size = 1 << log_subdomain_size;
223 let subdomain = Radix2EvaluationDomain::<Fr>::new(subdomain_size).unwrap();
224
225 let mut possible_offsets = vec![Fr::one()];
227 let domain_generator = domain.group_gen();
228
229 let mut offset = domain_generator;
230 let subdomain_generator = subdomain.group_gen();
231 while offset != subdomain_generator {
232 possible_offsets.push(offset);
233 offset *= domain_generator;
234 }
235
236 assert_eq!(possible_offsets.len(), domain_size / subdomain_size);
237
238 let cosets = possible_offsets
240 .iter()
241 .map(|offset| subdomain.get_coset(*offset).unwrap());
242
243 for coset in cosets {
244 let coset_elements = coset.elements().collect::<BTreeSet<_>>();
245 let filter_poly = domain.filter_polynomial(&coset);
246 assert_eq!(filter_poly.degree(), domain_size - subdomain_size);
247 for element in domain.elements() {
248 let evaluation = domain.evaluate_filter_polynomial(&coset, element);
249 assert_eq!(evaluation, filter_poly.evaluate(&element));
250 if coset_elements.contains(&element) {
251 assert_eq!(evaluation, Fr::one())
252 } else {
253 assert_eq!(evaluation, Fr::zero())
254 }
255 }
256 }
257 }
258 }
259 }
260
261 #[test]
262 fn size_of_elements() {
263 for coeffs in 1..10 {
264 let size = 1 << coeffs;
265 let domain = Radix2EvaluationDomain::<Fr>::new(size).unwrap();
266 let domain_size = domain.size();
267 assert_eq!(domain_size, domain.elements().count());
268 }
269 }
270
271 #[test]
272 fn elements_contents() {
273 for coeffs in 1..10 {
274 let size = 1 << coeffs;
275 let domain = Radix2EvaluationDomain::<Fr>::new(size).unwrap();
276 let offset = Fr::GENERATOR;
277 let coset_domain = domain.get_coset(offset).unwrap();
278 for (i, (x, coset_x)) in domain.elements().zip(coset_domain.elements()).enumerate() {
279 assert_eq!(x, domain.group_gen.pow([i as u64]));
280 assert_eq!(x, domain.element(i));
281 assert_eq!(coset_x, offset * coset_domain.group_gen.pow([i as u64]));
282 assert_eq!(coset_x, coset_domain.element(i));
283 }
284 }
285 }
286
287 #[test]
290 fn non_systematic_lagrange_coefficients_test() {
291 for domain_dim in 1..10 {
292 let domain_size = 1 << domain_dim;
293 let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
294 let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
295 let rand_pt = Fr::rand(&mut test_rng());
297 let lagrange_coeffs = domain.evaluate_all_lagrange_coefficients(rand_pt);
298 let coset_lagrange_coeffs = coset_domain.evaluate_all_lagrange_coefficients(rand_pt);
299
300 let rand_poly = DensePolynomial::<Fr>::rand(domain_size - 1, &mut test_rng());
303 let poly_evals = domain.fft(rand_poly.coeffs());
304 let coset_poly_evals = coset_domain.fft(rand_poly.coeffs());
305 let actual_eval = rand_poly.evaluate(&rand_pt);
306
307 let mut interpolated_eval = Fr::zero();
309 let mut coset_interpolated_eval = Fr::zero();
310 for i in 0..domain_size {
311 interpolated_eval += lagrange_coeffs[i] * poly_evals[i];
312 coset_interpolated_eval += coset_lagrange_coeffs[i] * coset_poly_evals[i];
313 }
314 assert_eq!(actual_eval, interpolated_eval);
315 assert_eq!(actual_eval, coset_interpolated_eval);
316 }
317 }
318
319 #[test]
321 fn systematic_lagrange_coefficients_test() {
322 for domain_dim in 1..5 {
325 let domain_size = 1 << domain_dim;
326 let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
327 let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
328 for (i, (x, coset_x)) in domain.elements().zip(coset_domain.elements()).enumerate() {
329 let lagrange_coeffs = domain.evaluate_all_lagrange_coefficients(x);
330 let coset_lagrange_coeffs =
331 coset_domain.evaluate_all_lagrange_coefficients(coset_x);
332 for (j, (y, coset_y)) in lagrange_coeffs
333 .into_iter()
334 .zip(coset_lagrange_coeffs)
335 .enumerate()
336 {
337 if i == j {
339 assert_eq!(y, Fr::one());
340 assert_eq!(coset_y, Fr::one());
341 } else {
342 assert_eq!(y, Fr::zero());
343 assert_eq!(coset_y, Fr::zero());
344 }
345 }
346 }
347 }
348 }
349
350 #[test]
351 fn test_fft_correctness() {
352 let log_degree = 5;
359 let degree = 1 << log_degree;
360 let rand_poly = DensePolynomial::<Fr>::rand(degree - 1, &mut test_rng());
361
362 for log_domain_size in log_degree..(log_degree + 2) {
363 let domain_size = 1 << log_domain_size;
364 let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
365 let coset_domain =
366 Radix2EvaluationDomain::<Fr>::new_coset(domain_size, Fr::GENERATOR).unwrap();
367 let poly_evals = domain.fft(&rand_poly.coeffs);
368 let poly_coset_evals = coset_domain.fft(&rand_poly.coeffs);
369
370 for (i, (x, coset_x)) in domain.elements().zip(coset_domain.elements()).enumerate() {
371 assert_eq!(poly_evals[i], rand_poly.evaluate(&x));
372 assert_eq!(poly_coset_evals[i], rand_poly.evaluate(&coset_x));
373 }
374
375 let rand_poly_from_subgroup =
376 DensePolynomial::from_coefficients_vec(domain.ifft(&poly_evals));
377 let rand_poly_from_coset =
378 DensePolynomial::from_coefficients_vec(coset_domain.ifft(&poly_coset_evals));
379
380 assert_eq!(
381 rand_poly, rand_poly_from_subgroup,
382 "degree = {}, domain size = {}",
383 degree, domain_size
384 );
385 assert_eq!(
386 rand_poly, rand_poly_from_coset,
387 "degree = {}, domain size = {}",
388 degree, domain_size
389 );
390 }
391 }
392
393 #[test]
394 fn degree_aware_fft_correctness() {
395 let num_coeffs = 1 << 5;
397 let rand_poly = DensePolynomial::<Fr>::rand(num_coeffs - 1, &mut test_rng());
398 let domain_size = num_coeffs * DEGREE_AWARE_FFT_THRESHOLD_FACTOR;
399 let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
400 let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
401
402 let deg_aware_fft_evals = domain.fft(&rand_poly);
403 let coset_deg_aware_fft_evals = coset_domain.fft(&rand_poly);
404
405 for (i, (x, coset_x)) in domain.elements().zip(coset_domain.elements()).enumerate() {
406 assert_eq!(deg_aware_fft_evals[i], rand_poly.evaluate(&x));
407 assert_eq!(coset_deg_aware_fft_evals[i], rand_poly.evaluate(&coset_x));
408 }
409 }
410
411 #[test]
412 fn test_roots_of_unity() {
413 let max_degree = 10;
415 for log_domain_size in 0..max_degree {
416 let domain_size = 1 << log_domain_size;
417 let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
418 let actual_roots = domain.roots_of_unity(domain.group_gen);
419 for &value in &actual_roots {
420 assert!(domain.evaluate_vanishing_polynomial(value).is_zero());
421 }
422 let expected_roots_elements = domain.elements();
423 for (expected, &actual) in expected_roots_elements.zip(&actual_roots) {
424 assert_eq!(expected, actual);
425 }
426 assert_eq!(actual_roots.len(), domain_size / 2);
427 }
428 }
429
430 #[test]
431 #[cfg(feature = "parallel")]
432 fn parallel_fft_consistency() {
433 use ark_std::{test_rng, vec::*};
434 use ark_test_curves::bls12_381::Fr;
435
436 fn serial_radix2_fft(a: &mut [Fr], omega: Fr, log_n: u32) {
439 use ark_std::convert::TryFrom;
440 let n = u32::try_from(a.len())
441 .expect("cannot perform FFTs larger on vectors of len > (1 << 32)");
442 assert_eq!(n, 1 << log_n);
443
444 crate::domain::utils::bitreverse_permutation_in_place(a, log_n);
446
447 let mut m = 1;
448 for _i in 1..=log_n {
449 let w_m = omega.pow([(n / (2 * m)) as u64]);
451
452 let mut k = 0;
453 while k < n {
454 let mut w = Fr::one();
456 for j in 0..m {
457 let mut t = a[(k + j + m) as usize];
458 t *= w;
459 let mut tmp = a[(k + j) as usize];
460 tmp -= t;
461 a[(k + j + m) as usize] = tmp;
462 a[(k + j) as usize] += t;
463 w *= &w_m;
464 }
465
466 k += 2 * m;
467 }
468
469 m *= 2;
470 }
471 }
472
473 fn serial_radix2_ifft(a: &mut [Fr], omega: Fr, log_n: u32) {
474 serial_radix2_fft(a, omega.inverse().unwrap(), log_n);
475 let domain_size_inv = Fr::from(a.len() as u64).inverse().unwrap();
476 for coeff in a.iter_mut() {
477 *coeff *= Fr::from(domain_size_inv);
478 }
479 }
480
481 fn serial_radix2_coset_fft(a: &mut [Fr], omega: Fr, log_n: u32) {
482 let coset_shift = Fr::GENERATOR;
483 let mut cur_pow = Fr::one();
484 for coeff in a.iter_mut() {
485 *coeff *= cur_pow;
486 cur_pow *= coset_shift;
487 }
488 serial_radix2_fft(a, omega, log_n);
489 }
490
491 fn serial_radix2_coset_ifft(a: &mut [Fr], omega: Fr, log_n: u32) {
492 serial_radix2_ifft(a, omega, log_n);
493 let coset_shift = Fr::GENERATOR.inverse().unwrap();
494 let mut cur_pow = Fr::one();
495 for coeff in a.iter_mut() {
496 *coeff *= cur_pow;
497 cur_pow *= coset_shift;
498 }
499 }
500
501 fn test_consistency<R: Rng>(rng: &mut R, max_coeffs: u32) {
502 for _ in 0..5 {
503 for log_d in 0..max_coeffs {
504 let d = 1 << log_d;
505
506 let expected_poly = (0..d).map(|_| Fr::rand(rng)).collect::<Vec<_>>();
507 let mut expected_vec = expected_poly.clone();
508 let mut actual_vec = expected_vec.clone();
509
510 let domain = Radix2EvaluationDomain::new(d).unwrap();
511 let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
512
513 serial_radix2_fft(&mut expected_vec, domain.group_gen, log_d);
514 domain.fft_in_place(&mut actual_vec);
515 assert_eq!(expected_vec, actual_vec);
516
517 serial_radix2_ifft(&mut expected_vec, domain.group_gen, log_d);
518 domain.ifft_in_place(&mut actual_vec);
519 assert_eq!(expected_vec, actual_vec);
520 assert_eq!(expected_vec, expected_poly);
521
522 serial_radix2_coset_fft(&mut expected_vec, domain.group_gen, log_d);
523 coset_domain.fft_in_place(&mut actual_vec);
524 assert_eq!(expected_vec, actual_vec);
525
526 serial_radix2_coset_ifft(&mut expected_vec, domain.group_gen, log_d);
527 coset_domain.ifft_in_place(&mut actual_vec);
528 assert_eq!(expected_vec, actual_vec);
529 }
530 }
531 }
532
533 let rng = &mut test_rng();
534
535 test_consistency(rng, 10);
536 }
537
538 #[test]
539 fn test_domain_creation() {
540 let domain_8 = Radix2EvaluationDomain::<Fr>::new(8).unwrap();
542 assert_eq!(domain_8.size(), 8);
543 assert_eq!(domain_8.log_size_of_group(), 3);
544
545 let domain_7 = Radix2EvaluationDomain::<Fr>::new(7).unwrap();
547 assert_eq!(domain_7.size(), 8);
548 }
549
550 #[test]
551 fn test_root_of_unity() {
552 let domain = Radix2EvaluationDomain::<Fr>::new(8).unwrap();
553 let root = domain.group_gen();
554
555 let expected = root.pow([8]);
557 assert_eq!(expected, Fr::one());
558 }
559
560 #[test]
561 fn test_inverse_root_of_unity() {
562 let domain = Radix2EvaluationDomain::<Fr>::new(8).unwrap();
563 let root = domain.group_gen();
564 let root_inv = domain.group_gen_inv();
565
566 let expected = root * root_inv;
568 assert_eq!(expected, Fr::one());
569 }
570
571 #[test]
572 fn test_size_inverse() {
573 let domain = Radix2EvaluationDomain::<Fr>::new(8).unwrap();
574 let size_inv = domain.size_inv();
575 let expected = Fr::from(8).inverse().unwrap();
576
577 assert_eq!(size_inv, expected);
578 }
579
580 #[test]
581 fn test_fft_ifft_identity() {
582 let domain = Radix2EvaluationDomain::<Fr>::new(8).unwrap();
583 let mut coeffs = vec![
584 Fr::from(1),
585 Fr::from(2),
586 Fr::from(3),
587 Fr::from(4),
588 Fr::from(5),
589 Fr::from(6),
590 Fr::from(7),
591 Fr::from(8),
592 ];
593
594 let original = coeffs.clone();
595 domain.fft_in_place(&mut coeffs);
596 domain.ifft_in_place(&mut coeffs);
597
598 assert_eq!(coeffs, original);
600 }
601
602 #[test]
603 fn test_vanishing_polynomial() {
604 let domain = Radix2EvaluationDomain::<Fr>::new(4).unwrap();
605 let z = domain.vanishing_polynomial();
606
607 for elem in domain.elements() {
608 assert_eq!(z.evaluate(&elem), Fr::zero());
609 }
610 }
611
612 #[test]
613 fn test_compute_size_of_domain() {
614 assert_eq!(
615 Radix2EvaluationDomain::<Fr>::compute_size_of_domain(7),
616 Some(8)
617 );
618 assert_eq!(
619 Radix2EvaluationDomain::<Fr>::compute_size_of_domain(8),
620 Some(8)
621 );
622 assert_eq!(
623 Radix2EvaluationDomain::<Fr>::compute_size_of_domain(15),
624 Some(16)
625 );
626 }
627}