1#![warn(
2 unused,
3 future_incompatible,
4 nonstandard_style,
5 rust_2018_idioms,
6 rust_2021_compatibility
7)]
8#![forbid(unsafe_code)]
9
10use num_bigint::BigUint;
11use proc_macro::TokenStream;
12use syn::{Expr, ExprLit, Item, ItemFn, Lit, Meta};
13
14mod montgomery;
15mod unroll;
16
17pub(crate) mod utils;
18
19#[proc_macro]
20pub fn to_sign_and_limbs(input: TokenStream) -> TokenStream {
21 let num = utils::parse_string(input).expect("expected decimal string");
22 let (is_positive, limbs) = utils::str_to_limbs(&num);
23
24 let limbs: String = limbs.join(", ");
25 let limbs_and_sign = format!("({}", is_positive) + ", [" + &limbs + "])";
26 let tuple: Expr = syn::parse_str(&limbs_and_sign).unwrap();
27 quote::quote!(#tuple).into()
28}
29
30#[proc_macro_derive(
41 MontConfig,
42 attributes(modulus, generator, small_subgroup_base, small_subgroup_power)
43)]
44pub fn mont_config(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
45 let ast: syn::DeriveInput = syn::parse(input).unwrap();
47
48 let modulus: BigUint = fetch_attr("modulus", &ast.attrs)
50 .expect("Please supply a modulus attribute")
51 .parse()
52 .expect("Modulus should be a number");
53
54 let generator: BigUint = fetch_attr("generator", &ast.attrs)
57 .expect("Please supply a generator attribute")
58 .parse()
59 .expect("Generator should be a number");
60
61 let small_subgroup_base: Option<u32> = fetch_attr("small_subgroup_base", &ast.attrs)
62 .map(|s| s.parse().expect("small_subgroup_base should be a number"));
63
64 let small_subgroup_power: Option<u32> = fetch_attr("small_subgroup_power", &ast.attrs)
65 .map(|s| s.parse().expect("small_subgroup_power should be a number"));
66
67 montgomery::mont_config_helper(
68 modulus,
69 generator,
70 small_subgroup_base,
71 small_subgroup_power,
72 ast.ident,
73 )
74 .into()
75}
76
77const ARG_MSG: &str = "Failed to parse unroll threshold; must be a positive integer";
78
79#[proc_macro_attribute]
81pub fn unroll_for_loops(args: TokenStream, input: TokenStream) -> TokenStream {
82 let unroll_by = match syn::parse2::<syn::Lit>(args.into()).expect(ARG_MSG) {
83 Lit::Int(int) => int.base10_parse().expect(ARG_MSG),
84 _ => panic!("{}", ARG_MSG),
85 };
86
87 let item: Item = syn::parse(input).expect("Failed to parse input.");
88
89 if let Item::Fn(item_fn) = item {
90 let new_block = {
91 let &ItemFn {
92 block: ref box_block,
93 ..
94 } = &item_fn;
95 unroll::unroll_in_block(box_block, unroll_by)
96 };
97 let new_item = Item::Fn(ItemFn {
98 block: Box::new(new_block),
99 ..item_fn
100 });
101 quote::quote! ( #new_item ).into()
102 } else {
103 quote::quote! ( #item ).into()
104 }
105}
106
107fn fetch_attr(name: &str, attrs: &[syn::Attribute]) -> Option<String> {
109 for attr in attrs {
111 match attr.meta {
112 Meta::NameValue(ref nv) if nv.path.is_ident(name) => {
115 if let Expr::Lit(ExprLit {
118 lit: Lit::Str(ref s),
119 ..
120 }) = nv.value
121 {
122 return Some(s.value());
123 } else {
124 panic!("attribute {name} should be a string")
125 }
126 },
127 _ => continue,
128 }
129 }
130 None
131}
132
133#[test]
134fn test_str_to_limbs() {
135 use num_bigint::Sign::*;
136 for i in 0..100 {
137 for sign in [Plus, Minus] {
138 let number = 1i128 << i;
139 let signed_number = match sign {
140 Minus => -number,
141 Plus | _ => number,
142 };
143 for base in [2, 8, 16, 10] {
144 let mut string = match base {
145 2 => format!("{:#b}", number),
146 8 => format!("{:#o}", number),
147 16 => format!("{:#x}", number),
148 10 => format!("{}", number),
149 _ => unreachable!(),
150 };
151 if sign == Minus {
152 string.insert(0, '-');
153 }
154 let (is_positive, limbs) = utils::str_to_limbs(&format!("{}", string));
155 assert_eq!(
156 limbs[0],
157 format!("{}u64", signed_number.abs() as u64),
158 "{signed_number}, {i}"
159 );
160 if i > 63 {
161 assert_eq!(
162 limbs[1],
163 format!("{}u64", (signed_number.abs() >> 64) as u64),
164 "{signed_number}, {i}"
165 );
166 }
167
168 assert_eq!(is_positive, sign == Plus);
169 }
170 }
171 }
172 let (is_positive, limbs) = utils::str_to_limbs("0");
173 assert!(is_positive);
174 assert_eq!(&limbs, &["0u64".to_string()]);
175
176 let (is_positive, limbs) = utils::str_to_limbs("-5");
177 assert!(!is_positive);
178 assert_eq!(&limbs, &["5u64".to_string()]);
179
180 let (is_positive, limbs) = utils::str_to_limbs("100");
181 assert!(is_positive);
182 assert_eq!(&limbs, &["100u64".to_string()]);
183
184 let large_num = -((1i128 << 64) + 101234001234i128);
185 let (is_positive, limbs) = utils::str_to_limbs(&large_num.to_string());
186 assert!(!is_positive);
187 assert_eq!(&limbs, &["101234001234u64".to_string(), "1u64".to_string()]);
188
189 let num = "80949648264912719408558363140637477264845294720710499478137287262712535938301461879813459410946";
190 let (is_positive, limbs) = utils::str_to_limbs(num);
191 assert!(is_positive);
192 let expected_limbs = [
193 format!("{}u64", 0x8508c00000000002u64),
194 format!("{}u64", 0x452217cc90000000u64),
195 format!("{}u64", 0xc5ed1347970dec00u64),
196 format!("{}u64", 0x619aaf7d34594aabu64),
197 format!("{}u64", 0x9b3af05dd14f6ecu64),
198 ];
199 assert_eq!(&limbs, &expected_limbs);
200}