1pub use crate::domain::utils::Elements;
7use crate::domain::{DomainCoeff, EvaluationDomain};
8use ark_ff::FftField;
9use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
10use ark_std::{fmt, vec::*};
11
12mod fft;
13
14const DEGREE_AWARE_FFT_THRESHOLD_FACTOR: usize = 1 << 2;
16
17#[derive(Copy, Clone, Hash, Eq, PartialEq, CanonicalSerialize, CanonicalDeserialize)]
21pub struct Radix2EvaluationDomain<F: FftField> {
22 pub size: u64,
24 pub log_size_of_group: u32,
26 pub size_as_field_element: F,
28 pub size_inv: F,
30 pub group_gen: F,
32 pub group_gen_inv: F,
34 pub offset: F,
36 pub offset_inv: F,
38 pub offset_pow_size: F,
41}
42
43impl<F: FftField> fmt::Debug for Radix2EvaluationDomain<F> {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 write!(f, "Radix-2 multiplicative subgroup of size {}", self.size)
46 }
47}
48
49impl<F: FftField> EvaluationDomain<F> for Radix2EvaluationDomain<F> {
50 type Elements = Elements<F>;
51
52 fn new(num_coeffs: usize) -> Option<Self> {
55 let size = if num_coeffs.is_power_of_two() {
56 num_coeffs
57 } else {
58 num_coeffs.checked_next_power_of_two()?
59 } as u64;
60 let log_size_of_group = size.trailing_zeros();
61
62 if log_size_of_group > F::TWO_ADICITY {
64 return None;
65 }
66
67 let group_gen = F::get_root_of_unity(size)?;
70 debug_assert_eq!(group_gen.pow([size]), F::one());
72 let size_as_field_element = F::from(size);
73 let size_inv = size_as_field_element.inverse()?;
74
75 Some(Radix2EvaluationDomain {
76 size,
77 log_size_of_group,
78 size_as_field_element,
79 size_inv,
80 group_gen,
81 group_gen_inv: group_gen.inverse()?,
82 offset: F::one(),
83 offset_inv: F::one(),
84 offset_pow_size: F::one(),
85 })
86 }
87
88 fn get_coset(&self, offset: F) -> Option<Self> {
89 Some(Radix2EvaluationDomain {
90 offset,
91 offset_inv: offset.inverse()?,
92 offset_pow_size: offset.pow([self.size]),
93 ..*self
94 })
95 }
96
97 fn compute_size_of_domain(num_coeffs: usize) -> Option<usize> {
98 let size = num_coeffs.checked_next_power_of_two()?;
99 if size.trailing_zeros() > F::TWO_ADICITY {
100 None
101 } else {
102 Some(size)
103 }
104 }
105
106 #[inline]
107 fn size(&self) -> usize {
108 usize::try_from(self.size).unwrap()
109 }
110
111 #[inline]
112 fn log_size_of_group(&self) -> u64 {
113 self.log_size_of_group as u64
114 }
115
116 #[inline]
117 fn size_inv(&self) -> F {
118 self.size_inv
119 }
120
121 #[inline]
122 fn group_gen(&self) -> F {
123 self.group_gen
124 }
125
126 #[inline]
127 fn group_gen_inv(&self) -> F {
128 self.group_gen_inv
129 }
130
131 #[inline]
132 fn coset_offset(&self) -> F {
133 self.offset
134 }
135
136 #[inline]
137 fn coset_offset_inv(&self) -> F {
138 self.offset_inv
139 }
140
141 #[inline]
142 fn coset_offset_pow_size(&self) -> F {
143 self.offset_pow_size
144 }
145
146 #[inline]
147 fn fft_in_place<T: DomainCoeff<F>>(&self, coeffs: &mut Vec<T>) {
148 if coeffs.len() * DEGREE_AWARE_FFT_THRESHOLD_FACTOR <= self.size() {
149 self.degree_aware_fft_in_place(coeffs);
150 } else {
151 coeffs.resize(self.size(), T::zero());
152 self.in_order_fft_in_place(coeffs);
153 }
154 }
155
156 #[inline]
157 fn ifft_in_place<T: DomainCoeff<F>>(&self, evals: &mut Vec<T>) {
158 evals.resize(self.size(), T::zero());
159 self.in_order_ifft_in_place(&mut *evals);
160 }
161
162 fn elements(&self) -> Elements<F> {
164 Elements {
165 cur_elem: self.offset,
166 cur_pow: 0,
167 size: self.size,
168 group_gen: self.group_gen,
169 }
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::DEGREE_AWARE_FFT_THRESHOLD_FACTOR;
176 use crate::{
177 polynomial::{univariate::*, DenseUVPolynomial, Polynomial},
178 EvaluationDomain, Radix2EvaluationDomain,
179 };
180 use ark_ff::{FftField, Field, One, UniformRand, Zero};
181 use ark_std::{collections::BTreeSet, rand::Rng, test_rng};
182 use ark_test_curves::bls12_381::Fr;
183
184 #[test]
185 fn vanishing_polynomial_evaluation() {
186 let rng = &mut test_rng();
187 for coeffs in 0..10 {
188 let domain = Radix2EvaluationDomain::<Fr>::new(coeffs).unwrap();
189 let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
190 let z = domain.vanishing_polynomial();
191 let z_coset = coset_domain.vanishing_polynomial();
192 for _ in 0..100 {
193 let point: Fr = rng.gen();
194 assert_eq!(
195 z.evaluate(&point),
196 domain.evaluate_vanishing_polynomial(point)
197 );
198 assert_eq!(
199 z_coset.evaluate(&point),
200 coset_domain.evaluate_vanishing_polynomial(point)
201 );
202 }
203 }
204 }
205
206 #[test]
207 fn vanishing_polynomial_vanishes_on_domain() {
208 for coeffs in 0..1000 {
209 let domain = Radix2EvaluationDomain::<Fr>::new(coeffs).unwrap();
210 let z = domain.vanishing_polynomial();
211 for point in domain.elements() {
212 assert!(z.evaluate(&point).is_zero())
213 }
214
215 let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
216 let z = coset_domain.vanishing_polynomial();
217 for point in coset_domain.elements() {
218 assert!(z.evaluate(&point).is_zero())
219 }
220 }
221 }
222
223 #[test]
224 fn filter_polynomial_test() {
225 for log_domain_size in 1..=4 {
226 let domain_size = 1 << log_domain_size;
227 let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
228 for log_subdomain_size in 1..=log_domain_size {
229 let subdomain_size = 1 << log_subdomain_size;
230 let subdomain = Radix2EvaluationDomain::<Fr>::new(subdomain_size).unwrap();
231
232 let mut possible_offsets = vec![Fr::one()];
234 let domain_generator = domain.group_gen();
235
236 let mut offset = domain_generator;
237 let subdomain_generator = subdomain.group_gen();
238 while offset != subdomain_generator {
239 possible_offsets.push(offset);
240 offset *= domain_generator;
241 }
242
243 assert_eq!(possible_offsets.len(), domain_size / subdomain_size);
244
245 let cosets = possible_offsets
247 .iter()
248 .map(|offset| subdomain.get_coset(*offset).unwrap());
249
250 for coset in cosets {
251 let coset_elements = coset.elements().collect::<BTreeSet<_>>();
252 let filter_poly = domain.filter_polynomial(&coset);
253 assert_eq!(filter_poly.degree(), domain_size - subdomain_size);
254 for element in domain.elements() {
255 let evaluation = domain.evaluate_filter_polynomial(&coset, element);
256 assert_eq!(evaluation, filter_poly.evaluate(&element));
257 if coset_elements.contains(&element) {
258 assert_eq!(evaluation, Fr::one())
259 } else {
260 assert_eq!(evaluation, Fr::zero())
261 }
262 }
263 }
264 }
265 }
266 }
267
268 #[test]
269 fn size_of_elements() {
270 for coeffs in 1..10 {
271 let size = 1 << coeffs;
272 let domain = Radix2EvaluationDomain::<Fr>::new(size).unwrap();
273 let domain_size = domain.size();
274 assert_eq!(domain_size, domain.elements().count());
275 }
276 }
277
278 #[test]
279 fn elements_contents() {
280 for coeffs in 1..10 {
281 let size = 1 << coeffs;
282 let domain = Radix2EvaluationDomain::<Fr>::new(size).unwrap();
283 let offset = Fr::GENERATOR;
284 let coset_domain = domain.get_coset(offset).unwrap();
285 for (i, (x, coset_x)) in domain.elements().zip(coset_domain.elements()).enumerate() {
286 assert_eq!(x, domain.group_gen.pow([i as u64]));
287 assert_eq!(x, domain.element(i));
288 assert_eq!(coset_x, offset * coset_domain.group_gen.pow([i as u64]));
289 assert_eq!(coset_x, coset_domain.element(i));
290 }
291 }
292 }
293
294 #[test]
297 fn non_systematic_lagrange_coefficients_test() {
298 for domain_dim in 1..10 {
299 let domain_size = 1 << domain_dim;
300 let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
301 let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
302 let rand_pt = Fr::rand(&mut test_rng());
304 let lagrange_coeffs = domain.evaluate_all_lagrange_coefficients(rand_pt);
305 let coset_lagrange_coeffs = coset_domain.evaluate_all_lagrange_coefficients(rand_pt);
306
307 let rand_poly = DensePolynomial::<Fr>::rand(domain_size - 1, &mut test_rng());
310 let poly_evals = domain.fft(rand_poly.coeffs());
311 let coset_poly_evals = coset_domain.fft(rand_poly.coeffs());
312 let actual_eval = rand_poly.evaluate(&rand_pt);
313
314 let mut interpolated_eval = Fr::zero();
316 let mut coset_interpolated_eval = Fr::zero();
317 for i in 0..domain_size {
318 interpolated_eval += lagrange_coeffs[i] * poly_evals[i];
319 coset_interpolated_eval += coset_lagrange_coeffs[i] * coset_poly_evals[i];
320 }
321 assert_eq!(actual_eval, interpolated_eval);
322 assert_eq!(actual_eval, coset_interpolated_eval);
323 }
324 }
325
326 #[test]
328 fn systematic_lagrange_coefficients_test() {
329 for domain_dim in 1..5 {
332 let domain_size = 1 << domain_dim;
333 let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
334 let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
335 for (i, (x, coset_x)) in domain.elements().zip(coset_domain.elements()).enumerate() {
336 let lagrange_coeffs = domain.evaluate_all_lagrange_coefficients(x);
337 let coset_lagrange_coeffs =
338 coset_domain.evaluate_all_lagrange_coefficients(coset_x);
339 for (j, (y, coset_y)) in lagrange_coeffs
340 .into_iter()
341 .zip(coset_lagrange_coeffs)
342 .enumerate()
343 {
344 if i == j {
346 assert_eq!(y, Fr::one());
347 assert_eq!(coset_y, Fr::one());
348 } else {
349 assert_eq!(y, Fr::zero());
350 assert_eq!(coset_y, Fr::zero());
351 }
352 }
353 }
354 }
355 }
356
357 #[test]
358 fn test_fft_correctness() {
359 let log_degree = 5;
366 let degree = 1 << log_degree;
367 let rand_poly = DensePolynomial::<Fr>::rand(degree - 1, &mut test_rng());
368
369 for log_domain_size in log_degree..(log_degree + 2) {
370 let domain_size = 1 << log_domain_size;
371 let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
372 let coset_domain =
373 Radix2EvaluationDomain::<Fr>::new_coset(domain_size, Fr::GENERATOR).unwrap();
374 let poly_evals = domain.fft(&rand_poly.coeffs);
375 let poly_coset_evals = coset_domain.fft(&rand_poly.coeffs);
376
377 for (i, (x, coset_x)) in domain.elements().zip(coset_domain.elements()).enumerate() {
378 assert_eq!(poly_evals[i], rand_poly.evaluate(&x));
379 assert_eq!(poly_coset_evals[i], rand_poly.evaluate(&coset_x));
380 }
381
382 let rand_poly_from_subgroup =
383 DensePolynomial::from_coefficients_vec(domain.ifft(&poly_evals));
384 let rand_poly_from_coset =
385 DensePolynomial::from_coefficients_vec(coset_domain.ifft(&poly_coset_evals));
386
387 assert_eq!(
388 rand_poly, rand_poly_from_subgroup,
389 "degree = {}, domain size = {}",
390 degree, domain_size
391 );
392 assert_eq!(
393 rand_poly, rand_poly_from_coset,
394 "degree = {}, domain size = {}",
395 degree, domain_size
396 );
397 }
398 }
399
400 #[test]
401 fn degree_aware_fft_correctness() {
402 let num_coeffs = 1 << 5;
404 let rand_poly = DensePolynomial::<Fr>::rand(num_coeffs - 1, &mut test_rng());
405 let domain_size = num_coeffs * DEGREE_AWARE_FFT_THRESHOLD_FACTOR;
406 let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
407 let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
408
409 let deg_aware_fft_evals = domain.fft(&rand_poly);
410 let coset_deg_aware_fft_evals = coset_domain.fft(&rand_poly);
411
412 for (i, (x, coset_x)) in domain.elements().zip(coset_domain.elements()).enumerate() {
413 assert_eq!(deg_aware_fft_evals[i], rand_poly.evaluate(&x));
414 assert_eq!(coset_deg_aware_fft_evals[i], rand_poly.evaluate(&coset_x));
415 }
416 }
417
418 #[test]
419 fn test_roots_of_unity() {
420 let max_degree = 10;
422 for log_domain_size in 0..max_degree {
423 let domain_size = 1 << log_domain_size;
424 let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
425 let actual_roots = domain.roots_of_unity(domain.group_gen);
426 for &value in &actual_roots {
427 assert!(domain.evaluate_vanishing_polynomial(value).is_zero());
428 }
429 let expected_roots_elements = domain.elements();
430 for (expected, &actual) in expected_roots_elements.zip(&actual_roots) {
431 assert_eq!(expected, actual);
432 }
433 assert_eq!(actual_roots.len(), domain_size / 2);
434 }
435 }
436
437 #[test]
438 #[cfg(feature = "parallel")]
439 fn parallel_fft_consistency() {
440 use ark_std::{test_rng, vec::*};
441 use ark_test_curves::bls12_381::Fr;
442
443 fn serial_radix2_fft(a: &mut [Fr], omega: Fr, log_n: u32) {
446 use ark_std::convert::TryFrom;
447 let n = u32::try_from(a.len())
448 .expect("cannot perform FFTs larger on vectors of len > (1 << 32)");
449 assert_eq!(n, 1 << log_n);
450
451 for k in 0..n {
453 let rk = crate::domain::utils::bitreverse(k, log_n);
454 if k < rk {
455 a.swap(rk as usize, k as usize);
456 }
457 }
458
459 let mut m = 1;
460 for _i in 1..=log_n {
461 let w_m = omega.pow([(n / (2 * m)) as u64]);
463
464 let mut k = 0;
465 while k < n {
466 let mut w = Fr::one();
468 for j in 0..m {
469 let mut t = a[(k + j + m) as usize];
470 t *= w;
471 let mut tmp = a[(k + j) as usize];
472 tmp -= t;
473 a[(k + j + m) as usize] = tmp;
474 a[(k + j) as usize] += t;
475 w *= &w_m;
476 }
477
478 k += 2 * m;
479 }
480
481 m *= 2;
482 }
483 }
484
485 fn serial_radix2_ifft(a: &mut [Fr], omega: Fr, log_n: u32) {
486 serial_radix2_fft(a, omega.inverse().unwrap(), log_n);
487 let domain_size_inv = Fr::from(a.len() as u64).inverse().unwrap();
488 for coeff in a.iter_mut() {
489 *coeff *= Fr::from(domain_size_inv);
490 }
491 }
492
493 fn serial_radix2_coset_fft(a: &mut [Fr], omega: Fr, log_n: u32) {
494 let coset_shift = Fr::GENERATOR;
495 let mut cur_pow = Fr::one();
496 for coeff in a.iter_mut() {
497 *coeff *= cur_pow;
498 cur_pow *= coset_shift;
499 }
500 serial_radix2_fft(a, omega, log_n);
501 }
502
503 fn serial_radix2_coset_ifft(a: &mut [Fr], omega: Fr, log_n: u32) {
504 serial_radix2_ifft(a, omega, log_n);
505 let coset_shift = Fr::GENERATOR.inverse().unwrap();
506 let mut cur_pow = Fr::one();
507 for coeff in a.iter_mut() {
508 *coeff *= cur_pow;
509 cur_pow *= coset_shift;
510 }
511 }
512
513 fn test_consistency<R: Rng>(rng: &mut R, max_coeffs: u32) {
514 for _ in 0..5 {
515 for log_d in 0..max_coeffs {
516 let d = 1 << log_d;
517
518 let expected_poly = (0..d).map(|_| Fr::rand(rng)).collect::<Vec<_>>();
519 let mut expected_vec = expected_poly.clone();
520 let mut actual_vec = expected_vec.clone();
521
522 let domain = Radix2EvaluationDomain::new(d).unwrap();
523 let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
524
525 serial_radix2_fft(&mut expected_vec, domain.group_gen, log_d);
526 domain.fft_in_place(&mut actual_vec);
527 assert_eq!(expected_vec, actual_vec);
528
529 serial_radix2_ifft(&mut expected_vec, domain.group_gen, log_d);
530 domain.ifft_in_place(&mut actual_vec);
531 assert_eq!(expected_vec, actual_vec);
532 assert_eq!(expected_vec, expected_poly);
533
534 serial_radix2_coset_fft(&mut expected_vec, domain.group_gen, log_d);
535 coset_domain.fft_in_place(&mut actual_vec);
536 assert_eq!(expected_vec, actual_vec);
537
538 serial_radix2_coset_ifft(&mut expected_vec, domain.group_gen, log_d);
539 coset_domain.ifft_in_place(&mut actual_vec);
540 assert_eq!(expected_vec, actual_vec);
541 }
542 }
543 }
544
545 let rng = &mut test_rng();
546
547 test_consistency(rng, 10);
548 }
549}