1use core::marker::PhantomData;
48use core::ops::{Add, AddAssign, Neg, Sub, SubAssign};
49
50use p3_field::{Algebra, Field};
51
52pub trait ConvolutionElt:
56 Add<Output = Self> + AddAssign + Copy + Neg<Output = Self> + Sub<Output = Self> + SubAssign
57{
58}
59
60impl<T> ConvolutionElt for T where
61 T: Add<Output = T> + AddAssign + Copy + Neg<Output = T> + Sub<Output = T> + SubAssign
62{
63}
64
65pub trait ConvolutionRhs:
69 Add<Output = Self> + Copy + Neg<Output = Self> + Sub<Output = Self>
70{
71}
72
73impl<T> ConvolutionRhs for T where T: Add<Output = T> + Copy + Neg<Output = T> + Sub<Output = T> {}
74
75pub trait Convolve<F, T: ConvolutionElt, U: ConvolutionRhs> {
96 const T_ZERO: T;
101
102 const U_ZERO: U;
107
108 fn halve(val: T) -> T;
113
114 fn read(input: F) -> T;
117
118 fn parity_dot<const N: usize>(lhs: [T; N], rhs: [U; N]) -> T;
125
126 fn reduce(z: T) -> F;
129
130 #[inline(always)]
135 fn apply<const N: usize, C: Fn([T; N], [U; N], &mut [T])>(
136 lhs: [F; N],
137 rhs: [U; N],
138 conv: C,
139 ) -> [F; N] {
140 let lhs = lhs.map(Self::read);
141 let mut output = [Self::T_ZERO; N];
142 conv(lhs, rhs, &mut output);
143 output.map(Self::reduce)
144 }
145
146 #[inline(always)]
147 fn conv3(lhs: [T; 3], rhs: [U; 3], output: &mut [T]) {
148 output[0] = Self::parity_dot(lhs, [rhs[0], rhs[2], rhs[1]]);
149 output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], rhs[2]]);
150 output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0]]);
151 }
152
153 #[inline(always)]
154 fn negacyclic_conv3(lhs: [T; 3], rhs: [U; 3], output: &mut [T]) {
155 output[0] = Self::parity_dot(lhs, [rhs[0], -rhs[2], -rhs[1]]);
156 output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], -rhs[2]]);
157 output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0]]);
158 }
159
160 #[inline(always)]
161 fn conv4(lhs: [T; 4], rhs: [U; 4], output: &mut [T]) {
162 let u_p = [lhs[0] + lhs[2], lhs[1] + lhs[3]];
165 let u_m = [lhs[0] - lhs[2], lhs[1] - lhs[3]];
166 let v_p = [rhs[0] + rhs[2], rhs[1] + rhs[3]];
167 let v_m = [rhs[0] - rhs[2], rhs[1] - rhs[3]];
168
169 output[0] = Self::parity_dot(u_m, [v_m[0], -v_m[1]]);
170 output[1] = Self::parity_dot(u_m, [v_m[1], v_m[0]]);
171 output[2] = Self::parity_dot(u_p, v_p);
172 output[3] = Self::parity_dot(u_p, [v_p[1], v_p[0]]);
173
174 output[0] += output[2];
175 output[1] += output[3];
176
177 output[0] = Self::halve(output[0]);
178 output[1] = Self::halve(output[1]);
179
180 output[2] -= output[0];
181 output[3] -= output[1];
182 }
183
184 #[inline(always)]
185 fn negacyclic_conv4(lhs: [T; 4], rhs: [U; 4], output: &mut [T]) {
186 output[0] = Self::parity_dot(lhs, [rhs[0], -rhs[3], -rhs[2], -rhs[1]]);
187 output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], -rhs[3], -rhs[2]]);
188 output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0], -rhs[3]]);
189 output[3] = Self::parity_dot(lhs, [rhs[3], rhs[2], rhs[1], rhs[0]]);
190 }
191
192 #[inline(always)]
195 fn conv_n_recursive<const N: usize, const HALF_N: usize, C, NC>(
196 lhs: [T; N],
197 rhs: [U; N],
198 output: &mut [T],
199 inner_conv: C,
200 inner_negacyclic_conv: NC,
201 ) where
202 C: Fn([T; HALF_N], [U; HALF_N], &mut [T]),
203 NC: Fn([T; HALF_N], [U; HALF_N], &mut [T]),
204 {
205 debug_assert_eq!(2 * HALF_N, N);
206 let mut lhs_pos = [Self::T_ZERO; HALF_N]; let mut lhs_neg = [Self::T_ZERO; HALF_N]; let mut rhs_pos = [Self::U_ZERO; HALF_N]; let mut rhs_neg = [Self::U_ZERO; HALF_N]; for i in 0..HALF_N {
212 let s = lhs[i];
213 let t = lhs[i + HALF_N];
214
215 lhs_pos[i] = s + t;
216 lhs_neg[i] = s - t;
217
218 let s = rhs[i];
219 let t = rhs[i + HALF_N];
220
221 rhs_pos[i] = s + t;
222 rhs_neg[i] = s - t;
223 }
224
225 let (left, right) = output.split_at_mut(HALF_N);
226
227 inner_negacyclic_conv(lhs_neg, rhs_neg, left);
229
230 inner_conv(lhs_pos, rhs_pos, right);
232
233 for i in 0..HALF_N {
234 left[i] += right[i]; left[i] = Self::halve(left[i]); right[i] -= left[i]; }
238 }
239
240 #[inline(always)]
243 fn negacyclic_conv_n_recursive<const N: usize, const HALF_N: usize, NC>(
244 lhs: [T; N],
245 rhs: [U; N],
246 output: &mut [T],
247 inner_negacyclic_conv: NC,
248 ) where
249 NC: Fn([T; HALF_N], [U; HALF_N], &mut [T]),
250 {
251 debug_assert_eq!(2 * HALF_N, N);
252 let mut lhs_even = [Self::T_ZERO; HALF_N];
253 let mut lhs_odd = [Self::T_ZERO; HALF_N];
254 let mut lhs_sum = [Self::T_ZERO; HALF_N];
255 let mut rhs_even = [Self::U_ZERO; HALF_N];
256 let mut rhs_odd = [Self::U_ZERO; HALF_N];
257 let mut rhs_sum = [Self::U_ZERO; HALF_N];
258
259 for i in 0..HALF_N {
260 let s = lhs[2 * i];
261 let t = lhs[2 * i + 1];
262 lhs_even[i] = s;
263 lhs_odd[i] = t;
264 lhs_sum[i] = s + t;
265
266 let s = rhs[2 * i];
267 let t = rhs[2 * i + 1];
268 rhs_even[i] = s;
269 rhs_odd[i] = t;
270 rhs_sum[i] = s + t;
271 }
272
273 let mut even_s_conv = [Self::T_ZERO; HALF_N];
274 let (left, right) = output.split_at_mut(HALF_N);
275
276 inner_negacyclic_conv(lhs_even, rhs_even, &mut even_s_conv);
279 inner_negacyclic_conv(lhs_odd, rhs_odd, left);
280 inner_negacyclic_conv(lhs_sum, rhs_sum, right);
281
282 right[0] -= even_s_conv[0] + left[0];
285 even_s_conv[0] -= left[HALF_N - 1];
286
287 for i in 1..HALF_N {
288 right[i] -= even_s_conv[i] + left[i];
289 even_s_conv[i] += left[i - 1];
290 }
291
292 for i in 0..HALF_N {
294 output[2 * i] = even_s_conv[i];
295 output[2 * i + 1] = output[i + HALF_N];
296 }
297 }
298
299 #[inline(always)]
300 fn conv6(lhs: [T; 6], rhs: [U; 6], output: &mut [T]) {
301 Self::conv_n_recursive(lhs, rhs, output, Self::conv3, Self::negacyclic_conv3);
302 }
303
304 #[inline(always)]
305 fn negacyclic_conv6(lhs: [T; 6], rhs: [U; 6], output: &mut [T]) {
306 Self::negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv3);
307 }
308
309 #[inline(always)]
310 fn conv8(lhs: [T; 8], rhs: [U; 8], output: &mut [T]) {
311 Self::conv_n_recursive(lhs, rhs, output, Self::conv4, Self::negacyclic_conv4);
312 }
313
314 #[inline(always)]
315 fn negacyclic_conv8(lhs: [T; 8], rhs: [U; 8], output: &mut [T]) {
316 Self::negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv4);
317 }
318
319 #[inline(always)]
320 fn conv12(lhs: [T; 12], rhs: [U; 12], output: &mut [T]) {
321 Self::conv_n_recursive(lhs, rhs, output, Self::conv6, Self::negacyclic_conv6);
322 }
323
324 #[inline(always)]
325 fn negacyclic_conv12(lhs: [T; 12], rhs: [U; 12], output: &mut [T]) {
326 Self::negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv6);
327 }
328
329 #[inline(always)]
330 fn conv16(lhs: [T; 16], rhs: [U; 16], output: &mut [T]) {
331 Self::conv_n_recursive(lhs, rhs, output, Self::conv8, Self::negacyclic_conv8);
332 }
333
334 #[inline(always)]
335 fn negacyclic_conv16(lhs: [T; 16], rhs: [U; 16], output: &mut [T]) {
336 Self::negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv8);
337 }
338
339 #[inline(always)]
340 fn conv24(lhs: [T; 24], rhs: [U; 24], output: &mut [T]) {
341 Self::conv_n_recursive(lhs, rhs, output, Self::conv12, Self::negacyclic_conv12);
342 }
343
344 #[inline(always)]
345 fn conv32(lhs: [T; 32], rhs: [U; 32], output: &mut [T]) {
346 Self::conv_n_recursive(lhs, rhs, output, Self::conv16, Self::negacyclic_conv16);
347 }
348
349 #[inline(always)]
350 fn negacyclic_conv32(lhs: [T; 32], rhs: [U; 32], output: &mut [T]) {
351 Self::negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv16);
352 }
353
354 #[inline(always)]
355 fn conv64(lhs: [T; 64], rhs: [U; 64], output: &mut [T]) {
356 Self::conv_n_recursive(lhs, rhs, output, Self::conv32, Self::negacyclic_conv32);
357 }
358}
359
360struct FieldConvolve<F, A>(PhantomData<(F, A)>);
365
366impl<F: Field, A: Algebra<F> + Copy> Convolve<A, A, F> for FieldConvolve<F, A> {
367 const T_ZERO: A = A::ZERO;
368 const U_ZERO: F = F::ZERO;
369
370 #[inline(always)]
371 fn halve(val: A) -> A {
372 val.halve()
373 }
374
375 #[inline(always)]
376 fn read(input: A) -> A {
377 input
378 }
379
380 #[inline(always)]
381 fn parity_dot<const N: usize>(lhs: [A; N], rhs: [F; N]) -> A {
382 A::mixed_dot_product(&lhs, &rhs)
383 }
384
385 #[inline(always)]
386 fn reduce(z: A) -> A {
387 z
388 }
389}
390
391#[inline]
393pub fn mds_circulant_karatsuba_16<F: Field, A: Algebra<F> + Copy>(
394 state: &mut [A; 16],
395 col: &[F; 16],
396) {
397 let input = *state;
398 FieldConvolve::<F, A>::conv16(input, *col, state.as_mut_slice());
399}
400
401#[inline]
403pub fn mds_circulant_karatsuba_24<F: Field, A: Algebra<F> + Copy>(
404 state: &mut [A; 24],
405 col: &[F; 24],
406) {
407 let input = *state;
408 FieldConvolve::<F, A>::conv24(input, *col, state.as_mut_slice());
409}
410
411#[cfg(test)]
412mod tests {
413 use p3_baby_bear::BabyBear;
414 use p3_field::PrimeCharacteristicRing;
415 use proptest::prelude::*;
416
417 use super::*;
418
419 type F = BabyBear;
420
421 fn arb_f() -> impl Strategy<Value = F> {
422 prop::num::u32::ANY.prop_map(F::from_u32)
423 }
424
425 fn naive_cyclic_conv<const N: usize>(lhs: [F; N], rhs: [F; N]) -> [F; N] {
426 core::array::from_fn(|i| {
428 let mut acc = F::ZERO;
429 for j in 0..N {
430 acc += lhs[j] * rhs[(N + i - j) % N];
431 }
432 acc
433 })
434 }
435
436 fn naive_negacyclic_conv<const N: usize>(lhs: [F; N], rhs: [F; N]) -> [F; N] {
437 let mut out = [F::ZERO; N];
440 for (i, &l) in lhs.iter().enumerate() {
441 for (j, &r) in rhs.iter().enumerate() {
442 let k = i + j;
443 if k < N {
444 out[k] += l * r;
445 } else {
446 out[k - N] -= l * r;
447 }
448 }
449 }
450 out
451 }
452
453 fn check_conv<const N: usize>(
454 lhs: [F; N],
455 rhs: [F; N],
456 conv_fn: fn([F; N], [F; N], &mut [F]),
457 naive_fn: fn([F; N], [F; N]) -> [F; N],
458 ) {
459 let expected = naive_fn(lhs, rhs);
460 let mut output = [F::ZERO; N];
461 conv_fn(lhs, rhs, &mut output);
462 assert_eq!(output, expected, "convolution mismatch");
463 }
464
465 macro_rules! conv_test {
466 ($name:ident, $n:expr, $conv:expr, $naive:expr, $arr:ident) => {
467 proptest! {
468 #[test]
469 fn $name(
470 lhs in prop::array::$arr(arb_f()),
471 rhs in prop::array::$arr(arb_f()),
472 ) {
473 check_conv::<$n>(lhs, rhs, $conv, $naive);
474 }
475 }
476 };
477 }
478
479 conv_test!(
481 conv3_matches_naive,
482 3,
483 FieldConvolve::<F, F>::conv3,
484 naive_cyclic_conv,
485 uniform3
486 );
487 conv_test!(
488 negacyclic_conv3_matches_naive,
489 3,
490 FieldConvolve::<F, F>::negacyclic_conv3,
491 naive_negacyclic_conv,
492 uniform3
493 );
494
495 conv_test!(
497 conv4_matches_naive,
498 4,
499 FieldConvolve::<F, F>::conv4,
500 naive_cyclic_conv,
501 uniform4
502 );
503 conv_test!(
504 negacyclic_conv4_matches_naive,
505 4,
506 FieldConvolve::<F, F>::negacyclic_conv4,
507 naive_negacyclic_conv,
508 uniform4
509 );
510
511 conv_test!(
513 conv6_matches_naive,
514 6,
515 FieldConvolve::<F, F>::conv6,
516 naive_cyclic_conv,
517 uniform6
518 );
519 conv_test!(
520 negacyclic_conv6_matches_naive,
521 6,
522 FieldConvolve::<F, F>::negacyclic_conv6,
523 naive_negacyclic_conv,
524 uniform6
525 );
526
527 conv_test!(
529 conv8_matches_naive,
530 8,
531 FieldConvolve::<F, F>::conv8,
532 naive_cyclic_conv,
533 uniform8
534 );
535 conv_test!(
536 negacyclic_conv8_matches_naive,
537 8,
538 FieldConvolve::<F, F>::negacyclic_conv8,
539 naive_negacyclic_conv,
540 uniform8
541 );
542
543 conv_test!(
545 conv12_matches_naive,
546 12,
547 FieldConvolve::<F, F>::conv12,
548 naive_cyclic_conv,
549 uniform12
550 );
551 conv_test!(
552 negacyclic_conv12_matches_naive,
553 12,
554 FieldConvolve::<F, F>::negacyclic_conv12,
555 naive_negacyclic_conv,
556 uniform12
557 );
558
559 conv_test!(
561 conv16_matches_naive,
562 16,
563 FieldConvolve::<F, F>::conv16,
564 naive_cyclic_conv,
565 uniform16
566 );
567 conv_test!(
568 negacyclic_conv16_matches_naive,
569 16,
570 FieldConvolve::<F, F>::negacyclic_conv16,
571 naive_negacyclic_conv,
572 uniform16
573 );
574
575 conv_test!(
577 conv24_matches_naive,
578 24,
579 FieldConvolve::<F, F>::conv24,
580 naive_cyclic_conv,
581 uniform24
582 );
583
584 conv_test!(
586 conv32_matches_naive,
587 32,
588 FieldConvolve::<F, F>::conv32,
589 naive_cyclic_conv,
590 uniform32
591 );
592 conv_test!(
593 negacyclic_conv32_matches_naive,
594 32,
595 FieldConvolve::<F, F>::negacyclic_conv32,
596 naive_negacyclic_conv,
597 uniform32
598 );
599
600 #[test]
601 fn conv64_matches_naive_fixed() {
602 let lhs: [F; 64] = core::array::from_fn(|i| F::from_u32(i as u32 + 1));
603 let rhs: [F; 64] = core::array::from_fn(|i| F::from_u32(64 - i as u32));
604 check_conv::<64>(lhs, rhs, FieldConvolve::<F, F>::conv64, naive_cyclic_conv);
605 }
606
607 #[test]
608 fn conv64_all_ones() {
609 let ones = [F::ONE; 64];
610 let expected = naive_cyclic_conv(ones, ones);
611 let mut output = [F::ZERO; 64];
612 FieldConvolve::<F, F>::conv64(ones, ones, &mut output);
613 assert_eq!(output, expected);
614 }
615
616 proptest! {
617 #[test]
618 fn karatsuba_16_matches_naive(
619 col in prop::array::uniform16(arb_f()),
620 state in prop::array::uniform16(arb_f()),
621 ) {
622 let expected = naive_cyclic_conv(state, col);
623 let mut actual = state;
624 mds_circulant_karatsuba_16(&mut actual, &col);
625 prop_assert_eq!(actual, expected);
626 }
627
628 #[test]
629 fn karatsuba_24_matches_naive(
630 col in prop::array::uniform24(arb_f()),
631 state in prop::array::uniform24(arb_f()),
632 ) {
633 let expected = naive_cyclic_conv(state, col);
634 let mut actual = state;
635 mds_circulant_karatsuba_24(&mut actual, &col);
636 prop_assert_eq!(actual, expected);
637 }
638 }
639
640 proptest! {
641 #[test]
642 fn conv8_commutative(
643 a in prop::array::uniform8(arb_f()),
644 b in prop::array::uniform8(arb_f()),
645 ) {
646 let mut ab = [F::ZERO; 8];
648 let mut ba = [F::ZERO; 8];
649 FieldConvolve::<F, F>::conv8(a, b, &mut ab);
650 FieldConvolve::<F, F>::conv8(b, a, &mut ba);
651 prop_assert_eq!(ab, ba);
652 }
653
654 #[test]
655 fn conv8_identity(a in prop::array::uniform8(arb_f())) {
656 let mut id = [F::ZERO; 8];
658 id[0] = F::ONE;
659 let mut out = [F::ZERO; 8];
660 FieldConvolve::<F, F>::conv8(a, id, &mut out);
661 prop_assert_eq!(out, a);
662 }
663
664 #[test]
665 fn conv8_zero(a in prop::array::uniform8(arb_f())) {
666 let zeros = [F::ZERO; 8];
668 let mut out = [F::ZERO; 8];
669 FieldConvolve::<F, F>::conv8(a, zeros, &mut out);
670 prop_assert_eq!(out, zeros);
671 }
672 }
673}