p3_monty_31/mds.rs
1use core::marker::PhantomData;
2
3use p3_mds::MdsPermutation;
4use p3_mds::karatsuba_convolution::Convolve;
5use p3_mds::util::dot_product;
6use p3_symmetric::Permutation;
7
8use crate::{BarrettParameters, MontyField31, MontyParameters};
9
10/// A collection of circulant MDS matrices saved using their left most column.
11pub trait MDSUtils: Clone + Sync {
12 const MATRIX_CIRC_MDS_8_COL: [i64; 8];
13 const MATRIX_CIRC_MDS_12_COL: [i64; 12];
14 const MATRIX_CIRC_MDS_16_COL: [i64; 16];
15 const MATRIX_CIRC_MDS_24_COL: [i64; 24];
16 const MATRIX_CIRC_MDS_32_COL: [i64; 32];
17 const MATRIX_CIRC_MDS_64_COL: [i64; 64];
18}
19
20#[derive(Clone, Debug, Default)]
21pub struct MdsMatrixMontyField31<MU: MDSUtils> {
22 _phantom: PhantomData<MU>,
23}
24
25/// Instantiate convolution for "small" RHS vectors over a 31-bit MONTY_FIELD.
26///
27/// Here "small" means N = len(rhs) <= 16 and sum(r for r in rhs) <
28/// 2^24 (roughly), though in practice the sum will be less than 2^9.
29struct SmallConvolveMontyField31;
30
31impl<FP: MontyParameters> Convolve<MontyField31<FP>, i64, i64> for SmallConvolveMontyField31 {
32 const T_ZERO: i64 = 0;
33 const U_ZERO: i64 = 0;
34
35 #[inline(always)]
36 fn halve(val: i64) -> i64 {
37 val >> 1
38 }
39
40 /// Return the lift of a Monty31 element, satisfying 0 <=
41 /// input.value < P < 2^31. Note that Monty31 elements are
42 /// represented in Monty form.
43 #[inline(always)]
44 fn read(input: MontyField31<FP>) -> i64 {
45 input.value as i64
46 }
47
48 /// For a convolution of size N, |x| < N * 2^31 and (as per the
49 /// assumption above), |y| < 2^24. So the product is at most N * 2^55
50 /// which will not overflow for N <= 16.
51 ///
52 /// Note that the LHS element is in Monty form, while the RHS
53 /// element is a "plain integer". This informs the implementation
54 /// of `reduce()` below.
55 #[inline(always)]
56 fn parity_dot<const N: usize>(u: [i64; N], v: [i64; N]) -> i64 {
57 dot_product(u, v)
58 }
59
60 /// The assumptions above mean z < N^2 * 2^55, which is at most
61 /// 2^63 when N <= 16.
62 ///
63 /// Because the LHS elements were in Monty form and the RHS
64 /// elements were plain integers, reduction is simply the usual
65 /// reduction modulo P, rather than "Monty reduction".
66 ///
67 /// NB: Even though intermediate values could be negative, the
68 /// output must be non-negative since the inputs were
69 /// non-negative.
70 #[inline(always)]
71 fn reduce(z: i64) -> MontyField31<FP> {
72 debug_assert!(z >= 0);
73
74 MontyField31::new_monty((z as u64 % FP::PRIME as u64) as u32)
75 }
76}
77
78/// Given |x| < 2^80 compute x' such that:
79/// |x'| < 2**50
80/// x' = x mod p
81/// x' = x mod 2^10
82/// See Thm 1 (Below function) for a proof that this function is correct.
83#[inline(always)]
84const fn barrett_red_monty31<BP: BarrettParameters>(input: i128) -> i64 {
85 // input = input_low + beta*input_high
86 // So input_high < 2**63 and fits in an i64.
87 let input_high = (input >> BP::N) as i64; // input_high < input / beta < 2**{80 - N}
88
89 // I, input_high are i64's so this multiplication can't overflow.
90 let quot = (((input_high as i128) * (BP::PSEUDO_INV as i128)) >> BP::N) as i64;
91
92 // Replace quot by a close value which is divisible by 2^10.
93 let quot_2adic = quot & BP::MASK;
94
95 // quot_2adic, P are i64's so this can't overflow.
96 // sub is by construction divisible by both P and 2^10.
97 let sub = (quot_2adic as i128) * BP::PRIME_I128;
98
99 (input - sub) as i64
100}
101
102// Theorem 1:
103// Given |x| < 2^80, barrett_red(x) computes an x' such that:
104// x' = x mod p
105// x' = x mod 2^10
106// |x'| < 2**50.
107///////////////////////////////////////////////////////////////////////////////////////
108// PROOF:
109// By construction P, 2**10 | sub and so we immediately see that
110// x' = x mod p
111// x' = x mod 2^10.
112//
113// It remains to prove that |x'| < 2**50.
114//
115// We start by introducing some simple inequalities and relations between our variables:
116//
117// First consider the relationship between bit-shift and division.
118// It's easy to check that for all x:
119// 1: (x >> N) <= x / 2**N <= 1 + (x >> N)
120//
121// Similarly, as our mask just 0's the last 10 bits,
122// 2: x + 1 - 2^10 <= x & mask <= x
123//
124// Now if x, y are positive integers then
125// (x / y) - 1 <= x // y <= x / y
126// Where // denotes integer division.
127//
128// From this last inequality we immediately derive:
129// 3: (2**{2N} / P) - 1 <= I <= (2**{2N} / P)
130// 3a: 2**{2N} - P <= PI
131//
132// Finally, note that by definition:
133// input = input_high*(2**N) + input_low
134// Hence a simple rearrangement gets us
135// 4: input_high*(2**N) = input - input_low
136//
137//
138// We now need to split into cases depending on the sign of input.
139// Note that if x = 0 then x' = 0 so that case is trivial.
140///////////////////////////////////////////////////////////////////////////
141// CASE 1: input > 0
142//
143// If input > 0 then:
144// sub = Q*P = ((((input >> N) * I) >> N) & mask) * P <= P * (input / 2**{N}) * (2**{2N} / P) / 2**{N} = input
145// So input - sub >= 0.
146//
147// We need to improve our bound on Q. Observe that:
148// Q = (((input_high * I) >> N) & mask)
149// --(2) => Q + (2^10 - 1) >= (input_high * I) >> N)
150// --(1) => Q + 2^10 >= (I*x_high)/(2**N)
151// => (2**N)*Q + 2^10*(2**N) >= I*x_high
152//
153// Hence we find that:
154// (2**N)*Q*P + 2^10*(2**N)*P >= input_high*I*P
155// --(3a) >= input_high*2**{2N} - P*input_high
156// --(4) >= (2**N)*input - (2**N)*input_low - (2**N)*input_high (Assuming P < 2**N)
157//
158// Dividing by 2**N we get
159// Q*P + 2^{10}*P >= input - input_low - input_high
160// which rearranges to
161// x' = input - Q*P <= 2^{10}*P + input_low + input_high
162//
163// Picking N = 40 we see that 2^{10}*P, input_low, input_high are all bounded by 2**40
164// Hence x' < 2**42 < 2**50 as desired.
165//
166//
167//
168///////////////////////////////////////////////////////////////////////////
169// CASE 2: input < 0
170//
171// This case will be similar but all our inequalities will change slightly as negatives complicate things.
172// First observe that:
173// (input >> N) * I >= (input >> N) * 2**(2N) / P
174// >= (1 + (input / 2**N)) * 2**(2N) / P
175// >= (2**N + input) * 2**N / P
176//
177// Thus:
178// Q = ((input >> N) * I) >> N >= ((2**N + input) * 2**N / P) >> N
179// >= ((2**N + input) / P) - 1
180//
181// And so sub = Q*P >= 2**N - P + input.
182// Hence input - sub < 2**N - P.
183//
184// Thus if input - sub > 0 then |input - sub| < 2**50.
185// Thus we are left with bounding -(input - sub) = (sub - input).
186// Again we will proceed by improving our bound on Q.
187//
188// Q = (((input_high * I) >> N) & mask)
189// --(2) => Q <= (input_high * I) >> N) <= (I*x_high)/(2**N)
190// --(1) => Q <= (I*x_high)/(2**N)
191// => (2**N)*Q <= I*x_high
192//
193// Hence we find that:
194// (2**N)*Q*P <= input_high*I*P
195// --(3a) <= input_high*2**{2N} - P*input_high
196// --(4) <= (2**N)*input - (2**N)*input_low - (2**N)*input_high (Assuming P < 2**N)
197//
198// Dividing by 2**N we get
199// Q*P <= input - input_low - input_high
200// which rearranges to
201// -x' = -input + Q*P <= -input_high - input_low < 2**50
202//
203// This completes the proof.
204
205/// Instantiate convolution for "large" RHS vectors over BabyBear.
206///
207/// Here "large" means the elements can be as big as the field
208/// characteristic, and the size N of the RHS is <= 64.
209#[derive(Debug, Clone, Default)]
210struct LargeConvolveMontyField31;
211
212impl<FP> Convolve<MontyField31<FP>, i64, i64> for LargeConvolveMontyField31
213where
214 FP: BarrettParameters,
215{
216 const T_ZERO: i64 = 0;
217 const U_ZERO: i64 = 0;
218
219 #[inline(always)]
220 fn halve(val: i64) -> i64 {
221 val >> 1
222 }
223
224 /// Return the lift of a MontyField31 element, satisfying
225 /// 0 <= input.value < P < 2^31.
226 /// Note that MontyField31 elements are represented in Monty form.
227 #[inline(always)]
228 fn read(input: MontyField31<FP>) -> i64 {
229 input.value as i64
230 }
231
232 #[inline(always)]
233 fn parity_dot<const N: usize>(u: [i64; N], v: [i64; N]) -> i64 {
234 // For a convolution of size N, |x|, |y| < N * 2^31, so the
235 // product could be as much as N^2 * 2^62. This will overflow an
236 // i64, so we first widen to i128. Note that N^2 * 2^62 < 2^80
237 // for N <= 64, as required by `barrett_red_monty31()`.
238
239 let mut dp = 0i128;
240 for i in 0..N {
241 dp += u[i] as i128 * v[i] as i128;
242 }
243 barrett_red_monty31::<FP>(dp)
244 }
245
246 #[inline(always)]
247 fn reduce(z: i64) -> MontyField31<FP> {
248 // After the barrett reduction method, the output z of parity
249 // dot satisfies |z| < 2^50 (See Thm 1 above).
250 //
251 // In the recombining steps, conv_n maps (wo, w1) ->
252 // ((wo + w1)/2, (wo + w1)/2) which has no effect on the maximal
253 // size. (Indeed, it makes sizes almost strictly smaller).
254 //
255 // On the other hand, negacyclic_conv_n (ignoring the re-index)
256 // recombines as: (w0, w1, w2) -> (w0 + w1, w2 - w0 - w1).
257 // Hence if the input is <= K, the output is <= 3K.
258 //
259 // Thus the values appearing at the end are bounded by 3^n 2^50
260 // where n is the maximal number of negacyclic_conv
261 // recombination steps. When N = 64, we need to recombine for
262 // signed_conv_32, signed_conv_16, signed_conv_8 so the
263 // overall bound will be 3^3 2^50 < 32 * 2^50 < 2^55.
264 debug_assert!(z > -(1i64 << 55));
265 debug_assert!(z < (1i64 << 55));
266
267 // Note we do NOT move it into MONTY form. We assume it is already
268 // in this form.
269 let red = (z % (FP::PRIME as i64)) as u32;
270
271 // If z >= 0: 0 <= red < P is the correct value and P + red will
272 // not overflow.
273 // If z < 0: -P < red < 0 and the value we want is P + red.
274 // On bits, + acts identically for i32 and u32. Hence we can use
275 // u32's and just check for overflow.
276
277 let (corr, over) = red.overflowing_add(FP::PRIME);
278 let value = if over { corr } else { red };
279 MontyField31::new_monty(value)
280 }
281}
282
283impl<FP: MontyParameters, MU: MDSUtils> Permutation<[MontyField31<FP>; 8]>
284 for MdsMatrixMontyField31<MU>
285{
286 fn permute(&self, input: [MontyField31<FP>; 8]) -> [MontyField31<FP>; 8] {
287 SmallConvolveMontyField31::apply(
288 input,
289 MU::MATRIX_CIRC_MDS_8_COL,
290 <SmallConvolveMontyField31 as Convolve<MontyField31<FP>, i64, i64>>::conv8,
291 )
292 }
293}
294impl<FP: MontyParameters, MU: MDSUtils> MdsPermutation<MontyField31<FP>, 8>
295 for MdsMatrixMontyField31<MU>
296{
297}
298
299impl<FP: MontyParameters, MU: MDSUtils> Permutation<[MontyField31<FP>; 12]>
300 for MdsMatrixMontyField31<MU>
301{
302 fn permute(&self, input: [MontyField31<FP>; 12]) -> [MontyField31<FP>; 12] {
303 SmallConvolveMontyField31::apply(
304 input,
305 MU::MATRIX_CIRC_MDS_12_COL,
306 <SmallConvolveMontyField31 as Convolve<MontyField31<FP>, i64, i64>>::conv12,
307 )
308 }
309}
310impl<FP: MontyParameters, MU: MDSUtils> MdsPermutation<MontyField31<FP>, 12>
311 for MdsMatrixMontyField31<MU>
312{
313}
314
315impl<FP: MontyParameters, MU: MDSUtils> Permutation<[MontyField31<FP>; 16]>
316 for MdsMatrixMontyField31<MU>
317{
318 fn permute(&self, input: [MontyField31<FP>; 16]) -> [MontyField31<FP>; 16] {
319 SmallConvolveMontyField31::apply(
320 input,
321 MU::MATRIX_CIRC_MDS_16_COL,
322 <SmallConvolveMontyField31 as Convolve<MontyField31<FP>, i64, i64>>::conv16,
323 )
324 }
325}
326impl<FP: MontyParameters, MU: MDSUtils> MdsPermutation<MontyField31<FP>, 16>
327 for MdsMatrixMontyField31<MU>
328{
329}
330
331impl<FP, MU: MDSUtils> Permutation<[MontyField31<FP>; 24]> for MdsMatrixMontyField31<MU>
332where
333 FP: BarrettParameters,
334{
335 fn permute(&self, input: [MontyField31<FP>; 24]) -> [MontyField31<FP>; 24] {
336 LargeConvolveMontyField31::apply(
337 input,
338 MU::MATRIX_CIRC_MDS_24_COL,
339 <LargeConvolveMontyField31 as Convolve<MontyField31<FP>, i64, i64>>::conv24,
340 )
341 }
342}
343impl<FP: BarrettParameters, MU: MDSUtils> MdsPermutation<MontyField31<FP>, 24>
344 for MdsMatrixMontyField31<MU>
345{
346}
347
348impl<FP: BarrettParameters, MU: MDSUtils> Permutation<[MontyField31<FP>; 32]>
349 for MdsMatrixMontyField31<MU>
350{
351 fn permute(&self, input: [MontyField31<FP>; 32]) -> [MontyField31<FP>; 32] {
352 LargeConvolveMontyField31::apply(
353 input,
354 MU::MATRIX_CIRC_MDS_32_COL,
355 <LargeConvolveMontyField31 as Convolve<MontyField31<FP>, i64, i64>>::conv32,
356 )
357 }
358}
359impl<FP: BarrettParameters, MU: MDSUtils> MdsPermutation<MontyField31<FP>, 32>
360 for MdsMatrixMontyField31<MU>
361{
362}
363
364impl<FP: BarrettParameters, MU: MDSUtils> Permutation<[MontyField31<FP>; 64]>
365 for MdsMatrixMontyField31<MU>
366{
367 fn permute(&self, input: [MontyField31<FP>; 64]) -> [MontyField31<FP>; 64] {
368 LargeConvolveMontyField31::apply(
369 input,
370 MU::MATRIX_CIRC_MDS_64_COL,
371 <LargeConvolveMontyField31 as Convolve<MontyField31<FP>, i64, i64>>::conv64,
372 )
373 }
374}
375impl<FP: BarrettParameters, MU: MDSUtils> MdsPermutation<MontyField31<FP>, 64>
376 for MdsMatrixMontyField31<MU>
377{
378}