ark_ff/fields/models/fp/
montgomery_backend.rs

1use super::{Fp, FpConfig};
2use crate::{
3    biginteger::arithmetic as fa, BigInt, BigInteger, PrimeField, SqrtPrecomputation, Zero,
4};
5use ark_ff_macros::unroll_for_loops;
6use ark_std::marker::PhantomData;
7
8/// A trait that specifies the constants and arithmetic procedures
9/// for Montgomery arithmetic over the prime field defined by `MODULUS`.
10///
11/// # Note
12/// Manual implementation of this trait is not recommended unless one wishes
13/// to specialize arithmetic methods. Instead, the
14/// [`MontConfig`][`ark_ff_macros::MontConfig`] derive macro should be used.
15pub trait MontConfig<const N: usize>: 'static + Sync + Send + Sized {
16    /// The modulus of the field.
17    const MODULUS: BigInt<N>;
18
19    /// Let `M` be the power of 2^64 nearest to `Self::MODULUS_BITS`. Then
20    /// `R = M % Self::MODULUS`.
21    const R: BigInt<N> = Self::MODULUS.montgomery_r();
22
23    /// R2 = R^2 % Self::MODULUS
24    const R2: BigInt<N> = Self::MODULUS.montgomery_r2();
25
26    /// INV = -MODULUS^{-1} mod 2^64
27    const INV: u64 = inv::<Self, N>();
28
29    /// A multiplicative generator of the field.
30    /// `Self::GENERATOR` is an element having multiplicative order
31    /// `Self::MODULUS - 1`.
32    const GENERATOR: Fp<MontBackend<Self, N>, N>;
33
34    /// Can we use the no-carry optimization for multiplication
35    /// outlined [here](https://hackmd.io/@gnark/modular_multiplication)?
36    ///
37    /// This optimization applies if
38    /// (a) `Self::MODULUS[N-1] < u64::MAX >> 1`, and
39    /// (b) the bits of the modulus are not all 1.
40    #[doc(hidden)]
41    const CAN_USE_NO_CARRY_MUL_OPT: bool = can_use_no_carry_mul_optimization::<Self, N>();
42
43    /// Can we use the no-carry optimization for squaring
44    /// outlined [here](https://hackmd.io/@gnark/modular_multiplication)?
45    ///
46    /// This optimization applies if
47    /// (a) `Self::MODULUS[N-1] < u64::MAX >> 2`, and
48    /// (b) the bits of the modulus are not all 1.
49    #[doc(hidden)]
50    const CAN_USE_NO_CARRY_SQUARE_OPT: bool = can_use_no_carry_mul_optimization::<Self, N>();
51
52    /// Does the modulus have a spare unused bit
53    ///
54    /// This condition applies if
55    /// (a) `Self::MODULUS[N-1] >> 63 == 0`
56    #[doc(hidden)]
57    const MODULUS_HAS_SPARE_BIT: bool = modulus_has_spare_bit::<Self, N>();
58
59    /// 2^s root of unity computed by GENERATOR^t
60    const TWO_ADIC_ROOT_OF_UNITY: Fp<MontBackend<Self, N>, N>;
61
62    /// An integer `b` such that there exists a multiplicative subgroup
63    /// of size `b^k` for some integer `k`.
64    const SMALL_SUBGROUP_BASE: Option<u32> = None;
65
66    /// The integer `k` such that there exists a multiplicative subgroup
67    /// of size `Self::SMALL_SUBGROUP_BASE^k`.
68    const SMALL_SUBGROUP_BASE_ADICITY: Option<u32> = None;
69
70    /// GENERATOR^((MODULUS-1) / (2^s *
71    /// SMALL_SUBGROUP_BASE^SMALL_SUBGROUP_BASE_ADICITY)).
72    /// Used for mixed-radix FFT.
73    const LARGE_SUBGROUP_ROOT_OF_UNITY: Option<Fp<MontBackend<Self, N>, N>> = None;
74
75    /// Precomputed material for use when computing square roots.
76    /// The default is to use the standard Tonelli-Shanks algorithm.
77    const SQRT_PRECOMP: Option<SqrtPrecomputation<Fp<MontBackend<Self, N>, N>>> =
78        sqrt_precomputation::<N, Self>();
79
80    /// (MODULUS + 1) / 4 when MODULUS % 4 == 3. Used for square root precomputations.
81    #[doc(hidden)]
82    const MODULUS_PLUS_ONE_DIV_FOUR: Option<BigInt<N>> = {
83        match Self::MODULUS.mod_4() == 3 {
84            true => {
85                let (modulus_plus_one, carry) =
86                    Self::MODULUS.const_add_with_carry(&BigInt::<N>::one());
87                let mut result = modulus_plus_one.divide_by_2_round_down();
88                // Since modulus_plus_one is even, dividing by 2 results in a MSB of 0.
89                // Thus we can set MSB to `carry` to get the correct result of (MODULUS + 1) // 2:
90                result.0[N - 1] |= (carry as u64) << 63;
91                Some(result.divide_by_2_round_down())
92            },
93            false => None,
94        }
95    };
96
97    /// Sets `a = a + b`.
98    #[inline(always)]
99    fn add_assign(a: &mut Fp<MontBackend<Self, N>, N>, b: &Fp<MontBackend<Self, N>, N>) {
100        // This cannot exceed the backing capacity.
101        let c = a.0.add_with_carry(&b.0);
102        // However, it may need to be reduced
103        if Self::MODULUS_HAS_SPARE_BIT {
104            a.subtract_modulus()
105        } else {
106            a.subtract_modulus_with_carry(c)
107        }
108    }
109
110    /// Sets `a = a - b`.
111    #[inline(always)]
112    fn sub_assign(a: &mut Fp<MontBackend<Self, N>, N>, b: &Fp<MontBackend<Self, N>, N>) {
113        // If `other` is larger than `self`, add the modulus to self first.
114        if b.0 > a.0 {
115            a.0.add_with_carry(&Self::MODULUS);
116        }
117        a.0.sub_with_borrow(&b.0);
118    }
119
120    /// Sets `a = 2 * a`.
121    #[inline(always)]
122    fn double_in_place(a: &mut Fp<MontBackend<Self, N>, N>) {
123        // This cannot exceed the backing capacity.
124        let c = a.0.mul2();
125        // However, it may need to be reduced.
126        if Self::MODULUS_HAS_SPARE_BIT {
127            a.subtract_modulus()
128        } else {
129            a.subtract_modulus_with_carry(c)
130        }
131    }
132
133    /// Sets `a = -a`.
134    #[inline(always)]
135    fn neg_in_place(a: &mut Fp<MontBackend<Self, N>, N>) {
136        if !a.is_zero() {
137            let mut tmp = Self::MODULUS;
138            tmp.sub_with_borrow(&a.0);
139            a.0 = tmp;
140        }
141    }
142
143    /// This modular multiplication algorithm uses Montgomery
144    /// reduction for efficient implementation. It also additionally
145    /// uses the "no-carry optimization" outlined
146    /// [here](https://hackmd.io/@gnark/modular_multiplication) if
147    /// `Self::MODULUS` has (a) a non-zero MSB, and (b) at least one
148    /// zero bit in the rest of the modulus.
149    #[unroll_for_loops(12)]
150    #[inline(always)]
151    fn mul_assign(a: &mut Fp<MontBackend<Self, N>, N>, b: &Fp<MontBackend<Self, N>, N>) {
152        // No-carry optimisation applied to CIOS
153        if Self::CAN_USE_NO_CARRY_MUL_OPT {
154            if N <= 6
155                && N > 1
156                && cfg!(all(
157                    feature = "asm",
158                    target_feature = "bmi2",
159                    target_feature = "adx",
160                    target_arch = "x86_64"
161                ))
162            {
163                #[cfg(
164                    all(
165                        feature = "asm",
166                        target_feature = "bmi2",
167                        target_feature = "adx",
168                        target_arch = "x86_64"
169                    )
170                )]
171                #[allow(unsafe_code, unused_mut)]
172                #[rustfmt::skip]
173
174                // Tentatively avoid using assembly for `N == 1`.
175                match N {
176                    2 => { ark_ff_asm::x86_64_asm_mul!(2, (a.0).0, (b.0).0); },
177                    3 => { ark_ff_asm::x86_64_asm_mul!(3, (a.0).0, (b.0).0); },
178                    4 => { ark_ff_asm::x86_64_asm_mul!(4, (a.0).0, (b.0).0); },
179                    5 => { ark_ff_asm::x86_64_asm_mul!(5, (a.0).0, (b.0).0); },
180                    6 => { ark_ff_asm::x86_64_asm_mul!(6, (a.0).0, (b.0).0); },
181                    _ => unsafe { ark_std::hint::unreachable_unchecked() },
182                };
183            } else {
184                let mut r = [0u64; N];
185
186                for i in 0..N {
187                    let mut carry1 = 0u64;
188                    r[0] = fa::mac(r[0], (a.0).0[0], (b.0).0[i], &mut carry1);
189
190                    let k = r[0].wrapping_mul(Self::INV);
191
192                    let mut carry2 = 0u64;
193                    fa::mac_discard(r[0], k, Self::MODULUS.0[0], &mut carry2);
194
195                    for j in 1..N {
196                        r[j] = fa::mac_with_carry(r[j], (a.0).0[j], (b.0).0[i], &mut carry1);
197                        r[j - 1] = fa::mac_with_carry(r[j], k, Self::MODULUS.0[j], &mut carry2);
198                    }
199                    r[N - 1] = carry1 + carry2;
200                }
201                (a.0).0.copy_from_slice(&r);
202            }
203            a.subtract_modulus();
204        } else {
205            // Alternative implementation
206            // Implements CIOS.
207            let (carry, res) = a.mul_without_cond_subtract(b);
208            *a = res;
209
210            if Self::MODULUS_HAS_SPARE_BIT {
211                a.subtract_modulus_with_carry(carry);
212            } else {
213                a.subtract_modulus();
214            }
215        }
216    }
217
218    #[inline(always)]
219    #[unroll_for_loops(12)]
220    fn square_in_place(a: &mut Fp<MontBackend<Self, N>, N>) {
221        if N == 1 {
222            // We default to multiplying with `a` using the `Mul` impl
223            // for the N == 1 case
224            *a *= *a;
225            return;
226        }
227        if Self::CAN_USE_NO_CARRY_SQUARE_OPT
228            && (2..=6).contains(&N)
229            && cfg!(all(
230                feature = "asm",
231                target_feature = "bmi2",
232                target_feature = "adx",
233                target_arch = "x86_64"
234            ))
235        {
236            #[cfg(all(
237                feature = "asm",
238                target_feature = "bmi2",
239                target_feature = "adx",
240                target_arch = "x86_64"
241            ))]
242            #[allow(unsafe_code, unused_mut)]
243            #[rustfmt::skip]
244            match N {
245                2 => { ark_ff_asm::x86_64_asm_square!(2, (a.0).0); },
246                3 => { ark_ff_asm::x86_64_asm_square!(3, (a.0).0); },
247                4 => { ark_ff_asm::x86_64_asm_square!(4, (a.0).0); },
248                5 => { ark_ff_asm::x86_64_asm_square!(5, (a.0).0); },
249                6 => { ark_ff_asm::x86_64_asm_square!(6, (a.0).0); },
250                _ => unsafe { ark_std::hint::unreachable_unchecked() },
251            };
252            a.subtract_modulus();
253            return;
254        }
255
256        let mut r = crate::const_helpers::MulBuffer::<N>::zeroed();
257
258        let mut carry = 0;
259        for i in 0..(N - 1) {
260            for j in (i + 1)..N {
261                r[i + j] = fa::mac_with_carry(r[i + j], (a.0).0[i], (a.0).0[j], &mut carry);
262            }
263            r.b1[i] = carry;
264            carry = 0;
265        }
266
267        r.b1[N - 1] = r.b1[N - 2] >> 63;
268        for i in 2..(2 * N - 1) {
269            r[2 * N - i] = (r[2 * N - i] << 1) | (r[2 * N - (i + 1)] >> 63);
270        }
271        r.b0[1] <<= 1;
272
273        for i in 0..N {
274            r[2 * i] = fa::mac_with_carry(r[2 * i], (a.0).0[i], (a.0).0[i], &mut carry);
275            carry = fa::adc(&mut r[2 * i + 1], 0, carry);
276        }
277        // Montgomery reduction
278        let mut carry2 = 0;
279        for i in 0..N {
280            let k = r[i].wrapping_mul(Self::INV);
281            carry = 0;
282            fa::mac_discard(r[i], k, Self::MODULUS.0[0], &mut carry);
283            for j in 1..N {
284                r[j + i] = fa::mac_with_carry(r[j + i], k, Self::MODULUS.0[j], &mut carry);
285            }
286            carry2 = fa::adc(&mut r.b1[i], carry, carry2);
287        }
288        (a.0).0.copy_from_slice(&r.b1);
289        if Self::MODULUS_HAS_SPARE_BIT {
290            a.subtract_modulus();
291        } else {
292            a.subtract_modulus_with_carry(carry2 != 0);
293        }
294    }
295
296    fn inverse(a: &Fp<MontBackend<Self, N>, N>) -> Option<Fp<MontBackend<Self, N>, N>> {
297        if a.is_zero() {
298            return None;
299        }
300        // Guajardo Kumar Paar Pelzl
301        // Efficient Software-Implementation of Finite Fields with Applications to
302        // Cryptography
303        // Algorithm 16 (BEA for Inversion in Fp)
304
305        let one = BigInt::from(1u64);
306
307        let mut u = a.0;
308        let mut v = Self::MODULUS;
309        let mut b = Fp::new_unchecked(Self::R2); // Avoids unnecessary reduction step.
310        let mut c = Fp::zero();
311
312        while u != one && v != one {
313            while u.is_even() {
314                u.div2();
315
316                if b.0.is_even() {
317                    b.0.div2();
318                } else {
319                    let carry = b.0.add_with_carry(&Self::MODULUS);
320                    b.0.div2();
321                    if !Self::MODULUS_HAS_SPARE_BIT && carry {
322                        (b.0).0[N - 1] |= 1 << 63;
323                    }
324                }
325            }
326
327            while v.is_even() {
328                v.div2();
329
330                if c.0.is_even() {
331                    c.0.div2();
332                } else {
333                    let carry = c.0.add_with_carry(&Self::MODULUS);
334                    c.0.div2();
335                    if !Self::MODULUS_HAS_SPARE_BIT && carry {
336                        (c.0).0[N - 1] |= 1 << 63;
337                    }
338                }
339            }
340
341            if v < u {
342                u.sub_with_borrow(&v);
343                b -= &c;
344            } else {
345                v.sub_with_borrow(&u);
346                c -= &b;
347            }
348        }
349
350        if u == one {
351            Some(b)
352        } else {
353            Some(c)
354        }
355    }
356
357    fn from_bigint(r: BigInt<N>) -> Option<Fp<MontBackend<Self, N>, N>> {
358        let mut r = Fp::new_unchecked(r);
359        if r.is_zero() {
360            Some(r)
361        } else if r.is_geq_modulus() {
362            None
363        } else {
364            r *= &Fp::new_unchecked(Self::R2);
365            Some(r)
366        }
367    }
368
369    #[inline]
370    #[cfg_attr(not(target_family = "wasm"), unroll_for_loops(12))]
371    #[cfg_attr(target_family = "wasm", unroll_for_loops(6))]
372    #[allow(clippy::modulo_one)]
373    fn into_bigint(a: Fp<MontBackend<Self, N>, N>) -> BigInt<N> {
374        let mut r = (a.0).0;
375        // Montgomery Reduction
376        for i in 0..N {
377            let k = r[i].wrapping_mul(Self::INV);
378            let mut carry = 0;
379
380            fa::mac_with_carry(r[i], k, Self::MODULUS.0[0], &mut carry);
381            for j in 1..N {
382                r[(j + i) % N] =
383                    fa::mac_with_carry(r[(j + i) % N], k, Self::MODULUS.0[j], &mut carry);
384            }
385            r[i % N] = carry;
386        }
387
388        BigInt::new(r)
389    }
390
391    #[unroll_for_loops(12)]
392    fn sum_of_products<const M: usize>(
393        a: &[Fp<MontBackend<Self, N>, N>; M],
394        b: &[Fp<MontBackend<Self, N>, N>; M],
395    ) -> Fp<MontBackend<Self, N>, N> {
396        // Adapted from https://github.com/zkcrypto/bls12_381/pull/84 by @str4d.
397
398        // For a single `a x b` multiplication, operand scanning (schoolbook) takes each
399        // limb of `a` in turn, and multiplies it by all of the limbs of `b` to compute
400        // the result as a double-width intermediate representation, which is then fully
401        // reduced at the carry. Here however we have pairs of multiplications (a_i, b_i),
402        // the results of which are summed.
403        //
404        // The intuition for this algorithm is two-fold:
405        // - We can interleave the operand scanning for each pair, by processing the jth
406        //   limb of each `a_i` together. As these have the same offset within the overall
407        //   operand scanning flow, their results can be summed directly.
408        // - We can interleave the multiplication and reduction steps, resulting in a
409        //   single bitshift by the limb size after each iteration. This means we only
410        //   need to store a single extra limb overall, instead of keeping around all the
411        //   intermediate results and eventually having twice as many limbs.
412
413        let modulus_size = Self::MODULUS.const_num_bits() as usize;
414        if modulus_size >= 64 * N - 1 {
415            a.iter().zip(b).map(|(a, b)| *a * b).sum()
416        } else if M == 2 {
417            // Algorithm 2, line 2
418            let result = (0..N).fold(BigInt::zero(), |mut result, j| {
419                // Algorithm 2, line 3
420                let mut carry_a = 0;
421                let mut carry_b = 0;
422                for (a, b) in a.iter().zip(b) {
423                    let a = &a.0;
424                    let b = &b.0;
425                    let mut carry2 = 0;
426                    result.0[0] = fa::mac(result.0[0], a.0[j], b.0[0], &mut carry2);
427                    for k in 1..N {
428                        result.0[k] = fa::mac_with_carry(result.0[k], a.0[j], b.0[k], &mut carry2);
429                    }
430                    carry_b = fa::adc(&mut carry_a, carry_b, carry2);
431                }
432
433                let k = result.0[0].wrapping_mul(Self::INV);
434                let mut carry2 = 0;
435                fa::mac_discard(result.0[0], k, Self::MODULUS.0[0], &mut carry2);
436                for i in 1..N {
437                    result.0[i - 1] =
438                        fa::mac_with_carry(result.0[i], k, Self::MODULUS.0[i], &mut carry2);
439                }
440                result.0[N - 1] = fa::adc_no_carry(carry_a, carry_b, &mut carry2);
441                result
442            });
443            let mut result = Fp::new_unchecked(result);
444            result.subtract_modulus();
445            debug_assert_eq!(
446                a.iter().zip(b).map(|(a, b)| *a * b).sum::<Fp<_, N>>(),
447                result
448            );
449            result
450        } else {
451            let chunk_size = 2 * (N * 64 - modulus_size) - 1;
452            // chunk_size is at least 1, since MODULUS_BIT_SIZE is at most N * 64 - 1.
453            a.chunks(chunk_size)
454                .zip(b.chunks(chunk_size))
455                .map(|(a, b)| {
456                    // Algorithm 2, line 2
457                    let result = (0..N).fold(BigInt::zero(), |mut result, j| {
458                        // Algorithm 2, line 3
459                        let (temp, carry) = a.iter().zip(b).fold(
460                            (result, 0),
461                            |(mut temp, mut carry), (Fp(a, _), Fp(b, _))| {
462                                let mut carry2 = 0;
463                                temp.0[0] = fa::mac(temp.0[0], a.0[j], b.0[0], &mut carry2);
464                                for k in 1..N {
465                                    temp.0[k] =
466                                        fa::mac_with_carry(temp.0[k], a.0[j], b.0[k], &mut carry2);
467                                }
468                                carry = fa::adc_no_carry(carry, 0, &mut carry2);
469                                (temp, carry)
470                            },
471                        );
472
473                        let k = temp.0[0].wrapping_mul(Self::INV);
474                        let mut carry2 = 0;
475                        fa::mac_discard(temp.0[0], k, Self::MODULUS.0[0], &mut carry2);
476                        for i in 1..N {
477                            result.0[i - 1] =
478                                fa::mac_with_carry(temp.0[i], k, Self::MODULUS.0[i], &mut carry2);
479                        }
480                        result.0[N - 1] = fa::adc_no_carry(carry, 0, &mut carry2);
481                        result
482                    });
483                    let mut result = Fp::new_unchecked(result);
484                    result.subtract_modulus();
485                    debug_assert_eq!(
486                        a.iter().zip(b).map(|(a, b)| *a * b).sum::<Fp<_, N>>(),
487                        result
488                    );
489                    result
490                })
491                .sum()
492        }
493    }
494}
495
496/// Compute -M^{-1} mod 2^64.
497pub const fn inv<T: MontConfig<N>, const N: usize>() -> u64 {
498    // We compute this as follows.
499    // First, MODULUS mod 2^64 is just the lower 64 bits of MODULUS.
500    // Hence MODULUS mod 2^64 = MODULUS.0[0] mod 2^64.
501    //
502    // Next, computing the inverse mod 2^64 involves exponentiating by
503    // the multiplicative group order, which is euler_totient(2^64) - 1.
504    // Now, euler_totient(2^64) = 1 << 63, and so
505    // euler_totient(2^64) - 1 = (1 << 63) - 1 = 1111111... (63 digits).
506    // We compute this powering via standard square and multiply.
507    let mut inv = 1u64;
508    crate::const_for!((_i in 0..63) {
509        // Square
510        inv = inv.wrapping_mul(inv);
511        // Multiply
512        inv = inv.wrapping_mul(T::MODULUS.0[0]);
513    });
514    inv.wrapping_neg()
515}
516
517#[inline]
518pub const fn can_use_no_carry_mul_optimization<T: MontConfig<N>, const N: usize>() -> bool {
519    // Checking the modulus at compile time
520    let mut all_remaining_bits_are_one = T::MODULUS.0[N - 1] == u64::MAX >> 1;
521    crate::const_for!((i in 1..N) {
522        all_remaining_bits_are_one  &= T::MODULUS.0[N - i - 1] == u64::MAX;
523    });
524    modulus_has_spare_bit::<T, N>() && !all_remaining_bits_are_one
525}
526
527#[inline]
528pub const fn modulus_has_spare_bit<T: MontConfig<N>, const N: usize>() -> bool {
529    T::MODULUS.0[N - 1] >> 63 == 0
530}
531
532#[inline]
533pub const fn can_use_no_carry_square_optimization<T: MontConfig<N>, const N: usize>() -> bool {
534    // Checking the modulus at compile time
535    let top_two_bits_are_zero = T::MODULUS.0[N - 1] >> 62 == 0;
536    let mut all_remaining_bits_are_one = T::MODULUS.0[N - 1] == u64::MAX >> 2;
537    crate::const_for!((i in 1..N) {
538        all_remaining_bits_are_one  &= T::MODULUS.0[N - i - 1] == u64::MAX;
539    });
540    top_two_bits_are_zero && !all_remaining_bits_are_one
541}
542
543pub const fn sqrt_precomputation<const N: usize, T: MontConfig<N>>(
544) -> Option<SqrtPrecomputation<Fp<MontBackend<T, N>, N>>> {
545    match T::MODULUS.mod_4() {
546        3 => match T::MODULUS_PLUS_ONE_DIV_FOUR.as_ref() {
547            Some(BigInt(modulus_plus_one_div_four)) => Some(SqrtPrecomputation::Case3Mod4 {
548                modulus_plus_one_div_four,
549            }),
550            None => None,
551        },
552        _ => Some(SqrtPrecomputation::TonelliShanks {
553            two_adicity: <MontBackend<T, N>>::TWO_ADICITY,
554            quadratic_nonresidue_to_trace: T::TWO_ADIC_ROOT_OF_UNITY,
555            trace_of_modulus_minus_one_div_two:
556                &<Fp<MontBackend<T, N>, N>>::TRACE_MINUS_ONE_DIV_TWO.0,
557        }),
558    }
559}
560
561/// Construct a [`Fp<MontBackend<T, N>, N>`] element from a literal string. This
562/// should be used primarily for constructing constant field elements; in a
563/// non-const context, [`Fp::from_str`](`ark_std::str::FromStr::from_str`) is
564/// preferable.
565///
566/// # Panics
567///
568/// If the integer represented by the string cannot fit in the number
569/// of limbs of the `Fp`, this macro results in a
570/// * compile-time error if used in a const context
571/// * run-time error otherwise.
572///
573/// # Usage
574///
575/// ```rust
576/// # use ark_test_curves::MontFp;
577/// # use ark_test_curves::bls12_381 as ark_bls12_381;
578/// # use ark_std::{One, str::FromStr};
579/// use ark_bls12_381::Fq;
580/// const ONE: Fq = MontFp!("1");
581/// const NEG_ONE: Fq = MontFp!("-1");
582///
583/// fn check_correctness() {
584///     assert_eq!(ONE, Fq::one());
585///     assert_eq!(Fq::from_str("1").unwrap(), ONE);
586///     assert_eq!(NEG_ONE, -Fq::one());
587/// }
588/// ```
589#[macro_export]
590macro_rules! MontFp {
591    ($c0:expr) => {{
592        let (is_positive, limbs) = $crate::ark_ff_macros::to_sign_and_limbs!($c0);
593        $crate::Fp::from_sign_and_limbs(is_positive, &limbs)
594    }};
595}
596
597pub use ark_ff_macros::MontConfig;
598
599pub use MontFp;
600
601pub struct MontBackend<T: MontConfig<N>, const N: usize>(PhantomData<T>);
602
603impl<T: MontConfig<N>, const N: usize> FpConfig<N> for MontBackend<T, N> {
604    /// The modulus of the field.
605    const MODULUS: crate::BigInt<N> = T::MODULUS;
606
607    /// A multiplicative generator of the field.
608    /// `Self::GENERATOR` is an element having multiplicative order
609    /// `Self::MODULUS - 1`.
610    const GENERATOR: Fp<Self, N> = T::GENERATOR;
611
612    /// Additive identity of the field, i.e. the element `e`
613    /// such that, for all elements `f` of the field, `e + f = f`.
614    const ZERO: Fp<Self, N> = Fp::new_unchecked(BigInt([0u64; N]));
615
616    /// Multiplicative identity of the field, i.e. the element `e`
617    /// such that, for all elements `f` of the field, `e * f = f`.
618    const ONE: Fp<Self, N> = Fp::new_unchecked(T::R);
619
620    const TWO_ADICITY: u32 = Self::MODULUS.two_adic_valuation();
621    const TWO_ADIC_ROOT_OF_UNITY: Fp<Self, N> = T::TWO_ADIC_ROOT_OF_UNITY;
622    const SMALL_SUBGROUP_BASE: Option<u32> = T::SMALL_SUBGROUP_BASE;
623    const SMALL_SUBGROUP_BASE_ADICITY: Option<u32> = T::SMALL_SUBGROUP_BASE_ADICITY;
624    const LARGE_SUBGROUP_ROOT_OF_UNITY: Option<Fp<Self, N>> = T::LARGE_SUBGROUP_ROOT_OF_UNITY;
625    const SQRT_PRECOMP: Option<crate::SqrtPrecomputation<Fp<Self, N>>> = T::SQRT_PRECOMP;
626
627    fn add_assign(a: &mut Fp<Self, N>, b: &Fp<Self, N>) {
628        T::add_assign(a, b)
629    }
630
631    fn sub_assign(a: &mut Fp<Self, N>, b: &Fp<Self, N>) {
632        T::sub_assign(a, b)
633    }
634
635    fn double_in_place(a: &mut Fp<Self, N>) {
636        T::double_in_place(a)
637    }
638
639    fn neg_in_place(a: &mut Fp<Self, N>) {
640        T::neg_in_place(a)
641    }
642
643    /// This modular multiplication algorithm uses Montgomery
644    /// reduction for efficient implementation. It also additionally
645    /// uses the "no-carry optimization" outlined
646    /// [here](https://hackmd.io/@zkteam/modular_multiplication) if
647    /// `P::MODULUS` has (a) a non-zero MSB, and (b) at least one
648    /// zero bit in the rest of the modulus.
649    #[inline]
650    fn mul_assign(a: &mut Fp<Self, N>, b: &Fp<Self, N>) {
651        T::mul_assign(a, b)
652    }
653
654    fn sum_of_products<const M: usize>(a: &[Fp<Self, N>; M], b: &[Fp<Self, N>; M]) -> Fp<Self, N> {
655        T::sum_of_products(a, b)
656    }
657
658    #[inline]
659    #[allow(unused_braces, clippy::absurd_extreme_comparisons)]
660    fn square_in_place(a: &mut Fp<Self, N>) {
661        T::square_in_place(a)
662    }
663
664    fn inverse(a: &Fp<Self, N>) -> Option<Fp<Self, N>> {
665        T::inverse(a)
666    }
667
668    fn from_bigint(r: BigInt<N>) -> Option<Fp<Self, N>> {
669        T::from_bigint(r)
670    }
671
672    #[inline]
673    #[allow(clippy::modulo_one)]
674    fn into_bigint(a: Fp<Self, N>) -> BigInt<N> {
675        T::into_bigint(a)
676    }
677}
678
679impl<T: MontConfig<N>, const N: usize> Fp<MontBackend<T, N>, N> {
680    #[doc(hidden)]
681    pub const R: BigInt<N> = T::R;
682    #[doc(hidden)]
683    pub const R2: BigInt<N> = T::R2;
684    #[doc(hidden)]
685    pub const INV: u64 = T::INV;
686
687    /// Construct a new field element from its underlying
688    /// [`struct@BigInt`] data type.
689    #[inline]
690    pub const fn new(element: BigInt<N>) -> Self {
691        let mut r = Self(element, PhantomData);
692        if r.const_is_zero() {
693            r
694        } else {
695            r = r.mul(&Fp(T::R2, PhantomData));
696            r
697        }
698    }
699
700    /// Construct a new field element from its underlying
701    /// [`struct@BigInt`] data type.
702    ///
703    /// Unlike [`Self::new`], this method does not perform Montgomery reduction.
704    /// Thus, this method should be used only when constructing
705    /// an element from an integer that has already been put in
706    /// Montgomery form.
707    #[inline]
708    pub const fn new_unchecked(element: BigInt<N>) -> Self {
709        Self(element, PhantomData)
710    }
711
712    const fn const_is_zero(&self) -> bool {
713        self.0.const_is_zero()
714    }
715
716    #[doc(hidden)]
717    const fn const_neg(self) -> Self {
718        if !self.const_is_zero() {
719            Self::new_unchecked(Self::sub_with_borrow(&T::MODULUS, &self.0))
720        } else {
721            self
722        }
723    }
724
725    /// Interpret a set of limbs (along with a sign) as a field element.
726    /// For *internal* use only; please use the `ark_ff::MontFp` macro instead
727    /// of this method
728    #[doc(hidden)]
729    pub const fn from_sign_and_limbs(is_positive: bool, limbs: &[u64]) -> Self {
730        let mut repr = BigInt::<N>([0; N]);
731        assert!(limbs.len() <= N);
732        crate::const_for!((i in 0..(limbs.len())) {
733            repr.0[i] = limbs[i];
734        });
735        let res = Self::new(repr);
736        if is_positive {
737            res
738        } else {
739            res.const_neg()
740        }
741    }
742
743    const fn mul_without_cond_subtract(mut self, other: &Self) -> (bool, Self) {
744        let (mut lo, mut hi) = ([0u64; N], [0u64; N]);
745        crate::const_for!((i in 0..N) {
746            let mut carry = 0;
747            crate::const_for!((j in 0..N) {
748                let k = i + j;
749                if k >= N {
750                    hi[k - N] = mac_with_carry!(hi[k - N], (self.0).0[i], (other.0).0[j], &mut carry);
751                } else {
752                    lo[k] = mac_with_carry!(lo[k], (self.0).0[i], (other.0).0[j], &mut carry);
753                }
754            });
755            hi[i] = carry;
756        });
757        // Montgomery reduction
758        let mut carry2 = 0;
759        crate::const_for!((i in 0..N) {
760            let tmp = lo[i].wrapping_mul(T::INV);
761            let mut carry;
762            mac!(lo[i], tmp, T::MODULUS.0[0], &mut carry);
763            crate::const_for!((j in 1..N) {
764                let k = i + j;
765                if k >= N {
766                    hi[k - N] = mac_with_carry!(hi[k - N], tmp, T::MODULUS.0[j], &mut carry);
767                }  else {
768                    lo[k] = mac_with_carry!(lo[k], tmp, T::MODULUS.0[j], &mut carry);
769                }
770            });
771            hi[i] = adc!(hi[i], carry, &mut carry2);
772        });
773
774        crate::const_for!((i in 0..N) {
775            (self.0).0[i] = hi[i];
776        });
777        (carry2 != 0, self)
778    }
779
780    const fn mul(self, other: &Self) -> Self {
781        let (carry, res) = self.mul_without_cond_subtract(other);
782        if T::MODULUS_HAS_SPARE_BIT {
783            res.const_subtract_modulus()
784        } else {
785            res.const_subtract_modulus_with_carry(carry)
786        }
787    }
788
789    const fn const_is_valid(&self) -> bool {
790        crate::const_for!((i in 0..N) {
791            if (self.0).0[N - i - 1] < T::MODULUS.0[N - i - 1] {
792                return true
793            } else if (self.0).0[N - i - 1] > T::MODULUS.0[N - i - 1] {
794                return false
795            }
796        });
797        false
798    }
799
800    #[inline]
801    const fn const_subtract_modulus(mut self) -> Self {
802        if !self.const_is_valid() {
803            self.0 = Self::sub_with_borrow(&self.0, &T::MODULUS);
804        }
805        self
806    }
807
808    #[inline]
809    const fn const_subtract_modulus_with_carry(mut self, carry: bool) -> Self {
810        if carry || !self.const_is_valid() {
811            self.0 = Self::sub_with_borrow(&self.0, &T::MODULUS);
812        }
813        self
814    }
815
816    const fn sub_with_borrow(a: &BigInt<N>, b: &BigInt<N>) -> BigInt<N> {
817        a.const_sub_with_borrow(b).0
818    }
819}
820
821#[cfg(test)]
822mod test {
823    use ark_std::{str::FromStr, vec::*};
824    use ark_test_curves::secp256k1::Fr;
825    use num_bigint::{BigInt, BigUint, Sign};
826
827    #[test]
828    fn test_mont_macro_correctness() {
829        let (is_positive, limbs) = str_to_limbs_u64(
830            "111192936301596926984056301862066282284536849596023571352007112326586892541694",
831        );
832        let t = Fr::from_sign_and_limbs(is_positive, &limbs);
833
834        let result: BigUint = t.into();
835        let expected = BigUint::from_str(
836            "111192936301596926984056301862066282284536849596023571352007112326586892541694",
837        )
838        .unwrap();
839
840        assert_eq!(result, expected);
841    }
842
843    fn str_to_limbs_u64(num: &str) -> (bool, Vec<u64>) {
844        let (sign, digits) = BigInt::from_str(num)
845            .expect("could not parse to bigint")
846            .to_radix_le(16);
847        let limbs = digits
848            .chunks(16)
849            .map(|chunk| {
850                let mut this = 0u64;
851                for (i, hexit) in chunk.iter().enumerate() {
852                    this += (*hexit as u64) << (4 * i);
853                }
854                this
855            })
856            .collect::<Vec<_>>();
857
858        let sign_is_positive = sign != Sign::Minus;
859        (sign_is_positive, limbs)
860    }
861}