1use std::borrow::Borrow;
55
56use syn::{
57 parse_quote, token::Brace, Block, Expr, ExprBlock, ExprForLoop, ExprIf, ExprLet, ExprRange,
58 Pat, PatIdent, RangeLimits, Stmt,
59};
60
61pub(crate) fn unroll_in_block(block: &Block, unroll_by: usize) -> Block {
63 let &Block {
64 ref brace_token,
65 ref stmts,
66 } = block;
67 let mut new_stmts = Vec::new();
68 for stmt in stmts.iter() {
69 if let Stmt::Expr(expr, semi) = stmt {
70 new_stmts.push(Stmt::Expr(unroll(expr, unroll_by), *semi));
71 } else {
72 new_stmts.push((*stmt).clone());
73 }
74 }
75 Block {
76 brace_token: *brace_token,
77 stmts: new_stmts,
78 }
79}
80
81fn unroll(expr: &Expr, unroll_by: usize) -> Expr {
84 if let Expr::ForLoop(for_loop) = expr {
86 let ExprForLoop {
87 ref attrs,
88 ref label,
89 ref pat,
90 expr: ref range,
91 ref body,
92 ..
93 } = *for_loop;
94
95 let new_body = unroll_in_block(body, unroll_by);
96
97 let forloop_with_body = |body| {
98 Expr::ForLoop(ExprForLoop {
99 body,
100 ..(*for_loop).clone()
101 })
102 };
103
104 if let Pat::Ident(PatIdent {
105 by_ref,
106 mutability,
107 ident,
108 subpat,
109 ..
110 }) = *pat.clone()
111 {
112 if !by_ref.is_none() || !mutability.is_none() || !subpat.is_none() {
114 return forloop_with_body(new_body);
115 }
116 let idx = ident; if let Expr::Range(ExprRange {
119 start, limits, end, ..
120 }) = range.borrow()
121 {
122 let begin = match start {
124 Some(e) => e.clone(),
125 _ => Box::new(parse_quote!(0usize)),
126 };
127 let end = match end {
128 Some(e) => e.clone(),
129 _ => return forloop_with_body(new_body),
130 };
131 let end_is_closed = if let RangeLimits::Closed(_) = limits {
132 1usize
133 } else {
134 0
135 };
136 let end: Expr = parse_quote!(#end + #end_is_closed);
137
138 let preamble: Vec<Stmt> = parse_quote! {
139 let total_iters: usize = (#end).checked_sub(#begin).unwrap_or(0);
140 let num_loops = total_iters / #unroll_by;
141 let remainder = total_iters % #unroll_by;
142 };
143 let mut block = Block {
144 brace_token: Brace::default(),
145 stmts: preamble,
146 };
147 let mut loop_expr: ExprForLoop = parse_quote! {
148 for #idx in (0..num_loops) {
149 let mut #idx = #begin + #idx * #unroll_by;
150 }
151 };
152 let loop_block: Vec<Stmt> = parse_quote! {
153 if #idx < #end {
154 #new_body
155 }
156 #idx += 1;
157 };
158 let loop_body = (0..unroll_by).flat_map(|_| loop_block.clone());
159 loop_expr.body.stmts.extend(loop_body);
160 block.stmts.push(Stmt::Expr(Expr::ForLoop(loop_expr), None));
161
162 block
164 .stmts
165 .push(parse_quote! { let mut #idx = #begin + num_loops * #unroll_by; });
166 let post_loop_block: Vec<Stmt> = parse_quote! {
168 if #idx < #end {
169 #new_body
170 }
171 #idx += 1;
172 };
173 let post_loop = (0..unroll_by).flat_map(|_| post_loop_block.clone());
174 block.stmts.extend(post_loop);
175
176 let mut attrs = attrs.clone();
177 attrs.extend(vec![parse_quote!(#[allow(unused)])]);
178 Expr::Block(ExprBlock {
179 attrs,
180 label: label.clone(),
181 block,
182 })
183 } else {
184 forloop_with_body(new_body)
185 }
186 } else {
187 forloop_with_body(new_body)
188 }
189 } else if let Expr::If(if_expr) = expr {
190 let ExprIf {
191 ref cond,
192 ref then_branch,
193 ref else_branch,
194 ..
195 } = *if_expr;
196 Expr::If(ExprIf {
197 cond: Box::new(unroll(cond, unroll_by)),
198 then_branch: unroll_in_block(then_branch, unroll_by),
199 else_branch: else_branch
200 .as_ref()
201 .map(|x| (x.0, Box::new(unroll(&x.1, unroll_by)))),
202 ..(*if_expr).clone()
203 })
204 } else if let Expr::Let(let_expr) = expr {
205 let ExprLet { ref expr, .. } = *let_expr;
206 Expr::Let(ExprLet {
207 expr: Box::new(unroll(expr, unroll_by)),
208 ..(*let_expr).clone()
209 })
210 } else if let Expr::Block(expr_block) = expr {
211 let ExprBlock { ref block, .. } = *expr_block;
212 Expr::Block(ExprBlock {
213 block: unroll_in_block(block, unroll_by),
214 ..(*expr_block).clone()
215 })
216 } else {
217 (*expr).clone()
218 }
219}
220
221#[test]
222fn test_expand() {
223 use quote::ToTokens;
224 let for_loop: Block = parse_quote! {{
225 let mut sum = 0;
226 for i in 0..8 {
227 sum += i;
228 }
229 }};
230 println!("{}", unroll_in_block(&for_loop, 12).to_token_stream());
231}