ark_poly/evaluations/multivariate/multilinear/
sparse.rs

1//! multilinear polynomial represented in sparse evaluation form.
2
3use crate::{
4    evaluations::multivariate::multilinear::swap_bits, DenseMultilinearExtension,
5    MultilinearExtension, Polynomial,
6};
7use ark_ff::{Field, Zero};
8use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
9use ark_std::{
10    collections::BTreeMap,
11    fmt,
12    fmt::{Debug, Formatter},
13    ops::{Add, AddAssign, Index, Neg, Sub, SubAssign},
14    rand::Rng,
15    vec::*,
16    UniformRand,
17};
18use hashbrown::HashMap;
19#[cfg(feature = "parallel")]
20use rayon::prelude::*;
21
22use super::DefaultHasher;
23
24/// Stores a multilinear polynomial in sparse evaluation form.
25#[derive(Clone, PartialEq, Eq, Hash, Default, CanonicalSerialize, CanonicalDeserialize)]
26pub struct SparseMultilinearExtension<F: Field> {
27    /// tuples of index and value
28    pub evaluations: BTreeMap<usize, F>,
29    /// number of variables
30    pub num_vars: usize,
31    zero: F,
32}
33
34impl<F: Field> SparseMultilinearExtension<F> {
35    pub fn from_evaluations<'a>(
36        num_vars: usize,
37        evaluations: impl IntoIterator<Item = &'a (usize, F)>,
38    ) -> Self {
39        let bit_mask = 1 << num_vars;
40        // check
41        let evaluations = evaluations.into_iter();
42        let evaluations: Vec<_> = evaluations
43            .map(|(i, v): &(usize, F)| {
44                assert!(*i < bit_mask, "index out of range");
45                (*i, *v)
46            })
47            .collect();
48
49        Self {
50            evaluations: tuples_to_treemap(&evaluations),
51            num_vars,
52            zero: F::zero(),
53        }
54    }
55
56    /// Outputs an `l`-variate multilinear extension where value of evaluations
57    /// are sampled uniformly at random. The number of nonzero entries is
58    /// `num_nonzero_entries` and indices of those nonzero entries are
59    /// distributed uniformly at random.
60    ///
61    /// Note that this function uses rejection sampling. As number of nonzero
62    /// entries approach `2 ^ num_vars`, sampling will be very slow due to
63    /// large number of collisions.
64    pub fn rand_with_config<R: Rng>(
65        num_vars: usize,
66        num_nonzero_entries: usize,
67        rng: &mut R,
68    ) -> Self {
69        assert!(num_nonzero_entries <= (1 << num_vars));
70
71        let mut map =
72            HashMap::with_hasher(core::hash::BuildHasherDefault::<DefaultHasher>::default());
73        for _ in 0..num_nonzero_entries {
74            let mut index = usize::rand(rng) & ((1 << num_vars) - 1);
75            while map.get(&index).is_some() {
76                index = usize::rand(rng) & ((1 << num_vars) - 1);
77            }
78            map.entry(index).or_insert(F::rand(rng));
79        }
80        let mut buf = Vec::new();
81        for (arg, v) in map.iter() {
82            if *v != F::zero() {
83                buf.push((*arg, *v));
84            }
85        }
86        let evaluations = hashmap_to_treemap(&map);
87        Self {
88            num_vars,
89            evaluations,
90            zero: F::zero(),
91        }
92    }
93
94    /// Convert the sparse multilinear polynomial to dense form.
95    pub fn to_dense_multilinear_extension(&self) -> DenseMultilinearExtension<F> {
96        let mut evaluations: Vec<_> = (0..(1 << self.num_vars)).map(|_| F::zero()).collect();
97        for (&i, &v) in self.evaluations.iter() {
98            evaluations[i] = v;
99        }
100        DenseMultilinearExtension::from_evaluations_vec(self.num_vars, evaluations)
101    }
102}
103
104/// utility: precompute f(x) = eq(g,x)
105fn precompute_eq<F: Field>(g: &[F]) -> Vec<F> {
106    let dim = g.len();
107    let mut dp = vec![F::zero(); 1 << dim];
108    dp[0] = F::one() - g[0];
109    dp[1] = g[0];
110    for i in 1..dim {
111        for b in 0..(1 << i) {
112            let prev = dp[b];
113            dp[b + (1 << i)] = prev * g[i];
114            dp[b] = prev - dp[b + (1 << i)];
115        }
116    }
117    dp
118}
119
120impl<F: Field> MultilinearExtension<F> for SparseMultilinearExtension<F> {
121    fn num_vars(&self) -> usize {
122        self.num_vars
123    }
124
125    /// Outputs an `l`-variate multilinear extension where value of evaluations
126    /// are sampled uniformly at random. The number of nonzero entries is
127    /// `sqrt(2^num_vars)` and indices of those nonzero entries are distributed
128    /// uniformly at random.
129    fn rand<R: Rng>(num_vars: usize, rng: &mut R) -> Self {
130        Self::rand_with_config(num_vars, 1 << (num_vars / 2), rng)
131    }
132
133    fn relabel(&self, mut a: usize, mut b: usize, k: usize) -> Self {
134        if a > b {
135            // swap
136            core::mem::swap(&mut a, &mut b);
137        }
138        // sanity check
139        assert!(
140            a + k < self.num_vars && b + k < self.num_vars,
141            "invalid relabel argument"
142        );
143        if a == b || k == 0 {
144            return self.clone();
145        }
146        assert!(a + k <= b, "overlapped swap window is not allowed");
147        let ev: Vec<_> = cfg_iter!(self.evaluations)
148            .map(|(&i, &v)| (swap_bits(i, a, b, k), v))
149            .collect();
150        Self {
151            num_vars: self.num_vars,
152            evaluations: tuples_to_treemap(&ev),
153            zero: F::zero(),
154        }
155    }
156
157    fn fix_variables(&self, partial_point: &[F]) -> Self {
158        let dim = partial_point.len();
159        assert!(dim <= self.num_vars, "invalid partial point dimension");
160
161        let mut window = ark_std::log2(self.evaluations.len()) as usize;
162        if window == 0 {
163            window = 1;
164        }
165        let mut point = partial_point;
166        let mut last = treemap_to_hashmap(&self.evaluations);
167
168        // batch evaluation
169        while !point.is_empty() {
170            let focus_length = if point.len() > window {
171                window
172            } else {
173                point.len()
174            };
175            let focus = &point[..focus_length];
176            point = &point[focus_length..];
177            let pre = precompute_eq(focus);
178            let dim = focus.len();
179            let mut result =
180                HashMap::with_hasher(core::hash::BuildHasherDefault::<DefaultHasher>::default());
181            for src_entry in last.iter() {
182                let old_idx = *src_entry.0;
183                let gz = pre[old_idx & ((1 << dim) - 1)];
184                let new_idx = old_idx >> dim;
185                let dst_entry = result.entry(new_idx).or_insert(F::zero());
186                *dst_entry += gz * src_entry.1;
187            }
188            last = result;
189        }
190        let evaluations = hashmap_to_treemap(&last);
191        Self {
192            num_vars: self.num_vars - dim,
193            evaluations,
194            zero: F::zero(),
195        }
196    }
197
198    fn to_evaluations(&self) -> Vec<F> {
199        let mut evaluations: Vec<_> = (0..1 << self.num_vars).map(|_| F::zero()).collect();
200        self.evaluations
201            .iter()
202            .map(|(&i, &v)| evaluations[i] = v)
203            .last();
204        evaluations
205    }
206}
207
208impl<F: Field> Index<usize> for SparseMultilinearExtension<F> {
209    type Output = F;
210
211    /// Returns the evaluation of the polynomial at a point represented by
212    /// index.
213    ///
214    /// Index represents a vector in {0,1}^`num_vars` in little endian form. For
215    /// example, `0b1011` represents `P(1,1,0,1)`
216    ///
217    /// For Sparse multilinear polynomial, Lookup_evaluation takes log time to
218    /// the size of polynomial.
219    fn index(&self, index: usize) -> &Self::Output {
220        if let Some(v) = self.evaluations.get(&index) {
221            v
222        } else {
223            &self.zero
224        }
225    }
226}
227
228impl<F: Field> Polynomial<F> for SparseMultilinearExtension<F> {
229    type Point = Vec<F>;
230
231    fn degree(&self) -> usize {
232        self.num_vars
233    }
234
235    fn evaluate(&self, point: &Self::Point) -> F {
236        assert!(point.len() == self.num_vars);
237        self.fix_variables(point)[0]
238    }
239}
240
241impl<F: Field> Add for SparseMultilinearExtension<F> {
242    type Output = SparseMultilinearExtension<F>;
243
244    fn add(self, other: SparseMultilinearExtension<F>) -> Self {
245        &self + &other
246    }
247}
248
249impl<'a, 'b, F: Field> Add<&'a SparseMultilinearExtension<F>>
250    for &'b SparseMultilinearExtension<F>
251{
252    type Output = SparseMultilinearExtension<F>;
253
254    fn add(self, rhs: &'a SparseMultilinearExtension<F>) -> Self::Output {
255        // handle zero case
256        if self.is_zero() {
257            return rhs.clone();
258        }
259        if rhs.is_zero() {
260            return self.clone();
261        }
262
263        assert_eq!(
264            rhs.num_vars, self.num_vars,
265            "trying to add non-zero polynomial with different number of variables"
266        );
267        // simply merge the evaluations
268        let mut evaluations =
269            HashMap::with_hasher(core::hash::BuildHasherDefault::<DefaultHasher>::default());
270        for (&i, &v) in self.evaluations.iter().chain(rhs.evaluations.iter()) {
271            *(evaluations.entry(i).or_insert(F::zero())) += v;
272        }
273        let evaluations: Vec<_> = evaluations
274            .into_iter()
275            .filter(|(_, v)| !v.is_zero())
276            .collect();
277
278        Self::Output {
279            evaluations: tuples_to_treemap(&evaluations),
280            num_vars: self.num_vars,
281            zero: F::zero(),
282        }
283    }
284}
285
286impl<F: Field> AddAssign for SparseMultilinearExtension<F> {
287    fn add_assign(&mut self, other: Self) {
288        *self = &*self + &other;
289    }
290}
291
292impl<'a, F: Field> AddAssign<&'a SparseMultilinearExtension<F>> for SparseMultilinearExtension<F> {
293    fn add_assign(&mut self, other: &'a SparseMultilinearExtension<F>) {
294        *self = &*self + other;
295    }
296}
297
298impl<'a, F: Field> AddAssign<(F, &'a SparseMultilinearExtension<F>)>
299    for SparseMultilinearExtension<F>
300{
301    fn add_assign(&mut self, (f, other): (F, &'a SparseMultilinearExtension<F>)) {
302        if !self.is_zero() && !other.is_zero() {
303            assert_eq!(
304                other.num_vars, self.num_vars,
305                "trying to add non-zero polynomial with different number of variables"
306            );
307        }
308        let ev: Vec<_> = cfg_iter!(other.evaluations)
309            .map(|(i, v)| (*i, f * v))
310            .collect();
311        let other = Self {
312            num_vars: other.num_vars,
313            evaluations: tuples_to_treemap(&ev),
314            zero: F::zero(),
315        };
316        *self += &other;
317    }
318}
319
320impl<F: Field> Neg for SparseMultilinearExtension<F> {
321    type Output = SparseMultilinearExtension<F>;
322
323    fn neg(self) -> Self::Output {
324        let ev: Vec<_> = cfg_iter!(self.evaluations)
325            .map(|(i, v)| (*i, -*v))
326            .collect();
327        Self::Output {
328            num_vars: self.num_vars,
329            evaluations: tuples_to_treemap(&ev),
330            zero: F::zero(),
331        }
332    }
333}
334
335impl<F: Field> Sub for SparseMultilinearExtension<F> {
336    type Output = SparseMultilinearExtension<F>;
337
338    fn sub(self, other: SparseMultilinearExtension<F>) -> Self {
339        &self - &other
340    }
341}
342
343impl<'a, 'b, F: Field> Sub<&'a SparseMultilinearExtension<F>>
344    for &'b SparseMultilinearExtension<F>
345{
346    type Output = SparseMultilinearExtension<F>;
347
348    fn sub(self, rhs: &'a SparseMultilinearExtension<F>) -> Self::Output {
349        self + &rhs.clone().neg()
350    }
351}
352
353impl<F: Field> SubAssign for SparseMultilinearExtension<F> {
354    fn sub_assign(&mut self, other: Self) {
355        *self = &*self - &other;
356    }
357}
358
359impl<'a, F: Field> SubAssign<&'a SparseMultilinearExtension<F>> for SparseMultilinearExtension<F> {
360    fn sub_assign(&mut self, other: &'a SparseMultilinearExtension<F>) {
361        *self = &*self - other;
362    }
363}
364
365impl<F: Field> Zero for SparseMultilinearExtension<F> {
366    fn zero() -> Self {
367        Self {
368            num_vars: 0,
369            evaluations: tuples_to_treemap(&Vec::new()),
370            zero: F::zero(),
371        }
372    }
373
374    fn is_zero(&self) -> bool {
375        self.num_vars == 0 && self.evaluations.is_empty()
376    }
377}
378
379impl<F: Field> Debug for SparseMultilinearExtension<F> {
380    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
381        write!(
382            f,
383            "SparseMultilinearPolynomial(num_vars = {}, evaluations = [",
384            self.num_vars
385        )?;
386        let mut ev_iter = self.evaluations.iter();
387        for _ in 0..ark_std::cmp::min(8, self.evaluations.len()) {
388            write!(f, "{:?}", ev_iter.next())?;
389        }
390        if self.evaluations.len() > 8 {
391            write!(f, "...")?;
392        }
393        write!(f, "])")?;
394        Ok(())
395    }
396}
397
398/// Utility: Convert tuples to hashmap.
399fn tuples_to_treemap<F: Field>(tuples: &[(usize, F)]) -> BTreeMap<usize, F> {
400    BTreeMap::from_iter(tuples.iter().map(|(i, v)| (*i, *v)))
401}
402
403fn treemap_to_hashmap<F: Field>(
404    map: &BTreeMap<usize, F>,
405) -> HashMap<usize, F, core::hash::BuildHasherDefault<DefaultHasher>> {
406    HashMap::from_iter(map.iter().map(|(i, v)| (*i, *v)))
407}
408
409fn hashmap_to_treemap<F: Field, S>(map: &HashMap<usize, F, S>) -> BTreeMap<usize, F> {
410    BTreeMap::from_iter(map.iter().map(|(i, v)| (*i, *v)))
411}
412
413#[cfg(test)]
414mod tests {
415    use crate::{
416        evaluations::multivariate::multilinear::MultilinearExtension, Polynomial,
417        SparseMultilinearExtension,
418    };
419    use ark_ff::{One, Zero};
420    use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
421    use ark_std::{ops::Neg, test_rng, vec::*, UniformRand};
422    use ark_test_curves::bls12_381::Fr;
423    /// Some sanity test to ensure random sparse polynomial make sense.
424    #[test]
425    fn random_poly() {
426        const NV: usize = 16;
427
428        let mut rng = test_rng();
429        // two random poly should be different
430        let poly1 = SparseMultilinearExtension::<Fr>::rand(NV, &mut rng);
431        let poly2 = SparseMultilinearExtension::<Fr>::rand(NV, &mut rng);
432        assert_ne!(poly1, poly2);
433        // test sparsity
434        assert!(
435            ((1 << (NV / 2)) >> 1) <= poly1.evaluations.len()
436                && poly1.evaluations.len() <= ((1 << (NV / 2)) << 1),
437            "polynomial size out of range: expected: [{},{}] ,actual: {}",
438            ((1 << (NV / 2)) >> 1),
439            ((1 << (NV / 2)) << 1),
440            poly1.evaluations.len()
441        );
442    }
443
444    #[test]
445    /// Test if sparse multilinear polynomial evaluates correctly.
446    /// This function assumes dense multilinear polynomial functions correctly.
447    fn evaluate() {
448        const NV: usize = 12;
449        let mut rng = test_rng();
450        for _ in 0..20 {
451            let sparse = SparseMultilinearExtension::<Fr>::rand(NV, &mut rng);
452            let dense = sparse.to_dense_multilinear_extension();
453            let point: Vec<_> = (0..NV).map(|_| Fr::rand(&mut rng)).collect();
454            assert_eq!(sparse.evaluate(&point), dense.evaluate(&point));
455            let sparse_partial = sparse.fix_variables(&point[..3].to_vec());
456            let dense_partial = dense.fix_variables(&point[..3].to_vec());
457            let point2: Vec<_> = (0..(NV - 3)).map(|_| Fr::rand(&mut rng)).collect();
458            assert_eq!(
459                sparse_partial.evaluate(&point2),
460                dense_partial.evaluate(&point2)
461            );
462        }
463    }
464
465    #[test]
466    fn evaluate_edge_cases() {
467        // test constant polynomial
468        let mut rng = test_rng();
469        let ev1 = Fr::rand(&mut rng);
470        let poly1 = SparseMultilinearExtension::from_evaluations(0, &vec![(0, ev1)]);
471        assert_eq!(poly1.evaluate(&[].into()), ev1);
472
473        // test single-variate polynomial
474        let ev2 = vec![Fr::rand(&mut rng), Fr::rand(&mut rng)];
475        let poly2 =
476            SparseMultilinearExtension::from_evaluations(1, &vec![(0, ev2[0]), (1, ev2[1])]);
477
478        let x = Fr::rand(&mut rng);
479        assert_eq!(
480            poly2.evaluate(&[x].into()),
481            x * ev2[1] + (Fr::one() - x) * ev2[0]
482        );
483
484        // test single-variate polynomial with one entry missing
485        let ev3 = Fr::rand(&mut rng);
486        let poly2 = SparseMultilinearExtension::from_evaluations(1, &vec![(1, ev3)]);
487
488        let x = Fr::rand(&mut rng);
489        assert_eq!(poly2.evaluate(&[x].into()), x * ev3);
490    }
491
492    #[test]
493    fn index() {
494        let mut rng = test_rng();
495        let points = vec![
496            (11, Fr::rand(&mut rng)),
497            (117, Fr::rand(&mut rng)),
498            (213, Fr::rand(&mut rng)),
499            (255, Fr::rand(&mut rng)),
500        ];
501        let poly = SparseMultilinearExtension::from_evaluations(8, &points);
502        points
503            .into_iter()
504            .map(|(i, v)| assert_eq!(poly[i], v))
505            .last();
506        assert_eq!(poly[0], Fr::zero());
507        assert_eq!(poly[1], Fr::zero());
508    }
509
510    #[test]
511    fn arithmetic() {
512        const NV: usize = 18;
513        let mut rng = test_rng();
514        for _ in 0..20 {
515            let point: Vec<_> = (0..NV).map(|_| Fr::rand(&mut rng)).collect();
516            let poly1 = SparseMultilinearExtension::rand(NV, &mut rng);
517            let poly2 = SparseMultilinearExtension::rand(NV, &mut rng);
518            let v1 = poly1.evaluate(&point);
519            let v2 = poly2.evaluate(&point);
520            // test add
521            assert_eq!((&poly1 + &poly2).evaluate(&point), v1 + v2);
522            // test sub
523            assert_eq!((&poly1 - &poly2).evaluate(&point), v1 - v2);
524            // test negate
525            assert_eq!(poly1.clone().neg().evaluate(&point), -v1);
526            // test add assign
527            {
528                let mut poly1 = poly1.clone();
529                poly1 += &poly2;
530                assert_eq!(poly1.evaluate(&point), v1 + v2)
531            }
532            // test sub assign
533            {
534                let mut poly1 = poly1.clone();
535                poly1 -= &poly2;
536                assert_eq!(poly1.evaluate(&point), v1 - v2)
537            }
538            // test add assign with scalar
539            {
540                let mut poly1 = poly1.clone();
541                let scalar = Fr::rand(&mut rng);
542                poly1 += (scalar, &poly2);
543                assert_eq!(poly1.evaluate(&point), v1 + scalar * v2)
544            }
545            // test additive identity
546            {
547                assert_eq!(&poly1 + &SparseMultilinearExtension::zero(), poly1);
548                assert_eq!(&SparseMultilinearExtension::zero() + &poly1, poly1);
549                {
550                    let mut poly1_cloned = poly1.clone();
551                    poly1_cloned += &SparseMultilinearExtension::zero();
552                    assert_eq!(&poly1_cloned, &poly1);
553                    let mut zero = SparseMultilinearExtension::zero();
554                    let scalar = Fr::rand(&mut rng);
555                    zero += (scalar, &poly1);
556                    assert_eq!(zero.evaluate(&point), scalar * v1);
557                }
558            }
559        }
560    }
561
562    #[test]
563    fn relabel() {
564        let mut rng = test_rng();
565        for _ in 0..20 {
566            let mut poly = SparseMultilinearExtension::rand(10, &mut rng);
567            let mut point: Vec<_> = (0..10).map(|_| Fr::rand(&mut rng)).collect();
568
569            let expected = poly.evaluate(&point);
570
571            poly = poly.relabel(2, 2, 1); // should have no effect
572            assert_eq!(expected, poly.evaluate(&point));
573
574            poly = poly.relabel(3, 4, 1); // should switch 3 and 4
575            point.swap(3, 4);
576            assert_eq!(expected, poly.evaluate(&point));
577
578            poly = poly.relabel(7, 5, 1);
579            point.swap(7, 5);
580            assert_eq!(expected, poly.evaluate(&point));
581
582            poly = poly.relabel(2, 5, 3);
583            point.swap(2, 5);
584            point.swap(3, 6);
585            point.swap(4, 7);
586            assert_eq!(expected, poly.evaluate(&point));
587
588            poly = poly.relabel(7, 0, 2);
589            point.swap(0, 7);
590            point.swap(1, 8);
591            assert_eq!(expected, poly.evaluate(&point));
592        }
593    }
594
595    #[test]
596    fn serialize() {
597        let mut rng = test_rng();
598        for _ in 0..20 {
599            let mut buf = Vec::new();
600            let poly = SparseMultilinearExtension::<Fr>::rand(10, &mut rng);
601            let point: Vec<_> = (0..10).map(|_| Fr::rand(&mut rng)).collect();
602            let expected = poly.evaluate(&point);
603
604            poly.serialize_compressed(&mut buf).unwrap();
605
606            let poly2: SparseMultilinearExtension<Fr> =
607                SparseMultilinearExtension::deserialize_compressed(&buf[..]).unwrap();
608            assert_eq!(poly2.evaluate(&point), expected);
609        }
610    }
611}