1use p3_field::PrimeCharacteristicRing;
10use p3_mds::MdsPermutation;
11use p3_mds::karatsuba_convolution::Convolve;
12use p3_mds::util::{dot_product, first_row_to_first_col};
13use p3_symmetric::Permutation;
14
15use crate::Mersenne31;
16
17#[derive(Clone, Debug, Default)]
18pub struct MdsMatrixMersenne31;
19
20struct SmallConvolveMersenne31;
25impl Convolve<Mersenne31, i64, i64> for SmallConvolveMersenne31 {
26 const T_ZERO: i64 = 0;
27 const U_ZERO: i64 = 0;
28
29 #[inline(always)]
30 fn halve(val: i64) -> i64 {
31 val >> 1
32 }
33
34 #[inline(always)]
38 fn read(input: Mersenne31) -> i64 {
39 input.value as i64
40 }
41
42 #[inline(always)]
46 fn parity_dot<const N: usize>(u: [i64; N], v: [i64; N]) -> i64 {
47 dot_product(u, v)
48 }
49
50 #[inline(always)]
57 fn reduce(z: i64) -> Mersenne31 {
58 debug_assert!(z >= 0);
59 Mersenne31::from_u64(z as u64)
60 }
61}
62
63struct LargeConvolveMersenne31;
68impl Convolve<Mersenne31, i64, i64> for LargeConvolveMersenne31 {
69 const T_ZERO: i64 = 0;
70 const U_ZERO: i64 = 0;
71
72 #[inline(always)]
73 fn halve(val: i64) -> i64 {
74 val >> 1
75 }
76
77 #[inline(always)]
81 fn read(input: Mersenne31) -> i64 {
82 input.value as i64
83 }
84
85 #[inline]
86 fn parity_dot<const N: usize>(u: [i64; N], v: [i64; N]) -> i64 {
87 let mut dp = 0i128;
92 for i in 0..N {
93 dp += u[i] as i128 * v[i] as i128;
94 }
95
96 const LOWMASK: i128 = (1 << 42) - 1; const HIGHMASK: i128 = !LOWMASK; let low_bits = (dp & LOWMASK) as i64; let high_bits = ((dp & HIGHMASK) >> 31) as i64; low_bits + high_bits
115 }
116
117 #[inline]
118 fn reduce(z: i64) -> Mersenne31 {
119 debug_assert!(z > -(1i64 << 49));
138 debug_assert!(z < (1i64 << 49));
139
140 const MASK: i64 = (1 << 31) - 1;
141 let low_bits = Mersenne31::new_reduced((z & MASK) as u32);
148
149 let high_bits = ((z >> 31) & MASK) as i32;
150 let sign_bits = (z >> 62) as i32;
151
152 let high = Mersenne31::new_reduced((high_bits + sign_bits) as u32);
158 low_bits + high
159 }
160}
161
162const MATRIX_CIRC_MDS_8_SML_ROW: [i64; 8] = [7, 1, 3, 8, 8, 3, 4, 9];
163
164impl Permutation<[Mersenne31; 8]> for MdsMatrixMersenne31 {
165 fn permute(&self, input: [Mersenne31; 8]) -> [Mersenne31; 8] {
166 const MATRIX_CIRC_MDS_8_SML_COL: [i64; 8] =
167 first_row_to_first_col(&MATRIX_CIRC_MDS_8_SML_ROW);
168 SmallConvolveMersenne31::apply(
169 input,
170 MATRIX_CIRC_MDS_8_SML_COL,
171 SmallConvolveMersenne31::conv8,
172 )
173 }
174}
175impl MdsPermutation<Mersenne31, 8> for MdsMatrixMersenne31 {}
176
177const MATRIX_CIRC_MDS_12_SML_ROW: [i64; 12] = [1, 1, 2, 1, 8, 9, 10, 7, 5, 9, 4, 10];
178
179impl Permutation<[Mersenne31; 12]> for MdsMatrixMersenne31 {
180 fn permute(&self, input: [Mersenne31; 12]) -> [Mersenne31; 12] {
181 const MATRIX_CIRC_MDS_12_SML_COL: [i64; 12] =
182 first_row_to_first_col(&MATRIX_CIRC_MDS_12_SML_ROW);
183 SmallConvolveMersenne31::apply(
184 input,
185 MATRIX_CIRC_MDS_12_SML_COL,
186 SmallConvolveMersenne31::conv12,
187 )
188 }
189}
190impl MdsPermutation<Mersenne31, 12> for MdsMatrixMersenne31 {}
191
192pub(crate) const MATRIX_CIRC_MDS_16_SML_ROW: [i64; 16] =
193 [1, 1, 51, 1, 11, 17, 2, 1, 101, 63, 15, 2, 67, 22, 13, 3];
194
195impl Permutation<[Mersenne31; 16]> for MdsMatrixMersenne31 {
196 fn permute(&self, input: [Mersenne31; 16]) -> [Mersenne31; 16] {
197 const MATRIX_CIRC_MDS_16_SML_COL: [i64; 16] =
198 first_row_to_first_col(&MATRIX_CIRC_MDS_16_SML_ROW);
199 SmallConvolveMersenne31::apply(
200 input,
201 MATRIX_CIRC_MDS_16_SML_COL,
202 SmallConvolveMersenne31::conv16,
203 )
204 }
205}
206impl MdsPermutation<Mersenne31, 16> for MdsMatrixMersenne31 {}
207
208#[rustfmt::skip]
209pub(crate) const MATRIX_CIRC_MDS_32_MERSENNE31_ROW: [i64; 32] = [
210 0x1896DC78, 0x559D1E29, 0x04EBD732, 0x3FF449D7,
211 0x2DB0E2CE, 0x26776B85, 0x76018E57, 0x1025FA13,
212 0x06486BAB, 0x37706EBA, 0x25EB966B, 0x113C24E5,
213 0x2AE20EC4, 0x5A27507C, 0x0CD38CF1, 0x761C10E5,
214 0x19E3EF1A, 0x032C730F, 0x35D8AF83, 0x651DF13B,
215 0x7EC3DB1A, 0x6A146994, 0x588F9145, 0x09B79455,
216 0x7FDA05EC, 0x19FE71A8, 0x6988947A, 0x624F1D31,
217 0x500BB628, 0x0B1428CE, 0x3A62E1D6, 0x77692387
218];
219
220impl Permutation<[Mersenne31; 32]> for MdsMatrixMersenne31 {
221 fn permute(&self, input: [Mersenne31; 32]) -> [Mersenne31; 32] {
222 const MATRIX_CIRC_MDS_32_MERSENNE31_COL: [i64; 32] =
223 first_row_to_first_col(&MATRIX_CIRC_MDS_32_MERSENNE31_ROW);
224 LargeConvolveMersenne31::apply(
225 input,
226 MATRIX_CIRC_MDS_32_MERSENNE31_COL,
227 LargeConvolveMersenne31::conv32,
228 )
229 }
230}
231impl MdsPermutation<Mersenne31, 32> for MdsMatrixMersenne31 {}
232
233#[rustfmt::skip]
234const MATRIX_CIRC_MDS_64_MERSENNE31_ROW: [i64; 64] = [
235 0x570227A5, 0x3702983F, 0x4B7B3B0A, 0x74F13DE3,
236 0x485314B0, 0x0157E2EC, 0x1AD2E5DE, 0x721515E3,
237 0x5452ADA3, 0x0C74B6C1, 0x67DA9450, 0x33A48369,
238 0x3BDBEE06, 0x7C678D5E, 0x160F16D3, 0x54888B8C,
239 0x666C7AA6, 0x113B89E2, 0x2A403CE2, 0x18F9DF42,
240 0x2A685E84, 0x49EEFDE5, 0x5D044806, 0x560A41F8,
241 0x69EF1BD0, 0x2CD15786, 0x62E07766, 0x22A231E2,
242 0x3CFCF40C, 0x4E8F63D8, 0x69657A15, 0x466B4B2D,
243 0x4194B4D2, 0x1E9A85EA, 0x39709C27, 0x4B030BF3,
244 0x655DCE1D, 0x251F8899, 0x5B2EA879, 0x1E10E42F,
245 0x31F5BE07, 0x2AFBB7F9, 0x3E11021A, 0x5D97A17B,
246 0x6F0620BD, 0x5DBFC31D, 0x76C4761D, 0x21938559,
247 0x33777473, 0x71F0E92C, 0x0B9872A1, 0x4C2411F9,
248 0x545B7C96, 0x20256BAF, 0x7B8B493E, 0x33AD525C,
249 0x15EAEA1C, 0x6D2D1A21, 0x06A81D14, 0x3FACEB4F,
250 0x130EC21C, 0x3C84C4F5, 0x50FD67C0, 0x30FDD85A,
251];
252
253impl Permutation<[Mersenne31; 64]> for MdsMatrixMersenne31 {
254 fn permute(&self, input: [Mersenne31; 64]) -> [Mersenne31; 64] {
255 const MATRIX_CIRC_MDS_64_MERSENNE31_COL: [i64; 64] =
256 first_row_to_first_col(&MATRIX_CIRC_MDS_64_MERSENNE31_ROW);
257 LargeConvolveMersenne31::apply(
258 input,
259 MATRIX_CIRC_MDS_64_MERSENNE31_COL,
260 LargeConvolveMersenne31::conv64,
261 )
262 }
263}
264impl MdsPermutation<Mersenne31, 64> for MdsMatrixMersenne31 {}
265
266#[cfg(test)]
267mod tests {
268 use p3_field::PrimeCharacteristicRing;
269 use p3_mds::karatsuba_convolution::Convolve;
270 use p3_symmetric::Permutation;
271
272 use super::{LargeConvolveMersenne31, MdsMatrixMersenne31, Mersenne31};
273
274 #[test]
275 fn large_convolve_reduce_accepts_p_representation() {
276 let got = LargeConvolveMersenne31::reduce(-1);
290 let expected = Mersenne31::ZERO - Mersenne31::ONE;
291 assert_eq!(got, expected);
292
293 let got = LargeConvolveMersenne31::reduce((1i64 << 31) - 1);
295 assert_eq!(got, Mersenne31::ZERO);
296 }
297
298 #[test]
299 fn mersenne8() {
300 let input: [Mersenne31; 8] = Mersenne31::new_array([
301 1741044457, 327154658, 318297696, 1528828225, 468360260, 1271368222, 1906288587,
302 1521884224,
303 ]);
304
305 let output = MdsMatrixMersenne31.permute(input);
306
307 let expected: [Mersenne31; 8] = Mersenne31::new_array([
308 895992680, 1343855369, 2107796831, 266468728, 846686506, 252887121, 205223309,
309 260248790,
310 ]);
311
312 assert_eq!(output, expected);
313 }
314
315 #[test]
316 fn mersenne12() {
317 let input: [Mersenne31; 12] = Mersenne31::new_array([
318 1232740094, 661555540, 11024822, 1620264994, 471137070, 276755041, 1316882747,
319 1023679816, 1675266989, 743211887, 44774582, 1990989306,
320 ]);
321
322 let output = MdsMatrixMersenne31.permute(input);
323
324 let expected: [Mersenne31; 12] = Mersenne31::new_array([
325 860812289, 399778981, 1228500858, 798196553, 673507779, 1116345060, 829764188,
326 138346433, 578243475, 553581995, 578183208, 1527769050,
327 ]);
328
329 assert_eq!(output, expected);
330 }
331
332 #[test]
333 fn mersenne16() {
334 let input: [Mersenne31; 16] = Mersenne31::new_array([
335 1431168444, 963811518, 88067321, 381314132, 908628282, 1260098295, 980207659,
336 150070493, 357706876, 2014609375, 387876458, 1621671571, 183146044, 107201572,
337 166536524, 2078440788,
338 ]);
339
340 let output = MdsMatrixMersenne31.permute(input);
341
342 let expected: [Mersenne31; 16] = Mersenne31::new_array([
343 1858869691, 1607793806, 1200396641, 1400502985, 1511630695, 187938132, 1332411488,
344 2041577083, 2014246632, 802022141, 796807132, 1647212930, 813167618, 1867105010,
345 508596277, 1457551581,
346 ]);
347
348 assert_eq!(output, expected);
349 }
350
351 #[test]
352 fn mersenne32() {
353 let input: [Mersenne31; 32] = Mersenne31::new_array([
354 873912014, 1112497426, 300405095, 4255553, 1234979949, 156402357, 1952135954,
355 718195399, 1041748465, 683604342, 184275751, 1184118518, 214257054, 1293941921,
356 64085758, 710448062, 1133100009, 350114887, 1091675272, 671421879, 1226105999,
357 546430131, 1298443967, 1787169653, 2129310791, 1560307302, 471771931, 1191484402,
358 1550203198, 1541319048, 229197040, 839673789,
359 ]);
360
361 let output = MdsMatrixMersenne31.permute(input);
362
363 let expected: [Mersenne31; 32] = Mersenne31::new_array([
364 1439049928, 890642852, 694402307, 713403244, 553213342, 1049445650, 321709533,
365 1195683415, 2118492257, 623077773, 96734062, 990488164, 1674607608, 749155000,
366 353377854, 966432998, 1114654884, 1370359248, 1624965859, 685087760, 1631836645,
367 1615931812, 2061986317, 1773551151, 1449911206, 1951762557, 545742785, 582866449,
368 1379774336, 229242759, 1871227547, 752848413,
369 ]);
370
371 assert_eq!(output, expected);
372 }
373
374 #[test]
375 fn mersenne64() {
376 let input: [Mersenne31; 64] = Mersenne31::new_array([
377 837269696, 1509031194, 413915480, 1889329185, 315502822, 1529162228, 1454661012,
378 1015826742, 973381409, 1414676304, 1449029961, 1968715566, 2027226497, 1721820509,
379 434042616, 1436005045, 1680352863, 651591867, 260585272, 1078022153, 703990572,
380 269504423, 1776357592, 1174979337, 1142666094, 1897872960, 1387995838, 250774418,
381 776134750, 73930096, 194742451, 1860060380, 666407744, 669566398, 963802147,
382 2063418105, 1772573581, 998923482, 701912753, 1716548204, 860820931, 1680395948,
383 949886256, 1811558161, 501734557, 1671977429, 463135040, 1911493108, 207754409,
384 608714758, 1553060084, 1558941605, 980281686, 2014426559, 650527801, 53015148,
385 1521176057, 720530872, 713593252, 88228433, 1194162313, 1922416934, 1075145779,
386 344403794,
387 ]);
388
389 let output = MdsMatrixMersenne31.permute(input);
390
391 let expected: [Mersenne31; 64] = Mersenne31::new_array([
392 1599981950, 252630853, 1171557270, 116468420, 1269245345, 666203050, 46155642,
393 1701131520, 530845775, 508460407, 630407239, 1731628135, 1199144768, 295132047,
394 77536342, 1472377703, 30752443, 1300339617, 18647556, 1267774380, 1194573079,
395 1624665024, 646848056, 1667216490, 1184843555, 1250329476, 254171597, 1902035936,
396 1706882202, 964921003, 952266538, 1215696284, 539510504, 1056507562, 1393151480,
397 733644883, 1663330816, 1100715048, 991108703, 1671345065, 1376431774, 408310416,
398 313176996, 743567676, 304660642, 1842695838, 958201635, 1650792218, 541570244,
399 968523062, 1958918704, 1866282698, 849808680, 1193306222, 794153281, 822835360,
400 135282913, 1149868448, 2068162123, 1474283743, 2039088058, 720305835, 746036736,
401 671006610,
402 ]);
403
404 assert_eq!(output, expected);
405 }
406}