1use ark_ff::prelude::*;
2use ark_std::{
3 borrow::Borrow,
4 cfg_chunks, cfg_into_iter, cfg_iter,
5 iterable::Iterable,
6 ops::{AddAssign, SubAssign},
7 vec,
8 vec::Vec,
9};
10
11#[cfg(feature = "parallel")]
12use rayon::prelude::*;
13
14pub mod stream_pippenger;
15pub use stream_pippenger::*;
16
17use super::ScalarMul;
18
19#[cfg(all(
20 target_has_atomic = "8",
21 target_has_atomic = "16",
22 target_has_atomic = "32",
23 target_has_atomic = "64",
24 target_has_atomic = "ptr"
25))]
26type DefaultHasher = ahash::AHasher;
27
28#[cfg(not(all(
29 target_has_atomic = "8",
30 target_has_atomic = "16",
31 target_has_atomic = "32",
32 target_has_atomic = "64",
33 target_has_atomic = "ptr"
34)))]
35type DefaultHasher = fnv::FnvHasher;
36
37pub trait VariableBaseMSM: ScalarMul + for<'a> AddAssign<&'a Self::Bucket> {
38 type Bucket: Default
39 + Copy
40 + Clone
41 + for<'a> AddAssign<&'a Self::Bucket>
42 + for<'a> SubAssign<&'a Self::Bucket>
43 + AddAssign<Self::MulBase>
44 + SubAssign<Self::MulBase>
45 + for<'a> AddAssign<&'a Self::MulBase>
46 + for<'a> SubAssign<&'a Self::MulBase>
47 + Send
48 + Sync
49 + Into<Self>;
50
51 const ZERO_BUCKET: Self::Bucket;
52 fn msm_unchecked(bases: &[Self::MulBase], scalars: &[Self::ScalarField]) -> Self {
60 let bigints = cfg_into_iter!(scalars)
61 .map(|s| s.into_bigint())
62 .collect::<Vec<_>>();
63 Self::msm_bigint(bases, bigints.as_slice())
64 }
65
66 fn msm(bases: &[Self::MulBase], scalars: &[Self::ScalarField]) -> Result<Self, usize> {
74 (bases.len() == scalars.len())
75 .then(|| Self::msm_unchecked(bases, scalars))
76 .ok_or_else(|| bases.len().min(scalars.len()))
77 }
78
79 fn msm_bigint(
81 bases: &[Self::MulBase],
82 bigints: &[<Self::ScalarField as PrimeField>::BigInt],
83 ) -> Self {
84 msm_signed(bases, bigints)
85 }
86
87 fn msm_u1(bases: &[Self::MulBase], scalars: &[bool]) -> Self {
90 msm_binary(bases, scalars)
91 }
92
93 fn msm_u8(bases: &[Self::MulBase], scalars: &[u8]) -> Self {
96 msm_u8(bases, scalars)
97 }
98
99 fn msm_u16(bases: &[Self::MulBase], scalars: &[u16]) -> Self {
102 msm_u16(bases, scalars)
103 }
104
105 fn msm_u32(bases: &[Self::MulBase], scalars: &[u32]) -> Self {
108 msm_u32(bases, scalars)
109 }
110
111 fn msm_u64(bases: &[Self::MulBase], scalars: &[u64]) -> Self {
114 msm_u64(bases, scalars)
115 }
116
117 fn msm_chunks<I, J>(bases_stream: &J, scalars_stream: &I) -> Self
120 where
121 I: Iterable + ?Sized,
122 I::Item: Borrow<Self::ScalarField>,
123 J: Iterable,
124 J::Item: Borrow<Self::MulBase>,
125 {
126 assert!(scalars_stream.len() <= bases_stream.len());
127
128 let bases_init = bases_stream.iter();
130 let mut scalars = scalars_stream.iter();
131
132 let mut bases = bases_init.skip(bases_stream.len() - scalars_stream.len());
136 let step: usize = 1 << 20;
137 let mut result = Self::zero();
138 for _ in 0..scalars_stream.len().div_ceil(step) {
139 let bases_step = (&mut bases)
140 .take(step)
141 .map(|b| *b.borrow())
142 .collect::<Vec<_>>();
143 let scalars_step = (&mut scalars)
144 .take(step)
145 .map(|s| s.borrow().into_bigint())
146 .collect::<Vec<_>>();
147 result += Self::msm_bigint(bases_step.as_slice(), scalars_step.as_slice());
148 }
149 result
150 }
151}
152
153#[inline]
154fn large_value_unzip<A: Send + Sync, B: Send + Sync>(
155 grouped: &[PackedIndex],
156 f: impl Fn(usize) -> (A, B) + Send + Sync,
157) -> (Vec<A>, Vec<B>) {
158 cfg_iter!(grouped)
159 .map(|&i| f(i.index()))
160 .unzip::<_, _, Vec<_>, Vec<_>>()
161}
162
163#[inline]
164fn small_value_unzip<A: Send + Sync, B: Send + Sync>(
165 grouped: &[PackedIndex],
166 f: impl Fn(usize, u16) -> (A, B) + Send + Sync,
167) -> (Vec<A>, Vec<B>) {
168 cfg_iter!(grouped)
169 .map(|&i| f(i.index(), i.value()))
170 .unzip::<_, _, Vec<_>, Vec<_>>()
171}
172
173#[inline(always)]
174fn sub<B: BigInteger>(m: &B, scalar: &B) -> u64 {
175 let mut negated = *m;
176 negated.sub_with_borrow(scalar);
177 negated.as_ref()[0]
178}
179
180const VALUE_MASK: u64 = (u16::MAX as u64) << 44;
182
183#[repr(u8)]
184#[derive(Debug, Clone, Copy, PartialEq, Eq)]
185enum ScalarSize {
186 U1 = 0,
187 NegU1 = 1,
188 U8 = 2,
189 NegU8 = 3,
190 U16 = 4,
191 NegU16 = 5,
192 U32 = 6,
193 NegU32 = 7,
194 U64 = 8,
195 NegU64 = 9,
196 BigInt = 10,
197}
198
199impl ScalarSize {
200 #[inline]
201 fn partition_point(self, v: &[PackedIndex]) -> usize {
202 v.partition_point(|i| i.group() < self as u8 + 1)
203 }
204}
205
206#[derive(Debug, Clone, Copy, PartialEq, Eq)]
207#[repr(transparent)]
208pub struct PackedIndex(pub u64);
209
210impl PackedIndex {
211 #[inline(always)]
212 fn new(index: usize, group: ScalarSize, value: u16) -> Self {
213 let index_bits = ((index as u64) << 20) >> 20;
215 let group_bits = (group as u64) << 60;
216 let value_bits = (value as u64) << 44;
217
218 PackedIndex(index_bits | value_bits | group_bits)
219 }
220 #[inline(always)]
222 fn index(self) -> usize {
223 ((self.0 << 20) >> 20) as usize
224 }
225
226 #[inline(always)]
228 fn group(self) -> u8 {
229 (self.0 >> 60) as u8
230 }
231
232 #[inline(always)]
233 fn value(self) -> u16 {
234 ((self.0 & VALUE_MASK) >> 44) as u16
235 }
236}
237
238fn msm_signed<V: VariableBaseMSM>(
243 bases: &[V::MulBase],
244 scalars: &[<V::ScalarField as PrimeField>::BigInt],
245) -> V {
246 let size = bases.len().min(scalars.len());
247 let bases = &bases[..size];
248 let scalars = &scalars[..size];
249
250 let mut grouped = cfg_iter!(scalars)
252 .enumerate()
253 .filter(|(_, scalar)| !scalar.is_zero())
254 .map(|(i, scalar)| {
255 use ScalarSize::*;
256 let mut value = 0;
257 let group = match scalar.num_bits() {
258 0..=1 => U1,
259 2..=8 => U8,
260 9..=16 => U16,
261 17..=32 => U32,
262 33..=64 => U64,
263 _ => {
264 let mut p_minus_scalar = V::ScalarField::MODULUS;
265 p_minus_scalar.sub_with_borrow(scalar);
266 let group = match p_minus_scalar.num_bits() {
267 0..=1 => NegU1,
268 2..=8 => NegU8,
269 9..=16 => NegU16,
270 17..=32 => NegU32,
271 33..=64 => NegU64,
272 _ => ScalarSize::BigInt,
273 };
274 if matches!(group, NegU1 | NegU8 | NegU16) {
275 value = p_minus_scalar.as_ref()[0] as u16
276 }
277 group
278 },
279 };
280 if matches!(group, U1 | U8 | U16) {
281 value = (scalar.as_ref()[0]) as u16;
282 };
283 PackedIndex::new(i, group, value)
284 })
285 .collect::<Vec<_>>();
286
287 #[cfg(feature = "parallel")]
288 grouped.par_sort_unstable_by_key(|i| i.group());
289 #[cfg(not(feature = "parallel"))]
290 grouped.sort_unstable_by_key(|i| i.group());
291
292 let (u1s, rest) = grouped.split_at(ScalarSize::U1.partition_point(&grouped));
293 let (i1s, rest) = rest.split_at(ScalarSize::NegU1.partition_point(rest));
294 let (u8s, rest) = rest.split_at(ScalarSize::U8.partition_point(rest));
295 let (i8s, rest) = rest.split_at(ScalarSize::NegU8.partition_point(rest));
296 let (u16s, rest) = rest.split_at(ScalarSize::U16.partition_point(rest));
297 let (i16s, rest) = rest.split_at(ScalarSize::NegU16.partition_point(rest));
298 let (u32s, rest) = rest.split_at(ScalarSize::U32.partition_point(rest));
299 let (i32s, rest) = rest.split_at(ScalarSize::NegU32.partition_point(rest));
300 let (u64s, rest) = rest.split_at(ScalarSize::U64.partition_point(rest));
301 let (i64s, rest) = rest.split_at(ScalarSize::NegU64.partition_point(rest));
302 let (bigints, _) = rest.split_at(ScalarSize::BigInt.partition_point(rest));
303
304 let m = V::ScalarField::MODULUS;
305 let mut add_result: V;
306 let mut sub_result: V;
307
308 let (ub, us) = small_value_unzip(&u1s, |i, v| (bases[i], v == 1));
310 let (ib, is) = small_value_unzip(&i1s, |i, v| (bases[i], v == 1));
311 add_result = msm_binary::<V>(&ub, &us);
312 sub_result = msm_binary::<V>(&ib, &is);
313
314 let (ub, us) = small_value_unzip(u8s, |i, v| (bases[i], v as u8));
316 let (ib, is) = small_value_unzip(i8s, |i, v| (bases[i], v as u8));
317 add_result += msm_u8::<V>(&ub, &us);
318 sub_result += msm_u8::<V>(&ib, &is);
319
320 let (ub, us) = small_value_unzip(u16s, |i, v| (bases[i], v as u16));
322 let (ib, is) = small_value_unzip(i16s, |i, v| (bases[i], v as u16));
323 add_result += msm_u16::<V>(&ub, &us);
324 sub_result += msm_u16::<V>(&ib, &is);
325
326 let (ub, us) = large_value_unzip(u32s, |i| (bases[i], scalars[i].as_ref()[0] as u32));
328 let (ib, is) = large_value_unzip(i32s, |i| (bases[i], sub(&m, &scalars[i]) as u32));
329 add_result += msm_u32::<V>(&ub, &us);
330 sub_result += msm_u32::<V>(&ib, &is);
331
332 let (ub, us) = large_value_unzip(u64s, |i| (bases[i], scalars[i].as_ref()[0]));
334 let (ib, is) = large_value_unzip(i64s, |i| (bases[i], sub(&m, &scalars[i])));
335 add_result += msm_u64::<V>(&ub, &us);
336 sub_result += msm_u64::<V>(&ib, &is);
337
338 let (bf, sf) = large_value_unzip(&bigints, |i| (bases[i], scalars[i]));
340 if V::NEGATION_IS_CHEAP {
341 add_result += msm_bigint_wnaf::<V>(&bf, &sf);
342 } else {
343 add_result += msm_bigint::<V>(&bf, &sf);
344 }
345
346 (add_result - sub_result).into()
347}
348
349fn preamble<A, B>(bases: &mut &[A], scalars: &mut &[B]) -> Option<usize> {
350 let size = bases.len().min(scalars.len());
351 if size == 0 {
352 return None;
353 }
354 #[cfg(feature = "parallel")]
355 let chunk_size = {
356 let chunk_size = size / rayon::current_num_threads();
357 if chunk_size == 0 {
358 size
359 } else {
360 chunk_size
361 }
362 };
363 #[cfg(not(feature = "parallel"))]
364 let chunk_size = size;
365
366 *bases = &bases[..size];
367 *scalars = &scalars[..size];
368 Some(chunk_size)
369}
370
371fn msm_binary<V: VariableBaseMSM>(mut bases: &[V::MulBase], mut scalars: &[bool]) -> V {
374 let chunk_size = match preamble(&mut bases, &mut scalars) {
375 Some(chunk_size) => chunk_size,
376 None => return V::zero(),
377 };
378
379 cfg_chunks!(bases, chunk_size)
381 .zip(cfg_chunks!(scalars, chunk_size))
382 .map(|(bases, scalars)| {
383 let mut res = V::ZERO_BUCKET;
384 for (base, _) in bases.iter().zip(scalars).filter(|(_, &s)| s) {
385 res += base;
386 }
387 res.into()
388 })
389 .sum()
390}
391
392fn msm_u8<V: VariableBaseMSM>(mut bases: &[V::MulBase], mut scalars: &[u8]) -> V {
393 let chunk_size = match preamble(&mut bases, &mut scalars) {
394 Some(chunk_size) => chunk_size,
395 None => return V::zero(),
396 };
397 cfg_chunks!(bases, chunk_size)
398 .zip(cfg_chunks!(scalars, chunk_size))
399 .map(|(bases, scalars)| msm_serial::<V>(bases, scalars))
400 .sum()
401}
402
403fn msm_u16<V: VariableBaseMSM>(mut bases: &[V::MulBase], mut scalars: &[u16]) -> V {
404 let chunk_size = match preamble(&mut bases, &mut scalars) {
405 Some(chunk_size) => chunk_size,
406 None => return V::zero(),
407 };
408 cfg_chunks!(bases, chunk_size)
409 .zip(cfg_chunks!(scalars, chunk_size))
410 .map(|(bases, scalars)| msm_serial::<V>(bases, scalars))
411 .sum()
412}
413
414fn msm_u32<V: VariableBaseMSM>(mut bases: &[V::MulBase], mut scalars: &[u32]) -> V {
415 let chunk_size = match preamble(&mut bases, &mut scalars) {
416 Some(chunk_size) => chunk_size,
417 None => return V::zero(),
418 };
419 cfg_chunks!(bases, chunk_size)
420 .zip(cfg_chunks!(scalars, chunk_size))
421 .map(|(bases, scalars)| msm_serial::<V>(bases, scalars))
422 .sum()
423}
424
425fn msm_u64<V: VariableBaseMSM>(mut bases: &[V::MulBase], mut scalars: &[u64]) -> V {
426 let chunk_size = match preamble(&mut bases, &mut scalars) {
427 Some(chunk_size) => chunk_size,
428 None => return V::zero(),
429 };
430 cfg_chunks!(bases, chunk_size)
431 .zip(cfg_chunks!(scalars, chunk_size))
432 .map(|(bases, scalars)| msm_serial::<V>(bases, scalars))
433 .sum()
434}
435
436fn msm_bigint_wnaf_parallel<V: VariableBaseMSM>(
438 bases: &[V::MulBase],
439 bigints: &[<V::ScalarField as PrimeField>::BigInt],
440) -> V {
441 let size = bases.len().min(bigints.len());
442 let scalars = &bigints[..size];
443 let bases = &bases[..size];
444
445 let c = if size < 32 {
446 3
447 } else {
448 super::ln_without_floats(size) + 2
449 };
450
451 let num_bits = V::ScalarField::MODULUS_BIT_SIZE as usize;
452 let digits_count = num_bits.div_ceil(c);
453 #[cfg(feature = "parallel")]
454 let scalar_digits = scalars
455 .into_par_iter()
456 .flat_map_iter(|s| make_digits(s, c, num_bits))
457 .collect::<Vec<_>>();
458 #[cfg(not(feature = "parallel"))]
459 let scalar_digits = scalars
460 .iter()
461 .flat_map(|s| make_digits(s, c, num_bits))
462 .collect::<Vec<_>>();
463 let zero = V::ZERO_BUCKET;
464 let window_sums: Vec<_> = ark_std::cfg_into_iter!(0..digits_count)
465 .map(|i| {
466 let mut buckets = vec![zero; 1 << c];
467 for (digits, base) in scalar_digits.chunks(digits_count).zip(bases) {
468 use ark_std::cmp::Ordering;
469 let scalar = digits[i];
470 match 0.cmp(&scalar) {
471 Ordering::Less => buckets[(scalar - 1) as usize] += base,
472 Ordering::Greater => buckets[(-scalar - 1) as usize] -= base,
473 Ordering::Equal => (),
474 }
475 }
476
477 let mut running_sum = V::ZERO_BUCKET;
479 let mut res = V::ZERO_BUCKET;
480 buckets.into_iter().rev().for_each(|b| {
481 running_sum += &b;
482 res += &running_sum;
483 });
484 res
485 })
486 .collect();
487
488 let lowest: V = (*window_sums.first().unwrap()).into();
490
491 lowest
493 + (&window_sums[1..])
494 .iter()
495 .rev()
496 .fold(V::zero(), |mut total, sum_i| {
497 total += sum_i;
498 for _ in 0..c {
499 total.double_in_place();
500 }
501 total
502 })
503}
504
505#[cfg(feature = "parallel")]
506const THREADS_PER_CHUNK: usize = 2;
507
508fn msm_bigint_wnaf<V: VariableBaseMSM>(
513 mut bases: &[V::MulBase],
514 mut scalars: &[<V::ScalarField as PrimeField>::BigInt],
515) -> V {
516 let size = bases.len().min(scalars.len());
517 if size == 0 {
518 return V::zero();
519 }
520
521 #[cfg(feature = "parallel")]
522 let chunk_size = {
523 let cur_num_threads = rayon::current_num_threads();
524 let num_chunks = if cur_num_threads < THREADS_PER_CHUNK {
525 1
526 } else {
527 cur_num_threads / THREADS_PER_CHUNK
528 };
529 let chunk_size = size / num_chunks;
530 if chunk_size == 0 {
531 size
532 } else {
533 chunk_size
534 }
535 };
536 #[cfg(not(feature = "parallel"))]
537 let chunk_size = size;
538
539 bases = &bases[..size];
540 scalars = &scalars[..size];
541
542 cfg_chunks!(bases, chunk_size)
543 .zip(cfg_chunks!(scalars, chunk_size))
544 .map(|(bases, scalars)| {
545 #[cfg(feature = "parallel")]
546 let result = rayon::ThreadPoolBuilder::new()
547 .num_threads(THREADS_PER_CHUNK.min(rayon::current_num_threads()))
548 .build()
549 .unwrap()
550 .install(|| msm_bigint_wnaf_parallel::<V>(bases, scalars));
551
552 #[cfg(not(feature = "parallel"))]
553 let result = msm_bigint_wnaf_parallel::<V>(bases, scalars);
554
555 result
556 })
557 .sum()
558}
559
560fn msm_bigint<V: VariableBaseMSM>(
562 mut bases: &[V::MulBase],
563 mut scalars: &[<V::ScalarField as PrimeField>::BigInt],
564) -> V {
565 if preamble(&mut bases, &mut scalars).is_none() {
566 return V::zero();
567 }
568 let size = scalars.len();
569 let scalars_and_bases_iter = scalars.iter().zip(bases).filter(|(s, _)| !s.is_zero());
570
571 let c = if size < 32 {
572 3
573 } else {
574 super::ln_without_floats(size) + 2
575 };
576
577 let one = V::ScalarField::one().into_bigint();
578 let zero = V::ZERO_BUCKET;
579 let num_bits = V::ScalarField::MODULUS_BIT_SIZE as usize;
580
581 let window_sums: Vec<_> = ark_std::cfg_into_iter!(0..num_bits)
585 .step_by(c)
586 .map(|w_start| {
587 let mut res = zero;
588 let mut buckets = vec![zero; (1 << c) - 1];
590 scalars_and_bases_iter.clone().for_each(|(&scalar, base)| {
593 if scalar == one {
594 if w_start == 0 {
596 res += base;
597 }
598 } else {
599 let mut scalar = scalar;
600
601 scalar >>= w_start as u32;
604
605 let scalar = scalar.as_ref()[0] % (1 << c);
607
608 if scalar != 0 {
612 buckets[(scalar - 1) as usize] += base;
613 }
614 }
615 });
616
617 let mut running_sum = V::ZERO_BUCKET;
632 buckets.into_iter().rev().for_each(|b| {
633 running_sum += &b;
634 res += &running_sum;
635 });
636 res
637 })
638 .collect();
639
640 let lowest = window_sums.first().copied().map_or(V::ZERO, Into::into);
642
643 lowest
645 + &window_sums[1..]
646 .iter()
647 .rev()
648 .fold(V::zero(), |mut total, sum_i| {
649 total += sum_i;
650 for _ in 0..c {
651 total.double_in_place();
652 }
653 total
654 })
655}
656
657fn msm_serial<V: VariableBaseMSM>(
658 bases: &[V::MulBase],
659 scalars: &[impl Into<u64> + Copy + Send + Sync],
660) -> V {
661 let c = if bases.len() < 32 {
662 3
663 } else {
664 super::ln_without_floats(bases.len()) + 2
665 };
666
667 let zero = V::ZERO_BUCKET;
668
669 let two_to_c = 1 << c;
673 let window_sums: Vec<_> = (0..(core::mem::size_of::<u64>() * 8))
674 .step_by(c)
675 .map(|w_start| {
676 let mut res = zero;
677 let mut buckets = vec![zero; two_to_c - 1];
679 scalars
682 .iter()
683 .zip(bases)
684 .filter_map(|(&s, b)| {
685 let s = s.into();
686 (s != 0).then_some((s, b))
687 })
688 .for_each(|(scalar, base)| {
689 if scalar == 1 {
690 if w_start == 0 {
692 res += base;
693 }
694 } else {
695 let mut scalar = scalar;
696
697 scalar >>= w_start as u32;
700
701 scalar %= two_to_c as u64;
703
704 if scalar != 0 {
708 buckets[(scalar - 1) as usize] += base;
709 }
710 }
711 });
712
713 let mut running_sum = V::ZERO_BUCKET;
728 buckets.into_iter().rev().for_each(|b| {
729 running_sum += &b;
730 res += &running_sum;
731 });
732 res
733 })
734 .collect();
735
736 let lowest = window_sums.first().copied().map_or(V::ZERO, Into::into);
738
739 lowest
741 + &window_sums[1..]
742 .iter()
743 .rev()
744 .fold(V::zero(), |mut total, sum_i| {
745 total += sum_i;
746 for _ in 0..c {
747 total.double_in_place();
748 }
749 total
750 })
751}
752
753fn make_digits(a: &impl BigInteger, w: usize, num_bits: usize) -> impl Iterator<Item = i64> + '_ {
755 let scalar = a.as_ref();
756 let radix: u64 = 1 << w;
757 let window_mask: u64 = radix - 1;
758
759 let mut carry = 0u64;
760 let num_bits = if num_bits == 0 {
761 a.num_bits() as usize
762 } else {
763 num_bits
764 };
765 let digits_count = num_bits.div_ceil(w);
766
767 (0..digits_count).map(move |i| {
768 let bit_offset = i * w;
770 let u64_idx = bit_offset / 64;
771 let bit_idx = bit_offset % 64;
772 let bit_buf = if bit_idx < 64 - w || u64_idx == scalar.len() - 1 {
774 scalar[u64_idx] >> bit_idx
777 } else {
778 (scalar[u64_idx] >> bit_idx) | (scalar[1 + u64_idx] << (64 - bit_idx))
780 };
781
782 let coef = carry + (bit_buf & window_mask); carry = (coef + radix / 2) >> w;
787 let mut digit = (coef as i64) - (carry << w) as i64;
788
789 if i == digits_count - 1 {
790 digit += (carry << w) as i64;
791 }
792 digit
793 })
794}