1use alloc::vec;
2use alloc::vec::Vec;
3use core::fmt::{Debug, Display, Formatter};
4use core::hash::{Hash, Hasher};
5use core::iter::{Product, Sum};
6use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
7use core::{array, fmt, iter};
8
9use num_bigint::BigUint;
10use p3_challenger::UniformSamplingField;
11use p3_field::exponentiation::exp_1717986917;
12use p3_field::integers::QuotientMap;
13use p3_field::op_assign_macros::{
14 impl_add_assign, impl_div_methods, impl_mul_methods, impl_sub_assign,
15};
16use p3_field::{
17 Field, InjectiveMonomial, Packable, PermutationMonomial, PrimeCharacteristicRing, PrimeField,
18 PrimeField32, PrimeField64, RawDataSerializable, halve_u32, impl_raw_serializable_primefield32,
19 quotient_map_large_iint, quotient_map_large_uint, quotient_map_small_int,
20};
21use p3_util::{flatten_to_base, gcd_inversion_prime_field_32};
22use rand::Rng;
23use rand::distr::{Distribution, StandardUniform};
24use serde::de::Error;
25use serde::{Deserialize, Deserializer, Serialize};
26
27const P: u32 = (1 << 31) - 1;
29
30#[derive(Copy, Clone, Default)]
32#[repr(transparent)] #[must_use]
34pub struct Mersenne31 {
35 pub(crate) value: u32,
37}
38
39impl Mersenne31 {
40 #[inline]
44 pub const fn new(value: u32) -> Self {
45 Self { value: value % P }
46 }
47
48 #[inline]
53 pub(crate) const fn new_reduced(value: u32) -> Self {
54 debug_assert!((value >> 31) == 0);
55 Self { value }
56 }
57
58 #[inline]
62 pub const fn new_checked(value: u32) -> Option<Self> {
63 if (value >> 31) == 0 {
64 Some(Self { value })
65 } else {
66 None
67 }
68 }
69
70 #[inline]
74 pub const fn new_array<const N: usize>(input: [u32; N]) -> [Self; N] {
75 let mut output = [Self::ZERO; N];
76 let mut i = 0;
77 while i < N {
78 output[i].value = input[i] % P;
79 i += 1;
80 }
81 output
82 }
83
84 pub const EXT_TWO_ADIC_GENERATORS: [[Self; 2]; 33] = [
87 [Self::ONE, Self::ZERO],
88 [Self::new(2_147_483_646), Self::new(0)],
89 [Self::new(0), Self::new(2_147_483_646)],
90 [Self::new(32_768), Self::new(2_147_450_879)],
91 [Self::new(590_768_354), Self::new(978_592_373)],
92 [Self::new(1_179_735_656), Self::new(1_241_207_368)],
93 [Self::new(1_567_857_810), Self::new(456_695_729)],
94 [Self::new(1_774_253_895), Self::new(1_309_288_441)],
95 [Self::new(736_262_640), Self::new(1_553_669_210)],
96 [Self::new(1_819_216_575), Self::new(1_662_816_114)],
97 [Self::new(1_323_191_254), Self::new(1_936_974_060)],
98 [Self::new(605_622_498), Self::new(1_964_232_216)],
99 [Self::new(343_674_985), Self::new(501_786_993)],
100 [Self::new(1_995_316_534), Self::new(149_306_621)],
101 [Self::new(2_107_600_913), Self::new(1_378_821_388)],
102 [Self::new(541_476_169), Self::new(2_101_081_972)],
103 [Self::new(2_135_874_973), Self::new(483_411_332)],
104 [Self::new(2_097_144_245), Self::new(1_684_033_590)],
105 [Self::new(1_662_322_247), Self::new(670_236_780)],
106 [Self::new(1_172_215_635), Self::new(595_888_646)],
107 [Self::new(241_940_101), Self::new(323_856_519)],
108 [Self::new(1_957_194_259), Self::new(2_139_647_100)],
109 [Self::new(1_957_419_629), Self::new(1_541_039_442)],
110 [Self::new(1_062_045_235), Self::new(1_824_580_421)],
111 [Self::new(1_929_382_196), Self::new(1_664_698_822)],
112 [Self::new(1_889_294_251), Self::new(331_248_939)],
113 [Self::new(1_214_231_414), Self::new(1_646_302_518)],
114 [Self::new(1_765_392_370), Self::new(461_136_547)],
115 [Self::new(1_629_751_483), Self::new(66_485_474)],
116 [Self::new(1_501_355_827), Self::new(1_439_063_420)],
117 [Self::new(509_778_402), Self::new(800_467_507)],
118 [Self::new(311_014_874), Self::new(1_584_694_829)],
119 [Self::new(1_166_849_849), Self::new(1_117_296_306)],
120 ];
121}
122
123impl PartialEq for Mersenne31 {
124 #[inline]
125 fn eq(&self, other: &Self) -> bool {
126 self.as_canonical_u32() == other.as_canonical_u32()
127 }
128}
129
130impl Eq for Mersenne31 {}
131
132impl Packable for Mersenne31 {}
133
134impl Hash for Mersenne31 {
135 fn hash<H: Hasher>(&self, state: &mut H) {
136 state.write_u32(self.to_unique_u32());
137 }
138}
139
140impl Ord for Mersenne31 {
141 #[inline]
142 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
143 self.as_canonical_u32().cmp(&other.as_canonical_u32())
144 }
145}
146
147impl PartialOrd for Mersenne31 {
148 #[inline]
149 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
150 Some(self.cmp(other))
151 }
152}
153
154impl Display for Mersenne31 {
155 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
156 Display::fmt(&self.value, f)
157 }
158}
159
160impl Debug for Mersenne31 {
161 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
162 Debug::fmt(&self.value, f)
163 }
164}
165
166impl Distribution<Mersenne31> for StandardUniform {
167 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Mersenne31 {
168 loop {
169 let next_u31 = rng.next_u32() >> 1;
170 let is_canonical = next_u31 != Mersenne31::ORDER_U32;
171 if is_canonical {
172 return Mersenne31::new_reduced(next_u31);
173 }
174 }
175 }
176}
177
178impl Serialize for Mersenne31 {
179 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
180 serializer.serialize_u32(self.value)
182 }
183}
184
185impl<'a> Deserialize<'a> for Mersenne31 {
186 fn deserialize<D: Deserializer<'a>>(d: D) -> Result<Self, D::Error> {
187 let val = u32::deserialize(d)?;
188 if val <= P {
190 Ok(Self::new_reduced(val))
191 } else {
192 Err(D::Error::custom("Value is out of range"))
193 }
194 }
195}
196
197impl RawDataSerializable for Mersenne31 {
198 impl_raw_serializable_primefield32!();
199}
200
201impl PrimeCharacteristicRing for Mersenne31 {
202 type PrimeSubfield = Self;
203
204 const ZERO: Self = Self { value: 0 };
205 const ONE: Self = Self { value: 1 };
206 const TWO: Self = Self { value: 2 };
207 const NEG_ONE: Self = Self {
208 value: Self::ORDER_U32 - 1,
209 };
210
211 #[inline]
212 fn from_prime_subfield(f: Self::PrimeSubfield) -> Self {
213 f
214 }
215
216 #[inline]
217 fn from_bool(b: bool) -> Self {
218 Self::new_reduced(b as u32)
219 }
220
221 #[inline]
222 fn halve(&self) -> Self {
223 Self::new_reduced(halve_u32::<P>(self.value))
224 }
225
226 #[inline]
227 fn mul_2exp_u64(&self, exp: u64) -> Self {
228 let exp = exp % 31;
230 let left = (self.value << exp) & ((1 << 31) - 1);
231 let right = self.value >> (31 - exp);
232 let rotated = left | right;
233 Self::new_reduced(rotated)
234 }
235
236 #[inline]
237 fn div_2exp_u64(&self, exp: u64) -> Self {
238 let exp = (exp % 31) as u8;
240 let left = self.value >> exp;
241 let right = (self.value << (31 - exp)) & ((1 << 31) - 1);
242 let rotated = left | right;
243 Self::new_reduced(rotated)
244 }
245
246 #[inline]
247 fn sum_array<const N: usize>(input: &[Self]) -> Self {
248 assert_eq!(N, input.len());
249 match N {
253 0 => Self::ZERO,
254 1 => input[0],
255 2 => input[0] + input[1],
256 3 => input[0] + input[1] + input[2],
257 4 => (input[0] + input[1]) + (input[2] + input[3]),
258 5 => {
259 let lhs = input[0] + input[1];
260 let rhs = input[2] + input[3];
261 lhs + rhs + input[4]
262 }
263 _ => input.iter().copied().sum(),
264 }
265 }
266
267 #[inline]
268 fn dot_product<const N: usize>(lhs: &[Self; N], rhs: &[Self; N]) -> Self {
269 match N {
273 0 => Self::ZERO,
274 1 => lhs[0] * rhs[0],
275 2 => {
276 let sum = (lhs[0].value as u64) * (rhs[0].value as u64)
277 + (lhs[1].value as u64) * (rhs[1].value as u64);
278 Self::new_reduced(reduce_64(sum))
279 }
280 3 => {
281 let sum = (lhs[0].value as u64) * (rhs[0].value as u64)
282 + (lhs[1].value as u64) * (rhs[1].value as u64)
283 + (lhs[2].value as u64) * (rhs[2].value as u64);
284 Self::new_reduced(reduce_64(sum))
285 }
286 4 => {
287 let sum = (lhs[0].value as u64) * (rhs[0].value as u64)
288 + (lhs[1].value as u64) * (rhs[1].value as u64)
289 + (lhs[2].value as u64) * (rhs[2].value as u64)
290 + (lhs[3].value as u64) * (rhs[3].value as u64);
291 Self::new_reduced(reduce_64(sum))
292 }
293 _ => {
294 let mut acc = 0u64;
296 let mut i = 0;
297 while i + 4 <= N {
298 let chunk_sum = (lhs[i].value as u64) * (rhs[i].value as u64)
299 + (lhs[i + 1].value as u64) * (rhs[i + 1].value as u64)
300 + (lhs[i + 2].value as u64) * (rhs[i + 2].value as u64)
301 + (lhs[i + 3].value as u64) * (rhs[i + 3].value as u64);
302 acc += partial_reduce(chunk_sum);
304 i += 4;
305 }
306 while i < N {
308 acc += (lhs[i].value as u64) * (rhs[i].value as u64);
309 i += 1;
310 }
311 Self::new_reduced(reduce_64(acc))
312 }
313 }
314 }
315
316 #[inline]
317 fn zero_vec(len: usize) -> Vec<Self> {
318 unsafe { flatten_to_base(vec![0u32; len]) }
323 }
324}
325
326impl InjectiveMonomial<5> for Mersenne31 {}
330
331impl PermutationMonomial<5> for Mersenne31 {
332 fn injective_exp_root_n(&self) -> Self {
336 exp_1717986917(*self)
337 }
338}
339
340impl Field for Mersenne31 {
341 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
342 type Packing = crate::PackedMersenne31Neon;
343 #[cfg(all(
344 target_arch = "x86_64",
345 target_feature = "avx2",
346 not(target_feature = "avx512f")
347 ))]
348 type Packing = crate::PackedMersenne31AVX2;
349 #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
350 type Packing = crate::PackedMersenne31AVX512;
351 #[cfg(not(any(
352 all(target_arch = "aarch64", target_feature = "neon"),
353 all(
354 target_arch = "x86_64",
355 target_feature = "avx2",
356 not(target_feature = "avx512f")
357 ),
358 all(target_arch = "x86_64", target_feature = "avx512f"),
359 )))]
360 type Packing = Self;
361
362 const GENERATOR: Self = Self::new(7);
364
365 #[inline]
366 fn is_zero(&self) -> bool {
367 self.value == 0 || self.value == Self::ORDER_U32
368 }
369
370 fn try_inverse(&self) -> Option<Self> {
371 if self.is_zero() {
372 return None;
373 }
374
375 const NUM_PRIME_BITS: u32 = 31;
377
378 let inverse_i64 = gcd_inversion_prime_field_32::<NUM_PRIME_BITS>(self.value, P);
380 Some(Self::from_int(inverse_i64).div_2exp_u64(60))
381 }
382
383 #[inline]
384 fn order() -> BigUint {
385 P.into()
386 }
387}
388
389quotient_map_small_int!(Mersenne31, u32, [u8, u16]);
391quotient_map_small_int!(Mersenne31, i32, [i8, i16]);
392quotient_map_large_uint!(
393 Mersenne31,
394 u32,
395 Mersenne31::ORDER_U32,
396 "`[0, 2^31 - 2]`",
397 "`[0, 2^31 - 1]`",
398 [u64, u128]
399);
400quotient_map_large_iint!(
401 Mersenne31,
402 i32,
403 "`[-2^30, 2^30]`",
404 "`[1 - 2^31, 2^31 - 1]`",
405 [(i64, u64), (i128, u128)]
406);
407
408impl QuotientMap<u32> for Mersenne31 {
410 #[inline]
412 fn from_int(int: u32) -> Self {
413 let msb = int & (1 << 31);
415 let msb_reduced = msb >> 31;
416 Self::new_reduced(int ^ msb) + Self::new_reduced(msb_reduced)
417 }
418
419 #[inline]
423 fn from_canonical_checked(int: u32) -> Option<Self> {
424 (int < Self::ORDER_U32).then(|| Self::new_reduced(int))
425 }
426
427 #[inline(always)]
432 unsafe fn from_canonical_unchecked(int: u32) -> Self {
433 debug_assert!(int < Self::ORDER_U32);
434 Self::new_reduced(int)
435 }
436}
437
438impl QuotientMap<i32> for Mersenne31 {
439 #[inline]
441 fn from_int(int: i32) -> Self {
442 if int >= 0 {
443 Self::new_reduced(int as u32)
444 } else if int > (-1 << 31) {
445 Self::new_reduced(Self::ORDER_U32.wrapping_add_signed(int))
446 } else {
447 Self::NEG_ONE
449 }
450 }
451
452 #[inline]
456 fn from_canonical_checked(int: i32) -> Option<Self> {
457 const TWO_EXP_30: i32 = 1 << 30;
458 const NEG_TWO_EXP_30_PLUS_1: i32 = (-1 << 30) + 1;
459 match int {
460 0..TWO_EXP_30 => Some(Self::new_reduced(int as u32)),
461 NEG_TWO_EXP_30_PLUS_1..0 => {
462 Some(Self::new_reduced(Self::ORDER_U32.wrapping_add_signed(int)))
463 }
464 _ => None,
465 }
466 }
467
468 #[inline(always)]
473 unsafe fn from_canonical_unchecked(int: i32) -> Self {
474 if int >= 0 {
475 Self::new_reduced(int as u32)
476 } else {
477 Self::new_reduced(Self::ORDER_U32.wrapping_add_signed(int))
478 }
479 }
480}
481
482impl PrimeField for Mersenne31 {
483 fn as_canonical_biguint(&self) -> BigUint {
484 <Self as PrimeField32>::as_canonical_u32(self).into()
485 }
486}
487
488impl PrimeField32 for Mersenne31 {
489 const ORDER_U32: u32 = P;
490
491 #[inline]
492 fn as_canonical_u32(&self) -> u32 {
493 if self.value == Self::ORDER_U32 {
496 0
497 } else {
498 self.value
499 }
500 }
501}
502
503impl PrimeField64 for Mersenne31 {
504 const ORDER_U64: u64 = <Self as PrimeField32>::ORDER_U32 as u64;
505
506 #[inline]
507 fn as_canonical_u64(&self) -> u64 {
508 self.as_canonical_u32().into()
509 }
510}
511
512impl Add for Mersenne31 {
513 type Output = Self;
514
515 #[inline]
516 fn add(self, rhs: Self) -> Self {
517 let (sum_i32, over) = (self.value as i32).overflowing_add(rhs.value as i32);
524 let sum_u32 = sum_i32 as u32;
525 let sum_corr = sum_u32.wrapping_sub(Self::ORDER_U32);
526
527 Self::new_reduced(if over { sum_corr } else { sum_u32 })
530 }
531}
532
533impl Sub for Mersenne31 {
534 type Output = Self;
535
536 #[inline]
537 fn sub(self, rhs: Self) -> Self {
538 let (mut sub, over) = self.value.overflowing_sub(rhs.value);
539
540 sub -= over as u32;
544 Self::new_reduced(sub & Self::ORDER_U32)
545 }
546}
547
548impl Neg for Mersenne31 {
549 type Output = Self;
550
551 #[inline]
552 fn neg(self) -> Self::Output {
553 Self::new_reduced(Self::ORDER_U32 - self.value)
555 }
556}
557
558impl Mul for Mersenne31 {
559 type Output = Self;
560
561 #[inline]
562 #[allow(clippy::cast_possible_truncation)]
563 fn mul(self, rhs: Self) -> Self {
564 let prod = u64::from(self.value) * u64::from(rhs.value);
565 from_u62(prod)
566 }
567}
568
569impl_add_assign!(Mersenne31);
570impl_sub_assign!(Mersenne31);
571impl_mul_methods!(Mersenne31);
572impl_div_methods!(Mersenne31, Mersenne31);
573
574impl Sum for Mersenne31 {
575 #[inline]
576 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
577 let sum = iter.map(|x| x.value as u64).sum::<u64>();
582
583 from_u62(sum)
585 }
586}
587
588#[inline(always)]
591pub(crate) const fn partial_reduce(val: u64) -> u64 {
592 let lo = (val & (P as u64)) as u32;
594 let hi = val >> 31;
595 lo as u64 + hi
596}
597
598#[inline(always)]
602pub(crate) fn reduce_64(val: u64) -> u32 {
603 let lo = (val & (P as u64)) as u32;
607 let hi = val >> 31;
608 let sum1 = lo as u64 + hi;
609
610 let lo2 = (sum1 & (P as u64)) as u32;
612 let hi2 = (sum1 >> 31) as u32; let sum2 = lo2 + hi2; sum2.min(sum2.wrapping_sub(P))
617}
618
619#[inline(always)]
620pub(crate) fn from_u62(input: u64) -> Mersenne31 {
621 debug_assert!(input < (1 << 62));
622 let input_lo = (input & ((1 << 31) - 1)) as u32;
623 let input_high = (input >> 31) as u32;
624 Mersenne31::new_reduced(input_lo) + Mersenne31::new_reduced(input_high)
625}
626
627impl UniformSamplingField for Mersenne31 {
628 const MAX_SINGLE_SAMPLE_BITS: usize = 16;
629 const SAMPLING_BITS_M: [u64; 64] = {
632 let prime: u64 = P as u64;
633 let mut a = [0u64; 64];
634 let mut k = 0;
635 while k < 64 {
636 if k == 0 {
637 a[k] = prime; } else {
639 let mask = !((1u64 << k) - 1);
641 a[k] = prime & mask;
642 }
643 k += 1;
644 }
645 a
646 };
647}
648
649#[cfg(test)]
650mod tests {
651 use num_bigint::BigUint;
652 use p3_field::{InjectiveMonomial, PermutationMonomial, PrimeCharacteristicRing};
653 use p3_field_testing::{
654 test_field, test_prime_field, test_prime_field_32, test_prime_field_64,
655 };
656
657 use crate::Mersenne31;
658
659 type F = Mersenne31;
660
661 #[test]
662 fn exp_root() {
663 let m1 = F::from_u32(0x34167c58);
666 let m2 = F::from_u32(0x61f3207b);
667
668 assert_eq!(m1.injective_exp_n().injective_exp_root_n(), m1);
669 assert_eq!(m2.injective_exp_n().injective_exp_root_n(), m2);
670 assert_eq!(F::TWO.injective_exp_n().injective_exp_root_n(), F::TWO);
671 }
672
673 const ZEROS: [Mersenne31; 2] = [Mersenne31::ZERO, Mersenne31::new((1_u32 << 31) - 1)];
675 const ONES: [Mersenne31; 1] = [Mersenne31::ONE];
676
677 fn multiplicative_group_prime_factorization() -> [(BigUint, u32); 7] {
680 [
681 (BigUint::from(2u8), 1),
682 (BigUint::from(3u8), 2),
683 (BigUint::from(7u8), 1),
684 (BigUint::from(11u8), 1),
685 (BigUint::from(31u8), 1),
686 (BigUint::from(151u8), 1),
687 (BigUint::from(331u16), 1),
688 ]
689 }
690
691 test_field!(
692 crate::Mersenne31,
693 &super::ZEROS,
694 &super::ONES,
695 &super::multiplicative_group_prime_factorization()
696 );
697 test_prime_field!(crate::Mersenne31);
698 test_prime_field_64!(crate::Mersenne31, &super::ZEROS, &super::ONES);
699 test_prime_field_32!(crate::Mersenne31, &super::ZEROS, &super::ONES);
700}