1pub use crate::domain::utils::Elements;
10use crate::domain::{
11    DomainCoeff, EvaluationDomain, MixedRadixEvaluationDomain, Radix2EvaluationDomain,
12};
13use ark_ff::FftField;
14use ark_serialize::{
15    CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate,
16};
17use ark_std::{
18    io::{Read, Write},
19    vec::*,
20};
21
22#[derive(Copy, Clone, Hash, Eq, PartialEq, Debug)]
49pub enum GeneralEvaluationDomain<F: FftField> {
50    Radix2(Radix2EvaluationDomain<F>),
52    MixedRadix(MixedRadixEvaluationDomain<F>),
54}
55
56macro_rules! map {
57    ($self:expr, $f1:ident $(, $x:expr)*) => {
58        match $self {
59            Self::Radix2(domain) => EvaluationDomain::$f1(domain, $($x)*),
60            Self::MixedRadix(domain) => EvaluationDomain::$f1(domain, $($x)*),
61        }
62    }
63}
64
65impl<F: FftField> CanonicalSerialize for GeneralEvaluationDomain<F> {
66    fn serialize_with_mode<W: Write>(
67        &self,
68        mut writer: W,
69        compress: Compress,
70    ) -> Result<(), SerializationError> {
71        let variant = match self {
72            GeneralEvaluationDomain::Radix2(_) => 0u8,
73            GeneralEvaluationDomain::MixedRadix(_) => 1u8,
74        };
75        variant.serialize_with_mode(&mut writer, compress)?;
76
77        match self {
78            GeneralEvaluationDomain::Radix2(domain) => {
79                domain.serialize_with_mode(&mut writer, compress)
80            },
81            GeneralEvaluationDomain::MixedRadix(domain) => {
82                domain.serialize_with_mode(&mut writer, compress)
83            },
84        }
85    }
86
87    fn serialized_size(&self, compress: Compress) -> usize {
88        let type_id = match self {
89            GeneralEvaluationDomain::Radix2(_) => 0u8,
90            GeneralEvaluationDomain::MixedRadix(_) => 1u8,
91        };
92
93        type_id.serialized_size(compress)
94            + match self {
95                GeneralEvaluationDomain::Radix2(domain) => domain.serialized_size(compress),
96                GeneralEvaluationDomain::MixedRadix(domain) => domain.serialized_size(compress),
97            }
98    }
99}
100
101impl<F: FftField> Valid for GeneralEvaluationDomain<F> {
102    fn check(&self) -> Result<(), SerializationError> {
103        Ok(())
104    }
105}
106
107impl<F: FftField> CanonicalDeserialize for GeneralEvaluationDomain<F> {
108    fn deserialize_with_mode<R: Read>(
109        mut reader: R,
110        compress: Compress,
111        validate: Validate,
112    ) -> Result<Self, SerializationError> {
113        match u8::deserialize_with_mode(&mut reader, compress, validate)? {
114            0 => Radix2EvaluationDomain::deserialize_with_mode(&mut reader, compress, validate)
115                .map(Self::Radix2),
116            1 => MixedRadixEvaluationDomain::deserialize_with_mode(&mut reader, compress, validate)
117                .map(Self::MixedRadix),
118            _ => Err(SerializationError::InvalidData),
119        }
120    }
121}
122
123impl<F: FftField> EvaluationDomain<F> for GeneralEvaluationDomain<F> {
124    type Elements = GeneralElements<F>;
125
126    fn new(num_coeffs: usize) -> Option<Self> {
133        let domain = Radix2EvaluationDomain::new(num_coeffs);
134        if let Some(domain) = domain {
135            return Some(GeneralEvaluationDomain::Radix2(domain));
136        }
137
138        if F::SMALL_SUBGROUP_BASE.is_some() {
139            return Some(GeneralEvaluationDomain::MixedRadix(
140                MixedRadixEvaluationDomain::new(num_coeffs)?,
141            ));
142        }
143
144        None
145    }
146
147    fn get_coset(&self, offset: F) -> Option<Self> {
148        Some(match self {
149            Self::Radix2(domain) => Self::Radix2(domain.get_coset(offset)?),
150            Self::MixedRadix(domain) => Self::MixedRadix(domain.get_coset(offset)?),
151        })
152    }
153
154    fn compute_size_of_domain(num_coeffs: usize) -> Option<usize> {
155        let domain_size = Radix2EvaluationDomain::<F>::compute_size_of_domain(num_coeffs);
156        if let Some(domain_size) = domain_size {
157            return Some(domain_size);
158        }
159
160        if F::SMALL_SUBGROUP_BASE.is_some() {
161            return MixedRadixEvaluationDomain::<F>::compute_size_of_domain(num_coeffs);
162        }
163
164        None
165    }
166
167    #[inline]
168    fn size(&self) -> usize {
169        map!(self, size)
170    }
171
172    #[inline]
173    fn log_size_of_group(&self) -> u64 {
174        map!(self, log_size_of_group) as u64
175    }
176
177    #[inline]
178    fn size_inv(&self) -> F {
179        map!(self, size_inv)
180    }
181
182    #[inline]
183    fn group_gen(&self) -> F {
184        map!(self, group_gen)
185    }
186
187    #[inline]
188    fn group_gen_inv(&self) -> F {
189        map!(self, group_gen_inv)
190    }
191
192    #[inline]
193    fn coset_offset(&self) -> F {
194        map!(self, coset_offset)
195    }
196
197    #[inline]
198    fn coset_offset_inv(&self) -> F {
199        map!(self, coset_offset_inv)
200    }
201
202    fn coset_offset_pow_size(&self) -> F {
203        map!(self, coset_offset_pow_size)
204    }
205
206    #[inline]
207    fn fft_in_place<T: DomainCoeff<F>>(&self, coeffs: &mut Vec<T>) {
208        map!(self, fft_in_place, coeffs)
209    }
210
211    #[inline]
212    fn ifft_in_place<T: DomainCoeff<F>>(&self, evals: &mut Vec<T>) {
213        map!(self, ifft_in_place, evals)
214    }
215
216    #[inline]
217    fn evaluate_all_lagrange_coefficients(&self, tau: F) -> Vec<F> {
218        map!(self, evaluate_all_lagrange_coefficients, tau)
219    }
220
221    #[inline]
222    fn vanishing_polynomial(&self) -> crate::univariate::SparsePolynomial<F> {
223        map!(self, vanishing_polynomial)
224    }
225
226    #[inline]
227    fn evaluate_vanishing_polynomial(&self, tau: F) -> F {
228        map!(self, evaluate_vanishing_polynomial, tau)
229    }
230
231    fn elements(&self) -> GeneralElements<F> {
233        GeneralElements(map!(self, elements))
234    }
235}
236
237pub struct GeneralElements<F: FftField>(Elements<F>);
239
240impl<F: FftField> Iterator for GeneralElements<F> {
241    type Item = F;
242
243    #[inline]
244    fn next(&mut self) -> Option<F> {
245        self.0.next()
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use crate::{polynomial::Polynomial, EvaluationDomain, GeneralEvaluationDomain};
252    use ark_ff::Zero;
253    use ark_std::{rand::Rng, test_rng};
254    use ark_test_curves::{bls12_381::Fr, bn384_small_two_adicity::Fr as BNFr};
255
256    #[test]
257    fn vanishing_polynomial_evaluation() {
258        let rng = &mut test_rng();
259        for coeffs in 0..10 {
260            let domain = GeneralEvaluationDomain::<Fr>::new(coeffs).unwrap();
261            let z = domain.vanishing_polynomial();
262            for _ in 0..100 {
263                let point = rng.gen();
264                assert_eq!(
265                    z.evaluate(&point),
266                    domain.evaluate_vanishing_polynomial(point)
267                )
268            }
269        }
270
271        for coeffs in 15..17 {
272            let domain = GeneralEvaluationDomain::<BNFr>::new(coeffs).unwrap();
273            let z = domain.vanishing_polynomial();
274            for _ in 0..100 {
275                let point = rng.gen();
276                assert_eq!(
277                    z.evaluate(&point),
278                    domain.evaluate_vanishing_polynomial(point)
279                )
280            }
281        }
282    }
283
284    #[test]
285    fn vanishing_polynomial_vanishes_on_domain() {
286        for coeffs in 0..1000 {
287            let domain = GeneralEvaluationDomain::<Fr>::new(coeffs).unwrap();
288            let z = domain.vanishing_polynomial();
289            for point in domain.elements() {
290                assert!(z.evaluate(&point).is_zero())
291            }
292        }
293    }
294
295    #[test]
296    fn size_of_elements() {
297        for coeffs in 1..10 {
298            let size = 1 << coeffs;
299            let domain = GeneralEvaluationDomain::<Fr>::new(size).unwrap();
300            let domain_size = domain.size();
301            assert_eq!(domain_size, domain.elements().count());
302        }
303    }
304}