Skip to main content

p3_dft/
butterflies.rs

1use core::mem::MaybeUninit;
2
3use itertools::izip;
4use p3_field::{Field, PackedField, PackedValue};
5
6/// A butterfly operation used in NTT to combine two values into a new pair.
7///
8/// This trait defines how to transform two elements (or vectors of elements)
9/// according to the structure of a butterfly gate.
10///
11/// In an NTT, butterflies are the core units that recursively combine values
12/// across layers. Each butterfly computes:
13/// ```text
14///   (a + b * twiddle, a - b * twiddle)   // DIT
15/// or
16///   (a + b, (a - b) * twiddle)           // DIF
17/// ```
18/// The transformation can be applied:
19/// - in-place (mutating input values)
20/// - to full rows of values (arrays of field elements)
21/// - out-of-place (writing results to separate destination buffers)
22///
23/// Different butterfly variants (DIT, DIF, or twiddle-free) define the exact formula.
24pub trait Butterfly<F: Field>: Copy + Send + Sync {
25    /// Applies the butterfly transformation to two packed field values.
26    ///
27    /// This method takes two inputs `x_1` and `x_2` and returns two outputs `(y_1, y_2)`
28    /// depending on the butterfly type.
29    /// ```text
30    /// Example (DIF):
31    ///   Input:  x_1 = a, x_2 = b
32    ///   Output: (a + b, (a - b) * twiddle)
33    /// ```
34    fn apply<PF: PackedField<Scalar = F>>(&self, x_1: PF, x_2: PF) -> (PF, PF);
35
36    /// Applies the butterfly in-place to two packed values.
37    ///
38    /// Mutates both `x_1` and `x_2` directly, storing the result of `apply`.
39    #[inline]
40    fn apply_in_place<PF: PackedField<Scalar = F>>(&self, x_1: &mut PF, x_2: &mut PF) {
41        (*x_1, *x_2) = self.apply(*x_1, *x_2);
42    }
43
44    /// Applies the butterfly transformation to two rows of scalar field values.
45    ///
46    /// Each row is a slice of `F`. This function processes the rows in packed
47    /// chunks using SIMD where possible, and falls back to scalar operations
48    /// for the suffix (remaining elements).
49    ///
50    /// The transformation is done in-place.
51    #[inline]
52    fn apply_to_rows(&self, row_1: &mut [F], row_2: &mut [F]) {
53        let (shorts_1, suffix_1) = F::Packing::pack_slice_with_suffix_mut(row_1);
54        let (shorts_2, suffix_2) = F::Packing::pack_slice_with_suffix_mut(row_2);
55        debug_assert_eq!(shorts_1.len(), shorts_2.len());
56        debug_assert_eq!(suffix_1.len(), suffix_2.len());
57        for (x_1, x_2) in shorts_1.iter_mut().zip(shorts_2) {
58            self.apply_in_place(x_1, x_2);
59        }
60        for (x_1, x_2) in suffix_1.iter_mut().zip(suffix_2) {
61            self.apply_in_place(x_1, x_2);
62        }
63    }
64
65    /// Applies the butterfly out-of-place to two source rows.
66    ///
67    /// This version does not overwrite the source. Instead, it writes the
68    /// result of each butterfly to separate destination slices (which may
69    /// be uninitialized memory).
70    ///
71    /// This is useful when performing LDE's where the size of the output is larger than the size of the input.
72    ///
73    /// - `src_1`, `src_2`: input slices
74    /// - `dst_1`, `dst_2`: output slices to write to (must be MaybeUninit)
75    #[inline]
76    fn apply_to_rows_oop(
77        &self,
78        src_1: &[F],
79        dst_1: &mut [MaybeUninit<F>],
80        src_2: &[F],
81        dst_2: &mut [MaybeUninit<F>],
82    ) {
83        let (src_shorts_1, src_suffix_1) = F::Packing::pack_slice_with_suffix(src_1);
84        let (src_shorts_2, src_suffix_2) = F::Packing::pack_slice_with_suffix(src_2);
85        let (dst_shorts_1, dst_suffix_1) =
86            F::Packing::pack_maybe_uninit_slice_with_suffix_mut(dst_1);
87        let (dst_shorts_2, dst_suffix_2) =
88            F::Packing::pack_maybe_uninit_slice_with_suffix_mut(dst_2);
89        debug_assert_eq!(src_shorts_1.len(), src_shorts_2.len());
90        debug_assert_eq!(src_suffix_1.len(), src_suffix_2.len());
91        debug_assert_eq!(dst_shorts_1.len(), dst_shorts_2.len());
92        debug_assert_eq!(dst_suffix_1.len(), dst_suffix_2.len());
93        for (s_1, s_2, d_1, d_2) in izip!(src_shorts_1, src_shorts_2, dst_shorts_1, dst_shorts_2) {
94            let (res_1, res_2) = self.apply(*s_1, *s_2);
95            d_1.write(res_1);
96            d_2.write(res_2);
97        }
98        for (s_1, s_2, d_1, d_2) in izip!(src_suffix_1, src_suffix_2, dst_suffix_1, dst_suffix_2) {
99            let (res_1, res_2) = self.apply(*s_1, *s_2);
100            d_1.write(res_1);
101            d_2.write(res_2);
102        }
103    }
104}
105
106/// DIF (Decimation-In-Frequency) butterfly operation.
107///
108/// Used in the *output-ordering* variant of NTT.
109/// This butterfly computes:
110/// ```text
111///   output_1 = x1 + x2
112///   output_2 = (x1 - x2) * twiddle
113/// ```
114/// The twiddle factor is applied after subtraction.
115/// Suitable for DIF-style recursive transforms.
116#[derive(Copy, Clone)]
117#[repr(transparent)] // Allows safe transmutes from F to this.
118pub struct DifButterfly<F>(pub F);
119
120impl<F: Field> Butterfly<F> for DifButterfly<F> {
121    #[inline]
122    fn apply<PF: PackedField<Scalar = F>>(&self, x_1: PF, x_2: PF) -> (PF, PF) {
123        (x_1 + x_2, (x_1 - x_2) * self.0)
124    }
125
126    /// Override `apply_to_rows` to pre-broadcast the twiddle factor into a packed field
127    /// once before the inner loop, and manually unroll it to expose multiple independent
128    /// sub-then-mul chains to the compiler's scheduler, hiding the multiplication latency.
129    /// Mirrors the [`DitButterfly`] override.
130    #[inline]
131    fn apply_to_rows(&self, row_1: &mut [F], row_2: &mut [F]) {
132        let (shorts_1, suffix_1) = F::Packing::pack_slice_with_suffix_mut(row_1);
133        let (shorts_2, suffix_2) = F::Packing::pack_slice_with_suffix_mut(row_2);
134        debug_assert_eq!(shorts_1.len(), shorts_2.len());
135        debug_assert_eq!(suffix_1.len(), suffix_2.len());
136        let twiddle_packed = F::Packing::from(self.0);
137        let mut c1 = shorts_1.chunks_exact_mut(4);
138        let mut c2 = shorts_2.chunks_exact_mut(4);
139        for (p1, p2) in (&mut c1).zip(&mut c2) {
140            let a1 = p1[0];
141            let b1 = p1[1];
142            let c1_ = p1[2];
143            let d1 = p1[3];
144            let a2 = p2[0];
145            let b2 = p2[1];
146            let c2_ = p2[2];
147            let d2 = p2[3];
148            p1[0] = a1 + a2;
149            p1[1] = b1 + b2;
150            p1[2] = c1_ + c2_;
151            p1[3] = d1 + d2;
152            p2[0] = (a1 - a2) * twiddle_packed;
153            p2[1] = (b1 - b2) * twiddle_packed;
154            p2[2] = (c1_ - c2_) * twiddle_packed;
155            p2[3] = (d1 - d2) * twiddle_packed;
156        }
157        for (x_1, x_2) in c1
158            .into_remainder()
159            .iter_mut()
160            .zip(c2.into_remainder().iter_mut())
161        {
162            let sum = *x_1 + *x_2;
163            *x_2 = (*x_1 - *x_2) * twiddle_packed;
164            *x_1 = sum;
165        }
166        for (x_1, x_2) in suffix_1.iter_mut().zip(suffix_2.iter_mut()) {
167            self.apply_in_place(x_1, x_2);
168        }
169    }
170}
171
172/// DIF (Decimation-In-Frequency) butterfly operation where `x_2` is guaranteed to be zero.
173///
174/// Useful in scenarios where the input has just been padded with zeros.
175///
176/// Used in the *output-ordering* variant of NTT.
177/// This butterfly computes:
178/// ```text
179///   output_1 = x1
180///   output_2 = x1 * twiddle
181/// ```
182#[derive(Copy, Clone)]
183#[repr(transparent)] // Allows safe transmutes from F to this.
184pub struct DifButterflyZeros<F>(pub F);
185
186impl<F: Field> Butterfly<F> for DifButterflyZeros<F> {
187    #[inline]
188    fn apply<PF: PackedField<Scalar = F>>(&self, x_1: PF, x_2: PF) -> (PF, PF) {
189        debug_assert!(x_2.as_slice().iter().all(|x| x.is_zero())); // Slightly convoluted but PF may not implement equality.
190        (x_1, x_1 * self.0)
191    }
192
193    #[inline]
194    fn apply_to_rows(&self, row_1: &mut [F], row_2: &mut [F]) {
195        let (shorts_1, suffix_1) = F::Packing::pack_slice_with_suffix(row_1);
196        let (shorts_2, suffix_2) = F::Packing::pack_slice_with_suffix_mut(row_2);
197        debug_assert_eq!(shorts_1.len(), shorts_2.len());
198        debug_assert_eq!(suffix_1.len(), suffix_2.len());
199        for (x_1, x_2) in shorts_1.iter().zip(shorts_2) {
200            debug_assert!(x_2.as_slice().iter().all(|x| x.is_zero())); // Slightly convoluted but PF may not implement equality.
201            *x_2 = *x_1 * self.0; // x_2 is guaranteed to be zero, so we just set it to x_1 * twiddle. 
202        }
203        for (x_1, x_2) in suffix_1.iter().zip(suffix_2) {
204            debug_assert!(x_2.is_zero());
205            *x_2 = *x_1 * self.0; // x_2 is guaranteed to be zero, so we just set it to x_1 * twiddle. 
206        }
207    }
208}
209
210/// DIT (Decimation-In-Time) butterfly operation.
211///
212/// Used in the *input-ordering* variant of NTT/FFT.
213/// This butterfly computes:
214/// ```text
215///   output_1 = x1 + x2 * twiddle
216///   output_2 = x1 - x2 * twiddle
217/// ```
218/// The twiddle factor is applied to x2 before combining.
219/// Suitable for DIT-style recursive transforms.
220#[derive(Copy, Clone)]
221#[repr(transparent)] // Allows safe transmutes from F to this.
222pub struct DitButterfly<F>(pub F);
223
224impl<F: Field> Butterfly<F> for DitButterfly<F> {
225    #[inline]
226    fn apply<PF: PackedField<Scalar = F>>(&self, x_1: PF, x_2: PF) -> (PF, PF) {
227        let x_2_twiddle = x_2 * self.0;
228        (x_1 + x_2_twiddle, x_1 - x_2_twiddle)
229    }
230
231    /// Override `apply_to_rows` to pre-broadcast the twiddle factor into a packed field
232    /// once before the inner loop, avoiding a scalar-to-vector broadcast on each packed
233    /// multiplication. For wide rows (e.g., 256 columns with AVX512 width=16, giving 16
234    /// packed iterations per row-pair), this eliminates 15 redundant broadcasts per call.
235    /// Manually unroll the inner packed loop to expose multiple independent mul chains
236    /// to the compiler's scheduler, hiding the ~12–15 cyc Montgomery mul latency.
237    #[inline]
238    fn apply_to_rows(&self, row_1: &mut [F], row_2: &mut [F]) {
239        let (shorts_1, suffix_1) = F::Packing::pack_slice_with_suffix_mut(row_1);
240        let (shorts_2, suffix_2) = F::Packing::pack_slice_with_suffix_mut(row_2);
241        debug_assert_eq!(shorts_1.len(), shorts_2.len());
242        debug_assert_eq!(suffix_1.len(), suffix_2.len());
243        let twiddle_packed = F::Packing::from(self.0);
244        let mut c1 = shorts_1.chunks_exact_mut(4);
245        let mut c2 = shorts_2.chunks_exact_mut(4);
246        for (p1, p2) in (&mut c1).zip(&mut c2) {
247            let a1 = p1[0];
248            let b1 = p1[1];
249            let c1_ = p1[2];
250            let d1 = p1[3];
251            let a2 = p2[0];
252            let b2 = p2[1];
253            let c2_ = p2[2];
254            let d2 = p2[3];
255            let a2t = a2 * twiddle_packed;
256            let b2t = b2 * twiddle_packed;
257            let c2t = c2_ * twiddle_packed;
258            let d2t = d2 * twiddle_packed;
259            p1[0] = a1 + a2t;
260            p2[0] = a1 - a2t;
261            p1[1] = b1 + b2t;
262            p2[1] = b1 - b2t;
263            p1[2] = c1_ + c2t;
264            p2[2] = c1_ - c2t;
265            p1[3] = d1 + d2t;
266            p2[3] = d1 - d2t;
267        }
268        for (x_1, x_2) in c1
269            .into_remainder()
270            .iter_mut()
271            .zip(c2.into_remainder().iter_mut())
272        {
273            let x_2_twiddle = *x_2 * twiddle_packed;
274            let new_x1 = *x_1 + x_2_twiddle;
275            *x_2 = *x_1 - x_2_twiddle;
276            *x_1 = new_x1;
277        }
278        for (x_1, x_2) in suffix_1.iter_mut().zip(suffix_2.iter_mut()) {
279            self.apply_in_place(x_1, x_2);
280        }
281    }
282
283    /// Out-of-place variant with matching unroll factor.
284    #[inline]
285    fn apply_to_rows_oop(
286        &self,
287        src_1: &[F],
288        dst_1: &mut [MaybeUninit<F>],
289        src_2: &[F],
290        dst_2: &mut [MaybeUninit<F>],
291    ) {
292        let (src_shorts_1, src_suffix_1) = F::Packing::pack_slice_with_suffix(src_1);
293        let (src_shorts_2, src_suffix_2) = F::Packing::pack_slice_with_suffix(src_2);
294        let (dst_shorts_1, dst_suffix_1) =
295            F::Packing::pack_maybe_uninit_slice_with_suffix_mut(dst_1);
296        let (dst_shorts_2, dst_suffix_2) =
297            F::Packing::pack_maybe_uninit_slice_with_suffix_mut(dst_2);
298        debug_assert_eq!(src_shorts_1.len(), src_shorts_2.len());
299        debug_assert_eq!(src_suffix_1.len(), src_suffix_2.len());
300        debug_assert_eq!(dst_shorts_1.len(), dst_shorts_2.len());
301        debug_assert_eq!(dst_suffix_1.len(), dst_suffix_2.len());
302        let twiddle_packed = F::Packing::from(self.0);
303        let n = src_shorts_1.len();
304        let n4 = n - (n & 3);
305        let mut i = 0;
306        while i < n4 {
307            let a1 = src_shorts_1[i];
308            let b1 = src_shorts_1[i + 1];
309            let c1 = src_shorts_1[i + 2];
310            let d1 = src_shorts_1[i + 3];
311            let a2 = src_shorts_2[i];
312            let b2 = src_shorts_2[i + 1];
313            let c2 = src_shorts_2[i + 2];
314            let d2 = src_shorts_2[i + 3];
315            let a2t = a2 * twiddle_packed;
316            let b2t = b2 * twiddle_packed;
317            let c2t = c2 * twiddle_packed;
318            let d2t = d2 * twiddle_packed;
319            dst_shorts_1[i].write(a1 + a2t);
320            dst_shorts_2[i].write(a1 - a2t);
321            dst_shorts_1[i + 1].write(b1 + b2t);
322            dst_shorts_2[i + 1].write(b1 - b2t);
323            dst_shorts_1[i + 2].write(c1 + c2t);
324            dst_shorts_2[i + 2].write(c1 - c2t);
325            dst_shorts_1[i + 3].write(d1 + d2t);
326            dst_shorts_2[i + 3].write(d1 - d2t);
327            i += 4;
328        }
329        while i < n {
330            let s1 = src_shorts_1[i];
331            let s2 = src_shorts_2[i];
332            let x_2_twiddle = s2 * twiddle_packed;
333            dst_shorts_1[i].write(s1 + x_2_twiddle);
334            dst_shorts_2[i].write(s1 - x_2_twiddle);
335            i += 1;
336        }
337        for (s_1, s_2, d_1, d_2) in izip!(src_suffix_1, src_suffix_2, dst_suffix_1, dst_suffix_2) {
338            let (res_1, res_2) = self.apply(*s_1, *s_2);
339            d_1.write(res_1);
340            d_2.write(res_2);
341        }
342    }
343}
344
345/// DIT (Decimation-In-Time) butterfly operation with a post-multiplication scale factor.
346///
347/// This butterfly computes:
348/// ```text
349///   output_1 = (x1 + x2 * twiddle) * scale
350///   output_2 = (x1 - x2 * twiddle) * scale
351/// ```
352/// which is equivalent to:
353/// ```text
354///   output_1 = x1 * scale + x2 * (twiddle * scale)
355///   output_2 = x1 * scale - x2 * (twiddle * scale)
356/// ```
357///
358/// This is used to merge a uniform scaling step (e.g., 1/N normalization in inverse DFT)
359/// into a butterfly pass, avoiding a separate memory pass over the data.
360///
361/// The struct stores `scale` and `twiddle_times_scale = twiddle * scale` so that the
362/// `apply` method only needs 2 multiplications instead of 3.
363#[derive(Copy, Clone)]
364pub struct ScaledDitButterfly<F> {
365    pub twiddle: F,
366    pub scale: F,
367    /// Precomputed product `twiddle * scale` to reduce multiplications in the hot loop.
368    pub twiddle_times_scale: F,
369}
370
371impl<F: Field> ScaledDitButterfly<F> {
372    /// Construct a `ScaledDitButterfly`, precomputing `twiddle * scale`.
373    #[inline]
374    pub fn new(twiddle: F, scale: F) -> Self {
375        Self {
376            twiddle,
377            scale,
378            twiddle_times_scale: twiddle * scale,
379        }
380    }
381}
382
383impl<F: Field> Butterfly<F> for ScaledDitButterfly<F> {
384    #[inline]
385    fn apply<PF: PackedField<Scalar = F>>(&self, x_1: PF, x_2: PF) -> (PF, PF) {
386        // 2 multiplications instead of 3:
387        //   x1_s   = x1 * scale
388        //   x2_ts  = x2 * (twiddle * scale)   [precomputed]
389        //   out1   = x1_s + x2_ts
390        //   out2   = x1_s - x2_ts
391        let x_1_scale = x_1 * self.scale;
392        let x_2_twiddle_scale = x_2 * self.twiddle_times_scale;
393        (x_1_scale + x_2_twiddle_scale, x_1_scale - x_2_twiddle_scale)
394    }
395
396    /// Override `apply_to_rows` to pre-broadcast both `scale` and `twiddle_times_scale`
397    /// into packed fields once before the inner loop.
398    #[inline]
399    fn apply_to_rows(&self, row_1: &mut [F], row_2: &mut [F]) {
400        let (shorts_1, suffix_1) = F::Packing::pack_slice_with_suffix_mut(row_1);
401        let (shorts_2, suffix_2) = F::Packing::pack_slice_with_suffix_mut(row_2);
402        debug_assert_eq!(shorts_1.len(), shorts_2.len());
403        debug_assert_eq!(suffix_1.len(), suffix_2.len());
404        let scale_packed = F::Packing::from(self.scale);
405        let twiddle_times_scale_packed = F::Packing::from(self.twiddle_times_scale);
406        // ScaledDitButterfly has 2 muls per butterfly (scale + twiddle_scale), so unroll-4
407        // exposes 8 independent mul chains — better ILP than unroll-2's 4 chains.
408        let mut c1 = shorts_1.chunks_exact_mut(4);
409        let mut c2 = shorts_2.chunks_exact_mut(4);
410        for (p1, p2) in (&mut c1).zip(&mut c2) {
411            let a1 = p1[0];
412            let b1 = p1[1];
413            let c1_ = p1[2];
414            let d1 = p1[3];
415            let a2 = p2[0];
416            let b2 = p2[1];
417            let c2_ = p2[2];
418            let d2 = p2[3];
419            let a1s = a1 * scale_packed;
420            let b1s = b1 * scale_packed;
421            let c1s = c1_ * scale_packed;
422            let d1s = d1 * scale_packed;
423            let a2t = a2 * twiddle_times_scale_packed;
424            let b2t = b2 * twiddle_times_scale_packed;
425            let c2t = c2_ * twiddle_times_scale_packed;
426            let d2t = d2 * twiddle_times_scale_packed;
427            p1[0] = a1s + a2t;
428            p2[0] = a1s - a2t;
429            p1[1] = b1s + b2t;
430            p2[1] = b1s - b2t;
431            p1[2] = c1s + c2t;
432            p2[2] = c1s - c2t;
433            p1[3] = d1s + d2t;
434            p2[3] = d1s - d2t;
435        }
436        for (x_1, x_2) in c1
437            .into_remainder()
438            .iter_mut()
439            .zip(c2.into_remainder().iter_mut())
440        {
441            let x_1_scale = *x_1 * scale_packed;
442            let x_2_twiddle_scale = *x_2 * twiddle_times_scale_packed;
443            *x_1 = x_1_scale + x_2_twiddle_scale;
444            *x_2 = x_1_scale - x_2_twiddle_scale;
445        }
446        for (x_1, x_2) in suffix_1.iter_mut().zip(suffix_2.iter_mut()) {
447            self.apply_in_place(x_1, x_2);
448        }
449    }
450}
451
452/// Butterfly with no twiddle factor (`twiddle = 1`).
453///
454/// This is used when no root-of-unity scaling is needed.
455/// It works for either DIT or DIF, and is often used at
456/// the final or base level of a transform tree.
457///
458/// This butterfly computes:
459/// ```text
460///   - output_1 = x1 + x2
461///   - output_2 = x1 - x2
462/// ```
463#[derive(Copy, Clone)]
464pub struct TwiddleFreeButterfly;
465
466impl<F: Field> Butterfly<F> for TwiddleFreeButterfly {
467    #[inline]
468    fn apply<PF: PackedField<Scalar = F>>(&self, x_1: PF, x_2: PF) -> (PF, PF) {
469        (x_1 + x_2, x_1 - x_2)
470    }
471}