Skip to main content

ark_ff_asm/
lib.rs

1#![warn(
2    unused,
3    future_incompatible,
4    nonstandard_style,
5    rust_2018_idioms,
6    rust_2021_compatibility
7)]
8#![forbid(unsafe_code)]
9#![recursion_limit = "128"]
10
11use proc_macro::TokenStream;
12use syn::{
13    parse::{Parse, ParseStream},
14    Expr,
15};
16
17mod context;
18use context::{AssemblyVar, Context};
19
20use std::cell::RefCell;
21
22struct AsmMulInput {
23    num_limbs: Box<Expr>,
24    a: Expr,
25    b: Expr,
26}
27
28impl Parse for AsmMulInput {
29    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
30        let input = input
31            .parse_terminated(Expr::parse, syn::token::Comma)?
32            .into_iter()
33            .collect::<Vec<_>>();
34        let num_limbs = input[0].clone();
35        let a = input[1].clone();
36        let b = input[2].clone();
37
38        let num_limbs = if let Expr::Group(syn::ExprGroup { expr, .. }) = num_limbs {
39            expr
40        } else {
41            Box::new(num_limbs)
42        };
43        let output = Self { num_limbs, a, b };
44        Ok(output)
45    }
46}
47
48#[proc_macro]
49pub fn x86_64_asm_mul(input: TokenStream) -> TokenStream {
50    let AsmMulInput { num_limbs, a, b } = syn::parse_macro_input!(input);
51    let num_limbs = if let Expr::Lit(syn::ExprLit {
52        lit: syn::Lit::Int(ref lit_int),
53        ..
54    }) = &*num_limbs
55    {
56        lit_int.base10_parse::<usize>().unwrap()
57    } else {
58        panic!("The number of limbs must be a literal");
59    };
60    if num_limbs <= 6 {
61        let impl_block = generate_impl(num_limbs, true);
62
63        let inner_ts: Expr = syn::parse_str(&impl_block).unwrap();
64        let ts = quote::quote! {
65            let a = &mut #a;
66            let b = &#b;
67            #inner_ts
68        };
69        ts.into()
70    } else {
71        TokenStream::new()
72    }
73}
74
75struct AsmSquareInput {
76    num_limbs: Box<Expr>,
77    a: Expr,
78}
79
80impl Parse for AsmSquareInput {
81    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
82        let input = input
83            .parse_terminated(Expr::parse, syn::token::Comma)?
84            .into_iter()
85            .collect::<Vec<_>>();
86        let num_limbs = input[0].clone();
87        let a = input[1].clone();
88
89        let num_limbs = if let Expr::Group(syn::ExprGroup { expr, .. }) = num_limbs {
90            expr
91        } else {
92            Box::new(num_limbs)
93        };
94        let output = Self { num_limbs, a };
95        Ok(output)
96    }
97}
98
99#[proc_macro]
100pub fn x86_64_asm_square(input: TokenStream) -> TokenStream {
101    let AsmSquareInput { num_limbs, a } = syn::parse_macro_input!(input);
102    let num_limbs = if let Expr::Lit(syn::ExprLit {
103        lit: syn::Lit::Int(ref lit_int),
104        ..
105    }) = &*num_limbs
106    {
107        lit_int.base10_parse::<usize>().unwrap()
108    } else {
109        panic!("The number of limbs must be a literal");
110    };
111    if num_limbs <= 6 {
112        let impl_block = generate_impl(num_limbs, false);
113
114        let inner_ts: Expr = syn::parse_str(&impl_block).unwrap();
115        let ts = quote::quote! {
116            let a = &mut #a;
117            #inner_ts
118        };
119        ts.into()
120    } else {
121        TokenStream::new()
122    }
123}
124
125fn construct_asm_mul(ctx: &Context<'_>, limbs: usize) -> Vec<String> {
126    let r: Vec<AssemblyVar> = Context::R.iter().map(|r| (*r).into()).collect();
127    let rax: AssemblyVar = Context::RAX.into();
128    let rcx: AssemblyVar = Context::RCX.into();
129    let rdx: AssemblyVar = Context::RDX.into();
130    let rsi: AssemblyVar = Context::RSI.into();
131    let a: AssemblyVar = ctx.get_decl("a").into();
132    let b: AssemblyVar = ctx.get_decl_with_fallback("b", "a").into(); // "b" is not available during squaring.
133    let modulus: AssemblyVar = ctx.get_decl("modulus").into();
134    let mod_inv: AssemblyVar = ctx.get_decl("mod_inv").into();
135
136    let asm_instructions = RefCell::new(Vec::new());
137
138    let comment = |comment: &str| {
139        asm_instructions.borrow_mut().push(format!("// {comment}"));
140    };
141
142    macro_rules! mulxq {
143        ($a: expr, $b: expr, $c: expr) => {
144            asm_instructions
145                .borrow_mut()
146                .push(format!("mulxq {}, {}, {}", &$a, &$b, &$c));
147        };
148    }
149
150    macro_rules! adcxq {
151        ($a: expr, $b: expr) => {
152            asm_instructions
153                .borrow_mut()
154                .push(format!("adcxq {}, {}", &$a, &$b));
155        };
156    }
157
158    macro_rules! adoxq {
159        ($a: expr, $b: expr) => {
160            asm_instructions
161                .borrow_mut()
162                .push(format!("adoxq {}, {}", &$a, &$b));
163        };
164    }
165
166    macro_rules! movq {
167        ($a: expr, $b: expr) => {{
168            asm_instructions
169                .borrow_mut()
170                .push(format!("movq {}, {}", &$a, &$b));
171        }};
172    }
173
174    macro_rules! xorq {
175        ($a: expr, $b: expr) => {
176            asm_instructions
177                .borrow_mut()
178                .push(format!("xorq {}, {}", &$a, &$b))
179        };
180    }
181
182    macro_rules! movq_zero {
183        ($a: expr) => {
184            asm_instructions
185                .borrow_mut()
186                .push(format!("movq $0, {}", &$a))
187        };
188    }
189
190    macro_rules! mul_1 {
191        ($a:expr, $b:ident, $limbs:expr) => {
192            comment("Mul 1 start");
193            movq!($a, rdx);
194            mulxq!($b[0], r[0], r[1]);
195            for j in 1..$limbs - 1 {
196                mulxq!($b[j], rax, r[((j + 1) % $limbs)]);
197                adcxq!(rax, r[j]);
198            }
199            mulxq!($b[$limbs - 1], rax, rcx);
200            movq_zero!(rsi);
201            adcxq!(rax, r[$limbs - 1]);
202            adcxq!(rsi, rcx);
203            comment("Mul 1 end")
204        };
205    }
206
207    macro_rules! mul_add_1 {
208        ($a:ident, $b:ident, $i:ident, $limbs:expr) => {
209            comment(&format!("mul_add_1 start for iteration {}", $i));
210            movq!($a[$i], rdx);
211            for j in 0..$limbs - 1 {
212                mulxq!($b[j], rax, rsi);
213                adcxq!(rax, r[(j + $i) % $limbs]);
214                adoxq!(rsi, r[(j + $i + 1) % $limbs]);
215            }
216            mulxq!($b[$limbs - 1], rax, rcx);
217            movq_zero!(rsi);
218            adcxq!(rax, r[($i + $limbs - 1) % $limbs]);
219            adoxq!(rsi, rcx);
220            adcxq!(rsi, rcx);
221            comment(&format!("mul_add_1 end for iteration {}", $i));
222        };
223    }
224
225    macro_rules! mul_add_shift_1 {
226        ($a:ident, $mod_inv:ident, $i:ident, $limbs:expr) => {
227            comment(&format!("mul_add_shift_1 start for iteration {}", $i));
228            movq!($mod_inv, rdx);
229            mulxq!(r[$i], rdx, rax);
230            mulxq!($a[0], rax, rsi);
231            adcxq!(r[$i % $limbs], rax);
232            adoxq!(rsi, r[($i + 1) % $limbs]);
233            for j in 1..$limbs - 1 {
234                mulxq!($a[j], rax, rsi);
235                adcxq!(rax, r[(j + $i) % $limbs]);
236                adoxq!(rsi, r[(j + $i + 1) % $limbs]);
237            }
238            mulxq!($a[$limbs - 1], rax, r[$i % $limbs]);
239            movq_zero!(rsi);
240            adcxq!(rax, r[($i + $limbs - 1) % $limbs]);
241            adoxq!(rcx, r[$i % $limbs]);
242            adcxq!(rsi, r[$i % $limbs]);
243            comment(&format!("mul_add_shift_1 end for iteration {}", $i));
244        };
245    }
246    {
247        let a1 = a.memory_accesses(limbs);
248        let b1 = b.memory_accesses(limbs);
249        let m1 = modulus.memory_accesses(limbs);
250
251        xorq!(rcx, rcx);
252        for i in 0..limbs {
253            if i == 0 {
254                mul_1!(a1[0], b1, limbs);
255            } else {
256                mul_add_1!(a1, b1, i, limbs);
257            }
258            mul_add_shift_1!(m1, mod_inv, i, limbs);
259        }
260
261        comment("Moving results into `a`");
262        for i in 0..limbs {
263            movq!(r[i], a1[i]);
264        }
265    }
266    asm_instructions.into_inner()
267}
268
269fn generate_impl(num_limbs: usize, is_mul: bool) -> String {
270    let mut ctx = Context::new();
271    ctx.add_declaration("a", "a");
272    if is_mul {
273        ctx.add_declaration("b", "b");
274    }
275    ctx.add_declaration("modulus", "&Self::MODULUS.0");
276    ctx.add_declaration("mod_inv", "Self::INV");
277
278    let asm_instructions = construct_asm_mul(&ctx, num_limbs);
279
280    ctx.add_asm(&asm_instructions);
281    ctx.add_clobbers(
282        [Context::RAX, Context::RCX, Context::RSI, Context::RDX]
283            .iter()
284            .copied(),
285    );
286    ctx.add_clobbers(Context::R.iter().take(std::cmp::min(num_limbs, 8)).copied());
287    ctx.build()
288}
289
290mod tests {
291    #[test]
292    fn expand_muls() {
293        let impl_block = super::generate_impl(4, true);
294        println!("{}", impl_block);
295    }
296}