1#![warn(
2 unused,
3 future_incompatible,
4 nonstandard_style,
5 rust_2018_idioms,
6 rust_2021_compatibility
7)]
8#![forbid(unsafe_code)]
9#![recursion_limit = "128"]
10
11use proc_macro::TokenStream;
12use syn::{
13 parse::{Parse, ParseStream},
14 Expr,
15};
16
17mod context;
18use context::{AssemblyVar, Context};
19
20use std::cell::RefCell;
21
22const MAX_REGS: usize = 6;
23
24struct AsmMulInput {
25 num_limbs: Box<Expr>,
26 a: Expr,
27 b: Expr,
28}
29
30impl Parse for AsmMulInput {
31 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
32 let input = input
33 .parse_terminated(Expr::parse, syn::token::Comma)?
34 .into_iter()
35 .collect::<Vec<_>>();
36 let num_limbs = input[0].clone();
37 let a = input[1].clone();
38 let b = input[2].clone();
39
40 let num_limbs = if let Expr::Group(syn::ExprGroup { expr, .. }) = num_limbs {
41 expr
42 } else {
43 Box::new(num_limbs)
44 };
45 let output = Self { num_limbs, a, b };
46 Ok(output)
47 }
48}
49
50#[proc_macro]
51pub fn x86_64_asm_mul(input: TokenStream) -> TokenStream {
52 let AsmMulInput { num_limbs, a, b } = syn::parse_macro_input!(input);
53 let num_limbs = if let Expr::Lit(syn::ExprLit {
54 lit: syn::Lit::Int(ref lit_int),
55 ..
56 }) = &*num_limbs
57 {
58 lit_int.base10_parse::<usize>().unwrap()
59 } else {
60 panic!("The number of limbs must be a literal");
61 };
62 if num_limbs <= 6 && num_limbs <= 3 * MAX_REGS {
63 let impl_block = generate_impl(num_limbs, true);
64
65 let inner_ts: Expr = syn::parse_str(&impl_block).unwrap();
66 let ts = quote::quote! {
67 let a = &mut #a;
68 let b = &#b;
69 #inner_ts
70 };
71 ts.into()
72 } else {
73 TokenStream::new()
74 }
75}
76
77struct AsmSquareInput {
78 num_limbs: Box<Expr>,
79 a: Expr,
80}
81
82impl Parse for AsmSquareInput {
83 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
84 let input = input
85 .parse_terminated(Expr::parse, syn::token::Comma)?
86 .into_iter()
87 .collect::<Vec<_>>();
88 let num_limbs = input[0].clone();
89 let a = input[1].clone();
90
91 let num_limbs = if let Expr::Group(syn::ExprGroup { expr, .. }) = num_limbs {
92 expr
93 } else {
94 Box::new(num_limbs)
95 };
96 let output = Self { num_limbs, a };
97 Ok(output)
98 }
99}
100
101#[proc_macro]
102pub fn x86_64_asm_square(input: TokenStream) -> TokenStream {
103 let AsmSquareInput { num_limbs, a } = syn::parse_macro_input!(input);
104 let num_limbs = if let Expr::Lit(syn::ExprLit {
105 lit: syn::Lit::Int(ref lit_int),
106 ..
107 }) = &*num_limbs
108 {
109 lit_int.base10_parse::<usize>().unwrap()
110 } else {
111 panic!("The number of limbs must be a literal");
112 };
113 if num_limbs <= 6 && num_limbs <= 3 * MAX_REGS {
114 let impl_block = generate_impl(num_limbs, false);
115
116 let inner_ts: Expr = syn::parse_str(&impl_block).unwrap();
117 let ts = quote::quote! {
118 let a = &mut #a;
119 #inner_ts
120 };
121 ts.into()
122 } else {
123 TokenStream::new()
124 }
125}
126
127fn construct_asm_mul(ctx: &Context<'_>, limbs: usize) -> Vec<String> {
128 let r: Vec<AssemblyVar> = Context::R.iter().map(|r| (*r).into()).collect();
129 let rax: AssemblyVar = Context::RAX.into();
130 let rcx: AssemblyVar = Context::RCX.into();
131 let rdx: AssemblyVar = Context::RDX.into();
132 let rsi: AssemblyVar = Context::RSI.into();
133 let a: AssemblyVar = ctx.get_decl("a").into();
134 let b: AssemblyVar = ctx.get_decl_with_fallback("b", "a").into(); let modulus: AssemblyVar = ctx.get_decl("modulus").into();
136 let mod_inv: AssemblyVar = ctx.get_decl("mod_inv").into();
137
138 let asm_instructions = RefCell::new(Vec::new());
139
140 let comment = |comment: &str| {
141 asm_instructions
142 .borrow_mut()
143 .push(format!("// {}", comment));
144 };
145
146 macro_rules! mulxq {
147 ($a: expr, $b: expr, $c: expr) => {
148 asm_instructions
149 .borrow_mut()
150 .push(format!("mulxq {}, {}, {}", &$a, &$b, &$c));
151 };
152 }
153
154 macro_rules! adcxq {
155 ($a: expr, $b: expr) => {
156 asm_instructions
157 .borrow_mut()
158 .push(format!("adcxq {}, {}", &$a, &$b));
159 };
160 }
161
162 macro_rules! adoxq {
163 ($a: expr, $b: expr) => {
164 asm_instructions
165 .borrow_mut()
166 .push(format!("adoxq {}, {}", &$a, &$b));
167 };
168 }
169
170 macro_rules! movq {
171 ($a: expr, $b: expr) => {{
172 asm_instructions
173 .borrow_mut()
174 .push(format!("movq {}, {}", &$a, &$b));
175 }};
176 }
177
178 macro_rules! xorq {
179 ($a: expr, $b: expr) => {
180 asm_instructions
181 .borrow_mut()
182 .push(format!("xorq {}, {}", &$a, &$b))
183 };
184 }
185
186 macro_rules! movq_zero {
187 ($a: expr) => {
188 asm_instructions
189 .borrow_mut()
190 .push(format!("movq $0, {}", &$a))
191 };
192 }
193
194 macro_rules! mul_1 {
195 ($a:expr, $b:ident, $limbs:expr) => {
196 comment("Mul 1 start");
197 movq!($a, rdx);
198 mulxq!($b[0], r[0], r[1]);
199 for j in 1..$limbs - 1 {
200 mulxq!($b[j], rax, r[((j + 1) % $limbs)]);
201 adcxq!(rax, r[j]);
202 }
203 mulxq!($b[$limbs - 1], rax, rcx);
204 movq_zero!(rsi);
205 adcxq!(rax, r[$limbs - 1]);
206 adcxq!(rsi, rcx);
207 comment("Mul 1 end")
208 };
209 }
210
211 macro_rules! mul_add_1 {
212 ($a:ident, $b:ident, $i:ident, $limbs:expr) => {
213 comment(&format!("mul_add_1 start for iteration {}", $i));
214 movq!($a[$i], rdx);
215 for j in 0..$limbs - 1 {
216 mulxq!($b[j], rax, rsi);
217 adcxq!(rax, r[(j + $i) % $limbs]);
218 adoxq!(rsi, r[(j + $i + 1) % $limbs]);
219 }
220 mulxq!($b[$limbs - 1], rax, rcx);
221 movq_zero!(rsi);
222 adcxq!(rax, r[($i + $limbs - 1) % $limbs]);
223 adoxq!(rsi, rcx);
224 adcxq!(rsi, rcx);
225 comment(&format!("mul_add_1 end for iteration {}", $i));
226 };
227 }
228
229 macro_rules! mul_add_shift_1 {
230 ($a:ident, $mod_inv:ident, $i:ident, $limbs:expr) => {
231 comment(&format!("mul_add_shift_1 start for iteration {}", $i));
232 movq!($mod_inv, rdx);
233 mulxq!(r[$i], rdx, rax);
234 mulxq!($a[0], rax, rsi);
235 adcxq!(r[$i % $limbs], rax);
236 adoxq!(rsi, r[($i + 1) % $limbs]);
237 for j in 1..$limbs - 1 {
238 mulxq!($a[j], rax, rsi);
239 adcxq!(rax, r[(j + $i) % $limbs]);
240 adoxq!(rsi, r[(j + $i + 1) % $limbs]);
241 }
242 mulxq!($a[$limbs - 1], rax, r[$i % $limbs]);
243 movq_zero!(rsi);
244 adcxq!(rax, r[($i + $limbs - 1) % $limbs]);
245 adoxq!(rcx, r[$i % $limbs]);
246 adcxq!(rsi, r[$i % $limbs]);
247 comment(&format!("mul_add_shift_1 end for iteration {}", $i));
248 };
249 }
250 {
251 let a1 = a.memory_accesses(limbs);
252 let b1 = b.memory_accesses(limbs);
253 let m1 = modulus.memory_accesses(limbs);
254
255 xorq!(rcx, rcx);
256 for i in 0..limbs {
257 if i == 0 {
258 mul_1!(a1[0], b1, limbs);
259 } else {
260 mul_add_1!(a1, b1, i, limbs);
261 }
262 mul_add_shift_1!(m1, mod_inv, i, limbs);
263 }
264
265 comment("Moving results into `a`");
266 for i in 0..limbs {
267 movq!(r[i], a1[i]);
268 }
269 }
270 asm_instructions.into_inner()
271}
272
273fn generate_impl(num_limbs: usize, is_mul: bool) -> String {
274 let mut ctx = Context::new();
275 ctx.add_declaration("a", "a");
276 if is_mul {
277 ctx.add_declaration("b", "b");
278 }
279 ctx.add_declaration("modulus", "&Self::MODULUS.0");
280 ctx.add_declaration("mod_inv", "Self::INV");
281
282 if num_limbs > MAX_REGS {
283 ctx.add_buffer(2 * num_limbs);
284 ctx.add_declaration("buf", "&mut spill_buffer");
285 }
286
287 let asm_instructions = construct_asm_mul(&ctx, num_limbs);
288
289 ctx.add_asm(&asm_instructions);
290 ctx.add_clobbers(
291 [Context::RAX, Context::RCX, Context::RSI, Context::RDX]
292 .iter()
293 .copied(),
294 );
295 ctx.add_clobbers(Context::R.iter().take(std::cmp::min(num_limbs, 8)).copied());
296 ctx.build()
297}
298
299mod tests {
300 #[test]
301 fn expand_muls() {
302 let impl_block = super::generate_impl(4, true);
303 println!("{}", impl_block);
304 }
305}