1#![recursion_limit = "128"]
2
3use syn::{Block, Expr, ExprBlock, ExprForLoop, ExprLit, ExprRange, Item, ItemFn, Lit, Pat,
38 PatIdent, RangeLimits, Stmt, ExprIf, ExprLet, parse_quote};
39use syn::token::Brace;
40use proc_macro::TokenStream;
41use quote::quote;
42
43#[proc_macro_attribute]
45pub fn unroll_for_loops(_meta: TokenStream, input: TokenStream) -> TokenStream {
46 let item: Item = syn::parse(input).expect("Failed to parse input.");
47
48 if let Item::Fn(item_fn) = item {
49 let new_block = {
50 let &ItemFn {
51 block: ref box_block,
52 ..
53 } = &item_fn;
54 unroll_in_block(&**box_block)
55 };
56 let new_item = Item::Fn(ItemFn {
57 block: Box::new(new_block),
58 ..item_fn
59 });
60 quote! ( #new_item ).into()
61 } else {
62 quote! ( #item ).into()
63 }
64}
65
66fn unroll_in_block(block: &Block) -> Block {
68 let &Block {
69 ref brace_token,
70 ref stmts,
71 } = block;
72 let mut new_stmts = Vec::new();
73 for stmt in stmts.iter() {
74 if let &Stmt::Expr(ref expr) = stmt {
75 new_stmts.push(Stmt::Expr(unroll(expr)));
76 } else if let &Stmt::Semi(ref expr, semi) = stmt {
77 new_stmts.push(Stmt::Semi(unroll(expr), semi));
78 } else {
79 new_stmts.push((*stmt).clone());
80 }
81 }
82 Block {
83 brace_token: brace_token.clone(),
84 stmts: new_stmts,
85 }
86}
87
88fn unroll(expr: &Expr) -> Expr {
91 if let &Expr::ForLoop(ref for_loop) = expr {
93 let ExprForLoop {
94 ref attrs,
95 ref label,
96 ref pat,
97 expr: ref range_expr,
98 ref body,
99 ..
100 } = *for_loop;
101
102 let new_body = unroll_in_block(&*body);
103
104 let forloop_with_body = |body| {
105 Expr::ForLoop(ExprForLoop {
106 body,
107 ..(*for_loop).clone()
108 })
109 };
110
111 if let Pat::Ident(PatIdent {
112 ref by_ref,
113 ref mutability,
114 ref ident,
115 ref subpat,
116 ..
117 }) = *pat
118 {
119 if !by_ref.is_none() || !mutability.is_none() || !subpat.is_none() {
121 return forloop_with_body(new_body);
122 }
123 let idx = ident; if let Expr::Range(ExprRange {
126 from: ref mb_box_from,
127 ref limits,
128 to: ref mb_box_to,
129 ..
130 }) = **range_expr
131 {
132 let begin = if let Some(ref box_from) = *mb_box_from {
134 if let Expr::Lit(ExprLit {
135 lit: Lit::Int(ref lit_int),
136 ..
137 }) = **box_from
138 {
139 lit_int.base10_parse::<usize>().unwrap()
140 } else {
141 return forloop_with_body(new_body);
142 }
143 } else {
144 0
145 };
146
147 let end = if let Some(ref box_to) = *mb_box_to {
149 if let Expr::Lit(ExprLit {
150 lit: Lit::Int(ref lit_int),
151 ..
152 }) = **box_to
153 {
154 lit_int.base10_parse::<usize>().unwrap()
155 } else {
156 return forloop_with_body(new_body);
157 }
158 } else {
159 return forloop_with_body(new_body);
161 } + if let &RangeLimits::Closed(_) = limits {
162 1
163 } else {
164 0
165 };
166
167 let mut stmts = Vec::new();
168 for i in begin..end {
169 let declare_i: Stmt = parse_quote! {
170 #[allow(non_upper_case_globals)]
171 const #idx: usize = #i;
172 };
173 let mut augmented_body = new_body.clone();
174 augmented_body.stmts.insert(0, declare_i);
175 stmts.push(parse_quote! { #augmented_body });
176 }
177 let block = Block {
178 brace_token: Brace::default(),
179 stmts,
180 };
181 return Expr::Block(ExprBlock {
182 attrs: attrs.clone(),
183 label: label.clone(),
184 block,
185 });
186 } else {
187 forloop_with_body(new_body)
188 }
189 } else {
190 forloop_with_body(new_body)
191 }
192 } else if let &Expr::If(ref if_expr) = expr {
193 let ExprIf {
194 ref cond,
195 ref then_branch,
196 ref else_branch,
197 ..
198 } = *if_expr;
199 Expr::If(ExprIf {
200 cond: Box::new(unroll(&**cond)),
201 then_branch: unroll_in_block(&*then_branch),
202 else_branch: else_branch.as_ref().map(|x| (x.0, Box::new(unroll(&*x.1)))),
203 ..(*if_expr).clone()
204 })
205 } else if let &Expr::Let(ref let_expr) = expr {
206 let ExprLet {
207 ref expr,
208 ..
209 } = *let_expr;
210 Expr::Let(ExprLet {
211 expr: Box::new(unroll(&**expr)),
212 ..(*let_expr).clone()
213 })
214 } else if let &Expr::Block(ref expr_block) = expr {
215 let ExprBlock { ref block, .. } = *expr_block;
216 Expr::Block(ExprBlock {
217 block: unroll_in_block(&*block),
218 ..(*expr_block).clone()
219 })
220 } else {
221 (*expr).clone()
222 }
223}