ark_ec/scalar_mul/variable_base/
mod.rs1use ark_ff::prelude::*;
2use ark_std::{borrow::Borrow, cfg_into_iter, iterable::Iterable, vec::*};
3
4#[cfg(feature = "parallel")]
5use rayon::prelude::*;
6
7pub mod stream_pippenger;
8pub use stream_pippenger::*;
9
10use super::ScalarMul;
11
12#[cfg(all(
13 target_has_atomic = "8",
14 target_has_atomic = "16",
15 target_has_atomic = "32",
16 target_has_atomic = "64",
17 target_has_atomic = "ptr"
18))]
19type DefaultHasher = ahash::AHasher;
20
21#[cfg(not(all(
22 target_has_atomic = "8",
23 target_has_atomic = "16",
24 target_has_atomic = "32",
25 target_has_atomic = "64",
26 target_has_atomic = "ptr"
27)))]
28type DefaultHasher = fnv::FnvHasher;
29
30pub trait VariableBaseMSM: ScalarMul {
31 fn msm_unchecked(bases: &[Self::MulBase], scalars: &[Self::ScalarField]) -> Self {
39 let bigints = cfg_into_iter!(scalars)
40 .map(|s| s.into_bigint())
41 .collect::<Vec<_>>();
42 Self::msm_bigint(bases, &bigints)
43 }
44
45 fn msm(bases: &[Self::MulBase], scalars: &[Self::ScalarField]) -> Result<Self, usize> {
53 (bases.len() == scalars.len())
54 .then(|| Self::msm_unchecked(bases, scalars))
55 .ok_or(bases.len().min(scalars.len()))
56 }
57
58 fn msm_bigint(
60 bases: &[Self::MulBase],
61 bigints: &[<Self::ScalarField as PrimeField>::BigInt],
62 ) -> Self {
63 if Self::NEGATION_IS_CHEAP {
64 msm_bigint_wnaf(bases, bigints)
65 } else {
66 msm_bigint(bases, bigints)
67 }
68 }
69
70 fn msm_chunks<I: ?Sized, J>(bases_stream: &J, scalars_stream: &I) -> Self
73 where
74 I: Iterable,
75 I::Item: Borrow<Self::ScalarField>,
76 J: Iterable,
77 J::Item: Borrow<Self::MulBase>,
78 {
79 assert!(scalars_stream.len() <= bases_stream.len());
80
81 let bases_init = bases_stream.iter();
83 let mut scalars = scalars_stream.iter();
84
85 let mut bases = bases_init.skip(bases_stream.len() - scalars_stream.len());
89 let step: usize = 1 << 20;
90 let mut result = Self::zero();
91 for _ in 0..(scalars_stream.len() + step - 1) / step {
92 let bases_step = (&mut bases)
93 .take(step)
94 .map(|b| *b.borrow())
95 .collect::<Vec<_>>();
96 let scalars_step = (&mut scalars)
97 .take(step)
98 .map(|s| s.borrow().into_bigint())
99 .collect::<Vec<_>>();
100 result += Self::msm_bigint(bases_step.as_slice(), scalars_step.as_slice());
101 }
102 result
103 }
104}
105
106fn msm_bigint_wnaf<V: VariableBaseMSM>(
108 bases: &[V::MulBase],
109 bigints: &[<V::ScalarField as PrimeField>::BigInt],
110) -> V {
111 let size = ark_std::cmp::min(bases.len(), bigints.len());
112 let scalars = &bigints[..size];
113 let bases = &bases[..size];
114
115 let c = if size < 32 {
116 3
117 } else {
118 super::ln_without_floats(size) + 2
119 };
120
121 let num_bits = V::ScalarField::MODULUS_BIT_SIZE as usize;
122 let digits_count = (num_bits + c - 1) / c;
123 #[cfg(feature = "parallel")]
124 let scalar_digits = scalars
125 .into_par_iter()
126 .flat_map_iter(|s| make_digits(s, c, num_bits))
127 .collect::<Vec<_>>();
128 #[cfg(not(feature = "parallel"))]
129 let scalar_digits = scalars
130 .iter()
131 .flat_map(|s| make_digits(s, c, num_bits))
132 .collect::<Vec<_>>();
133 let zero = V::zero();
134 let window_sums: Vec<_> = ark_std::cfg_into_iter!(0..digits_count)
135 .map(|i| {
136 let mut buckets = vec![zero; 1 << c];
137 for (digits, base) in scalar_digits.chunks(digits_count).zip(bases) {
138 use ark_std::cmp::Ordering;
139 let scalar = digits[i];
141 match 0.cmp(&scalar) {
142 Ordering::Less => buckets[(scalar - 1) as usize] += base,
143 Ordering::Greater => buckets[(-scalar - 1) as usize] -= base,
144 Ordering::Equal => (),
145 }
146 }
147
148 let mut running_sum = V::zero();
149 let mut res = V::zero();
150 buckets.into_iter().rev().for_each(|b| {
151 running_sum += &b;
152 res += &running_sum;
153 });
154 res
155 })
156 .collect();
157
158 let lowest = *window_sums.first().unwrap();
160
161 lowest
163 + &window_sums[1..]
164 .iter()
165 .rev()
166 .fold(zero, |mut total, sum_i| {
167 total += sum_i;
168 for _ in 0..c {
169 total.double_in_place();
170 }
171 total
172 })
173}
174
175fn msm_bigint<V: VariableBaseMSM>(
177 bases: &[V::MulBase],
178 bigints: &[<V::ScalarField as PrimeField>::BigInt],
179) -> V {
180 let size = ark_std::cmp::min(bases.len(), bigints.len());
181 let scalars = &bigints[..size];
182 let bases = &bases[..size];
183 let scalars_and_bases_iter = scalars.iter().zip(bases).filter(|(s, _)| !s.is_zero());
184
185 let c = if size < 32 {
186 3
187 } else {
188 super::ln_without_floats(size) + 2
189 };
190
191 let num_bits = V::ScalarField::MODULUS_BIT_SIZE as usize;
192 let one = V::ScalarField::one().into_bigint();
193
194 let zero = V::zero();
195 let window_starts: Vec<_> = (0..num_bits).step_by(c).collect();
196
197 let window_sums: Vec<_> = ark_std::cfg_into_iter!(window_starts)
201 .map(|w_start| {
202 let mut res = zero;
203 let mut buckets = vec![zero; (1 << c) - 1];
205 scalars_and_bases_iter.clone().for_each(|(&scalar, base)| {
208 if scalar == one {
209 if w_start == 0 {
211 res += base;
212 }
213 } else {
214 let mut scalar = scalar;
215
216 scalar >>= w_start as u32;
219
220 let scalar = scalar.as_ref()[0] % (1 << c);
222
223 if scalar != 0 {
227 buckets[(scalar - 1) as usize] += base;
228 }
229 }
230 });
231
232 let mut running_sum = V::zero();
247 buckets.into_iter().rev().for_each(|b| {
248 running_sum += &b;
249 res += &running_sum;
250 });
251 res
252 })
253 .collect();
254
255 let lowest = *window_sums.first().unwrap();
257
258 lowest
260 + &window_sums[1..]
261 .iter()
262 .rev()
263 .fold(zero, |mut total, sum_i| {
264 total += sum_i;
265 for _ in 0..c {
266 total.double_in_place();
267 }
268 total
269 })
270}
271
272fn make_digits(a: &impl BigInteger, w: usize, num_bits: usize) -> impl Iterator<Item = i64> + '_ {
274 let scalar = a.as_ref();
275 let radix: u64 = 1 << w;
276 let window_mask: u64 = radix - 1;
277
278 let mut carry = 0u64;
279 let num_bits = if num_bits == 0 {
280 a.num_bits() as usize
281 } else {
282 num_bits
283 };
284 let digits_count = (num_bits + w - 1) / w;
285
286 (0..digits_count).into_iter().map(move |i| {
287 let bit_offset = i * w;
289 let u64_idx = bit_offset / 64;
290 let bit_idx = bit_offset % 64;
291 let bit_buf = if bit_idx < 64 - w || u64_idx == scalar.len() - 1 {
293 scalar[u64_idx] >> bit_idx
296 } else {
297 (scalar[u64_idx] >> bit_idx) | (scalar[1 + u64_idx] << (64 - bit_idx))
299 };
300
301 let coef = carry + (bit_buf & window_mask); carry = (coef + radix / 2) >> w;
306 let mut digit = (coef as i64) - (carry << w) as i64;
307
308 if i == digits_count - 1 {
309 digit += (carry << w) as i64;
310 }
311 digit
312 })
313}