p3_field/packed/
packed_traits.rs

1use alloc::vec::Vec;
2use core::iter::{Product, Sum};
3use core::mem::MaybeUninit;
4use core::ops::{Div, DivAssign};
5use core::{array, slice};
6
7use crate::field::Field;
8use crate::{Algebra, BasedVectorSpace, ExtensionField, Powers, PrimeCharacteristicRing};
9
10/// A trait to constrain types that can be packed into a packed value.
11///
12/// The `Packable` trait allows us to specify implementations for potentially conflicting types.
13pub trait Packable: 'static + Default + Copy + Send + Sync + PartialEq + Eq {}
14
15/// A trait for array-like structs made up of multiple scalar elements.
16///
17/// # Safety
18/// - If `P` implements `PackedField` then `P` must be castable to/from `[P::Value; P::WIDTH]`
19///   without UB.
20pub unsafe trait PackedValue: 'static + Copy + Send + Sync {
21    /// The scalar type that is packed into this value.
22    type Value: Packable;
23
24    /// Number of scalar values packed together.
25    const WIDTH: usize;
26
27    /// Interprets a slice of scalar values as a packed value reference.
28    ///
29    /// # Panics:
30    /// This function will panic if `slice.len() != Self::WIDTH`
31    #[must_use]
32    fn from_slice(slice: &[Self::Value]) -> &Self;
33
34    /// Interprets a mutable slice of scalar values as a mutable packed value.
35    ///
36    /// # Panics:
37    /// This function will panic if `slice.len() != Self::WIDTH`
38    #[must_use]
39    fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self;
40
41    /// Constructs a packed value using a function to generate each element.
42    ///
43    /// Similar to `core:array::from_fn`.
44    #[must_use]
45    fn from_fn<F>(f: F) -> Self
46    where
47        F: FnMut(usize) -> Self::Value;
48
49    /// Returns the underlying scalar values as an immutable slice.
50    #[must_use]
51    fn as_slice(&self) -> &[Self::Value];
52
53    /// Returns the underlying scalar values as a mutable slice.
54    #[must_use]
55    fn as_slice_mut(&mut self) -> &mut [Self::Value];
56
57    /// Packs a slice of scalar values into a slice of packed values.
58    ///
59    /// # Panics
60    /// Panics if the slice length is not divisible by `WIDTH`.
61    #[inline]
62    #[must_use]
63    fn pack_slice(buf: &[Self::Value]) -> &[Self] {
64        // Sources vary, but this should be true on all platforms we care about.
65        const {
66            assert!(align_of::<Self>() <= align_of::<Self::Value>());
67        }
68        assert!(
69            buf.len().is_multiple_of(Self::WIDTH),
70            "Slice length (got {}) must be a multiple of packed field width ({}).",
71            buf.len(),
72            Self::WIDTH
73        );
74        let buf_ptr = buf.as_ptr().cast::<Self>();
75        let n = buf.len() / Self::WIDTH;
76        unsafe { slice::from_raw_parts(buf_ptr, n) }
77    }
78
79    /// Packs a slice into packed values and returns the packed portion and any remaining suffix.
80    #[inline]
81    #[must_use]
82    fn pack_slice_with_suffix(buf: &[Self::Value]) -> (&[Self], &[Self::Value]) {
83        let (packed, suffix) = buf.split_at(buf.len() - buf.len() % Self::WIDTH);
84        (Self::pack_slice(packed), suffix)
85    }
86
87    /// Converts a mutable slice of scalar values into a mutable slice of packed values.
88    ///
89    /// # Panics
90    /// Panics if the slice length is not divisible by `WIDTH`.
91    #[inline]
92    #[must_use]
93    fn pack_slice_mut(buf: &mut [Self::Value]) -> &mut [Self] {
94        const {
95            assert!(align_of::<Self>() <= align_of::<Self::Value>());
96        }
97        assert!(
98            buf.len().is_multiple_of(Self::WIDTH),
99            "Slice length (got {}) must be a multiple of packed field width ({}).",
100            buf.len(),
101            Self::WIDTH
102        );
103        let buf_ptr = buf.as_mut_ptr().cast::<Self>();
104        let n = buf.len() / Self::WIDTH;
105        unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
106    }
107
108    /// Converts a mutable slice of possibly uninitialized scalar values into
109    /// a mutable slice of possibly uninitialized packed values.
110    ///
111    /// # Panics
112    /// Panics if the slice length is not divisible by `WIDTH`.
113    #[inline]
114    #[must_use]
115    fn pack_maybe_uninit_slice_mut(
116        buf: &mut [MaybeUninit<Self::Value>],
117    ) -> &mut [MaybeUninit<Self>] {
118        const {
119            assert!(align_of::<Self>() <= align_of::<Self::Value>());
120        }
121        assert!(
122            buf.len().is_multiple_of(Self::WIDTH),
123            "Slice length (got {}) must be a multiple of packed field width ({}).",
124            buf.len(),
125            Self::WIDTH
126        );
127        let buf_ptr = buf.as_mut_ptr().cast::<MaybeUninit<Self>>();
128        let n = buf.len() / Self::WIDTH;
129        unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
130    }
131
132    /// Converts a mutable slice of scalar values into a pair:
133    /// - a slice of packed values covering the largest aligned portion,
134    /// - and a remainder slice of scalar values that couldn't be packed.
135    #[inline]
136    #[must_use]
137    fn pack_slice_with_suffix_mut(buf: &mut [Self::Value]) -> (&mut [Self], &mut [Self::Value]) {
138        let (packed, suffix) = buf.split_at_mut(buf.len() - buf.len() % Self::WIDTH);
139        (Self::pack_slice_mut(packed), suffix)
140    }
141
142    /// Converts a mutable slice of possibly uninitialized scalar values into a pair:
143    /// - a slice of possibly uninitialized packed values covering the largest aligned portion,
144    /// - and a remainder slice of possibly uninitialized scalar values that couldn't be packed.
145    #[inline]
146    #[must_use]
147    fn pack_maybe_uninit_slice_with_suffix_mut(
148        buf: &mut [MaybeUninit<Self::Value>],
149    ) -> (&mut [MaybeUninit<Self>], &mut [MaybeUninit<Self::Value>]) {
150        let (packed, suffix) = buf.split_at_mut(buf.len() - buf.len() % Self::WIDTH);
151        (Self::pack_maybe_uninit_slice_mut(packed), suffix)
152    }
153
154    /// Reinterprets a slice of packed values as a flat slice of scalar values.
155    ///
156    /// Each packed value contains `Self::WIDTH` scalar values, which are laid out
157    /// contiguously in memory. This function allows direct access to those scalars.
158    #[inline]
159    #[must_use]
160    fn unpack_slice(buf: &[Self]) -> &[Self::Value] {
161        const {
162            assert!(align_of::<Self>() >= align_of::<Self::Value>());
163        }
164        let buf_ptr = buf.as_ptr().cast::<Self::Value>();
165        let n = buf.len() * Self::WIDTH;
166        unsafe { slice::from_raw_parts(buf_ptr, n) }
167    }
168}
169
170unsafe impl<T: Packable, const WIDTH: usize> PackedValue for [T; WIDTH] {
171    type Value = T;
172    const WIDTH: usize = WIDTH;
173
174    #[inline]
175    fn from_slice(slice: &[Self::Value]) -> &Self {
176        assert_eq!(slice.len(), Self::WIDTH);
177        unsafe { &*slice.as_ptr().cast() }
178    }
179
180    #[inline]
181    fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self {
182        assert_eq!(slice.len(), Self::WIDTH);
183        unsafe { &mut *slice.as_mut_ptr().cast() }
184    }
185
186    #[inline]
187    fn from_fn<Fn>(f: Fn) -> Self
188    where
189        Fn: FnMut(usize) -> Self::Value,
190    {
191        core::array::from_fn(f)
192    }
193
194    #[inline]
195    fn as_slice(&self) -> &[Self::Value] {
196        self
197    }
198
199    #[inline]
200    fn as_slice_mut(&mut self) -> &mut [Self::Value] {
201        self
202    }
203}
204
205/// An array of field elements which can be packed into a vector for SIMD operations.
206///
207/// # Safety
208/// - See `PackedValue` above.
209pub unsafe trait PackedField: Algebra<Self::Scalar>
210    + PackedValue<Value = Self::Scalar>
211    // TODO: Implement packed / packed division
212    + Div<Self::Scalar, Output = Self>
213    + DivAssign<Self::Scalar>
214    + Sum<Self::Scalar>
215    + Product<Self::Scalar>
216{
217    type Scalar: Field;
218
219    /// Construct an iterator which returns powers of `base` packed into packed field elements.
220    ///
221    /// E.g. if `Self::WIDTH = 4`, returns: `[base^0, base^1, base^2, base^3], [base^4, base^5, base^6, base^7], ...`.
222    #[must_use]
223    fn packed_powers(base: Self::Scalar) -> Powers<Self> {
224        Self::packed_shifted_powers(base, Self::Scalar::ONE)
225    }
226
227    /// Construct an iterator which returns powers of `base` multiplied by `start` and packed into packed field elements.
228    ///
229    /// E.g. if `Self::WIDTH = 4`, returns: `[start, start*base, start*base^2, start*base^3], [start*base^4, start*base^5, start*base^6, start*base^7], ...`.
230    #[must_use]
231    fn packed_shifted_powers(base: Self::Scalar, start: Self::Scalar) -> Powers<Self> {
232        let mut current: Self = start.into();
233        let slice = current.as_slice_mut();
234        for i in 1..Self::WIDTH {
235            slice[i] = slice[i - 1] * base;
236        }
237
238        Powers {
239            base: base.exp_u64(Self::WIDTH as u64).into(),
240            current,
241        }
242    }
243
244    /// Compute a linear combination of a slice of base field elements and
245    /// a slice of packed field elements. The slices must have equal length
246    /// and it must be a compile time constant.
247    ///
248    /// # Panics
249    ///
250    /// May panic if the length of either slice is not equal to `N`.
251    #[must_use]
252    fn packed_linear_combination<const N: usize>(coeffs: &[Self::Scalar], vecs: &[Self]) -> Self {
253        assert_eq!(coeffs.len(), N);
254        assert_eq!(vecs.len(), N);
255        let combined: [Self; N] = array::from_fn(|i| vecs[i] * coeffs[i]);
256        Self::sum_array::<N>(&combined)
257    }
258}
259
260/// # Safety
261/// - `WIDTH` is assumed to be a power of 2.
262pub unsafe trait PackedFieldPow2: PackedField {
263    /// Take interpret two vectors as chunks of `block_len` elements. Unpack and interleave those
264    /// chunks. This is best seen with an example. If we have:
265    /// ```text
266    /// A = [x0, y0, x1, y1]
267    /// B = [x2, y2, x3, y3]
268    /// ```
269    ///
270    /// then
271    ///
272    /// ```text
273    /// interleave(A, B, 1) = ([x0, x2, x1, x3], [y0, y2, y1, y3])
274    /// ```
275    ///
276    /// Pairs that were adjacent in the input are at corresponding positions in the output.
277    ///
278    /// `r` lets us set the size of chunks we're interleaving. If we set `block_len = 2`, then for
279    ///
280    /// ```text
281    /// A = [x0, x1, y0, y1]
282    /// B = [x2, x3, y2, y3]
283    /// ```
284    ///
285    /// we obtain
286    ///
287    /// ```text
288    /// interleave(A, B, block_len) = ([x0, x1, x2, x3], [y0, y1, y2, y3])
289    /// ```
290    ///
291    /// We can also think about this as stacking the vectors, dividing them into 2x2 matrices, and
292    /// transposing those matrices.
293    ///
294    /// When `block_len = WIDTH`, this operation is a no-op.
295    ///
296    /// # Panics
297    /// This may panic if `block_len` does not divide `WIDTH`. Since `WIDTH` is specified to be a power of 2,
298    /// `block_len` must also be a power of 2. It cannot be 0 and it cannot exceed `WIDTH`.
299    #[must_use]
300    fn interleave(&self, other: Self, block_len: usize) -> (Self, Self);
301}
302
303/// Fix a field `F` a packing width `W` and an extension field `EF` of `F`.
304///
305/// By choosing a basis `B`, `EF` can be transformed into an array `[F; D]`.
306///
307/// A type should implement PackedFieldExtension if it can be transformed into `[F::Packing; D] ~ [[F; W]; D]`
308///
309/// This is interpreted by taking a transpose to get `[[F; D]; W]` which can then be reinterpreted
310/// as `[EF; W]` by making use of the chosen basis `B` again.
311pub trait PackedFieldExtension<
312    BaseField: Field,
313    ExtField: ExtensionField<BaseField, ExtensionPacking = Self>,
314>: Algebra<ExtField> + Algebra<BaseField::Packing> + BasedVectorSpace<BaseField::Packing>
315{
316    /// Given a slice of extension field `EF` elements of length `W`,
317    /// convert into the array `[[F; D]; W]` transpose to
318    /// `[[F; W]; D]` and then pack to get `[PF; D]`.
319    #[must_use]
320    fn from_ext_slice(ext_slice: &[ExtField]) -> Self;
321
322    /// Given a iterator of packed extension field elements, convert to an iterator of
323    /// extension field elements.
324    ///
325    /// This performs the inverse transformation to `from_ext_slice`.
326    #[inline]
327    #[must_use]
328    fn to_ext_iter(iter: impl IntoIterator<Item = Self>) -> impl Iterator<Item = ExtField> {
329        iter.into_iter().flat_map(|x| {
330            let packed_coeffs = x.as_basis_coefficients_slice();
331            (0..BaseField::Packing::WIDTH)
332                .map(|i| ExtField::from_basis_coefficients_fn(|j| packed_coeffs[j].as_slice()[i]))
333                .collect::<Vec<_>>() // PackedFieldExtension's should reimplement this to avoid this allocation.
334        })
335    }
336
337    /// Similar to `packed_powers`, construct an iterator which returns
338    /// powers of `base` packed into `PackedFieldExtension` elements.
339    #[must_use]
340    fn packed_ext_powers(base: ExtField) -> Powers<Self>;
341
342    /// Similar to `packed_ext_powers` but only returns `unpacked_len` powers of `base`.
343    ///
344    /// Note that the length of the returned iterator will be `unpacked_len / WIDTH` and
345    /// not `len` as the iterator is over packed extension field elements. If `unpacked_len`
346    /// is not divisible by `WIDTH`, `unpacked_len` will be rounded up to the next multiple of `WIDTH`.
347    #[must_use]
348    fn packed_ext_powers_capped(base: ExtField, unpacked_len: usize) -> impl Iterator<Item = Self> {
349        Self::packed_ext_powers(base).take(unpacked_len.div_ceil(BaseField::Packing::WIDTH))
350    }
351}
352
353unsafe impl<T: Packable> PackedValue for T {
354    type Value = Self;
355
356    const WIDTH: usize = 1;
357
358    #[inline]
359    fn from_slice(slice: &[Self::Value]) -> &Self {
360        assert_eq!(slice.len(), Self::WIDTH);
361        &slice[0]
362    }
363
364    #[inline]
365    fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self {
366        assert_eq!(slice.len(), Self::WIDTH);
367        &mut slice[0]
368    }
369
370    #[inline]
371    fn from_fn<Fn>(mut f: Fn) -> Self
372    where
373        Fn: FnMut(usize) -> Self::Value,
374    {
375        f(0)
376    }
377
378    #[inline]
379    fn as_slice(&self) -> &[Self::Value] {
380        slice::from_ref(self)
381    }
382
383    #[inline]
384    fn as_slice_mut(&mut self) -> &mut [Self::Value] {
385        slice::from_mut(self)
386    }
387}
388
389unsafe impl<F: Field> PackedField for F {
390    type Scalar = Self;
391}
392
393unsafe impl<F: Field> PackedFieldPow2 for F {
394    #[inline]
395    fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) {
396        match block_len {
397            1 => (*self, other),
398            _ => panic!("unsupported block length"),
399        }
400    }
401}
402
403impl<F: Field> PackedFieldExtension<F, F> for F::Packing {
404    #[inline]
405    fn from_ext_slice(ext_slice: &[F]) -> Self {
406        *F::Packing::from_slice(ext_slice)
407    }
408
409    #[inline]
410    fn packed_ext_powers(base: F) -> Powers<Self> {
411        F::Packing::packed_powers(base)
412    }
413}
414
415impl Packable for u8 {}
416
417impl Packable for u16 {}
418
419impl Packable for u32 {}
420
421impl Packable for u64 {}
422
423impl Packable for u128 {}