ark_ff_macros/montgomery/
square.rs

1use quote::quote;
2
3pub(super) fn square_in_place_impl(
4    can_use_no_carry_mul_opt: bool,
5    num_limbs: usize,
6    modulus_limbs: &[u64],
7    modulus_has_spare_bit: bool,
8) -> proc_macro2::TokenStream {
9    let mut body = proc_macro2::TokenStream::new();
10    let mut default = proc_macro2::TokenStream::new();
11
12    let modulus_0 = modulus_limbs[0];
13    let double_num_limbs = 2 * num_limbs;
14    default.extend(quote! {
15        let mut r = [0u64; #double_num_limbs];
16        let mut carry = 0;
17    });
18    for i in 0..(num_limbs - 1) {
19        for j in (i + 1)..num_limbs {
20            let idx = i + j;
21            default.extend(quote! {
22                r[#idx] = fa::mac_with_carry(r[#idx], (a.0).0[#i], (a.0).0[#j], &mut carry);
23            })
24        }
25        default.extend(quote! {
26            r[#num_limbs + #i] = carry;
27            carry = 0;
28        });
29    }
30    default.extend(quote! { r[#double_num_limbs - 1] = r[#double_num_limbs - 2] >> 63; });
31    for i in 2..(double_num_limbs - 1) {
32        let idx = double_num_limbs - i;
33        default.extend(quote! { r[#idx] = (r[#idx] << 1) | (r[#idx - 1] >> 63); });
34    }
35    default.extend(quote! { r[1] <<= 1; });
36
37    for i in 0..num_limbs {
38        let idx = 2 * i;
39        default.extend(quote! {
40            r[#idx] = fa::mac_with_carry(r[#idx], (a.0).0[#i], (a.0).0[#i], &mut carry);
41            carry = fa::adc(&mut r[#idx + 1], 0, carry);
42        });
43    }
44    // Montgomery reduction
45    default.extend(quote! { let mut carry2 = 0; });
46    for i in 0..num_limbs {
47        default.extend(quote! {
48            let k = r[#i].wrapping_mul(Self::INV);
49            let mut carry = 0;
50            fa::mac_discard(r[#i], k, #modulus_0, &mut carry);
51        });
52        for (j, modulus_j) in modulus_limbs.iter().enumerate().take(num_limbs).skip(1) {
53            let idx = j + i;
54            default.extend(quote! {
55                r[#idx] = fa::mac_with_carry(r[#idx], k, #modulus_j, &mut carry);
56            });
57        }
58        default.extend(quote! { carry2 = fa::adc(&mut r[#num_limbs + #i], carry, carry2); });
59    }
60    default.extend(quote! { (a.0).0 = r[#num_limbs..].try_into().unwrap(); });
61
62    if num_limbs == 1 {
63        // We default to multiplying with `a` using the `Mul` impl
64        // for the N == 1 case
65        quote!({
66            *a *= *a;
67        })
68    } else if (2..=6).contains(&num_limbs) && can_use_no_carry_mul_opt {
69        body.extend(quote!({
70            if cfg!(all(
71                feature = "asm",
72                target_feature = "bmi2",
73                target_feature = "adx",
74                target_arch = "x86_64"
75            )) {
76                #[cfg(
77                    all(
78                        feature = "asm",
79                        target_feature = "bmi2",
80                        target_feature = "adx",
81                        target_arch = "x86_64"
82                    )
83                )]
84                #[allow(unsafe_code, unused_mut)]
85                {
86                    ark_ff::x86_64_asm_square!(#num_limbs, (a.0).0);
87                }
88            } else {
89                #[cfg(
90                    not(all(
91                        feature = "asm",
92                        target_feature = "bmi2",
93                        target_feature = "adx",
94                        target_arch = "x86_64"
95                    ))
96                )]
97                {
98                    #default
99                }
100            }
101        }));
102        body.extend(quote!(__subtract_modulus(a);));
103        body
104    } else {
105        body.extend(quote!( #default ));
106        if modulus_has_spare_bit {
107            body.extend(quote!(__subtract_modulus(a);));
108        } else {
109            body.extend(quote!(__subtract_modulus_with_carry(a, carry2 != 0);));
110        }
111        body
112    }
113}