Skip to main content

p3_mersenne_31/
mds.rs

1//! MDS matrices over the Mersenne31 field, and permutations defined by them.
2//!
3//! NB: Not all sizes have fast implementations of their permutations.
4//! Supported sizes: 8, 12, 16, 32, 64.
5//! Sizes 8 and 12 are from Plonky2, size 16 was found as part of concurrent
6//! work by Angus Gruen and Hamish Ivey-Law. Other sizes are from Ulrich Haböck's
7//! database.
8
9use p3_field::PrimeCharacteristicRing;
10use p3_field::integers::QuotientMap;
11use p3_mds::MdsPermutation;
12use p3_mds::karatsuba_convolution::Convolve;
13use p3_mds::util::{dot_product, first_row_to_first_col};
14use p3_symmetric::Permutation;
15
16use crate::Mersenne31;
17
18#[derive(Clone, Debug, Default)]
19pub struct MdsMatrixMersenne31;
20
21/// Instantiate convolution for "small" RHS vectors over Mersenne31.
22///
23/// Here "small" means N = len(rhs) <= 16 and sum(r for r in rhs) <
24/// 2^24 (roughly), though in practice the sum will be less than 2^9.
25struct SmallConvolveMersenne31;
26impl Convolve<Mersenne31, i64, i64> for SmallConvolveMersenne31 {
27    const T_ZERO: i64 = 0;
28    const U_ZERO: i64 = 0;
29
30    #[inline(always)]
31    fn halve(val: i64) -> i64 {
32        val >> 1
33    }
34
35    /// Return the lift of an (almost) reduced Mersenne31 element.
36    /// The Mersenne31 implementation guarantees that
37    /// 0 <= input.value <= P < 2^31.
38    #[inline(always)]
39    fn read(input: Mersenne31) -> i64 {
40        input.value as i64
41    }
42
43    /// For a convolution of size N, |x| < N * 2^31 and (as per the
44    /// assumption above), |y| < 2^24. So the product is at most N * 2^55
45    /// which will not overflow for N <= 16.
46    #[inline(always)]
47    fn parity_dot<const N: usize>(u: [i64; N], v: [i64; N]) -> i64 {
48        dot_product(u, v)
49    }
50
51    /// The assumptions above mean z < N^2 * 2^55, which is at most
52    /// 2^63 when N <= 16.
53    ///
54    /// NB: Even though intermediate values could be negative, the
55    /// output must be non-negative since the inputs were
56    /// non-negative.
57    #[inline(always)]
58    fn reduce(z: i64) -> Mersenne31 {
59        debug_assert!(z >= 0);
60        Mersenne31::from_u64(z as u64)
61    }
62}
63
64/// Instantiate convolution for "large" RHS vectors over Mersenne31.
65///
66/// Here "large" means the elements can be as big as the field
67/// characteristic, and the size N of the RHS is <= 64.
68struct LargeConvolveMersenne31;
69impl Convolve<Mersenne31, i64, i64> for LargeConvolveMersenne31 {
70    const T_ZERO: i64 = 0;
71    const U_ZERO: i64 = 0;
72
73    #[inline(always)]
74    fn halve(val: i64) -> i64 {
75        val >> 1
76    }
77
78    /// Return the lift of an (almost) reduced Mersenne31 element.
79    /// The Mersenne31 implementation guarantees that
80    /// 0 <= input.value <= P < 2^31.
81    #[inline(always)]
82    fn read(input: Mersenne31) -> i64 {
83        input.value as i64
84    }
85
86    #[inline]
87    fn parity_dot<const N: usize>(u: [i64; N], v: [i64; N]) -> i64 {
88        // For a convolution of size N, |x|, |y| < N * 2^31, so the product
89        // could be as much as N^2 * 2^62. This will overflow an i64, so
90        // we first widen to i128.
91
92        let mut dp = 0i128;
93        for i in 0..N {
94            dp += u[i] as i128 * v[i] as i128;
95        }
96
97        const LOWMASK: i128 = (1 << 42) - 1; // Gets the bits lower than 42.
98        const HIGHMASK: i128 = !LOWMASK; // Gets all bits higher than 42.
99
100        let low_bits = (dp & LOWMASK) as i64; // low_bits < 2**42
101        let high_bits = ((dp & HIGHMASK) >> 31) as i64; // |high_bits| < 2**(n - 31)
102
103        // Proof that low_bits + high_bits is congruent to dp (mod p)
104        // and congruent to dp (mod 2^11):
105        //
106        // The individual bounds clearly show that low_bits +
107        // high_bits < 2**(n - 30).
108        //
109        // Next observe that low_bits + high_bits = input - (2**31 -
110        // 1) * (high_bits) = input mod P.
111        //
112        // Finally note that 2**11 divides high_bits and so low_bits +
113        // high_bits = low_bits mod 2**11 = input mod 2**11.
114
115        low_bits + high_bits
116    }
117
118    #[inline]
119    fn reduce(z: i64) -> Mersenne31 {
120        // After the dot product, the maximal size is N^2 * 2^62 < 2^74
121        // as N = 64 is the biggest size. So, after the partial
122        // reduction, the output z of parity dot satisfies |z| < 2^44
123        // (Where 44 is 74 - 30).
124        //
125        // In the recombining steps, conv maps (wo, w1) -> ((wo + w1)/2,
126        // (wo + w1)/2) which has no effect on the maximal size. (Indeed,
127        // it makes sizes almost strictly smaller).
128        //
129        // On the other hand, negacyclic_conv (ignoring the re-index)
130        // recombines as: (w0, w1, w2) -> (w0 + w1, w2 - w0 - w1). Hence
131        // if the input is <= K, the output is <= 3K.
132        //
133        // Thus the values appearing at the end are bounded by 3^n 2^44
134        // where n is the maximal number of negacyclic_conv recombination
135        // steps. When N = 64, we need to recombine for singed_conv_32,
136        // singed_conv_16, singed_conv_8 so the overall bound will be 3^3
137        // 2^44 < 32 * 2^44 < 2^49.
138        debug_assert!(z > -(1i64 << 49));
139        debug_assert!(z < (1i64 << 49));
140
141        const MASK: i64 = (1 << 31) - 1;
142        // Morally, our value is a i62 not a i64 as the top 3 bits are
143        // guaranteed to be equal.
144        let low_bits = unsafe {
145            // This is safe as 0 <= z & MASK < 2^31
146            Mersenne31::from_canonical_unchecked((z & MASK) as u32)
147        };
148
149        let high_bits = ((z >> 31) & MASK) as i32;
150        let sign_bits = (z >> 62) as i32;
151
152        let high = unsafe {
153            // This is safe as high_bits + sign_bits > 0 as by assumption b[63] = b[61].
154            Mersenne31::from_canonical_unchecked((high_bits + sign_bits) as u32)
155        };
156        low_bits + high
157    }
158}
159
160const MATRIX_CIRC_MDS_8_SML_ROW: [i64; 8] = [7, 1, 3, 8, 8, 3, 4, 9];
161
162impl Permutation<[Mersenne31; 8]> for MdsMatrixMersenne31 {
163    fn permute(&self, input: [Mersenne31; 8]) -> [Mersenne31; 8] {
164        const MATRIX_CIRC_MDS_8_SML_COL: [i64; 8] =
165            first_row_to_first_col(&MATRIX_CIRC_MDS_8_SML_ROW);
166        SmallConvolveMersenne31::apply(
167            input,
168            MATRIX_CIRC_MDS_8_SML_COL,
169            SmallConvolveMersenne31::conv8,
170        )
171    }
172}
173impl MdsPermutation<Mersenne31, 8> for MdsMatrixMersenne31 {}
174
175const MATRIX_CIRC_MDS_12_SML_ROW: [i64; 12] = [1, 1, 2, 1, 8, 9, 10, 7, 5, 9, 4, 10];
176
177impl Permutation<[Mersenne31; 12]> for MdsMatrixMersenne31 {
178    fn permute(&self, input: [Mersenne31; 12]) -> [Mersenne31; 12] {
179        const MATRIX_CIRC_MDS_12_SML_COL: [i64; 12] =
180            first_row_to_first_col(&MATRIX_CIRC_MDS_12_SML_ROW);
181        SmallConvolveMersenne31::apply(
182            input,
183            MATRIX_CIRC_MDS_12_SML_COL,
184            SmallConvolveMersenne31::conv12,
185        )
186    }
187}
188impl MdsPermutation<Mersenne31, 12> for MdsMatrixMersenne31 {}
189
190const MATRIX_CIRC_MDS_16_SML_ROW: [i64; 16] =
191    [1, 1, 51, 1, 11, 17, 2, 1, 101, 63, 15, 2, 67, 22, 13, 3];
192
193impl Permutation<[Mersenne31; 16]> for MdsMatrixMersenne31 {
194    fn permute(&self, input: [Mersenne31; 16]) -> [Mersenne31; 16] {
195        const MATRIX_CIRC_MDS_16_SML_COL: [i64; 16] =
196            first_row_to_first_col(&MATRIX_CIRC_MDS_16_SML_ROW);
197        SmallConvolveMersenne31::apply(
198            input,
199            MATRIX_CIRC_MDS_16_SML_COL,
200            SmallConvolveMersenne31::conv16,
201        )
202    }
203}
204impl MdsPermutation<Mersenne31, 16> for MdsMatrixMersenne31 {}
205
206#[rustfmt::skip]
207const MATRIX_CIRC_MDS_32_MERSENNE31_ROW: [i64; 32] = [
208    0x1896DC78, 0x559D1E29, 0x04EBD732, 0x3FF449D7,
209    0x2DB0E2CE, 0x26776B85, 0x76018E57, 0x1025FA13,
210    0x06486BAB, 0x37706EBA, 0x25EB966B, 0x113C24E5,
211    0x2AE20EC4, 0x5A27507C, 0x0CD38CF1, 0x761C10E5,
212    0x19E3EF1A, 0x032C730F, 0x35D8AF83, 0x651DF13B,
213    0x7EC3DB1A, 0x6A146994, 0x588F9145, 0x09B79455,
214    0x7FDA05EC, 0x19FE71A8, 0x6988947A, 0x624F1D31,
215    0x500BB628, 0x0B1428CE, 0x3A62E1D6, 0x77692387
216];
217
218impl Permutation<[Mersenne31; 32]> for MdsMatrixMersenne31 {
219    fn permute(&self, input: [Mersenne31; 32]) -> [Mersenne31; 32] {
220        const MATRIX_CIRC_MDS_32_MERSENNE31_COL: [i64; 32] =
221            first_row_to_first_col(&MATRIX_CIRC_MDS_32_MERSENNE31_ROW);
222        LargeConvolveMersenne31::apply(
223            input,
224            MATRIX_CIRC_MDS_32_MERSENNE31_COL,
225            LargeConvolveMersenne31::conv32,
226        )
227    }
228}
229impl MdsPermutation<Mersenne31, 32> for MdsMatrixMersenne31 {}
230
231#[rustfmt::skip]
232const MATRIX_CIRC_MDS_64_MERSENNE31_ROW: [i64; 64] = [
233    0x570227A5, 0x3702983F, 0x4B7B3B0A, 0x74F13DE3,
234    0x485314B0, 0x0157E2EC, 0x1AD2E5DE, 0x721515E3,
235    0x5452ADA3, 0x0C74B6C1, 0x67DA9450, 0x33A48369,
236    0x3BDBEE06, 0x7C678D5E, 0x160F16D3, 0x54888B8C,
237    0x666C7AA6, 0x113B89E2, 0x2A403CE2, 0x18F9DF42,
238    0x2A685E84, 0x49EEFDE5, 0x5D044806, 0x560A41F8,
239    0x69EF1BD0, 0x2CD15786, 0x62E07766, 0x22A231E2,
240    0x3CFCF40C, 0x4E8F63D8, 0x69657A15, 0x466B4B2D,
241    0x4194B4D2, 0x1E9A85EA, 0x39709C27, 0x4B030BF3,
242    0x655DCE1D, 0x251F8899, 0x5B2EA879, 0x1E10E42F,
243    0x31F5BE07, 0x2AFBB7F9, 0x3E11021A, 0x5D97A17B,
244    0x6F0620BD, 0x5DBFC31D, 0x76C4761D, 0x21938559,
245    0x33777473, 0x71F0E92C, 0x0B9872A1, 0x4C2411F9,
246    0x545B7C96, 0x20256BAF, 0x7B8B493E, 0x33AD525C,
247    0x15EAEA1C, 0x6D2D1A21, 0x06A81D14, 0x3FACEB4F,
248    0x130EC21C, 0x3C84C4F5, 0x50FD67C0, 0x30FDD85A,
249];
250
251impl Permutation<[Mersenne31; 64]> for MdsMatrixMersenne31 {
252    fn permute(&self, input: [Mersenne31; 64]) -> [Mersenne31; 64] {
253        const MATRIX_CIRC_MDS_64_MERSENNE31_COL: [i64; 64] =
254            first_row_to_first_col(&MATRIX_CIRC_MDS_64_MERSENNE31_ROW);
255        LargeConvolveMersenne31::apply(
256            input,
257            MATRIX_CIRC_MDS_64_MERSENNE31_COL,
258            LargeConvolveMersenne31::conv64,
259        )
260    }
261}
262impl MdsPermutation<Mersenne31, 64> for MdsMatrixMersenne31 {}
263
264#[cfg(test)]
265mod tests {
266    use p3_symmetric::Permutation;
267
268    use super::{MdsMatrixMersenne31, Mersenne31};
269
270    #[test]
271    fn mersenne8() {
272        let input: [Mersenne31; 8] = Mersenne31::new_array([
273            1741044457, 327154658, 318297696, 1528828225, 468360260, 1271368222, 1906288587,
274            1521884224,
275        ]);
276
277        let output = MdsMatrixMersenne31.permute(input);
278
279        let expected: [Mersenne31; 8] = Mersenne31::new_array([
280            895992680, 1343855369, 2107796831, 266468728, 846686506, 252887121, 205223309,
281            260248790,
282        ]);
283
284        assert_eq!(output, expected);
285    }
286
287    #[test]
288    fn mersenne12() {
289        let input: [Mersenne31; 12] = Mersenne31::new_array([
290            1232740094, 661555540, 11024822, 1620264994, 471137070, 276755041, 1316882747,
291            1023679816, 1675266989, 743211887, 44774582, 1990989306,
292        ]);
293
294        let output = MdsMatrixMersenne31.permute(input);
295
296        let expected: [Mersenne31; 12] = Mersenne31::new_array([
297            860812289, 399778981, 1228500858, 798196553, 673507779, 1116345060, 829764188,
298            138346433, 578243475, 553581995, 578183208, 1527769050,
299        ]);
300
301        assert_eq!(output, expected);
302    }
303
304    #[test]
305    fn mersenne16() {
306        let input: [Mersenne31; 16] = Mersenne31::new_array([
307            1431168444, 963811518, 88067321, 381314132, 908628282, 1260098295, 980207659,
308            150070493, 357706876, 2014609375, 387876458, 1621671571, 183146044, 107201572,
309            166536524, 2078440788,
310        ]);
311
312        let output = MdsMatrixMersenne31.permute(input);
313
314        let expected: [Mersenne31; 16] = Mersenne31::new_array([
315            1858869691, 1607793806, 1200396641, 1400502985, 1511630695, 187938132, 1332411488,
316            2041577083, 2014246632, 802022141, 796807132, 1647212930, 813167618, 1867105010,
317            508596277, 1457551581,
318        ]);
319
320        assert_eq!(output, expected);
321    }
322
323    #[test]
324    fn mersenne32() {
325        let input: [Mersenne31; 32] = Mersenne31::new_array([
326            873912014, 1112497426, 300405095, 4255553, 1234979949, 156402357, 1952135954,
327            718195399, 1041748465, 683604342, 184275751, 1184118518, 214257054, 1293941921,
328            64085758, 710448062, 1133100009, 350114887, 1091675272, 671421879, 1226105999,
329            546430131, 1298443967, 1787169653, 2129310791, 1560307302, 471771931, 1191484402,
330            1550203198, 1541319048, 229197040, 839673789,
331        ]);
332
333        let output = MdsMatrixMersenne31.permute(input);
334
335        let expected: [Mersenne31; 32] = Mersenne31::new_array([
336            1439049928, 890642852, 694402307, 713403244, 553213342, 1049445650, 321709533,
337            1195683415, 2118492257, 623077773, 96734062, 990488164, 1674607608, 749155000,
338            353377854, 966432998, 1114654884, 1370359248, 1624965859, 685087760, 1631836645,
339            1615931812, 2061986317, 1773551151, 1449911206, 1951762557, 545742785, 582866449,
340            1379774336, 229242759, 1871227547, 752848413,
341        ]);
342
343        assert_eq!(output, expected);
344    }
345
346    #[test]
347    fn mersenne64() {
348        let input: [Mersenne31; 64] = Mersenne31::new_array([
349            837269696, 1509031194, 413915480, 1889329185, 315502822, 1529162228, 1454661012,
350            1015826742, 973381409, 1414676304, 1449029961, 1968715566, 2027226497, 1721820509,
351            434042616, 1436005045, 1680352863, 651591867, 260585272, 1078022153, 703990572,
352            269504423, 1776357592, 1174979337, 1142666094, 1897872960, 1387995838, 250774418,
353            776134750, 73930096, 194742451, 1860060380, 666407744, 669566398, 963802147,
354            2063418105, 1772573581, 998923482, 701912753, 1716548204, 860820931, 1680395948,
355            949886256, 1811558161, 501734557, 1671977429, 463135040, 1911493108, 207754409,
356            608714758, 1553060084, 1558941605, 980281686, 2014426559, 650527801, 53015148,
357            1521176057, 720530872, 713593252, 88228433, 1194162313, 1922416934, 1075145779,
358            344403794,
359        ]);
360
361        let output = MdsMatrixMersenne31.permute(input);
362
363        let expected: [Mersenne31; 64] = Mersenne31::new_array([
364            1599981950, 252630853, 1171557270, 116468420, 1269245345, 666203050, 46155642,
365            1701131520, 530845775, 508460407, 630407239, 1731628135, 1199144768, 295132047,
366            77536342, 1472377703, 30752443, 1300339617, 18647556, 1267774380, 1194573079,
367            1624665024, 646848056, 1667216490, 1184843555, 1250329476, 254171597, 1902035936,
368            1706882202, 964921003, 952266538, 1215696284, 539510504, 1056507562, 1393151480,
369            733644883, 1663330816, 1100715048, 991108703, 1671345065, 1376431774, 408310416,
370            313176996, 743567676, 304660642, 1842695838, 958201635, 1650792218, 541570244,
371            968523062, 1958918704, 1866282698, 849808680, 1193306222, 794153281, 822835360,
372            135282913, 1149868448, 2068162123, 1474283743, 2039088058, 720305835, 746036736,
373            671006610,
374        ]);
375
376        assert_eq!(output, expected);
377    }
378}