Skip to main content

ark_ec/scalar_mul/variable_base/
mod.rs

1use ark_ff::prelude::*;
2use ark_std::{
3    borrow::Borrow,
4    cfg_chunks, cfg_into_iter, cfg_iter,
5    iterable::Iterable,
6    ops::{AddAssign, SubAssign},
7    vec,
8    vec::Vec,
9};
10
11#[cfg(feature = "parallel")]
12use rayon::prelude::*;
13
14pub mod stream_pippenger;
15pub use stream_pippenger::*;
16
17use super::ScalarMul;
18
19#[cfg(all(
20    target_has_atomic = "8",
21    target_has_atomic = "16",
22    target_has_atomic = "32",
23    target_has_atomic = "64",
24    target_has_atomic = "ptr"
25))]
26type DefaultHasher = ahash::AHasher;
27
28#[cfg(not(all(
29    target_has_atomic = "8",
30    target_has_atomic = "16",
31    target_has_atomic = "32",
32    target_has_atomic = "64",
33    target_has_atomic = "ptr"
34)))]
35type DefaultHasher = fnv::FnvHasher;
36
37pub trait VariableBaseMSM: ScalarMul + for<'a> AddAssign<&'a Self::Bucket> {
38    type Bucket: Default
39        + Copy
40        + Clone
41        + for<'a> AddAssign<&'a Self::Bucket>
42        + for<'a> SubAssign<&'a Self::Bucket>
43        + AddAssign<Self::MulBase>
44        + SubAssign<Self::MulBase>
45        + for<'a> AddAssign<&'a Self::MulBase>
46        + for<'a> SubAssign<&'a Self::MulBase>
47        + Send
48        + Sync
49        + Into<Self>;
50
51    const ZERO_BUCKET: Self::Bucket;
52    /// Computes an inner product between the [`PrimeField`] elements in `scalars`
53    /// and the corresponding group elements in `bases`.
54    ///
55    /// If the elements have different length, it will chop the slices to the
56    /// shortest length between `scalars.len()` and `bases.len()`.
57    ///
58    /// Reference: [`VariableBaseMSM::msm`]
59    fn msm_unchecked(bases: &[Self::MulBase], scalars: &[Self::ScalarField]) -> Self {
60        let bigints = cfg_into_iter!(scalars)
61            .map(|s| s.into_bigint())
62            .collect::<Vec<_>>();
63        Self::msm_bigint(bases, bigints.as_slice())
64    }
65
66    /// Performs multi-scalar multiplication.
67    ///
68    /// # Warning
69    ///
70    /// This method checks that `bases` and `scalars` have the same length.
71    /// If they are unequal, it returns an error containing
72    /// the shortest length over which the MSM can be performed.
73    fn msm(bases: &[Self::MulBase], scalars: &[Self::ScalarField]) -> Result<Self, usize> {
74        (bases.len() == scalars.len())
75            .then(|| Self::msm_unchecked(bases, scalars))
76            .ok_or_else(|| bases.len().min(scalars.len()))
77    }
78
79    /// Optimized implementation of multi-scalar multiplication.
80    fn msm_bigint(
81        bases: &[Self::MulBase],
82        bigints: &[<Self::ScalarField as PrimeField>::BigInt],
83    ) -> Self {
84        msm_signed(bases, bigints)
85    }
86
87    /// Performs multi-scalar multiplication when the scalars are known to be boolean.
88    /// The default implementation is faster than [`Self::msm_bigint`].
89    fn msm_u1(bases: &[Self::MulBase], scalars: &[bool]) -> Self {
90        msm_binary(bases, scalars)
91    }
92
93    /// Performs multi-scalar multiplication when the scalars are known to be `u8`-sized.
94    /// The default implementation is faster than [`Self::msm_bigint`].
95    fn msm_u8(bases: &[Self::MulBase], scalars: &[u8]) -> Self {
96        msm_u8(bases, scalars)
97    }
98
99    /// Performs multi-scalar multiplication when the scalars are known to be `u16`-sized.
100    /// The default implementation is faster than [`Self::msm_bigint`].
101    fn msm_u16(bases: &[Self::MulBase], scalars: &[u16]) -> Self {
102        msm_u16(bases, scalars)
103    }
104
105    /// Performs multi-scalar multiplication when the scalars are known to be `u32`-sized.
106    /// The default implementation is faster than [`Self::msm_bigint`].
107    fn msm_u32(bases: &[Self::MulBase], scalars: &[u32]) -> Self {
108        msm_u32(bases, scalars)
109    }
110
111    /// Performs multi-scalar multiplication when the scalars are known to be `u64`-sized.
112    /// The default implementation is faster than [`Self::msm_bigint`].
113    fn msm_u64(bases: &[Self::MulBase], scalars: &[u64]) -> Self {
114        msm_u64(bases, scalars)
115    }
116
117    /// Streaming multi-scalar multiplication algorithm with hard-coded chunk
118    /// size.
119    fn msm_chunks<I, J>(bases_stream: &J, scalars_stream: &I) -> Self
120    where
121        I: Iterable + ?Sized,
122        I::Item: Borrow<Self::ScalarField>,
123        J: Iterable,
124        J::Item: Borrow<Self::MulBase>,
125    {
126        assert!(scalars_stream.len() <= bases_stream.len());
127
128        // remove offset
129        let bases_init = bases_stream.iter();
130        let mut scalars = scalars_stream.iter();
131
132        // align the streams
133        // TODO: change `skip` to `advance_by` once rust-lang/rust#7774 is fixed.
134        // See <https://github.com/rust-lang/rust/issues/77404>
135        let mut bases = bases_init.skip(bases_stream.len() - scalars_stream.len());
136        let step: usize = 1 << 20;
137        let mut result = Self::zero();
138        for _ in 0..scalars_stream.len().div_ceil(step) {
139            let bases_step = (&mut bases)
140                .take(step)
141                .map(|b| *b.borrow())
142                .collect::<Vec<_>>();
143            let scalars_step = (&mut scalars)
144                .take(step)
145                .map(|s| s.borrow().into_bigint())
146                .collect::<Vec<_>>();
147            result += Self::msm_bigint(bases_step.as_slice(), scalars_step.as_slice());
148        }
149        result
150    }
151}
152
153#[inline]
154fn large_value_unzip<A: Send + Sync, B: Send + Sync>(
155    grouped: &[PackedIndex],
156    f: impl Fn(usize) -> (A, B) + Send + Sync,
157) -> (Vec<A>, Vec<B>) {
158    cfg_iter!(grouped)
159        .map(|&i| f(i.index()))
160        .unzip::<_, _, Vec<_>, Vec<_>>()
161}
162
163#[inline]
164fn small_value_unzip<A: Send + Sync, B: Send + Sync>(
165    grouped: &[PackedIndex],
166    f: impl Fn(usize, u16) -> (A, B) + Send + Sync,
167) -> (Vec<A>, Vec<B>) {
168    cfg_iter!(grouped)
169        .map(|&i| f(i.index(), i.value()))
170        .unzip::<_, _, Vec<_>, Vec<_>>()
171}
172
173#[inline(always)]
174fn sub<B: BigInteger>(m: &B, scalar: &B) -> u64 {
175    let mut negated = *m;
176    negated.sub_with_borrow(scalar);
177    negated.as_ref()[0]
178}
179
180// 44 zeroes, 1 in the next 16 bits, 0 rest
181const VALUE_MASK: u64 = (u16::MAX as u64) << 44;
182
183#[repr(u8)]
184#[derive(Debug, Clone, Copy, PartialEq, Eq)]
185enum ScalarSize {
186    U1 = 0,
187    NegU1 = 1,
188    U8 = 2,
189    NegU8 = 3,
190    U16 = 4,
191    NegU16 = 5,
192    U32 = 6,
193    NegU32 = 7,
194    U64 = 8,
195    NegU64 = 9,
196    BigInt = 10,
197}
198
199impl ScalarSize {
200    #[inline]
201    fn partition_point(self, v: &[PackedIndex]) -> usize {
202        v.partition_point(|i| i.group() < self as u8 + 1)
203    }
204}
205
206#[derive(Debug, Clone, Copy, PartialEq, Eq)]
207#[repr(transparent)]
208pub struct PackedIndex(pub u64);
209
210impl PackedIndex {
211    #[inline(always)]
212    fn new(index: usize, group: ScalarSize, value: u16) -> Self {
213        // Pack the index, group, and value into a single u64.
214        let index_bits = ((index as u64) << 20) >> 20;
215        let group_bits = (group as u64) << 60;
216        let value_bits = (value as u64) << 44;
217
218        PackedIndex(index_bits | value_bits | group_bits)
219    }
220    /// Extracts the index from the packed value.
221    #[inline(always)]
222    fn index(self) -> usize {
223        ((self.0 << 20) >> 20) as usize
224    }
225
226    /// Extracts the group from the packed value.
227    #[inline(always)]
228    fn group(self) -> u8 {
229        (self.0 >> 60) as u8
230    }
231
232    #[inline(always)]
233    fn value(self) -> u16 {
234        ((self.0 & VALUE_MASK) >> 44) as u16
235    }
236}
237
238/// Computes multi-scalar multiplication where the scalars
239/// can be negative, zero, or positive.
240/// Should be used when the negation is cheap, i.e. when
241/// `V::NEGATION_IS_CHEAP` is `true`.
242fn msm_signed<V: VariableBaseMSM>(
243    bases: &[V::MulBase],
244    scalars: &[<V::ScalarField as PrimeField>::BigInt],
245) -> V {
246    let size = bases.len().min(scalars.len());
247    let bases = &bases[..size];
248    let scalars = &scalars[..size];
249
250    // Partition scalars according to their size.
251    let mut grouped = cfg_iter!(scalars)
252        .enumerate()
253        .filter(|(_, scalar)| !scalar.is_zero())
254        .map(|(i, scalar)| {
255            use ScalarSize::*;
256            let mut value = 0;
257            let group = match scalar.num_bits() {
258                0..=1 => U1,
259                2..=8 => U8,
260                9..=16 => U16,
261                17..=32 => U32,
262                33..=64 => U64,
263                _ => {
264                    let mut p_minus_scalar = V::ScalarField::MODULUS;
265                    p_minus_scalar.sub_with_borrow(scalar);
266                    let group = match p_minus_scalar.num_bits() {
267                        0..=1 => NegU1,
268                        2..=8 => NegU8,
269                        9..=16 => NegU16,
270                        17..=32 => NegU32,
271                        33..=64 => NegU64,
272                        _ => ScalarSize::BigInt,
273                    };
274                    if matches!(group, NegU1 | NegU8 | NegU16) {
275                        value = p_minus_scalar.as_ref()[0] as u16
276                    }
277                    group
278                },
279            };
280            if matches!(group, U1 | U8 | U16) {
281                value = (scalar.as_ref()[0]) as u16;
282            };
283            PackedIndex::new(i, group, value)
284        })
285        .collect::<Vec<_>>();
286
287    #[cfg(feature = "parallel")]
288    grouped.par_sort_unstable_by_key(|i| i.group());
289    #[cfg(not(feature = "parallel"))]
290    grouped.sort_unstable_by_key(|i| i.group());
291
292    let (u1s, rest) = grouped.split_at(ScalarSize::U1.partition_point(&grouped));
293    let (i1s, rest) = rest.split_at(ScalarSize::NegU1.partition_point(rest));
294    let (u8s, rest) = rest.split_at(ScalarSize::U8.partition_point(rest));
295    let (i8s, rest) = rest.split_at(ScalarSize::NegU8.partition_point(rest));
296    let (u16s, rest) = rest.split_at(ScalarSize::U16.partition_point(rest));
297    let (i16s, rest) = rest.split_at(ScalarSize::NegU16.partition_point(rest));
298    let (u32s, rest) = rest.split_at(ScalarSize::U32.partition_point(rest));
299    let (i32s, rest) = rest.split_at(ScalarSize::NegU32.partition_point(rest));
300    let (u64s, rest) = rest.split_at(ScalarSize::U64.partition_point(rest));
301    let (i64s, rest) = rest.split_at(ScalarSize::NegU64.partition_point(rest));
302    let (bigints, _) = rest.split_at(ScalarSize::BigInt.partition_point(rest));
303
304    let m = V::ScalarField::MODULUS;
305    let mut add_result: V;
306    let mut sub_result: V;
307
308    // Handle the scalars in the range {-1, 0, 1}.
309    let (ub, us) = small_value_unzip(&u1s, |i, v| (bases[i], v == 1));
310    let (ib, is) = small_value_unzip(&i1s, |i, v| (bases[i], v == 1));
311    add_result = msm_binary::<V>(&ub, &us);
312    sub_result = msm_binary::<V>(&ib, &is);
313
314    // Handle positive and negative u8 scalars.
315    let (ub, us) = small_value_unzip(u8s, |i, v| (bases[i], v as u8));
316    let (ib, is) = small_value_unzip(i8s, |i, v| (bases[i], v as u8));
317    add_result += msm_u8::<V>(&ub, &us);
318    sub_result += msm_u8::<V>(&ib, &is);
319
320    // Handle positive and negative u16 scalars.
321    let (ub, us) = small_value_unzip(u16s, |i, v| (bases[i], v as u16));
322    let (ib, is) = small_value_unzip(i16s, |i, v| (bases[i], v as u16));
323    add_result += msm_u16::<V>(&ub, &us);
324    sub_result += msm_u16::<V>(&ib, &is);
325
326    // Handle positive and negative u32 scalars.
327    let (ub, us) = large_value_unzip(u32s, |i| (bases[i], scalars[i].as_ref()[0] as u32));
328    let (ib, is) = large_value_unzip(i32s, |i| (bases[i], sub(&m, &scalars[i]) as u32));
329    add_result += msm_u32::<V>(&ub, &us);
330    sub_result += msm_u32::<V>(&ib, &is);
331
332    // Handle positive and negative u64 scalars.
333    let (ub, us) = large_value_unzip(u64s, |i| (bases[i], scalars[i].as_ref()[0]));
334    let (ib, is) = large_value_unzip(i64s, |i| (bases[i], sub(&m, &scalars[i])));
335    add_result += msm_u64::<V>(&ub, &us);
336    sub_result += msm_u64::<V>(&ib, &is);
337
338    // Handle the rest of the scalars.
339    let (bf, sf) = large_value_unzip(&bigints, |i| (bases[i], scalars[i]));
340    if V::NEGATION_IS_CHEAP {
341        add_result += msm_bigint_wnaf::<V>(&bf, &sf);
342    } else {
343        add_result += msm_bigint::<V>(&bf, &sf);
344    }
345
346    (add_result - sub_result).into()
347}
348
349fn preamble<A, B>(bases: &mut &[A], scalars: &mut &[B]) -> Option<usize> {
350    let size = bases.len().min(scalars.len());
351    if size == 0 {
352        return None;
353    }
354    #[cfg(feature = "parallel")]
355    let chunk_size = {
356        let chunk_size = size / rayon::current_num_threads();
357        if chunk_size == 0 {
358            size
359        } else {
360            chunk_size
361        }
362    };
363    #[cfg(not(feature = "parallel"))]
364    let chunk_size = size;
365
366    *bases = &bases[..size];
367    *scalars = &scalars[..size];
368    Some(chunk_size)
369}
370
371/// Computes multi-scalar multiplication where the scalars
372/// lie in the range {-1, 0, 1}.
373fn msm_binary<V: VariableBaseMSM>(mut bases: &[V::MulBase], mut scalars: &[bool]) -> V {
374    let chunk_size = match preamble(&mut bases, &mut scalars) {
375        Some(chunk_size) => chunk_size,
376        None => return V::zero(),
377    };
378
379    // We only need to process the non-zero scalars.
380    cfg_chunks!(bases, chunk_size)
381        .zip(cfg_chunks!(scalars, chunk_size))
382        .map(|(bases, scalars)| {
383            let mut res = V::ZERO_BUCKET;
384            for (base, _) in bases.iter().zip(scalars).filter(|(_, &s)| s) {
385                res += base;
386            }
387            res.into()
388        })
389        .sum()
390}
391
392fn msm_u8<V: VariableBaseMSM>(mut bases: &[V::MulBase], mut scalars: &[u8]) -> V {
393    let chunk_size = match preamble(&mut bases, &mut scalars) {
394        Some(chunk_size) => chunk_size,
395        None => return V::zero(),
396    };
397    cfg_chunks!(bases, chunk_size)
398        .zip(cfg_chunks!(scalars, chunk_size))
399        .map(|(bases, scalars)| msm_serial::<V>(bases, scalars))
400        .sum()
401}
402
403fn msm_u16<V: VariableBaseMSM>(mut bases: &[V::MulBase], mut scalars: &[u16]) -> V {
404    let chunk_size = match preamble(&mut bases, &mut scalars) {
405        Some(chunk_size) => chunk_size,
406        None => return V::zero(),
407    };
408    cfg_chunks!(bases, chunk_size)
409        .zip(cfg_chunks!(scalars, chunk_size))
410        .map(|(bases, scalars)| msm_serial::<V>(bases, scalars))
411        .sum()
412}
413
414fn msm_u32<V: VariableBaseMSM>(mut bases: &[V::MulBase], mut scalars: &[u32]) -> V {
415    let chunk_size = match preamble(&mut bases, &mut scalars) {
416        Some(chunk_size) => chunk_size,
417        None => return V::zero(),
418    };
419    cfg_chunks!(bases, chunk_size)
420        .zip(cfg_chunks!(scalars, chunk_size))
421        .map(|(bases, scalars)| msm_serial::<V>(bases, scalars))
422        .sum()
423}
424
425fn msm_u64<V: VariableBaseMSM>(mut bases: &[V::MulBase], mut scalars: &[u64]) -> V {
426    let chunk_size = match preamble(&mut bases, &mut scalars) {
427        Some(chunk_size) => chunk_size,
428        None => return V::zero(),
429    };
430    cfg_chunks!(bases, chunk_size)
431        .zip(cfg_chunks!(scalars, chunk_size))
432        .map(|(bases, scalars)| msm_serial::<V>(bases, scalars))
433        .sum()
434}
435
436// Compute msm using windowed non-adjacent form
437fn msm_bigint_wnaf_parallel<V: VariableBaseMSM>(
438    bases: &[V::MulBase],
439    bigints: &[<V::ScalarField as PrimeField>::BigInt],
440) -> V {
441    let size = bases.len().min(bigints.len());
442    let scalars = &bigints[..size];
443    let bases = &bases[..size];
444
445    let c = if size < 32 {
446        3
447    } else {
448        super::ln_without_floats(size) + 2
449    };
450
451    let num_bits = V::ScalarField::MODULUS_BIT_SIZE as usize;
452    let digits_count = num_bits.div_ceil(c);
453    #[cfg(feature = "parallel")]
454    let scalar_digits = scalars
455        .into_par_iter()
456        .flat_map_iter(|s| make_digits(s, c, num_bits))
457        .collect::<Vec<_>>();
458    #[cfg(not(feature = "parallel"))]
459    let scalar_digits = scalars
460        .iter()
461        .flat_map(|s| make_digits(s, c, num_bits))
462        .collect::<Vec<_>>();
463    let zero = V::ZERO_BUCKET;
464    let window_sums: Vec<_> = ark_std::cfg_into_iter!(0..digits_count)
465        .map(|i| {
466            let mut buckets = vec![zero; 1 << c];
467            for (digits, base) in scalar_digits.chunks(digits_count).zip(bases) {
468                use ark_std::cmp::Ordering;
469                let scalar = digits[i];
470                match 0.cmp(&scalar) {
471                    Ordering::Less => buckets[(scalar - 1) as usize] += base,
472                    Ordering::Greater => buckets[(-scalar - 1) as usize] -= base,
473                    Ordering::Equal => (),
474                }
475            }
476
477            // prefix sum
478            let mut running_sum = V::ZERO_BUCKET;
479            let mut res = V::ZERO_BUCKET;
480            buckets.into_iter().rev().for_each(|b| {
481                running_sum += &b;
482                res += &running_sum;
483            });
484            res
485        })
486        .collect();
487
488    // We store the sum for the lowest window.
489    let lowest: V = (*window_sums.first().unwrap()).into();
490
491    // We're traversing windows from high to low.
492    lowest
493        + (&window_sums[1..])
494            .iter()
495            .rev()
496            .fold(V::zero(), |mut total, sum_i| {
497                total += sum_i;
498                for _ in 0..c {
499                    total.double_in_place();
500                }
501                total
502            })
503}
504
505#[cfg(feature = "parallel")]
506const THREADS_PER_CHUNK: usize = 2;
507
508/// Computes an MSM using the windowed non-adjacent form (WNAF) algorithm.
509/// To improve parallelism, when number of threads is at least 2, this
510/// function will split the input into enough chunks so that each chunk
511/// can be processed with 2 threads.
512fn msm_bigint_wnaf<V: VariableBaseMSM>(
513    mut bases: &[V::MulBase],
514    mut scalars: &[<V::ScalarField as PrimeField>::BigInt],
515) -> V {
516    let size = bases.len().min(scalars.len());
517    if size == 0 {
518        return V::zero();
519    }
520
521    #[cfg(feature = "parallel")]
522    let chunk_size = {
523        let cur_num_threads = rayon::current_num_threads();
524        let num_chunks = if cur_num_threads < THREADS_PER_CHUNK {
525            1
526        } else {
527            cur_num_threads / THREADS_PER_CHUNK
528        };
529        let chunk_size = size / num_chunks;
530        if chunk_size == 0 {
531            size
532        } else {
533            chunk_size
534        }
535    };
536    #[cfg(not(feature = "parallel"))]
537    let chunk_size = size;
538
539    bases = &bases[..size];
540    scalars = &scalars[..size];
541
542    cfg_chunks!(bases, chunk_size)
543        .zip(cfg_chunks!(scalars, chunk_size))
544        .map(|(bases, scalars)| {
545            #[cfg(feature = "parallel")]
546            let result = rayon::ThreadPoolBuilder::new()
547                .num_threads(THREADS_PER_CHUNK.min(rayon::current_num_threads()))
548                .build()
549                .unwrap()
550                .install(|| msm_bigint_wnaf_parallel::<V>(bases, scalars));
551
552            #[cfg(not(feature = "parallel"))]
553            let result = msm_bigint_wnaf_parallel::<V>(bases, scalars);
554
555            result
556        })
557        .sum()
558}
559
560/// Optimized implementation of multi-scalar multiplication.
561fn msm_bigint<V: VariableBaseMSM>(
562    mut bases: &[V::MulBase],
563    mut scalars: &[<V::ScalarField as PrimeField>::BigInt],
564) -> V {
565    if preamble(&mut bases, &mut scalars).is_none() {
566        return V::zero();
567    }
568    let size = scalars.len();
569    let scalars_and_bases_iter = scalars.iter().zip(bases).filter(|(s, _)| !s.is_zero());
570
571    let c = if size < 32 {
572        3
573    } else {
574        super::ln_without_floats(size) + 2
575    };
576
577    let one = V::ScalarField::one().into_bigint();
578    let zero = V::ZERO_BUCKET;
579    let num_bits = V::ScalarField::MODULUS_BIT_SIZE as usize;
580
581    // Each window is of size `c`.
582    // We divide up the bits 0..num_bits into windows of size `c`, and
583    // in parallel process each such window.
584    let window_sums: Vec<_> = ark_std::cfg_into_iter!(0..num_bits)
585        .step_by(c)
586        .map(|w_start| {
587            let mut res = zero;
588            // We don't need the "zero" bucket, so we only have 2^c - 1 buckets.
589            let mut buckets = vec![zero; (1 << c) - 1];
590            // This clone is cheap, because the iterator contains just a
591            // pointer and an index into the original vectors.
592            scalars_and_bases_iter.clone().for_each(|(&scalar, base)| {
593                if scalar == one {
594                    // We only process unit scalars once in the first window.
595                    if w_start == 0 {
596                        res += base;
597                    }
598                } else {
599                    let mut scalar = scalar;
600
601                    // We right-shift by w_start, thus getting rid of the
602                    // lower bits.
603                    scalar >>= w_start as u32;
604
605                    // We mod the remaining bits by 2^{window size}, thus taking `c` bits.
606                    let scalar = scalar.as_ref()[0] % (1 << c);
607
608                    // If the scalar is non-zero, we update the corresponding
609                    // bucket.
610                    // (Recall that `buckets` doesn't have a zero bucket.)
611                    if scalar != 0 {
612                        buckets[(scalar - 1) as usize] += base;
613                    }
614                }
615            });
616
617            // Compute sum_{i in 0..num_buckets} (sum_{j in i..num_buckets} bucket[j])
618            // This is computed below for b buckets, using 2b curve additions.
619            //
620            // We could first normalize `buckets` and then use mixed-addition
621            // here, but that's slower for the kinds of groups we care about
622            // (Short Weierstrass curves and Twisted Edwards curves).
623            // In the case of Short Weierstrass curves,
624            // mixed addition saves ~4 field multiplications per addition.
625            // However normalization (with the inversion batched) takes ~6
626            // field multiplications per element,
627            // hence batch normalization is a slowdown.
628
629            // `running_sum` = sum_{j in i..num_buckets} bucket[j],
630            // where we iterate backward from i = num_buckets to 0.
631            let mut running_sum = V::ZERO_BUCKET;
632            buckets.into_iter().rev().for_each(|b| {
633                running_sum += &b;
634                res += &running_sum;
635            });
636            res
637        })
638        .collect();
639
640    // We store the sum for the lowest window.
641    let lowest = window_sums.first().copied().map_or(V::ZERO, Into::into);
642
643    // We're traversing windows from high to low.
644    lowest
645        + &window_sums[1..]
646            .iter()
647            .rev()
648            .fold(V::zero(), |mut total, sum_i| {
649                total += sum_i;
650                for _ in 0..c {
651                    total.double_in_place();
652                }
653                total
654            })
655}
656
657fn msm_serial<V: VariableBaseMSM>(
658    bases: &[V::MulBase],
659    scalars: &[impl Into<u64> + Copy + Send + Sync],
660) -> V {
661    let c = if bases.len() < 32 {
662        3
663    } else {
664        super::ln_without_floats(bases.len()) + 2
665    };
666
667    let zero = V::ZERO_BUCKET;
668
669    // Each window is of size `c`.
670    // We divide up the bits 0..num_bits into windows of size `c`, and
671    // in parallel process each such window.
672    let two_to_c = 1 << c;
673    let window_sums: Vec<_> = (0..(core::mem::size_of::<u64>() * 8))
674        .step_by(c)
675        .map(|w_start| {
676            let mut res = zero;
677            // We don't need the "zero" bucket, so we only have 2^c - 1 buckets.
678            let mut buckets = vec![zero; two_to_c - 1];
679            // This clone is cheap, because the iterator contains just a
680            // pointer and an index into the original vectors.
681            scalars
682                .iter()
683                .zip(bases)
684                .filter_map(|(&s, b)| {
685                    let s = s.into();
686                    (s != 0).then_some((s, b))
687                })
688                .for_each(|(scalar, base)| {
689                    if scalar == 1 {
690                        // We only process unit scalars once in the first window.
691                        if w_start == 0 {
692                            res += base;
693                        }
694                    } else {
695                        let mut scalar = scalar;
696
697                        // We right-shift by w_start, thus getting rid of the
698                        // lower bits.
699                        scalar >>= w_start as u32;
700
701                        // We mod the remaining bits by 2^{window size}, thus taking `c` bits.
702                        scalar %= two_to_c as u64;
703
704                        // If the scalar is non-zero, we update the corresponding
705                        // bucket.
706                        // (Recall that `buckets` doesn't have a zero bucket.)
707                        if scalar != 0 {
708                            buckets[(scalar - 1) as usize] += base;
709                        }
710                    }
711                });
712
713            // Compute sum_{i in 0..num_buckets} (sum_{j in i..num_buckets} bucket[j])
714            // This is computed below for b buckets, using 2b curve additions.
715            //
716            // We could first normalize `buckets` and then use mixed-addition
717            // here, but that's slower for the kinds of groups we care about
718            // (Short Weierstrass curves and Twisted Edwards curves).
719            // In the case of Short Weierstrass curves,
720            // mixed addition saves ~4 field multiplications per addition.
721            // However normalization (with the inversion batched) takes ~6
722            // field multiplications per element,
723            // hence batch normalization is a slowdown.
724
725            // `running_sum` = sum_{j in i..num_buckets} bucket[j],
726            // where we iterate backward from i = num_buckets to 0.
727            let mut running_sum = V::ZERO_BUCKET;
728            buckets.into_iter().rev().for_each(|b| {
729                running_sum += &b;
730                res += &running_sum;
731            });
732            res
733        })
734        .collect();
735
736    // We store the sum for the lowest window.
737    let lowest = window_sums.first().copied().map_or(V::ZERO, Into::into);
738
739    // We're traversing windows from high to low.
740    lowest
741        + &window_sums[1..]
742            .iter()
743            .rev()
744            .fold(V::zero(), |mut total, sum_i| {
745                total += sum_i;
746                for _ in 0..c {
747                    total.double_in_place();
748                }
749                total
750            })
751}
752
753// From: https://github.com/arkworks-rs/gemini/blob/main/src/kzg/msm/variable_base.rs#L20
754fn make_digits(a: &impl BigInteger, w: usize, num_bits: usize) -> impl Iterator<Item = i64> + '_ {
755    let scalar = a.as_ref();
756    let radix: u64 = 1 << w;
757    let window_mask: u64 = radix - 1;
758
759    let mut carry = 0u64;
760    let num_bits = if num_bits == 0 {
761        a.num_bits() as usize
762    } else {
763        num_bits
764    };
765    let digits_count = num_bits.div_ceil(w);
766
767    (0..digits_count).map(move |i| {
768        // Construct a buffer of bits of the scalar, starting at `bit_offset`.
769        let bit_offset = i * w;
770        let u64_idx = bit_offset / 64;
771        let bit_idx = bit_offset % 64;
772        // Read the bits from the scalar
773        let bit_buf = if bit_idx < 64 - w || u64_idx == scalar.len() - 1 {
774            // This window's bits are contained in a single u64,
775            // or it's the last u64 anyway.
776            scalar[u64_idx] >> bit_idx
777        } else {
778            // Combine the current u64's bits with the bits from the next u64
779            (scalar[u64_idx] >> bit_idx) | (scalar[1 + u64_idx] << (64 - bit_idx))
780        };
781
782        // Read the actual coefficient value from the window
783        let coef = carry + (bit_buf & window_mask); // coef = [0, 2^r)
784
785        // Recenter coefficients from [0,2^w) to [-2^w/2, 2^w/2)
786        carry = (coef + radix / 2) >> w;
787        let mut digit = (coef as i64) - (carry << w) as i64;
788
789        if i == digits_count - 1 {
790            digit += (carry << w) as i64;
791        }
792        digit
793    })
794}