Skip to main content

p3_mersenne_31/
qm31.rs

1//! The degree-4 extension of Mersenne31, built as the binomial extension by
2//! `X² - (2 + i)` over the complex extension `Mersenne31[i]`, flattened to a
3//! degree-4 vector space over `Mersenne31`.
4//!
5//! The tower type `BinomialExtensionField<Complex<Mersenne31>, 2>` only knows
6//! it is an extension of `Complex<Mersenne31>`. This module supplies the
7//! `Algebra<Mersenne31>`, `BasedVectorSpace<Mersenne31>` and
8//! `ExtensionField<Mersenne31>` impls (plus the packed counterpart
9//! [`PackedQM31`]) that let it serve as the challenge field of a STARK over
10//! `Mersenne31`, with 4 · 31 = 124 bits of extension size.
11//!
12//! The flattened basis is `[1, i, u, iu]` (`u² = 2 + i`), i.e. the in-memory
13//! order of the nested `[[Mersenne31; 2]; 2]` representation.
14
15use alloc::vec::Vec;
16use core::iter::{Product, Sum};
17use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
18use core::slice;
19
20use p3_field::extension::{BinomialExtensionField, Complex};
21use p3_field::{
22    Algebra, BasedVectorSpace, ExtensionField, Field, PackedFieldExtension, PackedValue, Powers,
23    PrimeCharacteristicRing,
24};
25use p3_util::{as_base_slice, flatten_to_base, reconstitute_from_base};
26
27use crate::Mersenne31;
28
29/// The degree-4 extension of Mersenne31: `Mersenne31[i][u]` with `i² = -1`
30/// and `u² = 2 + i`.
31pub type QM31 = BinomialExtensionField<Complex<Mersenne31>, 2>;
32
33type CM31 = Complex<Mersenne31>;
34type PackedM31 = <Mersenne31 as Field>::Packing;
35type PackedCM31 = <CM31 as ExtensionField<Mersenne31>>::ExtensionPacking;
36
37/// The two `CM31` coefficients of a `QM31` element.
38#[inline(always)]
39fn coeffs(x: &QM31) -> [CM31; 2] {
40    let s = BasedVectorSpace::<CM31>::as_basis_coefficients_slice(x);
41    [s[0], s[1]]
42}
43
44/// Multiply a packed complex coefficient by `W = 2 + i` using only additions:
45/// `(a + bi)(2 + i) = (2a - b) + (a + 2b)i`.
46#[inline(always)]
47fn packed_mul_by_w(c: PackedCM31) -> PackedCM31 {
48    let s = BasedVectorSpace::<PackedM31>::as_basis_coefficients_slice(&c);
49    let (re, im) = (s[0], s[1]);
50    PackedCM31::new([re + re - im, re + im + im])
51}
52
53// ---------------------------------------------------------------------------
54// Scalar flattening: QM31 as an algebra / vector space / extension over M31
55// ---------------------------------------------------------------------------
56
57impl From<Mersenne31> for QM31 {
58    #[inline]
59    fn from(x: Mersenne31) -> Self {
60        Self::new([CM31::from(x), CM31::ZERO])
61    }
62}
63
64impl Add<Mersenne31> for QM31 {
65    type Output = Self;
66    #[inline]
67    fn add(self, rhs: Mersenne31) -> Self {
68        let [c0, c1] = coeffs(&self);
69        Self::new([c0 + rhs, c1])
70    }
71}
72
73impl AddAssign<Mersenne31> for QM31 {
74    #[inline]
75    fn add_assign(&mut self, rhs: Mersenne31) {
76        *self = *self + rhs;
77    }
78}
79
80impl Sub<Mersenne31> for QM31 {
81    type Output = Self;
82    #[inline]
83    fn sub(self, rhs: Mersenne31) -> Self {
84        let [c0, c1] = coeffs(&self);
85        Self::new([c0 - rhs, c1])
86    }
87}
88
89impl SubAssign<Mersenne31> for QM31 {
90    #[inline]
91    fn sub_assign(&mut self, rhs: Mersenne31) {
92        *self = *self - rhs;
93    }
94}
95
96impl Mul<Mersenne31> for QM31 {
97    type Output = Self;
98    #[inline]
99    fn mul(self, rhs: Mersenne31) -> Self {
100        let [c0, c1] = coeffs(&self);
101        Self::new([c0 * rhs, c1 * rhs])
102    }
103}
104
105impl MulAssign<Mersenne31> for QM31 {
106    #[inline]
107    fn mul_assign(&mut self, rhs: Mersenne31) {
108        *self = *self * rhs;
109    }
110}
111
112impl Algebra<Mersenne31> for QM31 {}
113
114impl BasedVectorSpace<Mersenne31> for QM31 {
115    const DIMENSION: usize = 4;
116
117    #[inline]
118    fn as_basis_coefficients_slice(&self) -> &[Mersenne31] {
119        // SAFETY: `QM31` is `repr(transparent)` over `[CM31; 2]` and `CM31`
120        // over `[Mersenne31; 2]`, so `QM31` is layout-identical to
121        // `[Mersenne31; 4]`.
122        unsafe { as_base_slice(slice::from_ref(self)) }
123    }
124
125    #[inline]
126    fn from_basis_coefficients_fn<Fn: FnMut(usize) -> Mersenne31>(mut f: Fn) -> Self {
127        Self::new(core::array::from_fn(|i| {
128            CM31::from_basis_coefficients_fn(|j| f(2 * i + j))
129        }))
130    }
131
132    #[inline]
133    fn from_basis_coefficients_iter<I: ExactSizeIterator<Item = Mersenne31>>(
134        mut iter: I,
135    ) -> Option<Self> {
136        (iter.len() == 4).then(|| Self::from_basis_coefficients_fn(|_| iter.next().unwrap()))
137    }
138
139    #[inline]
140    fn flatten_to_base(vec: Vec<Self>) -> Vec<Mersenne31> {
141        // SAFETY: `QM31` is layout-identical to `[Mersenne31; 4]` (see
142        // `as_basis_coefficients_slice`) and has the same alignment as `Mersenne31`.
143        unsafe { flatten_to_base(vec) }
144    }
145
146    #[inline]
147    fn reconstitute_from_base(vec: Vec<Mersenne31>) -> Vec<Self> {
148        // SAFETY: `QM31` is layout-identical to `[Mersenne31; 4]` (see
149        // `as_basis_coefficients_slice`) and has the same alignment as `Mersenne31`.
150        unsafe { reconstitute_from_base(vec) }
151    }
152}
153
154impl ExtensionField<Mersenne31> for QM31 {
155    type ExtensionPacking = PackedQM31;
156
157    #[inline]
158    fn is_in_basefield(&self) -> bool {
159        BasedVectorSpace::<Mersenne31>::as_basis_coefficients_slice(self)[1..]
160            .iter()
161            .all(Mersenne31::is_zero)
162    }
163
164    #[inline]
165    fn as_base(&self) -> Option<Mersenne31> {
166        ExtensionField::<Mersenne31>::is_in_basefield(self)
167            .then(|| BasedVectorSpace::<Mersenne31>::as_basis_coefficients_slice(self)[0])
168    }
169}
170
171// ---------------------------------------------------------------------------
172// PackedQM31: SIMD-lane-parallel QM31, two packed CM31 coefficients
173// ---------------------------------------------------------------------------
174
175/// Packed representation of [`QM31`]: two packed `Complex<Mersenne31>`
176/// coefficients, each holding `PackedM31::WIDTH` lanes.
177#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
178#[repr(transparent)]
179pub struct PackedQM31(pub [PackedCM31; 2]);
180
181impl PrimeCharacteristicRing for PackedQM31 {
182    type PrimeSubfield = Mersenne31;
183
184    const ZERO: Self = Self([PackedCM31::ZERO; 2]);
185    const ONE: Self = Self([PackedCM31::ONE, PackedCM31::ZERO]);
186    const TWO: Self = Self([PackedCM31::TWO, PackedCM31::ZERO]);
187    const NEG_ONE: Self = Self([PackedCM31::NEG_ONE, PackedCM31::ZERO]);
188
189    #[inline]
190    fn from_prime_subfield(val: Self::PrimeSubfield) -> Self {
191        Self([PackedCM31::from_prime_subfield(val), PackedCM31::ZERO])
192    }
193
194    #[inline]
195    fn halve(&self) -> Self {
196        Self(self.0.map(|c| c.halve()))
197    }
198
199    #[inline]
200    fn mul_2exp_u64(&self, exp: u64) -> Self {
201        Self(self.0.map(|c| c.mul_2exp_u64(exp)))
202    }
203
204    #[inline]
205    fn div_2exp_u64(&self, exp: u64) -> Self {
206        Self(self.0.map(|c| c.div_2exp_u64(exp)))
207    }
208
209    #[inline]
210    fn zero_vec(len: usize) -> Vec<Self> {
211        // SAFETY: `Self` is `repr(transparent)` over `[PackedCM31; 2]`.
212        unsafe { reconstitute_from_base(PackedCM31::zero_vec(len * 2)) }
213    }
214}
215
216impl Neg for PackedQM31 {
217    type Output = Self;
218    #[inline]
219    fn neg(self) -> Self {
220        Self(self.0.map(Neg::neg))
221    }
222}
223
224impl Add for PackedQM31 {
225    type Output = Self;
226    #[inline]
227    fn add(self, rhs: Self) -> Self {
228        Self([self.0[0] + rhs.0[0], self.0[1] + rhs.0[1]])
229    }
230}
231
232impl AddAssign for PackedQM31 {
233    #[inline]
234    fn add_assign(&mut self, rhs: Self) {
235        *self = *self + rhs;
236    }
237}
238
239impl Sub for PackedQM31 {
240    type Output = Self;
241    #[inline]
242    fn sub(self, rhs: Self) -> Self {
243        Self([self.0[0] - rhs.0[0], self.0[1] - rhs.0[1]])
244    }
245}
246
247impl SubAssign for PackedQM31 {
248    #[inline]
249    fn sub_assign(&mut self, rhs: Self) {
250        *self = *self - rhs;
251    }
252}
253
254impl Mul for PackedQM31 {
255    type Output = Self;
256    #[inline]
257    fn mul(self, rhs: Self) -> Self {
258        // Karatsuba over CM31 with the cheap W = 2 + i correction:
259        //   (a0 + a1 u)(b0 + b1 u) = a0 b0 + W a1 b1 + (a0 b1 + a1 b0) u
260        // with a0 b1 + a1 b0 = (a0 + a1)(b0 + b1) - a0 b0 - a1 b1,
261        // for 3 full CM31 multiplications instead of 4.
262        let [a0, a1] = self.0;
263        let [b0, b1] = rhs.0;
264        let m0 = a0 * b0;
265        let m1 = a1 * b1;
266        let m2 = (a0 + a1) * (b0 + b1);
267        Self([m0 + packed_mul_by_w(m1), m2 - m0 - m1])
268    }
269}
270
271impl MulAssign for PackedQM31 {
272    #[inline]
273    fn mul_assign(&mut self, rhs: Self) {
274        *self = *self * rhs;
275    }
276}
277
278impl core::ops::Div for PackedQM31 {
279    type Output = Self;
280    #[allow(clippy::suspicious_arithmetic_impl)]
281    #[inline]
282    fn div(self, rhs: Self) -> Self {
283        self * p3_field::invert_packed_extension::<Mersenne31, QM31>(rhs)
284    }
285}
286
287impl core::ops::DivAssign for PackedQM31 {
288    #[inline]
289    fn div_assign(&mut self, rhs: Self) {
290        *self = *self / rhs;
291    }
292}
293
294impl Sum for PackedQM31 {
295    #[inline]
296    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
297        iter.fold(Self::ZERO, |acc, x| acc + x)
298    }
299}
300
301impl Product for PackedQM31 {
302    #[inline]
303    fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
304        iter.fold(Self::ONE, |acc, x| acc * x)
305    }
306}
307
308// --- Algebra<QM31> ---
309
310impl From<QM31> for PackedQM31 {
311    #[inline]
312    fn from(x: QM31) -> Self {
313        let [c0, c1] = coeffs(&x);
314        Self([c0.into(), c1.into()])
315    }
316}
317
318impl Add<QM31> for PackedQM31 {
319    type Output = Self;
320    #[inline]
321    fn add(self, rhs: QM31) -> Self {
322        let [b0, b1] = coeffs(&rhs);
323        Self([self.0[0] + b0, self.0[1] + b1])
324    }
325}
326
327impl AddAssign<QM31> for PackedQM31 {
328    #[inline]
329    fn add_assign(&mut self, rhs: QM31) {
330        *self = *self + rhs;
331    }
332}
333
334impl Sub<QM31> for PackedQM31 {
335    type Output = Self;
336    #[inline]
337    fn sub(self, rhs: QM31) -> Self {
338        let [b0, b1] = coeffs(&rhs);
339        Self([self.0[0] - b0, self.0[1] - b1])
340    }
341}
342
343impl SubAssign<QM31> for PackedQM31 {
344    #[inline]
345    fn sub_assign(&mut self, rhs: QM31) {
346        *self = *self - rhs;
347    }
348}
349
350impl Mul<QM31> for PackedQM31 {
351    type Output = Self;
352    #[inline]
353    fn mul(self, rhs: QM31) -> Self {
354        let [a0, a1] = self.0;
355        let [b0, b1] = coeffs(&rhs);
356        let m0 = a0 * b0;
357        let m1 = a1 * b1;
358        let m2 = (a0 + a1) * (b0 + b1);
359        Self([m0 + packed_mul_by_w(m1), m2 - m0 - m1])
360    }
361}
362
363impl MulAssign<QM31> for PackedQM31 {
364    #[inline]
365    fn mul_assign(&mut self, rhs: QM31) {
366        *self = *self * rhs;
367    }
368}
369
370impl Algebra<QM31> for PackedQM31 {}
371
372// --- Algebra<PackedM31> ---
373
374impl From<PackedM31> for PackedQM31 {
375    #[inline]
376    fn from(x: PackedM31) -> Self {
377        Self([x.into(), PackedCM31::ZERO])
378    }
379}
380
381impl Add<PackedM31> for PackedQM31 {
382    type Output = Self;
383    #[inline]
384    fn add(self, rhs: PackedM31) -> Self {
385        Self([self.0[0] + rhs, self.0[1]])
386    }
387}
388
389impl AddAssign<PackedM31> for PackedQM31 {
390    #[inline]
391    fn add_assign(&mut self, rhs: PackedM31) {
392        *self = *self + rhs;
393    }
394}
395
396impl Sub<PackedM31> for PackedQM31 {
397    type Output = Self;
398    #[inline]
399    fn sub(self, rhs: PackedM31) -> Self {
400        Self([self.0[0] - rhs, self.0[1]])
401    }
402}
403
404impl SubAssign<PackedM31> for PackedQM31 {
405    #[inline]
406    fn sub_assign(&mut self, rhs: PackedM31) {
407        *self = *self - rhs;
408    }
409}
410
411impl Mul<PackedM31> for PackedQM31 {
412    type Output = Self;
413    #[inline]
414    fn mul(self, rhs: PackedM31) -> Self {
415        Self([self.0[0] * rhs, self.0[1] * rhs])
416    }
417}
418
419impl MulAssign<PackedM31> for PackedQM31 {
420    #[inline]
421    fn mul_assign(&mut self, rhs: PackedM31) {
422        *self = *self * rhs;
423    }
424}
425
426impl Algebra<PackedM31> for PackedQM31 {}
427
428impl BasedVectorSpace<PackedM31> for PackedQM31 {
429    const DIMENSION: usize = 4;
430
431    #[inline]
432    fn as_basis_coefficients_slice(&self) -> &[PackedM31] {
433        // SAFETY: `PackedQM31` is `repr(transparent)` over `[PackedCM31; 2]`
434        // and `PackedCM31` over `[PackedM31; 2]`, so `PackedQM31` is
435        // layout-identical to `[PackedM31; 4]`.
436        unsafe { as_base_slice(slice::from_ref(self)) }
437    }
438
439    #[inline]
440    fn from_basis_coefficients_fn<Fn: FnMut(usize) -> PackedM31>(mut f: Fn) -> Self {
441        Self(core::array::from_fn(|i| {
442            PackedCM31::from_basis_coefficients_fn(|j| f(2 * i + j))
443        }))
444    }
445
446    #[inline]
447    fn from_basis_coefficients_iter<I: ExactSizeIterator<Item = PackedM31>>(
448        mut iter: I,
449    ) -> Option<Self> {
450        (iter.len() == 4).then(|| Self::from_basis_coefficients_fn(|_| iter.next().unwrap()))
451    }
452}
453
454impl rand::distr::Distribution<PackedQM31> for rand::distr::StandardUniform {
455    #[inline]
456    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> PackedQM31 {
457        PackedQM31(core::array::from_fn(|_| {
458            <PackedCM31 as BasedVectorSpace<PackedM31>>::from_basis_coefficients_fn(|_| {
459                self.sample(rng)
460            })
461        }))
462    }
463}
464
465impl PackedFieldExtension<Mersenne31, QM31> for PackedQM31 {
466    #[inline]
467    fn packed_ext_powers(base: QM31) -> Powers<Self> {
468        let width = PackedM31::WIDTH;
469        let powers = base.powers().collect_n(width + 1);
470        // Transpose the first WIDTH powers into the lanes.
471        let current = Self::from_ext_slice(&powers[..width]);
472        // Broadcast base^WIDTH as the per-step multiplier.
473        Powers {
474            base: powers[width].into(),
475            current,
476        }
477    }
478}
479
480#[cfg(test)]
481mod tests {
482    use num_bigint::BigUint;
483    use p3_field::PrimeCharacteristicRing;
484    use p3_field_testing::{test_extension_field, test_field, test_packed_extension_field};
485
486    use super::*;
487
488    type F = Mersenne31;
489    type EF = QM31;
490
491    const ZEROS: [EF; 1] = [EF::ZERO];
492    const ONES: [EF; 1] = [EF::ONE];
493
494    // The prime factorization of P^4 - 1 (same multiplicative group as the
495    // quadratic-over-complex view tested in `extension.rs`).
496    fn multiplicative_group_prime_factorization() -> [(BigUint, u32); 11] {
497        [
498            (BigUint::from(2u8), 33),
499            (BigUint::from(3u8), 2),
500            (BigUint::from(5u8), 1),
501            (BigUint::from(7u8), 1),
502            (BigUint::from(11u8), 1),
503            (BigUint::from(31u8), 1),
504            (BigUint::from(151u8), 1),
505            (BigUint::from(331u16), 1),
506            (BigUint::from(733u16), 1),
507            (BigUint::from(1709u16), 1),
508            (BigUint::from(368140581013u64), 1),
509        ]
510    }
511
512    test_field!(
513        super::EF,
514        &super::ZEROS,
515        &super::ONES,
516        &super::multiplicative_group_prime_factorization()
517    );
518
519    test_extension_field!(super::F, super::EF);
520
521    type Pef = PackedQM31;
522    const PACKED_ZEROS: [Pef; 1] = [Pef::ZERO];
523    const PACKED_ONES: [Pef; 1] = [Pef::ONE];
524    test_packed_extension_field!(
525        super::F,
526        super::EF,
527        super::Pef,
528        &super::PACKED_ZEROS,
529        &super::PACKED_ONES
530    );
531
532    /// The flattened M31 basis order must match the nested CM31 layout.
533    #[test]
534    fn flattened_basis_order_matches_nested_layout() {
535        use p3_field::BasedVectorSpace;
536
537        let x = QM31::new([
538            Complex::new_complex(F::new(1), F::new(2)),
539            Complex::new_complex(F::new(3), F::new(4)),
540        ]);
541        let flat = BasedVectorSpace::<F>::as_basis_coefficients_slice(&x);
542        assert_eq!(flat, &[F::new(1), F::new(2), F::new(3), F::new(4)]);
543
544        let rebuilt = <QM31 as BasedVectorSpace<F>>::from_basis_coefficients_slice(flat).unwrap();
545        assert_eq!(rebuilt, x);
546    }
547
548    /// Packed multiplication must agree with scalar multiplication lane-wise.
549    #[test]
550    fn packed_mul_matches_scalar() {
551        use p3_field::PackedFieldExtension;
552        use rand::rngs::SmallRng;
553        use rand::{RngExt, SeedableRng};
554
555        let mut rng = SmallRng::seed_from_u64(1);
556        let width = <PackedM31 as p3_field::PackedValue>::WIDTH;
557        let xs: alloc::vec::Vec<QM31> = (0..width).map(|_| rng.random()).collect();
558        let ys: alloc::vec::Vec<QM31> = (0..width).map(|_| rng.random()).collect();
559
560        let px = <PackedQM31 as PackedFieldExtension<F, EF>>::from_ext_slice(&xs);
561        let py = <PackedQM31 as PackedFieldExtension<F, EF>>::from_ext_slice(&ys);
562        let prod = px * py;
563
564        for lane in 0..width {
565            assert_eq!(
566                <PackedQM31 as PackedFieldExtension<F, EF>>::extract(&prod, lane),
567                xs[lane] * ys[lane]
568            );
569        }
570    }
571
572    /// The zero-copy `flatten_to_base`/`reconstitute_from_base` overrides must match
573    /// the basis-coefficient view element by element and round-trip exactly.
574    #[test]
575    fn flatten_reconstitute_roundtrip() {
576        use alloc::vec::Vec;
577
578        use rand::rngs::SmallRng;
579        use rand::{RngExt, SeedableRng};
580
581        let mut rng = SmallRng::seed_from_u64(7);
582        let xs: Vec<QM31> = (0..23).map(|_| rng.random()).collect();
583
584        let flat = <QM31 as BasedVectorSpace<F>>::flatten_to_base(xs.clone());
585        let expected: Vec<F> = xs
586            .iter()
587            .flat_map(|x| x.as_basis_coefficients_slice().to_vec())
588            .collect();
589        assert_eq!(flat, expected);
590
591        let back = <QM31 as BasedVectorSpace<F>>::reconstitute_from_base(flat);
592        assert_eq!(back, xs);
593    }
594}