Skip to main content

p3_mersenne_31/
mersenne_31.rs

1use alloc::vec;
2use alloc::vec::Vec;
3use core::fmt::{Debug, Display, Formatter};
4use core::hash::{Hash, Hasher};
5use core::iter::{Product, Sum};
6use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
7use core::{array, fmt, iter};
8
9use num_bigint::BigUint;
10use p3_challenger::UniformSamplingField;
11use p3_field::exponentiation::exp_1717986917;
12use p3_field::integers::QuotientMap;
13use p3_field::op_assign_macros::{
14    impl_add_assign, impl_div_methods, impl_mul_methods, impl_sub_assign,
15};
16use p3_field::{
17    Field, InjectiveMonomial, Packable, PermutationMonomial, PrimeCharacteristicRing, PrimeField,
18    PrimeField32, PrimeField64, RawDataSerializable, halve_u32, impl_raw_serializable_primefield32,
19    quotient_map_large_iint, quotient_map_large_uint, quotient_map_small_int,
20};
21use p3_util::{flatten_to_base, gcd_inversion_prime_field_32};
22use rand::Rng;
23use rand::distr::{Distribution, StandardUniform};
24use serde::de::Error;
25use serde::{Deserialize, Deserializer, Serialize};
26
27/// The Mersenne31 prime
28const P: u32 = (1 << 31) - 1;
29
30/// The prime field `F_p` where `p = 2^31 - 1`.
31///
32/// The serde encoding is canonical: every field element has exactly one valid byte representation.
33#[derive(Copy, Clone, Default)]
34#[repr(transparent)] // Important for reasoning about memory layout.
35#[must_use]
36pub struct Mersenne31 {
37    /// Not necessarily canonical, but must fit in 31 bits.
38    pub(crate) value: u32,
39}
40
41impl Mersenne31 {
42    /// Create a new field element from any `u32`.
43    ///
44    /// Any `u32` value is accepted and automatically reduced modulo P.
45    #[inline]
46    pub const fn new(value: u32) -> Self {
47        Self { value: value % P }
48    }
49
50    /// Create a field element from a value assumed to be < 2^31.
51    ///
52    /// # Safety
53    /// The element must lie in the range: `[0, 2^31 - 1]`.
54    #[inline]
55    pub(crate) const fn new_reduced(value: u32) -> Self {
56        debug_assert!((value >> 31) == 0);
57        Self { value }
58    }
59
60    /// Convert a u32 element into a Mersenne31 element.
61    ///
62    /// Returns `None` if the element does not lie in the range: `[0, 2^31 - 1]`.
63    #[inline]
64    pub const fn new_checked(value: u32) -> Option<Self> {
65        if (value >> 31) == 0 {
66            Some(Self { value })
67        } else {
68            None
69        }
70    }
71
72    /// Convert a `[u32; N]` array to an array of field elements.
73    ///
74    /// Const version of `input.map(Mersenne31::new)`.
75    #[inline]
76    pub const fn new_array<const N: usize>(input: [u32; N]) -> [Self; N] {
77        let mut output = [Self::ZERO; N];
78        let mut i = 0;
79        while i < N {
80            output[i].value = input[i] % P;
81            i += 1;
82        }
83        output
84    }
85
86    /// Precomputed table of generators for two-adic subgroups of the degree two extension field over Mersenne31.
87    /// The `i`'th element is a generator of the subgroup of order `2^i`.
88    pub const EXT_TWO_ADIC_GENERATORS: [[Self; 2]; 33] = [
89        [Self::ONE, Self::ZERO],
90        [Self::new(2_147_483_646), Self::new(0)],
91        [Self::new(0), Self::new(2_147_483_646)],
92        [Self::new(32_768), Self::new(2_147_450_879)],
93        [Self::new(590_768_354), Self::new(978_592_373)],
94        [Self::new(1_179_735_656), Self::new(1_241_207_368)],
95        [Self::new(1_567_857_810), Self::new(456_695_729)],
96        [Self::new(1_774_253_895), Self::new(1_309_288_441)],
97        [Self::new(736_262_640), Self::new(1_553_669_210)],
98        [Self::new(1_819_216_575), Self::new(1_662_816_114)],
99        [Self::new(1_323_191_254), Self::new(1_936_974_060)],
100        [Self::new(605_622_498), Self::new(1_964_232_216)],
101        [Self::new(343_674_985), Self::new(501_786_993)],
102        [Self::new(1_995_316_534), Self::new(149_306_621)],
103        [Self::new(2_107_600_913), Self::new(1_378_821_388)],
104        [Self::new(541_476_169), Self::new(2_101_081_972)],
105        [Self::new(2_135_874_973), Self::new(483_411_332)],
106        [Self::new(2_097_144_245), Self::new(1_684_033_590)],
107        [Self::new(1_662_322_247), Self::new(670_236_780)],
108        [Self::new(1_172_215_635), Self::new(595_888_646)],
109        [Self::new(241_940_101), Self::new(323_856_519)],
110        [Self::new(1_957_194_259), Self::new(2_139_647_100)],
111        [Self::new(1_957_419_629), Self::new(1_541_039_442)],
112        [Self::new(1_062_045_235), Self::new(1_824_580_421)],
113        [Self::new(1_929_382_196), Self::new(1_664_698_822)],
114        [Self::new(1_889_294_251), Self::new(331_248_939)],
115        [Self::new(1_214_231_414), Self::new(1_646_302_518)],
116        [Self::new(1_765_392_370), Self::new(461_136_547)],
117        [Self::new(1_629_751_483), Self::new(66_485_474)],
118        [Self::new(1_501_355_827), Self::new(1_439_063_420)],
119        [Self::new(509_778_402), Self::new(800_467_507)],
120        [Self::new(311_014_874), Self::new(1_584_694_829)],
121        [Self::new(1_166_849_849), Self::new(1_117_296_306)],
122    ];
123}
124
125impl PartialEq for Mersenne31 {
126    #[inline]
127    fn eq(&self, other: &Self) -> bool {
128        self.as_canonical_u32() == other.as_canonical_u32()
129    }
130}
131
132impl Eq for Mersenne31 {}
133
134impl Packable for Mersenne31 {}
135
136impl Hash for Mersenne31 {
137    fn hash<H: Hasher>(&self, state: &mut H) {
138        state.write_u32(self.to_unique_u32());
139    }
140}
141
142impl Ord for Mersenne31 {
143    #[inline]
144    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
145        self.as_canonical_u32().cmp(&other.as_canonical_u32())
146    }
147}
148
149impl PartialOrd for Mersenne31 {
150    #[inline]
151    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
152        Some(self.cmp(other))
153    }
154}
155
156impl Display for Mersenne31 {
157    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
158        Display::fmt(&self.value, f)
159    }
160}
161
162impl Debug for Mersenne31 {
163    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
164        Debug::fmt(&self.value, f)
165    }
166}
167
168impl Distribution<Mersenne31> for StandardUniform {
169    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Mersenne31 {
170        loop {
171            let next_u31 = rng.next_u32() >> 1;
172            let is_canonical = next_u31 != Mersenne31::ORDER_U32;
173            if is_canonical {
174                return Mersenne31::new_reduced(next_u31);
175            }
176        }
177    }
178}
179
180impl Serialize for Mersenne31 {
181    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
182        // Emit the canonical representative so every field element has one encoding.
183        serializer.serialize_u32(self.as_canonical_u32())
184    }
185}
186
187impl<'a> Deserialize<'a> for Mersenne31 {
188    fn deserialize<D: Deserializer<'a>>(d: D) -> Result<Self, D::Error> {
189        let val = u32::deserialize(d)?;
190        // Reject non-canonical encodings so a proof cannot be re-encoded without the witness.
191        // `val == P` is field-equal to 0, so only `[0, P)` is canonical.
192        if val < P {
193            Ok(Self::new_reduced(val))
194        } else {
195            Err(D::Error::custom("Value is out of range"))
196        }
197    }
198}
199
200impl RawDataSerializable for Mersenne31 {
201    impl_raw_serializable_primefield32!();
202}
203
204impl PrimeCharacteristicRing for Mersenne31 {
205    type PrimeSubfield = Self;
206
207    const ZERO: Self = Self { value: 0 };
208    const ONE: Self = Self { value: 1 };
209    const TWO: Self = Self { value: 2 };
210    const NEG_ONE: Self = Self {
211        value: Self::ORDER_U32 - 1,
212    };
213
214    #[inline]
215    fn from_prime_subfield(f: Self::PrimeSubfield) -> Self {
216        f
217    }
218
219    #[inline]
220    fn from_bool(b: bool) -> Self {
221        Self::new_reduced(b as u32)
222    }
223
224    #[inline]
225    fn halve(&self) -> Self {
226        Self::new_reduced(halve_u32::<P>(self.value))
227    }
228
229    #[inline]
230    fn mul_2exp_u64(&self, exp: u64) -> Self {
231        // In a Mersenne field, multiplication by 2^k is just a left rotation by k bits.
232        let exp = exp % 31;
233        let left = (self.value << exp) & ((1 << 31) - 1);
234        let right = self.value >> (31 - exp);
235        let rotated = left | right;
236        Self::new_reduced(rotated)
237    }
238
239    #[inline]
240    fn div_2exp_u64(&self, exp: u64) -> Self {
241        // In a Mersenne field, division by 2^k is just a right rotation by k bits.
242        let exp = (exp % 31) as u8;
243        let left = self.value >> exp;
244        let right = (self.value << (31 - exp)) & ((1 << 31) - 1);
245        let rotated = left | right;
246        Self::new_reduced(rotated)
247    }
248
249    #[inline]
250    fn sum_array<const N: usize>(input: &[Self]) -> Self {
251        assert_eq!(N, input.len());
252        // Benchmarking shows that for N <= 5 it's faster to sum the elements directly
253        // but for N > 5 it's faster to use the .sum() methods which passes through u64's
254        // allowing for delayed reductions.
255        match N {
256            0 => Self::ZERO,
257            1 => input[0],
258            2 => input[0] + input[1],
259            3 => input[0] + input[1] + input[2],
260            4 => (input[0] + input[1]) + (input[2] + input[3]),
261            5 => {
262                let lhs = input[0] + input[1];
263                let rhs = input[2] + input[3];
264                lhs + rhs + input[4]
265            }
266            _ => input.iter().copied().sum(),
267        }
268    }
269
270    #[inline]
271    fn dot_product<const N: usize>(lhs: &[Self; N], rhs: &[Self; N]) -> Self {
272        // Accumulate products as u64 to avoid per-multiply reductions.
273        // For M31: each value < P < 2^31, so product < P^2 < 2^62.
274        // Sum of 4 products < 4 * (P-1)^2 = 2^64 - 2^35 + 16 < 2^64, which fits in u64.
275        match N {
276            0 => Self::ZERO,
277            1 => lhs[0] * rhs[0],
278            2 => {
279                let sum = (lhs[0].value as u64) * (rhs[0].value as u64)
280                    + (lhs[1].value as u64) * (rhs[1].value as u64);
281                Self::new_reduced(reduce_64(sum))
282            }
283            3 => {
284                let sum = (lhs[0].value as u64) * (rhs[0].value as u64)
285                    + (lhs[1].value as u64) * (rhs[1].value as u64)
286                    + (lhs[2].value as u64) * (rhs[2].value as u64);
287                Self::new_reduced(reduce_64(sum))
288            }
289            4 => {
290                let sum = (lhs[0].value as u64) * (rhs[0].value as u64)
291                    + (lhs[1].value as u64) * (rhs[1].value as u64)
292                    + (lhs[2].value as u64) * (rhs[2].value as u64)
293                    + (lhs[3].value as u64) * (rhs[3].value as u64);
294                Self::new_reduced(reduce_64(sum))
295            }
296            _ => {
297                // Process in chunks of 4 with intermediate reductions
298                let mut acc = 0u64;
299                let mut i = 0;
300                while i + 4 <= N {
301                    let chunk_sum = (lhs[i].value as u64) * (rhs[i].value as u64)
302                        + (lhs[i + 1].value as u64) * (rhs[i + 1].value as u64)
303                        + (lhs[i + 2].value as u64) * (rhs[i + 2].value as u64)
304                        + (lhs[i + 3].value as u64) * (rhs[i + 3].value as u64);
305                    // Reduce chunk_sum to ~34 bits and add to accumulator
306                    acc += partial_reduce(chunk_sum);
307                    i += 4;
308                }
309                // Handle remainder
310                while i < N {
311                    acc += (lhs[i].value as u64) * (rhs[i].value as u64);
312                    i += 1;
313                }
314                Self::new_reduced(reduce_64(acc))
315            }
316        }
317    }
318
319    #[inline]
320    fn zero_vec(len: usize) -> Vec<Self> {
321        // SAFETY:
322        // Due to `#[repr(transparent)]`, Mersenne31 and u32 have the same size, alignment
323        // and memory layout making `flatten_to_base` safe. This this will create
324        // a vector Mersenne31 elements with value set to 0.
325        unsafe { flatten_to_base(vec![0u32; len]) }
326    }
327}
328
329// Degree of the smallest permutation polynomial for Mersenne31.
330//
331// As p - 1 = 2×3^2×7×11×... the smallest choice for a degree D satisfying gcd(p - 1, D) = 1 is 5.
332impl InjectiveMonomial<5> for Mersenne31 {}
333
334impl PermutationMonomial<5> for Mersenne31 {
335    /// In the field `Mersenne31`, `a^{1/5}` is equal to a^{1717986917}.
336    ///
337    /// This follows from the calculation `5 * 1717986917 = 4*(2^31 - 2) + 1 = 1 mod p - 1`.
338    fn injective_exp_root_n(&self) -> Self {
339        exp_1717986917(*self)
340    }
341}
342
343impl Field for Mersenne31 {
344    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
345    type Packing = crate::PackedMersenne31Neon;
346    #[cfg(all(
347        target_arch = "x86_64",
348        target_feature = "avx2",
349        not(target_feature = "avx512f")
350    ))]
351    type Packing = crate::PackedMersenne31AVX2;
352    #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
353    type Packing = crate::PackedMersenne31AVX512;
354    #[cfg(not(any(
355        all(target_arch = "aarch64", target_feature = "neon"),
356        all(
357            target_arch = "x86_64",
358            target_feature = "avx2",
359            not(target_feature = "avx512f")
360        ),
361        all(target_arch = "x86_64", target_feature = "avx512f"),
362    )))]
363    type Packing = Self;
364
365    // Sage: GF(2^31 - 1).multiplicative_generator()
366    const GENERATOR: Self = Self::new(7);
367
368    #[inline]
369    fn is_zero(&self) -> bool {
370        self.value == 0 || self.value == Self::ORDER_U32
371    }
372
373    fn try_inverse(&self) -> Option<Self> {
374        if self.is_zero() {
375            return None;
376        }
377
378        // Number of bits in the Mersenne31 prime.
379        const NUM_PRIME_BITS: u32 = 31;
380
381        // gcd_inversion returns the inverse multiplied by 2^60 so we need to correct for that.
382        let inverse_i64 = gcd_inversion_prime_field_32::<NUM_PRIME_BITS>(self.value, P);
383        Some(Self::from_int(inverse_i64).div_2exp_u64(60))
384    }
385
386    #[inline]
387    fn order() -> BigUint {
388        P.into()
389    }
390
391    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
392    #[inline]
393    fn batched_columnwise_dot_product<EF, R, I, const N: usize>(
394        acc: &mut [EF::ExtensionPacking],
395        items: I,
396    ) where
397        EF: p3_field::ExtensionField<Self>,
398        R: Iterator<Item = Self::Packing>,
399        I: Iterator<Item = (R, [EF; N])>,
400    {
401        crate::aarch64_neon::batched_columnwise_dot_product::<EF, R, I, N>(acc, items);
402    }
403}
404
405// We can use some macros to implement QuotientMap<Int> for all integer types except for u32 and i32's.
406quotient_map_small_int!(Mersenne31, u32, [u8, u16]);
407quotient_map_small_int!(Mersenne31, i32, [i8, i16]);
408quotient_map_large_uint!(
409    Mersenne31,
410    u32,
411    Mersenne31::ORDER_U32,
412    "`[0, 2^31 - 2]`",
413    "`[0, 2^31 - 1]`",
414    [u64, u128]
415);
416quotient_map_large_iint!(
417    Mersenne31,
418    i32,
419    "`[-2^30, 2^30]`",
420    "`[1 - 2^31, 2^31 - 1]`",
421    [(i64, u64), (i128, u128)]
422);
423
424// We simple need to prove custom Mersenne31 impls for QuotientMap<u32> and QuotientMap<i32>
425impl QuotientMap<u32> for Mersenne31 {
426    /// Convert a given `u32` integer into an element of the `Mersenne31` field.
427    #[inline]
428    fn from_int(int: u32) -> Self {
429        // To reduce `n` to 31 bits, we clear its MSB, then add it back in its reduced form.
430        let msb = int & (1 << 31);
431        let msb_reduced = msb >> 31;
432        Self::new_reduced(int ^ msb) + Self::new_reduced(msb_reduced)
433    }
434
435    /// Convert a given `u32` integer into an element of the `Mersenne31` field.
436    ///
437    /// Returns none if the input does not lie in the range `[0, 2^31 - 1]`.
438    #[inline]
439    fn from_canonical_checked(int: u32) -> Option<Self> {
440        (int < Self::ORDER_U32).then(|| Self::new_reduced(int))
441    }
442
443    /// Convert a given `u32` integer into an element of the `Mersenne31` field.
444    ///
445    /// # Safety
446    /// The input must lie in the range: `[0, 2^31 - 1]`.
447    #[inline(always)]
448    unsafe fn from_canonical_unchecked(int: u32) -> Self {
449        debug_assert!(int < Self::ORDER_U32);
450        Self::new_reduced(int)
451    }
452}
453
454impl QuotientMap<i32> for Mersenne31 {
455    /// Convert a given `i32` integer into an element of the `Mersenne31` field.
456    #[inline]
457    fn from_int(int: i32) -> Self {
458        if int >= 0 {
459            Self::new_reduced(int as u32)
460        } else if int > (-1 << 31) {
461            Self::new_reduced(Self::ORDER_U32.wrapping_add_signed(int))
462        } else {
463            // The only other option is int = -(2^31) = -1 mod p.
464            Self::NEG_ONE
465        }
466    }
467
468    /// Convert a given `i32` integer into an element of the `Mersenne31` field.
469    ///
470    /// Returns none if the input does not lie in the range `(-2^30, 2^30)`.
471    #[inline]
472    fn from_canonical_checked(int: i32) -> Option<Self> {
473        const TWO_EXP_30: i32 = 1 << 30;
474        const NEG_TWO_EXP_30_PLUS_1: i32 = (-1 << 30) + 1;
475        match int {
476            0..TWO_EXP_30 => Some(Self::new_reduced(int as u32)),
477            NEG_TWO_EXP_30_PLUS_1..0 => {
478                Some(Self::new_reduced(Self::ORDER_U32.wrapping_add_signed(int)))
479            }
480            _ => None,
481        }
482    }
483
484    /// Convert a given `i32` integer into an element of the `Mersenne31` field.
485    ///
486    /// # Safety
487    /// The input must lie in the range: `[1 - 2^31, 2^31 - 1]`.
488    #[inline(always)]
489    unsafe fn from_canonical_unchecked(int: i32) -> Self {
490        if int >= 0 {
491            Self::new_reduced(int as u32)
492        } else {
493            Self::new_reduced(Self::ORDER_U32.wrapping_add_signed(int))
494        }
495    }
496}
497
498impl PrimeField for Mersenne31 {
499    fn as_canonical_biguint(&self) -> BigUint {
500        <Self as PrimeField32>::as_canonical_u32(self).into()
501    }
502}
503
504impl PrimeField32 for Mersenne31 {
505    const ORDER_U32: u32 = P;
506
507    #[inline]
508    fn as_canonical_u32(&self) -> u32 {
509        // Since our invariant guarantees that `value` fits in 31 bits, there is only one possible
510        // `value` that is not canonical, namely 2^31 - 1 = p = 0.
511        if self.value == Self::ORDER_U32 {
512            0
513        } else {
514            self.value
515        }
516    }
517}
518
519impl PrimeField64 for Mersenne31 {
520    const ORDER_U64: u64 = <Self as PrimeField32>::ORDER_U32 as u64;
521
522    #[inline]
523    fn as_canonical_u64(&self) -> u64 {
524        self.as_canonical_u32().into()
525    }
526}
527
528impl Add for Mersenne31 {
529    type Output = Self;
530
531    #[inline]
532    fn add(self, rhs: Self) -> Self {
533        // See the following for a way to compute the sum that avoids
534        // the conditional which may be preferable on some
535        // architectures.
536        // https://github.com/Plonky3/Plonky3/blob/6049a30c3b1f5351c3eb0f7c994dc97e8f68d10d/mersenne-31/src/lib.rs#L249
537
538        // Working with i32 means we get a flag which informs us if overflow happened.
539        let (sum_i32, over) = (self.value as i32).overflowing_add(rhs.value as i32);
540        let sum_u32 = sum_i32 as u32;
541        let sum_corr = sum_u32.wrapping_sub(Self::ORDER_U32);
542
543        // If self + rhs did not overflow, return it.
544        // If self + rhs overflowed, sum_corr = self + rhs - (2**31 - 1).
545        Self::new_reduced(if over { sum_corr } else { sum_u32 })
546    }
547}
548
549impl Sub for Mersenne31 {
550    type Output = Self;
551
552    #[inline]
553    fn sub(self, rhs: Self) -> Self {
554        let (mut sub, over) = self.value.overflowing_sub(rhs.value);
555
556        // If we didn't overflow we have the correct value.
557        // Otherwise we have added 2**32 = 2**31 + 1 mod 2**31 - 1.
558        // Hence we need to remove the most significant bit and subtract 1.
559        sub -= over as u32;
560        Self::new_reduced(sub & Self::ORDER_U32)
561    }
562}
563
564impl Neg for Mersenne31 {
565    type Output = Self;
566
567    #[inline]
568    fn neg(self) -> Self::Output {
569        // Can't underflow, since self.value is 31-bits and thus can't exceed ORDER.
570        Self::new_reduced(Self::ORDER_U32 - self.value)
571    }
572}
573
574impl Mul for Mersenne31 {
575    type Output = Self;
576
577    #[inline]
578    #[allow(clippy::cast_possible_truncation)]
579    fn mul(self, rhs: Self) -> Self {
580        let prod = u64::from(self.value) * u64::from(rhs.value);
581        from_u62(prod)
582    }
583}
584
585impl_add_assign!(Mersenne31);
586impl_sub_assign!(Mersenne31);
587impl_mul_methods!(Mersenne31);
588impl_div_methods!(Mersenne31, Mersenne31);
589
590impl Sum for Mersenne31 {
591    #[inline]
592    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
593        // This is faster than iter.reduce(|x, y| x + y).unwrap_or(Self::ZERO) for iterators of length >= 6.
594        // It assumes that iter.len() < 2^31.
595
596        // This sum will not overflow so long as iter.len() < 2^33.
597        let sum = iter.map(|x| x.value as u64).sum::<u64>();
598
599        // sum is < 2^62 provided iter.len() < 2^31.
600        from_u62(sum)
601    }
602}
603
604/// Perform a partial reduction of a u64 value modulo P = 2^31 - 1.
605/// The result will be contained in [0, 2^34 - 1].
606#[inline(always)]
607pub(crate) const fn partial_reduce(val: u64) -> u64 {
608    // Refer to the full reduction process in `reduce_64`.
609    let lo = (val & (P as u64)) as u32;
610    let hi = val >> 31;
611    lo as u64 + hi
612}
613
614/// Reduce a u64 value modulo P = 2^31 - 1.
615/// Uses the identity: 2^31 ≡ 1 (mod P), so val ≡ (val & P) + (val >> 31) (mod P).
616/// Returns a value in [0, P].
617#[inline(always)]
618pub(crate) fn reduce_64(val: u64) -> u32 {
619    // First reduction: split into low 31 bits and high 33 bits
620    // For val < 2^64: hi < 2^33, lo < 2^31
621    // sum1 = lo + hi < 2^33 + 2^31 < 2^34
622    let lo = (val & (P as u64)) as u32;
623    let hi = val >> 31;
624    let sum1 = lo as u64 + hi;
625
626    // Second reduction: sum1 < 2^34
627    let lo2 = (sum1 & (P as u64)) as u32;
628    let hi2 = (sum1 >> 31) as u32; // hi2 < 2^3 = 8
629    let sum2 = lo2 + hi2; // sum2 < 2^31 + 8
630
631    // Final reduction to [0, P]
632    sum2.min(sum2.wrapping_sub(P))
633}
634
635#[inline(always)]
636pub(crate) fn from_u62(input: u64) -> Mersenne31 {
637    debug_assert!(input < (1 << 62));
638    let input_lo = (input & ((1 << 31) - 1)) as u32;
639    let input_high = (input >> 31) as u32;
640    Mersenne31::new_reduced(input_lo) + Mersenne31::new_reduced(input_high)
641}
642
643impl UniformSamplingField for Mersenne31 {
644    const MAX_SINGLE_SAMPLE_BITS: usize = 16;
645    // For Mersenne31 uniform sampling really only makes sense if we allow rejection sampling.
646    // Sampling 16 bits already has a chance of 3e-5 to require a resample!
647    const SAMPLING_BITS_M: [u64; 64] = {
648        let prime: u64 = P as u64;
649        let mut a = [0u64; 64];
650        let mut k = 0;
651        while k < 64 {
652            if k == 0 {
653                a[k] = prime; // This value is irrelevant in practice. `bits = 0` returns 0 always.
654            } else {
655                // Create a mask to zero out the last k bits
656                let mask = !((1u64 << k) - 1);
657                a[k] = prime & mask;
658            }
659            k += 1;
660        }
661        a
662    };
663}
664
665#[cfg(test)]
666mod tests {
667    use num_bigint::BigUint;
668    use p3_field::{InjectiveMonomial, PermutationMonomial, PrimeCharacteristicRing};
669    use p3_field_testing::{
670        test_field, test_prime_field, test_prime_field_32, test_prime_field_64,
671    };
672
673    use crate::Mersenne31;
674
675    type F = Mersenne31;
676
677    #[test]
678    fn exp_root() {
679        // Confirm that (x^{1/5})^5 = x
680
681        let m1 = F::from_u32(0x34167c58);
682        let m2 = F::from_u32(0x61f3207b);
683
684        assert_eq!(m1.injective_exp_n().injective_exp_root_n(), m1);
685        assert_eq!(m2.injective_exp_n().injective_exp_root_n(), m2);
686        assert_eq!(F::TWO.injective_exp_n().injective_exp_root_n(), F::TWO);
687    }
688
689    #[test]
690    fn serialize_is_canonical() {
691        // `P` is the redundant in-memory representation of zero.
692        // `new_reduced` preserves it; `new` would collapse it to 0.
693        let redundant_zero = F::new_reduced((1 << 31) - 1);
694        assert_eq!(redundant_zero, F::ZERO);
695
696        // Both representations serialize to the single canonical encoding.
697        assert_eq!(serde_json::to_string(&redundant_zero).unwrap(), "0");
698        assert_eq!(serde_json::to_string(&F::ZERO).unwrap(), "0");
699    }
700
701    #[test]
702    fn deserialize_rejects_non_canonical_encodings() {
703        // `P` is field-equal to 0, so its encoding is non-canonical and must be rejected.
704        // This blocks re-encoding a proof as `P` without the witness.
705        let p_json = serde_json::to_string(&((1u32 << 31) - 1)).unwrap();
706        assert!(serde_json::from_str::<F>(&p_json).is_err());
707
708        // The largest canonical value, p - 1, still deserializes.
709        let max_canonical_json = serde_json::to_string(&((1u32 << 31) - 2)).unwrap();
710        let max_canonical: F = serde_json::from_str(&max_canonical_json).unwrap();
711        assert_eq!(max_canonical, F::new((1 << 31) - 2));
712    }
713
714    // Mersenne31 has a redundant representation of Zero but no redundant representation of One.
715    // The second entry deliberately uses `new_reduced` to *preserve* the `value == P` encoding:
716    // the public `new` constructor reduces modulo P and would silently collapse it to 0,
717    // making the redundant-representation tests no-ops (cf. #1684).
718    const ZEROS: [Mersenne31; 2] = [Mersenne31::ZERO, Mersenne31::new_reduced((1_u32 << 31) - 1)];
719    const ONES: [Mersenne31; 1] = [Mersenne31::ONE];
720
721    // Get the prime factorization of the order of the multiplicative group.
722    // i.e. the prime factorization of P - 1.
723    fn multiplicative_group_prime_factorization() -> [(BigUint, u32); 7] {
724        [
725            (BigUint::from(2u8), 1),
726            (BigUint::from(3u8), 2),
727            (BigUint::from(7u8), 1),
728            (BigUint::from(11u8), 1),
729            (BigUint::from(31u8), 1),
730            (BigUint::from(151u8), 1),
731            (BigUint::from(331u16), 1),
732        ]
733    }
734
735    test_field!(
736        crate::Mersenne31,
737        &super::ZEROS,
738        &super::ONES,
739        &super::multiplicative_group_prime_factorization()
740    );
741    test_prime_field!(crate::Mersenne31);
742    test_prime_field_64!(crate::Mersenne31, &super::ZEROS, &super::ONES);
743    test_prime_field_32!(crate::Mersenne31, &super::ZEROS, &super::ONES);
744}