Skip to main content

p3_mersenne_31/
mds.rs

1//! MDS matrices over the Mersenne31 field, and permutations defined by them.
2//!
3//! NB: Not all sizes have fast implementations of their permutations.
4//! Supported sizes: 8, 12, 16, 32, 64.
5//! Sizes 8 and 12 are from Plonky2, size 16 was found as part of concurrent
6//! work by Angus Gruen and Hamish Ivey-Law. Other sizes are from Ulrich Haböck's
7//! database.
8
9use p3_field::PrimeCharacteristicRing;
10use p3_mds::MdsPermutation;
11use p3_mds::karatsuba_convolution::Convolve;
12use p3_mds::util::{dot_product, first_row_to_first_col};
13use p3_symmetric::Permutation;
14
15use crate::Mersenne31;
16
17#[derive(Clone, Debug, Default)]
18pub struct MdsMatrixMersenne31;
19
20/// Instantiate convolution for "small" RHS vectors over Mersenne31.
21///
22/// Here "small" means N = len(rhs) <= 16 and sum(r for r in rhs) <
23/// 2^24 (roughly), though in practice the sum will be less than 2^9.
24struct SmallConvolveMersenne31;
25impl Convolve<Mersenne31, i64, i64> for SmallConvolveMersenne31 {
26    const T_ZERO: i64 = 0;
27    const U_ZERO: i64 = 0;
28
29    #[inline(always)]
30    fn halve(val: i64) -> i64 {
31        val >> 1
32    }
33
34    /// Return the lift of an (almost) reduced Mersenne31 element.
35    /// The Mersenne31 implementation guarantees that
36    /// 0 <= input.value <= P < 2^31.
37    #[inline(always)]
38    fn read(input: Mersenne31) -> i64 {
39        input.value as i64
40    }
41
42    /// For a convolution of size N, |x| < N * 2^31 and (as per the
43    /// assumption above), |y| < 2^24. So the product is at most N * 2^55
44    /// which will not overflow for N <= 16.
45    #[inline(always)]
46    fn parity_dot<const N: usize>(u: [i64; N], v: [i64; N]) -> i64 {
47        dot_product(u, v)
48    }
49
50    /// The assumptions above mean z < N^2 * 2^55, which is at most
51    /// 2^63 when N <= 16.
52    ///
53    /// NB: Even though intermediate values could be negative, the
54    /// output must be non-negative since the inputs were
55    /// non-negative.
56    #[inline(always)]
57    fn reduce(z: i64) -> Mersenne31 {
58        debug_assert!(z >= 0);
59        Mersenne31::from_u64(z as u64)
60    }
61}
62
63/// Instantiate convolution for "large" RHS vectors over Mersenne31.
64///
65/// Here "large" means the elements can be as big as the field
66/// characteristic, and the size N of the RHS is <= 64.
67struct LargeConvolveMersenne31;
68impl Convolve<Mersenne31, i64, i64> for LargeConvolveMersenne31 {
69    const T_ZERO: i64 = 0;
70    const U_ZERO: i64 = 0;
71
72    #[inline(always)]
73    fn halve(val: i64) -> i64 {
74        val >> 1
75    }
76
77    /// Return the lift of an (almost) reduced Mersenne31 element.
78    /// The Mersenne31 implementation guarantees that
79    /// 0 <= input.value <= P < 2^31.
80    #[inline(always)]
81    fn read(input: Mersenne31) -> i64 {
82        input.value as i64
83    }
84
85    #[inline]
86    fn parity_dot<const N: usize>(u: [i64; N], v: [i64; N]) -> i64 {
87        // For a convolution of size N, |x|, |y| < N * 2^31, so the product
88        // could be as much as N^2 * 2^62. This will overflow an i64, so
89        // we first widen to i128.
90
91        let mut dp = 0i128;
92        for i in 0..N {
93            dp += u[i] as i128 * v[i] as i128;
94        }
95
96        const LOWMASK: i128 = (1 << 42) - 1; // Gets the bits lower than 42.
97        const HIGHMASK: i128 = !LOWMASK; // Gets all bits higher than 42.
98
99        let low_bits = (dp & LOWMASK) as i64; // low_bits < 2**42
100        let high_bits = ((dp & HIGHMASK) >> 31) as i64; // |high_bits| < 2**(n - 31)
101
102        // Proof that low_bits + high_bits is congruent to dp (mod p)
103        // and congruent to dp (mod 2^11):
104        //
105        // The individual bounds clearly show that low_bits +
106        // high_bits < 2**(n - 30).
107        //
108        // Next observe that low_bits + high_bits = input - (2**31 -
109        // 1) * (high_bits) = input mod P.
110        //
111        // Finally note that 2**11 divides high_bits and so low_bits +
112        // high_bits = low_bits mod 2**11 = input mod 2**11.
113
114        low_bits + high_bits
115    }
116
117    #[inline]
118    fn reduce(z: i64) -> Mersenne31 {
119        // After the dot product, the maximal size is N^2 * 2^62 < 2^74
120        // as N = 64 is the biggest size. So, after the partial
121        // reduction, the output z of parity dot satisfies |z| < 2^44
122        // (Where 44 is 74 - 30).
123        //
124        // In the recombining steps, conv maps (wo, w1) -> ((wo + w1)/2,
125        // (wo + w1)/2) which has no effect on the maximal size. (Indeed,
126        // it makes sizes almost strictly smaller).
127        //
128        // On the other hand, negacyclic_conv (ignoring the re-index)
129        // recombines as: (w0, w1, w2) -> (w0 + w1, w2 - w0 - w1). Hence
130        // if the input is <= K, the output is <= 3K.
131        //
132        // Thus the values appearing at the end are bounded by 3^n 2^44
133        // where n is the maximal number of negacyclic_conv recombination
134        // steps. When N = 64, we need to recombine for singed_conv_32,
135        // singed_conv_16, singed_conv_8 so the overall bound will be 3^3
136        // 2^44 < 32 * 2^44 < 2^49.
137        debug_assert!(z > -(1i64 << 49));
138        debug_assert!(z < (1i64 << 49));
139
140        const MASK: i64 = (1 << 31) - 1;
141        // Morally, our value is a i62 not a i64 as the top 3 bits are
142        // guaranteed to be equal.
143        //
144        // The masked value can equal 2^31 - 1 (the non-canonical representation of zero).
145        //
146        // So the constructor must accept it.
147        let low_bits = Mersenne31::new_reduced((z & MASK) as u32);
148
149        let high_bits = ((z >> 31) & MASK) as i32;
150        let sign_bits = (z >> 62) as i32;
151
152        // The sum lies in [0, 2^31 - 1].
153        //
154        // A negative `z` forces the upper-bit chunk to be at least 1.
155        //
156        // So the sign correction of -1 cannot drag the sum below zero.
157        let high = Mersenne31::new_reduced((high_bits + sign_bits) as u32);
158        low_bits + high
159    }
160}
161
162const MATRIX_CIRC_MDS_8_SML_ROW: [i64; 8] = [7, 1, 3, 8, 8, 3, 4, 9];
163
164impl Permutation<[Mersenne31; 8]> for MdsMatrixMersenne31 {
165    fn permute(&self, input: [Mersenne31; 8]) -> [Mersenne31; 8] {
166        const MATRIX_CIRC_MDS_8_SML_COL: [i64; 8] =
167            first_row_to_first_col(&MATRIX_CIRC_MDS_8_SML_ROW);
168        SmallConvolveMersenne31::apply(
169            input,
170            MATRIX_CIRC_MDS_8_SML_COL,
171            SmallConvolveMersenne31::conv8,
172        )
173    }
174}
175impl MdsPermutation<Mersenne31, 8> for MdsMatrixMersenne31 {}
176
177const MATRIX_CIRC_MDS_12_SML_ROW: [i64; 12] = [1, 1, 2, 1, 8, 9, 10, 7, 5, 9, 4, 10];
178
179impl Permutation<[Mersenne31; 12]> for MdsMatrixMersenne31 {
180    fn permute(&self, input: [Mersenne31; 12]) -> [Mersenne31; 12] {
181        const MATRIX_CIRC_MDS_12_SML_COL: [i64; 12] =
182            first_row_to_first_col(&MATRIX_CIRC_MDS_12_SML_ROW);
183        SmallConvolveMersenne31::apply(
184            input,
185            MATRIX_CIRC_MDS_12_SML_COL,
186            SmallConvolveMersenne31::conv12,
187        )
188    }
189}
190impl MdsPermutation<Mersenne31, 12> for MdsMatrixMersenne31 {}
191
192pub(crate) const MATRIX_CIRC_MDS_16_SML_ROW: [i64; 16] =
193    [1, 1, 51, 1, 11, 17, 2, 1, 101, 63, 15, 2, 67, 22, 13, 3];
194
195impl Permutation<[Mersenne31; 16]> for MdsMatrixMersenne31 {
196    fn permute(&self, input: [Mersenne31; 16]) -> [Mersenne31; 16] {
197        const MATRIX_CIRC_MDS_16_SML_COL: [i64; 16] =
198            first_row_to_first_col(&MATRIX_CIRC_MDS_16_SML_ROW);
199        SmallConvolveMersenne31::apply(
200            input,
201            MATRIX_CIRC_MDS_16_SML_COL,
202            SmallConvolveMersenne31::conv16,
203        )
204    }
205}
206impl MdsPermutation<Mersenne31, 16> for MdsMatrixMersenne31 {}
207
208#[rustfmt::skip]
209pub(crate) const MATRIX_CIRC_MDS_32_MERSENNE31_ROW: [i64; 32] = [
210    0x1896DC78, 0x559D1E29, 0x04EBD732, 0x3FF449D7,
211    0x2DB0E2CE, 0x26776B85, 0x76018E57, 0x1025FA13,
212    0x06486BAB, 0x37706EBA, 0x25EB966B, 0x113C24E5,
213    0x2AE20EC4, 0x5A27507C, 0x0CD38CF1, 0x761C10E5,
214    0x19E3EF1A, 0x032C730F, 0x35D8AF83, 0x651DF13B,
215    0x7EC3DB1A, 0x6A146994, 0x588F9145, 0x09B79455,
216    0x7FDA05EC, 0x19FE71A8, 0x6988947A, 0x624F1D31,
217    0x500BB628, 0x0B1428CE, 0x3A62E1D6, 0x77692387
218];
219
220impl Permutation<[Mersenne31; 32]> for MdsMatrixMersenne31 {
221    fn permute(&self, input: [Mersenne31; 32]) -> [Mersenne31; 32] {
222        const MATRIX_CIRC_MDS_32_MERSENNE31_COL: [i64; 32] =
223            first_row_to_first_col(&MATRIX_CIRC_MDS_32_MERSENNE31_ROW);
224        LargeConvolveMersenne31::apply(
225            input,
226            MATRIX_CIRC_MDS_32_MERSENNE31_COL,
227            LargeConvolveMersenne31::conv32,
228        )
229    }
230}
231impl MdsPermutation<Mersenne31, 32> for MdsMatrixMersenne31 {}
232
233#[rustfmt::skip]
234const MATRIX_CIRC_MDS_64_MERSENNE31_ROW: [i64; 64] = [
235    0x570227A5, 0x3702983F, 0x4B7B3B0A, 0x74F13DE3,
236    0x485314B0, 0x0157E2EC, 0x1AD2E5DE, 0x721515E3,
237    0x5452ADA3, 0x0C74B6C1, 0x67DA9450, 0x33A48369,
238    0x3BDBEE06, 0x7C678D5E, 0x160F16D3, 0x54888B8C,
239    0x666C7AA6, 0x113B89E2, 0x2A403CE2, 0x18F9DF42,
240    0x2A685E84, 0x49EEFDE5, 0x5D044806, 0x560A41F8,
241    0x69EF1BD0, 0x2CD15786, 0x62E07766, 0x22A231E2,
242    0x3CFCF40C, 0x4E8F63D8, 0x69657A15, 0x466B4B2D,
243    0x4194B4D2, 0x1E9A85EA, 0x39709C27, 0x4B030BF3,
244    0x655DCE1D, 0x251F8899, 0x5B2EA879, 0x1E10E42F,
245    0x31F5BE07, 0x2AFBB7F9, 0x3E11021A, 0x5D97A17B,
246    0x6F0620BD, 0x5DBFC31D, 0x76C4761D, 0x21938559,
247    0x33777473, 0x71F0E92C, 0x0B9872A1, 0x4C2411F9,
248    0x545B7C96, 0x20256BAF, 0x7B8B493E, 0x33AD525C,
249    0x15EAEA1C, 0x6D2D1A21, 0x06A81D14, 0x3FACEB4F,
250    0x130EC21C, 0x3C84C4F5, 0x50FD67C0, 0x30FDD85A,
251];
252
253impl Permutation<[Mersenne31; 64]> for MdsMatrixMersenne31 {
254    fn permute(&self, input: [Mersenne31; 64]) -> [Mersenne31; 64] {
255        const MATRIX_CIRC_MDS_64_MERSENNE31_COL: [i64; 64] =
256            first_row_to_first_col(&MATRIX_CIRC_MDS_64_MERSENNE31_ROW);
257        LargeConvolveMersenne31::apply(
258            input,
259            MATRIX_CIRC_MDS_64_MERSENNE31_COL,
260            LargeConvolveMersenne31::conv64,
261        )
262    }
263}
264impl MdsPermutation<Mersenne31, 64> for MdsMatrixMersenne31 {}
265
266#[cfg(test)]
267mod tests {
268    use p3_field::PrimeCharacteristicRing;
269    use p3_mds::karatsuba_convolution::Convolve;
270    use p3_symmetric::Permutation;
271
272    use super::{LargeConvolveMersenne31, MdsMatrixMersenne31, Mersenne31};
273
274    #[test]
275    fn large_convolve_reduce_accepts_p_representation() {
276        // Invariant: the reducer accepts inputs whose low 31 bits are all ones.
277        //
278        // Reason: 2^31 - 1 is the non-canonical representation of zero.
279
280        // Fixture state:
281        //
282        //     P     = 2^31 - 1
283        //     z_neg = -1     (bit pattern: all 64 bits set)
284        //     z_pos =  P     (bit pattern: low 31 bits set, rest zero)
285        //
286        // Masking either value by `(1 << 31) - 1` produces P, the edge case under test.
287
288        // Case 1: z = -1  →  -1 ≡ P - 1   (mod P).
289        let got = LargeConvolveMersenne31::reduce(-1);
290        let expected = Mersenne31::ZERO - Mersenne31::ONE;
291        assert_eq!(got, expected);
292
293        // Case 2: z = P   →   P ≡ 0   (mod P).
294        let got = LargeConvolveMersenne31::reduce((1i64 << 31) - 1);
295        assert_eq!(got, Mersenne31::ZERO);
296    }
297
298    #[test]
299    fn mersenne8() {
300        let input: [Mersenne31; 8] = Mersenne31::new_array([
301            1741044457, 327154658, 318297696, 1528828225, 468360260, 1271368222, 1906288587,
302            1521884224,
303        ]);
304
305        let output = MdsMatrixMersenne31.permute(input);
306
307        let expected: [Mersenne31; 8] = Mersenne31::new_array([
308            895992680, 1343855369, 2107796831, 266468728, 846686506, 252887121, 205223309,
309            260248790,
310        ]);
311
312        assert_eq!(output, expected);
313    }
314
315    #[test]
316    fn mersenne12() {
317        let input: [Mersenne31; 12] = Mersenne31::new_array([
318            1232740094, 661555540, 11024822, 1620264994, 471137070, 276755041, 1316882747,
319            1023679816, 1675266989, 743211887, 44774582, 1990989306,
320        ]);
321
322        let output = MdsMatrixMersenne31.permute(input);
323
324        let expected: [Mersenne31; 12] = Mersenne31::new_array([
325            860812289, 399778981, 1228500858, 798196553, 673507779, 1116345060, 829764188,
326            138346433, 578243475, 553581995, 578183208, 1527769050,
327        ]);
328
329        assert_eq!(output, expected);
330    }
331
332    #[test]
333    fn mersenne16() {
334        let input: [Mersenne31; 16] = Mersenne31::new_array([
335            1431168444, 963811518, 88067321, 381314132, 908628282, 1260098295, 980207659,
336            150070493, 357706876, 2014609375, 387876458, 1621671571, 183146044, 107201572,
337            166536524, 2078440788,
338        ]);
339
340        let output = MdsMatrixMersenne31.permute(input);
341
342        let expected: [Mersenne31; 16] = Mersenne31::new_array([
343            1858869691, 1607793806, 1200396641, 1400502985, 1511630695, 187938132, 1332411488,
344            2041577083, 2014246632, 802022141, 796807132, 1647212930, 813167618, 1867105010,
345            508596277, 1457551581,
346        ]);
347
348        assert_eq!(output, expected);
349    }
350
351    #[test]
352    fn mersenne32() {
353        let input: [Mersenne31; 32] = Mersenne31::new_array([
354            873912014, 1112497426, 300405095, 4255553, 1234979949, 156402357, 1952135954,
355            718195399, 1041748465, 683604342, 184275751, 1184118518, 214257054, 1293941921,
356            64085758, 710448062, 1133100009, 350114887, 1091675272, 671421879, 1226105999,
357            546430131, 1298443967, 1787169653, 2129310791, 1560307302, 471771931, 1191484402,
358            1550203198, 1541319048, 229197040, 839673789,
359        ]);
360
361        let output = MdsMatrixMersenne31.permute(input);
362
363        let expected: [Mersenne31; 32] = Mersenne31::new_array([
364            1439049928, 890642852, 694402307, 713403244, 553213342, 1049445650, 321709533,
365            1195683415, 2118492257, 623077773, 96734062, 990488164, 1674607608, 749155000,
366            353377854, 966432998, 1114654884, 1370359248, 1624965859, 685087760, 1631836645,
367            1615931812, 2061986317, 1773551151, 1449911206, 1951762557, 545742785, 582866449,
368            1379774336, 229242759, 1871227547, 752848413,
369        ]);
370
371        assert_eq!(output, expected);
372    }
373
374    #[test]
375    fn mersenne64() {
376        let input: [Mersenne31; 64] = Mersenne31::new_array([
377            837269696, 1509031194, 413915480, 1889329185, 315502822, 1529162228, 1454661012,
378            1015826742, 973381409, 1414676304, 1449029961, 1968715566, 2027226497, 1721820509,
379            434042616, 1436005045, 1680352863, 651591867, 260585272, 1078022153, 703990572,
380            269504423, 1776357592, 1174979337, 1142666094, 1897872960, 1387995838, 250774418,
381            776134750, 73930096, 194742451, 1860060380, 666407744, 669566398, 963802147,
382            2063418105, 1772573581, 998923482, 701912753, 1716548204, 860820931, 1680395948,
383            949886256, 1811558161, 501734557, 1671977429, 463135040, 1911493108, 207754409,
384            608714758, 1553060084, 1558941605, 980281686, 2014426559, 650527801, 53015148,
385            1521176057, 720530872, 713593252, 88228433, 1194162313, 1922416934, 1075145779,
386            344403794,
387        ]);
388
389        let output = MdsMatrixMersenne31.permute(input);
390
391        let expected: [Mersenne31; 64] = Mersenne31::new_array([
392            1599981950, 252630853, 1171557270, 116468420, 1269245345, 666203050, 46155642,
393            1701131520, 530845775, 508460407, 630407239, 1731628135, 1199144768, 295132047,
394            77536342, 1472377703, 30752443, 1300339617, 18647556, 1267774380, 1194573079,
395            1624665024, 646848056, 1667216490, 1184843555, 1250329476, 254171597, 1902035936,
396            1706882202, 964921003, 952266538, 1215696284, 539510504, 1056507562, 1393151480,
397            733644883, 1663330816, 1100715048, 991108703, 1671345065, 1376431774, 408310416,
398            313176996, 743567676, 304660642, 1842695838, 958201635, 1650792218, 541570244,
399            968523062, 1958918704, 1866282698, 849808680, 1193306222, 794153281, 822835360,
400            135282913, 1149868448, 2068162123, 1474283743, 2039088058, 720305835, 746036736,
401            671006610,
402        ]);
403
404        assert_eq!(output, expected);
405    }
406}