ark_ff_asm/
lib.rs

1#![warn(
2    unused,
3    future_incompatible,
4    nonstandard_style,
5    rust_2018_idioms,
6    rust_2021_compatibility
7)]
8#![forbid(unsafe_code)]
9#![recursion_limit = "128"]
10
11use proc_macro::TokenStream;
12use syn::{
13    parse::{Parse, ParseStream},
14    Expr,
15};
16
17mod context;
18use context::{AssemblyVar, Context};
19
20use std::cell::RefCell;
21
22const MAX_REGS: usize = 6;
23
24struct AsmMulInput {
25    num_limbs: Box<Expr>,
26    a: Expr,
27    b: Expr,
28}
29
30impl Parse for AsmMulInput {
31    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
32        let input = input
33            .parse_terminated(Expr::parse, syn::token::Comma)?
34            .into_iter()
35            .collect::<Vec<_>>();
36        let num_limbs = input[0].clone();
37        let a = input[1].clone();
38        let b = input[2].clone();
39
40        let num_limbs = if let Expr::Group(syn::ExprGroup { expr, .. }) = num_limbs {
41            expr
42        } else {
43            Box::new(num_limbs)
44        };
45        let output = Self { num_limbs, a, b };
46        Ok(output)
47    }
48}
49
50#[proc_macro]
51pub fn x86_64_asm_mul(input: TokenStream) -> TokenStream {
52    let AsmMulInput { num_limbs, a, b } = syn::parse_macro_input!(input);
53    let num_limbs = if let Expr::Lit(syn::ExprLit {
54        lit: syn::Lit::Int(ref lit_int),
55        ..
56    }) = &*num_limbs
57    {
58        lit_int.base10_parse::<usize>().unwrap()
59    } else {
60        panic!("The number of limbs must be a literal");
61    };
62    if num_limbs <= 6 && num_limbs <= 3 * MAX_REGS {
63        let impl_block = generate_impl(num_limbs, true);
64
65        let inner_ts: Expr = syn::parse_str(&impl_block).unwrap();
66        let ts = quote::quote! {
67            let a = &mut #a;
68            let b = &#b;
69            #inner_ts
70        };
71        ts.into()
72    } else {
73        TokenStream::new()
74    }
75}
76
77struct AsmSquareInput {
78    num_limbs: Box<Expr>,
79    a: Expr,
80}
81
82impl Parse for AsmSquareInput {
83    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
84        let input = input
85            .parse_terminated(Expr::parse, syn::token::Comma)?
86            .into_iter()
87            .collect::<Vec<_>>();
88        let num_limbs = input[0].clone();
89        let a = input[1].clone();
90
91        let num_limbs = if let Expr::Group(syn::ExprGroup { expr, .. }) = num_limbs {
92            expr
93        } else {
94            Box::new(num_limbs)
95        };
96        let output = Self { num_limbs, a };
97        Ok(output)
98    }
99}
100
101#[proc_macro]
102pub fn x86_64_asm_square(input: TokenStream) -> TokenStream {
103    let AsmSquareInput { num_limbs, a } = syn::parse_macro_input!(input);
104    let num_limbs = if let Expr::Lit(syn::ExprLit {
105        lit: syn::Lit::Int(ref lit_int),
106        ..
107    }) = &*num_limbs
108    {
109        lit_int.base10_parse::<usize>().unwrap()
110    } else {
111        panic!("The number of limbs must be a literal");
112    };
113    if num_limbs <= 6 && num_limbs <= 3 * MAX_REGS {
114        let impl_block = generate_impl(num_limbs, false);
115
116        let inner_ts: Expr = syn::parse_str(&impl_block).unwrap();
117        let ts = quote::quote! {
118            let a = &mut #a;
119            #inner_ts
120        };
121        ts.into()
122    } else {
123        TokenStream::new()
124    }
125}
126
127fn construct_asm_mul(ctx: &Context<'_>, limbs: usize) -> Vec<String> {
128    let r: Vec<AssemblyVar> = Context::R.iter().map(|r| (*r).into()).collect();
129    let rax: AssemblyVar = Context::RAX.into();
130    let rcx: AssemblyVar = Context::RCX.into();
131    let rdx: AssemblyVar = Context::RDX.into();
132    let rsi: AssemblyVar = Context::RSI.into();
133    let a: AssemblyVar = ctx.get_decl("a").into();
134    let b: AssemblyVar = ctx.get_decl_with_fallback("b", "a").into(); // "b" is not available during squaring.
135    let modulus: AssemblyVar = ctx.get_decl("modulus").into();
136    let mod_inv: AssemblyVar = ctx.get_decl("mod_inv").into();
137
138    let asm_instructions = RefCell::new(Vec::new());
139
140    let comment = |comment: &str| {
141        asm_instructions
142            .borrow_mut()
143            .push(format!("// {}", comment));
144    };
145
146    macro_rules! mulxq {
147        ($a: expr, $b: expr, $c: expr) => {
148            asm_instructions
149                .borrow_mut()
150                .push(format!("mulxq {}, {}, {}", &$a, &$b, &$c));
151        };
152    }
153
154    macro_rules! adcxq {
155        ($a: expr, $b: expr) => {
156            asm_instructions
157                .borrow_mut()
158                .push(format!("adcxq {}, {}", &$a, &$b));
159        };
160    }
161
162    macro_rules! adoxq {
163        ($a: expr, $b: expr) => {
164            asm_instructions
165                .borrow_mut()
166                .push(format!("adoxq {}, {}", &$a, &$b));
167        };
168    }
169
170    macro_rules! movq {
171        ($a: expr, $b: expr) => {{
172            asm_instructions
173                .borrow_mut()
174                .push(format!("movq {}, {}", &$a, &$b));
175        }};
176    }
177
178    macro_rules! xorq {
179        ($a: expr, $b: expr) => {
180            asm_instructions
181                .borrow_mut()
182                .push(format!("xorq {}, {}", &$a, &$b))
183        };
184    }
185
186    macro_rules! movq_zero {
187        ($a: expr) => {
188            asm_instructions
189                .borrow_mut()
190                .push(format!("movq $0, {}", &$a))
191        };
192    }
193
194    macro_rules! mul_1 {
195        ($a:expr, $b:ident, $limbs:expr) => {
196            comment("Mul 1 start");
197            movq!($a, rdx);
198            mulxq!($b[0], r[0], r[1]);
199            for j in 1..$limbs - 1 {
200                mulxq!($b[j], rax, r[((j + 1) % $limbs)]);
201                adcxq!(rax, r[j]);
202            }
203            mulxq!($b[$limbs - 1], rax, rcx);
204            movq_zero!(rsi);
205            adcxq!(rax, r[$limbs - 1]);
206            adcxq!(rsi, rcx);
207            comment("Mul 1 end")
208        };
209    }
210
211    macro_rules! mul_add_1 {
212        ($a:ident, $b:ident, $i:ident, $limbs:expr) => {
213            comment(&format!("mul_add_1 start for iteration {}", $i));
214            movq!($a[$i], rdx);
215            for j in 0..$limbs - 1 {
216                mulxq!($b[j], rax, rsi);
217                adcxq!(rax, r[(j + $i) % $limbs]);
218                adoxq!(rsi, r[(j + $i + 1) % $limbs]);
219            }
220            mulxq!($b[$limbs - 1], rax, rcx);
221            movq_zero!(rsi);
222            adcxq!(rax, r[($i + $limbs - 1) % $limbs]);
223            adoxq!(rsi, rcx);
224            adcxq!(rsi, rcx);
225            comment(&format!("mul_add_1 end for iteration {}", $i));
226        };
227    }
228
229    macro_rules! mul_add_shift_1 {
230        ($a:ident, $mod_inv:ident, $i:ident, $limbs:expr) => {
231            comment(&format!("mul_add_shift_1 start for iteration {}", $i));
232            movq!($mod_inv, rdx);
233            mulxq!(r[$i], rdx, rax);
234            mulxq!($a[0], rax, rsi);
235            adcxq!(r[$i % $limbs], rax);
236            adoxq!(rsi, r[($i + 1) % $limbs]);
237            for j in 1..$limbs - 1 {
238                mulxq!($a[j], rax, rsi);
239                adcxq!(rax, r[(j + $i) % $limbs]);
240                adoxq!(rsi, r[(j + $i + 1) % $limbs]);
241            }
242            mulxq!($a[$limbs - 1], rax, r[$i % $limbs]);
243            movq_zero!(rsi);
244            adcxq!(rax, r[($i + $limbs - 1) % $limbs]);
245            adoxq!(rcx, r[$i % $limbs]);
246            adcxq!(rsi, r[$i % $limbs]);
247            comment(&format!("mul_add_shift_1 end for iteration {}", $i));
248        };
249    }
250    {
251        let a1 = a.memory_accesses(limbs);
252        let b1 = b.memory_accesses(limbs);
253        let m1 = modulus.memory_accesses(limbs);
254
255        xorq!(rcx, rcx);
256        for i in 0..limbs {
257            if i == 0 {
258                mul_1!(a1[0], b1, limbs);
259            } else {
260                mul_add_1!(a1, b1, i, limbs);
261            }
262            mul_add_shift_1!(m1, mod_inv, i, limbs);
263        }
264
265        comment("Moving results into `a`");
266        for i in 0..limbs {
267            movq!(r[i], a1[i]);
268        }
269    }
270    asm_instructions.into_inner()
271}
272
273fn generate_impl(num_limbs: usize, is_mul: bool) -> String {
274    let mut ctx = Context::new();
275    ctx.add_declaration("a", "a");
276    if is_mul {
277        ctx.add_declaration("b", "b");
278    }
279    ctx.add_declaration("modulus", "&Self::MODULUS.0");
280    ctx.add_declaration("mod_inv", "Self::INV");
281
282    if num_limbs > MAX_REGS {
283        ctx.add_buffer(2 * num_limbs);
284        ctx.add_declaration("buf", "&mut spill_buffer");
285    }
286
287    let asm_instructions = construct_asm_mul(&ctx, num_limbs);
288
289    ctx.add_asm(&asm_instructions);
290    ctx.add_clobbers(
291        [Context::RAX, Context::RCX, Context::RSI, Context::RDX]
292            .iter()
293            .copied(),
294    );
295    ctx.add_clobbers(Context::R.iter().take(std::cmp::min(num_limbs, 8)).copied());
296    ctx.build()
297}
298
299mod tests {
300    #[test]
301    fn expand_muls() {
302        let impl_block = super::generate_impl(4, true);
303        println!("{}", impl_block);
304    }
305}