1use p3_field::PrimeCharacteristicRing;
10use p3_field::integers::QuotientMap;
11use p3_mds::MdsPermutation;
12use p3_mds::karatsuba_convolution::Convolve;
13use p3_mds::util::{dot_product, first_row_to_first_col};
14use p3_symmetric::Permutation;
15
16use crate::Mersenne31;
17
18#[derive(Clone, Debug, Default)]
19pub struct MdsMatrixMersenne31;
20
21struct SmallConvolveMersenne31;
26impl Convolve<Mersenne31, i64, i64, i64> for SmallConvolveMersenne31 {
27 #[inline(always)]
31 fn read(input: Mersenne31) -> i64 {
32 input.value as i64
33 }
34
35 #[inline(always)]
40 fn parity_dot<const N: usize>(u: [i64; N], v: [i64; N]) -> i64 {
41 dot_product(u, v)
42 }
43
44 #[inline(always)]
51 fn reduce(z: i64) -> Mersenne31 {
52 debug_assert!(z >= 0);
53 Mersenne31::from_u64(z as u64)
54 }
55}
56
57struct LargeConvolveMersenne31;
62impl Convolve<Mersenne31, i64, i64, i64> for LargeConvolveMersenne31 {
63 #[inline(always)]
67 fn read(input: Mersenne31) -> i64 {
68 input.value as i64
69 }
70
71 #[inline]
72 fn parity_dot<const N: usize>(u: [i64; N], v: [i64; N]) -> i64 {
73 let mut dp = 0i128;
78 for i in 0..N {
79 dp += u[i] as i128 * v[i] as i128;
80 }
81
82 const LOWMASK: i128 = (1 << 42) - 1; const HIGHMASK: i128 = !LOWMASK; let low_bits = (dp & LOWMASK) as i64; let high_bits = ((dp & HIGHMASK) >> 31) as i64; low_bits + high_bits
101 }
102
103 #[inline]
104 fn reduce(z: i64) -> Mersenne31 {
105 debug_assert!(z > -(1i64 << 49));
124 debug_assert!(z < (1i64 << 49));
125
126 const MASK: i64 = (1 << 31) - 1;
127 let low_bits = unsafe {
130 Mersenne31::from_canonical_unchecked((z & MASK) as u32)
132 };
133
134 let high_bits = ((z >> 31) & MASK) as i32;
135 let sign_bits = (z >> 62) as i32;
136
137 let high = unsafe {
138 Mersenne31::from_canonical_unchecked((high_bits + sign_bits) as u32)
140 };
141 low_bits + high
142 }
143}
144
145const MATRIX_CIRC_MDS_8_SML_ROW: [i64; 8] = [7, 1, 3, 8, 8, 3, 4, 9];
146
147impl Permutation<[Mersenne31; 8]> for MdsMatrixMersenne31 {
148 fn permute(&self, input: [Mersenne31; 8]) -> [Mersenne31; 8] {
149 const MATRIX_CIRC_MDS_8_SML_COL: [i64; 8] =
150 first_row_to_first_col(&MATRIX_CIRC_MDS_8_SML_ROW);
151 SmallConvolveMersenne31::apply(
152 input,
153 MATRIX_CIRC_MDS_8_SML_COL,
154 SmallConvolveMersenne31::conv8,
155 )
156 }
157
158 fn permute_mut(&self, input: &mut [Mersenne31; 8]) {
159 *input = self.permute(*input);
160 }
161}
162impl MdsPermutation<Mersenne31, 8> for MdsMatrixMersenne31 {}
163
164const MATRIX_CIRC_MDS_12_SML_ROW: [i64; 12] = [1, 1, 2, 1, 8, 9, 10, 7, 5, 9, 4, 10];
165
166impl Permutation<[Mersenne31; 12]> for MdsMatrixMersenne31 {
167 fn permute(&self, input: [Mersenne31; 12]) -> [Mersenne31; 12] {
168 const MATRIX_CIRC_MDS_12_SML_COL: [i64; 12] =
169 first_row_to_first_col(&MATRIX_CIRC_MDS_12_SML_ROW);
170 SmallConvolveMersenne31::apply(
171 input,
172 MATRIX_CIRC_MDS_12_SML_COL,
173 SmallConvolveMersenne31::conv12,
174 )
175 }
176
177 fn permute_mut(&self, input: &mut [Mersenne31; 12]) {
178 *input = self.permute(*input);
179 }
180}
181impl MdsPermutation<Mersenne31, 12> for MdsMatrixMersenne31 {}
182
183const MATRIX_CIRC_MDS_16_SML_ROW: [i64; 16] =
184 [1, 1, 51, 1, 11, 17, 2, 1, 101, 63, 15, 2, 67, 22, 13, 3];
185
186impl Permutation<[Mersenne31; 16]> for MdsMatrixMersenne31 {
187 fn permute(&self, input: [Mersenne31; 16]) -> [Mersenne31; 16] {
188 const MATRIX_CIRC_MDS_16_SML_COL: [i64; 16] =
189 first_row_to_first_col(&MATRIX_CIRC_MDS_16_SML_ROW);
190 SmallConvolveMersenne31::apply(
191 input,
192 MATRIX_CIRC_MDS_16_SML_COL,
193 SmallConvolveMersenne31::conv16,
194 )
195 }
196
197 fn permute_mut(&self, input: &mut [Mersenne31; 16]) {
198 *input = self.permute(*input);
199 }
200}
201impl MdsPermutation<Mersenne31, 16> for MdsMatrixMersenne31 {}
202
203#[rustfmt::skip]
204const MATRIX_CIRC_MDS_32_MERSENNE31_ROW: [i64; 32] = [
205 0x1896DC78, 0x559D1E29, 0x04EBD732, 0x3FF449D7,
206 0x2DB0E2CE, 0x26776B85, 0x76018E57, 0x1025FA13,
207 0x06486BAB, 0x37706EBA, 0x25EB966B, 0x113C24E5,
208 0x2AE20EC4, 0x5A27507C, 0x0CD38CF1, 0x761C10E5,
209 0x19E3EF1A, 0x032C730F, 0x35D8AF83, 0x651DF13B,
210 0x7EC3DB1A, 0x6A146994, 0x588F9145, 0x09B79455,
211 0x7FDA05EC, 0x19FE71A8, 0x6988947A, 0x624F1D31,
212 0x500BB628, 0x0B1428CE, 0x3A62E1D6, 0x77692387
213];
214
215impl Permutation<[Mersenne31; 32]> for MdsMatrixMersenne31 {
216 fn permute(&self, input: [Mersenne31; 32]) -> [Mersenne31; 32] {
217 const MATRIX_CIRC_MDS_32_MERSENNE31_COL: [i64; 32] =
218 first_row_to_first_col(&MATRIX_CIRC_MDS_32_MERSENNE31_ROW);
219 LargeConvolveMersenne31::apply(
220 input,
221 MATRIX_CIRC_MDS_32_MERSENNE31_COL,
222 LargeConvolveMersenne31::conv32,
223 )
224 }
225
226 fn permute_mut(&self, input: &mut [Mersenne31; 32]) {
227 *input = self.permute(*input);
228 }
229}
230impl MdsPermutation<Mersenne31, 32> for MdsMatrixMersenne31 {}
231
232#[rustfmt::skip]
233const MATRIX_CIRC_MDS_64_MERSENNE31_ROW: [i64; 64] = [
234 0x570227A5, 0x3702983F, 0x4B7B3B0A, 0x74F13DE3,
235 0x485314B0, 0x0157E2EC, 0x1AD2E5DE, 0x721515E3,
236 0x5452ADA3, 0x0C74B6C1, 0x67DA9450, 0x33A48369,
237 0x3BDBEE06, 0x7C678D5E, 0x160F16D3, 0x54888B8C,
238 0x666C7AA6, 0x113B89E2, 0x2A403CE2, 0x18F9DF42,
239 0x2A685E84, 0x49EEFDE5, 0x5D044806, 0x560A41F8,
240 0x69EF1BD0, 0x2CD15786, 0x62E07766, 0x22A231E2,
241 0x3CFCF40C, 0x4E8F63D8, 0x69657A15, 0x466B4B2D,
242 0x4194B4D2, 0x1E9A85EA, 0x39709C27, 0x4B030BF3,
243 0x655DCE1D, 0x251F8899, 0x5B2EA879, 0x1E10E42F,
244 0x31F5BE07, 0x2AFBB7F9, 0x3E11021A, 0x5D97A17B,
245 0x6F0620BD, 0x5DBFC31D, 0x76C4761D, 0x21938559,
246 0x33777473, 0x71F0E92C, 0x0B9872A1, 0x4C2411F9,
247 0x545B7C96, 0x20256BAF, 0x7B8B493E, 0x33AD525C,
248 0x15EAEA1C, 0x6D2D1A21, 0x06A81D14, 0x3FACEB4F,
249 0x130EC21C, 0x3C84C4F5, 0x50FD67C0, 0x30FDD85A,
250];
251
252impl Permutation<[Mersenne31; 64]> for MdsMatrixMersenne31 {
253 fn permute(&self, input: [Mersenne31; 64]) -> [Mersenne31; 64] {
254 const MATRIX_CIRC_MDS_64_MERSENNE31_COL: [i64; 64] =
255 first_row_to_first_col(&MATRIX_CIRC_MDS_64_MERSENNE31_ROW);
256 LargeConvolveMersenne31::apply(
257 input,
258 MATRIX_CIRC_MDS_64_MERSENNE31_COL,
259 LargeConvolveMersenne31::conv64,
260 )
261 }
262
263 fn permute_mut(&self, input: &mut [Mersenne31; 64]) {
264 *input = self.permute(*input);
265 }
266}
267impl MdsPermutation<Mersenne31, 64> for MdsMatrixMersenne31 {}
268
269#[cfg(test)]
270mod tests {
271 use p3_symmetric::Permutation;
272
273 use super::{MdsMatrixMersenne31, Mersenne31};
274
275 #[test]
276 fn mersenne8() {
277 let input: [Mersenne31; 8] = Mersenne31::new_array([
278 1741044457, 327154658, 318297696, 1528828225, 468360260, 1271368222, 1906288587,
279 1521884224,
280 ]);
281
282 let output = MdsMatrixMersenne31.permute(input);
283
284 let expected: [Mersenne31; 8] = Mersenne31::new_array([
285 895992680, 1343855369, 2107796831, 266468728, 846686506, 252887121, 205223309,
286 260248790,
287 ]);
288
289 assert_eq!(output, expected);
290 }
291
292 #[test]
293 fn mersenne12() {
294 let input: [Mersenne31; 12] = Mersenne31::new_array([
295 1232740094, 661555540, 11024822, 1620264994, 471137070, 276755041, 1316882747,
296 1023679816, 1675266989, 743211887, 44774582, 1990989306,
297 ]);
298
299 let output = MdsMatrixMersenne31.permute(input);
300
301 let expected: [Mersenne31; 12] = Mersenne31::new_array([
302 860812289, 399778981, 1228500858, 798196553, 673507779, 1116345060, 829764188,
303 138346433, 578243475, 553581995, 578183208, 1527769050,
304 ]);
305
306 assert_eq!(output, expected);
307 }
308
309 #[test]
310 fn mersenne16() {
311 let input: [Mersenne31; 16] = Mersenne31::new_array([
312 1431168444, 963811518, 88067321, 381314132, 908628282, 1260098295, 980207659,
313 150070493, 357706876, 2014609375, 387876458, 1621671571, 183146044, 107201572,
314 166536524, 2078440788,
315 ]);
316
317 let output = MdsMatrixMersenne31.permute(input);
318
319 let expected: [Mersenne31; 16] = Mersenne31::new_array([
320 1858869691, 1607793806, 1200396641, 1400502985, 1511630695, 187938132, 1332411488,
321 2041577083, 2014246632, 802022141, 796807132, 1647212930, 813167618, 1867105010,
322 508596277, 1457551581,
323 ]);
324
325 assert_eq!(output, expected);
326 }
327
328 #[test]
329 fn mersenne32() {
330 let input: [Mersenne31; 32] = Mersenne31::new_array([
331 873912014, 1112497426, 300405095, 4255553, 1234979949, 156402357, 1952135954,
332 718195399, 1041748465, 683604342, 184275751, 1184118518, 214257054, 1293941921,
333 64085758, 710448062, 1133100009, 350114887, 1091675272, 671421879, 1226105999,
334 546430131, 1298443967, 1787169653, 2129310791, 1560307302, 471771931, 1191484402,
335 1550203198, 1541319048, 229197040, 839673789,
336 ]);
337
338 let output = MdsMatrixMersenne31.permute(input);
339
340 let expected: [Mersenne31; 32] = Mersenne31::new_array([
341 1439049928, 890642852, 694402307, 713403244, 553213342, 1049445650, 321709533,
342 1195683415, 2118492257, 623077773, 96734062, 990488164, 1674607608, 749155000,
343 353377854, 966432998, 1114654884, 1370359248, 1624965859, 685087760, 1631836645,
344 1615931812, 2061986317, 1773551151, 1449911206, 1951762557, 545742785, 582866449,
345 1379774336, 229242759, 1871227547, 752848413,
346 ]);
347
348 assert_eq!(output, expected);
349 }
350
351 #[test]
352 fn mersenne64() {
353 let input: [Mersenne31; 64] = Mersenne31::new_array([
354 837269696, 1509031194, 413915480, 1889329185, 315502822, 1529162228, 1454661012,
355 1015826742, 973381409, 1414676304, 1449029961, 1968715566, 2027226497, 1721820509,
356 434042616, 1436005045, 1680352863, 651591867, 260585272, 1078022153, 703990572,
357 269504423, 1776357592, 1174979337, 1142666094, 1897872960, 1387995838, 250774418,
358 776134750, 73930096, 194742451, 1860060380, 666407744, 669566398, 963802147,
359 2063418105, 1772573581, 998923482, 701912753, 1716548204, 860820931, 1680395948,
360 949886256, 1811558161, 501734557, 1671977429, 463135040, 1911493108, 207754409,
361 608714758, 1553060084, 1558941605, 980281686, 2014426559, 650527801, 53015148,
362 1521176057, 720530872, 713593252, 88228433, 1194162313, 1922416934, 1075145779,
363 344403794,
364 ]);
365
366 let output = MdsMatrixMersenne31.permute(input);
367
368 let expected: [Mersenne31; 64] = Mersenne31::new_array([
369 1599981950, 252630853, 1171557270, 116468420, 1269245345, 666203050, 46155642,
370 1701131520, 530845775, 508460407, 630407239, 1731628135, 1199144768, 295132047,
371 77536342, 1472377703, 30752443, 1300339617, 18647556, 1267774380, 1194573079,
372 1624665024, 646848056, 1667216490, 1184843555, 1250329476, 254171597, 1902035936,
373 1706882202, 964921003, 952266538, 1215696284, 539510504, 1056507562, 1393151480,
374 733644883, 1663330816, 1100715048, 991108703, 1671345065, 1376431774, 408310416,
375 313176996, 743567676, 304660642, 1842695838, 958201635, 1650792218, 541570244,
376 968523062, 1958918704, 1866282698, 849808680, 1193306222, 794153281, 822835360,
377 135282913, 1149868448, 2068162123, 1474283743, 2039088058, 720305835, 746036736,
378 671006610,
379 ]);
380
381 assert_eq!(output, expected);
382 }
383}