bls12_381/
scalar.rs

1//! This module provides an implementation of the BLS12-381 scalar field $\mathbb{F}_q$
2//! where `q = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001`
3
4use core::fmt;
5use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
6use rand_core::RngCore;
7
8use ff::{Field, PrimeField};
9use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
10
11#[cfg(feature = "bits")]
12use ff::{FieldBits, PrimeFieldBits};
13
14use crate::util::{adc, mac, sbb};
15
16/// Represents an element of the scalar field $\mathbb{F}_q$ of the BLS12-381 elliptic
17/// curve construction.
18// The internal representation of this type is four 64-bit unsigned
19// integers in little-endian order. `Scalar` values are always in
20// Montgomery form; i.e., Scalar(a) = aR mod q, with R = 2^256.
21#[derive(Clone, Copy, Eq)]
22pub struct Scalar(pub(crate) [u64; 4]);
23
24impl fmt::Debug for Scalar {
25    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
26        let tmp = self.to_bytes();
27        write!(f, "0x")?;
28        for &b in tmp.iter().rev() {
29            write!(f, "{:02x}", b)?;
30        }
31        Ok(())
32    }
33}
34
35impl fmt::Display for Scalar {
36    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
37        write!(f, "{:?}", self)
38    }
39}
40
41impl From<u64> for Scalar {
42    fn from(val: u64) -> Scalar {
43        Scalar([val, 0, 0, 0]) * R2
44    }
45}
46
47impl ConstantTimeEq for Scalar {
48    fn ct_eq(&self, other: &Self) -> Choice {
49        self.0[0].ct_eq(&other.0[0])
50            & self.0[1].ct_eq(&other.0[1])
51            & self.0[2].ct_eq(&other.0[2])
52            & self.0[3].ct_eq(&other.0[3])
53    }
54}
55
56impl PartialEq for Scalar {
57    #[inline]
58    fn eq(&self, other: &Self) -> bool {
59        bool::from(self.ct_eq(other))
60    }
61}
62
63impl ConditionallySelectable for Scalar {
64    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
65        Scalar([
66            u64::conditional_select(&a.0[0], &b.0[0], choice),
67            u64::conditional_select(&a.0[1], &b.0[1], choice),
68            u64::conditional_select(&a.0[2], &b.0[2], choice),
69            u64::conditional_select(&a.0[3], &b.0[3], choice),
70        ])
71    }
72}
73
74/// Constant representing the modulus
75/// q = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001
76const MODULUS: Scalar = Scalar([
77    0xffff_ffff_0000_0001,
78    0x53bd_a402_fffe_5bfe,
79    0x3339_d808_09a1_d805,
80    0x73ed_a753_299d_7d48,
81]);
82
83/// The modulus as u32 limbs.
84#[cfg(all(feature = "bits", not(target_pointer_width = "64")))]
85const MODULUS_LIMBS_32: [u32; 8] = [
86    0x0000_0001,
87    0xffff_ffff,
88    0xfffe_5bfe,
89    0x53bd_a402,
90    0x09a1_d805,
91    0x3339_d808,
92    0x299d_7d48,
93    0x73ed_a753,
94];
95
96// The number of bits needed to represent the modulus.
97const MODULUS_BITS: u32 = 255;
98
99// GENERATOR = 7 (multiplicative generator of r-1 order, that is also quadratic nonresidue)
100const GENERATOR: Scalar = Scalar([
101    0x0000_000e_ffff_fff1,
102    0x17e3_63d3_0018_9c0f,
103    0xff9c_5787_6f84_57b0,
104    0x3513_3220_8fc5_a8c4,
105]);
106
107impl<'a> Neg for &'a Scalar {
108    type Output = Scalar;
109
110    #[inline]
111    fn neg(self) -> Scalar {
112        self.neg()
113    }
114}
115
116impl Neg for Scalar {
117    type Output = Scalar;
118
119    #[inline]
120    fn neg(self) -> Scalar {
121        -&self
122    }
123}
124
125impl<'a, 'b> Sub<&'b Scalar> for &'a Scalar {
126    type Output = Scalar;
127
128    #[inline]
129    fn sub(self, rhs: &'b Scalar) -> Scalar {
130        self.sub(rhs)
131    }
132}
133
134impl<'a, 'b> Add<&'b Scalar> for &'a Scalar {
135    type Output = Scalar;
136
137    #[inline]
138    fn add(self, rhs: &'b Scalar) -> Scalar {
139        self.add(rhs)
140    }
141}
142
143impl<'a, 'b> Mul<&'b Scalar> for &'a Scalar {
144    type Output = Scalar;
145
146    #[inline]
147    fn mul(self, rhs: &'b Scalar) -> Scalar {
148        self.mul(rhs)
149    }
150}
151
152impl_binops_additive!(Scalar, Scalar);
153impl_binops_multiplicative!(Scalar, Scalar);
154
155/// INV = -(q^{-1} mod 2^64) mod 2^64
156const INV: u64 = 0xffff_fffe_ffff_ffff;
157
158/// R = 2^256 mod q
159const R: Scalar = Scalar([
160    0x0000_0001_ffff_fffe,
161    0x5884_b7fa_0003_4802,
162    0x998c_4fef_ecbc_4ff5,
163    0x1824_b159_acc5_056f,
164]);
165
166/// R^2 = 2^512 mod q
167const R2: Scalar = Scalar([
168    0xc999_e990_f3f2_9c6d,
169    0x2b6c_edcb_8792_5c23,
170    0x05d3_1496_7254_398f,
171    0x0748_d9d9_9f59_ff11,
172]);
173
174/// R^3 = 2^768 mod q
175const R3: Scalar = Scalar([
176    0xc62c_1807_439b_73af,
177    0x1b3e_0d18_8cf0_6990,
178    0x73d1_3c71_c7b5_f418,
179    0x6e2a_5bb9_c8db_33e9,
180]);
181
182/// 2^-1
183const TWO_INV: Scalar = Scalar([
184    0x0000_0000_ffff_ffff,
185    0xac42_5bfd_0001_a401,
186    0xccc6_27f7_f65e_27fa,
187    0x0c12_58ac_d662_82b7,
188]);
189
190// 2^S * t = MODULUS - 1 with t odd
191const S: u32 = 32;
192
193/// GENERATOR^t where t * 2^s + 1 = q
194/// with t odd. In other words, this
195/// is a 2^s root of unity.
196///
197/// `GENERATOR = 7 mod q` is a generator
198/// of the q - 1 order multiplicative
199/// subgroup.
200const ROOT_OF_UNITY: Scalar = Scalar([
201    0xb9b5_8d8c_5f0e_466a,
202    0x5b1b_4c80_1819_d7ec,
203    0x0af5_3ae3_52a3_1e64,
204    0x5bf3_adda_19e9_b27b,
205]);
206
207/// ROOT_OF_UNITY^-1
208const ROOT_OF_UNITY_INV: Scalar = Scalar([
209    0x4256_481a_dcf3_219a,
210    0x45f3_7b7f_96b6_cad3,
211    0xf9c3_f1d7_5f7a_3b27,
212    0x2d2f_c049_658a_fd43,
213]);
214
215/// GENERATOR^{2^s} where t * 2^s + 1 = q with t odd.
216/// In other words, this is a t root of unity.
217const DELTA: Scalar = Scalar([
218    0x70e3_10d3_d146_f96a,
219    0x4b64_c089_19e2_99e6,
220    0x51e1_1418_6a8b_970d,
221    0x6185_d066_27c0_67cb,
222]);
223
224impl Default for Scalar {
225    #[inline]
226    fn default() -> Self {
227        Self::zero()
228    }
229}
230
231#[cfg(feature = "zeroize")]
232impl zeroize::DefaultIsZeroes for Scalar {}
233
234impl Scalar {
235    /// Returns zero, the additive identity.
236    #[inline]
237    pub const fn zero() -> Scalar {
238        Scalar([0, 0, 0, 0])
239    }
240
241    /// Returns one, the multiplicative identity.
242    #[inline]
243    pub const fn one() -> Scalar {
244        R
245    }
246
247    /// Doubles this field element.
248    #[inline]
249    pub const fn double(&self) -> Scalar {
250        // TODO: This can be achieved more efficiently with a bitshift.
251        self.add(self)
252    }
253
254    /// Attempts to convert a little-endian byte representation of
255    /// a scalar into a `Scalar`, failing if the input is not canonical.
256    pub fn from_bytes(bytes: &[u8; 32]) -> CtOption<Scalar> {
257        let mut tmp = Scalar([0, 0, 0, 0]);
258
259        tmp.0[0] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[0..8]).unwrap());
260        tmp.0[1] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap());
261        tmp.0[2] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap());
262        tmp.0[3] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap());
263
264        // Try to subtract the modulus
265        let (_, borrow) = sbb(tmp.0[0], MODULUS.0[0], 0);
266        let (_, borrow) = sbb(tmp.0[1], MODULUS.0[1], borrow);
267        let (_, borrow) = sbb(tmp.0[2], MODULUS.0[2], borrow);
268        let (_, borrow) = sbb(tmp.0[3], MODULUS.0[3], borrow);
269
270        // If the element is smaller than MODULUS then the
271        // subtraction will underflow, producing a borrow value
272        // of 0xffff...ffff. Otherwise, it'll be zero.
273        let is_some = (borrow as u8) & 1;
274
275        // Convert to Montgomery form by computing
276        // (a.R^0 * R^2) / R = a.R
277        tmp *= &R2;
278
279        CtOption::new(tmp, Choice::from(is_some))
280    }
281
282    /// Converts an element of `Scalar` into a byte representation in
283    /// little-endian byte order.
284    pub fn to_bytes(&self) -> [u8; 32] {
285        // Turn into canonical form by computing
286        // (a.R) / R = a
287        let tmp = Scalar::montgomery_reduce(self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0);
288
289        let mut res = [0; 32];
290        res[0..8].copy_from_slice(&tmp.0[0].to_le_bytes());
291        res[8..16].copy_from_slice(&tmp.0[1].to_le_bytes());
292        res[16..24].copy_from_slice(&tmp.0[2].to_le_bytes());
293        res[24..32].copy_from_slice(&tmp.0[3].to_le_bytes());
294
295        res
296    }
297
298    /// Converts a 512-bit little endian integer into
299    /// a `Scalar` by reducing by the modulus.
300    pub fn from_bytes_wide(bytes: &[u8; 64]) -> Scalar {
301        Scalar::from_u512([
302            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[0..8]).unwrap()),
303            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap()),
304            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap()),
305            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap()),
306            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[32..40]).unwrap()),
307            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[40..48]).unwrap()),
308            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[48..56]).unwrap()),
309            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[56..64]).unwrap()),
310        ])
311    }
312
313    fn from_u512(limbs: [u64; 8]) -> Scalar {
314        // We reduce an arbitrary 512-bit number by decomposing it into two 256-bit digits
315        // with the higher bits multiplied by 2^256. Thus, we perform two reductions
316        //
317        // 1. the lower bits are multiplied by R^2, as normal
318        // 2. the upper bits are multiplied by R^2 * 2^256 = R^3
319        //
320        // and computing their sum in the field. It remains to see that arbitrary 256-bit
321        // numbers can be placed into Montgomery form safely using the reduction. The
322        // reduction works so long as the product is less than R=2^256 multiplied by
323        // the modulus. This holds because for any `c` smaller than the modulus, we have
324        // that (2^256 - 1)*c is an acceptable product for the reduction. Therefore, the
325        // reduction always works so long as `c` is in the field; in this case it is either the
326        // constant `R2` or `R3`.
327        let d0 = Scalar([limbs[0], limbs[1], limbs[2], limbs[3]]);
328        let d1 = Scalar([limbs[4], limbs[5], limbs[6], limbs[7]]);
329        // Convert to Montgomery form
330        d0 * R2 + d1 * R3
331    }
332
333    /// Converts from an integer represented in little endian
334    /// into its (congruent) `Scalar` representation.
335    pub const fn from_raw(val: [u64; 4]) -> Self {
336        (&Scalar(val)).mul(&R2)
337    }
338
339    /// Squares this element.
340    #[inline]
341    pub const fn square(&self) -> Scalar {
342        let (r1, carry) = mac(0, self.0[0], self.0[1], 0);
343        let (r2, carry) = mac(0, self.0[0], self.0[2], carry);
344        let (r3, r4) = mac(0, self.0[0], self.0[3], carry);
345
346        let (r3, carry) = mac(r3, self.0[1], self.0[2], 0);
347        let (r4, r5) = mac(r4, self.0[1], self.0[3], carry);
348
349        let (r5, r6) = mac(r5, self.0[2], self.0[3], 0);
350
351        let r7 = r6 >> 63;
352        let r6 = (r6 << 1) | (r5 >> 63);
353        let r5 = (r5 << 1) | (r4 >> 63);
354        let r4 = (r4 << 1) | (r3 >> 63);
355        let r3 = (r3 << 1) | (r2 >> 63);
356        let r2 = (r2 << 1) | (r1 >> 63);
357        let r1 = r1 << 1;
358
359        let (r0, carry) = mac(0, self.0[0], self.0[0], 0);
360        let (r1, carry) = adc(0, r1, carry);
361        let (r2, carry) = mac(r2, self.0[1], self.0[1], carry);
362        let (r3, carry) = adc(0, r3, carry);
363        let (r4, carry) = mac(r4, self.0[2], self.0[2], carry);
364        let (r5, carry) = adc(0, r5, carry);
365        let (r6, carry) = mac(r6, self.0[3], self.0[3], carry);
366        let (r7, _) = adc(0, r7, carry);
367
368        Scalar::montgomery_reduce(r0, r1, r2, r3, r4, r5, r6, r7)
369    }
370
371    /// Exponentiates `self` by `by`, where `by` is a
372    /// little-endian order integer exponent.
373    pub fn pow(&self, by: &[u64; 4]) -> Self {
374        let mut res = Self::one();
375        for e in by.iter().rev() {
376            for i in (0..64).rev() {
377                res = res.square();
378                let mut tmp = res;
379                tmp *= self;
380                res.conditional_assign(&tmp, (((*e >> i) & 0x1) as u8).into());
381            }
382        }
383        res
384    }
385
386    /// Exponentiates `self` by `by`, where `by` is a
387    /// little-endian order integer exponent.
388    ///
389    /// **This operation is variable time with respect
390    /// to the exponent.** If the exponent is fixed,
391    /// this operation is effectively constant time.
392    pub fn pow_vartime(&self, by: &[u64; 4]) -> Self {
393        let mut res = Self::one();
394        for e in by.iter().rev() {
395            for i in (0..64).rev() {
396                res = res.square();
397
398                if ((*e >> i) & 1) == 1 {
399                    res.mul_assign(self);
400                }
401            }
402        }
403        res
404    }
405
406    /// Computes the multiplicative inverse of this element,
407    /// failing if the element is zero.
408    pub fn invert(&self) -> CtOption<Self> {
409        #[inline(always)]
410        fn square_assign_multi(n: &mut Scalar, num_times: usize) {
411            for _ in 0..num_times {
412                *n = n.square();
413            }
414        }
415        // found using https://github.com/kwantam/addchain
416        let mut t0 = self.square();
417        let mut t1 = t0 * self;
418        let mut t16 = t0.square();
419        let mut t6 = t16.square();
420        let mut t5 = t6 * t0;
421        t0 = t6 * t16;
422        let mut t12 = t5 * t16;
423        let mut t2 = t6.square();
424        let mut t7 = t5 * t6;
425        let mut t15 = t0 * t5;
426        let mut t17 = t12.square();
427        t1 *= t17;
428        let mut t3 = t7 * t2;
429        let t8 = t1 * t17;
430        let t4 = t8 * t2;
431        let t9 = t8 * t7;
432        t7 = t4 * t5;
433        let t11 = t4 * t17;
434        t5 = t9 * t17;
435        let t14 = t7 * t15;
436        let t13 = t11 * t12;
437        t12 = t11 * t17;
438        t15 *= &t12;
439        t16 *= &t15;
440        t3 *= &t16;
441        t17 *= &t3;
442        t0 *= &t17;
443        t6 *= &t0;
444        t2 *= &t6;
445        square_assign_multi(&mut t0, 8);
446        t0 *= &t17;
447        square_assign_multi(&mut t0, 9);
448        t0 *= &t16;
449        square_assign_multi(&mut t0, 9);
450        t0 *= &t15;
451        square_assign_multi(&mut t0, 9);
452        t0 *= &t15;
453        square_assign_multi(&mut t0, 7);
454        t0 *= &t14;
455        square_assign_multi(&mut t0, 7);
456        t0 *= &t13;
457        square_assign_multi(&mut t0, 10);
458        t0 *= &t12;
459        square_assign_multi(&mut t0, 9);
460        t0 *= &t11;
461        square_assign_multi(&mut t0, 8);
462        t0 *= &t8;
463        square_assign_multi(&mut t0, 8);
464        t0 *= self;
465        square_assign_multi(&mut t0, 14);
466        t0 *= &t9;
467        square_assign_multi(&mut t0, 10);
468        t0 *= &t8;
469        square_assign_multi(&mut t0, 15);
470        t0 *= &t7;
471        square_assign_multi(&mut t0, 10);
472        t0 *= &t6;
473        square_assign_multi(&mut t0, 8);
474        t0 *= &t5;
475        square_assign_multi(&mut t0, 16);
476        t0 *= &t3;
477        square_assign_multi(&mut t0, 8);
478        t0 *= &t2;
479        square_assign_multi(&mut t0, 7);
480        t0 *= &t4;
481        square_assign_multi(&mut t0, 9);
482        t0 *= &t2;
483        square_assign_multi(&mut t0, 8);
484        t0 *= &t3;
485        square_assign_multi(&mut t0, 8);
486        t0 *= &t2;
487        square_assign_multi(&mut t0, 8);
488        t0 *= &t2;
489        square_assign_multi(&mut t0, 8);
490        t0 *= &t2;
491        square_assign_multi(&mut t0, 8);
492        t0 *= &t3;
493        square_assign_multi(&mut t0, 8);
494        t0 *= &t2;
495        square_assign_multi(&mut t0, 8);
496        t0 *= &t2;
497        square_assign_multi(&mut t0, 5);
498        t0 *= &t1;
499        square_assign_multi(&mut t0, 5);
500        t0 *= &t1;
501
502        CtOption::new(t0, !self.ct_eq(&Self::zero()))
503    }
504
505    #[inline(always)]
506    const fn montgomery_reduce(
507        r0: u64,
508        r1: u64,
509        r2: u64,
510        r3: u64,
511        r4: u64,
512        r5: u64,
513        r6: u64,
514        r7: u64,
515    ) -> Self {
516        // The Montgomery reduction here is based on Algorithm 14.32 in
517        // Handbook of Applied Cryptography
518        // <http://cacr.uwaterloo.ca/hac/about/chap14.pdf>.
519
520        let k = r0.wrapping_mul(INV);
521        let (_, carry) = mac(r0, k, MODULUS.0[0], 0);
522        let (r1, carry) = mac(r1, k, MODULUS.0[1], carry);
523        let (r2, carry) = mac(r2, k, MODULUS.0[2], carry);
524        let (r3, carry) = mac(r3, k, MODULUS.0[3], carry);
525        let (r4, carry2) = adc(r4, 0, carry);
526
527        let k = r1.wrapping_mul(INV);
528        let (_, carry) = mac(r1, k, MODULUS.0[0], 0);
529        let (r2, carry) = mac(r2, k, MODULUS.0[1], carry);
530        let (r3, carry) = mac(r3, k, MODULUS.0[2], carry);
531        let (r4, carry) = mac(r4, k, MODULUS.0[3], carry);
532        let (r5, carry2) = adc(r5, carry2, carry);
533
534        let k = r2.wrapping_mul(INV);
535        let (_, carry) = mac(r2, k, MODULUS.0[0], 0);
536        let (r3, carry) = mac(r3, k, MODULUS.0[1], carry);
537        let (r4, carry) = mac(r4, k, MODULUS.0[2], carry);
538        let (r5, carry) = mac(r5, k, MODULUS.0[3], carry);
539        let (r6, carry2) = adc(r6, carry2, carry);
540
541        let k = r3.wrapping_mul(INV);
542        let (_, carry) = mac(r3, k, MODULUS.0[0], 0);
543        let (r4, carry) = mac(r4, k, MODULUS.0[1], carry);
544        let (r5, carry) = mac(r5, k, MODULUS.0[2], carry);
545        let (r6, carry) = mac(r6, k, MODULUS.0[3], carry);
546        let (r7, _) = adc(r7, carry2, carry);
547
548        // Result may be within MODULUS of the correct value
549        (&Scalar([r4, r5, r6, r7])).sub(&MODULUS)
550    }
551
552    /// Multiplies `rhs` by `self`, returning the result.
553    #[inline]
554    pub const fn mul(&self, rhs: &Self) -> Self {
555        // Schoolbook multiplication
556
557        let (r0, carry) = mac(0, self.0[0], rhs.0[0], 0);
558        let (r1, carry) = mac(0, self.0[0], rhs.0[1], carry);
559        let (r2, carry) = mac(0, self.0[0], rhs.0[2], carry);
560        let (r3, r4) = mac(0, self.0[0], rhs.0[3], carry);
561
562        let (r1, carry) = mac(r1, self.0[1], rhs.0[0], 0);
563        let (r2, carry) = mac(r2, self.0[1], rhs.0[1], carry);
564        let (r3, carry) = mac(r3, self.0[1], rhs.0[2], carry);
565        let (r4, r5) = mac(r4, self.0[1], rhs.0[3], carry);
566
567        let (r2, carry) = mac(r2, self.0[2], rhs.0[0], 0);
568        let (r3, carry) = mac(r3, self.0[2], rhs.0[1], carry);
569        let (r4, carry) = mac(r4, self.0[2], rhs.0[2], carry);
570        let (r5, r6) = mac(r5, self.0[2], rhs.0[3], carry);
571
572        let (r3, carry) = mac(r3, self.0[3], rhs.0[0], 0);
573        let (r4, carry) = mac(r4, self.0[3], rhs.0[1], carry);
574        let (r5, carry) = mac(r5, self.0[3], rhs.0[2], carry);
575        let (r6, r7) = mac(r6, self.0[3], rhs.0[3], carry);
576
577        Scalar::montgomery_reduce(r0, r1, r2, r3, r4, r5, r6, r7)
578    }
579
580    /// Subtracts `rhs` from `self`, returning the result.
581    #[inline]
582    pub const fn sub(&self, rhs: &Self) -> Self {
583        let (d0, borrow) = sbb(self.0[0], rhs.0[0], 0);
584        let (d1, borrow) = sbb(self.0[1], rhs.0[1], borrow);
585        let (d2, borrow) = sbb(self.0[2], rhs.0[2], borrow);
586        let (d3, borrow) = sbb(self.0[3], rhs.0[3], borrow);
587
588        // If underflow occurred on the final limb, borrow = 0xfff...fff, otherwise
589        // borrow = 0x000...000. Thus, we use it as a mask to conditionally add the modulus.
590        let (d0, carry) = adc(d0, MODULUS.0[0] & borrow, 0);
591        let (d1, carry) = adc(d1, MODULUS.0[1] & borrow, carry);
592        let (d2, carry) = adc(d2, MODULUS.0[2] & borrow, carry);
593        let (d3, _) = adc(d3, MODULUS.0[3] & borrow, carry);
594
595        Scalar([d0, d1, d2, d3])
596    }
597
598    /// Adds `rhs` to `self`, returning the result.
599    #[inline]
600    pub const fn add(&self, rhs: &Self) -> Self {
601        let (d0, carry) = adc(self.0[0], rhs.0[0], 0);
602        let (d1, carry) = adc(self.0[1], rhs.0[1], carry);
603        let (d2, carry) = adc(self.0[2], rhs.0[2], carry);
604        let (d3, _) = adc(self.0[3], rhs.0[3], carry);
605
606        // Attempt to subtract the modulus, to ensure the value
607        // is smaller than the modulus.
608        (&Scalar([d0, d1, d2, d3])).sub(&MODULUS)
609    }
610
611    /// Negates `self`.
612    #[inline]
613    pub const fn neg(&self) -> Self {
614        // Subtract `self` from `MODULUS` to negate. Ignore the final
615        // borrow because it cannot underflow; self is guaranteed to
616        // be in the field.
617        let (d0, borrow) = sbb(MODULUS.0[0], self.0[0], 0);
618        let (d1, borrow) = sbb(MODULUS.0[1], self.0[1], borrow);
619        let (d2, borrow) = sbb(MODULUS.0[2], self.0[2], borrow);
620        let (d3, _) = sbb(MODULUS.0[3], self.0[3], borrow);
621
622        // `tmp` could be `MODULUS` if `self` was zero. Create a mask that is
623        // zero if `self` was zero, and `u64::max_value()` if self was nonzero.
624        let mask = (((self.0[0] | self.0[1] | self.0[2] | self.0[3]) == 0) as u64).wrapping_sub(1);
625
626        Scalar([d0 & mask, d1 & mask, d2 & mask, d3 & mask])
627    }
628}
629
630impl From<Scalar> for [u8; 32] {
631    fn from(value: Scalar) -> [u8; 32] {
632        value.to_bytes()
633    }
634}
635
636impl<'a> From<&'a Scalar> for [u8; 32] {
637    fn from(value: &'a Scalar) -> [u8; 32] {
638        value.to_bytes()
639    }
640}
641
642impl Field for Scalar {
643    const ZERO: Self = Self::zero();
644    const ONE: Self = Self::one();
645
646    fn random(mut rng: impl RngCore) -> Self {
647        let mut buf = [0; 64];
648        rng.fill_bytes(&mut buf);
649        Self::from_bytes_wide(&buf)
650    }
651
652    #[must_use]
653    fn square(&self) -> Self {
654        self.square()
655    }
656
657    #[must_use]
658    fn double(&self) -> Self {
659        self.double()
660    }
661
662    fn invert(&self) -> CtOption<Self> {
663        self.invert()
664    }
665
666    fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
667        ff::helpers::sqrt_ratio_generic(num, div)
668    }
669
670    fn sqrt(&self) -> CtOption<Self> {
671        // (t - 1) // 2 = 6104339283789297388802252303364915521546564123189034618274734669823
672        ff::helpers::sqrt_tonelli_shanks(
673            self,
674            &[
675                0x7fff_2dff_7fff_ffff,
676                0x04d0_ec02_a9de_d201,
677                0x94ce_bea4_199c_ec04,
678                0x0000_0000_39f6_d3a9,
679            ],
680        )
681    }
682
683    fn is_zero_vartime(&self) -> bool {
684        self.0 == Self::zero().0
685    }
686}
687
688impl PrimeField for Scalar {
689    type Repr = [u8; 32];
690
691    fn from_repr(r: Self::Repr) -> CtOption<Self> {
692        Self::from_bytes(&r)
693    }
694
695    fn to_repr(&self) -> Self::Repr {
696        self.to_bytes()
697    }
698
699    fn is_odd(&self) -> Choice {
700        Choice::from(self.to_bytes()[0] & 1)
701    }
702
703    const MODULUS: &'static str =
704        "0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001";
705    const NUM_BITS: u32 = MODULUS_BITS;
706    const CAPACITY: u32 = Self::NUM_BITS - 1;
707    const TWO_INV: Self = TWO_INV;
708    const MULTIPLICATIVE_GENERATOR: Self = GENERATOR;
709    const S: u32 = S;
710    const ROOT_OF_UNITY: Self = ROOT_OF_UNITY;
711    const ROOT_OF_UNITY_INV: Self = ROOT_OF_UNITY_INV;
712    const DELTA: Self = DELTA;
713}
714
715#[cfg(all(feature = "bits", not(target_pointer_width = "64")))]
716type ReprBits = [u32; 8];
717
718#[cfg(all(feature = "bits", target_pointer_width = "64"))]
719type ReprBits = [u64; 4];
720
721#[cfg(feature = "bits")]
722impl PrimeFieldBits for Scalar {
723    type ReprBits = ReprBits;
724
725    fn to_le_bits(&self) -> FieldBits<Self::ReprBits> {
726        let bytes = self.to_bytes();
727
728        #[cfg(not(target_pointer_width = "64"))]
729        let limbs = [
730            u32::from_le_bytes(bytes[0..4].try_into().unwrap()),
731            u32::from_le_bytes(bytes[4..8].try_into().unwrap()),
732            u32::from_le_bytes(bytes[8..12].try_into().unwrap()),
733            u32::from_le_bytes(bytes[12..16].try_into().unwrap()),
734            u32::from_le_bytes(bytes[16..20].try_into().unwrap()),
735            u32::from_le_bytes(bytes[20..24].try_into().unwrap()),
736            u32::from_le_bytes(bytes[24..28].try_into().unwrap()),
737            u32::from_le_bytes(bytes[28..32].try_into().unwrap()),
738        ];
739
740        #[cfg(target_pointer_width = "64")]
741        let limbs = [
742            u64::from_le_bytes(bytes[0..8].try_into().unwrap()),
743            u64::from_le_bytes(bytes[8..16].try_into().unwrap()),
744            u64::from_le_bytes(bytes[16..24].try_into().unwrap()),
745            u64::from_le_bytes(bytes[24..32].try_into().unwrap()),
746        ];
747
748        FieldBits::new(limbs)
749    }
750
751    fn char_le_bits() -> FieldBits<Self::ReprBits> {
752        #[cfg(not(target_pointer_width = "64"))]
753        {
754            FieldBits::new(MODULUS_LIMBS_32)
755        }
756
757        #[cfg(target_pointer_width = "64")]
758        FieldBits::new(MODULUS.0)
759    }
760}
761
762impl<T> core::iter::Sum<T> for Scalar
763where
764    T: core::borrow::Borrow<Scalar>,
765{
766    fn sum<I>(iter: I) -> Self
767    where
768        I: Iterator<Item = T>,
769    {
770        iter.fold(Self::zero(), |acc, item| acc + item.borrow())
771    }
772}
773
774impl<T> core::iter::Product<T> for Scalar
775where
776    T: core::borrow::Borrow<Scalar>,
777{
778    fn product<I>(iter: I) -> Self
779    where
780        I: Iterator<Item = T>,
781    {
782        iter.fold(Self::one(), |acc, item| acc * item.borrow())
783    }
784}
785
786#[test]
787fn test_constants() {
788    assert_eq!(
789        Scalar::MODULUS,
790        "0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001",
791    );
792
793    assert_eq!(Scalar::from(2) * Scalar::TWO_INV, Scalar::ONE);
794
795    assert_eq!(
796        Scalar::ROOT_OF_UNITY * Scalar::ROOT_OF_UNITY_INV,
797        Scalar::ONE,
798    );
799
800    // ROOT_OF_UNITY^{2^s} mod m == 1
801    assert_eq!(
802        Scalar::ROOT_OF_UNITY.pow(&[1u64 << Scalar::S, 0, 0, 0]),
803        Scalar::ONE,
804    );
805
806    // DELTA^{t} mod m == 1
807    assert_eq!(
808        Scalar::DELTA.pow(&[
809            0xfffe_5bfe_ffff_ffff,
810            0x09a1_d805_53bd_a402,
811            0x299d_7d48_3339_d808,
812            0x0000_0000_73ed_a753,
813        ]),
814        Scalar::ONE,
815    );
816}
817
818#[test]
819fn test_inv() {
820    // Compute -(q^{-1} mod 2^64) mod 2^64 by exponentiating
821    // by totient(2**64) - 1
822
823    let mut inv = 1u64;
824    for _ in 0..63 {
825        inv = inv.wrapping_mul(inv);
826        inv = inv.wrapping_mul(MODULUS.0[0]);
827    }
828    inv = inv.wrapping_neg();
829
830    assert_eq!(inv, INV);
831}
832
833#[cfg(feature = "std")]
834#[test]
835fn test_debug() {
836    assert_eq!(
837        format!("{:?}", Scalar::zero()),
838        "0x0000000000000000000000000000000000000000000000000000000000000000"
839    );
840    assert_eq!(
841        format!("{:?}", Scalar::one()),
842        "0x0000000000000000000000000000000000000000000000000000000000000001"
843    );
844    assert_eq!(
845        format!("{:?}", R2),
846        "0x1824b159acc5056f998c4fefecbc4ff55884b7fa0003480200000001fffffffe"
847    );
848}
849
850#[test]
851fn test_equality() {
852    assert_eq!(Scalar::zero(), Scalar::zero());
853    assert_eq!(Scalar::one(), Scalar::one());
854    assert_eq!(R2, R2);
855
856    assert!(Scalar::zero() != Scalar::one());
857    assert!(Scalar::one() != R2);
858}
859
860#[test]
861fn test_to_bytes() {
862    assert_eq!(
863        Scalar::zero().to_bytes(),
864        [
865            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
866            0, 0, 0
867        ]
868    );
869
870    assert_eq!(
871        Scalar::one().to_bytes(),
872        [
873            1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
874            0, 0, 0
875        ]
876    );
877
878    assert_eq!(
879        R2.to_bytes(),
880        [
881            254, 255, 255, 255, 1, 0, 0, 0, 2, 72, 3, 0, 250, 183, 132, 88, 245, 79, 188, 236, 239,
882            79, 140, 153, 111, 5, 197, 172, 89, 177, 36, 24
883        ]
884    );
885
886    assert_eq!(
887        (-&Scalar::one()).to_bytes(),
888        [
889            0, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
890            216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 115
891        ]
892    );
893}
894
895#[test]
896fn test_from_bytes() {
897    assert_eq!(
898        Scalar::from_bytes(&[
899            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
900            0, 0, 0
901        ])
902        .unwrap(),
903        Scalar::zero()
904    );
905
906    assert_eq!(
907        Scalar::from_bytes(&[
908            1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
909            0, 0, 0
910        ])
911        .unwrap(),
912        Scalar::one()
913    );
914
915    assert_eq!(
916        Scalar::from_bytes(&[
917            254, 255, 255, 255, 1, 0, 0, 0, 2, 72, 3, 0, 250, 183, 132, 88, 245, 79, 188, 236, 239,
918            79, 140, 153, 111, 5, 197, 172, 89, 177, 36, 24
919        ])
920        .unwrap(),
921        R2
922    );
923
924    // -1 should work
925    assert!(bool::from(
926        Scalar::from_bytes(&[
927            0, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
928            216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 115
929        ])
930        .is_some()
931    ));
932
933    // modulus is invalid
934    assert!(bool::from(
935        Scalar::from_bytes(&[
936            1, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
937            216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 115
938        ])
939        .is_none()
940    ));
941
942    // Anything larger than the modulus is invalid
943    assert!(bool::from(
944        Scalar::from_bytes(&[
945            2, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
946            216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 115
947        ])
948        .is_none()
949    ));
950    assert!(bool::from(
951        Scalar::from_bytes(&[
952            1, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
953            216, 58, 51, 72, 125, 157, 41, 83, 167, 237, 115
954        ])
955        .is_none()
956    ));
957    assert!(bool::from(
958        Scalar::from_bytes(&[
959            1, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
960            216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 116
961        ])
962        .is_none()
963    ));
964}
965
966#[test]
967fn test_from_u512_zero() {
968    assert_eq!(
969        Scalar::zero(),
970        Scalar::from_u512([
971            MODULUS.0[0],
972            MODULUS.0[1],
973            MODULUS.0[2],
974            MODULUS.0[3],
975            0,
976            0,
977            0,
978            0
979        ])
980    );
981}
982
983#[test]
984fn test_from_u512_r() {
985    assert_eq!(R, Scalar::from_u512([1, 0, 0, 0, 0, 0, 0, 0]));
986}
987
988#[test]
989fn test_from_u512_r2() {
990    assert_eq!(R2, Scalar::from_u512([0, 0, 0, 0, 1, 0, 0, 0]));
991}
992
993#[test]
994fn test_from_u512_max() {
995    let max_u64 = 0xffff_ffff_ffff_ffff;
996    assert_eq!(
997        R3 - R,
998        Scalar::from_u512([max_u64, max_u64, max_u64, max_u64, max_u64, max_u64, max_u64, max_u64])
999    );
1000}
1001
1002#[test]
1003fn test_from_bytes_wide_r2() {
1004    assert_eq!(
1005        R2,
1006        Scalar::from_bytes_wide(&[
1007            254, 255, 255, 255, 1, 0, 0, 0, 2, 72, 3, 0, 250, 183, 132, 88, 245, 79, 188, 236, 239,
1008            79, 140, 153, 111, 5, 197, 172, 89, 177, 36, 24, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1009            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1010        ])
1011    );
1012}
1013
1014#[test]
1015fn test_from_bytes_wide_negative_one() {
1016    assert_eq!(
1017        -&Scalar::one(),
1018        Scalar::from_bytes_wide(&[
1019            0, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
1020            216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 115, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1021            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1022        ])
1023    );
1024}
1025
1026#[test]
1027fn test_from_bytes_wide_maximum() {
1028    assert_eq!(
1029        Scalar([
1030            0xc62c_1805_439b_73b1,
1031            0xc2b9_551e_8ced_218e,
1032            0xda44_ec81_daf9_a422,
1033            0x5605_aa60_1c16_2e79,
1034        ]),
1035        Scalar::from_bytes_wide(&[0xff; 64])
1036    );
1037}
1038
1039#[test]
1040fn test_zero() {
1041    assert_eq!(Scalar::zero(), -&Scalar::zero());
1042    assert_eq!(Scalar::zero(), Scalar::zero() + Scalar::zero());
1043    assert_eq!(Scalar::zero(), Scalar::zero() - Scalar::zero());
1044    assert_eq!(Scalar::zero(), Scalar::zero() * Scalar::zero());
1045}
1046
1047#[cfg(test)]
1048const LARGEST: Scalar = Scalar([
1049    0xffff_ffff_0000_0000,
1050    0x53bd_a402_fffe_5bfe,
1051    0x3339_d808_09a1_d805,
1052    0x73ed_a753_299d_7d48,
1053]);
1054
1055#[test]
1056fn test_addition() {
1057    let mut tmp = LARGEST;
1058    tmp += &LARGEST;
1059
1060    assert_eq!(
1061        tmp,
1062        Scalar([
1063            0xffff_fffe_ffff_ffff,
1064            0x53bd_a402_fffe_5bfe,
1065            0x3339_d808_09a1_d805,
1066            0x73ed_a753_299d_7d48,
1067        ])
1068    );
1069
1070    let mut tmp = LARGEST;
1071    tmp += &Scalar([1, 0, 0, 0]);
1072
1073    assert_eq!(tmp, Scalar::zero());
1074}
1075
1076#[test]
1077fn test_negation() {
1078    let tmp = -&LARGEST;
1079
1080    assert_eq!(tmp, Scalar([1, 0, 0, 0]));
1081
1082    let tmp = -&Scalar::zero();
1083    assert_eq!(tmp, Scalar::zero());
1084    let tmp = -&Scalar([1, 0, 0, 0]);
1085    assert_eq!(tmp, LARGEST);
1086}
1087
1088#[test]
1089fn test_subtraction() {
1090    let mut tmp = LARGEST;
1091    tmp -= &LARGEST;
1092
1093    assert_eq!(tmp, Scalar::zero());
1094
1095    let mut tmp = Scalar::zero();
1096    tmp -= &LARGEST;
1097
1098    let mut tmp2 = MODULUS;
1099    tmp2 -= &LARGEST;
1100
1101    assert_eq!(tmp, tmp2);
1102}
1103
1104#[test]
1105fn test_multiplication() {
1106    let mut cur = LARGEST;
1107
1108    for _ in 0..100 {
1109        let mut tmp = cur;
1110        tmp *= &cur;
1111
1112        let mut tmp2 = Scalar::zero();
1113        for b in cur
1114            .to_bytes()
1115            .iter()
1116            .rev()
1117            .flat_map(|byte| (0..8).rev().map(move |i| ((byte >> i) & 1u8) == 1u8))
1118        {
1119            let tmp3 = tmp2;
1120            tmp2.add_assign(&tmp3);
1121
1122            if b {
1123                tmp2.add_assign(&cur);
1124            }
1125        }
1126
1127        assert_eq!(tmp, tmp2);
1128
1129        cur.add_assign(&LARGEST);
1130    }
1131}
1132
1133#[test]
1134fn test_squaring() {
1135    let mut cur = LARGEST;
1136
1137    for _ in 0..100 {
1138        let mut tmp = cur;
1139        tmp = tmp.square();
1140
1141        let mut tmp2 = Scalar::zero();
1142        for b in cur
1143            .to_bytes()
1144            .iter()
1145            .rev()
1146            .flat_map(|byte| (0..8).rev().map(move |i| ((byte >> i) & 1u8) == 1u8))
1147        {
1148            let tmp3 = tmp2;
1149            tmp2.add_assign(&tmp3);
1150
1151            if b {
1152                tmp2.add_assign(&cur);
1153            }
1154        }
1155
1156        assert_eq!(tmp, tmp2);
1157
1158        cur.add_assign(&LARGEST);
1159    }
1160}
1161
1162#[test]
1163fn test_inversion() {
1164    assert!(bool::from(Scalar::zero().invert().is_none()));
1165    assert_eq!(Scalar::one().invert().unwrap(), Scalar::one());
1166    assert_eq!((-&Scalar::one()).invert().unwrap(), -&Scalar::one());
1167
1168    let mut tmp = R2;
1169
1170    for _ in 0..100 {
1171        let mut tmp2 = tmp.invert().unwrap();
1172        tmp2.mul_assign(&tmp);
1173
1174        assert_eq!(tmp2, Scalar::one());
1175
1176        tmp.add_assign(&R2);
1177    }
1178}
1179
1180#[test]
1181fn test_invert_is_pow() {
1182    let q_minus_2 = [
1183        0xffff_fffe_ffff_ffff,
1184        0x53bd_a402_fffe_5bfe,
1185        0x3339_d808_09a1_d805,
1186        0x73ed_a753_299d_7d48,
1187    ];
1188
1189    let mut r1 = R;
1190    let mut r2 = R;
1191    let mut r3 = R;
1192
1193    for _ in 0..100 {
1194        r1 = r1.invert().unwrap();
1195        r2 = r2.pow_vartime(&q_minus_2);
1196        r3 = r3.pow(&q_minus_2);
1197
1198        assert_eq!(r1, r2);
1199        assert_eq!(r2, r3);
1200        // Add R so we check something different next time around
1201        r1.add_assign(&R);
1202        r2 = r1;
1203        r3 = r1;
1204    }
1205}
1206
1207#[test]
1208fn test_sqrt() {
1209    {
1210        assert_eq!(Scalar::zero().sqrt().unwrap(), Scalar::zero());
1211    }
1212
1213    let mut square = Scalar([
1214        0x46cd_85a5_f273_077e,
1215        0x1d30_c47d_d68f_c735,
1216        0x77f6_56f6_0bec_a0eb,
1217        0x494a_a01b_df32_468d,
1218    ]);
1219
1220    let mut none_count = 0;
1221
1222    for _ in 0..100 {
1223        let square_root = square.sqrt();
1224        if bool::from(square_root.is_none()) {
1225            none_count += 1;
1226        } else {
1227            assert_eq!(square_root.unwrap() * square_root.unwrap(), square);
1228        }
1229        square -= Scalar::one();
1230    }
1231
1232    assert_eq!(49, none_count);
1233}
1234
1235#[test]
1236fn test_from_raw() {
1237    assert_eq!(
1238        Scalar::from_raw([
1239            0x0001_ffff_fffd,
1240            0x5884_b7fa_0003_4802,
1241            0x998c_4fef_ecbc_4ff5,
1242            0x1824_b159_acc5_056f,
1243        ]),
1244        Scalar::from_raw([0xffff_ffff_ffff_ffff; 4])
1245    );
1246
1247    assert_eq!(Scalar::from_raw(MODULUS.0), Scalar::zero());
1248
1249    assert_eq!(Scalar::from_raw([1, 0, 0, 0]), R);
1250}
1251
1252#[test]
1253fn test_double() {
1254    let a = Scalar::from_raw([
1255        0x1fff_3231_233f_fffd,
1256        0x4884_b7fa_0003_4802,
1257        0x998c_4fef_ecbc_4ff3,
1258        0x1824_b159_acc5_0562,
1259    ]);
1260
1261    assert_eq!(a.double(), a + a);
1262}
1263
1264#[cfg(feature = "zeroize")]
1265#[test]
1266fn test_zeroize() {
1267    use zeroize::Zeroize;
1268
1269    let mut a = Scalar::from_raw([
1270        0x1fff_3231_233f_fffd,
1271        0x4884_b7fa_0003_4802,
1272        0x998c_4fef_ecbc_4ff3,
1273        0x1824_b159_acc5_0562,
1274    ]);
1275    a.zeroize();
1276    assert!(bool::from(a.is_zero()));
1277}