p3_mersenne_31/
mds.rs

1//! MDS matrices over the Mersenne31 field, and permutations defined by them.
2//!
3//! NB: Not all sizes have fast implementations of their permutations.
4//! Supported sizes: 8, 12, 16, 32, 64.
5//! Sizes 8 and 12 are from Plonky2, size 16 was found as part of concurrent
6//! work by Angus Gruen and Hamish Ivey-Law. Other sizes are from Ulrich Haböck's
7//! database.
8
9use p3_field::PrimeCharacteristicRing;
10use p3_field::integers::QuotientMap;
11use p3_mds::MdsPermutation;
12use p3_mds::karatsuba_convolution::Convolve;
13use p3_mds::util::{dot_product, first_row_to_first_col};
14use p3_symmetric::Permutation;
15
16use crate::Mersenne31;
17
18#[derive(Clone, Debug, Default)]
19pub struct MdsMatrixMersenne31;
20
21/// Instantiate convolution for "small" RHS vectors over Mersenne31.
22///
23/// Here "small" means N = len(rhs) <= 16 and sum(r for r in rhs) <
24/// 2^24 (roughly), though in practice the sum will be less than 2^9.
25struct SmallConvolveMersenne31;
26impl Convolve<Mersenne31, i64, i64, i64> for SmallConvolveMersenne31 {
27    /// Return the lift of an (almost) reduced Mersenne31 element.
28    /// The Mersenne31 implementation guarantees that
29    /// 0 <= input.value <= P < 2^31.
30    #[inline(always)]
31    fn read(input: Mersenne31) -> i64 {
32        input.value as i64
33    }
34
35    /// FIXME: Refactor the dot product
36    /// For a convolution of size N, |x| < N * 2^31 and (as per the
37    /// assumption above), |y| < 2^24. So the product is at most N * 2^55
38    /// which will not overflow for N <= 16.
39    #[inline(always)]
40    fn parity_dot<const N: usize>(u: [i64; N], v: [i64; N]) -> i64 {
41        dot_product(u, v)
42    }
43
44    /// The assumptions above mean z < N^2 * 2^55, which is at most
45    /// 2^63 when N <= 16.
46    ///
47    /// NB: Even though intermediate values could be negative, the
48    /// output must be non-negative since the inputs were
49    /// non-negative.
50    #[inline(always)]
51    fn reduce(z: i64) -> Mersenne31 {
52        debug_assert!(z >= 0);
53        Mersenne31::from_u64(z as u64)
54    }
55}
56
57/// Instantiate convolution for "large" RHS vectors over Mersenne31.
58///
59/// Here "large" means the elements can be as big as the field
60/// characteristic, and the size N of the RHS is <= 64.
61struct LargeConvolveMersenne31;
62impl Convolve<Mersenne31, i64, i64, i64> for LargeConvolveMersenne31 {
63    /// Return the lift of an (almost) reduced Mersenne31 element.
64    /// The Mersenne31 implementation guarantees that
65    /// 0 <= input.value <= P < 2^31.
66    #[inline(always)]
67    fn read(input: Mersenne31) -> i64 {
68        input.value as i64
69    }
70
71    #[inline]
72    fn parity_dot<const N: usize>(u: [i64; N], v: [i64; N]) -> i64 {
73        // For a convolution of size N, |x|, |y| < N * 2^31, so the product
74        // could be as much as N^2 * 2^62. This will overflow an i64, so
75        // we first widen to i128.
76
77        let mut dp = 0i128;
78        for i in 0..N {
79            dp += u[i] as i128 * v[i] as i128;
80        }
81
82        const LOWMASK: i128 = (1 << 42) - 1; // Gets the bits lower than 42.
83        const HIGHMASK: i128 = !LOWMASK; // Gets all bits higher than 42.
84
85        let low_bits = (dp & LOWMASK) as i64; // low_bits < 2**42
86        let high_bits = ((dp & HIGHMASK) >> 31) as i64; // |high_bits| < 2**(n - 31)
87
88        // Proof that low_bits + high_bits is congruent to dp (mod p)
89        // and congruent to dp (mod 2^11):
90        //
91        // The individual bounds clearly show that low_bits +
92        // high_bits < 2**(n - 30).
93        //
94        // Next observe that low_bits + high_bits = input - (2**31 -
95        // 1) * (high_bits) = input mod P.
96        //
97        // Finally note that 2**11 divides high_bits and so low_bits +
98        // high_bits = low_bits mod 2**11 = input mod 2**11.
99
100        low_bits + high_bits
101    }
102
103    #[inline]
104    fn reduce(z: i64) -> Mersenne31 {
105        // After the dot product, the maximal size is N^2 * 2^62 < 2^74
106        // as N = 64 is the biggest size. So, after the partial
107        // reduction, the output z of parity dot satisfies |z| < 2^44
108        // (Where 44 is 74 - 30).
109        //
110        // In the recombining steps, conv maps (wo, w1) -> ((wo + w1)/2,
111        // (wo + w1)/2) which has no effect on the maximal size. (Indeed,
112        // it makes sizes almost strictly smaller).
113        //
114        // On the other hand, negacyclic_conv (ignoring the re-index)
115        // recombines as: (w0, w1, w2) -> (w0 + w1, w2 - w0 - w1). Hence
116        // if the input is <= K, the output is <= 3K.
117        //
118        // Thus the values appearing at the end are bounded by 3^n 2^44
119        // where n is the maximal number of negacyclic_conv recombination
120        // steps. When N = 64, we need to recombine for singed_conv_32,
121        // singed_conv_16, singed_conv_8 so the overall bound will be 3^3
122        // 2^44 < 32 * 2^44 < 2^49.
123        debug_assert!(z > -(1i64 << 49));
124        debug_assert!(z < (1i64 << 49));
125
126        const MASK: i64 = (1 << 31) - 1;
127        // Morally, our value is a i62 not a i64 as the top 3 bits are
128        // guaranteed to be equal.
129        let low_bits = unsafe {
130            // This is safe as 0 <= z & MASK < 2^31
131            Mersenne31::from_canonical_unchecked((z & MASK) as u32)
132        };
133
134        let high_bits = ((z >> 31) & MASK) as i32;
135        let sign_bits = (z >> 62) as i32;
136
137        let high = unsafe {
138            // This is safe as high_bits + sign_bits > 0 as by assumption b[63] = b[61].
139            Mersenne31::from_canonical_unchecked((high_bits + sign_bits) as u32)
140        };
141        low_bits + high
142    }
143}
144
145const MATRIX_CIRC_MDS_8_SML_ROW: [i64; 8] = [7, 1, 3, 8, 8, 3, 4, 9];
146
147impl Permutation<[Mersenne31; 8]> for MdsMatrixMersenne31 {
148    fn permute(&self, input: [Mersenne31; 8]) -> [Mersenne31; 8] {
149        const MATRIX_CIRC_MDS_8_SML_COL: [i64; 8] =
150            first_row_to_first_col(&MATRIX_CIRC_MDS_8_SML_ROW);
151        SmallConvolveMersenne31::apply(
152            input,
153            MATRIX_CIRC_MDS_8_SML_COL,
154            SmallConvolveMersenne31::conv8,
155        )
156    }
157}
158impl MdsPermutation<Mersenne31, 8> for MdsMatrixMersenne31 {}
159
160const MATRIX_CIRC_MDS_12_SML_ROW: [i64; 12] = [1, 1, 2, 1, 8, 9, 10, 7, 5, 9, 4, 10];
161
162impl Permutation<[Mersenne31; 12]> for MdsMatrixMersenne31 {
163    fn permute(&self, input: [Mersenne31; 12]) -> [Mersenne31; 12] {
164        const MATRIX_CIRC_MDS_12_SML_COL: [i64; 12] =
165            first_row_to_first_col(&MATRIX_CIRC_MDS_12_SML_ROW);
166        SmallConvolveMersenne31::apply(
167            input,
168            MATRIX_CIRC_MDS_12_SML_COL,
169            SmallConvolveMersenne31::conv12,
170        )
171    }
172}
173impl MdsPermutation<Mersenne31, 12> for MdsMatrixMersenne31 {}
174
175const MATRIX_CIRC_MDS_16_SML_ROW: [i64; 16] =
176    [1, 1, 51, 1, 11, 17, 2, 1, 101, 63, 15, 2, 67, 22, 13, 3];
177
178impl Permutation<[Mersenne31; 16]> for MdsMatrixMersenne31 {
179    fn permute(&self, input: [Mersenne31; 16]) -> [Mersenne31; 16] {
180        const MATRIX_CIRC_MDS_16_SML_COL: [i64; 16] =
181            first_row_to_first_col(&MATRIX_CIRC_MDS_16_SML_ROW);
182        SmallConvolveMersenne31::apply(
183            input,
184            MATRIX_CIRC_MDS_16_SML_COL,
185            SmallConvolveMersenne31::conv16,
186        )
187    }
188}
189impl MdsPermutation<Mersenne31, 16> for MdsMatrixMersenne31 {}
190
191#[rustfmt::skip]
192const MATRIX_CIRC_MDS_32_MERSENNE31_ROW: [i64; 32] = [
193    0x1896DC78, 0x559D1E29, 0x04EBD732, 0x3FF449D7,
194    0x2DB0E2CE, 0x26776B85, 0x76018E57, 0x1025FA13,
195    0x06486BAB, 0x37706EBA, 0x25EB966B, 0x113C24E5,
196    0x2AE20EC4, 0x5A27507C, 0x0CD38CF1, 0x761C10E5,
197    0x19E3EF1A, 0x032C730F, 0x35D8AF83, 0x651DF13B,
198    0x7EC3DB1A, 0x6A146994, 0x588F9145, 0x09B79455,
199    0x7FDA05EC, 0x19FE71A8, 0x6988947A, 0x624F1D31,
200    0x500BB628, 0x0B1428CE, 0x3A62E1D6, 0x77692387
201];
202
203impl Permutation<[Mersenne31; 32]> for MdsMatrixMersenne31 {
204    fn permute(&self, input: [Mersenne31; 32]) -> [Mersenne31; 32] {
205        const MATRIX_CIRC_MDS_32_MERSENNE31_COL: [i64; 32] =
206            first_row_to_first_col(&MATRIX_CIRC_MDS_32_MERSENNE31_ROW);
207        LargeConvolveMersenne31::apply(
208            input,
209            MATRIX_CIRC_MDS_32_MERSENNE31_COL,
210            LargeConvolveMersenne31::conv32,
211        )
212    }
213}
214impl MdsPermutation<Mersenne31, 32> for MdsMatrixMersenne31 {}
215
216#[rustfmt::skip]
217const MATRIX_CIRC_MDS_64_MERSENNE31_ROW: [i64; 64] = [
218    0x570227A5, 0x3702983F, 0x4B7B3B0A, 0x74F13DE3,
219    0x485314B0, 0x0157E2EC, 0x1AD2E5DE, 0x721515E3,
220    0x5452ADA3, 0x0C74B6C1, 0x67DA9450, 0x33A48369,
221    0x3BDBEE06, 0x7C678D5E, 0x160F16D3, 0x54888B8C,
222    0x666C7AA6, 0x113B89E2, 0x2A403CE2, 0x18F9DF42,
223    0x2A685E84, 0x49EEFDE5, 0x5D044806, 0x560A41F8,
224    0x69EF1BD0, 0x2CD15786, 0x62E07766, 0x22A231E2,
225    0x3CFCF40C, 0x4E8F63D8, 0x69657A15, 0x466B4B2D,
226    0x4194B4D2, 0x1E9A85EA, 0x39709C27, 0x4B030BF3,
227    0x655DCE1D, 0x251F8899, 0x5B2EA879, 0x1E10E42F,
228    0x31F5BE07, 0x2AFBB7F9, 0x3E11021A, 0x5D97A17B,
229    0x6F0620BD, 0x5DBFC31D, 0x76C4761D, 0x21938559,
230    0x33777473, 0x71F0E92C, 0x0B9872A1, 0x4C2411F9,
231    0x545B7C96, 0x20256BAF, 0x7B8B493E, 0x33AD525C,
232    0x15EAEA1C, 0x6D2D1A21, 0x06A81D14, 0x3FACEB4F,
233    0x130EC21C, 0x3C84C4F5, 0x50FD67C0, 0x30FDD85A,
234];
235
236impl Permutation<[Mersenne31; 64]> for MdsMatrixMersenne31 {
237    fn permute(&self, input: [Mersenne31; 64]) -> [Mersenne31; 64] {
238        const MATRIX_CIRC_MDS_64_MERSENNE31_COL: [i64; 64] =
239            first_row_to_first_col(&MATRIX_CIRC_MDS_64_MERSENNE31_ROW);
240        LargeConvolveMersenne31::apply(
241            input,
242            MATRIX_CIRC_MDS_64_MERSENNE31_COL,
243            LargeConvolveMersenne31::conv64,
244        )
245    }
246}
247impl MdsPermutation<Mersenne31, 64> for MdsMatrixMersenne31 {}
248
249#[cfg(test)]
250mod tests {
251    use p3_symmetric::Permutation;
252
253    use super::{MdsMatrixMersenne31, Mersenne31};
254
255    #[test]
256    fn mersenne8() {
257        let input: [Mersenne31; 8] = Mersenne31::new_array([
258            1741044457, 327154658, 318297696, 1528828225, 468360260, 1271368222, 1906288587,
259            1521884224,
260        ]);
261
262        let output = MdsMatrixMersenne31.permute(input);
263
264        let expected: [Mersenne31; 8] = Mersenne31::new_array([
265            895992680, 1343855369, 2107796831, 266468728, 846686506, 252887121, 205223309,
266            260248790,
267        ]);
268
269        assert_eq!(output, expected);
270    }
271
272    #[test]
273    fn mersenne12() {
274        let input: [Mersenne31; 12] = Mersenne31::new_array([
275            1232740094, 661555540, 11024822, 1620264994, 471137070, 276755041, 1316882747,
276            1023679816, 1675266989, 743211887, 44774582, 1990989306,
277        ]);
278
279        let output = MdsMatrixMersenne31.permute(input);
280
281        let expected: [Mersenne31; 12] = Mersenne31::new_array([
282            860812289, 399778981, 1228500858, 798196553, 673507779, 1116345060, 829764188,
283            138346433, 578243475, 553581995, 578183208, 1527769050,
284        ]);
285
286        assert_eq!(output, expected);
287    }
288
289    #[test]
290    fn mersenne16() {
291        let input: [Mersenne31; 16] = Mersenne31::new_array([
292            1431168444, 963811518, 88067321, 381314132, 908628282, 1260098295, 980207659,
293            150070493, 357706876, 2014609375, 387876458, 1621671571, 183146044, 107201572,
294            166536524, 2078440788,
295        ]);
296
297        let output = MdsMatrixMersenne31.permute(input);
298
299        let expected: [Mersenne31; 16] = Mersenne31::new_array([
300            1858869691, 1607793806, 1200396641, 1400502985, 1511630695, 187938132, 1332411488,
301            2041577083, 2014246632, 802022141, 796807132, 1647212930, 813167618, 1867105010,
302            508596277, 1457551581,
303        ]);
304
305        assert_eq!(output, expected);
306    }
307
308    #[test]
309    fn mersenne32() {
310        let input: [Mersenne31; 32] = Mersenne31::new_array([
311            873912014, 1112497426, 300405095, 4255553, 1234979949, 156402357, 1952135954,
312            718195399, 1041748465, 683604342, 184275751, 1184118518, 214257054, 1293941921,
313            64085758, 710448062, 1133100009, 350114887, 1091675272, 671421879, 1226105999,
314            546430131, 1298443967, 1787169653, 2129310791, 1560307302, 471771931, 1191484402,
315            1550203198, 1541319048, 229197040, 839673789,
316        ]);
317
318        let output = MdsMatrixMersenne31.permute(input);
319
320        let expected: [Mersenne31; 32] = Mersenne31::new_array([
321            1439049928, 890642852, 694402307, 713403244, 553213342, 1049445650, 321709533,
322            1195683415, 2118492257, 623077773, 96734062, 990488164, 1674607608, 749155000,
323            353377854, 966432998, 1114654884, 1370359248, 1624965859, 685087760, 1631836645,
324            1615931812, 2061986317, 1773551151, 1449911206, 1951762557, 545742785, 582866449,
325            1379774336, 229242759, 1871227547, 752848413,
326        ]);
327
328        assert_eq!(output, expected);
329    }
330
331    #[test]
332    fn mersenne64() {
333        let input: [Mersenne31; 64] = Mersenne31::new_array([
334            837269696, 1509031194, 413915480, 1889329185, 315502822, 1529162228, 1454661012,
335            1015826742, 973381409, 1414676304, 1449029961, 1968715566, 2027226497, 1721820509,
336            434042616, 1436005045, 1680352863, 651591867, 260585272, 1078022153, 703990572,
337            269504423, 1776357592, 1174979337, 1142666094, 1897872960, 1387995838, 250774418,
338            776134750, 73930096, 194742451, 1860060380, 666407744, 669566398, 963802147,
339            2063418105, 1772573581, 998923482, 701912753, 1716548204, 860820931, 1680395948,
340            949886256, 1811558161, 501734557, 1671977429, 463135040, 1911493108, 207754409,
341            608714758, 1553060084, 1558941605, 980281686, 2014426559, 650527801, 53015148,
342            1521176057, 720530872, 713593252, 88228433, 1194162313, 1922416934, 1075145779,
343            344403794,
344        ]);
345
346        let output = MdsMatrixMersenne31.permute(input);
347
348        let expected: [Mersenne31; 64] = Mersenne31::new_array([
349            1599981950, 252630853, 1171557270, 116468420, 1269245345, 666203050, 46155642,
350            1701131520, 530845775, 508460407, 630407239, 1731628135, 1199144768, 295132047,
351            77536342, 1472377703, 30752443, 1300339617, 18647556, 1267774380, 1194573079,
352            1624665024, 646848056, 1667216490, 1184843555, 1250329476, 254171597, 1902035936,
353            1706882202, 964921003, 952266538, 1215696284, 539510504, 1056507562, 1393151480,
354            733644883, 1663330816, 1100715048, 991108703, 1671345065, 1376431774, 408310416,
355            313176996, 743567676, 304660642, 1842695838, 958201635, 1650792218, 541570244,
356            968523062, 1958918704, 1866282698, 849808680, 1193306222, 794153281, 822835360,
357            135282913, 1149868448, 2068162123, 1474283743, 2039088058, 720305835, 746036736,
358            671006610,
359        ]);
360
361        assert_eq!(output, expected);
362    }
363}