p3_dft/butterflies.rs
1use core::mem::MaybeUninit;
2
3use itertools::izip;
4use p3_field::{Field, PackedField, PackedValue};
5
6/// A butterfly operation used in NTT to combine two values into a new pair.
7///
8/// This trait defines how to transform two elements (or vectors of elements)
9/// according to the structure of a butterfly gate.
10///
11/// In an NTT, butterflies are the core units that recursively combine values
12/// across layers. Each butterfly computes:
13/// ```text
14/// (a + b * twiddle, a - b * twiddle) // DIT
15/// or
16/// (a + b, (a - b) * twiddle) // DIF
17/// ```
18/// The transformation can be applied:
19/// - in-place (mutating input values)
20/// - to full rows of values (arrays of field elements)
21/// - out-of-place (writing results to separate destination buffers)
22///
23/// Different butterfly variants (DIT, DIF, or twiddle-free) define the exact formula.
24pub trait Butterfly<F: Field>: Copy + Send + Sync {
25 /// Applies the butterfly transformation to two packed field values.
26 ///
27 /// This method takes two inputs `x_1` and `x_2` and returns two outputs `(y_1, y_2)`
28 /// depending on the butterfly type.
29 /// ```text
30 /// Example (DIF):
31 /// Input: x_1 = a, x_2 = b
32 /// Output: (a + b, (a - b) * twiddle)
33 /// ```
34 fn apply<PF: PackedField<Scalar = F>>(&self, x_1: PF, x_2: PF) -> (PF, PF);
35
36 /// Applies the butterfly in-place to two packed values.
37 ///
38 /// Mutates both `x_1` and `x_2` directly, storing the result of `apply`.
39 #[inline]
40 fn apply_in_place<PF: PackedField<Scalar = F>>(&self, x_1: &mut PF, x_2: &mut PF) {
41 (*x_1, *x_2) = self.apply(*x_1, *x_2);
42 }
43
44 /// Applies the butterfly transformation to two rows of scalar field values.
45 ///
46 /// Each row is a slice of `F`. This function processes the rows in packed
47 /// chunks using SIMD where possible, and falls back to scalar operations
48 /// for the suffix (remaining elements).
49 ///
50 /// The transformation is done in-place.
51 #[inline]
52 fn apply_to_rows(&self, row_1: &mut [F], row_2: &mut [F]) {
53 let (shorts_1, suffix_1) = F::Packing::pack_slice_with_suffix_mut(row_1);
54 let (shorts_2, suffix_2) = F::Packing::pack_slice_with_suffix_mut(row_2);
55 debug_assert_eq!(shorts_1.len(), shorts_2.len());
56 debug_assert_eq!(suffix_1.len(), suffix_2.len());
57 for (x_1, x_2) in shorts_1.iter_mut().zip(shorts_2) {
58 self.apply_in_place(x_1, x_2);
59 }
60 for (x_1, x_2) in suffix_1.iter_mut().zip(suffix_2) {
61 self.apply_in_place(x_1, x_2);
62 }
63 }
64
65 /// Applies the butterfly out-of-place to two source rows.
66 ///
67 /// This version does not overwrite the source. Instead, it writes the
68 /// result of each butterfly to separate destination slices (which may
69 /// be uninitialized memory).
70 ///
71 /// This is useful when performing LDE's where the size of the output is larger than the size of the input.
72 ///
73 /// - `src_1`, `src_2`: input slices
74 /// - `dst_1`, `dst_2`: output slices to write to (must be MaybeUninit)
75 #[inline]
76 fn apply_to_rows_oop(
77 &self,
78 src_1: &[F],
79 dst_1: &mut [MaybeUninit<F>],
80 src_2: &[F],
81 dst_2: &mut [MaybeUninit<F>],
82 ) {
83 let (src_shorts_1, src_suffix_1) = F::Packing::pack_slice_with_suffix(src_1);
84 let (src_shorts_2, src_suffix_2) = F::Packing::pack_slice_with_suffix(src_2);
85 let (dst_shorts_1, dst_suffix_1) =
86 F::Packing::pack_maybe_uninit_slice_with_suffix_mut(dst_1);
87 let (dst_shorts_2, dst_suffix_2) =
88 F::Packing::pack_maybe_uninit_slice_with_suffix_mut(dst_2);
89 debug_assert_eq!(src_shorts_1.len(), src_shorts_2.len());
90 debug_assert_eq!(src_suffix_1.len(), src_suffix_2.len());
91 debug_assert_eq!(dst_shorts_1.len(), dst_shorts_2.len());
92 debug_assert_eq!(dst_suffix_1.len(), dst_suffix_2.len());
93 for (s_1, s_2, d_1, d_2) in izip!(src_shorts_1, src_shorts_2, dst_shorts_1, dst_shorts_2) {
94 let (res_1, res_2) = self.apply(*s_1, *s_2);
95 d_1.write(res_1);
96 d_2.write(res_2);
97 }
98 for (s_1, s_2, d_1, d_2) in izip!(src_suffix_1, src_suffix_2, dst_suffix_1, dst_suffix_2) {
99 let (res_1, res_2) = self.apply(*s_1, *s_2);
100 d_1.write(res_1);
101 d_2.write(res_2);
102 }
103 }
104}
105
106/// DIF (Decimation-In-Frequency) butterfly operation.
107///
108/// Used in the *output-ordering* variant of NTT.
109/// This butterfly computes:
110/// ```text
111/// output_1 = x1 + x2
112/// output_2 = (x1 - x2) * twiddle
113/// ```
114/// The twiddle factor is applied after subtraction.
115/// Suitable for DIF-style recursive transforms.
116#[derive(Copy, Clone)]
117#[repr(transparent)] // Allows safe transmutes from F to this.
118pub struct DifButterfly<F>(pub F);
119
120impl<F: Field> Butterfly<F> for DifButterfly<F> {
121 #[inline]
122 fn apply<PF: PackedField<Scalar = F>>(&self, x_1: PF, x_2: PF) -> (PF, PF) {
123 (x_1 + x_2, (x_1 - x_2) * self.0)
124 }
125}
126
127/// DIF (Decimation-In-Frequency) butterfly operation where `x_2` is guaranteed to be zero.
128///
129/// Useful in scenarios where the input has just been padded with zeros.
130///
131/// Used in the *output-ordering* variant of NTT.
132/// This butterfly computes:
133/// ```text
134/// output_1 = x1
135/// output_2 = x1 * twiddle
136/// ```
137#[derive(Copy, Clone)]
138#[repr(transparent)] // Allows safe transmutes from F to this.
139pub struct DifButterflyZeros<F>(pub F);
140
141impl<F: Field> Butterfly<F> for DifButterflyZeros<F> {
142 #[inline]
143 fn apply<PF: PackedField<Scalar = F>>(&self, x_1: PF, x_2: PF) -> (PF, PF) {
144 debug_assert!(x_2.as_slice().iter().all(|x| x.is_zero())); // Slightly convoluted but PF may not implement equality.
145 (x_1, x_1 * self.0)
146 }
147
148 #[inline]
149 fn apply_to_rows(&self, row_1: &mut [F], row_2: &mut [F]) {
150 let (shorts_1, suffix_1) = F::Packing::pack_slice_with_suffix(row_1);
151 let (shorts_2, suffix_2) = F::Packing::pack_slice_with_suffix_mut(row_2);
152 debug_assert_eq!(shorts_1.len(), shorts_2.len());
153 debug_assert_eq!(suffix_1.len(), suffix_2.len());
154 for (x_1, x_2) in shorts_1.iter().zip(shorts_2) {
155 debug_assert!(x_2.as_slice().iter().all(|x| x.is_zero())); // Slightly convoluted but PF may not implement equality.
156 *x_2 = *x_1 * self.0; // x_2 is guaranteed to be zero, so we just set it to x_1 * twiddle.
157 }
158 for (x_1, x_2) in suffix_1.iter().zip(suffix_2) {
159 debug_assert!(x_2.is_zero());
160 *x_2 = *x_1 * self.0; // x_2 is guaranteed to be zero, so we just set it to x_1 * twiddle.
161 }
162 }
163}
164
165/// DIT (Decimation-In-Time) butterfly operation.
166///
167/// Used in the *input-ordering* variant of NTT/FFT.
168/// This butterfly computes:
169/// ```text
170/// output_1 = x1 + x2 * twiddle
171/// output_2 = x1 - x2 * twiddle
172/// ```
173/// The twiddle factor is applied to x2 before combining.
174/// Suitable for DIT-style recursive transforms.
175#[derive(Copy, Clone)]
176#[repr(transparent)] // Allows safe transmutes from F to this.
177pub struct DitButterfly<F>(pub F);
178
179impl<F: Field> Butterfly<F> for DitButterfly<F> {
180 #[inline]
181 fn apply<PF: PackedField<Scalar = F>>(&self, x_1: PF, x_2: PF) -> (PF, PF) {
182 let x_2_twiddle = x_2 * self.0;
183 (x_1 + x_2_twiddle, x_1 - x_2_twiddle)
184 }
185}
186
187/// Butterfly with no twiddle factor (`twiddle = 1`).
188///
189/// This is used when no root-of-unity scaling is needed.
190/// It works for either DIT or DIF, and is often used at
191/// the final or base level of a transform tree.
192///
193/// This butterfly computes:
194/// ```text
195/// - output_1 = x1 + x2
196/// - output_2 = x1 - x2
197/// ```
198#[derive(Copy, Clone)]
199pub struct TwiddleFreeButterfly;
200
201impl<F: Field> Butterfly<F> for TwiddleFreeButterfly {
202 #[inline]
203 fn apply<PF: PackedField<Scalar = F>>(&self, x_1: PF, x_2: PF) -> (PF, PF) {
204 (x_1 + x_2, x_1 - x_2)
205 }
206}