1use core::mem::MaybeUninit;
2
3use itertools::izip;
4use p3_field::{Field, PackedField, PackedValue};
5
6pub trait Butterfly<F: Field>: Copy + Send + Sync {
25 fn apply<PF: PackedField<Scalar = F>>(&self, x_1: PF, x_2: PF) -> (PF, PF);
35
36 #[inline]
40 fn apply_in_place<PF: PackedField<Scalar = F>>(&self, x_1: &mut PF, x_2: &mut PF) {
41 (*x_1, *x_2) = self.apply(*x_1, *x_2);
42 }
43
44 #[inline]
52 fn apply_to_rows(&self, row_1: &mut [F], row_2: &mut [F]) {
53 let (shorts_1, suffix_1) = F::Packing::pack_slice_with_suffix_mut(row_1);
54 let (shorts_2, suffix_2) = F::Packing::pack_slice_with_suffix_mut(row_2);
55 debug_assert_eq!(shorts_1.len(), shorts_2.len());
56 debug_assert_eq!(suffix_1.len(), suffix_2.len());
57 for (x_1, x_2) in shorts_1.iter_mut().zip(shorts_2) {
58 self.apply_in_place(x_1, x_2);
59 }
60 for (x_1, x_2) in suffix_1.iter_mut().zip(suffix_2) {
61 self.apply_in_place(x_1, x_2);
62 }
63 }
64
65 #[inline]
76 fn apply_to_rows_oop(
77 &self,
78 src_1: &[F],
79 dst_1: &mut [MaybeUninit<F>],
80 src_2: &[F],
81 dst_2: &mut [MaybeUninit<F>],
82 ) {
83 let (src_shorts_1, src_suffix_1) = F::Packing::pack_slice_with_suffix(src_1);
84 let (src_shorts_2, src_suffix_2) = F::Packing::pack_slice_with_suffix(src_2);
85 let (dst_shorts_1, dst_suffix_1) =
86 F::Packing::pack_maybe_uninit_slice_with_suffix_mut(dst_1);
87 let (dst_shorts_2, dst_suffix_2) =
88 F::Packing::pack_maybe_uninit_slice_with_suffix_mut(dst_2);
89 debug_assert_eq!(src_shorts_1.len(), src_shorts_2.len());
90 debug_assert_eq!(src_suffix_1.len(), src_suffix_2.len());
91 debug_assert_eq!(dst_shorts_1.len(), dst_shorts_2.len());
92 debug_assert_eq!(dst_suffix_1.len(), dst_suffix_2.len());
93 for (s_1, s_2, d_1, d_2) in izip!(src_shorts_1, src_shorts_2, dst_shorts_1, dst_shorts_2) {
94 let (res_1, res_2) = self.apply(*s_1, *s_2);
95 d_1.write(res_1);
96 d_2.write(res_2);
97 }
98 for (s_1, s_2, d_1, d_2) in izip!(src_suffix_1, src_suffix_2, dst_suffix_1, dst_suffix_2) {
99 let (res_1, res_2) = self.apply(*s_1, *s_2);
100 d_1.write(res_1);
101 d_2.write(res_2);
102 }
103 }
104}
105
106#[derive(Copy, Clone)]
117#[repr(transparent)] pub struct DifButterfly<F>(pub F);
119
120impl<F: Field> Butterfly<F> for DifButterfly<F> {
121 #[inline]
122 fn apply<PF: PackedField<Scalar = F>>(&self, x_1: PF, x_2: PF) -> (PF, PF) {
123 (x_1 + x_2, (x_1 - x_2) * self.0)
124 }
125
126 #[inline]
131 fn apply_to_rows(&self, row_1: &mut [F], row_2: &mut [F]) {
132 let (shorts_1, suffix_1) = F::Packing::pack_slice_with_suffix_mut(row_1);
133 let (shorts_2, suffix_2) = F::Packing::pack_slice_with_suffix_mut(row_2);
134 debug_assert_eq!(shorts_1.len(), shorts_2.len());
135 debug_assert_eq!(suffix_1.len(), suffix_2.len());
136 let twiddle_packed = F::Packing::from(self.0);
137 let mut c1 = shorts_1.chunks_exact_mut(4);
138 let mut c2 = shorts_2.chunks_exact_mut(4);
139 for (p1, p2) in (&mut c1).zip(&mut c2) {
140 let a1 = p1[0];
141 let b1 = p1[1];
142 let c1_ = p1[2];
143 let d1 = p1[3];
144 let a2 = p2[0];
145 let b2 = p2[1];
146 let c2_ = p2[2];
147 let d2 = p2[3];
148 p1[0] = a1 + a2;
149 p1[1] = b1 + b2;
150 p1[2] = c1_ + c2_;
151 p1[3] = d1 + d2;
152 p2[0] = (a1 - a2) * twiddle_packed;
153 p2[1] = (b1 - b2) * twiddle_packed;
154 p2[2] = (c1_ - c2_) * twiddle_packed;
155 p2[3] = (d1 - d2) * twiddle_packed;
156 }
157 for (x_1, x_2) in c1
158 .into_remainder()
159 .iter_mut()
160 .zip(c2.into_remainder().iter_mut())
161 {
162 let sum = *x_1 + *x_2;
163 *x_2 = (*x_1 - *x_2) * twiddle_packed;
164 *x_1 = sum;
165 }
166 for (x_1, x_2) in suffix_1.iter_mut().zip(suffix_2.iter_mut()) {
167 self.apply_in_place(x_1, x_2);
168 }
169 }
170}
171
172#[derive(Copy, Clone)]
183#[repr(transparent)] pub struct DifButterflyZeros<F>(pub F);
185
186impl<F: Field> Butterfly<F> for DifButterflyZeros<F> {
187 #[inline]
188 fn apply<PF: PackedField<Scalar = F>>(&self, x_1: PF, x_2: PF) -> (PF, PF) {
189 debug_assert!(x_2.as_slice().iter().all(|x| x.is_zero())); (x_1, x_1 * self.0)
191 }
192
193 #[inline]
194 fn apply_to_rows(&self, row_1: &mut [F], row_2: &mut [F]) {
195 let (shorts_1, suffix_1) = F::Packing::pack_slice_with_suffix(row_1);
196 let (shorts_2, suffix_2) = F::Packing::pack_slice_with_suffix_mut(row_2);
197 debug_assert_eq!(shorts_1.len(), shorts_2.len());
198 debug_assert_eq!(suffix_1.len(), suffix_2.len());
199 for (x_1, x_2) in shorts_1.iter().zip(shorts_2) {
200 debug_assert!(x_2.as_slice().iter().all(|x| x.is_zero())); *x_2 = *x_1 * self.0; }
203 for (x_1, x_2) in suffix_1.iter().zip(suffix_2) {
204 debug_assert!(x_2.is_zero());
205 *x_2 = *x_1 * self.0; }
207 }
208}
209
210#[derive(Copy, Clone)]
221#[repr(transparent)] pub struct DitButterfly<F>(pub F);
223
224impl<F: Field> Butterfly<F> for DitButterfly<F> {
225 #[inline]
226 fn apply<PF: PackedField<Scalar = F>>(&self, x_1: PF, x_2: PF) -> (PF, PF) {
227 let x_2_twiddle = x_2 * self.0;
228 (x_1 + x_2_twiddle, x_1 - x_2_twiddle)
229 }
230
231 #[inline]
238 fn apply_to_rows(&self, row_1: &mut [F], row_2: &mut [F]) {
239 let (shorts_1, suffix_1) = F::Packing::pack_slice_with_suffix_mut(row_1);
240 let (shorts_2, suffix_2) = F::Packing::pack_slice_with_suffix_mut(row_2);
241 debug_assert_eq!(shorts_1.len(), shorts_2.len());
242 debug_assert_eq!(suffix_1.len(), suffix_2.len());
243 let twiddle_packed = F::Packing::from(self.0);
244 let mut c1 = shorts_1.chunks_exact_mut(4);
245 let mut c2 = shorts_2.chunks_exact_mut(4);
246 for (p1, p2) in (&mut c1).zip(&mut c2) {
247 let a1 = p1[0];
248 let b1 = p1[1];
249 let c1_ = p1[2];
250 let d1 = p1[3];
251 let a2 = p2[0];
252 let b2 = p2[1];
253 let c2_ = p2[2];
254 let d2 = p2[3];
255 let a2t = a2 * twiddle_packed;
256 let b2t = b2 * twiddle_packed;
257 let c2t = c2_ * twiddle_packed;
258 let d2t = d2 * twiddle_packed;
259 p1[0] = a1 + a2t;
260 p2[0] = a1 - a2t;
261 p1[1] = b1 + b2t;
262 p2[1] = b1 - b2t;
263 p1[2] = c1_ + c2t;
264 p2[2] = c1_ - c2t;
265 p1[3] = d1 + d2t;
266 p2[3] = d1 - d2t;
267 }
268 for (x_1, x_2) in c1
269 .into_remainder()
270 .iter_mut()
271 .zip(c2.into_remainder().iter_mut())
272 {
273 let x_2_twiddle = *x_2 * twiddle_packed;
274 let new_x1 = *x_1 + x_2_twiddle;
275 *x_2 = *x_1 - x_2_twiddle;
276 *x_1 = new_x1;
277 }
278 for (x_1, x_2) in suffix_1.iter_mut().zip(suffix_2.iter_mut()) {
279 self.apply_in_place(x_1, x_2);
280 }
281 }
282
283 #[inline]
285 fn apply_to_rows_oop(
286 &self,
287 src_1: &[F],
288 dst_1: &mut [MaybeUninit<F>],
289 src_2: &[F],
290 dst_2: &mut [MaybeUninit<F>],
291 ) {
292 let (src_shorts_1, src_suffix_1) = F::Packing::pack_slice_with_suffix(src_1);
293 let (src_shorts_2, src_suffix_2) = F::Packing::pack_slice_with_suffix(src_2);
294 let (dst_shorts_1, dst_suffix_1) =
295 F::Packing::pack_maybe_uninit_slice_with_suffix_mut(dst_1);
296 let (dst_shorts_2, dst_suffix_2) =
297 F::Packing::pack_maybe_uninit_slice_with_suffix_mut(dst_2);
298 debug_assert_eq!(src_shorts_1.len(), src_shorts_2.len());
299 debug_assert_eq!(src_suffix_1.len(), src_suffix_2.len());
300 debug_assert_eq!(dst_shorts_1.len(), dst_shorts_2.len());
301 debug_assert_eq!(dst_suffix_1.len(), dst_suffix_2.len());
302 let twiddle_packed = F::Packing::from(self.0);
303 let n = src_shorts_1.len();
304 let n4 = n - (n & 3);
305 let mut i = 0;
306 while i < n4 {
307 let a1 = src_shorts_1[i];
308 let b1 = src_shorts_1[i + 1];
309 let c1 = src_shorts_1[i + 2];
310 let d1 = src_shorts_1[i + 3];
311 let a2 = src_shorts_2[i];
312 let b2 = src_shorts_2[i + 1];
313 let c2 = src_shorts_2[i + 2];
314 let d2 = src_shorts_2[i + 3];
315 let a2t = a2 * twiddle_packed;
316 let b2t = b2 * twiddle_packed;
317 let c2t = c2 * twiddle_packed;
318 let d2t = d2 * twiddle_packed;
319 dst_shorts_1[i].write(a1 + a2t);
320 dst_shorts_2[i].write(a1 - a2t);
321 dst_shorts_1[i + 1].write(b1 + b2t);
322 dst_shorts_2[i + 1].write(b1 - b2t);
323 dst_shorts_1[i + 2].write(c1 + c2t);
324 dst_shorts_2[i + 2].write(c1 - c2t);
325 dst_shorts_1[i + 3].write(d1 + d2t);
326 dst_shorts_2[i + 3].write(d1 - d2t);
327 i += 4;
328 }
329 while i < n {
330 let s1 = src_shorts_1[i];
331 let s2 = src_shorts_2[i];
332 let x_2_twiddle = s2 * twiddle_packed;
333 dst_shorts_1[i].write(s1 + x_2_twiddle);
334 dst_shorts_2[i].write(s1 - x_2_twiddle);
335 i += 1;
336 }
337 for (s_1, s_2, d_1, d_2) in izip!(src_suffix_1, src_suffix_2, dst_suffix_1, dst_suffix_2) {
338 let (res_1, res_2) = self.apply(*s_1, *s_2);
339 d_1.write(res_1);
340 d_2.write(res_2);
341 }
342 }
343}
344
345#[derive(Copy, Clone)]
364pub struct ScaledDitButterfly<F> {
365 pub twiddle: F,
366 pub scale: F,
367 pub twiddle_times_scale: F,
369}
370
371impl<F: Field> ScaledDitButterfly<F> {
372 #[inline]
374 pub fn new(twiddle: F, scale: F) -> Self {
375 Self {
376 twiddle,
377 scale,
378 twiddle_times_scale: twiddle * scale,
379 }
380 }
381}
382
383impl<F: Field> Butterfly<F> for ScaledDitButterfly<F> {
384 #[inline]
385 fn apply<PF: PackedField<Scalar = F>>(&self, x_1: PF, x_2: PF) -> (PF, PF) {
386 let x_1_scale = x_1 * self.scale;
392 let x_2_twiddle_scale = x_2 * self.twiddle_times_scale;
393 (x_1_scale + x_2_twiddle_scale, x_1_scale - x_2_twiddle_scale)
394 }
395
396 #[inline]
399 fn apply_to_rows(&self, row_1: &mut [F], row_2: &mut [F]) {
400 let (shorts_1, suffix_1) = F::Packing::pack_slice_with_suffix_mut(row_1);
401 let (shorts_2, suffix_2) = F::Packing::pack_slice_with_suffix_mut(row_2);
402 debug_assert_eq!(shorts_1.len(), shorts_2.len());
403 debug_assert_eq!(suffix_1.len(), suffix_2.len());
404 let scale_packed = F::Packing::from(self.scale);
405 let twiddle_times_scale_packed = F::Packing::from(self.twiddle_times_scale);
406 let mut c1 = shorts_1.chunks_exact_mut(4);
409 let mut c2 = shorts_2.chunks_exact_mut(4);
410 for (p1, p2) in (&mut c1).zip(&mut c2) {
411 let a1 = p1[0];
412 let b1 = p1[1];
413 let c1_ = p1[2];
414 let d1 = p1[3];
415 let a2 = p2[0];
416 let b2 = p2[1];
417 let c2_ = p2[2];
418 let d2 = p2[3];
419 let a1s = a1 * scale_packed;
420 let b1s = b1 * scale_packed;
421 let c1s = c1_ * scale_packed;
422 let d1s = d1 * scale_packed;
423 let a2t = a2 * twiddle_times_scale_packed;
424 let b2t = b2 * twiddle_times_scale_packed;
425 let c2t = c2_ * twiddle_times_scale_packed;
426 let d2t = d2 * twiddle_times_scale_packed;
427 p1[0] = a1s + a2t;
428 p2[0] = a1s - a2t;
429 p1[1] = b1s + b2t;
430 p2[1] = b1s - b2t;
431 p1[2] = c1s + c2t;
432 p2[2] = c1s - c2t;
433 p1[3] = d1s + d2t;
434 p2[3] = d1s - d2t;
435 }
436 for (x_1, x_2) in c1
437 .into_remainder()
438 .iter_mut()
439 .zip(c2.into_remainder().iter_mut())
440 {
441 let x_1_scale = *x_1 * scale_packed;
442 let x_2_twiddle_scale = *x_2 * twiddle_times_scale_packed;
443 *x_1 = x_1_scale + x_2_twiddle_scale;
444 *x_2 = x_1_scale - x_2_twiddle_scale;
445 }
446 for (x_1, x_2) in suffix_1.iter_mut().zip(suffix_2.iter_mut()) {
447 self.apply_in_place(x_1, x_2);
448 }
449 }
450}
451
452#[derive(Copy, Clone)]
464pub struct TwiddleFreeButterfly;
465
466impl<F: Field> Butterfly<F> for TwiddleFreeButterfly {
467 #[inline]
468 fn apply<PF: PackedField<Scalar = F>>(&self, x_1: PF, x_2: PF) -> (PF, PF) {
469 (x_1 + x_2, x_1 - x_2)
470 }
471}