1use crate::{
4 evaluations::multivariate::multilinear::swap_bits, DenseMultilinearExtension,
5 MultilinearExtension, Polynomial,
6};
7use ark_ff::{Field, Zero};
8use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
9use ark_std::{
10 collections::BTreeMap,
11 fmt,
12 fmt::{Debug, Formatter},
13 ops::{Add, AddAssign, Index, Neg, Sub, SubAssign},
14 rand::Rng,
15 vec::*,
16 UniformRand,
17};
18use hashbrown::HashMap;
19#[cfg(feature = "parallel")]
20use rayon::prelude::*;
21
22use super::DefaultHasher;
23
24#[derive(Clone, PartialEq, Eq, Hash, Default, CanonicalSerialize, CanonicalDeserialize)]
26pub struct SparseMultilinearExtension<F: Field> {
27 pub evaluations: BTreeMap<usize, F>,
29 pub num_vars: usize,
31 zero: F,
32}
33
34impl<F: Field> SparseMultilinearExtension<F> {
35 pub fn from_evaluations<'a>(
36 num_vars: usize,
37 evaluations: impl IntoIterator<Item = &'a (usize, F)>,
38 ) -> Self {
39 let bit_mask = 1 << num_vars;
40 let evaluations = evaluations.into_iter();
42 let evaluations: Vec<_> = evaluations
43 .map(|(i, v): &(usize, F)| {
44 assert!(*i < bit_mask, "index out of range");
45 (*i, *v)
46 })
47 .collect();
48
49 Self {
50 evaluations: tuples_to_treemap(&evaluations),
51 num_vars,
52 zero: F::zero(),
53 }
54 }
55
56 pub fn rand_with_config<R: Rng>(
65 num_vars: usize,
66 num_nonzero_entries: usize,
67 rng: &mut R,
68 ) -> Self {
69 assert!(num_nonzero_entries <= (1 << num_vars));
70
71 let mut map =
72 HashMap::with_hasher(core::hash::BuildHasherDefault::<DefaultHasher>::default());
73 for _ in 0..num_nonzero_entries {
74 let mut index = usize::rand(rng) & ((1 << num_vars) - 1);
75 while map.get(&index).is_some() {
76 index = usize::rand(rng) & ((1 << num_vars) - 1);
77 }
78 map.entry(index).or_insert(F::rand(rng));
79 }
80 let mut buf = Vec::new();
81 for (arg, v) in map.iter() {
82 if *v != F::zero() {
83 buf.push((*arg, *v));
84 }
85 }
86 let evaluations = hashmap_to_treemap(&map);
87 Self {
88 num_vars,
89 evaluations,
90 zero: F::zero(),
91 }
92 }
93
94 pub fn to_dense_multilinear_extension(&self) -> DenseMultilinearExtension<F> {
96 let mut evaluations: Vec<_> = (0..(1 << self.num_vars)).map(|_| F::zero()).collect();
97 for (&i, &v) in self.evaluations.iter() {
98 evaluations[i] = v;
99 }
100 DenseMultilinearExtension::from_evaluations_vec(self.num_vars, evaluations)
101 }
102}
103
104fn precompute_eq<F: Field>(g: &[F]) -> Vec<F> {
106 let dim = g.len();
107 let mut dp = vec![F::zero(); 1 << dim];
108 dp[0] = F::one() - g[0];
109 dp[1] = g[0];
110 for i in 1..dim {
111 for b in 0..(1 << i) {
112 let prev = dp[b];
113 dp[b + (1 << i)] = prev * g[i];
114 dp[b] = prev - dp[b + (1 << i)];
115 }
116 }
117 dp
118}
119
120impl<F: Field> MultilinearExtension<F> for SparseMultilinearExtension<F> {
121 fn num_vars(&self) -> usize {
122 self.num_vars
123 }
124
125 fn rand<R: Rng>(num_vars: usize, rng: &mut R) -> Self {
130 Self::rand_with_config(num_vars, 1 << (num_vars / 2), rng)
131 }
132
133 fn relabel(&self, mut a: usize, mut b: usize, k: usize) -> Self {
134 if a > b {
135 core::mem::swap(&mut a, &mut b);
137 }
138 assert!(
140 a + k < self.num_vars && b + k < self.num_vars,
141 "invalid relabel argument"
142 );
143 if a == b || k == 0 {
144 return self.clone();
145 }
146 assert!(a + k <= b, "overlapped swap window is not allowed");
147 let ev: Vec<_> = cfg_iter!(self.evaluations)
148 .map(|(&i, &v)| (swap_bits(i, a, b, k), v))
149 .collect();
150 Self {
151 num_vars: self.num_vars,
152 evaluations: tuples_to_treemap(&ev),
153 zero: F::zero(),
154 }
155 }
156
157 fn fix_variables(&self, partial_point: &[F]) -> Self {
158 let dim = partial_point.len();
159 assert!(dim <= self.num_vars, "invalid partial point dimension");
160
161 let mut window = ark_std::log2(self.evaluations.len()) as usize;
162 if window == 0 {
163 window = 1;
164 }
165 let mut point = partial_point;
166 let mut last = treemap_to_hashmap(&self.evaluations);
167
168 while !point.is_empty() {
170 let focus_length = if point.len() > window {
171 window
172 } else {
173 point.len()
174 };
175 let focus = &point[..focus_length];
176 point = &point[focus_length..];
177 let pre = precompute_eq(focus);
178 let dim = focus.len();
179 let mut result =
180 HashMap::with_hasher(core::hash::BuildHasherDefault::<DefaultHasher>::default());
181 for src_entry in last.iter() {
182 let old_idx = *src_entry.0;
183 let gz = pre[old_idx & ((1 << dim) - 1)];
184 let new_idx = old_idx >> dim;
185 let dst_entry = result.entry(new_idx).or_insert(F::zero());
186 *dst_entry += gz * src_entry.1;
187 }
188 last = result;
189 }
190 let evaluations = hashmap_to_treemap(&last);
191 Self {
192 num_vars: self.num_vars - dim,
193 evaluations,
194 zero: F::zero(),
195 }
196 }
197
198 fn to_evaluations(&self) -> Vec<F> {
199 let mut evaluations: Vec<_> = (0..1 << self.num_vars).map(|_| F::zero()).collect();
200 self.evaluations
201 .iter()
202 .map(|(&i, &v)| evaluations[i] = v)
203 .last();
204 evaluations
205 }
206}
207
208impl<F: Field> Index<usize> for SparseMultilinearExtension<F> {
209 type Output = F;
210
211 fn index(&self, index: usize) -> &Self::Output {
220 if let Some(v) = self.evaluations.get(&index) {
221 v
222 } else {
223 &self.zero
224 }
225 }
226}
227
228impl<F: Field> Polynomial<F> for SparseMultilinearExtension<F> {
229 type Point = Vec<F>;
230
231 fn degree(&self) -> usize {
232 self.num_vars
233 }
234
235 fn evaluate(&self, point: &Self::Point) -> F {
236 assert!(point.len() == self.num_vars);
237 self.fix_variables(point)[0]
238 }
239}
240
241impl<F: Field> Add for SparseMultilinearExtension<F> {
242 type Output = SparseMultilinearExtension<F>;
243
244 fn add(self, other: SparseMultilinearExtension<F>) -> Self {
245 &self + &other
246 }
247}
248
249impl<'a, 'b, F: Field> Add<&'a SparseMultilinearExtension<F>>
250 for &'b SparseMultilinearExtension<F>
251{
252 type Output = SparseMultilinearExtension<F>;
253
254 fn add(self, rhs: &'a SparseMultilinearExtension<F>) -> Self::Output {
255 if self.is_zero() {
257 return rhs.clone();
258 }
259 if rhs.is_zero() {
260 return self.clone();
261 }
262
263 assert_eq!(
264 rhs.num_vars, self.num_vars,
265 "trying to add non-zero polynomial with different number of variables"
266 );
267 let mut evaluations =
269 HashMap::with_hasher(core::hash::BuildHasherDefault::<DefaultHasher>::default());
270 for (&i, &v) in self.evaluations.iter().chain(rhs.evaluations.iter()) {
271 *(evaluations.entry(i).or_insert(F::zero())) += v;
272 }
273 let evaluations: Vec<_> = evaluations
274 .into_iter()
275 .filter(|(_, v)| !v.is_zero())
276 .collect();
277
278 Self::Output {
279 evaluations: tuples_to_treemap(&evaluations),
280 num_vars: self.num_vars,
281 zero: F::zero(),
282 }
283 }
284}
285
286impl<F: Field> AddAssign for SparseMultilinearExtension<F> {
287 fn add_assign(&mut self, other: Self) {
288 *self = &*self + &other;
289 }
290}
291
292impl<'a, F: Field> AddAssign<&'a SparseMultilinearExtension<F>> for SparseMultilinearExtension<F> {
293 fn add_assign(&mut self, other: &'a SparseMultilinearExtension<F>) {
294 *self = &*self + other;
295 }
296}
297
298impl<'a, F: Field> AddAssign<(F, &'a SparseMultilinearExtension<F>)>
299 for SparseMultilinearExtension<F>
300{
301 fn add_assign(&mut self, (f, other): (F, &'a SparseMultilinearExtension<F>)) {
302 if !self.is_zero() && !other.is_zero() {
303 assert_eq!(
304 other.num_vars, self.num_vars,
305 "trying to add non-zero polynomial with different number of variables"
306 );
307 }
308 let ev: Vec<_> = cfg_iter!(other.evaluations)
309 .map(|(i, v)| (*i, f * v))
310 .collect();
311 let other = Self {
312 num_vars: other.num_vars,
313 evaluations: tuples_to_treemap(&ev),
314 zero: F::zero(),
315 };
316 *self += &other;
317 }
318}
319
320impl<F: Field> Neg for SparseMultilinearExtension<F> {
321 type Output = SparseMultilinearExtension<F>;
322
323 fn neg(self) -> Self::Output {
324 let ev: Vec<_> = cfg_iter!(self.evaluations)
325 .map(|(i, v)| (*i, -*v))
326 .collect();
327 Self::Output {
328 num_vars: self.num_vars,
329 evaluations: tuples_to_treemap(&ev),
330 zero: F::zero(),
331 }
332 }
333}
334
335impl<F: Field> Sub for SparseMultilinearExtension<F> {
336 type Output = SparseMultilinearExtension<F>;
337
338 fn sub(self, other: SparseMultilinearExtension<F>) -> Self {
339 &self - &other
340 }
341}
342
343impl<'a, 'b, F: Field> Sub<&'a SparseMultilinearExtension<F>>
344 for &'b SparseMultilinearExtension<F>
345{
346 type Output = SparseMultilinearExtension<F>;
347
348 fn sub(self, rhs: &'a SparseMultilinearExtension<F>) -> Self::Output {
349 self + &rhs.clone().neg()
350 }
351}
352
353impl<F: Field> SubAssign for SparseMultilinearExtension<F> {
354 fn sub_assign(&mut self, other: Self) {
355 *self = &*self - &other;
356 }
357}
358
359impl<'a, F: Field> SubAssign<&'a SparseMultilinearExtension<F>> for SparseMultilinearExtension<F> {
360 fn sub_assign(&mut self, other: &'a SparseMultilinearExtension<F>) {
361 *self = &*self - other;
362 }
363}
364
365impl<F: Field> Zero for SparseMultilinearExtension<F> {
366 fn zero() -> Self {
367 Self {
368 num_vars: 0,
369 evaluations: tuples_to_treemap(&Vec::new()),
370 zero: F::zero(),
371 }
372 }
373
374 fn is_zero(&self) -> bool {
375 self.num_vars == 0 && self.evaluations.is_empty()
376 }
377}
378
379impl<F: Field> Debug for SparseMultilinearExtension<F> {
380 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
381 write!(
382 f,
383 "SparseMultilinearPolynomial(num_vars = {}, evaluations = [",
384 self.num_vars
385 )?;
386 let mut ev_iter = self.evaluations.iter();
387 for _ in 0..ark_std::cmp::min(8, self.evaluations.len()) {
388 write!(f, "{:?}", ev_iter.next())?;
389 }
390 if self.evaluations.len() > 8 {
391 write!(f, "...")?;
392 }
393 write!(f, "])")?;
394 Ok(())
395 }
396}
397
398fn tuples_to_treemap<F: Field>(tuples: &[(usize, F)]) -> BTreeMap<usize, F> {
400 BTreeMap::from_iter(tuples.iter().map(|(i, v)| (*i, *v)))
401}
402
403fn treemap_to_hashmap<F: Field>(
404 map: &BTreeMap<usize, F>,
405) -> HashMap<usize, F, core::hash::BuildHasherDefault<DefaultHasher>> {
406 HashMap::from_iter(map.iter().map(|(i, v)| (*i, *v)))
407}
408
409fn hashmap_to_treemap<F: Field, S>(map: &HashMap<usize, F, S>) -> BTreeMap<usize, F> {
410 BTreeMap::from_iter(map.iter().map(|(i, v)| (*i, *v)))
411}
412
413#[cfg(test)]
414mod tests {
415 use crate::{
416 evaluations::multivariate::multilinear::MultilinearExtension, Polynomial,
417 SparseMultilinearExtension,
418 };
419 use ark_ff::{One, Zero};
420 use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
421 use ark_std::{ops::Neg, test_rng, vec::*, UniformRand};
422 use ark_test_curves::bls12_381::Fr;
423 #[test]
425 fn random_poly() {
426 const NV: usize = 16;
427
428 let mut rng = test_rng();
429 let poly1 = SparseMultilinearExtension::<Fr>::rand(NV, &mut rng);
431 let poly2 = SparseMultilinearExtension::<Fr>::rand(NV, &mut rng);
432 assert_ne!(poly1, poly2);
433 assert!(
435 ((1 << (NV / 2)) >> 1) <= poly1.evaluations.len()
436 && poly1.evaluations.len() <= ((1 << (NV / 2)) << 1),
437 "polynomial size out of range: expected: [{},{}] ,actual: {}",
438 ((1 << (NV / 2)) >> 1),
439 ((1 << (NV / 2)) << 1),
440 poly1.evaluations.len()
441 );
442 }
443
444 #[test]
445 fn evaluate() {
448 const NV: usize = 12;
449 let mut rng = test_rng();
450 for _ in 0..20 {
451 let sparse = SparseMultilinearExtension::<Fr>::rand(NV, &mut rng);
452 let dense = sparse.to_dense_multilinear_extension();
453 let point: Vec<_> = (0..NV).map(|_| Fr::rand(&mut rng)).collect();
454 assert_eq!(sparse.evaluate(&point), dense.evaluate(&point));
455 let sparse_partial = sparse.fix_variables(&point[..3].to_vec());
456 let dense_partial = dense.fix_variables(&point[..3].to_vec());
457 let point2: Vec<_> = (0..(NV - 3)).map(|_| Fr::rand(&mut rng)).collect();
458 assert_eq!(
459 sparse_partial.evaluate(&point2),
460 dense_partial.evaluate(&point2)
461 );
462 }
463 }
464
465 #[test]
466 fn evaluate_edge_cases() {
467 let mut rng = test_rng();
469 let ev1 = Fr::rand(&mut rng);
470 let poly1 = SparseMultilinearExtension::from_evaluations(0, &vec![(0, ev1)]);
471 assert_eq!(poly1.evaluate(&[].into()), ev1);
472
473 let ev2 = vec![Fr::rand(&mut rng), Fr::rand(&mut rng)];
475 let poly2 =
476 SparseMultilinearExtension::from_evaluations(1, &vec![(0, ev2[0]), (1, ev2[1])]);
477
478 let x = Fr::rand(&mut rng);
479 assert_eq!(
480 poly2.evaluate(&[x].into()),
481 x * ev2[1] + (Fr::one() - x) * ev2[0]
482 );
483
484 let ev3 = Fr::rand(&mut rng);
486 let poly2 = SparseMultilinearExtension::from_evaluations(1, &vec![(1, ev3)]);
487
488 let x = Fr::rand(&mut rng);
489 assert_eq!(poly2.evaluate(&[x].into()), x * ev3);
490 }
491
492 #[test]
493 fn index() {
494 let mut rng = test_rng();
495 let points = vec![
496 (11, Fr::rand(&mut rng)),
497 (117, Fr::rand(&mut rng)),
498 (213, Fr::rand(&mut rng)),
499 (255, Fr::rand(&mut rng)),
500 ];
501 let poly = SparseMultilinearExtension::from_evaluations(8, &points);
502 points
503 .into_iter()
504 .map(|(i, v)| assert_eq!(poly[i], v))
505 .last();
506 assert_eq!(poly[0], Fr::zero());
507 assert_eq!(poly[1], Fr::zero());
508 }
509
510 #[test]
511 fn arithmetic() {
512 const NV: usize = 18;
513 let mut rng = test_rng();
514 for _ in 0..20 {
515 let point: Vec<_> = (0..NV).map(|_| Fr::rand(&mut rng)).collect();
516 let poly1 = SparseMultilinearExtension::rand(NV, &mut rng);
517 let poly2 = SparseMultilinearExtension::rand(NV, &mut rng);
518 let v1 = poly1.evaluate(&point);
519 let v2 = poly2.evaluate(&point);
520 assert_eq!((&poly1 + &poly2).evaluate(&point), v1 + v2);
522 assert_eq!((&poly1 - &poly2).evaluate(&point), v1 - v2);
524 assert_eq!(poly1.clone().neg().evaluate(&point), -v1);
526 {
528 let mut poly1 = poly1.clone();
529 poly1 += &poly2;
530 assert_eq!(poly1.evaluate(&point), v1 + v2)
531 }
532 {
534 let mut poly1 = poly1.clone();
535 poly1 -= &poly2;
536 assert_eq!(poly1.evaluate(&point), v1 - v2)
537 }
538 {
540 let mut poly1 = poly1.clone();
541 let scalar = Fr::rand(&mut rng);
542 poly1 += (scalar, &poly2);
543 assert_eq!(poly1.evaluate(&point), v1 + scalar * v2)
544 }
545 {
547 assert_eq!(&poly1 + &SparseMultilinearExtension::zero(), poly1);
548 assert_eq!(&SparseMultilinearExtension::zero() + &poly1, poly1);
549 {
550 let mut poly1_cloned = poly1.clone();
551 poly1_cloned += &SparseMultilinearExtension::zero();
552 assert_eq!(&poly1_cloned, &poly1);
553 let mut zero = SparseMultilinearExtension::zero();
554 let scalar = Fr::rand(&mut rng);
555 zero += (scalar, &poly1);
556 assert_eq!(zero.evaluate(&point), scalar * v1);
557 }
558 }
559 }
560 }
561
562 #[test]
563 fn relabel() {
564 let mut rng = test_rng();
565 for _ in 0..20 {
566 let mut poly = SparseMultilinearExtension::rand(10, &mut rng);
567 let mut point: Vec<_> = (0..10).map(|_| Fr::rand(&mut rng)).collect();
568
569 let expected = poly.evaluate(&point);
570
571 poly = poly.relabel(2, 2, 1); assert_eq!(expected, poly.evaluate(&point));
573
574 poly = poly.relabel(3, 4, 1); point.swap(3, 4);
576 assert_eq!(expected, poly.evaluate(&point));
577
578 poly = poly.relabel(7, 5, 1);
579 point.swap(7, 5);
580 assert_eq!(expected, poly.evaluate(&point));
581
582 poly = poly.relabel(2, 5, 3);
583 point.swap(2, 5);
584 point.swap(3, 6);
585 point.swap(4, 7);
586 assert_eq!(expected, poly.evaluate(&point));
587
588 poly = poly.relabel(7, 0, 2);
589 point.swap(0, 7);
590 point.swap(1, 8);
591 assert_eq!(expected, poly.evaluate(&point));
592 }
593 }
594
595 #[test]
596 fn serialize() {
597 let mut rng = test_rng();
598 for _ in 0..20 {
599 let mut buf = Vec::new();
600 let poly = SparseMultilinearExtension::<Fr>::rand(10, &mut rng);
601 let point: Vec<_> = (0..10).map(|_| Fr::rand(&mut rng)).collect();
602 let expected = poly.evaluate(&point);
603
604 poly.serialize_compressed(&mut buf).unwrap();
605
606 let poly2: SparseMultilinearExtension<Fr> =
607 SparseMultilinearExtension::deserialize_compressed(&buf[..]).unwrap();
608 assert_eq!(poly2.evaluate(&point), expected);
609 }
610 }
611}