1use alloc::collections::BTreeMap;
2use alloc::slice;
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use core::mem::{MaybeUninit, transmute};
6
7use itertools::{Itertools, izip};
8use p3_field::integers::QuotientMap;
9use p3_field::{Field, Powers, TwoAdicField};
10use p3_matrix::Matrix;
11use p3_matrix::bitrev::{BitReversalPerm, BitReversedMatrixView, BitReversibleMatrix};
12use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView, RowMajorMatrixViewMut};
13use p3_matrix::util::reverse_matrix_index_bits;
14use p3_maybe_rayon::prelude::*;
15use p3_util::{log2_strict_usize, reverse_bits_len, reverse_slice_index_bits};
16use spin::RwLock;
17use tracing::{debug_span, instrument};
18
19use crate::butterflies::{Butterfly, DitButterfly, ScaledDitButterfly, TwiddleFreeButterfly};
20use crate::{Layout, TwoAdicSubgroupDft};
21
22#[derive(Default, Clone, Debug)]
30pub struct Radix2DitParallel<F> {
31 twiddles: Arc<RwLock<BTreeMap<usize, Arc<VectorPair<F>>>>>,
33
34 #[allow(clippy::type_complexity)]
36 coset_twiddles: Arc<RwLock<BTreeMap<(usize, F), Arc<[Vec<F>]>>>>,
37
38 inverse_twiddles: Arc<RwLock<BTreeMap<usize, Arc<VectorPair<F>>>>>,
40}
41
42#[derive(Default, Clone, Debug)]
44struct VectorPair<F> {
45 twiddles: Vec<F>,
46 bitrev_twiddles: Vec<F>,
47}
48
49impl<F> Radix2DitParallel<F>
50where
51 F: TwoAdicField + Ord,
52{
53 fn get_or_compute_twiddles(&self, log_h: usize) -> Arc<VectorPair<F>> {
54 if let Some(pair) = self.twiddles.read().get(&log_h) {
56 return pair.clone();
57 }
58
59 let mut w_lock = self.twiddles.write();
61
62 w_lock
64 .entry(log_h)
65 .or_insert_with(|| {
66 let half_h = (1 << log_h) >> 1;
67 let root = F::two_adic_generator(log_h);
68 let twiddles = root.powers().collect_n(half_h);
69 let mut bitrev_twiddles = twiddles.clone();
70 reverse_slice_index_bits(&mut bitrev_twiddles);
71
72 Arc::new(VectorPair {
73 twiddles,
74 bitrev_twiddles,
75 })
76 })
77 .clone()
78 }
79
80 fn get_or_compute_coset_twiddles(&self, (log_h, shift): (usize, F)) -> Arc<[Vec<F>]> {
81 let key = (log_h, shift);
82 if let Some(twiddles) = self.coset_twiddles.read().get(&key) {
84 return twiddles.clone();
85 }
86 let mut w_lock = self.coset_twiddles.write();
89 w_lock
92 .entry(key)
93 .or_insert_with(|| {
94 let mid = log_h.div_ceil(2);
95 let h = 1 << log_h;
96 let root = F::two_adic_generator(log_h);
97 (0..log_h)
98 .map(|layer| {
99 let shift_power = shift.exp_power_of_2(layer);
100 let powers = Powers {
101 base: root.exp_power_of_2(layer),
102 current: shift_power,
103 };
104 let mut twiddles = powers.collect_n(h >> (layer + 1));
105 let layer_rev = log_h - 1 - layer;
106 if layer_rev >= mid {
107 reverse_slice_index_bits(&mut twiddles);
108 }
109 twiddles
110 })
111 .collect::<Vec<_>>()
112 .into()
113 })
114 .clone()
115 }
116
117 fn get_or_compute_inverse_twiddles(&self, log_h: usize) -> Arc<VectorPair<F>> {
118 if let Some(pair) = self.inverse_twiddles.read().get(&log_h) {
120 return pair.clone();
121 }
122 let mut w_lock = self.inverse_twiddles.write();
124 w_lock
127 .entry(log_h)
128 .or_insert_with(|| {
129 let half_h = (1 << log_h) >> 1;
131 let root_inv = F::two_adic_generator(log_h).inverse();
132 let twiddles = root_inv.powers().collect_n(half_h);
133 let mut bitrev_twiddles = twiddles.clone();
134 reverse_slice_index_bits(&mut bitrev_twiddles);
135
136 Arc::new(VectorPair {
137 twiddles,
138 bitrev_twiddles,
139 })
140 })
141 .clone()
142 }
143}
144
145impl<F: TwoAdicField + Ord> TwoAdicSubgroupDft<F> for Radix2DitParallel<F> {
146 type Evaluations = BitReversedMatrixView<RowMajorMatrix<F>>;
147
148 fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> Self::Evaluations {
149 let h = mat.height();
150 let log_h = log2_strict_usize(h);
151
152 let twiddles = self.get_or_compute_twiddles(log_h);
154
155 let mid = log_h.div_ceil(2);
156
157 reverse_matrix_index_bits(&mut mat);
159 first_half(&mut mat, mid, &twiddles.twiddles);
160
161 reverse_matrix_index_bits(&mut mat);
163 second_half(&mut mat, mid, &twiddles.bitrev_twiddles, None);
164
165 mat.bit_reverse_rows()
166 }
167
168 fn coset_dft_batch(&self, mut mat: RowMajorMatrix<F>, shift: F) -> Self::Evaluations {
169 reverse_matrix_index_bits(&mut mat);
170 coset_dft(self, &mut mat.as_view_mut(), shift);
171 BitReversalPerm::new_view(mat)
172 }
173
174 fn coset_idft_batch(&self, mat: RowMajorMatrix<F>, shift: F) -> RowMajorMatrix<F> {
175 let mut coeffs = self.idft_batch(mat);
176 crate::util::coset_shift_cols(&mut coeffs, shift.inverse());
177 coeffs
178 }
179
180 #[instrument(skip_all, level = "debug", fields(dims = %mat.dimensions(), added_bits = added_bits))]
181 fn coset_lde_batch_with_transform<T>(
182 &self,
183 mut mat: RowMajorMatrix<F>,
184 added_bits: usize,
185 shift: F,
186 transform: T,
187 ) -> Self::Evaluations
188 where
189 T: FnOnce(&mut RowMajorMatrixViewMut<'_, F>, Layout),
190 {
191 let w = mat.width;
192 let h = mat.height();
193 let log_h = log2_strict_usize(h);
194 let mid = log_h.div_ceil(2);
195
196 let inverse_twiddles = self.get_or_compute_inverse_twiddles(log_h);
197
198 reverse_matrix_index_bits(&mut mat);
200 first_half(&mut mat, mid, &inverse_twiddles.twiddles);
201
202 reverse_matrix_index_bits(&mut mat);
204 let h_inv_subfield = F::PrimeSubfield::from_int(h).try_inverse();
208 let scale = h_inv_subfield.map(F::from_prime_subfield);
209 second_half(&mut mat, mid, &inverse_twiddles.bitrev_twiddles, scale);
210 transform(&mut mat.as_view_mut(), Layout::BitReversed);
213
214 let lde_elems = w * (h << added_bits);
215 let elems_to_add = lde_elems - w * h;
216 debug_span!("reserve_exact").in_scope(|| mat.values.reserve_exact(elems_to_add));
217
218 let g_big = F::two_adic_generator(log_h + added_bits);
219
220 let mat_ptr = mat.values.as_mut_ptr();
221 let rest_ptr = unsafe { (mat_ptr as *mut MaybeUninit<F>).add(w * h) };
222 let first_slice: &mut [F] = unsafe { slice::from_raw_parts_mut(mat_ptr, w * h) };
223 let rest_slice: &mut [MaybeUninit<F>] =
224 unsafe { slice::from_raw_parts_mut(rest_ptr, lde_elems - w * h) };
225 let mut first_coset_mat = RowMajorMatrixViewMut::new(first_slice, w);
226 let mut rest_cosets_mat = rest_slice
227 .chunks_exact_mut(w * h)
228 .map(|slice| RowMajorMatrixViewMut::new(slice, w))
229 .collect_vec();
230
231 for coset_idx in 1..(1 << added_bits) {
232 let total_shift = g_big.exp_u64(coset_idx as u64) * shift;
233 let coset_idx = reverse_bits_len(coset_idx, added_bits);
234 let dest = &mut rest_cosets_mat[coset_idx - 1]; coset_dft_oop(self, &first_coset_mat.as_view(), dest, total_shift);
236 }
237
238 coset_dft(self, &mut first_coset_mat.as_view_mut(), shift);
240
241 unsafe {
243 mat.values.set_len(lde_elems);
244 }
245 BitReversalPerm::new_view(mat)
246 }
247}
248
249#[instrument(level = "debug", skip_all)]
250fn coset_dft<F: TwoAdicField + Ord>(
251 dft: &Radix2DitParallel<F>,
252 mat: &mut RowMajorMatrixViewMut<'_, F>,
253 shift: F,
254) {
255 let log_h = log2_strict_usize(mat.height());
256 let mid = log_h.div_ceil(2);
257
258 let twiddles = dft.get_or_compute_coset_twiddles((log_h, shift));
259
260 first_half_general(mat, mid, &twiddles);
262
263 reverse_matrix_index_bits(mat);
265
266 second_half_general(mat, mid, &twiddles);
267}
268
269#[instrument(level = "debug", skip_all)]
271fn coset_dft_oop<F: TwoAdicField + Ord>(
272 dft: &Radix2DitParallel<F>,
273 src: &RowMajorMatrixView<'_, F>,
274 dst_maybe: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
275 shift: F,
276) {
277 assert_eq!(src.dimensions(), dst_maybe.dimensions());
278
279 let log_h = log2_strict_usize(dst_maybe.height());
280
281 if log_h == 0 {
282 let src_maybe = unsafe {
285 transmute::<&RowMajorMatrixView<'_, F>, &RowMajorMatrixView<'_, MaybeUninit<F>>>(src)
286 };
287 dst_maybe.copy_from(src_maybe);
288 return;
289 }
290
291 let mid = log_h.div_ceil(2);
292
293 let twiddles = dft.get_or_compute_coset_twiddles((log_h, shift));
294
295 first_half_general_oop(src, dst_maybe, mid, &twiddles);
297
298 let dst = unsafe {
300 transmute::<&mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>, &mut RowMajorMatrixViewMut<'_, F>>(
301 dst_maybe,
302 )
303 };
304
305 reverse_matrix_index_bits(dst);
307
308 second_half_general(dst, mid, &twiddles);
309}
310
311#[instrument(level = "debug", skip_all)]
319fn first_half<F: Field>(mat: &mut RowMajorMatrix<F>, mid: usize, twiddles: &[F]) {
320 let log_h = log2_strict_usize(mat.height());
321
322 mat.par_row_chunks_exact_mut(1 << mid)
324 .for_each(|mut submat| {
325 let mut backwards = false;
326 for layer in 0..mid {
327 if layer == 0 {
328 dit_layer_twiddle_free(&mut submat, backwards);
332 } else {
333 let layer_rev = log_h - 1 - layer;
334 let layer_pow = 1 << layer_rev;
335 dit_layer_first_one(
339 &mut submat,
340 layer,
341 twiddles.iter().step_by(layer_pow),
342 backwards,
343 );
344 }
345 backwards = !backwards;
346 }
347 });
348}
349
350#[instrument(level = "debug", skip_all)]
353fn first_half_general<F: Field>(
354 mat: &mut RowMajorMatrixViewMut<'_, F>,
355 mid: usize,
356 twiddles: &[Vec<F>],
357) {
358 let log_h = log2_strict_usize(mat.height());
359 mat.par_row_chunks_exact_mut(1 << mid)
360 .for_each(|mut submat| {
361 let mut backwards = false;
362 for layer in 0..mid {
363 let layer_rev = log_h - 1 - layer;
364 dit_layer(&mut submat, layer, twiddles[layer_rev].iter(), backwards);
365 backwards = !backwards;
366 }
367 });
368}
369
370#[instrument(level = "debug", skip_all)]
377fn first_half_general_oop<F: Field>(
378 src: &RowMajorMatrixView<'_, F>,
379 dst_maybe: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
380 mid: usize,
381 twiddles: &[Vec<F>],
382) {
383 let log_h = log2_strict_usize(src.height());
384 src.par_row_chunks_exact(1 << mid)
385 .zip(dst_maybe.par_row_chunks_exact_mut(1 << mid))
386 .for_each(|(src_submat, mut dst_submat_maybe)| {
387 debug_assert_eq!(src_submat.dimensions(), dst_submat_maybe.dimensions());
388
389 let layer_rev = log_h - 1;
392 dit_layer_oop(
393 &src_submat,
394 &mut dst_submat_maybe,
395 0,
396 twiddles[layer_rev].iter(),
397 );
398
399 let mut dst_submat = unsafe {
401 transmute::<RowMajorMatrixViewMut<'_, MaybeUninit<F>>, RowMajorMatrixViewMut<'_, F>>(
402 dst_submat_maybe,
403 )
404 };
405
406 let mut backwards = true;
408 for layer in 1..mid {
409 let layer_rev = log_h - 1 - layer;
410 dit_layer(
411 &mut dst_submat,
412 layer,
413 twiddles[layer_rev].iter(),
414 backwards,
415 );
416 backwards = !backwards;
417 }
418 });
419}
420
421#[instrument(level = "debug", skip_all)]
427#[inline(always)] fn second_half<F: Field>(
429 mat: &mut RowMajorMatrix<F>,
430 mid: usize,
431 twiddles_rev: &[F],
432 scale: Option<F>,
433) {
434 let log_h = log2_strict_usize(mat.height());
435
436 mat.par_row_chunks_exact_mut(1 << (log_h - mid))
438 .enumerate()
439 .for_each(|(thread, mut submat)| {
440 let mut backwards = false;
441 if let Some(scale) = scale {
442 let mut scale_applied = false;
446 for layer in mid..log_h {
447 let first_block = thread << (layer - mid);
448 if !scale_applied {
449 scale_applied = true;
450 dit_layer_rev_scaled(
451 &mut submat,
452 log_h,
453 layer,
454 twiddles_rev[first_block..].iter().copied(),
455 backwards,
456 Some(scale),
457 );
458 } else {
459 dit_layer_rev(
460 &mut submat,
461 log_h,
462 layer,
463 twiddles_rev[first_block..].iter().copied(),
464 backwards,
465 );
466 }
467 backwards = !backwards;
468 }
469 if !scale_applied {
471 submat.scale(scale);
472 }
473 } else {
474 for layer in mid..log_h {
475 let first_block = thread << (layer - mid);
476 dit_layer_rev(
477 &mut submat,
478 log_h,
479 layer,
480 twiddles_rev[first_block..].iter().copied(),
481 backwards,
482 );
483 backwards = !backwards;
484 }
485 }
486 });
487}
488
489#[instrument(level = "debug", skip_all)]
492fn second_half_general<F: Field>(
493 mat: &mut RowMajorMatrixViewMut<'_, F>,
494 mid: usize,
495 twiddles_rev: &[Vec<F>],
496) {
497 let log_h = log2_strict_usize(mat.height());
498 mat.par_row_chunks_exact_mut(1 << (log_h - mid))
499 .enumerate()
500 .for_each(|(thread, mut submat)| {
501 let mut backwards = false;
502 for layer in mid..log_h {
503 let layer_rev = log_h - 1 - layer;
504 let first_block = thread << (layer - mid);
505 dit_layer_rev(
506 &mut submat,
507 log_h,
508 layer,
509 twiddles_rev[layer_rev][first_block..].iter().copied(),
510 backwards,
511 );
512 backwards = !backwards;
513 }
514 });
515}
516
517fn dit_layer_twiddle_free<F: Field>(submat: &mut RowMajorMatrixViewMut<'_, F>, backwards: bool) {
526 let width = submat.width();
528 debug_assert!(submat.height() >= 2);
529
530 let process_block = move |block: &mut [F]| {
531 let (lo, hi) = block.split_at_mut(width);
533 TwiddleFreeButterfly.apply_to_rows(lo, hi);
534 };
535
536 let blocks = submat.values.chunks_mut(2 * width);
537 if backwards {
538 for block in blocks.rev() {
539 process_block(block);
540 }
541 } else {
542 for block in blocks {
543 process_block(block);
544 }
545 }
546}
547
548fn dit_layer_first_one<'a, F: Field>(
557 submat: &mut RowMajorMatrixViewMut<'_, F>,
558 layer: usize,
559 twiddles: impl Iterator<Item = &'a F> + Clone,
560 backwards: bool,
561) {
562 let half_block_size = 1 << layer;
563 let block_size = half_block_size * 2;
564 let width = submat.width();
565 debug_assert!(submat.height() >= block_size);
566 debug_assert!(
567 half_block_size >= 2,
568 "layer must be >= 1 for dit_layer_first_one"
569 );
570
571 let process_block = move |block: &mut [F]| {
572 let (lows, highs) = block.split_at_mut(half_block_size * width);
573 let mut tw_iter = twiddles.clone();
574 let _ = tw_iter.next(); let (lo0, lo_rest) = lows.split_at_mut(width);
577 let (hi0, hi_rest) = highs.split_at_mut(width);
578 TwiddleFreeButterfly.apply_to_rows(lo0, hi0);
579 for (lo, hi, twiddle) in izip!(
581 lo_rest.chunks_mut(width),
582 hi_rest.chunks_mut(width),
583 tw_iter
584 ) {
585 DitButterfly(*twiddle).apply_to_rows(lo, hi);
586 }
587 };
588
589 let blocks = submat.values.chunks_mut(block_size * width);
590 if backwards {
591 for block in blocks.rev() {
592 process_block(block);
593 }
594 } else {
595 for block in blocks {
596 process_block(block);
597 }
598 }
599}
600
601fn dit_layer<'a, F: Field>(
603 submat: &mut RowMajorMatrixViewMut<'_, F>,
604 layer: usize,
605 twiddles: impl Iterator<Item = &'a F> + Clone,
606 backwards: bool,
607) {
608 let half_block_size = 1 << layer;
609 let block_size = half_block_size * 2;
610 let width = submat.width();
611 debug_assert!(submat.height() >= block_size);
612
613 let process_block = move |block: &mut [F]| {
614 let (lows, highs) = block.split_at_mut(half_block_size * width);
615 for (lo, hi, twiddle) in izip!(
616 lows.chunks_mut(width),
617 highs.chunks_mut(width),
618 twiddles.clone()
619 ) {
620 DitButterfly(*twiddle).apply_to_rows(lo, hi);
621 }
622 };
623
624 let blocks = submat.values.chunks_mut(block_size * width);
625 if backwards {
626 for block in blocks.rev() {
627 process_block(block);
628 }
629 } else {
630 for block in blocks {
631 process_block(block);
632 }
633 }
634}
635
636fn dit_layer_oop<'a, F: Field>(
638 src: &RowMajorMatrixView<'_, F>,
639 dst: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
640 layer: usize,
641 twiddles: impl Iterator<Item = &'a F> + Clone,
642) {
643 debug_assert_eq!(src.dimensions(), dst.dimensions());
644 let half_block_size = 1 << layer;
645 let block_size = half_block_size * 2;
646 let width = dst.width();
647 debug_assert!(dst.height() >= block_size);
648
649 let process_blocks = move |src_block: &[F], dst_block: &mut [MaybeUninit<F>]| {
650 let (src_lows, src_highs) = src_block.split_at(half_block_size * width);
651 let (dst_lows, dst_highs) = dst_block.split_at_mut(half_block_size * width);
652
653 for (src_lo, dst_lo, src_hi, dst_hi, twiddle) in izip!(
654 src_lows.chunks(width),
655 dst_lows.chunks_mut(width),
656 src_highs.chunks(width),
657 dst_highs.chunks_mut(width),
658 twiddles.clone()
659 ) {
660 DitButterfly(*twiddle).apply_to_rows_oop(src_lo, dst_lo, src_hi, dst_hi);
661 }
662 };
663
664 let src_chunks = src.values.chunks(block_size * width);
665 let dst_chunks = dst.values.chunks_mut(block_size * width);
666
667 for (src_block, dst_block) in src_chunks.zip(dst_chunks) {
668 process_blocks(src_block, dst_block);
669 }
670}
671
672fn dit_layer_rev_scaled<F: Field>(
680 submat: &mut RowMajorMatrixViewMut<'_, F>,
681 log_h: usize,
682 layer: usize,
683 twiddles_rev: impl DoubleEndedIterator<Item = F> + ExactSizeIterator,
684 backwards: bool,
685 scale: Option<F>,
686) {
687 let layer_rev = log_h - 1 - layer;
688
689 let half_block_size = 1 << layer_rev;
690 let block_size = half_block_size * 2;
691 let width = submat.width();
692 debug_assert!(submat.height() >= block_size);
693
694 match scale {
695 None => {
696 let blocks_and_twiddles = submat
698 .values
699 .chunks_mut(block_size * width)
700 .zip(twiddles_rev);
701 if backwards {
702 for (block, twiddle) in blocks_and_twiddles.rev() {
703 let (lo, hi) = block.split_at_mut(half_block_size * width);
704 DitButterfly(twiddle).apply_to_rows(lo, hi);
705 }
706 } else {
707 for (block, twiddle) in blocks_and_twiddles {
708 let (lo, hi) = block.split_at_mut(half_block_size * width);
709 DitButterfly(twiddle).apply_to_rows(lo, hi);
710 }
711 }
712 }
713 Some(s) => {
714 let blocks_and_twiddles = submat
718 .values
719 .chunks_mut(block_size * width)
720 .zip(twiddles_rev);
721 if backwards {
722 for (block, twiddle) in blocks_and_twiddles.rev() {
723 let (lo, hi) = block.split_at_mut(half_block_size * width);
724 ScaledDitButterfly::new(twiddle, s).apply_to_rows(lo, hi);
725 }
726 } else {
727 for (block, twiddle) in blocks_and_twiddles {
728 let (lo, hi) = block.split_at_mut(half_block_size * width);
729 ScaledDitButterfly::new(twiddle, s).apply_to_rows(lo, hi);
730 }
731 }
732 }
733 }
734}
735
736fn dit_layer_rev<F: Field>(
739 submat: &mut RowMajorMatrixViewMut<'_, F>,
740 log_h: usize,
741 layer: usize,
742 twiddles_rev: impl DoubleEndedIterator<Item = F> + ExactSizeIterator,
743 backwards: bool,
744) {
745 let layer_rev = log_h - 1 - layer;
746
747 let half_block_size = 1 << layer_rev;
748 let block_size = half_block_size * 2;
749 let width = submat.width();
750 debug_assert!(submat.height() >= block_size);
751
752 let blocks_and_twiddles = submat
753 .values
754 .chunks_mut(block_size * width)
755 .zip(twiddles_rev);
756 if backwards {
757 for (block, twiddle) in blocks_and_twiddles.rev() {
758 let (lo, hi) = block.split_at_mut(half_block_size * width);
759 DitButterfly(twiddle).apply_to_rows(lo, hi);
760 }
761 } else {
762 for (block, twiddle) in blocks_and_twiddles {
763 let (lo, hi) = block.split_at_mut(half_block_size * width);
764 DitButterfly(twiddle).apply_to_rows(lo, hi);
765 }
766 }
767}
768
769#[cfg(test)]
770mod tests {
771 use p3_baby_bear::BabyBear;
772 use p3_field::TwoAdicField;
773 use p3_matrix::Matrix;
774 use p3_matrix::dense::RowMajorMatrix;
775 use rand::SeedableRng;
776 use rand::rngs::SmallRng;
777
778 use super::*;
779
780 type F = BabyBear;
781
782 #[test]
783 fn coset_dft_idft_roundtrip() {
784 let dft = Radix2DitParallel::<F>::default();
785 let shift = F::GENERATOR;
786 let mut rng = SmallRng::seed_from_u64(42);
787 let original = RowMajorMatrix::<F>::rand(&mut rng, 16, 3);
788
789 let evals = dft.coset_dft_batch(original.clone(), shift);
790 let recovered = dft.coset_idft_batch(evals.to_row_major_matrix(), shift);
791
792 assert_eq!(original, recovered);
793 }
794
795 #[test]
796 fn coset_dft_matches_default_trait() {
797 let dft = Radix2DitParallel::<F>::default();
798 let shift = F::two_adic_generator(4) * F::GENERATOR;
799 let mut rng = SmallRng::seed_from_u64(7);
800 let mat = RowMajorMatrix::<F>::rand(&mut rng, 16, 4);
801
802 let override_result = dft
803 .coset_dft_batch(mat.clone(), shift)
804 .to_row_major_matrix();
805
806 let mut shifted = mat;
807 crate::util::coset_shift_cols(&mut shifted, shift);
808 let default_result = dft.dft_batch(shifted).to_row_major_matrix();
809
810 assert_eq!(override_result, default_result);
811 }
812}