Skip to main content

ark_ec/scalar_mul/
mod.rs

1pub mod glv;
2pub mod wnaf;
3
4pub mod variable_base;
5
6use crate::{AffineRepr, PrimeGroup};
7use ark_ff::{AdditiveGroup, BigInteger, PrimeField};
8use ark_std::{
9    cfg_iter, cfg_iter_mut,
10    ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign},
11    vec,
12    vec::*,
13};
14
15#[cfg(feature = "parallel")]
16use rayon::prelude::*;
17
18/// The result of this function is only approximately `ln(a)`
19/// [`Explanation of usage`]
20///
21/// [`Explanation of usage`]: https://github.com/scipr-lab/zexe/issues/79#issue-556220473
22const fn ln_without_floats(a: usize) -> usize {
23    // log2(a) * ln(2)
24    (ark_std::log2(a) * 69 / 100) as usize
25}
26
27/// Standard double-and-add method for multiplication by a scalar for generic additive groups.
28#[inline(always)]
29pub fn double_and_add<G: AdditiveGroup>(base: &G, scalar: impl AsRef<[u64]>) -> G {
30    let mut res = G::ZERO;
31    for b in ark_ff::BitIteratorBE::without_leading_zeros(scalar) {
32        res.double_in_place();
33        if b {
34            res += base
35        }
36    }
37    res
38}
39
40/// Standard double-and-add method for multiplication by a scalar for generic affine points.
41#[inline(always)]
42pub fn double_and_add_affine<P: AffineRepr>(base: &P, scalar: impl AsRef<[u64]>) -> P::Group {
43    let mut res = P::Group::ZERO;
44    for b in ark_ff::BitIteratorBE::without_leading_zeros(scalar) {
45        res.double_in_place();
46        if b {
47            res += base
48        }
49    }
50    res
51}
52
53pub trait ScalarMul:
54    PrimeGroup
55    + Add<Self::MulBase, Output = Self>
56    + AddAssign<Self::MulBase>
57    + for<'a> Add<&'a Self::MulBase, Output = Self>
58    + for<'a> AddAssign<&'a Self::MulBase>
59    + Sub<Self::MulBase, Output = Self>
60    + SubAssign<Self::MulBase>
61    + for<'a> Sub<&'a Self::MulBase, Output = Self>
62    + for<'a> SubAssign<&'a Self::MulBase>
63    + From<Self::MulBase>
64{
65    type MulBase: Send
66        + Sync
67        + Copy
68        + Eq
69        + core::hash::Hash
70        + Mul<Self::ScalarField, Output = Self>
71        + for<'a> Mul<&'a Self::ScalarField, Output = Self>
72        + Neg<Output = Self::MulBase>
73        + From<Self>;
74
75    const NEGATION_IS_CHEAP: bool;
76
77    fn batch_convert_to_mul_base(bases: &[Self]) -> Vec<Self::MulBase>;
78
79    /// Compute the vector v\[0\].G, v\[1\].G, ..., v\[n-1\].G, given:
80    /// - an element `g`
81    /// - a list `v` of n scalars
82    ///
83    /// # Example
84    /// ```
85    /// use ark_std::{One, UniformRand};
86    /// use ark_ec::pairing::Pairing;
87    /// use ark_test_curves::bls12_381::G1Projective as G;
88    /// use ark_test_curves::bls12_381::Fr;
89    /// use ark_ec::scalar_mul::ScalarMul;
90    ///
91    /// // Compute G, s.G, s^2.G, ..., s^9.G
92    /// let mut rng = ark_std::test_rng();
93    /// let max_degree = 10;
94    /// let s = Fr::rand(&mut rng);
95    /// let g = G::rand(&mut rng);
96    /// let mut powers_of_s = vec![Fr::one()];
97    /// let mut cur = s;
98    /// for _ in 0..max_degree {
99    ///     powers_of_s.push(cur);
100    ///     cur *= &s;
101    /// }
102    /// let powers_of_g = g.batch_mul(&powers_of_s);
103    /// let naive_powers_of_g: Vec<G> = powers_of_s.iter().map(|e| g * e).collect();
104    /// assert_eq!(powers_of_g, naive_powers_of_g);
105    /// ```
106    fn batch_mul(self, v: &[Self::ScalarField]) -> Vec<Self::MulBase> {
107        let table = BatchMulPreprocessing::new(self, v.len());
108        Self::batch_mul_with_preprocessing(&table, v)
109    }
110
111    /// Compute the vector v\[0\].G, v\[1\].G, ..., v\[n-1\].G, given:
112    /// - an element `g`
113    /// - a list `v` of n scalars
114    ///
115    /// This method allows the user to provide a precomputed table of multiples of `g`.
116    /// A more ergonomic way to call this would be to use [`BatchMulPreprocessing::batch_mul`].
117    ///
118    /// # Example
119    /// ```
120    /// use ark_std::{One, UniformRand};
121    /// use ark_ec::pairing::Pairing;
122    /// use ark_test_curves::bls12_381::G1Projective as G;
123    /// use ark_test_curves::bls12_381::Fr;
124    /// use ark_ec::scalar_mul::*;
125    ///
126    /// // Compute G, s.G, s^2.G, ..., s^9.G
127    /// let mut rng = ark_std::test_rng();
128    /// let max_degree = 10;
129    /// let s = Fr::rand(&mut rng);
130    /// let g = G::rand(&mut rng);
131    /// let mut powers_of_s = vec![Fr::one()];
132    /// let mut cur = s;
133    /// for _ in 0..max_degree {
134    ///     powers_of_s.push(cur);
135    ///     cur *= &s;
136    /// }
137    /// let table = BatchMulPreprocessing::new(g, powers_of_s.len());
138    /// let powers_of_g = G::batch_mul_with_preprocessing(&table, &powers_of_s);
139    /// let powers_of_g_2 = table.batch_mul(&powers_of_s);
140    /// let naive_powers_of_g: Vec<G> = powers_of_s.iter().map(|e| g * e).collect();
141    /// assert_eq!(powers_of_g, naive_powers_of_g);
142    /// assert_eq!(powers_of_g_2, naive_powers_of_g);
143    /// ```
144    fn batch_mul_with_preprocessing(
145        table: &BatchMulPreprocessing<Self>,
146        v: &[Self::ScalarField],
147    ) -> Vec<Self::MulBase> {
148        table.batch_mul(v)
149    }
150}
151
152/// Preprocessing used internally for batch scalar multiplication via [`ScalarMul::batch_mul`].
153/// - `window` is the window size used for the precomputation
154/// - `max_scalar_size` is the maximum size of the scalars that will be multiplied
155/// - `table` is the precomputed table of multiples of `base`
156pub struct BatchMulPreprocessing<T: ScalarMul> {
157    pub window: usize,
158    pub max_scalar_size: usize,
159    pub table: Vec<Vec<T::MulBase>>,
160}
161
162impl<T: ScalarMul> BatchMulPreprocessing<T> {
163    pub fn new(base: T, num_scalars: usize) -> Self {
164        let scalar_size = T::ScalarField::MODULUS_BIT_SIZE as usize;
165        Self::with_num_scalars_and_scalar_size(base, num_scalars, scalar_size)
166    }
167
168    pub fn with_num_scalars_and_scalar_size(
169        base: T,
170        num_scalars: usize,
171        max_scalar_size: usize,
172    ) -> Self {
173        let window = Self::compute_window_size(num_scalars);
174        let in_window = 1 << window;
175        let outerc = max_scalar_size.div_ceil(window);
176        let last_in_window = 1 << (max_scalar_size - (outerc - 1) * window);
177
178        let mut multiples_of_g = vec![vec![T::zero(); in_window]; outerc];
179
180        let mut g_outer = base;
181        let mut g_outers = Vec::with_capacity(outerc);
182        for _ in 0..outerc {
183            g_outers.push(g_outer);
184            for _ in 0..window {
185                g_outer.double_in_place();
186            }
187        }
188        cfg_iter_mut!(multiples_of_g)
189            .enumerate()
190            .take(outerc)
191            .zip(g_outers)
192            .for_each(|((outer, multiples_of_g), g_outer)| {
193                let cur_in_window = if outer == outerc - 1 {
194                    last_in_window
195                } else {
196                    in_window
197                };
198
199                let mut g_inner = T::zero();
200                for inner in multiples_of_g.iter_mut().take(cur_in_window) {
201                    *inner = g_inner;
202                    g_inner += &g_outer;
203                }
204            });
205        let table = cfg_iter!(multiples_of_g)
206            .map(|s| T::batch_convert_to_mul_base(s))
207            .collect();
208        Self {
209            window,
210            max_scalar_size,
211            table,
212        }
213    }
214
215    pub const fn compute_window_size(num_scalars: usize) -> usize {
216        if num_scalars < 32 {
217            3
218        } else {
219            ln_without_floats(num_scalars)
220        }
221    }
222
223    pub fn batch_mul(&self, v: &[T::ScalarField]) -> Vec<T::MulBase> {
224        let result: Vec<_> = cfg_iter!(v).map(|e| self.windowed_mul(e)).collect();
225        T::batch_convert_to_mul_base(&result)
226    }
227
228    fn windowed_mul(&self, scalar: &T::ScalarField) -> T {
229        let outerc = self.max_scalar_size.div_ceil(self.window);
230        let modulus_size = T::ScalarField::MODULUS_BIT_SIZE as usize;
231        let scalar_val = scalar.into_bigint().to_bits_le();
232
233        let mut res = T::from(self.table[0][0]);
234        for outer in 0..outerc {
235            let mut inner = 0usize;
236            for i in 0..self.window {
237                if outer * self.window + i < modulus_size && scalar_val[outer * self.window + i] {
238                    inner |= 1 << i;
239                }
240            }
241            res += &self.table[outer][inner];
242        }
243        res
244    }
245}