Skip to main content

p3_mds/
karatsuba_convolution.rs

1//! Calculate the convolution of two vectors using a Karatsuba-style
2//! decomposition and the CRT.
3//!
4//! This is not a new idea, but we did have the pleasure of
5//! reinventing it independently. Some references:
6//! - `<https://cr.yp.to/lineartime/multapps-20080515.pdf>`
7//! - `<https://2π.com/23/convolution/>`
8//!
9//! Given a vector v \in F^N, let v(x) \in F[x] denote the polynomial
10//! v_0 + v_1 x + ... + v_{N - 1} x^{N - 1}.  Then w is equal to the
11//! convolution v * u if and only if w(x) = v(x)u(x) mod x^N - 1.
12//! Additionally, define the negacyclic convolution by w(x) = v(x)u(x)
13//! mod x^N + 1.  Using the Chinese remainder theorem we can compute
14//! w(x) as
15//!     w(x) = 1/2 (w_0(x) + w_1(x)) + x^{N/2}/2 (w_0(x) - w_1(x))
16//! where
17//!     w_0 = v(x)u(x) mod x^{N/2} - 1
18//!     w_1 = v(x)u(x) mod x^{N/2} + 1
19//!
20//! To compute w_0 and w_1 we first compute
21//!                  v_0(x) = v(x) mod x^{N/2} - 1
22//!                  v_1(x) = v(x) mod x^{N/2} + 1
23//!                  u_0(x) = u(x) mod x^{N/2} - 1
24//!                  u_1(x) = u(x) mod x^{N/2} + 1
25//!
26//! Now w_0 is the convolution of v_0 and u_0 which we can compute
27//! recursively.  For w_1 we compute the negacyclic convolution
28//! v_1(x)u_1(x) mod x^{N/2} + 1 using Karatsuba.
29//!
30//! There are 2 possible approaches to applying Karatsuba which mirror
31//! the DIT vs DIF approaches to FFT's, the left/right decomposition
32//! or the even/odd decomposition. The latter seems to have fewer
33//! operations and so it is the one implemented below, though it does
34//! require a bit more data manipulation. It works as follows:
35//!
36//! Define the even v_e and odd v_o parts so that v(x) = (v_e(x^2) + x v_o(x^2)).
37//! Then v(x)u(x)
38//!    = (v_e(x^2)u_e(x^2) + x^2 v_o(x^2)u_o(x^2))
39//!      + x ((v_e(x^2) + v_o(x^2))(u_e(x^2) + u_o(x^2))
40//!            - (v_e(x^2)u_e(x^2) + v_o(x^2)u_o(x^2)))
41//! This reduces the problem to 3 negacyclic convolutions of size N/2 which
42//! can be computed recursively.
43//!
44//! Of course, for small sizes we just explicitly write out the O(n^2)
45//! approach.
46
47use core::marker::PhantomData;
48use core::ops::{Add, AddAssign, Neg, Sub, SubAssign};
49
50use p3_field::{Algebra, Field};
51
52/// Bound alias for the wide operand type (used for both lhs and output).
53///
54/// Must support addition, subtraction, negation, and in-place variants.
55pub trait ConvolutionElt:
56    Add<Output = Self> + AddAssign + Copy + Neg<Output = Self> + Sub<Output = Self> + SubAssign
57{
58}
59
60impl<T> ConvolutionElt for T where
61    T: Add<Output = T> + AddAssign + Copy + Neg<Output = T> + Sub<Output = T> + SubAssign
62{
63}
64
65/// Bound alias for the narrow operand type (rhs only).
66///
67/// Requires addition, subtraction, negation, and copy.
68pub trait ConvolutionRhs:
69    Add<Output = Self> + Copy + Neg<Output = Self> + Sub<Output = Self>
70{
71}
72
73impl<T> ConvolutionRhs for T where T: Add<Output = T> + Copy + Neg<Output = T> + Sub<Output = T> {}
74
75/// Trait for computing cyclic and negacyclic convolutions.
76///
77/// Implementors choose how to lift field elements into a wider type,
78/// compute dot products, and reduce back.
79/// This allows integer-lifted arithmetic (e.g. i64) to avoid modular
80/// reductions inside the inner loops.
81///
82/// # Overflow contract
83///
84/// For a convolution of size N, it must be possible to add N elements
85/// of type T without overflow, and similarly for U.
86/// The product of one T and one U element must not overflow T after
87/// about N further additions.
88///
89/// # Performance notes
90///
91/// In practice one operand is a compile-time constant (the MDS matrix).
92/// The compiler folds the constant arithmetic at compile time.
93/// For large matrices (N >= 24), the compile-time-generated constants
94/// are about N times bigger than strictly necessary.
95pub trait Convolve<F, T: ConvolutionElt, U: ConvolutionRhs> {
96    /// Additive identity for the wide operand type `T`.
97    ///
98    /// Used to initialize output and scratch arrays before the convolution
99    /// fills them with computed values.
100    const T_ZERO: T;
101
102    /// Additive identity for the narrow operand type `U`.
103    ///
104    /// Used to initialize temporary arrays for the RHS decomposition
105    /// in the recursive CRT / Karatsuba steps.
106    const U_ZERO: U;
107
108    /// Divide an element of `T` by 2.
109    ///
110    /// - For integers (`i64`, `i128`): arithmetic right shift by 1.
111    /// - For field elements: multiplication by the multiplicative inverse of 2.
112    fn halve(val: T) -> T;
113
114    /// Given an input element, retrieve the corresponding internal
115    /// element that will be used in calculations.
116    fn read(input: F) -> T;
117
118    /// Given input vectors `lhs` and `rhs`, calculate their dot
119    /// product. The result can be reduced with respect to the modulus
120    /// (of `F`), but it must have the same lower 10 bits as the dot
121    /// product if all inputs are considered integers. See
122    /// `monty-31/src/mds.rs::barrett_red_monty31()` for an example
123    /// of how this can be implemented in practice.
124    fn parity_dot<const N: usize>(lhs: [T; N], rhs: [U; N]) -> T;
125
126    /// Convert an internal element of type `T` back into an external
127    /// element.
128    fn reduce(z: T) -> F;
129
130    /// Convolve `lhs` and `rhs`.
131    ///
132    /// The parameter `conv` should be the function in this trait that
133    /// corresponds to length `N`.
134    #[inline(always)]
135    fn apply<const N: usize, C: Fn([T; N], [U; N], &mut [T])>(
136        lhs: [F; N],
137        rhs: [U; N],
138        conv: C,
139    ) -> [F; N] {
140        let lhs = lhs.map(Self::read);
141        let mut output = [Self::T_ZERO; N];
142        conv(lhs, rhs, &mut output);
143        output.map(Self::reduce)
144    }
145
146    #[inline(always)]
147    fn conv3(lhs: [T; 3], rhs: [U; 3], output: &mut [T]) {
148        output[0] = Self::parity_dot(lhs, [rhs[0], rhs[2], rhs[1]]);
149        output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], rhs[2]]);
150        output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0]]);
151    }
152
153    #[inline(always)]
154    fn negacyclic_conv3(lhs: [T; 3], rhs: [U; 3], output: &mut [T]) {
155        output[0] = Self::parity_dot(lhs, [rhs[0], -rhs[2], -rhs[1]]);
156        output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], -rhs[2]]);
157        output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0]]);
158    }
159
160    #[inline(always)]
161    fn conv4(lhs: [T; 4], rhs: [U; 4], output: &mut [T]) {
162        // NB: This is just explicitly implementing
163        // conv_n_recursive::<4, 2, _, _>(lhs, rhs, output, Self::conv2, Self::negacyclic_conv2)
164        let u_p = [lhs[0] + lhs[2], lhs[1] + lhs[3]];
165        let u_m = [lhs[0] - lhs[2], lhs[1] - lhs[3]];
166        let v_p = [rhs[0] + rhs[2], rhs[1] + rhs[3]];
167        let v_m = [rhs[0] - rhs[2], rhs[1] - rhs[3]];
168
169        output[0] = Self::parity_dot(u_m, [v_m[0], -v_m[1]]);
170        output[1] = Self::parity_dot(u_m, [v_m[1], v_m[0]]);
171        output[2] = Self::parity_dot(u_p, v_p);
172        output[3] = Self::parity_dot(u_p, [v_p[1], v_p[0]]);
173
174        output[0] += output[2];
175        output[1] += output[3];
176
177        output[0] = Self::halve(output[0]);
178        output[1] = Self::halve(output[1]);
179
180        output[2] -= output[0];
181        output[3] -= output[1];
182    }
183
184    #[inline(always)]
185    fn negacyclic_conv4(lhs: [T; 4], rhs: [U; 4], output: &mut [T]) {
186        output[0] = Self::parity_dot(lhs, [rhs[0], -rhs[3], -rhs[2], -rhs[1]]);
187        output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], -rhs[3], -rhs[2]]);
188        output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0], -rhs[3]]);
189        output[3] = Self::parity_dot(lhs, [rhs[3], rhs[2], rhs[1], rhs[0]]);
190    }
191
192    /// Compute output(x) = lhs(x)rhs(x) mod x^N - 1 recursively using
193    /// a convolution and negacyclic convolution of size HALF_N = N/2.
194    #[inline(always)]
195    fn conv_n_recursive<const N: usize, const HALF_N: usize, C, NC>(
196        lhs: [T; N],
197        rhs: [U; N],
198        output: &mut [T],
199        inner_conv: C,
200        inner_negacyclic_conv: NC,
201    ) where
202        C: Fn([T; HALF_N], [U; HALF_N], &mut [T]),
203        NC: Fn([T; HALF_N], [U; HALF_N], &mut [T]),
204    {
205        debug_assert_eq!(2 * HALF_N, N);
206        let mut lhs_pos = [Self::T_ZERO; HALF_N]; // lhs_pos = lhs(x) mod x^{N/2} - 1
207        let mut lhs_neg = [Self::T_ZERO; HALF_N]; // lhs_neg = lhs(x) mod x^{N/2} + 1
208        let mut rhs_pos = [Self::U_ZERO; HALF_N]; // rhs_pos = rhs(x) mod x^{N/2} - 1
209        let mut rhs_neg = [Self::U_ZERO; HALF_N]; // rhs_neg = rhs(x) mod x^{N/2} + 1
210
211        for i in 0..HALF_N {
212            let s = lhs[i];
213            let t = lhs[i + HALF_N];
214
215            lhs_pos[i] = s + t;
216            lhs_neg[i] = s - t;
217
218            let s = rhs[i];
219            let t = rhs[i + HALF_N];
220
221            rhs_pos[i] = s + t;
222            rhs_neg[i] = s - t;
223        }
224
225        let (left, right) = output.split_at_mut(HALF_N);
226
227        // left = w1 = lhs(x)rhs(x) mod x^{N/2} + 1
228        inner_negacyclic_conv(lhs_neg, rhs_neg, left);
229
230        // right = w0 = lhs(x)rhs(x) mod x^{N/2} - 1
231        inner_conv(lhs_pos, rhs_pos, right);
232
233        for i in 0..HALF_N {
234            left[i] += right[i]; // w_0 + w_1
235            left[i] = Self::halve(left[i]); // (w_0 + w_1)/2
236            right[i] -= left[i]; // (w_0 - w_1)/2
237        }
238    }
239
240    /// Compute output(x) = lhs(x)rhs(x) mod x^N + 1 recursively using
241    /// three negacyclic convolutions of size HALF_N = N/2.
242    #[inline(always)]
243    fn negacyclic_conv_n_recursive<const N: usize, const HALF_N: usize, NC>(
244        lhs: [T; N],
245        rhs: [U; N],
246        output: &mut [T],
247        inner_negacyclic_conv: NC,
248    ) where
249        NC: Fn([T; HALF_N], [U; HALF_N], &mut [T]),
250    {
251        debug_assert_eq!(2 * HALF_N, N);
252        let mut lhs_even = [Self::T_ZERO; HALF_N];
253        let mut lhs_odd = [Self::T_ZERO; HALF_N];
254        let mut lhs_sum = [Self::T_ZERO; HALF_N];
255        let mut rhs_even = [Self::U_ZERO; HALF_N];
256        let mut rhs_odd = [Self::U_ZERO; HALF_N];
257        let mut rhs_sum = [Self::U_ZERO; HALF_N];
258
259        for i in 0..HALF_N {
260            let s = lhs[2 * i];
261            let t = lhs[2 * i + 1];
262            lhs_even[i] = s;
263            lhs_odd[i] = t;
264            lhs_sum[i] = s + t;
265
266            let s = rhs[2 * i];
267            let t = rhs[2 * i + 1];
268            rhs_even[i] = s;
269            rhs_odd[i] = t;
270            rhs_sum[i] = s + t;
271        }
272
273        let mut even_s_conv = [Self::T_ZERO; HALF_N];
274        let (left, right) = output.split_at_mut(HALF_N);
275
276        // Recursively compute the size N/2 negacyclic convolutions of
277        // the even parts, odd parts, and sums.
278        inner_negacyclic_conv(lhs_even, rhs_even, &mut even_s_conv);
279        inner_negacyclic_conv(lhs_odd, rhs_odd, left);
280        inner_negacyclic_conv(lhs_sum, rhs_sum, right);
281
282        // Adjust so that the correct values are in right and
283        // even_s_conv respectively:
284        right[0] -= even_s_conv[0] + left[0];
285        even_s_conv[0] -= left[HALF_N - 1];
286
287        for i in 1..HALF_N {
288            right[i] -= even_s_conv[i] + left[i];
289            even_s_conv[i] += left[i - 1];
290        }
291
292        // Interleave even_s_conv and right in the output:
293        for i in 0..HALF_N {
294            output[2 * i] = even_s_conv[i];
295            output[2 * i + 1] = output[i + HALF_N];
296        }
297    }
298
299    #[inline(always)]
300    fn conv6(lhs: [T; 6], rhs: [U; 6], output: &mut [T]) {
301        Self::conv_n_recursive(lhs, rhs, output, Self::conv3, Self::negacyclic_conv3);
302    }
303
304    #[inline(always)]
305    fn negacyclic_conv6(lhs: [T; 6], rhs: [U; 6], output: &mut [T]) {
306        Self::negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv3);
307    }
308
309    #[inline(always)]
310    fn conv8(lhs: [T; 8], rhs: [U; 8], output: &mut [T]) {
311        Self::conv_n_recursive(lhs, rhs, output, Self::conv4, Self::negacyclic_conv4);
312    }
313
314    #[inline(always)]
315    fn negacyclic_conv8(lhs: [T; 8], rhs: [U; 8], output: &mut [T]) {
316        Self::negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv4);
317    }
318
319    #[inline(always)]
320    fn conv12(lhs: [T; 12], rhs: [U; 12], output: &mut [T]) {
321        Self::conv_n_recursive(lhs, rhs, output, Self::conv6, Self::negacyclic_conv6);
322    }
323
324    #[inline(always)]
325    fn negacyclic_conv12(lhs: [T; 12], rhs: [U; 12], output: &mut [T]) {
326        Self::negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv6);
327    }
328
329    #[inline(always)]
330    fn conv16(lhs: [T; 16], rhs: [U; 16], output: &mut [T]) {
331        Self::conv_n_recursive(lhs, rhs, output, Self::conv8, Self::negacyclic_conv8);
332    }
333
334    #[inline(always)]
335    fn negacyclic_conv16(lhs: [T; 16], rhs: [U; 16], output: &mut [T]) {
336        Self::negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv8);
337    }
338
339    #[inline(always)]
340    fn conv24(lhs: [T; 24], rhs: [U; 24], output: &mut [T]) {
341        Self::conv_n_recursive(lhs, rhs, output, Self::conv12, Self::negacyclic_conv12);
342    }
343
344    #[inline(always)]
345    fn conv32(lhs: [T; 32], rhs: [U; 32], output: &mut [T]) {
346        Self::conv_n_recursive(lhs, rhs, output, Self::conv16, Self::negacyclic_conv16);
347    }
348
349    #[inline(always)]
350    fn negacyclic_conv32(lhs: [T; 32], rhs: [U; 32], output: &mut [T]) {
351        Self::negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv16);
352    }
353
354    #[inline(always)]
355    fn conv64(lhs: [T; 64], rhs: [U; 64], output: &mut [T]) {
356        Self::conv_n_recursive(lhs, rhs, output, Self::conv32, Self::negacyclic_conv32);
357    }
358}
359
360/// Convolution implementor that stays entirely within the field.
361///
362/// No integer lifting — all operations are native field arithmetic.
363/// Used by the public Karatsuba entry points for generic field/algebra pairs.
364struct FieldConvolve<F, A>(PhantomData<(F, A)>);
365
366impl<F: Field, A: Algebra<F> + Copy> Convolve<A, A, F> for FieldConvolve<F, A> {
367    const T_ZERO: A = A::ZERO;
368    const U_ZERO: F = F::ZERO;
369
370    #[inline(always)]
371    fn halve(val: A) -> A {
372        val.halve()
373    }
374
375    #[inline(always)]
376    fn read(input: A) -> A {
377        input
378    }
379
380    #[inline(always)]
381    fn parity_dot<const N: usize>(lhs: [A; N], rhs: [F; N]) -> A {
382        A::mixed_dot_product(&lhs, &rhs)
383    }
384
385    #[inline(always)]
386    fn reduce(z: A) -> A {
387        z
388    }
389}
390
391/// Circulant matrix-vector multiply for width 16 via Karatsuba convolution.
392#[inline]
393pub fn mds_circulant_karatsuba_16<F: Field, A: Algebra<F> + Copy>(
394    state: &mut [A; 16],
395    col: &[F; 16],
396) {
397    let input = *state;
398    FieldConvolve::<F, A>::conv16(input, *col, state.as_mut_slice());
399}
400
401/// Circulant matrix-vector multiply for width 24 via Karatsuba convolution.
402#[inline]
403pub fn mds_circulant_karatsuba_24<F: Field, A: Algebra<F> + Copy>(
404    state: &mut [A; 24],
405    col: &[F; 24],
406) {
407    let input = *state;
408    FieldConvolve::<F, A>::conv24(input, *col, state.as_mut_slice());
409}
410
411#[cfg(test)]
412mod tests {
413    use p3_baby_bear::BabyBear;
414    use p3_field::PrimeCharacteristicRing;
415    use proptest::prelude::*;
416
417    use super::*;
418
419    type F = BabyBear;
420
421    fn arb_f() -> impl Strategy<Value = F> {
422        prop::num::u32::ANY.prop_map(F::from_u32)
423    }
424
425    fn naive_cyclic_conv<const N: usize>(lhs: [F; N], rhs: [F; N]) -> [F; N] {
426        // O(N^2) reference: w[i] = sum_j lhs[j] * rhs[(i - j) mod N].
427        core::array::from_fn(|i| {
428            let mut acc = F::ZERO;
429            for j in 0..N {
430                acc += lhs[j] * rhs[(N + i - j) % N];
431            }
432            acc
433        })
434    }
435
436    fn naive_negacyclic_conv<const N: usize>(lhs: [F; N], rhs: [F; N]) -> [F; N] {
437        // O(N^2) reference: w(x) = lhs(x) * rhs(x) mod (x^N + 1).
438        // Coefficients that wrap past degree N-1 are subtracted (negacyclic).
439        let mut out = [F::ZERO; N];
440        for (i, &l) in lhs.iter().enumerate() {
441            for (j, &r) in rhs.iter().enumerate() {
442                let k = i + j;
443                if k < N {
444                    out[k] += l * r;
445                } else {
446                    out[k - N] -= l * r;
447                }
448            }
449        }
450        out
451    }
452
453    fn check_conv<const N: usize>(
454        lhs: [F; N],
455        rhs: [F; N],
456        conv_fn: fn([F; N], [F; N], &mut [F]),
457        naive_fn: fn([F; N], [F; N]) -> [F; N],
458    ) {
459        let expected = naive_fn(lhs, rhs);
460        let mut output = [F::ZERO; N];
461        conv_fn(lhs, rhs, &mut output);
462        assert_eq!(output, expected, "convolution mismatch");
463    }
464
465    macro_rules! conv_test {
466        ($name:ident, $n:expr, $conv:expr, $naive:expr, $arr:ident) => {
467            proptest! {
468                #[test]
469                fn $name(
470                    lhs in prop::array::$arr(arb_f()),
471                    rhs in prop::array::$arr(arb_f()),
472                ) {
473                    check_conv::<$n>(lhs, rhs, $conv, $naive);
474                }
475            }
476        };
477    }
478
479    // Width 3
480    conv_test!(
481        conv3_matches_naive,
482        3,
483        FieldConvolve::<F, F>::conv3,
484        naive_cyclic_conv,
485        uniform3
486    );
487    conv_test!(
488        negacyclic_conv3_matches_naive,
489        3,
490        FieldConvolve::<F, F>::negacyclic_conv3,
491        naive_negacyclic_conv,
492        uniform3
493    );
494
495    // Width 4
496    conv_test!(
497        conv4_matches_naive,
498        4,
499        FieldConvolve::<F, F>::conv4,
500        naive_cyclic_conv,
501        uniform4
502    );
503    conv_test!(
504        negacyclic_conv4_matches_naive,
505        4,
506        FieldConvolve::<F, F>::negacyclic_conv4,
507        naive_negacyclic_conv,
508        uniform4
509    );
510
511    // Width 6
512    conv_test!(
513        conv6_matches_naive,
514        6,
515        FieldConvolve::<F, F>::conv6,
516        naive_cyclic_conv,
517        uniform6
518    );
519    conv_test!(
520        negacyclic_conv6_matches_naive,
521        6,
522        FieldConvolve::<F, F>::negacyclic_conv6,
523        naive_negacyclic_conv,
524        uniform6
525    );
526
527    // Width 8
528    conv_test!(
529        conv8_matches_naive,
530        8,
531        FieldConvolve::<F, F>::conv8,
532        naive_cyclic_conv,
533        uniform8
534    );
535    conv_test!(
536        negacyclic_conv8_matches_naive,
537        8,
538        FieldConvolve::<F, F>::negacyclic_conv8,
539        naive_negacyclic_conv,
540        uniform8
541    );
542
543    // Width 12
544    conv_test!(
545        conv12_matches_naive,
546        12,
547        FieldConvolve::<F, F>::conv12,
548        naive_cyclic_conv,
549        uniform12
550    );
551    conv_test!(
552        negacyclic_conv12_matches_naive,
553        12,
554        FieldConvolve::<F, F>::negacyclic_conv12,
555        naive_negacyclic_conv,
556        uniform12
557    );
558
559    // Width 16
560    conv_test!(
561        conv16_matches_naive,
562        16,
563        FieldConvolve::<F, F>::conv16,
564        naive_cyclic_conv,
565        uniform16
566    );
567    conv_test!(
568        negacyclic_conv16_matches_naive,
569        16,
570        FieldConvolve::<F, F>::negacyclic_conv16,
571        naive_negacyclic_conv,
572        uniform16
573    );
574
575    // Width 24
576    conv_test!(
577        conv24_matches_naive,
578        24,
579        FieldConvolve::<F, F>::conv24,
580        naive_cyclic_conv,
581        uniform24
582    );
583
584    // Width 32
585    conv_test!(
586        conv32_matches_naive,
587        32,
588        FieldConvolve::<F, F>::conv32,
589        naive_cyclic_conv,
590        uniform32
591    );
592    conv_test!(
593        negacyclic_conv32_matches_naive,
594        32,
595        FieldConvolve::<F, F>::negacyclic_conv32,
596        naive_negacyclic_conv,
597        uniform32
598    );
599
600    #[test]
601    fn conv64_matches_naive_fixed() {
602        let lhs: [F; 64] = core::array::from_fn(|i| F::from_u32(i as u32 + 1));
603        let rhs: [F; 64] = core::array::from_fn(|i| F::from_u32(64 - i as u32));
604        check_conv::<64>(lhs, rhs, FieldConvolve::<F, F>::conv64, naive_cyclic_conv);
605    }
606
607    #[test]
608    fn conv64_all_ones() {
609        let ones = [F::ONE; 64];
610        let expected = naive_cyclic_conv(ones, ones);
611        let mut output = [F::ZERO; 64];
612        FieldConvolve::<F, F>::conv64(ones, ones, &mut output);
613        assert_eq!(output, expected);
614    }
615
616    proptest! {
617        #[test]
618        fn karatsuba_16_matches_naive(
619            col in prop::array::uniform16(arb_f()),
620            state in prop::array::uniform16(arb_f()),
621        ) {
622            let expected = naive_cyclic_conv(state, col);
623            let mut actual = state;
624            mds_circulant_karatsuba_16(&mut actual, &col);
625            prop_assert_eq!(actual, expected);
626        }
627
628        #[test]
629        fn karatsuba_24_matches_naive(
630            col in prop::array::uniform24(arb_f()),
631            state in prop::array::uniform24(arb_f()),
632        ) {
633            let expected = naive_cyclic_conv(state, col);
634            let mut actual = state;
635            mds_circulant_karatsuba_24(&mut actual, &col);
636            prop_assert_eq!(actual, expected);
637        }
638    }
639
640    proptest! {
641        #[test]
642        fn conv8_commutative(
643            a in prop::array::uniform8(arb_f()),
644            b in prop::array::uniform8(arb_f()),
645        ) {
646            // Cyclic convolution is commutative: a * b = b * a.
647            let mut ab = [F::ZERO; 8];
648            let mut ba = [F::ZERO; 8];
649            FieldConvolve::<F, F>::conv8(a, b, &mut ab);
650            FieldConvolve::<F, F>::conv8(b, a, &mut ba);
651            prop_assert_eq!(ab, ba);
652        }
653
654        #[test]
655        fn conv8_identity(a in prop::array::uniform8(arb_f())) {
656            // The delta impulse [1, 0, 0, ...] is the convolution identity.
657            let mut id = [F::ZERO; 8];
658            id[0] = F::ONE;
659            let mut out = [F::ZERO; 8];
660            FieldConvolve::<F, F>::conv8(a, id, &mut out);
661            prop_assert_eq!(out, a);
662        }
663
664        #[test]
665        fn conv8_zero(a in prop::array::uniform8(arb_f())) {
666            // Convolving with the zero vector must produce all zeros.
667            let zeros = [F::ZERO; 8];
668            let mut out = [F::ZERO; 8];
669            FieldConvolve::<F, F>::conv8(a, zeros, &mut out);
670            prop_assert_eq!(out, zeros);
671        }
672    }
673}