1use p3_field::PrimeCharacteristicRing;
10use p3_field::integers::QuotientMap;
11use p3_mds::MdsPermutation;
12use p3_mds::karatsuba_convolution::Convolve;
13use p3_mds::util::{dot_product, first_row_to_first_col};
14use p3_symmetric::Permutation;
15
16use crate::Mersenne31;
17
18#[derive(Clone, Debug, Default)]
19pub struct MdsMatrixMersenne31;
20
21struct SmallConvolveMersenne31;
26impl Convolve<Mersenne31, i64, i64, i64> for SmallConvolveMersenne31 {
27 #[inline(always)]
31 fn read(input: Mersenne31) -> i64 {
32 input.value as i64
33 }
34
35 #[inline(always)]
40 fn parity_dot<const N: usize>(u: [i64; N], v: [i64; N]) -> i64 {
41 dot_product(u, v)
42 }
43
44 #[inline(always)]
51 fn reduce(z: i64) -> Mersenne31 {
52 debug_assert!(z >= 0);
53 Mersenne31::from_u64(z as u64)
54 }
55}
56
57struct LargeConvolveMersenne31;
62impl Convolve<Mersenne31, i64, i64, i64> for LargeConvolveMersenne31 {
63 #[inline(always)]
67 fn read(input: Mersenne31) -> i64 {
68 input.value as i64
69 }
70
71 #[inline]
72 fn parity_dot<const N: usize>(u: [i64; N], v: [i64; N]) -> i64 {
73 let mut dp = 0i128;
78 for i in 0..N {
79 dp += u[i] as i128 * v[i] as i128;
80 }
81
82 const LOWMASK: i128 = (1 << 42) - 1; const HIGHMASK: i128 = !LOWMASK; let low_bits = (dp & LOWMASK) as i64; let high_bits = ((dp & HIGHMASK) >> 31) as i64; low_bits + high_bits
101 }
102
103 #[inline]
104 fn reduce(z: i64) -> Mersenne31 {
105 debug_assert!(z > -(1i64 << 49));
124 debug_assert!(z < (1i64 << 49));
125
126 const MASK: i64 = (1 << 31) - 1;
127 let low_bits = unsafe {
130 Mersenne31::from_canonical_unchecked((z & MASK) as u32)
132 };
133
134 let high_bits = ((z >> 31) & MASK) as i32;
135 let sign_bits = (z >> 62) as i32;
136
137 let high = unsafe {
138 Mersenne31::from_canonical_unchecked((high_bits + sign_bits) as u32)
140 };
141 low_bits + high
142 }
143}
144
145const MATRIX_CIRC_MDS_8_SML_ROW: [i64; 8] = [7, 1, 3, 8, 8, 3, 4, 9];
146
147impl Permutation<[Mersenne31; 8]> for MdsMatrixMersenne31 {
148 fn permute(&self, input: [Mersenne31; 8]) -> [Mersenne31; 8] {
149 const MATRIX_CIRC_MDS_8_SML_COL: [i64; 8] =
150 first_row_to_first_col(&MATRIX_CIRC_MDS_8_SML_ROW);
151 SmallConvolveMersenne31::apply(
152 input,
153 MATRIX_CIRC_MDS_8_SML_COL,
154 SmallConvolveMersenne31::conv8,
155 )
156 }
157}
158impl MdsPermutation<Mersenne31, 8> for MdsMatrixMersenne31 {}
159
160const MATRIX_CIRC_MDS_12_SML_ROW: [i64; 12] = [1, 1, 2, 1, 8, 9, 10, 7, 5, 9, 4, 10];
161
162impl Permutation<[Mersenne31; 12]> for MdsMatrixMersenne31 {
163 fn permute(&self, input: [Mersenne31; 12]) -> [Mersenne31; 12] {
164 const MATRIX_CIRC_MDS_12_SML_COL: [i64; 12] =
165 first_row_to_first_col(&MATRIX_CIRC_MDS_12_SML_ROW);
166 SmallConvolveMersenne31::apply(
167 input,
168 MATRIX_CIRC_MDS_12_SML_COL,
169 SmallConvolveMersenne31::conv12,
170 )
171 }
172}
173impl MdsPermutation<Mersenne31, 12> for MdsMatrixMersenne31 {}
174
175const MATRIX_CIRC_MDS_16_SML_ROW: [i64; 16] =
176 [1, 1, 51, 1, 11, 17, 2, 1, 101, 63, 15, 2, 67, 22, 13, 3];
177
178impl Permutation<[Mersenne31; 16]> for MdsMatrixMersenne31 {
179 fn permute(&self, input: [Mersenne31; 16]) -> [Mersenne31; 16] {
180 const MATRIX_CIRC_MDS_16_SML_COL: [i64; 16] =
181 first_row_to_first_col(&MATRIX_CIRC_MDS_16_SML_ROW);
182 SmallConvolveMersenne31::apply(
183 input,
184 MATRIX_CIRC_MDS_16_SML_COL,
185 SmallConvolveMersenne31::conv16,
186 )
187 }
188}
189impl MdsPermutation<Mersenne31, 16> for MdsMatrixMersenne31 {}
190
191#[rustfmt::skip]
192const MATRIX_CIRC_MDS_32_MERSENNE31_ROW: [i64; 32] = [
193 0x1896DC78, 0x559D1E29, 0x04EBD732, 0x3FF449D7,
194 0x2DB0E2CE, 0x26776B85, 0x76018E57, 0x1025FA13,
195 0x06486BAB, 0x37706EBA, 0x25EB966B, 0x113C24E5,
196 0x2AE20EC4, 0x5A27507C, 0x0CD38CF1, 0x761C10E5,
197 0x19E3EF1A, 0x032C730F, 0x35D8AF83, 0x651DF13B,
198 0x7EC3DB1A, 0x6A146994, 0x588F9145, 0x09B79455,
199 0x7FDA05EC, 0x19FE71A8, 0x6988947A, 0x624F1D31,
200 0x500BB628, 0x0B1428CE, 0x3A62E1D6, 0x77692387
201];
202
203impl Permutation<[Mersenne31; 32]> for MdsMatrixMersenne31 {
204 fn permute(&self, input: [Mersenne31; 32]) -> [Mersenne31; 32] {
205 const MATRIX_CIRC_MDS_32_MERSENNE31_COL: [i64; 32] =
206 first_row_to_first_col(&MATRIX_CIRC_MDS_32_MERSENNE31_ROW);
207 LargeConvolveMersenne31::apply(
208 input,
209 MATRIX_CIRC_MDS_32_MERSENNE31_COL,
210 LargeConvolveMersenne31::conv32,
211 )
212 }
213}
214impl MdsPermutation<Mersenne31, 32> for MdsMatrixMersenne31 {}
215
216#[rustfmt::skip]
217const MATRIX_CIRC_MDS_64_MERSENNE31_ROW: [i64; 64] = [
218 0x570227A5, 0x3702983F, 0x4B7B3B0A, 0x74F13DE3,
219 0x485314B0, 0x0157E2EC, 0x1AD2E5DE, 0x721515E3,
220 0x5452ADA3, 0x0C74B6C1, 0x67DA9450, 0x33A48369,
221 0x3BDBEE06, 0x7C678D5E, 0x160F16D3, 0x54888B8C,
222 0x666C7AA6, 0x113B89E2, 0x2A403CE2, 0x18F9DF42,
223 0x2A685E84, 0x49EEFDE5, 0x5D044806, 0x560A41F8,
224 0x69EF1BD0, 0x2CD15786, 0x62E07766, 0x22A231E2,
225 0x3CFCF40C, 0x4E8F63D8, 0x69657A15, 0x466B4B2D,
226 0x4194B4D2, 0x1E9A85EA, 0x39709C27, 0x4B030BF3,
227 0x655DCE1D, 0x251F8899, 0x5B2EA879, 0x1E10E42F,
228 0x31F5BE07, 0x2AFBB7F9, 0x3E11021A, 0x5D97A17B,
229 0x6F0620BD, 0x5DBFC31D, 0x76C4761D, 0x21938559,
230 0x33777473, 0x71F0E92C, 0x0B9872A1, 0x4C2411F9,
231 0x545B7C96, 0x20256BAF, 0x7B8B493E, 0x33AD525C,
232 0x15EAEA1C, 0x6D2D1A21, 0x06A81D14, 0x3FACEB4F,
233 0x130EC21C, 0x3C84C4F5, 0x50FD67C0, 0x30FDD85A,
234];
235
236impl Permutation<[Mersenne31; 64]> for MdsMatrixMersenne31 {
237 fn permute(&self, input: [Mersenne31; 64]) -> [Mersenne31; 64] {
238 const MATRIX_CIRC_MDS_64_MERSENNE31_COL: [i64; 64] =
239 first_row_to_first_col(&MATRIX_CIRC_MDS_64_MERSENNE31_ROW);
240 LargeConvolveMersenne31::apply(
241 input,
242 MATRIX_CIRC_MDS_64_MERSENNE31_COL,
243 LargeConvolveMersenne31::conv64,
244 )
245 }
246}
247impl MdsPermutation<Mersenne31, 64> for MdsMatrixMersenne31 {}
248
249#[cfg(test)]
250mod tests {
251 use p3_symmetric::Permutation;
252
253 use super::{MdsMatrixMersenne31, Mersenne31};
254
255 #[test]
256 fn mersenne8() {
257 let input: [Mersenne31; 8] = Mersenne31::new_array([
258 1741044457, 327154658, 318297696, 1528828225, 468360260, 1271368222, 1906288587,
259 1521884224,
260 ]);
261
262 let output = MdsMatrixMersenne31.permute(input);
263
264 let expected: [Mersenne31; 8] = Mersenne31::new_array([
265 895992680, 1343855369, 2107796831, 266468728, 846686506, 252887121, 205223309,
266 260248790,
267 ]);
268
269 assert_eq!(output, expected);
270 }
271
272 #[test]
273 fn mersenne12() {
274 let input: [Mersenne31; 12] = Mersenne31::new_array([
275 1232740094, 661555540, 11024822, 1620264994, 471137070, 276755041, 1316882747,
276 1023679816, 1675266989, 743211887, 44774582, 1990989306,
277 ]);
278
279 let output = MdsMatrixMersenne31.permute(input);
280
281 let expected: [Mersenne31; 12] = Mersenne31::new_array([
282 860812289, 399778981, 1228500858, 798196553, 673507779, 1116345060, 829764188,
283 138346433, 578243475, 553581995, 578183208, 1527769050,
284 ]);
285
286 assert_eq!(output, expected);
287 }
288
289 #[test]
290 fn mersenne16() {
291 let input: [Mersenne31; 16] = Mersenne31::new_array([
292 1431168444, 963811518, 88067321, 381314132, 908628282, 1260098295, 980207659,
293 150070493, 357706876, 2014609375, 387876458, 1621671571, 183146044, 107201572,
294 166536524, 2078440788,
295 ]);
296
297 let output = MdsMatrixMersenne31.permute(input);
298
299 let expected: [Mersenne31; 16] = Mersenne31::new_array([
300 1858869691, 1607793806, 1200396641, 1400502985, 1511630695, 187938132, 1332411488,
301 2041577083, 2014246632, 802022141, 796807132, 1647212930, 813167618, 1867105010,
302 508596277, 1457551581,
303 ]);
304
305 assert_eq!(output, expected);
306 }
307
308 #[test]
309 fn mersenne32() {
310 let input: [Mersenne31; 32] = Mersenne31::new_array([
311 873912014, 1112497426, 300405095, 4255553, 1234979949, 156402357, 1952135954,
312 718195399, 1041748465, 683604342, 184275751, 1184118518, 214257054, 1293941921,
313 64085758, 710448062, 1133100009, 350114887, 1091675272, 671421879, 1226105999,
314 546430131, 1298443967, 1787169653, 2129310791, 1560307302, 471771931, 1191484402,
315 1550203198, 1541319048, 229197040, 839673789,
316 ]);
317
318 let output = MdsMatrixMersenne31.permute(input);
319
320 let expected: [Mersenne31; 32] = Mersenne31::new_array([
321 1439049928, 890642852, 694402307, 713403244, 553213342, 1049445650, 321709533,
322 1195683415, 2118492257, 623077773, 96734062, 990488164, 1674607608, 749155000,
323 353377854, 966432998, 1114654884, 1370359248, 1624965859, 685087760, 1631836645,
324 1615931812, 2061986317, 1773551151, 1449911206, 1951762557, 545742785, 582866449,
325 1379774336, 229242759, 1871227547, 752848413,
326 ]);
327
328 assert_eq!(output, expected);
329 }
330
331 #[test]
332 fn mersenne64() {
333 let input: [Mersenne31; 64] = Mersenne31::new_array([
334 837269696, 1509031194, 413915480, 1889329185, 315502822, 1529162228, 1454661012,
335 1015826742, 973381409, 1414676304, 1449029961, 1968715566, 2027226497, 1721820509,
336 434042616, 1436005045, 1680352863, 651591867, 260585272, 1078022153, 703990572,
337 269504423, 1776357592, 1174979337, 1142666094, 1897872960, 1387995838, 250774418,
338 776134750, 73930096, 194742451, 1860060380, 666407744, 669566398, 963802147,
339 2063418105, 1772573581, 998923482, 701912753, 1716548204, 860820931, 1680395948,
340 949886256, 1811558161, 501734557, 1671977429, 463135040, 1911493108, 207754409,
341 608714758, 1553060084, 1558941605, 980281686, 2014426559, 650527801, 53015148,
342 1521176057, 720530872, 713593252, 88228433, 1194162313, 1922416934, 1075145779,
343 344403794,
344 ]);
345
346 let output = MdsMatrixMersenne31.permute(input);
347
348 let expected: [Mersenne31; 64] = Mersenne31::new_array([
349 1599981950, 252630853, 1171557270, 116468420, 1269245345, 666203050, 46155642,
350 1701131520, 530845775, 508460407, 630407239, 1731628135, 1199144768, 295132047,
351 77536342, 1472377703, 30752443, 1300339617, 18647556, 1267774380, 1194573079,
352 1624665024, 646848056, 1667216490, 1184843555, 1250329476, 254171597, 1902035936,
353 1706882202, 964921003, 952266538, 1215696284, 539510504, 1056507562, 1393151480,
354 733644883, 1663330816, 1100715048, 991108703, 1671345065, 1376431774, 408310416,
355 313176996, 743567676, 304660642, 1842695838, 958201635, 1650792218, 541570244,
356 968523062, 1958918704, 1866282698, 849808680, 1193306222, 794153281, 822835360,
357 135282913, 1149868448, 2068162123, 1474283743, 2039088058, 720305835, 746036736,
358 671006610,
359 ]);
360
361 assert_eq!(output, expected);
362 }
363}