unroll/
lib.rs

1#![recursion_limit = "128"]
2
3//! An attribute-like procedural macro for unrolling for loops with integer literal bounds.
4//! 
5//! This crate provides the [`unroll_for_loops`] attribute-like macro that can be applied to
6//! functions containing for-loops with integer bounds. This macro looks for loops to unroll and
7//! unrolls them at compile time.
8//! 
9//!
10//! ## Usage
11//! 
12//! Just add `#[unroll_for_loops]` above the function whose for loops you would like to unroll.
13//! Currently all for loops with integer literal bounds will be unrolled, although this macro
14//! currently can't see inside complex code (e.g. for loops within closures).
15//! 
16//! 
17//! ## Example
18//! 
19//! The following function computes a matrix-vector product and returns the result as an array.
20//! Both of the inner for-loops are unrolled when `#[unroll_for_loops]` is applied.
21//! 
22//! ```rust
23//! use unroll::unroll_for_loops;
24//!
25//! #[unroll_for_loops]
26//! fn mtx_vec_mul(mtx: &[[f64; 5]; 5], vec: &[f64; 5]) -> [f64; 5] {
27//!     let mut out = [0.0; 5];
28//!     for col in 0..5 {
29//!         for row in 0..5 {
30//!             out[row] += mtx[col][row] * vec[col];
31//!         }
32//!     }
33//!     out
34//! }
35//! ```
36
37use syn::{Block, Expr, ExprBlock, ExprForLoop, ExprLit, ExprRange, Item, ItemFn, Lit, Pat,
38          PatIdent, RangeLimits, Stmt, ExprIf, ExprLet, parse_quote};
39use syn::token::Brace;
40use proc_macro::TokenStream;
41use quote::quote;
42
43/// Attribute used to unroll for loops found inside a function block.
44#[proc_macro_attribute]
45pub fn unroll_for_loops(_meta: TokenStream, input: TokenStream) -> TokenStream {
46    let item: Item = syn::parse(input).expect("Failed to parse input.");
47
48    if let Item::Fn(item_fn) = item {
49        let new_block = {
50            let &ItemFn {
51                block: ref box_block,
52                ..
53            } = &item_fn;
54            unroll_in_block(&**box_block)
55        };
56        let new_item = Item::Fn(ItemFn {
57            block: Box::new(new_block),
58            ..item_fn
59        });
60        quote! ( #new_item ).into()
61    } else {
62        quote! ( #item ).into()
63    }
64}
65
66/// Routine to unroll for loops within a block
67fn unroll_in_block(block: &Block) -> Block {
68    let &Block {
69        ref brace_token,
70        ref stmts,
71    } = block;
72    let mut new_stmts = Vec::new();
73    for stmt in stmts.iter() {
74        if let &Stmt::Expr(ref expr) = stmt {
75            new_stmts.push(Stmt::Expr(unroll(expr)));
76        } else if let &Stmt::Semi(ref expr, semi) = stmt {
77            new_stmts.push(Stmt::Semi(unroll(expr), semi));
78        } else {
79            new_stmts.push((*stmt).clone());
80        }
81    }
82    Block {
83        brace_token: brace_token.clone(),
84        stmts: new_stmts,
85    }
86}
87
88/// Routine to unroll a for loop statement, or return the statement unchanged if it's not a for
89/// loop.
90fn unroll(expr: &Expr) -> Expr {
91    // impose a scope that we can break out of so we can return stmt without copying it.
92    if let &Expr::ForLoop(ref for_loop) = expr {
93        let ExprForLoop {
94            ref attrs,
95            ref label,
96            ref pat,
97            expr: ref range_expr,
98            ref body,
99            ..
100        } = *for_loop;
101
102        let new_body = unroll_in_block(&*body);
103
104        let forloop_with_body = |body| {
105            Expr::ForLoop(ExprForLoop {
106                body,
107                ..(*for_loop).clone()
108            })
109        };
110
111        if let Pat::Ident(PatIdent {
112            ref by_ref,
113            ref mutability,
114            ref ident,
115            ref subpat,
116            ..
117        }) = *pat
118        {
119            // Don't know how to deal with these so skip and return the original.
120            if !by_ref.is_none() || !mutability.is_none() || !subpat.is_none() {
121                return forloop_with_body(new_body);
122            }
123            let idx = ident; // got the index variable name
124
125            if let Expr::Range(ExprRange {
126                from: ref mb_box_from,
127                ref limits,
128                to: ref mb_box_to,
129                ..
130            }) = **range_expr
131            {
132                // Parse mb_box_from
133                let begin = if let Some(ref box_from) = *mb_box_from {
134                    if let Expr::Lit(ExprLit {
135                        lit: Lit::Int(ref lit_int),
136                        ..
137                    }) = **box_from
138                    {
139                        lit_int.base10_parse::<usize>().unwrap()
140                    } else {
141                        return forloop_with_body(new_body);
142                    }
143                } else {
144                    0
145                };
146
147                // Parse mb_box_to
148                let end = if let Some(ref box_to) = *mb_box_to {
149                    if let Expr::Lit(ExprLit {
150                        lit: Lit::Int(ref lit_int),
151                        ..
152                    }) = **box_to
153                    {
154                        lit_int.base10_parse::<usize>().unwrap()
155                    } else {
156                        return forloop_with_body(new_body);
157                    }
158                } else {
159                    // we need to know where the limit is to know how much to unroll by.
160                    return forloop_with_body(new_body);
161                } + if let &RangeLimits::Closed(_) = limits {
162                    1
163                } else {
164                    0
165                };
166
167                let mut stmts = Vec::new();
168                for i in begin..end {
169                    let declare_i: Stmt = parse_quote! {
170                        #[allow(non_upper_case_globals)]
171                        const #idx: usize = #i;
172                    };
173                    let mut augmented_body = new_body.clone();
174                    augmented_body.stmts.insert(0, declare_i);
175                    stmts.push(parse_quote! { #augmented_body });
176                }
177                let block = Block {
178                    brace_token: Brace::default(),
179                    stmts,
180                };
181                return Expr::Block(ExprBlock {
182                    attrs: attrs.clone(),
183                    label: label.clone(),
184                    block,
185                });
186            } else {
187                forloop_with_body(new_body)
188            }
189        } else {
190            forloop_with_body(new_body)
191        }
192    } else if let &Expr::If(ref if_expr) = expr {
193        let ExprIf {
194            ref cond,
195            ref then_branch,
196            ref else_branch,
197            ..
198        } = *if_expr;
199        Expr::If(ExprIf {
200            cond: Box::new(unroll(&**cond)),
201            then_branch: unroll_in_block(&*then_branch),
202            else_branch: else_branch.as_ref().map(|x| (x.0, Box::new(unroll(&*x.1)))),
203            ..(*if_expr).clone()
204        })
205    } else if let &Expr::Let(ref let_expr) = expr {
206        let ExprLet {
207            ref expr,
208            ..
209        } = *let_expr;
210        Expr::Let(ExprLet {
211            expr: Box::new(unroll(&**expr)),
212            ..(*let_expr).clone()
213        })
214    } else if let &Expr::Block(ref expr_block) = expr {
215        let ExprBlock { ref block, .. } = *expr_block;
216        Expr::Block(ExprBlock {
217            block: unroll_in_block(&*block),
218            ..(*expr_block).clone()
219        })
220    } else {
221        (*expr).clone()
222    }
223}