p3_mersenne_31/
mds.rs

1//! MDS matrices over the Mersenne31 field, and permutations defined by them.
2//!
3//! NB: Not all sizes have fast implementations of their permutations.
4//! Supported sizes: 8, 12, 16, 32, 64.
5//! Sizes 8 and 12 are from Plonky2, size 16 was found as part of concurrent
6//! work by Angus Gruen and Hamish Ivey-Law. Other sizes are from Ulrich Haböck's
7//! database.
8
9use p3_field::PrimeCharacteristicRing;
10use p3_field::integers::QuotientMap;
11use p3_mds::MdsPermutation;
12use p3_mds::karatsuba_convolution::Convolve;
13use p3_mds::util::{dot_product, first_row_to_first_col};
14use p3_symmetric::Permutation;
15
16use crate::Mersenne31;
17
18#[derive(Clone, Debug, Default)]
19pub struct MdsMatrixMersenne31;
20
21/// Instantiate convolution for "small" RHS vectors over Mersenne31.
22///
23/// Here "small" means N = len(rhs) <= 16 and sum(r for r in rhs) <
24/// 2^24 (roughly), though in practice the sum will be less than 2^9.
25struct SmallConvolveMersenne31;
26impl Convolve<Mersenne31, i64, i64, i64> for SmallConvolveMersenne31 {
27    /// Return the lift of an (almost) reduced Mersenne31 element.
28    /// The Mersenne31 implementation guarantees that
29    /// 0 <= input.value <= P < 2^31.
30    #[inline(always)]
31    fn read(input: Mersenne31) -> i64 {
32        input.value as i64
33    }
34
35    /// FIXME: Refactor the dot product
36    /// For a convolution of size N, |x| < N * 2^31 and (as per the
37    /// assumption above), |y| < 2^24. So the product is at most N * 2^55
38    /// which will not overflow for N <= 16.
39    #[inline(always)]
40    fn parity_dot<const N: usize>(u: [i64; N], v: [i64; N]) -> i64 {
41        dot_product(u, v)
42    }
43
44    /// The assumptions above mean z < N^2 * 2^55, which is at most
45    /// 2^63 when N <= 16.
46    ///
47    /// NB: Even though intermediate values could be negative, the
48    /// output must be non-negative since the inputs were
49    /// non-negative.
50    #[inline(always)]
51    fn reduce(z: i64) -> Mersenne31 {
52        debug_assert!(z >= 0);
53        Mersenne31::from_u64(z as u64)
54    }
55}
56
57/// Instantiate convolution for "large" RHS vectors over Mersenne31.
58///
59/// Here "large" means the elements can be as big as the field
60/// characteristic, and the size N of the RHS is <= 64.
61struct LargeConvolveMersenne31;
62impl Convolve<Mersenne31, i64, i64, i64> for LargeConvolveMersenne31 {
63    /// Return the lift of an (almost) reduced Mersenne31 element.
64    /// The Mersenne31 implementation guarantees that
65    /// 0 <= input.value <= P < 2^31.
66    #[inline(always)]
67    fn read(input: Mersenne31) -> i64 {
68        input.value as i64
69    }
70
71    #[inline]
72    fn parity_dot<const N: usize>(u: [i64; N], v: [i64; N]) -> i64 {
73        // For a convolution of size N, |x|, |y| < N * 2^31, so the product
74        // could be as much as N^2 * 2^62. This will overflow an i64, so
75        // we first widen to i128.
76
77        let mut dp = 0i128;
78        for i in 0..N {
79            dp += u[i] as i128 * v[i] as i128;
80        }
81
82        const LOWMASK: i128 = (1 << 42) - 1; // Gets the bits lower than 42.
83        const HIGHMASK: i128 = !LOWMASK; // Gets all bits higher than 42.
84
85        let low_bits = (dp & LOWMASK) as i64; // low_bits < 2**42
86        let high_bits = ((dp & HIGHMASK) >> 31) as i64; // |high_bits| < 2**(n - 31)
87
88        // Proof that low_bits + high_bits is congruent to dp (mod p)
89        // and congruent to dp (mod 2^11):
90        //
91        // The individual bounds clearly show that low_bits +
92        // high_bits < 2**(n - 30).
93        //
94        // Next observe that low_bits + high_bits = input - (2**31 -
95        // 1) * (high_bits) = input mod P.
96        //
97        // Finally note that 2**11 divides high_bits and so low_bits +
98        // high_bits = low_bits mod 2**11 = input mod 2**11.
99
100        low_bits + high_bits
101    }
102
103    #[inline]
104    fn reduce(z: i64) -> Mersenne31 {
105        // After the dot product, the maximal size is N^2 * 2^62 < 2^74
106        // as N = 64 is the biggest size. So, after the partial
107        // reduction, the output z of parity dot satisfies |z| < 2^44
108        // (Where 44 is 74 - 30).
109        //
110        // In the recombining steps, conv maps (wo, w1) -> ((wo + w1)/2,
111        // (wo + w1)/2) which has no effect on the maximal size. (Indeed,
112        // it makes sizes almost strictly smaller).
113        //
114        // On the other hand, negacyclic_conv (ignoring the re-index)
115        // recombines as: (w0, w1, w2) -> (w0 + w1, w2 - w0 - w1). Hence
116        // if the input is <= K, the output is <= 3K.
117        //
118        // Thus the values appearing at the end are bounded by 3^n 2^44
119        // where n is the maximal number of negacyclic_conv recombination
120        // steps. When N = 64, we need to recombine for singed_conv_32,
121        // singed_conv_16, singed_conv_8 so the overall bound will be 3^3
122        // 2^44 < 32 * 2^44 < 2^49.
123        debug_assert!(z > -(1i64 << 49));
124        debug_assert!(z < (1i64 << 49));
125
126        const MASK: i64 = (1 << 31) - 1;
127        // Morally, our value is a i62 not a i64 as the top 3 bits are
128        // guaranteed to be equal.
129        let low_bits = unsafe {
130            // This is safe as 0 <= z & MASK < 2^31
131            Mersenne31::from_canonical_unchecked((z & MASK) as u32)
132        };
133
134        let high_bits = ((z >> 31) & MASK) as i32;
135        let sign_bits = (z >> 62) as i32;
136
137        let high = unsafe {
138            // This is safe as high_bits + sign_bits > 0 as by assumption b[63] = b[61].
139            Mersenne31::from_canonical_unchecked((high_bits + sign_bits) as u32)
140        };
141        low_bits + high
142    }
143}
144
145const MATRIX_CIRC_MDS_8_SML_ROW: [i64; 8] = [7, 1, 3, 8, 8, 3, 4, 9];
146
147impl Permutation<[Mersenne31; 8]> for MdsMatrixMersenne31 {
148    fn permute(&self, input: [Mersenne31; 8]) -> [Mersenne31; 8] {
149        const MATRIX_CIRC_MDS_8_SML_COL: [i64; 8] =
150            first_row_to_first_col(&MATRIX_CIRC_MDS_8_SML_ROW);
151        SmallConvolveMersenne31::apply(
152            input,
153            MATRIX_CIRC_MDS_8_SML_COL,
154            SmallConvolveMersenne31::conv8,
155        )
156    }
157
158    fn permute_mut(&self, input: &mut [Mersenne31; 8]) {
159        *input = self.permute(*input);
160    }
161}
162impl MdsPermutation<Mersenne31, 8> for MdsMatrixMersenne31 {}
163
164const MATRIX_CIRC_MDS_12_SML_ROW: [i64; 12] = [1, 1, 2, 1, 8, 9, 10, 7, 5, 9, 4, 10];
165
166impl Permutation<[Mersenne31; 12]> for MdsMatrixMersenne31 {
167    fn permute(&self, input: [Mersenne31; 12]) -> [Mersenne31; 12] {
168        const MATRIX_CIRC_MDS_12_SML_COL: [i64; 12] =
169            first_row_to_first_col(&MATRIX_CIRC_MDS_12_SML_ROW);
170        SmallConvolveMersenne31::apply(
171            input,
172            MATRIX_CIRC_MDS_12_SML_COL,
173            SmallConvolveMersenne31::conv12,
174        )
175    }
176
177    fn permute_mut(&self, input: &mut [Mersenne31; 12]) {
178        *input = self.permute(*input);
179    }
180}
181impl MdsPermutation<Mersenne31, 12> for MdsMatrixMersenne31 {}
182
183const MATRIX_CIRC_MDS_16_SML_ROW: [i64; 16] =
184    [1, 1, 51, 1, 11, 17, 2, 1, 101, 63, 15, 2, 67, 22, 13, 3];
185
186impl Permutation<[Mersenne31; 16]> for MdsMatrixMersenne31 {
187    fn permute(&self, input: [Mersenne31; 16]) -> [Mersenne31; 16] {
188        const MATRIX_CIRC_MDS_16_SML_COL: [i64; 16] =
189            first_row_to_first_col(&MATRIX_CIRC_MDS_16_SML_ROW);
190        SmallConvolveMersenne31::apply(
191            input,
192            MATRIX_CIRC_MDS_16_SML_COL,
193            SmallConvolveMersenne31::conv16,
194        )
195    }
196
197    fn permute_mut(&self, input: &mut [Mersenne31; 16]) {
198        *input = self.permute(*input);
199    }
200}
201impl MdsPermutation<Mersenne31, 16> for MdsMatrixMersenne31 {}
202
203#[rustfmt::skip]
204const MATRIX_CIRC_MDS_32_MERSENNE31_ROW: [i64; 32] = [
205    0x1896DC78, 0x559D1E29, 0x04EBD732, 0x3FF449D7,
206    0x2DB0E2CE, 0x26776B85, 0x76018E57, 0x1025FA13,
207    0x06486BAB, 0x37706EBA, 0x25EB966B, 0x113C24E5,
208    0x2AE20EC4, 0x5A27507C, 0x0CD38CF1, 0x761C10E5,
209    0x19E3EF1A, 0x032C730F, 0x35D8AF83, 0x651DF13B,
210    0x7EC3DB1A, 0x6A146994, 0x588F9145, 0x09B79455,
211    0x7FDA05EC, 0x19FE71A8, 0x6988947A, 0x624F1D31,
212    0x500BB628, 0x0B1428CE, 0x3A62E1D6, 0x77692387
213];
214
215impl Permutation<[Mersenne31; 32]> for MdsMatrixMersenne31 {
216    fn permute(&self, input: [Mersenne31; 32]) -> [Mersenne31; 32] {
217        const MATRIX_CIRC_MDS_32_MERSENNE31_COL: [i64; 32] =
218            first_row_to_first_col(&MATRIX_CIRC_MDS_32_MERSENNE31_ROW);
219        LargeConvolveMersenne31::apply(
220            input,
221            MATRIX_CIRC_MDS_32_MERSENNE31_COL,
222            LargeConvolveMersenne31::conv32,
223        )
224    }
225
226    fn permute_mut(&self, input: &mut [Mersenne31; 32]) {
227        *input = self.permute(*input);
228    }
229}
230impl MdsPermutation<Mersenne31, 32> for MdsMatrixMersenne31 {}
231
232#[rustfmt::skip]
233const MATRIX_CIRC_MDS_64_MERSENNE31_ROW: [i64; 64] = [
234    0x570227A5, 0x3702983F, 0x4B7B3B0A, 0x74F13DE3,
235    0x485314B0, 0x0157E2EC, 0x1AD2E5DE, 0x721515E3,
236    0x5452ADA3, 0x0C74B6C1, 0x67DA9450, 0x33A48369,
237    0x3BDBEE06, 0x7C678D5E, 0x160F16D3, 0x54888B8C,
238    0x666C7AA6, 0x113B89E2, 0x2A403CE2, 0x18F9DF42,
239    0x2A685E84, 0x49EEFDE5, 0x5D044806, 0x560A41F8,
240    0x69EF1BD0, 0x2CD15786, 0x62E07766, 0x22A231E2,
241    0x3CFCF40C, 0x4E8F63D8, 0x69657A15, 0x466B4B2D,
242    0x4194B4D2, 0x1E9A85EA, 0x39709C27, 0x4B030BF3,
243    0x655DCE1D, 0x251F8899, 0x5B2EA879, 0x1E10E42F,
244    0x31F5BE07, 0x2AFBB7F9, 0x3E11021A, 0x5D97A17B,
245    0x6F0620BD, 0x5DBFC31D, 0x76C4761D, 0x21938559,
246    0x33777473, 0x71F0E92C, 0x0B9872A1, 0x4C2411F9,
247    0x545B7C96, 0x20256BAF, 0x7B8B493E, 0x33AD525C,
248    0x15EAEA1C, 0x6D2D1A21, 0x06A81D14, 0x3FACEB4F,
249    0x130EC21C, 0x3C84C4F5, 0x50FD67C0, 0x30FDD85A,
250];
251
252impl Permutation<[Mersenne31; 64]> for MdsMatrixMersenne31 {
253    fn permute(&self, input: [Mersenne31; 64]) -> [Mersenne31; 64] {
254        const MATRIX_CIRC_MDS_64_MERSENNE31_COL: [i64; 64] =
255            first_row_to_first_col(&MATRIX_CIRC_MDS_64_MERSENNE31_ROW);
256        LargeConvolveMersenne31::apply(
257            input,
258            MATRIX_CIRC_MDS_64_MERSENNE31_COL,
259            LargeConvolveMersenne31::conv64,
260        )
261    }
262
263    fn permute_mut(&self, input: &mut [Mersenne31; 64]) {
264        *input = self.permute(*input);
265    }
266}
267impl MdsPermutation<Mersenne31, 64> for MdsMatrixMersenne31 {}
268
269#[cfg(test)]
270mod tests {
271    use p3_symmetric::Permutation;
272
273    use super::{MdsMatrixMersenne31, Mersenne31};
274
275    #[test]
276    fn mersenne8() {
277        let input: [Mersenne31; 8] = Mersenne31::new_array([
278            1741044457, 327154658, 318297696, 1528828225, 468360260, 1271368222, 1906288587,
279            1521884224,
280        ]);
281
282        let output = MdsMatrixMersenne31.permute(input);
283
284        let expected: [Mersenne31; 8] = Mersenne31::new_array([
285            895992680, 1343855369, 2107796831, 266468728, 846686506, 252887121, 205223309,
286            260248790,
287        ]);
288
289        assert_eq!(output, expected);
290    }
291
292    #[test]
293    fn mersenne12() {
294        let input: [Mersenne31; 12] = Mersenne31::new_array([
295            1232740094, 661555540, 11024822, 1620264994, 471137070, 276755041, 1316882747,
296            1023679816, 1675266989, 743211887, 44774582, 1990989306,
297        ]);
298
299        let output = MdsMatrixMersenne31.permute(input);
300
301        let expected: [Mersenne31; 12] = Mersenne31::new_array([
302            860812289, 399778981, 1228500858, 798196553, 673507779, 1116345060, 829764188,
303            138346433, 578243475, 553581995, 578183208, 1527769050,
304        ]);
305
306        assert_eq!(output, expected);
307    }
308
309    #[test]
310    fn mersenne16() {
311        let input: [Mersenne31; 16] = Mersenne31::new_array([
312            1431168444, 963811518, 88067321, 381314132, 908628282, 1260098295, 980207659,
313            150070493, 357706876, 2014609375, 387876458, 1621671571, 183146044, 107201572,
314            166536524, 2078440788,
315        ]);
316
317        let output = MdsMatrixMersenne31.permute(input);
318
319        let expected: [Mersenne31; 16] = Mersenne31::new_array([
320            1858869691, 1607793806, 1200396641, 1400502985, 1511630695, 187938132, 1332411488,
321            2041577083, 2014246632, 802022141, 796807132, 1647212930, 813167618, 1867105010,
322            508596277, 1457551581,
323        ]);
324
325        assert_eq!(output, expected);
326    }
327
328    #[test]
329    fn mersenne32() {
330        let input: [Mersenne31; 32] = Mersenne31::new_array([
331            873912014, 1112497426, 300405095, 4255553, 1234979949, 156402357, 1952135954,
332            718195399, 1041748465, 683604342, 184275751, 1184118518, 214257054, 1293941921,
333            64085758, 710448062, 1133100009, 350114887, 1091675272, 671421879, 1226105999,
334            546430131, 1298443967, 1787169653, 2129310791, 1560307302, 471771931, 1191484402,
335            1550203198, 1541319048, 229197040, 839673789,
336        ]);
337
338        let output = MdsMatrixMersenne31.permute(input);
339
340        let expected: [Mersenne31; 32] = Mersenne31::new_array([
341            1439049928, 890642852, 694402307, 713403244, 553213342, 1049445650, 321709533,
342            1195683415, 2118492257, 623077773, 96734062, 990488164, 1674607608, 749155000,
343            353377854, 966432998, 1114654884, 1370359248, 1624965859, 685087760, 1631836645,
344            1615931812, 2061986317, 1773551151, 1449911206, 1951762557, 545742785, 582866449,
345            1379774336, 229242759, 1871227547, 752848413,
346        ]);
347
348        assert_eq!(output, expected);
349    }
350
351    #[test]
352    fn mersenne64() {
353        let input: [Mersenne31; 64] = Mersenne31::new_array([
354            837269696, 1509031194, 413915480, 1889329185, 315502822, 1529162228, 1454661012,
355            1015826742, 973381409, 1414676304, 1449029961, 1968715566, 2027226497, 1721820509,
356            434042616, 1436005045, 1680352863, 651591867, 260585272, 1078022153, 703990572,
357            269504423, 1776357592, 1174979337, 1142666094, 1897872960, 1387995838, 250774418,
358            776134750, 73930096, 194742451, 1860060380, 666407744, 669566398, 963802147,
359            2063418105, 1772573581, 998923482, 701912753, 1716548204, 860820931, 1680395948,
360            949886256, 1811558161, 501734557, 1671977429, 463135040, 1911493108, 207754409,
361            608714758, 1553060084, 1558941605, 980281686, 2014426559, 650527801, 53015148,
362            1521176057, 720530872, 713593252, 88228433, 1194162313, 1922416934, 1075145779,
363            344403794,
364        ]);
365
366        let output = MdsMatrixMersenne31.permute(input);
367
368        let expected: [Mersenne31; 64] = Mersenne31::new_array([
369            1599981950, 252630853, 1171557270, 116468420, 1269245345, 666203050, 46155642,
370            1701131520, 530845775, 508460407, 630407239, 1731628135, 1199144768, 295132047,
371            77536342, 1472377703, 30752443, 1300339617, 18647556, 1267774380, 1194573079,
372            1624665024, 646848056, 1667216490, 1184843555, 1250329476, 254171597, 1902035936,
373            1706882202, 964921003, 952266538, 1215696284, 539510504, 1056507562, 1393151480,
374            733644883, 1663330816, 1100715048, 991108703, 1671345065, 1376431774, 408310416,
375            313176996, 743567676, 304660642, 1842695838, 958201635, 1650792218, 541570244,
376            968523062, 1958918704, 1866282698, 849808680, 1193306222, 794153281, 822835360,
377            135282913, 1149868448, 2068162123, 1474283743, 2039088058, 720305835, 746036736,
378            671006610,
379        ]);
380
381        assert_eq!(output, expected);
382    }
383}