Skip to main content

p3_mersenne_31/
mersenne_31.rs

1use alloc::vec;
2use alloc::vec::Vec;
3use core::fmt::{Debug, Display, Formatter};
4use core::hash::{Hash, Hasher};
5use core::iter::{Product, Sum};
6use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
7use core::{array, fmt, iter};
8
9use num_bigint::BigUint;
10use p3_challenger::UniformSamplingField;
11use p3_field::exponentiation::exp_1717986917;
12use p3_field::integers::QuotientMap;
13use p3_field::op_assign_macros::{
14    impl_add_assign, impl_div_methods, impl_mul_methods, impl_sub_assign,
15};
16use p3_field::{
17    Field, InjectiveMonomial, Packable, PermutationMonomial, PrimeCharacteristicRing, PrimeField,
18    PrimeField32, PrimeField64, RawDataSerializable, halve_u32, impl_raw_serializable_primefield32,
19    quotient_map_large_iint, quotient_map_large_uint, quotient_map_small_int,
20};
21use p3_util::{flatten_to_base, gcd_inversion_prime_field_32};
22use rand::Rng;
23use rand::distr::{Distribution, StandardUniform};
24use serde::de::Error;
25use serde::{Deserialize, Deserializer, Serialize};
26
27/// The Mersenne31 prime
28const P: u32 = (1 << 31) - 1;
29
30/// The prime field `F_p` where `p = 2^31 - 1`.
31#[derive(Copy, Clone, Default)]
32#[repr(transparent)] // Important for reasoning about memory layout.
33#[must_use]
34pub struct Mersenne31 {
35    /// Not necessarily canonical, but must fit in 31 bits.
36    pub(crate) value: u32,
37}
38
39impl Mersenne31 {
40    /// Create a new field element from any `u32`.
41    ///
42    /// Any `u32` value is accepted and automatically reduced modulo P.
43    #[inline]
44    pub const fn new(value: u32) -> Self {
45        Self { value: value % P }
46    }
47
48    /// Create a field element from a value assumed to be < 2^31.
49    ///
50    /// # Safety
51    /// The element must lie in the range: `[0, 2^31 - 1]`.
52    #[inline]
53    pub(crate) const fn new_reduced(value: u32) -> Self {
54        debug_assert!((value >> 31) == 0);
55        Self { value }
56    }
57
58    /// Convert a u32 element into a Mersenne31 element.
59    ///
60    /// Returns `None` if the element does not lie in the range: `[0, 2^31 - 1]`.
61    #[inline]
62    pub const fn new_checked(value: u32) -> Option<Self> {
63        if (value >> 31) == 0 {
64            Some(Self { value })
65        } else {
66            None
67        }
68    }
69
70    /// Convert a `[u32; N]` array to an array of field elements.
71    ///
72    /// Const version of `input.map(Mersenne31::new)`.
73    #[inline]
74    pub const fn new_array<const N: usize>(input: [u32; N]) -> [Self; N] {
75        let mut output = [Self::ZERO; N];
76        let mut i = 0;
77        while i < N {
78            output[i].value = input[i] % P;
79            i += 1;
80        }
81        output
82    }
83
84    /// Precomputed table of generators for two-adic subgroups of the degree two extension field over Mersenne31.
85    /// The `i`'th element is a generator of the subgroup of order `2^i`.
86    pub const EXT_TWO_ADIC_GENERATORS: [[Self; 2]; 33] = [
87        [Self::ONE, Self::ZERO],
88        [Self::new(2_147_483_646), Self::new(0)],
89        [Self::new(0), Self::new(2_147_483_646)],
90        [Self::new(32_768), Self::new(2_147_450_879)],
91        [Self::new(590_768_354), Self::new(978_592_373)],
92        [Self::new(1_179_735_656), Self::new(1_241_207_368)],
93        [Self::new(1_567_857_810), Self::new(456_695_729)],
94        [Self::new(1_774_253_895), Self::new(1_309_288_441)],
95        [Self::new(736_262_640), Self::new(1_553_669_210)],
96        [Self::new(1_819_216_575), Self::new(1_662_816_114)],
97        [Self::new(1_323_191_254), Self::new(1_936_974_060)],
98        [Self::new(605_622_498), Self::new(1_964_232_216)],
99        [Self::new(343_674_985), Self::new(501_786_993)],
100        [Self::new(1_995_316_534), Self::new(149_306_621)],
101        [Self::new(2_107_600_913), Self::new(1_378_821_388)],
102        [Self::new(541_476_169), Self::new(2_101_081_972)],
103        [Self::new(2_135_874_973), Self::new(483_411_332)],
104        [Self::new(2_097_144_245), Self::new(1_684_033_590)],
105        [Self::new(1_662_322_247), Self::new(670_236_780)],
106        [Self::new(1_172_215_635), Self::new(595_888_646)],
107        [Self::new(241_940_101), Self::new(323_856_519)],
108        [Self::new(1_957_194_259), Self::new(2_139_647_100)],
109        [Self::new(1_957_419_629), Self::new(1_541_039_442)],
110        [Self::new(1_062_045_235), Self::new(1_824_580_421)],
111        [Self::new(1_929_382_196), Self::new(1_664_698_822)],
112        [Self::new(1_889_294_251), Self::new(331_248_939)],
113        [Self::new(1_214_231_414), Self::new(1_646_302_518)],
114        [Self::new(1_765_392_370), Self::new(461_136_547)],
115        [Self::new(1_629_751_483), Self::new(66_485_474)],
116        [Self::new(1_501_355_827), Self::new(1_439_063_420)],
117        [Self::new(509_778_402), Self::new(800_467_507)],
118        [Self::new(311_014_874), Self::new(1_584_694_829)],
119        [Self::new(1_166_849_849), Self::new(1_117_296_306)],
120    ];
121}
122
123impl PartialEq for Mersenne31 {
124    #[inline]
125    fn eq(&self, other: &Self) -> bool {
126        self.as_canonical_u32() == other.as_canonical_u32()
127    }
128}
129
130impl Eq for Mersenne31 {}
131
132impl Packable for Mersenne31 {}
133
134impl Hash for Mersenne31 {
135    fn hash<H: Hasher>(&self, state: &mut H) {
136        state.write_u32(self.to_unique_u32());
137    }
138}
139
140impl Ord for Mersenne31 {
141    #[inline]
142    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
143        self.as_canonical_u32().cmp(&other.as_canonical_u32())
144    }
145}
146
147impl PartialOrd for Mersenne31 {
148    #[inline]
149    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
150        Some(self.cmp(other))
151    }
152}
153
154impl Display for Mersenne31 {
155    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
156        Display::fmt(&self.value, f)
157    }
158}
159
160impl Debug for Mersenne31 {
161    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
162        Debug::fmt(&self.value, f)
163    }
164}
165
166impl Distribution<Mersenne31> for StandardUniform {
167    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Mersenne31 {
168        loop {
169            let next_u31 = rng.next_u32() >> 1;
170            let is_canonical = next_u31 != Mersenne31::ORDER_U32;
171            if is_canonical {
172                return Mersenne31::new_reduced(next_u31);
173            }
174        }
175    }
176}
177
178impl Serialize for Mersenne31 {
179    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
180        // No need to convert to canonical.
181        serializer.serialize_u32(self.value)
182    }
183}
184
185impl<'a> Deserialize<'a> for Mersenne31 {
186    fn deserialize<D: Deserializer<'a>>(d: D) -> Result<Self, D::Error> {
187        let val = u32::deserialize(d)?;
188        // Ensure that `val` satisfies our invariant. i.e. Not necessarily canonical, but must fit in 31 bits.
189        if val <= P {
190            Ok(Self::new_reduced(val))
191        } else {
192            Err(D::Error::custom("Value is out of range"))
193        }
194    }
195}
196
197impl RawDataSerializable for Mersenne31 {
198    impl_raw_serializable_primefield32!();
199}
200
201impl PrimeCharacteristicRing for Mersenne31 {
202    type PrimeSubfield = Self;
203
204    const ZERO: Self = Self { value: 0 };
205    const ONE: Self = Self { value: 1 };
206    const TWO: Self = Self { value: 2 };
207    const NEG_ONE: Self = Self {
208        value: Self::ORDER_U32 - 1,
209    };
210
211    #[inline]
212    fn from_prime_subfield(f: Self::PrimeSubfield) -> Self {
213        f
214    }
215
216    #[inline]
217    fn from_bool(b: bool) -> Self {
218        Self::new_reduced(b as u32)
219    }
220
221    #[inline]
222    fn halve(&self) -> Self {
223        Self::new_reduced(halve_u32::<P>(self.value))
224    }
225
226    #[inline]
227    fn mul_2exp_u64(&self, exp: u64) -> Self {
228        // In a Mersenne field, multiplication by 2^k is just a left rotation by k bits.
229        let exp = exp % 31;
230        let left = (self.value << exp) & ((1 << 31) - 1);
231        let right = self.value >> (31 - exp);
232        let rotated = left | right;
233        Self::new_reduced(rotated)
234    }
235
236    #[inline]
237    fn div_2exp_u64(&self, exp: u64) -> Self {
238        // In a Mersenne field, division by 2^k is just a right rotation by k bits.
239        let exp = (exp % 31) as u8;
240        let left = self.value >> exp;
241        let right = (self.value << (31 - exp)) & ((1 << 31) - 1);
242        let rotated = left | right;
243        Self::new_reduced(rotated)
244    }
245
246    #[inline]
247    fn sum_array<const N: usize>(input: &[Self]) -> Self {
248        assert_eq!(N, input.len());
249        // Benchmarking shows that for N <= 5 it's faster to sum the elements directly
250        // but for N > 5 it's faster to use the .sum() methods which passes through u64's
251        // allowing for delayed reductions.
252        match N {
253            0 => Self::ZERO,
254            1 => input[0],
255            2 => input[0] + input[1],
256            3 => input[0] + input[1] + input[2],
257            4 => (input[0] + input[1]) + (input[2] + input[3]),
258            5 => {
259                let lhs = input[0] + input[1];
260                let rhs = input[2] + input[3];
261                lhs + rhs + input[4]
262            }
263            _ => input.iter().copied().sum(),
264        }
265    }
266
267    #[inline]
268    fn zero_vec(len: usize) -> Vec<Self> {
269        // SAFETY:
270        // Due to `#[repr(transparent)]`, Mersenne31 and u32 have the same size, alignment
271        // and memory layout making `flatten_to_base` safe. This this will create
272        // a vector Mersenne31 elements with value set to 0.
273        unsafe { flatten_to_base(vec![0u32; len]) }
274    }
275}
276
277// Degree of the smallest permutation polynomial for Mersenne31.
278//
279// As p - 1 = 2×3^2×7×11×... the smallest choice for a degree D satisfying gcd(p - 1, D) = 1 is 5.
280impl InjectiveMonomial<5> for Mersenne31 {}
281
282impl PermutationMonomial<5> for Mersenne31 {
283    /// In the field `Mersenne31`, `a^{1/5}` is equal to a^{1717986917}.
284    ///
285    /// This follows from the calculation `5 * 1717986917 = 4*(2^31 - 2) + 1 = 1 mod p - 1`.
286    fn injective_exp_root_n(&self) -> Self {
287        exp_1717986917(*self)
288    }
289}
290
291impl Field for Mersenne31 {
292    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
293    type Packing = crate::PackedMersenne31Neon;
294    #[cfg(all(
295        target_arch = "x86_64",
296        target_feature = "avx2",
297        not(target_feature = "avx512f")
298    ))]
299    type Packing = crate::PackedMersenne31AVX2;
300    #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
301    type Packing = crate::PackedMersenne31AVX512;
302    #[cfg(not(any(
303        all(target_arch = "aarch64", target_feature = "neon"),
304        all(
305            target_arch = "x86_64",
306            target_feature = "avx2",
307            not(target_feature = "avx512f")
308        ),
309        all(target_arch = "x86_64", target_feature = "avx512f"),
310    )))]
311    type Packing = Self;
312
313    // Sage: GF(2^31 - 1).multiplicative_generator()
314    const GENERATOR: Self = Self::new(7);
315
316    #[inline]
317    fn is_zero(&self) -> bool {
318        self.value == 0 || self.value == Self::ORDER_U32
319    }
320
321    fn try_inverse(&self) -> Option<Self> {
322        if self.is_zero() {
323            return None;
324        }
325
326        // Number of bits in the Mersenne31 prime.
327        const NUM_PRIME_BITS: u32 = 31;
328
329        // gcd_inversion returns the inverse multiplied by 2^60 so we need to correct for that.
330        let inverse_i64 = gcd_inversion_prime_field_32::<NUM_PRIME_BITS>(self.value, P);
331        Some(Self::from_int(inverse_i64).div_2exp_u64(60))
332    }
333
334    #[inline]
335    fn order() -> BigUint {
336        P.into()
337    }
338}
339
340// We can use some macros to implement QuotientMap<Int> for all integer types except for u32 and i32's.
341quotient_map_small_int!(Mersenne31, u32, [u8, u16]);
342quotient_map_small_int!(Mersenne31, i32, [i8, i16]);
343quotient_map_large_uint!(
344    Mersenne31,
345    u32,
346    Mersenne31::ORDER_U32,
347    "`[0, 2^31 - 2]`",
348    "`[0, 2^31 - 1]`",
349    [u64, u128]
350);
351quotient_map_large_iint!(
352    Mersenne31,
353    i32,
354    "`[-2^30, 2^30]`",
355    "`[1 - 2^31, 2^31 - 1]`",
356    [(i64, u64), (i128, u128)]
357);
358
359// We simple need to prove custom Mersenne31 impls for QuotientMap<u32> and QuotientMap<i32>
360impl QuotientMap<u32> for Mersenne31 {
361    /// Convert a given `u32` integer into an element of the `Mersenne31` field.
362    #[inline]
363    fn from_int(int: u32) -> Self {
364        // To reduce `n` to 31 bits, we clear its MSB, then add it back in its reduced form.
365        let msb = int & (1 << 31);
366        let msb_reduced = msb >> 31;
367        Self::new_reduced(int ^ msb) + Self::new_reduced(msb_reduced)
368    }
369
370    /// Convert a given `u32` integer into an element of the `Mersenne31` field.
371    ///
372    /// Returns none if the input does not lie in the range `[0, 2^31 - 1]`.
373    #[inline]
374    fn from_canonical_checked(int: u32) -> Option<Self> {
375        (int < Self::ORDER_U32).then(|| Self::new_reduced(int))
376    }
377
378    /// Convert a given `u32` integer into an element of the `Mersenne31` field.
379    ///
380    /// # Safety
381    /// The input must lie in the range: `[0, 2^31 - 1]`.
382    #[inline(always)]
383    unsafe fn from_canonical_unchecked(int: u32) -> Self {
384        debug_assert!(int < Self::ORDER_U32);
385        Self::new_reduced(int)
386    }
387}
388
389impl QuotientMap<i32> for Mersenne31 {
390    /// Convert a given `i32` integer into an element of the `Mersenne31` field.
391    #[inline]
392    fn from_int(int: i32) -> Self {
393        if int >= 0 {
394            Self::new_reduced(int as u32)
395        } else if int > (-1 << 31) {
396            Self::new_reduced(Self::ORDER_U32.wrapping_add_signed(int))
397        } else {
398            // The only other option is int = -(2^31) = -1 mod p.
399            Self::NEG_ONE
400        }
401    }
402
403    /// Convert a given `i32` integer into an element of the `Mersenne31` field.
404    ///
405    /// Returns none if the input does not lie in the range `(-2^30, 2^30)`.
406    #[inline]
407    fn from_canonical_checked(int: i32) -> Option<Self> {
408        const TWO_EXP_30: i32 = 1 << 30;
409        const NEG_TWO_EXP_30_PLUS_1: i32 = (-1 << 30) + 1;
410        match int {
411            0..TWO_EXP_30 => Some(Self::new_reduced(int as u32)),
412            NEG_TWO_EXP_30_PLUS_1..0 => {
413                Some(Self::new_reduced(Self::ORDER_U32.wrapping_add_signed(int)))
414            }
415            _ => None,
416        }
417    }
418
419    /// Convert a given `i32` integer into an element of the `Mersenne31` field.
420    ///
421    /// # Safety
422    /// The input must lie in the range: `[1 - 2^31, 2^31 - 1]`.
423    #[inline(always)]
424    unsafe fn from_canonical_unchecked(int: i32) -> Self {
425        if int >= 0 {
426            Self::new_reduced(int as u32)
427        } else {
428            Self::new_reduced(Self::ORDER_U32.wrapping_add_signed(int))
429        }
430    }
431}
432
433impl PrimeField for Mersenne31 {
434    fn as_canonical_biguint(&self) -> BigUint {
435        <Self as PrimeField32>::as_canonical_u32(self).into()
436    }
437}
438
439impl PrimeField32 for Mersenne31 {
440    const ORDER_U32: u32 = P;
441
442    #[inline]
443    fn as_canonical_u32(&self) -> u32 {
444        // Since our invariant guarantees that `value` fits in 31 bits, there is only one possible
445        // `value` that is not canonical, namely 2^31 - 1 = p = 0.
446        if self.value == Self::ORDER_U32 {
447            0
448        } else {
449            self.value
450        }
451    }
452}
453
454impl PrimeField64 for Mersenne31 {
455    const ORDER_U64: u64 = <Self as PrimeField32>::ORDER_U32 as u64;
456
457    #[inline]
458    fn as_canonical_u64(&self) -> u64 {
459        self.as_canonical_u32().into()
460    }
461}
462
463impl Add for Mersenne31 {
464    type Output = Self;
465
466    #[inline]
467    fn add(self, rhs: Self) -> Self {
468        // See the following for a way to compute the sum that avoids
469        // the conditional which may be preferable on some
470        // architectures.
471        // https://github.com/Plonky3/Plonky3/blob/6049a30c3b1f5351c3eb0f7c994dc97e8f68d10d/mersenne-31/src/lib.rs#L249
472
473        // Working with i32 means we get a flag which informs us if overflow happened.
474        let (sum_i32, over) = (self.value as i32).overflowing_add(rhs.value as i32);
475        let sum_u32 = sum_i32 as u32;
476        let sum_corr = sum_u32.wrapping_sub(Self::ORDER_U32);
477
478        // If self + rhs did not overflow, return it.
479        // If self + rhs overflowed, sum_corr = self + rhs - (2**31 - 1).
480        Self::new_reduced(if over { sum_corr } else { sum_u32 })
481    }
482}
483
484impl Sub for Mersenne31 {
485    type Output = Self;
486
487    #[inline]
488    fn sub(self, rhs: Self) -> Self {
489        let (mut sub, over) = self.value.overflowing_sub(rhs.value);
490
491        // If we didn't overflow we have the correct value.
492        // Otherwise we have added 2**32 = 2**31 + 1 mod 2**31 - 1.
493        // Hence we need to remove the most significant bit and subtract 1.
494        sub -= over as u32;
495        Self::new_reduced(sub & Self::ORDER_U32)
496    }
497}
498
499impl Neg for Mersenne31 {
500    type Output = Self;
501
502    #[inline]
503    fn neg(self) -> Self::Output {
504        // Can't underflow, since self.value is 31-bits and thus can't exceed ORDER.
505        Self::new_reduced(Self::ORDER_U32 - self.value)
506    }
507}
508
509impl Mul for Mersenne31 {
510    type Output = Self;
511
512    #[inline]
513    #[allow(clippy::cast_possible_truncation)]
514    fn mul(self, rhs: Self) -> Self {
515        let prod = u64::from(self.value) * u64::from(rhs.value);
516        from_u62(prod)
517    }
518}
519
520impl_add_assign!(Mersenne31);
521impl_sub_assign!(Mersenne31);
522impl_mul_methods!(Mersenne31);
523impl_div_methods!(Mersenne31, Mersenne31);
524
525impl Sum for Mersenne31 {
526    #[inline]
527    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
528        // This is faster than iter.reduce(|x, y| x + y).unwrap_or(Self::ZERO) for iterators of length >= 6.
529        // It assumes that iter.len() < 2^31.
530
531        // This sum will not overflow so long as iter.len() < 2^33.
532        let sum = iter.map(|x| x.value as u64).sum::<u64>();
533
534        // sum is < 2^62 provided iter.len() < 2^31.
535        from_u62(sum)
536    }
537}
538
539#[inline(always)]
540pub(crate) fn from_u62(input: u64) -> Mersenne31 {
541    debug_assert!(input < (1 << 62));
542    let input_lo = (input & ((1 << 31) - 1)) as u32;
543    let input_high = (input >> 31) as u32;
544    Mersenne31::new_reduced(input_lo) + Mersenne31::new_reduced(input_high)
545}
546
547impl UniformSamplingField for Mersenne31 {
548    const MAX_SINGLE_SAMPLE_BITS: usize = 16;
549    // For Mersenne31 uniform sampling really only makes sense if we allow rejection sampling.
550    // Sampling 16 bits already has a chance of 3e-5 to require a resample!
551    const SAMPLING_BITS_M: [u64; 64] = {
552        let prime: u64 = P as u64;
553        let mut a = [0u64; 64];
554        let mut k = 0;
555        while k < 64 {
556            if k == 0 {
557                a[k] = prime; // This value is irrelevant in practice. `bits = 0` returns 0 always.
558            } else {
559                // Create a mask to zero out the last k bits
560                let mask = !((1u64 << k) - 1);
561                a[k] = prime & mask;
562            }
563            k += 1;
564        }
565        a
566    };
567}
568
569#[cfg(test)]
570mod tests {
571    use num_bigint::BigUint;
572    use p3_field::{InjectiveMonomial, PermutationMonomial, PrimeCharacteristicRing};
573    use p3_field_testing::{
574        test_field, test_prime_field, test_prime_field_32, test_prime_field_64,
575    };
576
577    use crate::Mersenne31;
578
579    type F = Mersenne31;
580
581    #[test]
582    fn exp_root() {
583        // Confirm that (x^{1/5})^5 = x
584
585        let m1 = F::from_u32(0x34167c58);
586        let m2 = F::from_u32(0x61f3207b);
587
588        assert_eq!(m1.injective_exp_n().injective_exp_root_n(), m1);
589        assert_eq!(m2.injective_exp_n().injective_exp_root_n(), m2);
590        assert_eq!(F::TWO.injective_exp_n().injective_exp_root_n(), F::TWO);
591    }
592
593    // Mersenne31 has a redundant representation of Zero but no redundant representation of One.
594    const ZEROS: [Mersenne31; 2] = [Mersenne31::ZERO, Mersenne31::new((1_u32 << 31) - 1)];
595    const ONES: [Mersenne31; 1] = [Mersenne31::ONE];
596
597    // Get the prime factorization of the order of the multiplicative group.
598    // i.e. the prime factorization of P - 1.
599    fn multiplicative_group_prime_factorization() -> [(BigUint, u32); 7] {
600        [
601            (BigUint::from(2u8), 1),
602            (BigUint::from(3u8), 2),
603            (BigUint::from(7u8), 1),
604            (BigUint::from(11u8), 1),
605            (BigUint::from(31u8), 1),
606            (BigUint::from(151u8), 1),
607            (BigUint::from(331u16), 1),
608        ]
609    }
610
611    test_field!(
612        crate::Mersenne31,
613        &super::ZEROS,
614        &super::ONES,
615        &super::multiplicative_group_prime_factorization()
616    );
617    test_prime_field!(crate::Mersenne31);
618    test_prime_field_64!(crate::Mersenne31, &super::ZEROS, &super::ONES);
619    test_prime_field_32!(crate::Mersenne31, &super::ZEROS, &super::ONES);
620}