ark_ff_macros/
unroll.rs

1//! An attribute-like procedural macro for unrolling for loops with integer
2//! literal bounds.
3//!
4//! This crate provides the [`unroll_for_loops`](../attr.unroll_for_loops.html)
5//! attribute-like macro that can be applied to functions containing for-loops
6//! with integer bounds. This macro looks for loops to unroll and unrolls them
7//! at compile time.
8//!
9//!
10//! ## Usage
11//!
12//! Just add `#[unroll_for_loops]` above the function whose for loops you would
13//! like to unroll. Currently all for loops with integer literal bounds will be
14//! unrolled, although this macro currently can't see inside complex code (e.g.
15//! for loops within closures).
16//!
17//!
18//! ## Example
19//!
20//! The following function computes a matrix-vector product and returns the
21//! result as an array. Both of the inner for-loops are unrolled when
22//! `#[unroll_for_loops]` is applied.
23//!
24//! ```rust
25//! use ark_ff_macros::unroll_for_loops;
26//!
27//! #[unroll_for_loops(12)]
28//! fn mtx_vec_mul(mtx: &[[f64; 5]; 5], vec: &[f64; 5]) -> [f64; 5] {
29//!     let mut out = [0.0; 5];
30//!     for col in 0..5 {
31//!         for row in 0..5 {
32//!             out[row] += mtx[col][row] * vec[col];
33//!         }
34//!     }
35//!     out
36//! }
37//!
38//! fn mtx_vec_mul_2(mtx: &[[f64; 5]; 5], vec: &[f64; 5]) -> [f64; 5] {
39//!     let mut out = [0.0; 5];
40//!     for col in 0..5 {
41//!         for row in 0..5 {
42//!             out[row] += mtx[col][row] * vec[col];
43//!         }
44//!     }
45//!     out
46//! }
47//! let a = [[1.0, 2.0, 3.0, 4.0, 5.0]; 5];
48//! let b = [7.9, 4.8, 3.8, 4.22, 5.2];
49//! assert_eq!(mtx_vec_mul(&a, &b), mtx_vec_mul_2(&a, &b));
50//! ```
51//!
52//! This code was adapted from the [`unroll`](https://crates.io/crates/unroll) crate.
53
54use std::borrow::Borrow;
55
56use syn::{
57    parse_quote, token::Brace, Block, Expr, ExprBlock, ExprForLoop, ExprIf, ExprLet, ExprRange,
58    Pat, PatIdent, RangeLimits, Stmt,
59};
60
61/// Routine to unroll for loops within a block
62pub(crate) fn unroll_in_block(block: &Block, unroll_by: usize) -> Block {
63    let &Block {
64        ref brace_token,
65        ref stmts,
66    } = block;
67    let mut new_stmts = Vec::new();
68    for stmt in stmts.iter() {
69        if let Stmt::Expr(expr, semi) = stmt {
70            new_stmts.push(Stmt::Expr(unroll(expr, unroll_by), *semi));
71        } else {
72            new_stmts.push((*stmt).clone());
73        }
74    }
75    Block {
76        brace_token: *brace_token,
77        stmts: new_stmts,
78    }
79}
80
81/// Routine to unroll a for loop statement, or return the statement unchanged if
82/// it's not a for loop.
83fn unroll(expr: &Expr, unroll_by: usize) -> Expr {
84    // impose a scope that we can break out of so we can return stmt without copying it.
85    if let Expr::ForLoop(for_loop) = expr {
86        let ExprForLoop {
87            ref attrs,
88            ref label,
89            ref pat,
90            expr: ref range,
91            ref body,
92            ..
93        } = *for_loop;
94
95        let new_body = unroll_in_block(body, unroll_by);
96
97        let forloop_with_body = |body| {
98            Expr::ForLoop(ExprForLoop {
99                body,
100                ..(*for_loop).clone()
101            })
102        };
103
104        if let Pat::Ident(PatIdent {
105            by_ref,
106            mutability,
107            ident,
108            subpat,
109            ..
110        }) = *pat.clone()
111        {
112            // Don't know how to deal with these so skip and return the original.
113            if !by_ref.is_none() || !mutability.is_none() || !subpat.is_none() {
114                return forloop_with_body(new_body);
115            }
116            let idx = ident; // got the index variable name
117
118            if let Expr::Range(ExprRange {
119                start, limits, end, ..
120            }) = range.borrow()
121            {
122                // Parse `start` in `start..end`.
123                let begin = match start {
124                    Some(e) => e.clone(),
125                    _ => Box::new(parse_quote!(0usize)),
126                };
127                let end = match end {
128                    Some(e) => e.clone(),
129                    _ => return forloop_with_body(new_body),
130                };
131                let end_is_closed = if let RangeLimits::Closed(_) = limits {
132                    1usize
133                } else {
134                    0
135                };
136                let end: Expr = parse_quote!(#end + #end_is_closed);
137
138                let preamble: Vec<Stmt> = parse_quote! {
139                    let total_iters: usize = (#end).checked_sub(#begin).unwrap_or(0);
140                    let num_loops = total_iters / #unroll_by;
141                    let remainder = total_iters % #unroll_by;
142                };
143                let mut block = Block {
144                    brace_token: Brace::default(),
145                    stmts: preamble,
146                };
147                let mut loop_expr: ExprForLoop = parse_quote! {
148                    for #idx in (0..num_loops) {
149                        let mut #idx = #begin + #idx * #unroll_by;
150                    }
151                };
152                let loop_block: Vec<Stmt> = parse_quote! {
153                    if #idx < #end {
154                        #new_body
155                    }
156                    #idx += 1;
157                };
158                let loop_body = (0..unroll_by).flat_map(|_| loop_block.clone());
159                loop_expr.body.stmts.extend(loop_body);
160                block.stmts.push(Stmt::Expr(Expr::ForLoop(loop_expr), None));
161
162                // idx = num_loops * unroll_by;
163                block
164                    .stmts
165                    .push(parse_quote! { let mut #idx = #begin + num_loops * #unroll_by; });
166                // if idx < remainder + num_loops * unroll_by { ... }
167                let post_loop_block: Vec<Stmt> = parse_quote! {
168                    if #idx < #end {
169                        #new_body
170                    }
171                    #idx += 1;
172                };
173                let post_loop = (0..unroll_by).flat_map(|_| post_loop_block.clone());
174                block.stmts.extend(post_loop);
175
176                let mut attrs = attrs.clone();
177                attrs.extend(vec![parse_quote!(#[allow(unused)])]);
178                Expr::Block(ExprBlock {
179                    attrs,
180                    label: label.clone(),
181                    block,
182                })
183            } else {
184                forloop_with_body(new_body)
185            }
186        } else {
187            forloop_with_body(new_body)
188        }
189    } else if let Expr::If(if_expr) = expr {
190        let ExprIf {
191            ref cond,
192            ref then_branch,
193            ref else_branch,
194            ..
195        } = *if_expr;
196        Expr::If(ExprIf {
197            cond: Box::new(unroll(cond, unroll_by)),
198            then_branch: unroll_in_block(then_branch, unroll_by),
199            else_branch: else_branch
200                .as_ref()
201                .map(|x| (x.0, Box::new(unroll(&x.1, unroll_by)))),
202            ..(*if_expr).clone()
203        })
204    } else if let Expr::Let(let_expr) = expr {
205        let ExprLet { ref expr, .. } = *let_expr;
206        Expr::Let(ExprLet {
207            expr: Box::new(unroll(expr, unroll_by)),
208            ..(*let_expr).clone()
209        })
210    } else if let Expr::Block(expr_block) = expr {
211        let ExprBlock { ref block, .. } = *expr_block;
212        Expr::Block(ExprBlock {
213            block: unroll_in_block(block, unroll_by),
214            ..(*expr_block).clone()
215        })
216    } else {
217        (*expr).clone()
218    }
219}
220
221#[test]
222fn test_expand() {
223    use quote::ToTokens;
224    let for_loop: Block = parse_quote! {{
225        let mut sum = 0;
226        for i in 0..8 {
227            sum += i;
228        }
229    }};
230    println!("{}", unroll_in_block(&for_loop, 12).to_token_stream());
231}