1use crate::{
4 evaluations::multivariate::multilinear::{swap_bits, MultilinearExtension},
5 Polynomial,
6};
7use ark_ff::{Field, Zero};
8use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
9use ark_std::{
10 fmt,
11 fmt::Formatter,
12 iter::IntoIterator,
13 log2,
14 ops::{Add, AddAssign, Index, Mul, MulAssign, Neg, Sub, SubAssign},
15 rand::Rng,
16 slice::{Iter, IterMut},
17 vec::*,
18};
19#[cfg(feature = "parallel")]
20use rayon::prelude::*;
21
22#[derive(Clone, PartialEq, Eq, Hash, Default, CanonicalSerialize, CanonicalDeserialize)]
24pub struct DenseMultilinearExtension<F: Field> {
25 pub evaluations: Vec<F>,
27 pub num_vars: usize,
29}
30
31impl<F: Field> DenseMultilinearExtension<F> {
32 pub fn from_evaluations_slice(num_vars: usize, evaluations: &[F]) -> Self {
36 Self::from_evaluations_vec(num_vars, evaluations.to_vec())
37 }
38
39 pub fn from_evaluations_vec(num_vars: usize, evaluations: Vec<F>) -> Self {
58 assert_eq!(
60 evaluations.len(),
61 1 << num_vars,
62 "The size of evaluations should be 2^num_vars."
63 );
64
65 Self {
66 num_vars,
67 evaluations,
68 }
69 }
70 pub fn relabel_in_place(&mut self, mut a: usize, mut b: usize, k: usize) {
76 if a > b {
78 ark_std::mem::swap(&mut a, &mut b);
79 }
80 if a == b || k == 0 {
81 return;
82 }
83 assert!(b + k <= self.num_vars, "invalid relabel argument");
84 assert!(a + k <= b, "overlapped swap window is not allowed");
85 for i in 0..self.evaluations.len() {
86 let j = swap_bits(i, a, b, k);
87 if i < j {
88 self.evaluations.swap(i, j);
89 }
90 }
91 }
92
93 pub fn iter(&self) -> Iter<'_, F> {
95 self.evaluations.iter()
96 }
97
98 pub fn iter_mut(&mut self) -> IterMut<'_, F> {
100 self.evaluations.iter_mut()
101 }
102
103 pub fn concat(polys: impl IntoIterator<Item = impl AsRef<Self>> + Clone) -> Self {
133 let polys_iter_cloned = polys.clone().into_iter();
136
137 let total_len: usize = polys
138 .into_iter()
139 .map(|poly| poly.as_ref().evaluations.len())
140 .sum();
141
142 let next_pow_of_two = total_len.next_power_of_two();
143 let num_vars = log2(next_pow_of_two);
144 let mut evaluations: Vec<F> = Vec::with_capacity(next_pow_of_two);
145
146 for poly in polys_iter_cloned {
147 evaluations.extend_from_slice(&poly.as_ref().evaluations.as_slice());
148 }
149
150 evaluations.resize(next_pow_of_two, F::zero());
151
152 Self::from_evaluations_slice(num_vars as usize, &evaluations)
153 }
154}
155
156impl<F: Field> AsRef<DenseMultilinearExtension<F>> for DenseMultilinearExtension<F> {
157 fn as_ref(&self) -> &DenseMultilinearExtension<F> {
158 self
159 }
160}
161
162impl<F: Field> MultilinearExtension<F> for DenseMultilinearExtension<F> {
163 fn num_vars(&self) -> usize {
164 self.num_vars
165 }
166
167 fn rand<R: Rng>(num_vars: usize, rng: &mut R) -> Self {
168 Self::from_evaluations_vec(
169 num_vars,
170 (0..(1 << num_vars)).map(|_| F::rand(rng)).collect(),
171 )
172 }
173
174 fn relabel(&self, a: usize, b: usize, k: usize) -> Self {
175 let mut copied = self.clone();
176 copied.relabel_in_place(a, b, k);
177 copied
178 }
179
180 fn fix_variables(&self, partial_point: &[F]) -> Self {
204 assert!(
205 partial_point.len() <= self.num_vars,
206 "invalid size of partial point"
207 );
208 let mut poly = self.evaluations.to_vec();
209 let nv = self.num_vars;
210 let dim = partial_point.len();
211 for i in 1..dim + 1 {
213 let r = partial_point[i - 1];
214 for b in 0..(1 << (nv - i)) {
215 let left = poly[b << 1];
216 let right = poly[(b << 1) + 1];
217 poly[b] = left + r * (right - left);
218 }
219 }
220 Self::from_evaluations_slice(nv - dim, &poly[..(1 << (nv - dim))])
221 }
222
223 fn to_evaluations(&self) -> Vec<F> {
224 self.evaluations.to_vec()
225 }
226}
227
228impl<F: Field> Index<usize> for DenseMultilinearExtension<F> {
229 type Output = F;
230
231 fn index(&self, index: usize) -> &Self::Output {
238 &self.evaluations[index]
239 }
240}
241
242impl<F: Field> Add for DenseMultilinearExtension<F> {
243 type Output = DenseMultilinearExtension<F>;
244
245 fn add(self, other: DenseMultilinearExtension<F>) -> Self {
246 &self + &other
247 }
248}
249
250impl<'a, 'b, F: Field> Add<&'a DenseMultilinearExtension<F>> for &'b DenseMultilinearExtension<F> {
251 type Output = DenseMultilinearExtension<F>;
252
253 fn add(self, rhs: &'a DenseMultilinearExtension<F>) -> Self::Output {
254 if rhs.is_zero() {
256 return self.clone();
257 }
258 if self.is_zero() {
259 return rhs.clone();
260 }
261 assert_eq!(self.num_vars, rhs.num_vars);
262 let result: Vec<F> = cfg_iter!(self.evaluations)
263 .zip(cfg_iter!(rhs.evaluations))
264 .map(|(a, b)| *a + *b)
265 .collect();
266
267 Self::Output::from_evaluations_vec(self.num_vars, result)
268 }
269}
270
271impl<F: Field> AddAssign for DenseMultilinearExtension<F> {
272 fn add_assign(&mut self, other: Self) {
273 *self = &*self + &other;
274 }
275}
276
277impl<'a, F: Field> AddAssign<&'a DenseMultilinearExtension<F>> for DenseMultilinearExtension<F> {
278 fn add_assign(&mut self, other: &'a DenseMultilinearExtension<F>) {
279 *self = &*self + other;
280 }
281}
282
283impl<'a, F: Field> AddAssign<(F, &'a DenseMultilinearExtension<F>)>
284 for DenseMultilinearExtension<F>
285{
286 fn add_assign(&mut self, (f, other): (F, &'a DenseMultilinearExtension<F>)) {
287 let other = Self {
288 num_vars: other.num_vars,
289 evaluations: cfg_iter!(other.evaluations).map(|x| f * x).collect(),
290 };
291 *self = &*self + &other;
292 }
293}
294
295impl<F: Field> Neg for DenseMultilinearExtension<F> {
296 type Output = DenseMultilinearExtension<F>;
297
298 fn neg(self) -> Self::Output {
299 Self::Output {
300 num_vars: self.num_vars,
301 evaluations: cfg_iter!(self.evaluations).map(|x| -*x).collect(),
302 }
303 }
304}
305
306impl<F: Field> Sub for DenseMultilinearExtension<F> {
307 type Output = DenseMultilinearExtension<F>;
308
309 fn sub(self, other: DenseMultilinearExtension<F>) -> Self {
310 &self - &other
311 }
312}
313
314impl<'a, 'b, F: Field> Sub<&'a DenseMultilinearExtension<F>> for &'b DenseMultilinearExtension<F> {
315 type Output = DenseMultilinearExtension<F>;
316
317 fn sub(self, rhs: &'a DenseMultilinearExtension<F>) -> Self::Output {
318 self + &rhs.clone().neg()
319 }
320}
321
322impl<F: Field> SubAssign for DenseMultilinearExtension<F> {
323 fn sub_assign(&mut self, other: Self) {
324 *self = &*self - &other;
325 }
326}
327
328impl<'a, F: Field> SubAssign<&'a DenseMultilinearExtension<F>> for DenseMultilinearExtension<F> {
329 fn sub_assign(&mut self, other: &'a DenseMultilinearExtension<F>) {
330 *self = &*self - other;
331 }
332}
333
334impl<F: Field> Mul<F> for DenseMultilinearExtension<F> {
335 type Output = DenseMultilinearExtension<F>;
336
337 fn mul(self, scalar: F) -> Self::Output {
338 &self * &scalar
339 }
340}
341
342impl<'a, 'b, F: Field> Mul<&'a F> for &'b DenseMultilinearExtension<F> {
343 type Output = DenseMultilinearExtension<F>;
344
345 fn mul(self, scalar: &'a F) -> Self::Output {
346 if scalar.is_zero() {
347 return DenseMultilinearExtension::zero();
348 } else if scalar.is_one() {
349 return self.clone();
350 }
351 let result: Vec<F> = self.evaluations.iter().map(|&x| x * scalar).collect();
352
353 DenseMultilinearExtension {
354 num_vars: self.num_vars,
355 evaluations: result,
356 }
357 }
358}
359
360impl<F: Field> MulAssign<F> for DenseMultilinearExtension<F> {
361 fn mul_assign(&mut self, scalar: F) {
362 *self = &*self * &scalar
363 }
364}
365
366impl<'a, F: Field> MulAssign<&'a F> for DenseMultilinearExtension<F> {
367 fn mul_assign(&mut self, scalar: &'a F) {
368 *self = &*self * scalar
369 }
370}
371
372impl<F: Field> fmt::Debug for DenseMultilinearExtension<F> {
373 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
374 write!(f, "DenseML(nv = {}, evaluations = [", self.num_vars)?;
375 for i in 0..ark_std::cmp::min(4, self.evaluations.len()) {
376 write!(f, "{:?} ", self.evaluations[i])?;
377 }
378 if self.evaluations.len() < 4 {
379 write!(f, "])")?;
380 } else {
381 write!(f, "...])")?;
382 }
383 Ok(())
384 }
385}
386
387impl<F: Field> Zero for DenseMultilinearExtension<F> {
388 fn zero() -> Self {
389 Self {
390 num_vars: 0,
391 evaluations: vec![F::zero()],
392 }
393 }
394
395 fn is_zero(&self) -> bool {
396 self.num_vars == 0 && self.evaluations[0].is_zero()
397 }
398}
399
400impl<F: Field> Polynomial<F> for DenseMultilinearExtension<F> {
401 type Point = Vec<F>;
402
403 fn degree(&self) -> usize {
404 self.num_vars
405 }
406
407 fn evaluate(&self, point: &Self::Point) -> F {
427 assert!(point.len() == self.num_vars);
428 self.fix_variables(&point)[0]
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use crate::{DenseMultilinearExtension, MultilinearExtension, Polynomial};
435 use ark_ff::{Field, One, Zero};
436 use ark_std::{ops::Neg, test_rng, vec::*, UniformRand};
437 use ark_test_curves::bls12_381::Fr;
438
439 fn evaluate_data_array<F: Field>(data: &[F], point: &[F]) -> F {
441 if data.len() != (1 << point.len()) {
442 panic!("Data size mismatch with number of variables. ")
443 }
444
445 let nv = point.len();
446 let mut a = data.to_vec();
447
448 for i in 1..nv + 1 {
449 let r = point[i - 1];
450 for b in 0..(1 << (nv - i)) {
451 a[b] = a[b << 1] * (F::one() - r) + a[(b << 1) + 1] * r;
452 }
453 }
454 a[0]
455 }
456
457 #[test]
458 fn evaluate_at_a_point() {
459 let mut rng = test_rng();
460 let poly = DenseMultilinearExtension::rand(10, &mut rng);
461 for _ in 0..10 {
462 let point: Vec<_> = (0..10).map(|_| Fr::rand(&mut rng)).collect();
463 assert_eq!(
464 evaluate_data_array(&poly.evaluations, &point),
465 poly.evaluate(&point)
466 )
467 }
468 }
469
470 #[test]
471 fn relabel_polynomial() {
472 let mut rng = test_rng();
473 for _ in 0..20 {
474 let mut poly = DenseMultilinearExtension::rand(10, &mut rng);
475 let mut point: Vec<_> = (0..10).map(|_| Fr::rand(&mut rng)).collect();
476
477 let expected = poly.evaluate(&point);
478
479 poly.relabel_in_place(2, 2, 1); assert_eq!(expected, poly.evaluate(&point));
481
482 poly.relabel_in_place(3, 4, 1); point.swap(3, 4);
484 assert_eq!(expected, poly.evaluate(&point));
485
486 poly.relabel_in_place(7, 5, 1);
487 point.swap(7, 5);
488 assert_eq!(expected, poly.evaluate(&point));
489
490 poly.relabel_in_place(2, 5, 3);
491 point.swap(2, 5);
492 point.swap(3, 6);
493 point.swap(4, 7);
494 assert_eq!(expected, poly.evaluate(&point));
495
496 poly.relabel_in_place(7, 0, 2);
497 point.swap(0, 7);
498 point.swap(1, 8);
499 assert_eq!(expected, poly.evaluate(&point));
500
501 poly.relabel_in_place(0, 9, 1);
502 point.swap(0, 9);
503 assert_eq!(expected, poly.evaluate(&point));
504 }
505 }
506
507 #[test]
508 fn arithmetic() {
509 const NV: usize = 10;
510 let mut rng = test_rng();
511 for _ in 0..20 {
512 let scalar = Fr::rand(&mut rng);
513 let point: Vec<_> = (0..NV).map(|_| Fr::rand(&mut rng)).collect();
514 let poly1 = DenseMultilinearExtension::rand(NV, &mut rng);
515 let poly2 = DenseMultilinearExtension::rand(NV, &mut rng);
516 let v1 = poly1.evaluate(&point);
517 let v2 = poly2.evaluate(&point);
518 assert_eq!((&poly1 + &poly2).evaluate(&point), v1 + v2);
520 assert_eq!((&poly1 - &poly2).evaluate(&point), v1 - v2);
522 assert_eq!(poly1.clone().neg().evaluate(&point), -v1);
524 assert_eq!((&poly1 * &scalar).evaluate(&point), v1 * scalar);
526 {
528 let mut poly1 = poly1.clone();
529 poly1 += &poly2;
530 assert_eq!(poly1.evaluate(&point), v1 + v2)
531 }
532 {
534 let mut poly1 = poly1.clone();
535 poly1 -= &poly2;
536 assert_eq!(poly1.evaluate(&point), v1 - v2)
537 }
538 {
540 let mut poly1 = poly1.clone();
541 let scalar = Fr::rand(&mut rng);
542 poly1 += (scalar, &poly2);
543 assert_eq!(poly1.evaluate(&point), v1 + scalar * v2)
544 }
545 {
547 assert_eq!(&poly1 + &DenseMultilinearExtension::zero(), poly1);
548 assert_eq!(&DenseMultilinearExtension::zero() + &poly1, poly1);
549 {
550 let mut poly1_cloned = poly1.clone();
551 poly1_cloned += &DenseMultilinearExtension::zero();
552 assert_eq!(&poly1_cloned, &poly1);
553 let mut zero = DenseMultilinearExtension::zero();
554 let scalar = Fr::rand(&mut rng);
555 zero += (scalar, &poly1);
556 assert_eq!(zero.evaluate(&point), scalar * v1);
557 }
558 }
559 {
561 let mut poly1_cloned = poly1.clone();
562 poly1_cloned *= Fr::one();
563 assert_eq!(poly1_cloned.evaluate(&point), v1);
564 poly1_cloned *= scalar;
565 assert_eq!(poly1_cloned.evaluate(&point), v1 * scalar);
566 poly1_cloned *= Fr::zero();
567 assert_eq!(poly1_cloned, DenseMultilinearExtension::zero());
568 }
569 }
570 }
571
572 #[test]
573 fn concat_two_equal_polys() {
574 let mut rng = test_rng();
575 let degree = 10;
576
577 let poly_l = DenseMultilinearExtension::rand(degree, &mut rng);
578 let poly_r = DenseMultilinearExtension::rand(degree, &mut rng);
579
580 let merged = DenseMultilinearExtension::concat(&[&poly_l, &poly_r]);
581 for _ in 0..10 {
582 let point: Vec<_> = (0..(degree + 1)).map(|_| Fr::rand(&mut rng)).collect();
583
584 let expected = (Fr::ONE - point[10]) * poly_l.evaluate(&point[..10].to_vec())
585 + point[10] * poly_r.evaluate(&point[..10].to_vec());
586 assert_eq!(expected, merged.evaluate(&point));
587 }
588 }
589
590 #[test]
591 fn concat_unequal_polys() {
592 let mut rng = test_rng();
593 let degree = 10;
594 let poly_l = DenseMultilinearExtension::rand(degree, &mut rng);
595 let poly_r = DenseMultilinearExtension::rand(degree - 1, &mut rng);
597
598 let merged = DenseMultilinearExtension::concat(&[&poly_l, &poly_r]);
599
600 for _ in 0..10 {
601 let point: Vec<_> = (0..(degree + 1)).map(|_| Fr::rand(&mut rng)).collect();
602
603 let expected = (Fr::ONE - point[10]) * poly_l.evaluate(&point[..10].to_vec())
606 + point[10] * ((Fr::ONE - point[9]) * poly_r.evaluate(&point[..9].to_vec()));
607 assert_eq!(expected, merged.evaluate(&point));
608 }
609 }
610
611 #[test]
612 fn concat_two_iterators() {
613 let mut rng = test_rng();
614 let degree = 10;
615
616 let polys_l: Vec<_> = (0..2)
618 .map(|_| DenseMultilinearExtension::rand(degree - 2, &mut test_rng()))
619 .collect();
620 let polys_r: Vec<_> = (0..2)
621 .map(|_| DenseMultilinearExtension::rand(degree - 2, &mut test_rng()))
622 .collect();
623
624 let merged = DenseMultilinearExtension::<Fr>::concat(polys_l.iter().chain(polys_r.iter()));
625
626 for _ in 0..10 {
627 let point: Vec<_> = (0..(degree)).map(|_| Fr::rand(&mut rng)).collect();
628
629 let expected = (Fr::ONE - point[9])
630 * ((Fr::ONE - point[8]) * polys_l[0].evaluate(&point[..8].to_vec())
631 + point[8] * polys_l[1].evaluate(&point[..8].to_vec()))
632 + point[9]
633 * ((Fr::ONE - point[8]) * polys_r[0].evaluate(&point[..8].to_vec())
634 + point[8] * polys_r[1].evaluate(&point[..8].to_vec()));
635
636 assert_eq!(expected, merged.evaluate(&point));
637 }
638 }
639}