ark_ec/scalar_mul/variable_base/
mod.rs

1use ark_ff::prelude::*;
2use ark_std::{borrow::Borrow, cfg_into_iter, iterable::Iterable, vec::*};
3
4#[cfg(feature = "parallel")]
5use rayon::prelude::*;
6
7pub mod stream_pippenger;
8pub use stream_pippenger::*;
9
10use super::ScalarMul;
11
12#[cfg(all(
13    target_has_atomic = "8",
14    target_has_atomic = "16",
15    target_has_atomic = "32",
16    target_has_atomic = "64",
17    target_has_atomic = "ptr"
18))]
19type DefaultHasher = ahash::AHasher;
20
21#[cfg(not(all(
22    target_has_atomic = "8",
23    target_has_atomic = "16",
24    target_has_atomic = "32",
25    target_has_atomic = "64",
26    target_has_atomic = "ptr"
27)))]
28type DefaultHasher = fnv::FnvHasher;
29
30pub trait VariableBaseMSM: ScalarMul {
31    /// Computes an inner product between the [`PrimeField`] elements in `scalars`
32    /// and the corresponding group elements in `bases`.
33    ///
34    /// If the elements have different length, it will chop the slices to the
35    /// shortest length between `scalars.len()` and `bases.len()`.
36    ///
37    /// Reference: [`VariableBaseMSM::msm`]
38    fn msm_unchecked(bases: &[Self::MulBase], scalars: &[Self::ScalarField]) -> Self {
39        let bigints = cfg_into_iter!(scalars)
40            .map(|s| s.into_bigint())
41            .collect::<Vec<_>>();
42        Self::msm_bigint(bases, &bigints)
43    }
44
45    /// Performs multi-scalar multiplication.
46    ///
47    /// # Warning
48    ///
49    /// This method checks that `bases` and `scalars` have the same length.
50    /// If they are unequal, it returns an error containing
51    /// the shortest length over which the MSM can be performed.
52    fn msm(bases: &[Self::MulBase], scalars: &[Self::ScalarField]) -> Result<Self, usize> {
53        (bases.len() == scalars.len())
54            .then(|| Self::msm_unchecked(bases, scalars))
55            .ok_or(bases.len().min(scalars.len()))
56    }
57
58    /// Optimized implementation of multi-scalar multiplication.
59    fn msm_bigint(
60        bases: &[Self::MulBase],
61        bigints: &[<Self::ScalarField as PrimeField>::BigInt],
62    ) -> Self {
63        if Self::NEGATION_IS_CHEAP {
64            msm_bigint_wnaf(bases, bigints)
65        } else {
66            msm_bigint(bases, bigints)
67        }
68    }
69
70    /// Streaming multi-scalar multiplication algorithm with hard-coded chunk
71    /// size.
72    fn msm_chunks<I: ?Sized, J>(bases_stream: &J, scalars_stream: &I) -> Self
73    where
74        I: Iterable,
75        I::Item: Borrow<Self::ScalarField>,
76        J: Iterable,
77        J::Item: Borrow<Self::MulBase>,
78    {
79        assert!(scalars_stream.len() <= bases_stream.len());
80
81        // remove offset
82        let bases_init = bases_stream.iter();
83        let mut scalars = scalars_stream.iter();
84
85        // align the streams
86        // TODO: change `skip` to `advance_by` once rust-lang/rust#7774 is fixed.
87        // See <https://github.com/rust-lang/rust/issues/77404>
88        let mut bases = bases_init.skip(bases_stream.len() - scalars_stream.len());
89        let step: usize = 1 << 20;
90        let mut result = Self::zero();
91        for _ in 0..(scalars_stream.len() + step - 1) / step {
92            let bases_step = (&mut bases)
93                .take(step)
94                .map(|b| *b.borrow())
95                .collect::<Vec<_>>();
96            let scalars_step = (&mut scalars)
97                .take(step)
98                .map(|s| s.borrow().into_bigint())
99                .collect::<Vec<_>>();
100            result += Self::msm_bigint(bases_step.as_slice(), scalars_step.as_slice());
101        }
102        result
103    }
104}
105
106// Compute msm using windowed non-adjacent form
107fn msm_bigint_wnaf<V: VariableBaseMSM>(
108    bases: &[V::MulBase],
109    bigints: &[<V::ScalarField as PrimeField>::BigInt],
110) -> V {
111    let size = ark_std::cmp::min(bases.len(), bigints.len());
112    let scalars = &bigints[..size];
113    let bases = &bases[..size];
114
115    let c = if size < 32 {
116        3
117    } else {
118        super::ln_without_floats(size) + 2
119    };
120
121    let num_bits = V::ScalarField::MODULUS_BIT_SIZE as usize;
122    let digits_count = (num_bits + c - 1) / c;
123    #[cfg(feature = "parallel")]
124    let scalar_digits = scalars
125        .into_par_iter()
126        .flat_map_iter(|s| make_digits(s, c, num_bits))
127        .collect::<Vec<_>>();
128    #[cfg(not(feature = "parallel"))]
129    let scalar_digits = scalars
130        .iter()
131        .flat_map(|s| make_digits(s, c, num_bits))
132        .collect::<Vec<_>>();
133    let zero = V::zero();
134    let window_sums: Vec<_> = ark_std::cfg_into_iter!(0..digits_count)
135        .map(|i| {
136            let mut buckets = vec![zero; 1 << c];
137            for (digits, base) in scalar_digits.chunks(digits_count).zip(bases) {
138                use ark_std::cmp::Ordering;
139                // digits is the digits thing of the first scalar?
140                let scalar = digits[i];
141                match 0.cmp(&scalar) {
142                    Ordering::Less => buckets[(scalar - 1) as usize] += base,
143                    Ordering::Greater => buckets[(-scalar - 1) as usize] -= base,
144                    Ordering::Equal => (),
145                }
146            }
147
148            let mut running_sum = V::zero();
149            let mut res = V::zero();
150            buckets.into_iter().rev().for_each(|b| {
151                running_sum += &b;
152                res += &running_sum;
153            });
154            res
155        })
156        .collect();
157
158    // We store the sum for the lowest window.
159    let lowest = *window_sums.first().unwrap();
160
161    // We're traversing windows from high to low.
162    lowest
163        + &window_sums[1..]
164            .iter()
165            .rev()
166            .fold(zero, |mut total, sum_i| {
167                total += sum_i;
168                for _ in 0..c {
169                    total.double_in_place();
170                }
171                total
172            })
173}
174
175/// Optimized implementation of multi-scalar multiplication.
176fn msm_bigint<V: VariableBaseMSM>(
177    bases: &[V::MulBase],
178    bigints: &[<V::ScalarField as PrimeField>::BigInt],
179) -> V {
180    let size = ark_std::cmp::min(bases.len(), bigints.len());
181    let scalars = &bigints[..size];
182    let bases = &bases[..size];
183    let scalars_and_bases_iter = scalars.iter().zip(bases).filter(|(s, _)| !s.is_zero());
184
185    let c = if size < 32 {
186        3
187    } else {
188        super::ln_without_floats(size) + 2
189    };
190
191    let num_bits = V::ScalarField::MODULUS_BIT_SIZE as usize;
192    let one = V::ScalarField::one().into_bigint();
193
194    let zero = V::zero();
195    let window_starts: Vec<_> = (0..num_bits).step_by(c).collect();
196
197    // Each window is of size `c`.
198    // We divide up the bits 0..num_bits into windows of size `c`, and
199    // in parallel process each such window.
200    let window_sums: Vec<_> = ark_std::cfg_into_iter!(window_starts)
201        .map(|w_start| {
202            let mut res = zero;
203            // We don't need the "zero" bucket, so we only have 2^c - 1 buckets.
204            let mut buckets = vec![zero; (1 << c) - 1];
205            // This clone is cheap, because the iterator contains just a
206            // pointer and an index into the original vectors.
207            scalars_and_bases_iter.clone().for_each(|(&scalar, base)| {
208                if scalar == one {
209                    // We only process unit scalars once in the first window.
210                    if w_start == 0 {
211                        res += base;
212                    }
213                } else {
214                    let mut scalar = scalar;
215
216                    // We right-shift by w_start, thus getting rid of the
217                    // lower bits.
218                    scalar >>= w_start as u32;
219
220                    // We mod the remaining bits by 2^{window size}, thus taking `c` bits.
221                    let scalar = scalar.as_ref()[0] % (1 << c);
222
223                    // If the scalar is non-zero, we update the corresponding
224                    // bucket.
225                    // (Recall that `buckets` doesn't have a zero bucket.)
226                    if scalar != 0 {
227                        buckets[(scalar - 1) as usize] += base;
228                    }
229                }
230            });
231
232            // Compute sum_{i in 0..num_buckets} (sum_{j in i..num_buckets} bucket[j])
233            // This is computed below for b buckets, using 2b curve additions.
234            //
235            // We could first normalize `buckets` and then use mixed-addition
236            // here, but that's slower for the kinds of groups we care about
237            // (Short Weierstrass curves and Twisted Edwards curves).
238            // In the case of Short Weierstrass curves,
239            // mixed addition saves ~4 field multiplications per addition.
240            // However normalization (with the inversion batched) takes ~6
241            // field multiplications per element,
242            // hence batch normalization is a slowdown.
243
244            // `running_sum` = sum_{j in i..num_buckets} bucket[j],
245            // where we iterate backward from i = num_buckets to 0.
246            let mut running_sum = V::zero();
247            buckets.into_iter().rev().for_each(|b| {
248                running_sum += &b;
249                res += &running_sum;
250            });
251            res
252        })
253        .collect();
254
255    // We store the sum for the lowest window.
256    let lowest = *window_sums.first().unwrap();
257
258    // We're traversing windows from high to low.
259    lowest
260        + &window_sums[1..]
261            .iter()
262            .rev()
263            .fold(zero, |mut total, sum_i| {
264                total += sum_i;
265                for _ in 0..c {
266                    total.double_in_place();
267                }
268                total
269            })
270}
271
272// From: https://github.com/arkworks-rs/gemini/blob/main/src/kzg/msm/variable_base.rs#L20
273fn make_digits(a: &impl BigInteger, w: usize, num_bits: usize) -> impl Iterator<Item = i64> + '_ {
274    let scalar = a.as_ref();
275    let radix: u64 = 1 << w;
276    let window_mask: u64 = radix - 1;
277
278    let mut carry = 0u64;
279    let num_bits = if num_bits == 0 {
280        a.num_bits() as usize
281    } else {
282        num_bits
283    };
284    let digits_count = (num_bits + w - 1) / w;
285
286    (0..digits_count).into_iter().map(move |i| {
287        // Construct a buffer of bits of the scalar, starting at `bit_offset`.
288        let bit_offset = i * w;
289        let u64_idx = bit_offset / 64;
290        let bit_idx = bit_offset % 64;
291        // Read the bits from the scalar
292        let bit_buf = if bit_idx < 64 - w || u64_idx == scalar.len() - 1 {
293            // This window's bits are contained in a single u64,
294            // or it's the last u64 anyway.
295            scalar[u64_idx] >> bit_idx
296        } else {
297            // Combine the current u64's bits with the bits from the next u64
298            (scalar[u64_idx] >> bit_idx) | (scalar[1 + u64_idx] << (64 - bit_idx))
299        };
300
301        // Read the actual coefficient value from the window
302        let coef = carry + (bit_buf & window_mask); // coef = [0, 2^r)
303
304        // Recenter coefficients from [0,2^w) to [-2^w/2, 2^w/2)
305        carry = (coef + radix / 2) >> w;
306        let mut digit = (coef as i64) - (carry << w) as i64;
307
308        if i == digits_count - 1 {
309            digit += (carry << w) as i64;
310        }
311        digit
312    })
313}