ark_ff_macros/
lib.rs

1#![warn(
2    unused,
3    future_incompatible,
4    nonstandard_style,
5    rust_2018_idioms,
6    rust_2021_compatibility
7)]
8#![forbid(unsafe_code)]
9
10use num_bigint::BigUint;
11use proc_macro::TokenStream;
12use syn::{Expr, ExprLit, Item, ItemFn, Lit, Meta};
13
14mod montgomery;
15mod unroll;
16
17pub(crate) mod utils;
18
19#[proc_macro]
20pub fn to_sign_and_limbs(input: TokenStream) -> TokenStream {
21    let num = utils::parse_string(input).expect("expected decimal string");
22    let (is_positive, limbs) = utils::str_to_limbs(&num);
23
24    let limbs: String = limbs.join(", ");
25    let limbs_and_sign = format!("({}", is_positive) + ", [" + &limbs + "])";
26    let tuple: Expr = syn::parse_str(&limbs_and_sign).unwrap();
27    quote::quote!(#tuple).into()
28}
29
30/// Derive the `MontConfig` trait.
31///
32/// The attributes available to this macro are
33/// * `modulus`: Specify the prime modulus underlying this prime field.
34/// * `generator`: Specify the generator of the multiplicative subgroup of this
35///   prime field. This value must be a quadratic non-residue in the field.
36/// * `small_subgroup_base` and `small_subgroup_power` (optional): If the field
37///   has insufficient two-adicity, specify an additional subgroup of size
38///   `small_subgroup_base.pow(small_subgroup_power)`.
39// This code was adapted from the `PrimeField` Derive Macro in ff-derive.
40#[proc_macro_derive(
41    MontConfig,
42    attributes(modulus, generator, small_subgroup_base, small_subgroup_power)
43)]
44pub fn mont_config(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
45    // Parse the type definition
46    let ast: syn::DeriveInput = syn::parse(input).unwrap();
47
48    // We're given the modulus p of the prime field
49    let modulus: BigUint = fetch_attr("modulus", &ast.attrs)
50        .expect("Please supply a modulus attribute")
51        .parse()
52        .expect("Modulus should be a number");
53
54    // We may be provided with a generator of p - 1 order. It is required that this
55    // generator be quadratic nonresidue.
56    let generator: BigUint = fetch_attr("generator", &ast.attrs)
57        .expect("Please supply a generator attribute")
58        .parse()
59        .expect("Generator should be a number");
60
61    let small_subgroup_base: Option<u32> = fetch_attr("small_subgroup_base", &ast.attrs)
62        .map(|s| s.parse().expect("small_subgroup_base should be a number"));
63
64    let small_subgroup_power: Option<u32> = fetch_attr("small_subgroup_power", &ast.attrs)
65        .map(|s| s.parse().expect("small_subgroup_power should be a number"));
66
67    montgomery::mont_config_helper(
68        modulus,
69        generator,
70        small_subgroup_base,
71        small_subgroup_power,
72        ast.ident,
73    )
74    .into()
75}
76
77const ARG_MSG: &str = "Failed to parse unroll threshold; must be a positive integer";
78
79/// Attribute used to unroll for loops found inside a function block.
80#[proc_macro_attribute]
81pub fn unroll_for_loops(args: TokenStream, input: TokenStream) -> TokenStream {
82    let unroll_by = match syn::parse2::<syn::Lit>(args.into()).expect(ARG_MSG) {
83        Lit::Int(int) => int.base10_parse().expect(ARG_MSG),
84        _ => panic!("{}", ARG_MSG),
85    };
86
87    let item: Item = syn::parse(input).expect("Failed to parse input.");
88
89    if let Item::Fn(item_fn) = item {
90        let new_block = {
91            let &ItemFn {
92                block: ref box_block,
93                ..
94            } = &item_fn;
95            unroll::unroll_in_block(box_block, unroll_by)
96        };
97        let new_item = Item::Fn(ItemFn {
98            block: Box::new(new_block),
99            ..item_fn
100        });
101        quote::quote! ( #new_item ).into()
102    } else {
103        quote::quote! ( #item ).into()
104    }
105}
106
107/// Fetch an attribute string from the derived struct.
108fn fetch_attr(name: &str, attrs: &[syn::Attribute]) -> Option<String> {
109    // Go over each attribute
110    for attr in attrs {
111        match attr.meta {
112            // If the attribute's path matches `name`, and if the attribute is of
113            // the form `#[name = "value"]`, return `value`
114            Meta::NameValue(ref nv) if nv.path.is_ident(name) => {
115                // Extract and return the string value.
116                // If `value` is not a string, return an error
117                if let Expr::Lit(ExprLit {
118                    lit: Lit::Str(ref s),
119                    ..
120                }) = nv.value
121                {
122                    return Some(s.value());
123                } else {
124                    panic!("attribute {name} should be a string")
125                }
126            },
127            _ => continue,
128        }
129    }
130    None
131}
132
133#[test]
134fn test_str_to_limbs() {
135    use num_bigint::Sign::*;
136    for i in 0..100 {
137        for sign in [Plus, Minus] {
138            let number = 1i128 << i;
139            let signed_number = match sign {
140                Minus => -number,
141                Plus | _ => number,
142            };
143            for base in [2, 8, 16, 10] {
144                let mut string = match base {
145                    2 => format!("{:#b}", number),
146                    8 => format!("{:#o}", number),
147                    16 => format!("{:#x}", number),
148                    10 => format!("{}", number),
149                    _ => unreachable!(),
150                };
151                if sign == Minus {
152                    string.insert(0, '-');
153                }
154                let (is_positive, limbs) = utils::str_to_limbs(&format!("{}", string));
155                assert_eq!(
156                    limbs[0],
157                    format!("{}u64", signed_number.abs() as u64),
158                    "{signed_number}, {i}"
159                );
160                if i > 63 {
161                    assert_eq!(
162                        limbs[1],
163                        format!("{}u64", (signed_number.abs() >> 64) as u64),
164                        "{signed_number}, {i}"
165                    );
166                }
167
168                assert_eq!(is_positive, sign == Plus);
169            }
170        }
171    }
172    let (is_positive, limbs) = utils::str_to_limbs("0");
173    assert!(is_positive);
174    assert_eq!(&limbs, &["0u64".to_string()]);
175
176    let (is_positive, limbs) = utils::str_to_limbs("-5");
177    assert!(!is_positive);
178    assert_eq!(&limbs, &["5u64".to_string()]);
179
180    let (is_positive, limbs) = utils::str_to_limbs("100");
181    assert!(is_positive);
182    assert_eq!(&limbs, &["100u64".to_string()]);
183
184    let large_num = -((1i128 << 64) + 101234001234i128);
185    let (is_positive, limbs) = utils::str_to_limbs(&large_num.to_string());
186    assert!(!is_positive);
187    assert_eq!(&limbs, &["101234001234u64".to_string(), "1u64".to_string()]);
188
189    let num = "80949648264912719408558363140637477264845294720710499478137287262712535938301461879813459410946";
190    let (is_positive, limbs) = utils::str_to_limbs(num);
191    assert!(is_positive);
192    let expected_limbs = [
193        format!("{}u64", 0x8508c00000000002u64),
194        format!("{}u64", 0x452217cc90000000u64),
195        format!("{}u64", 0xc5ed1347970dec00u64),
196        format!("{}u64", 0x619aaf7d34594aabu64),
197        format!("{}u64", 0x9b3af05dd14f6ecu64),
198    ];
199    assert_eq!(&limbs, &expected_limbs);
200}