Skip to main content

p3_dft/
util.rs

1use core::borrow::BorrowMut;
2
3use p3_field::{Field, PrimeCharacteristicRing, scale_slice_in_place_single_core};
4use p3_matrix::Matrix;
5use p3_matrix::dense::{DenseMatrix, DenseStorage, RowMajorMatrix};
6use p3_maybe_rayon::prelude::*;
7use p3_util::log2_strict_usize;
8use tracing::instrument;
9
10/// Divide each coefficient of the given matrix by its height.
11///
12/// # Panics
13///
14/// Panics if the height of the matrix is not a power of two.
15#[instrument(skip_all, fields(dims = %mat.dimensions()))]
16pub fn divide_by_height<F: Field, S: DenseStorage<F> + BorrowMut<[F]>>(
17    mat: &mut DenseMatrix<F, S>,
18) {
19    let h = mat.height();
20    let log_h = log2_strict_usize(h);
21    // It's cheaper to use div_2exp_u64 as this usually avoids an inversion.
22    // It's also cheaper to work in the PrimeSubfield whenever possible.
23    let h_inv_subfield = F::PrimeSubfield::ONE.div_2exp_u64(log_h as u64);
24    let h_inv = F::from_prime_subfield(h_inv_subfield);
25    mat.scale(h_inv);
26}
27
28/// Multiply each element of row `i` of `mat` by `shift**i`.
29///
30/// Scales row chunks in parallel, computing one starting power per chunk and
31/// then advancing it row-by-row inside the chunk.
32pub(crate) fn coset_shift_cols<F: Field>(mat: &mut RowMajorMatrix<F>, shift: F) {
33    if shift == F::ONE {
34        return;
35    }
36
37    // Keep chunks large enough to amortize the starting-power exponentiation,
38    // but small enough to avoid allocating one weight per row.
39    const TARGET_CHUNK_VALUES: usize = 1 << 15;
40    let width = mat.width();
41    let chunk_rows = (TARGET_CHUNK_VALUES / width).max(1);
42
43    mat.par_row_chunks_mut(chunk_rows)
44        .enumerate()
45        .for_each(|(chunk_idx, mut rows)| {
46            let mut weight = shift.exp_u64((chunk_idx * chunk_rows) as u64);
47            for row in rows.rows_mut() {
48                scale_slice_in_place_single_core(row, weight);
49                weight *= shift;
50            }
51        });
52}
53
54#[cfg(test)]
55mod tests {
56    use alloc::vec;
57
58    use p3_baby_bear::BabyBear;
59    use p3_matrix::dense::RowMajorMatrix;
60
61    use super::*;
62
63    type F = BabyBear;
64
65    #[test]
66    fn test_divide_by_height_2x2() {
67        // Matrix:
68        // [ 2, 4 ]
69        // [ 6, 8 ]
70        //
71        // height = 2 => divide each element by 2
72        let mut mat = RowMajorMatrix::new(
73            vec![F::from_u8(2), F::from_u8(4), F::from_u8(6), F::from_u8(8)],
74            2,
75        );
76
77        divide_by_height(&mut mat);
78
79        // Compute: [2, 4, 6, 8] * 1/2 = [1, 2, 3, 4]
80        let expected = vec![F::from_u8(1), F::from_u8(2), F::from_u8(3), F::from_u8(4)];
81
82        assert_eq!(mat.values, expected);
83    }
84
85    #[test]
86    fn test_divide_by_height_1x4() {
87        // Matrix:
88        // [ 10, 20, 30, 40 ]
89        // height = 1 => no division (1⁻¹ = 1), matrix should remain unchanged
90        let mut mat = RowMajorMatrix::new_row(vec![
91            F::from_u8(10),
92            F::from_u8(20),
93            F::from_u8(30),
94            F::from_u8(40),
95        ]);
96
97        divide_by_height(&mut mat);
98
99        let expected = vec![
100            F::from_u8(10),
101            F::from_u8(20),
102            F::from_u8(30),
103            F::from_u8(40),
104        ];
105
106        assert_eq!(mat.values, expected);
107    }
108
109    #[test]
110    #[should_panic]
111    fn test_divide_by_height_non_power_of_two_height_should_panic() {
112        // Matrix of height = 3 is not a power of two → should panic
113        let mut mat = RowMajorMatrix::new(vec![F::from_u8(1), F::from_u8(2), F::from_u8(3)], 1);
114
115        divide_by_height(&mut mat);
116    }
117
118    #[test]
119    fn test_coset_shift_cols_3x2_shift_2() {
120        // Input matrix:
121        // [ 1, 2 ]
122        // [ 3, 4 ]
123        // [ 5, 6 ]
124        //
125        // shift = 2
126        // Row 0: shift^0 = 1 → [1 * 1, 2 * 1] = [1, 2]
127        // Row 1: shift^1 = 2 → [3 * 2, 4 * 2] = [6, 8]
128        // Row 2: shift^2 = 4 → [5 * 4, 6 * 4] = [20, 24]
129
130        let mut mat = RowMajorMatrix::new(
131            vec![
132                F::from_u8(1),
133                F::from_u8(2),
134                F::from_u8(3),
135                F::from_u8(4),
136                F::from_u8(5),
137                F::from_u8(6),
138            ],
139            2,
140        );
141
142        coset_shift_cols(&mut mat, F::from_u8(2));
143
144        let expected = vec![
145            F::from_u8(1),
146            F::from_u8(2),
147            F::from_u8(6),
148            F::from_u8(8),
149            F::from_u8(20),
150            F::from_u8(24),
151        ];
152
153        assert_eq!(mat.values, expected);
154    }
155
156    #[test]
157    fn test_coset_shift_cols_early_return_for_one_shift() {
158        // shift = 1 takes the early-return path and leaves the matrix unchanged.
159        let mut mat = RowMajorMatrix::new(
160            vec![F::from_u8(7), F::from_u8(8), F::from_u8(9), F::from_u8(10)],
161            2,
162        );
163        let expected = mat.clone();
164
165        coset_shift_cols(&mut mat, F::ONE);
166
167        assert_eq!(mat, expected);
168    }
169
170    #[test]
171    fn test_coset_shift_cols_matches_scalar_reference() {
172        let mut actual = RowMajorMatrix::new((1u8..=24).map(F::from_u8).collect(), 3);
173        let mut expected = actual.clone();
174        let shift = F::from_u8(3);
175
176        let mut weight = F::ONE;
177        for row in expected.rows_mut() {
178            for coeff in row.iter_mut() {
179                *coeff *= weight;
180            }
181            weight *= shift;
182        }
183
184        coset_shift_cols(&mut actual, shift);
185
186        assert_eq!(actual, expected);
187    }
188}