Skip to main content

p3_monty_31/
mds.rs

1use core::marker::PhantomData;
2
3use p3_mds::MdsPermutation;
4use p3_mds::karatsuba_convolution::Convolve;
5use p3_mds::util::dot_product;
6use p3_symmetric::Permutation;
7
8use crate::{BarrettParameters, MontyField31, MontyParameters};
9
10/// A collection of circulant MDS matrices saved using their left most column.
11pub trait MDSUtils: Clone + Sync {
12    const MATRIX_CIRC_MDS_8_COL: [i64; 8];
13    const MATRIX_CIRC_MDS_12_COL: [i64; 12];
14    const MATRIX_CIRC_MDS_16_COL: [i64; 16];
15    const MATRIX_CIRC_MDS_24_COL: [i64; 24];
16    const MATRIX_CIRC_MDS_32_COL: [i64; 32];
17    const MATRIX_CIRC_MDS_64_COL: [i64; 64];
18}
19
20#[derive(Clone, Debug, Default)]
21pub struct MdsMatrixMontyField31<MU: MDSUtils> {
22    _phantom: PhantomData<MU>,
23}
24
25/// Instantiate convolution for "small" RHS vectors over a 31-bit MONTY_FIELD.
26///
27/// Here "small" means N = len(rhs) <= 16 and sum(r for r in rhs) <
28/// 2^24 (roughly), though in practice the sum will be less than 2^9.
29struct SmallConvolveMontyField31;
30
31impl<FP: MontyParameters> Convolve<MontyField31<FP>, i64, i64> for SmallConvolveMontyField31 {
32    const T_ZERO: i64 = 0;
33    const U_ZERO: i64 = 0;
34
35    #[inline(always)]
36    fn halve(val: i64) -> i64 {
37        val >> 1
38    }
39
40    /// Return the lift of a Monty31 element, satisfying 0 <=
41    /// input.value < P < 2^31. Note that Monty31 elements are
42    /// represented in Monty form.
43    #[inline(always)]
44    fn read(input: MontyField31<FP>) -> i64 {
45        input.value as i64
46    }
47
48    /// For a convolution of size N, |x| < N * 2^31 and (as per the
49    /// assumption above), |y| < 2^24. So the product is at most N * 2^55
50    /// which will not overflow for N <= 16.
51    ///
52    /// Note that the LHS element is in Monty form, while the RHS
53    /// element is a "plain integer". This informs the implementation
54    /// of `reduce()` below.
55    #[inline(always)]
56    fn parity_dot<const N: usize>(u: [i64; N], v: [i64; N]) -> i64 {
57        dot_product(u, v)
58    }
59
60    /// The assumptions above mean z < N^2 * 2^55, which is at most
61    /// 2^63 when N <= 16.
62    ///
63    /// Because the LHS elements were in Monty form and the RHS
64    /// elements were plain integers, reduction is simply the usual
65    /// reduction modulo P, rather than "Monty reduction".
66    ///
67    /// NB: Even though intermediate values could be negative, the
68    /// output must be non-negative since the inputs were
69    /// non-negative.
70    #[inline(always)]
71    fn reduce(z: i64) -> MontyField31<FP> {
72        debug_assert!(z >= 0);
73
74        MontyField31::new_monty((z as u64 % FP::PRIME as u64) as u32)
75    }
76}
77
78/// Given |x| < 2^80 compute x' such that:
79/// |x'| < 2**50
80/// x' = x mod p
81/// x' = x mod 2^10
82/// See Thm 1 (Below function) for a proof that this function is correct.
83#[inline(always)]
84const fn barrett_red_monty31<BP: BarrettParameters>(input: i128) -> i64 {
85    // input = input_low + beta*input_high
86    // So input_high < 2**63 and fits in an i64.
87    let input_high = (input >> BP::N) as i64; // input_high < input / beta < 2**{80 - N}
88
89    // I, input_high are i64's so this multiplication can't overflow.
90    let quot = (((input_high as i128) * (BP::PSEUDO_INV as i128)) >> BP::N) as i64;
91
92    // Replace quot by a close value which is divisible by 2^10.
93    let quot_2adic = quot & BP::MASK;
94
95    // quot_2adic, P are i64's so this can't overflow.
96    // sub is by construction divisible by both P and 2^10.
97    let sub = (quot_2adic as i128) * BP::PRIME_I128;
98
99    (input - sub) as i64
100}
101
102// Theorem 1:
103// Given |x| < 2^80, barrett_red(x) computes an x' such that:
104//       x' = x mod p
105//       x' = x mod 2^10
106//       |x'| < 2**50.
107///////////////////////////////////////////////////////////////////////////////////////
108// PROOF:
109// By construction P, 2**10 | sub and so we immediately see that
110// x' = x mod p
111// x' = x mod 2^10.
112//
113// It remains to prove that |x'| < 2**50.
114//
115// We start by introducing some simple inequalities and relations between our variables:
116//
117// First consider the relationship between bit-shift and division.
118// It's easy to check that for all x:
119// 1: (x >> N) <= x / 2**N <= 1 + (x >> N)
120//
121// Similarly, as our mask just 0's the last 10 bits,
122// 2: x + 1 - 2^10 <= x & mask <= x
123//
124// Now if x, y are positive integers then
125// (x / y) - 1 <= x // y <= x / y
126// Where // denotes integer division.
127//
128// From this last inequality we immediately derive:
129// 3: (2**{2N} / P) - 1 <= I <= (2**{2N} / P)
130// 3a: 2**{2N} - P <= PI
131//
132// Finally, note that by definition:
133// input = input_high*(2**N) + input_low
134// Hence a simple rearrangement gets us
135// 4: input_high*(2**N) = input - input_low
136//
137//
138// We now need to split into cases depending on the sign of input.
139// Note that if x = 0 then x' = 0 so that case is trivial.
140///////////////////////////////////////////////////////////////////////////
141// CASE 1: input > 0
142//
143// If input > 0 then:
144// sub = Q*P = ((((input >> N) * I) >> N) & mask) * P <= P * (input / 2**{N}) * (2**{2N} / P) / 2**{N} = input
145// So input - sub >= 0.
146//
147// We need to improve our bound on Q. Observe that:
148// Q = (((input_high * I) >> N) & mask)
149// --(2)   => Q + (2^10 - 1) >= (input_high * I) >> N)
150// --(1)   => Q + 2^10 >= (I*x_high)/(2**N)
151//         => (2**N)*Q + 2^10*(2**N) >= I*x_high
152//
153// Hence we find that:
154// (2**N)*Q*P + 2^10*(2**N)*P >= input_high*I*P
155// --(3a)                     >= input_high*2**{2N} - P*input_high
156// --(4)                      >= (2**N)*input - (2**N)*input_low - (2**N)*input_high   (Assuming P < 2**N)
157//
158// Dividing by 2**N we get
159// Q*P + 2^{10}*P >= input - input_low - input_high
160// which rearranges to
161// x' = input - Q*P <= 2^{10}*P + input_low + input_high
162//
163// Picking N = 40 we see that 2^{10}*P, input_low, input_high are all bounded by 2**40
164// Hence x' < 2**42 < 2**50 as desired.
165//
166//
167//
168///////////////////////////////////////////////////////////////////////////
169// CASE 2: input < 0
170//
171// This case will be similar but all our inequalities will change slightly as negatives complicate things.
172// First observe that:
173// (input >> N) * I   >= (input >> N) * 2**(2N) / P
174//                    >= (1 + (input / 2**N)) * 2**(2N) / P
175//                    >= (2**N + input) * 2**N / P
176//
177// Thus:
178// Q = ((input >> N) * I) >> N >= ((2**N + input) * 2**N / P) >> N
179//                             >= ((2**N + input) / P) - 1
180//
181// And so sub = Q*P >= 2**N - P + input.
182// Hence input - sub < 2**N - P.
183//
184// Thus if input - sub > 0 then |input - sub| < 2**50.
185// Thus we are left with bounding -(input - sub) = (sub - input).
186// Again we will proceed by improving our bound on Q.
187//
188// Q = (((input_high * I) >> N) & mask)
189// --(2)   => Q <= (input_high * I) >> N) <= (I*x_high)/(2**N)
190// --(1)   => Q <= (I*x_high)/(2**N)
191//         => (2**N)*Q <= I*x_high
192//
193// Hence we find that:
194// (2**N)*Q*P <= input_high*I*P
195// --(3a)     <= input_high*2**{2N} - P*input_high
196// --(4)      <= (2**N)*input - (2**N)*input_low - (2**N)*input_high   (Assuming P < 2**N)
197//
198// Dividing by 2**N we get
199// Q*P <= input - input_low - input_high
200// which rearranges to
201// -x' = -input + Q*P <= -input_high - input_low < 2**50
202//
203// This completes the proof.
204
205/// Instantiate convolution for "large" RHS vectors over BabyBear.
206///
207/// Here "large" means the elements can be as big as the field
208/// characteristic, and the size N of the RHS is <= 64.
209#[derive(Debug, Clone, Default)]
210struct LargeConvolveMontyField31;
211
212impl<FP> Convolve<MontyField31<FP>, i64, i64> for LargeConvolveMontyField31
213where
214    FP: BarrettParameters,
215{
216    const T_ZERO: i64 = 0;
217    const U_ZERO: i64 = 0;
218
219    #[inline(always)]
220    fn halve(val: i64) -> i64 {
221        val >> 1
222    }
223
224    /// Return the lift of a MontyField31 element, satisfying
225    /// 0 <= input.value < P < 2^31.
226    /// Note that MontyField31 elements are represented in Monty form.
227    #[inline(always)]
228    fn read(input: MontyField31<FP>) -> i64 {
229        input.value as i64
230    }
231
232    #[inline(always)]
233    fn parity_dot<const N: usize>(u: [i64; N], v: [i64; N]) -> i64 {
234        // For a convolution of size N, |x|, |y| < N * 2^31, so the
235        // product could be as much as N^2 * 2^62. This will overflow an
236        // i64, so we first widen to i128. Note that N^2 * 2^62 < 2^80
237        // for N <= 64, as required by `barrett_red_monty31()`.
238
239        let mut dp = 0i128;
240        for i in 0..N {
241            dp += u[i] as i128 * v[i] as i128;
242        }
243        barrett_red_monty31::<FP>(dp)
244    }
245
246    #[inline(always)]
247    fn reduce(z: i64) -> MontyField31<FP> {
248        // After the barrett reduction method, the output z of parity
249        // dot satisfies |z| < 2^50 (See Thm 1 above).
250        //
251        // In the recombining steps, conv_n maps (wo, w1) ->
252        // ((wo + w1)/2, (wo + w1)/2) which has no effect on the maximal
253        // size. (Indeed, it makes sizes almost strictly smaller).
254        //
255        // On the other hand, negacyclic_conv_n (ignoring the re-index)
256        // recombines as: (w0, w1, w2) -> (w0 + w1, w2 - w0 - w1).
257        // Hence if the input is <= K, the output is <= 3K.
258        //
259        // Thus the values appearing at the end are bounded by 3^n 2^50
260        // where n is the maximal number of negacyclic_conv
261        // recombination steps. When N = 64, we need to recombine for
262        // signed_conv_32, signed_conv_16, signed_conv_8 so the
263        // overall bound will be 3^3 2^50 < 32 * 2^50 < 2^55.
264        debug_assert!(z > -(1i64 << 55));
265        debug_assert!(z < (1i64 << 55));
266
267        // Note we do NOT move it into MONTY form. We assume it is already
268        // in this form.
269        let red = (z % (FP::PRIME as i64)) as u32;
270
271        // If z >= 0: 0 <= red < P is the correct value and P + red will
272        // not overflow.
273        // If z < 0: -P < red < 0 and the value we want is P + red.
274        // On bits, + acts identically for i32 and u32. Hence we can use
275        // u32's and just check for overflow.
276
277        let (corr, over) = red.overflowing_add(FP::PRIME);
278        let value = if over { corr } else { red };
279        MontyField31::new_monty(value)
280    }
281}
282
283impl<FP: MontyParameters, MU: MDSUtils> Permutation<[MontyField31<FP>; 8]>
284    for MdsMatrixMontyField31<MU>
285{
286    fn permute(&self, input: [MontyField31<FP>; 8]) -> [MontyField31<FP>; 8] {
287        SmallConvolveMontyField31::apply(
288            input,
289            MU::MATRIX_CIRC_MDS_8_COL,
290            <SmallConvolveMontyField31 as Convolve<MontyField31<FP>, i64, i64>>::conv8,
291        )
292    }
293}
294impl<FP: MontyParameters, MU: MDSUtils> MdsPermutation<MontyField31<FP>, 8>
295    for MdsMatrixMontyField31<MU>
296{
297}
298
299impl<FP: MontyParameters, MU: MDSUtils> Permutation<[MontyField31<FP>; 12]>
300    for MdsMatrixMontyField31<MU>
301{
302    fn permute(&self, input: [MontyField31<FP>; 12]) -> [MontyField31<FP>; 12] {
303        SmallConvolveMontyField31::apply(
304            input,
305            MU::MATRIX_CIRC_MDS_12_COL,
306            <SmallConvolveMontyField31 as Convolve<MontyField31<FP>, i64, i64>>::conv12,
307        )
308    }
309}
310impl<FP: MontyParameters, MU: MDSUtils> MdsPermutation<MontyField31<FP>, 12>
311    for MdsMatrixMontyField31<MU>
312{
313}
314
315impl<FP: MontyParameters, MU: MDSUtils> Permutation<[MontyField31<FP>; 16]>
316    for MdsMatrixMontyField31<MU>
317{
318    fn permute(&self, input: [MontyField31<FP>; 16]) -> [MontyField31<FP>; 16] {
319        SmallConvolveMontyField31::apply(
320            input,
321            MU::MATRIX_CIRC_MDS_16_COL,
322            <SmallConvolveMontyField31 as Convolve<MontyField31<FP>, i64, i64>>::conv16,
323        )
324    }
325}
326impl<FP: MontyParameters, MU: MDSUtils> MdsPermutation<MontyField31<FP>, 16>
327    for MdsMatrixMontyField31<MU>
328{
329}
330
331impl<FP, MU: MDSUtils> Permutation<[MontyField31<FP>; 24]> for MdsMatrixMontyField31<MU>
332where
333    FP: BarrettParameters,
334{
335    fn permute(&self, input: [MontyField31<FP>; 24]) -> [MontyField31<FP>; 24] {
336        LargeConvolveMontyField31::apply(
337            input,
338            MU::MATRIX_CIRC_MDS_24_COL,
339            <LargeConvolveMontyField31 as Convolve<MontyField31<FP>, i64, i64>>::conv24,
340        )
341    }
342}
343impl<FP: BarrettParameters, MU: MDSUtils> MdsPermutation<MontyField31<FP>, 24>
344    for MdsMatrixMontyField31<MU>
345{
346}
347
348impl<FP: BarrettParameters, MU: MDSUtils> Permutation<[MontyField31<FP>; 32]>
349    for MdsMatrixMontyField31<MU>
350{
351    fn permute(&self, input: [MontyField31<FP>; 32]) -> [MontyField31<FP>; 32] {
352        LargeConvolveMontyField31::apply(
353            input,
354            MU::MATRIX_CIRC_MDS_32_COL,
355            <LargeConvolveMontyField31 as Convolve<MontyField31<FP>, i64, i64>>::conv32,
356        )
357    }
358}
359impl<FP: BarrettParameters, MU: MDSUtils> MdsPermutation<MontyField31<FP>, 32>
360    for MdsMatrixMontyField31<MU>
361{
362}
363
364impl<FP: BarrettParameters, MU: MDSUtils> Permutation<[MontyField31<FP>; 64]>
365    for MdsMatrixMontyField31<MU>
366{
367    fn permute(&self, input: [MontyField31<FP>; 64]) -> [MontyField31<FP>; 64] {
368        LargeConvolveMontyField31::apply(
369            input,
370            MU::MATRIX_CIRC_MDS_64_COL,
371            <LargeConvolveMontyField31 as Convolve<MontyField31<FP>, i64, i64>>::conv64,
372        )
373    }
374}
375impl<FP: BarrettParameters, MU: MDSUtils> MdsPermutation<MontyField31<FP>, 64>
376    for MdsMatrixMontyField31<MU>
377{
378}