1#![warn(
2 unused,
3 future_incompatible,
4 nonstandard_style,
5 rust_2018_idioms,
6 rust_2021_compatibility
7)]
8#![forbid(unsafe_code)]
9#![recursion_limit = "128"]
10
11use proc_macro::TokenStream;
12use syn::{
13 parse::{Parse, ParseStream},
14 Expr,
15};
16
17mod context;
18use context::{AssemblyVar, Context};
19
20use std::cell::RefCell;
21
22struct AsmMulInput {
23 num_limbs: Box<Expr>,
24 a: Expr,
25 b: Expr,
26}
27
28impl Parse for AsmMulInput {
29 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
30 let input = input
31 .parse_terminated(Expr::parse, syn::token::Comma)?
32 .into_iter()
33 .collect::<Vec<_>>();
34 let num_limbs = input[0].clone();
35 let a = input[1].clone();
36 let b = input[2].clone();
37
38 let num_limbs = if let Expr::Group(syn::ExprGroup { expr, .. }) = num_limbs {
39 expr
40 } else {
41 Box::new(num_limbs)
42 };
43 let output = Self { num_limbs, a, b };
44 Ok(output)
45 }
46}
47
48#[proc_macro]
49pub fn x86_64_asm_mul(input: TokenStream) -> TokenStream {
50 let AsmMulInput { num_limbs, a, b } = syn::parse_macro_input!(input);
51 let num_limbs = if let Expr::Lit(syn::ExprLit {
52 lit: syn::Lit::Int(ref lit_int),
53 ..
54 }) = &*num_limbs
55 {
56 lit_int.base10_parse::<usize>().unwrap()
57 } else {
58 panic!("The number of limbs must be a literal");
59 };
60 if num_limbs <= 6 {
61 let impl_block = generate_impl(num_limbs, true);
62
63 let inner_ts: Expr = syn::parse_str(&impl_block).unwrap();
64 let ts = quote::quote! {
65 let a = &mut #a;
66 let b = &#b;
67 #inner_ts
68 };
69 ts.into()
70 } else {
71 TokenStream::new()
72 }
73}
74
75struct AsmSquareInput {
76 num_limbs: Box<Expr>,
77 a: Expr,
78}
79
80impl Parse for AsmSquareInput {
81 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
82 let input = input
83 .parse_terminated(Expr::parse, syn::token::Comma)?
84 .into_iter()
85 .collect::<Vec<_>>();
86 let num_limbs = input[0].clone();
87 let a = input[1].clone();
88
89 let num_limbs = if let Expr::Group(syn::ExprGroup { expr, .. }) = num_limbs {
90 expr
91 } else {
92 Box::new(num_limbs)
93 };
94 let output = Self { num_limbs, a };
95 Ok(output)
96 }
97}
98
99#[proc_macro]
100pub fn x86_64_asm_square(input: TokenStream) -> TokenStream {
101 let AsmSquareInput { num_limbs, a } = syn::parse_macro_input!(input);
102 let num_limbs = if let Expr::Lit(syn::ExprLit {
103 lit: syn::Lit::Int(ref lit_int),
104 ..
105 }) = &*num_limbs
106 {
107 lit_int.base10_parse::<usize>().unwrap()
108 } else {
109 panic!("The number of limbs must be a literal");
110 };
111 if num_limbs <= 6 {
112 let impl_block = generate_impl(num_limbs, false);
113
114 let inner_ts: Expr = syn::parse_str(&impl_block).unwrap();
115 let ts = quote::quote! {
116 let a = &mut #a;
117 #inner_ts
118 };
119 ts.into()
120 } else {
121 TokenStream::new()
122 }
123}
124
125fn construct_asm_mul(ctx: &Context<'_>, limbs: usize) -> Vec<String> {
126 let r: Vec<AssemblyVar> = Context::R.iter().map(|r| (*r).into()).collect();
127 let rax: AssemblyVar = Context::RAX.into();
128 let rcx: AssemblyVar = Context::RCX.into();
129 let rdx: AssemblyVar = Context::RDX.into();
130 let rsi: AssemblyVar = Context::RSI.into();
131 let a: AssemblyVar = ctx.get_decl("a").into();
132 let b: AssemblyVar = ctx.get_decl_with_fallback("b", "a").into(); let modulus: AssemblyVar = ctx.get_decl("modulus").into();
134 let mod_inv: AssemblyVar = ctx.get_decl("mod_inv").into();
135
136 let asm_instructions = RefCell::new(Vec::new());
137
138 let comment = |comment: &str| {
139 asm_instructions.borrow_mut().push(format!("// {comment}"));
140 };
141
142 macro_rules! mulxq {
143 ($a: expr, $b: expr, $c: expr) => {
144 asm_instructions
145 .borrow_mut()
146 .push(format!("mulxq {}, {}, {}", &$a, &$b, &$c));
147 };
148 }
149
150 macro_rules! adcxq {
151 ($a: expr, $b: expr) => {
152 asm_instructions
153 .borrow_mut()
154 .push(format!("adcxq {}, {}", &$a, &$b));
155 };
156 }
157
158 macro_rules! adoxq {
159 ($a: expr, $b: expr) => {
160 asm_instructions
161 .borrow_mut()
162 .push(format!("adoxq {}, {}", &$a, &$b));
163 };
164 }
165
166 macro_rules! movq {
167 ($a: expr, $b: expr) => {{
168 asm_instructions
169 .borrow_mut()
170 .push(format!("movq {}, {}", &$a, &$b));
171 }};
172 }
173
174 macro_rules! xorq {
175 ($a: expr, $b: expr) => {
176 asm_instructions
177 .borrow_mut()
178 .push(format!("xorq {}, {}", &$a, &$b))
179 };
180 }
181
182 macro_rules! movq_zero {
183 ($a: expr) => {
184 asm_instructions
185 .borrow_mut()
186 .push(format!("movq $0, {}", &$a))
187 };
188 }
189
190 macro_rules! mul_1 {
191 ($a:expr, $b:ident, $limbs:expr) => {
192 comment("Mul 1 start");
193 movq!($a, rdx);
194 mulxq!($b[0], r[0], r[1]);
195 for j in 1..$limbs - 1 {
196 mulxq!($b[j], rax, r[((j + 1) % $limbs)]);
197 adcxq!(rax, r[j]);
198 }
199 mulxq!($b[$limbs - 1], rax, rcx);
200 movq_zero!(rsi);
201 adcxq!(rax, r[$limbs - 1]);
202 adcxq!(rsi, rcx);
203 comment("Mul 1 end")
204 };
205 }
206
207 macro_rules! mul_add_1 {
208 ($a:ident, $b:ident, $i:ident, $limbs:expr) => {
209 comment(&format!("mul_add_1 start for iteration {}", $i));
210 movq!($a[$i], rdx);
211 for j in 0..$limbs - 1 {
212 mulxq!($b[j], rax, rsi);
213 adcxq!(rax, r[(j + $i) % $limbs]);
214 adoxq!(rsi, r[(j + $i + 1) % $limbs]);
215 }
216 mulxq!($b[$limbs - 1], rax, rcx);
217 movq_zero!(rsi);
218 adcxq!(rax, r[($i + $limbs - 1) % $limbs]);
219 adoxq!(rsi, rcx);
220 adcxq!(rsi, rcx);
221 comment(&format!("mul_add_1 end for iteration {}", $i));
222 };
223 }
224
225 macro_rules! mul_add_shift_1 {
226 ($a:ident, $mod_inv:ident, $i:ident, $limbs:expr) => {
227 comment(&format!("mul_add_shift_1 start for iteration {}", $i));
228 movq!($mod_inv, rdx);
229 mulxq!(r[$i], rdx, rax);
230 mulxq!($a[0], rax, rsi);
231 adcxq!(r[$i % $limbs], rax);
232 adoxq!(rsi, r[($i + 1) % $limbs]);
233 for j in 1..$limbs - 1 {
234 mulxq!($a[j], rax, rsi);
235 adcxq!(rax, r[(j + $i) % $limbs]);
236 adoxq!(rsi, r[(j + $i + 1) % $limbs]);
237 }
238 mulxq!($a[$limbs - 1], rax, r[$i % $limbs]);
239 movq_zero!(rsi);
240 adcxq!(rax, r[($i + $limbs - 1) % $limbs]);
241 adoxq!(rcx, r[$i % $limbs]);
242 adcxq!(rsi, r[$i % $limbs]);
243 comment(&format!("mul_add_shift_1 end for iteration {}", $i));
244 };
245 }
246 {
247 let a1 = a.memory_accesses(limbs);
248 let b1 = b.memory_accesses(limbs);
249 let m1 = modulus.memory_accesses(limbs);
250
251 xorq!(rcx, rcx);
252 for i in 0..limbs {
253 if i == 0 {
254 mul_1!(a1[0], b1, limbs);
255 } else {
256 mul_add_1!(a1, b1, i, limbs);
257 }
258 mul_add_shift_1!(m1, mod_inv, i, limbs);
259 }
260
261 comment("Moving results into `a`");
262 for i in 0..limbs {
263 movq!(r[i], a1[i]);
264 }
265 }
266 asm_instructions.into_inner()
267}
268
269fn generate_impl(num_limbs: usize, is_mul: bool) -> String {
270 let mut ctx = Context::new();
271 ctx.add_declaration("a", "a");
272 if is_mul {
273 ctx.add_declaration("b", "b");
274 }
275 ctx.add_declaration("modulus", "&Self::MODULUS.0");
276 ctx.add_declaration("mod_inv", "Self::INV");
277
278 let asm_instructions = construct_asm_mul(&ctx, num_limbs);
279
280 ctx.add_asm(&asm_instructions);
281 ctx.add_clobbers(
282 [Context::RAX, Context::RCX, Context::RSI, Context::RDX]
283 .iter()
284 .copied(),
285 );
286 ctx.add_clobbers(Context::R.iter().take(std::cmp::min(num_limbs, 8)).copied());
287 ctx.build()
288}
289
290mod tests {
291 #[test]
292 fn expand_muls() {
293 let impl_block = super::generate_impl(4, true);
294 println!("{}", impl_block);
295 }
296}