ark_ff_asm/context/
mod.rs

1mod data_structures;
2pub use data_structures::*;
3
4#[derive(Clone)]
5pub struct Context<'a> {
6    assembly_instructions: Vec<String>,
7    declarations: Vec<Declaration<'a>>,
8    used_registers: Vec<Register<'a>>,
9}
10
11impl<'a> Context<'a> {
12    pub const RAX: Register<'static> = Register("rax");
13    pub const RSI: Register<'static> = Register("rsi");
14    pub const RCX: Register<'static> = Register("rcx");
15    pub const RDX: Register<'static> = Register("rdx");
16
17    pub const R: [Register<'static>; 8] = [
18        Register("r8"),
19        Register("r9"),
20        Register("r10"),
21        Register("r11"),
22        Register("r12"),
23        Register("r13"),
24        Register("r14"),
25        Register("r15"),
26    ];
27
28    pub fn new() -> Self {
29        Self {
30            assembly_instructions: Vec::new(),
31            declarations: Vec::new(),
32            used_registers: Vec::new(),
33        }
34    }
35
36    fn find(&self, name: &str) -> Option<&Declaration<'_>> {
37        self.declarations.iter().find(|item| item.name == name)
38    }
39
40    fn append(&mut self, other: &str) {
41        self.assembly_instructions.push(format!("\"{}\",", other));
42    }
43
44    fn instructions_to_string(&self) -> String {
45        self.assembly_instructions.join("\n")
46    }
47
48    fn get_decl_name(&self, name: &str) -> Option<&Declaration<'_>> {
49        self.find(name)
50    }
51
52    pub fn get_decl(&self, name: &str) -> Declaration<'_> {
53        *self.get_decl_name(name).unwrap()
54    }
55
56    pub fn get_decl_with_fallback(&self, name: &str, fallback_name: &str) -> Declaration<'_> {
57        self.get_decl_name(name)
58            .copied()
59            .unwrap_or_else(|| self.get_decl(fallback_name))
60    }
61
62    pub fn add_declaration(&mut self, name: &'a str, expr: &'a str) {
63        let declaration = Declaration { name, expr };
64        self.declarations.push(declaration);
65    }
66
67    pub fn add_buffer(&mut self, extra_reg: usize) {
68        self.append(&format!(
69            "let mut spill_buffer = core::mem::MaybeUninit::<[u64; {}]>::uninit();",
70            extra_reg
71        ));
72    }
73
74    pub fn add_asm(&mut self, asm_instructions: &[String]) {
75        for instruction in asm_instructions {
76            self.append(instruction)
77        }
78    }
79
80    pub fn add_clobbers(&mut self, clobbers: impl Iterator<Item = Register<'a>>) {
81        for clobber in clobbers {
82            self.add_clobber(clobber)
83        }
84    }
85
86    pub fn add_clobber(&mut self, clobber: Register<'a>) {
87        self.used_registers.push(clobber);
88    }
89
90    pub fn build(self) -> String {
91        let declarations: String = self
92            .declarations
93            .iter()
94            .map(ToString::to_string)
95            .collect::<Vec<String>>()
96            .join("\n");
97        let clobbers = self
98            .used_registers
99            .iter()
100            .map(|l| format!("out({}) _,", l))
101            .collect::<Vec<String>>()
102            .join("\n");
103        let options = "options(att_syntax)".to_string();
104        let assembly = self.instructions_to_string();
105        [
106            "unsafe {".to_string(),
107            "ark_std::arch::asm!(".to_string(),
108            assembly,
109            declarations,
110            clobbers,
111            options,
112            ")".to_string(),
113            "}".to_string(),
114        ]
115        .join("\n")
116    }
117}