p3_mds/
karatsuba_convolution.rs

1//! Calculate the convolution of two vectors using a Karatsuba-style
2//! decomposition and the CRT.
3//!
4//! This is not a new idea, but we did have the pleasure of
5//! reinventing it independently. Some references:
6//! - `<https://cr.yp.to/lineartime/multapps-20080515.pdf>`
7//! - `<https://2π.com/23/convolution/>`
8//!
9//! Given a vector v \in F^N, let v(x) \in F[x] denote the polynomial
10//! v_0 + v_1 x + ... + v_{N - 1} x^{N - 1}.  Then w is equal to the
11//! convolution v * u if and only if w(x) = v(x)u(x) mod x^N - 1.
12//! Additionally, define the negacyclic convolution by w(x) = v(x)u(x)
13//! mod x^N + 1.  Using the Chinese remainder theorem we can compute
14//! w(x) as
15//!     w(x) = 1/2 (w_0(x) + w_1(x)) + x^{N/2}/2 (w_0(x) - w_1(x))
16//! where
17//!     w_0 = v(x)u(x) mod x^{N/2} - 1
18//!     w_1 = v(x)u(x) mod x^{N/2} + 1
19//!
20//! To compute w_0 and w_1 we first compute
21//!                  v_0(x) = v(x) mod x^{N/2} - 1
22//!                  v_1(x) = v(x) mod x^{N/2} + 1
23//!                  u_0(x) = u(x) mod x^{N/2} - 1
24//!                  u_1(x) = u(x) mod x^{N/2} + 1
25//!
26//! Now w_0 is the convolution of v_0 and u_0 which we can compute
27//! recursively.  For w_1 we compute the negacyclic convolution
28//! v_1(x)u_1(x) mod x^{N/2} + 1 using Karatsuba.
29//!
30//! There are 2 possible approaches to applying Karatsuba which mirror
31//! the DIT vs DIF approaches to FFT's, the left/right decomposition
32//! or the even/odd decomposition. The latter seems to have fewer
33//! operations and so it is the one implemented below, though it does
34//! require a bit more data manipulation. It works as follows:
35//!
36//! Define the even v_e and odd v_o parts so that v(x) = (v_e(x^2) + x v_o(x^2)).
37//! Then v(x)u(x)
38//!    = (v_e(x^2)u_e(x^2) + x^2 v_o(x^2)u_o(x^2))
39//!      + x ((v_e(x^2) + v_o(x^2))(u_e(x^2) + u_o(x^2))
40//!            - (v_e(x^2)u_e(x^2) + v_o(x^2)u_o(x^2)))
41//! This reduces the problem to 3 negacyclic convolutions of size N/2 which
42//! can be computed recursively.
43//!
44//! Of course, for small sizes we just explicitly write out the O(n^2)
45//! approach.
46
47use core::ops::{Add, AddAssign, Neg, ShrAssign, Sub, SubAssign};
48
49/// This trait collects the operations needed by `Convolve` below.
50///
51/// TODO: Think of a better name for this.
52pub trait RngElt:
53    Add<Output = Self>
54    + AddAssign
55    + Copy
56    + Default
57    + Neg<Output = Self>
58    + ShrAssign<u32>
59    + Sub<Output = Self>
60    + SubAssign
61{
62}
63
64impl RngElt for i64 {}
65impl RngElt for i128 {}
66
67/// Template function to perform convolution of vectors.
68///
69/// Roughly speaking, for a convolution of size `N`, it should be
70/// possible to add `N` elements of type `T` without overflowing, and
71/// similarly for `U`. Then multiplication via `Self::mul` should
72/// produce an element of type `V` which will not overflow after about
73/// `N` additions (this is an over-estimate).
74///
75/// For example usage, see `{mersenne-31,baby-bear,goldilocks}/src/mds.rs`.
76///
77/// NB: In practice, one of the parameters to the convolution will be
78/// constant (the MDS matrix). After inspecting Godbolt output, it
79/// seems that the compiler does indeed generate single constants as
80/// inputs to the multiplication, rather than doing all that
81/// arithmetic on the constant values every time. Note however that,
82/// for MDS matrices with large entries (N >= 24), these compile-time
83/// generated constants will be about N times bigger than they need to
84/// be in principle, which could be a potential avenue for some minor
85/// improvements.
86///
87/// NB: If primitive multiplications are still the bottleneck, a
88/// further possibility would be to find an MDS matrix some of whose
89/// entries are powers of 2. Then the multiplication can be replaced
90/// with a shift, which on most architectures has better throughput
91/// and latency, and is issued on different ports (1*p06) to
92/// multiplication (1*p1).
93pub trait Convolve<F, T: RngElt, U: RngElt, V: RngElt> {
94    /// Given an input element, retrieve the corresponding internal
95    /// element that will be used in calculations.
96    fn read(input: F) -> T;
97
98    /// Given input vectors `lhs` and `rhs`, calculate their dot
99    /// product. The result can be reduced with respect to the modulus
100    /// (of `F`), but it must have the same lower 10 bits as the dot
101    /// product if all inputs are considered integers. See
102    /// `monty-31/src/mds.rs::barrett_red_monty31()` for an example
103    /// of how this can be implemented in practice.
104    fn parity_dot<const N: usize>(lhs: [T; N], rhs: [U; N]) -> V;
105
106    /// Convert an internal element of type `V` back into an external
107    /// element.
108    fn reduce(z: V) -> F;
109
110    /// Convolve `lhs` and `rhs`.
111    ///
112    /// The parameter `conv` should be the function in this trait that
113    /// corresponds to length `N`.
114    #[inline(always)]
115    fn apply<const N: usize, C: Fn([T; N], [U; N], &mut [V])>(
116        lhs: [F; N],
117        rhs: [U; N],
118        conv: C,
119    ) -> [F; N] {
120        let lhs = lhs.map(Self::read);
121        let mut output = [V::default(); N];
122        conv(lhs, rhs, &mut output);
123        output.map(Self::reduce)
124    }
125
126    #[inline(always)]
127    fn conv3(lhs: [T; 3], rhs: [U; 3], output: &mut [V]) {
128        output[0] = Self::parity_dot(lhs, [rhs[0], rhs[2], rhs[1]]);
129        output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], rhs[2]]);
130        output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0]]);
131    }
132
133    #[inline(always)]
134    fn negacyclic_conv3(lhs: [T; 3], rhs: [U; 3], output: &mut [V]) {
135        output[0] = Self::parity_dot(lhs, [rhs[0], -rhs[2], -rhs[1]]);
136        output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], -rhs[2]]);
137        output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0]]);
138    }
139
140    #[inline(always)]
141    fn conv4(lhs: [T; 4], rhs: [U; 4], output: &mut [V]) {
142        // NB: This is just explicitly implementing
143        // conv_n_recursive::<4, 2, _, _>(lhs, rhs, output, Self::conv2, Self::negacyclic_conv2)
144        let u_p = [lhs[0] + lhs[2], lhs[1] + lhs[3]];
145        let u_m = [lhs[0] - lhs[2], lhs[1] - lhs[3]];
146        let v_p = [rhs[0] + rhs[2], rhs[1] + rhs[3]];
147        let v_m = [rhs[0] - rhs[2], rhs[1] - rhs[3]];
148
149        output[0] = Self::parity_dot(u_m, [v_m[0], -v_m[1]]);
150        output[1] = Self::parity_dot(u_m, [v_m[1], v_m[0]]);
151        output[2] = Self::parity_dot(u_p, v_p);
152        output[3] = Self::parity_dot(u_p, [v_p[1], v_p[0]]);
153
154        output[0] += output[2];
155        output[1] += output[3];
156
157        output[0] >>= 1;
158        output[1] >>= 1;
159
160        output[2] -= output[0];
161        output[3] -= output[1];
162    }
163
164    #[inline(always)]
165    fn negacyclic_conv4(lhs: [T; 4], rhs: [U; 4], output: &mut [V]) {
166        output[0] = Self::parity_dot(lhs, [rhs[0], -rhs[3], -rhs[2], -rhs[1]]);
167        output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], -rhs[3], -rhs[2]]);
168        output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0], -rhs[3]]);
169        output[3] = Self::parity_dot(lhs, [rhs[3], rhs[2], rhs[1], rhs[0]]);
170    }
171
172    #[inline(always)]
173    fn conv6(lhs: [T; 6], rhs: [U; 6], output: &mut [V]) {
174        conv_n_recursive(lhs, rhs, output, Self::conv3, Self::negacyclic_conv3);
175    }
176
177    #[inline(always)]
178    fn negacyclic_conv6(lhs: [T; 6], rhs: [U; 6], output: &mut [V]) {
179        negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv3);
180    }
181
182    #[inline(always)]
183    fn conv8(lhs: [T; 8], rhs: [U; 8], output: &mut [V]) {
184        conv_n_recursive(lhs, rhs, output, Self::conv4, Self::negacyclic_conv4);
185    }
186
187    #[inline(always)]
188    fn negacyclic_conv8(lhs: [T; 8], rhs: [U; 8], output: &mut [V]) {
189        negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv4);
190    }
191
192    #[inline(always)]
193    fn conv12(lhs: [T; 12], rhs: [U; 12], output: &mut [V]) {
194        conv_n_recursive(lhs, rhs, output, Self::conv6, Self::negacyclic_conv6);
195    }
196
197    #[inline(always)]
198    fn negacyclic_conv12(lhs: [T; 12], rhs: [U; 12], output: &mut [V]) {
199        negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv6);
200    }
201
202    #[inline(always)]
203    fn conv16(lhs: [T; 16], rhs: [U; 16], output: &mut [V]) {
204        conv_n_recursive(lhs, rhs, output, Self::conv8, Self::negacyclic_conv8);
205    }
206
207    #[inline(always)]
208    fn negacyclic_conv16(lhs: [T; 16], rhs: [U; 16], output: &mut [V]) {
209        negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv8);
210    }
211
212    #[inline(always)]
213    fn conv24(lhs: [T; 24], rhs: [U; 24], output: &mut [V]) {
214        conv_n_recursive(lhs, rhs, output, Self::conv12, Self::negacyclic_conv12);
215    }
216
217    #[inline(always)]
218    fn conv32(lhs: [T; 32], rhs: [U; 32], output: &mut [V]) {
219        conv_n_recursive(lhs, rhs, output, Self::conv16, Self::negacyclic_conv16);
220    }
221
222    #[inline(always)]
223    fn negacyclic_conv32(lhs: [T; 32], rhs: [U; 32], output: &mut [V]) {
224        negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv16);
225    }
226
227    #[inline(always)]
228    fn conv64(lhs: [T; 64], rhs: [U; 64], output: &mut [V]) {
229        conv_n_recursive(lhs, rhs, output, Self::conv32, Self::negacyclic_conv32);
230    }
231}
232
233/// Compute output(x) = lhs(x)rhs(x) mod x^N - 1.
234/// Do this recursively using a convolution and negacyclic convolution of size HALF_N = N/2.
235#[inline(always)]
236fn conv_n_recursive<const N: usize, const HALF_N: usize, T, U, V, C, NC>(
237    lhs: [T; N],
238    rhs: [U; N],
239    output: &mut [V],
240    inner_conv: C,
241    inner_negacyclic_conv: NC,
242) where
243    T: RngElt,
244    U: RngElt,
245    V: RngElt,
246    C: Fn([T; HALF_N], [U; HALF_N], &mut [V]),
247    NC: Fn([T; HALF_N], [U; HALF_N], &mut [V]),
248{
249    debug_assert_eq!(2 * HALF_N, N);
250    // NB: The compiler is smart enough not to initialise these arrays.
251    let mut lhs_pos = [T::default(); HALF_N]; // lhs_pos = lhs(x) mod x^{N/2} - 1
252    let mut lhs_neg = [T::default(); HALF_N]; // lhs_neg = lhs(x) mod x^{N/2} + 1
253    let mut rhs_pos = [U::default(); HALF_N]; // rhs_pos = rhs(x) mod x^{N/2} - 1
254    let mut rhs_neg = [U::default(); HALF_N]; // rhs_neg = rhs(x) mod x^{N/2} + 1
255
256    for i in 0..HALF_N {
257        let s = lhs[i];
258        let t = lhs[i + HALF_N];
259
260        lhs_pos[i] = s + t;
261        lhs_neg[i] = s - t;
262
263        let s = rhs[i];
264        let t = rhs[i + HALF_N];
265
266        rhs_pos[i] = s + t;
267        rhs_neg[i] = s - t;
268    }
269
270    let (left, right) = output.split_at_mut(HALF_N);
271
272    // left = w1 = lhs(x)rhs(x) mod x^{N/2} + 1
273    inner_negacyclic_conv(lhs_neg, rhs_neg, left);
274
275    // right = w0 = lhs(x)rhs(x) mod x^{N/2} - 1
276    inner_conv(lhs_pos, rhs_pos, right);
277
278    for i in 0..HALF_N {
279        left[i] += right[i]; // w_0 + w_1
280        left[i] >>= 1; // (w_0 + w_1)/2
281        right[i] -= left[i]; // (w_0 - w_1)/2
282    }
283}
284
285/// Compute output(x) = lhs(x)rhs(x) mod x^N + 1.
286/// Do this recursively using three negacyclic convolutions of size HALF_N = N/2.
287#[inline(always)]
288fn negacyclic_conv_n_recursive<const N: usize, const HALF_N: usize, T, U, V, NC>(
289    lhs: [T; N],
290    rhs: [U; N],
291    output: &mut [V],
292    inner_negacyclic_conv: NC,
293) where
294    T: RngElt,
295    U: RngElt,
296    V: RngElt,
297    NC: Fn([T; HALF_N], [U; HALF_N], &mut [V]),
298{
299    debug_assert_eq!(2 * HALF_N, N);
300    // NB: The compiler is smart enough not to initialise these arrays.
301    let mut lhs_even = [T::default(); HALF_N];
302    let mut lhs_odd = [T::default(); HALF_N];
303    let mut lhs_sum = [T::default(); HALF_N];
304    let mut rhs_even = [U::default(); HALF_N];
305    let mut rhs_odd = [U::default(); HALF_N];
306    let mut rhs_sum = [U::default(); HALF_N];
307
308    for i in 0..HALF_N {
309        let s = lhs[2 * i];
310        let t = lhs[2 * i + 1];
311        lhs_even[i] = s;
312        lhs_odd[i] = t;
313        lhs_sum[i] = s + t;
314
315        let s = rhs[2 * i];
316        let t = rhs[2 * i + 1];
317        rhs_even[i] = s;
318        rhs_odd[i] = t;
319        rhs_sum[i] = s + t;
320    }
321
322    let mut even_s_conv = [V::default(); HALF_N];
323    let (left, right) = output.split_at_mut(HALF_N);
324
325    // Recursively compute the size N/2 negacyclic convolutions of
326    // the even parts, odd parts, and sums.
327    inner_negacyclic_conv(lhs_even, rhs_even, &mut even_s_conv);
328    inner_negacyclic_conv(lhs_odd, rhs_odd, left);
329    inner_negacyclic_conv(lhs_sum, rhs_sum, right);
330
331    // Adjust so that the correct values are in right and
332    // even_s_conv respectively:
333    right[0] -= even_s_conv[0] + left[0];
334    even_s_conv[0] -= left[HALF_N - 1];
335
336    for i in 1..HALF_N {
337        right[i] -= even_s_conv[i] + left[i];
338        even_s_conv[i] += left[i - 1];
339    }
340
341    // Interleave even_s_conv and right in the output:
342    for i in 0..HALF_N {
343        output[2 * i] = even_s_conv[i];
344        output[2 * i + 1] = output[i + HALF_N];
345    }
346}