ark_ff_macros/montgomery/
mul.rs

1use quote::quote;
2
3pub(super) fn mul_assign_impl(
4    can_use_no_carry_mul_opt: bool,
5    num_limbs: usize,
6    modulus_limbs: &[u64],
7    modulus_has_spare_bit: bool,
8) -> proc_macro2::TokenStream {
9    let mut body = proc_macro2::TokenStream::new();
10    let modulus_0 = modulus_limbs[0];
11    if can_use_no_carry_mul_opt {
12        // This modular multiplication algorithm uses Montgomery
13        // reduction for efficient implementation. It also additionally
14        // uses the "no-carry optimization" outlined
15        // [here](https://hackmd.io/@gnark/modular_multiplication) if
16        // `MODULUS` has (a) a non-zero MSB, and (b) at least one
17        // zero bit in the rest of the modulus.
18
19        let mut default = proc_macro2::TokenStream::new();
20        default.extend(quote! { let mut r = [0u64; #num_limbs]; });
21        for i in 0..num_limbs {
22            default.extend(quote! {
23                let mut carry1 = 0u64;
24                r[0] = fa::mac(r[0], (a.0).0[0], (b.0).0[#i], &mut carry1);
25                let k = r[0].wrapping_mul(Self::INV);
26                let mut carry2 = 0u64;
27                fa::mac_discard(r[0], k, #modulus_0, &mut carry2);
28            });
29            for (j, modulus_j) in modulus_limbs.iter().enumerate().take(num_limbs).skip(1) {
30                let idx = j - 1;
31                default.extend(quote! {
32                    r[#j] = fa::mac_with_carry(r[#j], (a.0).0[#j], (b.0).0[#i], &mut carry1);
33                    r[#idx] = fa::mac_with_carry(r[#j], k, #modulus_j, &mut carry2);
34                });
35            }
36            default.extend(quote!(r[#num_limbs - 1] = carry1 + carry2;));
37        }
38        default.extend(quote!((a.0).0 = r;));
39        // Avoid using assembly for `N == 1`.
40        if (2..=6).contains(&num_limbs) {
41            body.extend(quote!({
42                if cfg!(all(
43                    feature = "asm",
44                    target_feature = "bmi2",
45                    target_feature = "adx",
46                    target_arch = "x86_64"
47                )) {
48                    #[cfg(
49                        all(
50                            feature = "asm",
51                            target_feature = "bmi2",
52                            target_feature = "adx",
53                            target_arch = "x86_64"
54                        )
55                    )]
56                    #[allow(unsafe_code, unused_mut)]
57                    ark_ff::x86_64_asm_mul!(#num_limbs, (a.0).0, (b.0).0);
58                } else {
59                    #[cfg(
60                        not(all(
61                            feature = "asm",
62                            target_feature = "bmi2",
63                            target_feature = "adx",
64                            target_arch = "x86_64"
65                        ))
66                    )]
67                    {
68                        #default
69                    }
70                }
71            }))
72        } else {
73            body.extend(quote!({ #default }))
74        }
75        body.extend(quote!(__subtract_modulus(a);));
76    } else {
77        // We use standard CIOS
78        let double_limbs = num_limbs * 2;
79        body.extend(quote! {
80            let mut scratch = [0u64; #double_limbs];
81        });
82        for i in 0..num_limbs {
83            body.extend(quote! { let mut carry = 0u64; });
84            for j in 0..num_limbs {
85                let k = i + j;
86                body.extend(quote!{scratch[#k] = fa::mac_with_carry(scratch[#k], (a.0).0[#i], (b.0).0[#j], &mut carry);});
87            }
88            body.extend(quote! { scratch[#i + #num_limbs] = carry; });
89        }
90        body.extend(quote!( let mut carry2 = 0u64; ));
91        for i in 0..num_limbs {
92            body.extend(quote! {
93                let tmp = scratch[#i].wrapping_mul(Self::INV);
94                let mut carry = 0u64;
95                fa::mac(scratch[#i], tmp, #modulus_0, &mut carry);
96            });
97            for j in 1..num_limbs {
98                let modulus_j = modulus_limbs[j];
99                let k = i + j;
100                body.extend(quote!(scratch[#k] = fa::mac_with_carry(scratch[#k], tmp, #modulus_j, &mut carry);));
101            }
102            body.extend(quote!(carry2 = fa::adc(&mut scratch[#i + #num_limbs], carry, carry2);));
103        }
104        body.extend(quote! {
105            (a.0).0 = scratch[#num_limbs..].try_into().unwrap();
106        });
107        if modulus_has_spare_bit {
108            body.extend(quote!(__subtract_modulus(a);));
109        } else {
110            body.extend(quote!(__subtract_modulus_with_carry(a, carry2 != 0);));
111        }
112    }
113    body
114}