1use p3_field::PrimeCharacteristicRing;
10use p3_field::integers::QuotientMap;
11use p3_mds::MdsPermutation;
12use p3_mds::karatsuba_convolution::Convolve;
13use p3_mds::util::{dot_product, first_row_to_first_col};
14use p3_symmetric::Permutation;
15
16use crate::Mersenne31;
17
18#[derive(Clone, Debug, Default)]
19pub struct MdsMatrixMersenne31;
20
21struct SmallConvolveMersenne31;
26impl Convolve<Mersenne31, i64, i64> for SmallConvolveMersenne31 {
27 const T_ZERO: i64 = 0;
28 const U_ZERO: i64 = 0;
29
30 #[inline(always)]
31 fn halve(val: i64) -> i64 {
32 val >> 1
33 }
34
35 #[inline(always)]
39 fn read(input: Mersenne31) -> i64 {
40 input.value as i64
41 }
42
43 #[inline(always)]
47 fn parity_dot<const N: usize>(u: [i64; N], v: [i64; N]) -> i64 {
48 dot_product(u, v)
49 }
50
51 #[inline(always)]
58 fn reduce(z: i64) -> Mersenne31 {
59 debug_assert!(z >= 0);
60 Mersenne31::from_u64(z as u64)
61 }
62}
63
64struct LargeConvolveMersenne31;
69impl Convolve<Mersenne31, i64, i64> for LargeConvolveMersenne31 {
70 const T_ZERO: i64 = 0;
71 const U_ZERO: i64 = 0;
72
73 #[inline(always)]
74 fn halve(val: i64) -> i64 {
75 val >> 1
76 }
77
78 #[inline(always)]
82 fn read(input: Mersenne31) -> i64 {
83 input.value as i64
84 }
85
86 #[inline]
87 fn parity_dot<const N: usize>(u: [i64; N], v: [i64; N]) -> i64 {
88 let mut dp = 0i128;
93 for i in 0..N {
94 dp += u[i] as i128 * v[i] as i128;
95 }
96
97 const LOWMASK: i128 = (1 << 42) - 1; const HIGHMASK: i128 = !LOWMASK; let low_bits = (dp & LOWMASK) as i64; let high_bits = ((dp & HIGHMASK) >> 31) as i64; low_bits + high_bits
116 }
117
118 #[inline]
119 fn reduce(z: i64) -> Mersenne31 {
120 debug_assert!(z > -(1i64 << 49));
139 debug_assert!(z < (1i64 << 49));
140
141 const MASK: i64 = (1 << 31) - 1;
142 let low_bits = unsafe {
145 Mersenne31::from_canonical_unchecked((z & MASK) as u32)
147 };
148
149 let high_bits = ((z >> 31) & MASK) as i32;
150 let sign_bits = (z >> 62) as i32;
151
152 let high = unsafe {
153 Mersenne31::from_canonical_unchecked((high_bits + sign_bits) as u32)
155 };
156 low_bits + high
157 }
158}
159
160const MATRIX_CIRC_MDS_8_SML_ROW: [i64; 8] = [7, 1, 3, 8, 8, 3, 4, 9];
161
162impl Permutation<[Mersenne31; 8]> for MdsMatrixMersenne31 {
163 fn permute(&self, input: [Mersenne31; 8]) -> [Mersenne31; 8] {
164 const MATRIX_CIRC_MDS_8_SML_COL: [i64; 8] =
165 first_row_to_first_col(&MATRIX_CIRC_MDS_8_SML_ROW);
166 SmallConvolveMersenne31::apply(
167 input,
168 MATRIX_CIRC_MDS_8_SML_COL,
169 SmallConvolveMersenne31::conv8,
170 )
171 }
172}
173impl MdsPermutation<Mersenne31, 8> for MdsMatrixMersenne31 {}
174
175const MATRIX_CIRC_MDS_12_SML_ROW: [i64; 12] = [1, 1, 2, 1, 8, 9, 10, 7, 5, 9, 4, 10];
176
177impl Permutation<[Mersenne31; 12]> for MdsMatrixMersenne31 {
178 fn permute(&self, input: [Mersenne31; 12]) -> [Mersenne31; 12] {
179 const MATRIX_CIRC_MDS_12_SML_COL: [i64; 12] =
180 first_row_to_first_col(&MATRIX_CIRC_MDS_12_SML_ROW);
181 SmallConvolveMersenne31::apply(
182 input,
183 MATRIX_CIRC_MDS_12_SML_COL,
184 SmallConvolveMersenne31::conv12,
185 )
186 }
187}
188impl MdsPermutation<Mersenne31, 12> for MdsMatrixMersenne31 {}
189
190const MATRIX_CIRC_MDS_16_SML_ROW: [i64; 16] =
191 [1, 1, 51, 1, 11, 17, 2, 1, 101, 63, 15, 2, 67, 22, 13, 3];
192
193impl Permutation<[Mersenne31; 16]> for MdsMatrixMersenne31 {
194 fn permute(&self, input: [Mersenne31; 16]) -> [Mersenne31; 16] {
195 const MATRIX_CIRC_MDS_16_SML_COL: [i64; 16] =
196 first_row_to_first_col(&MATRIX_CIRC_MDS_16_SML_ROW);
197 SmallConvolveMersenne31::apply(
198 input,
199 MATRIX_CIRC_MDS_16_SML_COL,
200 SmallConvolveMersenne31::conv16,
201 )
202 }
203}
204impl MdsPermutation<Mersenne31, 16> for MdsMatrixMersenne31 {}
205
206#[rustfmt::skip]
207const MATRIX_CIRC_MDS_32_MERSENNE31_ROW: [i64; 32] = [
208 0x1896DC78, 0x559D1E29, 0x04EBD732, 0x3FF449D7,
209 0x2DB0E2CE, 0x26776B85, 0x76018E57, 0x1025FA13,
210 0x06486BAB, 0x37706EBA, 0x25EB966B, 0x113C24E5,
211 0x2AE20EC4, 0x5A27507C, 0x0CD38CF1, 0x761C10E5,
212 0x19E3EF1A, 0x032C730F, 0x35D8AF83, 0x651DF13B,
213 0x7EC3DB1A, 0x6A146994, 0x588F9145, 0x09B79455,
214 0x7FDA05EC, 0x19FE71A8, 0x6988947A, 0x624F1D31,
215 0x500BB628, 0x0B1428CE, 0x3A62E1D6, 0x77692387
216];
217
218impl Permutation<[Mersenne31; 32]> for MdsMatrixMersenne31 {
219 fn permute(&self, input: [Mersenne31; 32]) -> [Mersenne31; 32] {
220 const MATRIX_CIRC_MDS_32_MERSENNE31_COL: [i64; 32] =
221 first_row_to_first_col(&MATRIX_CIRC_MDS_32_MERSENNE31_ROW);
222 LargeConvolveMersenne31::apply(
223 input,
224 MATRIX_CIRC_MDS_32_MERSENNE31_COL,
225 LargeConvolveMersenne31::conv32,
226 )
227 }
228}
229impl MdsPermutation<Mersenne31, 32> for MdsMatrixMersenne31 {}
230
231#[rustfmt::skip]
232const MATRIX_CIRC_MDS_64_MERSENNE31_ROW: [i64; 64] = [
233 0x570227A5, 0x3702983F, 0x4B7B3B0A, 0x74F13DE3,
234 0x485314B0, 0x0157E2EC, 0x1AD2E5DE, 0x721515E3,
235 0x5452ADA3, 0x0C74B6C1, 0x67DA9450, 0x33A48369,
236 0x3BDBEE06, 0x7C678D5E, 0x160F16D3, 0x54888B8C,
237 0x666C7AA6, 0x113B89E2, 0x2A403CE2, 0x18F9DF42,
238 0x2A685E84, 0x49EEFDE5, 0x5D044806, 0x560A41F8,
239 0x69EF1BD0, 0x2CD15786, 0x62E07766, 0x22A231E2,
240 0x3CFCF40C, 0x4E8F63D8, 0x69657A15, 0x466B4B2D,
241 0x4194B4D2, 0x1E9A85EA, 0x39709C27, 0x4B030BF3,
242 0x655DCE1D, 0x251F8899, 0x5B2EA879, 0x1E10E42F,
243 0x31F5BE07, 0x2AFBB7F9, 0x3E11021A, 0x5D97A17B,
244 0x6F0620BD, 0x5DBFC31D, 0x76C4761D, 0x21938559,
245 0x33777473, 0x71F0E92C, 0x0B9872A1, 0x4C2411F9,
246 0x545B7C96, 0x20256BAF, 0x7B8B493E, 0x33AD525C,
247 0x15EAEA1C, 0x6D2D1A21, 0x06A81D14, 0x3FACEB4F,
248 0x130EC21C, 0x3C84C4F5, 0x50FD67C0, 0x30FDD85A,
249];
250
251impl Permutation<[Mersenne31; 64]> for MdsMatrixMersenne31 {
252 fn permute(&self, input: [Mersenne31; 64]) -> [Mersenne31; 64] {
253 const MATRIX_CIRC_MDS_64_MERSENNE31_COL: [i64; 64] =
254 first_row_to_first_col(&MATRIX_CIRC_MDS_64_MERSENNE31_ROW);
255 LargeConvolveMersenne31::apply(
256 input,
257 MATRIX_CIRC_MDS_64_MERSENNE31_COL,
258 LargeConvolveMersenne31::conv64,
259 )
260 }
261}
262impl MdsPermutation<Mersenne31, 64> for MdsMatrixMersenne31 {}
263
264#[cfg(test)]
265mod tests {
266 use p3_symmetric::Permutation;
267
268 use super::{MdsMatrixMersenne31, Mersenne31};
269
270 #[test]
271 fn mersenne8() {
272 let input: [Mersenne31; 8] = Mersenne31::new_array([
273 1741044457, 327154658, 318297696, 1528828225, 468360260, 1271368222, 1906288587,
274 1521884224,
275 ]);
276
277 let output = MdsMatrixMersenne31.permute(input);
278
279 let expected: [Mersenne31; 8] = Mersenne31::new_array([
280 895992680, 1343855369, 2107796831, 266468728, 846686506, 252887121, 205223309,
281 260248790,
282 ]);
283
284 assert_eq!(output, expected);
285 }
286
287 #[test]
288 fn mersenne12() {
289 let input: [Mersenne31; 12] = Mersenne31::new_array([
290 1232740094, 661555540, 11024822, 1620264994, 471137070, 276755041, 1316882747,
291 1023679816, 1675266989, 743211887, 44774582, 1990989306,
292 ]);
293
294 let output = MdsMatrixMersenne31.permute(input);
295
296 let expected: [Mersenne31; 12] = Mersenne31::new_array([
297 860812289, 399778981, 1228500858, 798196553, 673507779, 1116345060, 829764188,
298 138346433, 578243475, 553581995, 578183208, 1527769050,
299 ]);
300
301 assert_eq!(output, expected);
302 }
303
304 #[test]
305 fn mersenne16() {
306 let input: [Mersenne31; 16] = Mersenne31::new_array([
307 1431168444, 963811518, 88067321, 381314132, 908628282, 1260098295, 980207659,
308 150070493, 357706876, 2014609375, 387876458, 1621671571, 183146044, 107201572,
309 166536524, 2078440788,
310 ]);
311
312 let output = MdsMatrixMersenne31.permute(input);
313
314 let expected: [Mersenne31; 16] = Mersenne31::new_array([
315 1858869691, 1607793806, 1200396641, 1400502985, 1511630695, 187938132, 1332411488,
316 2041577083, 2014246632, 802022141, 796807132, 1647212930, 813167618, 1867105010,
317 508596277, 1457551581,
318 ]);
319
320 assert_eq!(output, expected);
321 }
322
323 #[test]
324 fn mersenne32() {
325 let input: [Mersenne31; 32] = Mersenne31::new_array([
326 873912014, 1112497426, 300405095, 4255553, 1234979949, 156402357, 1952135954,
327 718195399, 1041748465, 683604342, 184275751, 1184118518, 214257054, 1293941921,
328 64085758, 710448062, 1133100009, 350114887, 1091675272, 671421879, 1226105999,
329 546430131, 1298443967, 1787169653, 2129310791, 1560307302, 471771931, 1191484402,
330 1550203198, 1541319048, 229197040, 839673789,
331 ]);
332
333 let output = MdsMatrixMersenne31.permute(input);
334
335 let expected: [Mersenne31; 32] = Mersenne31::new_array([
336 1439049928, 890642852, 694402307, 713403244, 553213342, 1049445650, 321709533,
337 1195683415, 2118492257, 623077773, 96734062, 990488164, 1674607608, 749155000,
338 353377854, 966432998, 1114654884, 1370359248, 1624965859, 685087760, 1631836645,
339 1615931812, 2061986317, 1773551151, 1449911206, 1951762557, 545742785, 582866449,
340 1379774336, 229242759, 1871227547, 752848413,
341 ]);
342
343 assert_eq!(output, expected);
344 }
345
346 #[test]
347 fn mersenne64() {
348 let input: [Mersenne31; 64] = Mersenne31::new_array([
349 837269696, 1509031194, 413915480, 1889329185, 315502822, 1529162228, 1454661012,
350 1015826742, 973381409, 1414676304, 1449029961, 1968715566, 2027226497, 1721820509,
351 434042616, 1436005045, 1680352863, 651591867, 260585272, 1078022153, 703990572,
352 269504423, 1776357592, 1174979337, 1142666094, 1897872960, 1387995838, 250774418,
353 776134750, 73930096, 194742451, 1860060380, 666407744, 669566398, 963802147,
354 2063418105, 1772573581, 998923482, 701912753, 1716548204, 860820931, 1680395948,
355 949886256, 1811558161, 501734557, 1671977429, 463135040, 1911493108, 207754409,
356 608714758, 1553060084, 1558941605, 980281686, 2014426559, 650527801, 53015148,
357 1521176057, 720530872, 713593252, 88228433, 1194162313, 1922416934, 1075145779,
358 344403794,
359 ]);
360
361 let output = MdsMatrixMersenne31.permute(input);
362
363 let expected: [Mersenne31; 64] = Mersenne31::new_array([
364 1599981950, 252630853, 1171557270, 116468420, 1269245345, 666203050, 46155642,
365 1701131520, 530845775, 508460407, 630407239, 1731628135, 1199144768, 295132047,
366 77536342, 1472377703, 30752443, 1300339617, 18647556, 1267774380, 1194573079,
367 1624665024, 646848056, 1667216490, 1184843555, 1250329476, 254171597, 1902035936,
368 1706882202, 964921003, 952266538, 1215696284, 539510504, 1056507562, 1393151480,
369 733644883, 1663330816, 1100715048, 991108703, 1671345065, 1376431774, 408310416,
370 313176996, 743567676, 304660642, 1842695838, 958201635, 1650792218, 541570244,
371 968523062, 1958918704, 1866282698, 849808680, 1193306222, 794153281, 822835360,
372 135282913, 1149868448, 2068162123, 1474283743, 2039088058, 720305835, 746036736,
373 671006610,
374 ]);
375
376 assert_eq!(output, expected);
377 }
378}