1use alloc::vec;
2use alloc::vec::Vec;
3use core::fmt::{Debug, Display, Formatter};
4use core::hash::{Hash, Hasher};
5use core::iter::{Product, Sum};
6use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
7use core::{array, fmt, iter};
8
9use num_bigint::BigUint;
10use p3_challenger::UniformSamplingField;
11use p3_field::exponentiation::exp_1717986917;
12use p3_field::integers::QuotientMap;
13use p3_field::op_assign_macros::{
14 impl_add_assign, impl_div_methods, impl_mul_methods, impl_sub_assign,
15};
16use p3_field::{
17 Field, InjectiveMonomial, Packable, PermutationMonomial, PrimeCharacteristicRing, PrimeField,
18 PrimeField32, PrimeField64, RawDataSerializable, halve_u32, impl_raw_serializable_primefield32,
19 quotient_map_large_iint, quotient_map_large_uint, quotient_map_small_int,
20};
21use p3_util::{flatten_to_base, gcd_inversion_prime_field_32};
22use rand::Rng;
23use rand::distr::{Distribution, StandardUniform};
24use serde::de::Error;
25use serde::{Deserialize, Deserializer, Serialize};
26
27const P: u32 = (1 << 31) - 1;
29
30#[derive(Copy, Clone, Default)]
34#[repr(transparent)] #[must_use]
36pub struct Mersenne31 {
37 pub(crate) value: u32,
39}
40
41impl Mersenne31 {
42 #[inline]
46 pub const fn new(value: u32) -> Self {
47 Self { value: value % P }
48 }
49
50 #[inline]
55 pub(crate) const fn new_reduced(value: u32) -> Self {
56 debug_assert!((value >> 31) == 0);
57 Self { value }
58 }
59
60 #[inline]
64 pub const fn new_checked(value: u32) -> Option<Self> {
65 if (value >> 31) == 0 {
66 Some(Self { value })
67 } else {
68 None
69 }
70 }
71
72 #[inline]
76 pub const fn new_array<const N: usize>(input: [u32; N]) -> [Self; N] {
77 let mut output = [Self::ZERO; N];
78 let mut i = 0;
79 while i < N {
80 output[i].value = input[i] % P;
81 i += 1;
82 }
83 output
84 }
85
86 pub const EXT_TWO_ADIC_GENERATORS: [[Self; 2]; 33] = [
89 [Self::ONE, Self::ZERO],
90 [Self::new(2_147_483_646), Self::new(0)],
91 [Self::new(0), Self::new(2_147_483_646)],
92 [Self::new(32_768), Self::new(2_147_450_879)],
93 [Self::new(590_768_354), Self::new(978_592_373)],
94 [Self::new(1_179_735_656), Self::new(1_241_207_368)],
95 [Self::new(1_567_857_810), Self::new(456_695_729)],
96 [Self::new(1_774_253_895), Self::new(1_309_288_441)],
97 [Self::new(736_262_640), Self::new(1_553_669_210)],
98 [Self::new(1_819_216_575), Self::new(1_662_816_114)],
99 [Self::new(1_323_191_254), Self::new(1_936_974_060)],
100 [Self::new(605_622_498), Self::new(1_964_232_216)],
101 [Self::new(343_674_985), Self::new(501_786_993)],
102 [Self::new(1_995_316_534), Self::new(149_306_621)],
103 [Self::new(2_107_600_913), Self::new(1_378_821_388)],
104 [Self::new(541_476_169), Self::new(2_101_081_972)],
105 [Self::new(2_135_874_973), Self::new(483_411_332)],
106 [Self::new(2_097_144_245), Self::new(1_684_033_590)],
107 [Self::new(1_662_322_247), Self::new(670_236_780)],
108 [Self::new(1_172_215_635), Self::new(595_888_646)],
109 [Self::new(241_940_101), Self::new(323_856_519)],
110 [Self::new(1_957_194_259), Self::new(2_139_647_100)],
111 [Self::new(1_957_419_629), Self::new(1_541_039_442)],
112 [Self::new(1_062_045_235), Self::new(1_824_580_421)],
113 [Self::new(1_929_382_196), Self::new(1_664_698_822)],
114 [Self::new(1_889_294_251), Self::new(331_248_939)],
115 [Self::new(1_214_231_414), Self::new(1_646_302_518)],
116 [Self::new(1_765_392_370), Self::new(461_136_547)],
117 [Self::new(1_629_751_483), Self::new(66_485_474)],
118 [Self::new(1_501_355_827), Self::new(1_439_063_420)],
119 [Self::new(509_778_402), Self::new(800_467_507)],
120 [Self::new(311_014_874), Self::new(1_584_694_829)],
121 [Self::new(1_166_849_849), Self::new(1_117_296_306)],
122 ];
123}
124
125impl PartialEq for Mersenne31 {
126 #[inline]
127 fn eq(&self, other: &Self) -> bool {
128 self.as_canonical_u32() == other.as_canonical_u32()
129 }
130}
131
132impl Eq for Mersenne31 {}
133
134impl Packable for Mersenne31 {}
135
136impl Hash for Mersenne31 {
137 fn hash<H: Hasher>(&self, state: &mut H) {
138 state.write_u32(self.to_unique_u32());
139 }
140}
141
142impl Ord for Mersenne31 {
143 #[inline]
144 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
145 self.as_canonical_u32().cmp(&other.as_canonical_u32())
146 }
147}
148
149impl PartialOrd for Mersenne31 {
150 #[inline]
151 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
152 Some(self.cmp(other))
153 }
154}
155
156impl Display for Mersenne31 {
157 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
158 Display::fmt(&self.value, f)
159 }
160}
161
162impl Debug for Mersenne31 {
163 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
164 Debug::fmt(&self.value, f)
165 }
166}
167
168impl Distribution<Mersenne31> for StandardUniform {
169 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Mersenne31 {
170 loop {
171 let next_u31 = rng.next_u32() >> 1;
172 let is_canonical = next_u31 != Mersenne31::ORDER_U32;
173 if is_canonical {
174 return Mersenne31::new_reduced(next_u31);
175 }
176 }
177 }
178}
179
180impl Serialize for Mersenne31 {
181 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
182 serializer.serialize_u32(self.as_canonical_u32())
184 }
185}
186
187impl<'a> Deserialize<'a> for Mersenne31 {
188 fn deserialize<D: Deserializer<'a>>(d: D) -> Result<Self, D::Error> {
189 let val = u32::deserialize(d)?;
190 if val < P {
193 Ok(Self::new_reduced(val))
194 } else {
195 Err(D::Error::custom("Value is out of range"))
196 }
197 }
198}
199
200impl RawDataSerializable for Mersenne31 {
201 impl_raw_serializable_primefield32!();
202}
203
204impl PrimeCharacteristicRing for Mersenne31 {
205 type PrimeSubfield = Self;
206
207 const ZERO: Self = Self { value: 0 };
208 const ONE: Self = Self { value: 1 };
209 const TWO: Self = Self { value: 2 };
210 const NEG_ONE: Self = Self {
211 value: Self::ORDER_U32 - 1,
212 };
213
214 #[inline]
215 fn from_prime_subfield(f: Self::PrimeSubfield) -> Self {
216 f
217 }
218
219 #[inline]
220 fn from_bool(b: bool) -> Self {
221 Self::new_reduced(b as u32)
222 }
223
224 #[inline]
225 fn halve(&self) -> Self {
226 Self::new_reduced(halve_u32::<P>(self.value))
227 }
228
229 #[inline]
230 fn mul_2exp_u64(&self, exp: u64) -> Self {
231 let exp = exp % 31;
233 let left = (self.value << exp) & ((1 << 31) - 1);
234 let right = self.value >> (31 - exp);
235 let rotated = left | right;
236 Self::new_reduced(rotated)
237 }
238
239 #[inline]
240 fn div_2exp_u64(&self, exp: u64) -> Self {
241 let exp = (exp % 31) as u8;
243 let left = self.value >> exp;
244 let right = (self.value << (31 - exp)) & ((1 << 31) - 1);
245 let rotated = left | right;
246 Self::new_reduced(rotated)
247 }
248
249 #[inline]
250 fn sum_array<const N: usize>(input: &[Self]) -> Self {
251 assert_eq!(N, input.len());
252 match N {
256 0 => Self::ZERO,
257 1 => input[0],
258 2 => input[0] + input[1],
259 3 => input[0] + input[1] + input[2],
260 4 => (input[0] + input[1]) + (input[2] + input[3]),
261 5 => {
262 let lhs = input[0] + input[1];
263 let rhs = input[2] + input[3];
264 lhs + rhs + input[4]
265 }
266 _ => input.iter().copied().sum(),
267 }
268 }
269
270 #[inline]
271 fn dot_product<const N: usize>(lhs: &[Self; N], rhs: &[Self; N]) -> Self {
272 match N {
276 0 => Self::ZERO,
277 1 => lhs[0] * rhs[0],
278 2 => {
279 let sum = (lhs[0].value as u64) * (rhs[0].value as u64)
280 + (lhs[1].value as u64) * (rhs[1].value as u64);
281 Self::new_reduced(reduce_64(sum))
282 }
283 3 => {
284 let sum = (lhs[0].value as u64) * (rhs[0].value as u64)
285 + (lhs[1].value as u64) * (rhs[1].value as u64)
286 + (lhs[2].value as u64) * (rhs[2].value as u64);
287 Self::new_reduced(reduce_64(sum))
288 }
289 4 => {
290 let sum = (lhs[0].value as u64) * (rhs[0].value as u64)
291 + (lhs[1].value as u64) * (rhs[1].value as u64)
292 + (lhs[2].value as u64) * (rhs[2].value as u64)
293 + (lhs[3].value as u64) * (rhs[3].value as u64);
294 Self::new_reduced(reduce_64(sum))
295 }
296 _ => {
297 let mut acc = 0u64;
299 let mut i = 0;
300 while i + 4 <= N {
301 let chunk_sum = (lhs[i].value as u64) * (rhs[i].value as u64)
302 + (lhs[i + 1].value as u64) * (rhs[i + 1].value as u64)
303 + (lhs[i + 2].value as u64) * (rhs[i + 2].value as u64)
304 + (lhs[i + 3].value as u64) * (rhs[i + 3].value as u64);
305 acc += partial_reduce(chunk_sum);
307 i += 4;
308 }
309 while i < N {
311 acc += (lhs[i].value as u64) * (rhs[i].value as u64);
312 i += 1;
313 }
314 Self::new_reduced(reduce_64(acc))
315 }
316 }
317 }
318
319 #[inline]
320 fn zero_vec(len: usize) -> Vec<Self> {
321 unsafe { flatten_to_base(vec![0u32; len]) }
326 }
327}
328
329impl InjectiveMonomial<5> for Mersenne31 {}
333
334impl PermutationMonomial<5> for Mersenne31 {
335 fn injective_exp_root_n(&self) -> Self {
339 exp_1717986917(*self)
340 }
341}
342
343impl Field for Mersenne31 {
344 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
345 type Packing = crate::PackedMersenne31Neon;
346 #[cfg(all(
347 target_arch = "x86_64",
348 target_feature = "avx2",
349 not(target_feature = "avx512f")
350 ))]
351 type Packing = crate::PackedMersenne31AVX2;
352 #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
353 type Packing = crate::PackedMersenne31AVX512;
354 #[cfg(not(any(
355 all(target_arch = "aarch64", target_feature = "neon"),
356 all(
357 target_arch = "x86_64",
358 target_feature = "avx2",
359 not(target_feature = "avx512f")
360 ),
361 all(target_arch = "x86_64", target_feature = "avx512f"),
362 )))]
363 type Packing = Self;
364
365 const GENERATOR: Self = Self::new(7);
367
368 #[inline]
369 fn is_zero(&self) -> bool {
370 self.value == 0 || self.value == Self::ORDER_U32
371 }
372
373 fn try_inverse(&self) -> Option<Self> {
374 if self.is_zero() {
375 return None;
376 }
377
378 const NUM_PRIME_BITS: u32 = 31;
380
381 let inverse_i64 = gcd_inversion_prime_field_32::<NUM_PRIME_BITS>(self.value, P);
383 Some(Self::from_int(inverse_i64).div_2exp_u64(60))
384 }
385
386 #[inline]
387 fn order() -> BigUint {
388 P.into()
389 }
390
391 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
392 #[inline]
393 fn batched_columnwise_dot_product<EF, R, I, const N: usize>(
394 acc: &mut [EF::ExtensionPacking],
395 items: I,
396 ) where
397 EF: p3_field::ExtensionField<Self>,
398 R: Iterator<Item = Self::Packing>,
399 I: Iterator<Item = (R, [EF; N])>,
400 {
401 crate::aarch64_neon::batched_columnwise_dot_product::<EF, R, I, N>(acc, items);
402 }
403}
404
405quotient_map_small_int!(Mersenne31, u32, [u8, u16]);
407quotient_map_small_int!(Mersenne31, i32, [i8, i16]);
408quotient_map_large_uint!(
409 Mersenne31,
410 u32,
411 Mersenne31::ORDER_U32,
412 "`[0, 2^31 - 2]`",
413 "`[0, 2^31 - 1]`",
414 [u64, u128]
415);
416quotient_map_large_iint!(
417 Mersenne31,
418 i32,
419 "`[-2^30, 2^30]`",
420 "`[1 - 2^31, 2^31 - 1]`",
421 [(i64, u64), (i128, u128)]
422);
423
424impl QuotientMap<u32> for Mersenne31 {
426 #[inline]
428 fn from_int(int: u32) -> Self {
429 let msb = int & (1 << 31);
431 let msb_reduced = msb >> 31;
432 Self::new_reduced(int ^ msb) + Self::new_reduced(msb_reduced)
433 }
434
435 #[inline]
439 fn from_canonical_checked(int: u32) -> Option<Self> {
440 (int < Self::ORDER_U32).then(|| Self::new_reduced(int))
441 }
442
443 #[inline(always)]
448 unsafe fn from_canonical_unchecked(int: u32) -> Self {
449 debug_assert!(int < Self::ORDER_U32);
450 Self::new_reduced(int)
451 }
452}
453
454impl QuotientMap<i32> for Mersenne31 {
455 #[inline]
457 fn from_int(int: i32) -> Self {
458 if int >= 0 {
459 Self::new_reduced(int as u32)
460 } else if int > (-1 << 31) {
461 Self::new_reduced(Self::ORDER_U32.wrapping_add_signed(int))
462 } else {
463 Self::NEG_ONE
465 }
466 }
467
468 #[inline]
472 fn from_canonical_checked(int: i32) -> Option<Self> {
473 const TWO_EXP_30: i32 = 1 << 30;
474 const NEG_TWO_EXP_30_PLUS_1: i32 = (-1 << 30) + 1;
475 match int {
476 0..TWO_EXP_30 => Some(Self::new_reduced(int as u32)),
477 NEG_TWO_EXP_30_PLUS_1..0 => {
478 Some(Self::new_reduced(Self::ORDER_U32.wrapping_add_signed(int)))
479 }
480 _ => None,
481 }
482 }
483
484 #[inline(always)]
489 unsafe fn from_canonical_unchecked(int: i32) -> Self {
490 if int >= 0 {
491 Self::new_reduced(int as u32)
492 } else {
493 Self::new_reduced(Self::ORDER_U32.wrapping_add_signed(int))
494 }
495 }
496}
497
498impl PrimeField for Mersenne31 {
499 fn as_canonical_biguint(&self) -> BigUint {
500 <Self as PrimeField32>::as_canonical_u32(self).into()
501 }
502}
503
504impl PrimeField32 for Mersenne31 {
505 const ORDER_U32: u32 = P;
506
507 #[inline]
508 fn as_canonical_u32(&self) -> u32 {
509 if self.value == Self::ORDER_U32 {
512 0
513 } else {
514 self.value
515 }
516 }
517}
518
519impl PrimeField64 for Mersenne31 {
520 const ORDER_U64: u64 = <Self as PrimeField32>::ORDER_U32 as u64;
521
522 #[inline]
523 fn as_canonical_u64(&self) -> u64 {
524 self.as_canonical_u32().into()
525 }
526}
527
528impl Add for Mersenne31 {
529 type Output = Self;
530
531 #[inline]
532 fn add(self, rhs: Self) -> Self {
533 let (sum_i32, over) = (self.value as i32).overflowing_add(rhs.value as i32);
540 let sum_u32 = sum_i32 as u32;
541 let sum_corr = sum_u32.wrapping_sub(Self::ORDER_U32);
542
543 Self::new_reduced(if over { sum_corr } else { sum_u32 })
546 }
547}
548
549impl Sub for Mersenne31 {
550 type Output = Self;
551
552 #[inline]
553 fn sub(self, rhs: Self) -> Self {
554 let (mut sub, over) = self.value.overflowing_sub(rhs.value);
555
556 sub -= over as u32;
560 Self::new_reduced(sub & Self::ORDER_U32)
561 }
562}
563
564impl Neg for Mersenne31 {
565 type Output = Self;
566
567 #[inline]
568 fn neg(self) -> Self::Output {
569 Self::new_reduced(Self::ORDER_U32 - self.value)
571 }
572}
573
574impl Mul for Mersenne31 {
575 type Output = Self;
576
577 #[inline]
578 #[allow(clippy::cast_possible_truncation)]
579 fn mul(self, rhs: Self) -> Self {
580 let prod = u64::from(self.value) * u64::from(rhs.value);
581 from_u62(prod)
582 }
583}
584
585impl_add_assign!(Mersenne31);
586impl_sub_assign!(Mersenne31);
587impl_mul_methods!(Mersenne31);
588impl_div_methods!(Mersenne31, Mersenne31);
589
590impl Sum for Mersenne31 {
591 #[inline]
592 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
593 let sum = iter.map(|x| x.value as u64).sum::<u64>();
598
599 from_u62(sum)
601 }
602}
603
604#[inline(always)]
607pub(crate) const fn partial_reduce(val: u64) -> u64 {
608 let lo = (val & (P as u64)) as u32;
610 let hi = val >> 31;
611 lo as u64 + hi
612}
613
614#[inline(always)]
618pub(crate) fn reduce_64(val: u64) -> u32 {
619 let lo = (val & (P as u64)) as u32;
623 let hi = val >> 31;
624 let sum1 = lo as u64 + hi;
625
626 let lo2 = (sum1 & (P as u64)) as u32;
628 let hi2 = (sum1 >> 31) as u32; let sum2 = lo2 + hi2; sum2.min(sum2.wrapping_sub(P))
633}
634
635#[inline(always)]
636pub(crate) fn from_u62(input: u64) -> Mersenne31 {
637 debug_assert!(input < (1 << 62));
638 let input_lo = (input & ((1 << 31) - 1)) as u32;
639 let input_high = (input >> 31) as u32;
640 Mersenne31::new_reduced(input_lo) + Mersenne31::new_reduced(input_high)
641}
642
643impl UniformSamplingField for Mersenne31 {
644 const MAX_SINGLE_SAMPLE_BITS: usize = 16;
645 const SAMPLING_BITS_M: [u64; 64] = {
648 let prime: u64 = P as u64;
649 let mut a = [0u64; 64];
650 let mut k = 0;
651 while k < 64 {
652 if k == 0 {
653 a[k] = prime; } else {
655 let mask = !((1u64 << k) - 1);
657 a[k] = prime & mask;
658 }
659 k += 1;
660 }
661 a
662 };
663}
664
665#[cfg(test)]
666mod tests {
667 use num_bigint::BigUint;
668 use p3_field::{InjectiveMonomial, PermutationMonomial, PrimeCharacteristicRing};
669 use p3_field_testing::{
670 test_field, test_prime_field, test_prime_field_32, test_prime_field_64,
671 };
672
673 use crate::Mersenne31;
674
675 type F = Mersenne31;
676
677 #[test]
678 fn exp_root() {
679 let m1 = F::from_u32(0x34167c58);
682 let m2 = F::from_u32(0x61f3207b);
683
684 assert_eq!(m1.injective_exp_n().injective_exp_root_n(), m1);
685 assert_eq!(m2.injective_exp_n().injective_exp_root_n(), m2);
686 assert_eq!(F::TWO.injective_exp_n().injective_exp_root_n(), F::TWO);
687 }
688
689 #[test]
690 fn serialize_is_canonical() {
691 let redundant_zero = F::new_reduced((1 << 31) - 1);
694 assert_eq!(redundant_zero, F::ZERO);
695
696 assert_eq!(serde_json::to_string(&redundant_zero).unwrap(), "0");
698 assert_eq!(serde_json::to_string(&F::ZERO).unwrap(), "0");
699 }
700
701 #[test]
702 fn deserialize_rejects_non_canonical_encodings() {
703 let p_json = serde_json::to_string(&((1u32 << 31) - 1)).unwrap();
706 assert!(serde_json::from_str::<F>(&p_json).is_err());
707
708 let max_canonical_json = serde_json::to_string(&((1u32 << 31) - 2)).unwrap();
710 let max_canonical: F = serde_json::from_str(&max_canonical_json).unwrap();
711 assert_eq!(max_canonical, F::new((1 << 31) - 2));
712 }
713
714 const ZEROS: [Mersenne31; 2] = [Mersenne31::ZERO, Mersenne31::new_reduced((1_u32 << 31) - 1)];
719 const ONES: [Mersenne31; 1] = [Mersenne31::ONE];
720
721 fn multiplicative_group_prime_factorization() -> [(BigUint, u32); 7] {
724 [
725 (BigUint::from(2u8), 1),
726 (BigUint::from(3u8), 2),
727 (BigUint::from(7u8), 1),
728 (BigUint::from(11u8), 1),
729 (BigUint::from(31u8), 1),
730 (BigUint::from(151u8), 1),
731 (BigUint::from(331u16), 1),
732 ]
733 }
734
735 test_field!(
736 crate::Mersenne31,
737 &super::ZEROS,
738 &super::ONES,
739 &super::multiplicative_group_prime_factorization()
740 );
741 test_prime_field!(crate::Mersenne31);
742 test_prime_field_64!(crate::Mersenne31, &super::ZEROS, &super::ONES);
743 test_prime_field_32!(crate::Mersenne31, &super::ZEROS, &super::ONES);
744}