p3_dft/
util.rs

1use core::borrow::BorrowMut;
2
3use p3_field::{Field, PrimeCharacteristicRing};
4use p3_matrix::Matrix;
5use p3_matrix::dense::{DenseMatrix, DenseStorage, RowMajorMatrix};
6use p3_util::log2_strict_usize;
7use tracing::instrument;
8
9/// Divide each coefficient of the given matrix by its height.
10///
11/// # Panics
12///
13/// Panics if the height of the matrix is not a power of two.
14#[instrument(skip_all, fields(dims = %mat.dimensions()))]
15pub fn divide_by_height<F: Field, S: DenseStorage<F> + BorrowMut<[F]>>(
16    mat: &mut DenseMatrix<F, S>,
17) {
18    let h = mat.height();
19    let log_h = log2_strict_usize(h);
20    // It's cheaper to use div_2exp_u64 as this usually avoids an inversion.
21    // It's also cheaper to work in the PrimeSubfield whenever possible.
22    let h_inv_subfield = F::PrimeSubfield::ONE.div_2exp_u64(log_h as u64);
23    let h_inv = F::from_prime_subfield(h_inv_subfield);
24    mat.scale(h_inv)
25}
26
27/// Multiply each element of row `i` of `mat` by `shift**i`.
28pub(crate) fn coset_shift_cols<F: Field>(mat: &mut RowMajorMatrix<F>, shift: F) {
29    mat.rows_mut()
30        .zip(shift.powers())
31        .for_each(|(row, weight)| {
32            row.iter_mut().for_each(|coeff| {
33                *coeff *= weight;
34            });
35        });
36}
37
38#[cfg(test)]
39mod tests {
40    use alloc::vec;
41
42    use p3_baby_bear::BabyBear;
43    use p3_matrix::dense::RowMajorMatrix;
44
45    use super::*;
46
47    type F = BabyBear;
48
49    #[test]
50    fn test_divide_by_height_2x2() {
51        // Matrix:
52        // [ 2, 4 ]
53        // [ 6, 8 ]
54        //
55        // height = 2 => divide each element by 2
56        let mut mat = RowMajorMatrix::new(
57            vec![F::from_u8(2), F::from_u8(4), F::from_u8(6), F::from_u8(8)],
58            2,
59        );
60
61        divide_by_height(&mut mat);
62
63        // Compute: [2, 4, 6, 8] * 1/2 = [1, 2, 3, 4]
64        let expected = vec![F::from_u8(1), F::from_u8(2), F::from_u8(3), F::from_u8(4)];
65
66        assert_eq!(mat.values, expected);
67    }
68
69    #[test]
70    fn test_divide_by_height_1x4() {
71        // Matrix:
72        // [ 10, 20, 30, 40 ]
73        // height = 1 => no division (1⁻¹ = 1), matrix should remain unchanged
74        let mut mat = RowMajorMatrix::new_row(vec![
75            F::from_u8(10),
76            F::from_u8(20),
77            F::from_u8(30),
78            F::from_u8(40),
79        ]);
80
81        divide_by_height(&mut mat);
82
83        let expected = vec![
84            F::from_u8(10),
85            F::from_u8(20),
86            F::from_u8(30),
87            F::from_u8(40),
88        ];
89
90        assert_eq!(mat.values, expected);
91    }
92
93    #[test]
94    #[should_panic]
95    fn test_divide_by_height_non_power_of_two_height_should_panic() {
96        // Matrix of height = 3 is not a power of two → should panic
97        let mut mat = RowMajorMatrix::new(vec![F::from_u8(1), F::from_u8(2), F::from_u8(3)], 1);
98
99        divide_by_height(&mut mat);
100    }
101
102    #[test]
103    fn test_coset_shift_cols_3x2_shift_2() {
104        // Input matrix:
105        // [ 1, 2 ]
106        // [ 3, 4 ]
107        // [ 5, 6 ]
108        //
109        // shift = 2
110        // Row 0: shift^0 = 1 → [1 * 1, 2 * 1] = [1, 2]
111        // Row 1: shift^1 = 2 → [3 * 2, 4 * 2] = [6, 8]
112        // Row 2: shift^2 = 4 → [5 * 4, 6 * 4] = [20, 24]
113
114        let mut mat = RowMajorMatrix::new(
115            vec![
116                F::from_u8(1),
117                F::from_u8(2),
118                F::from_u8(3),
119                F::from_u8(4),
120                F::from_u8(5),
121                F::from_u8(6),
122            ],
123            2,
124        );
125
126        coset_shift_cols(&mut mat, F::from_u8(2));
127
128        let expected = vec![
129            F::from_u8(1),
130            F::from_u8(2),
131            F::from_u8(6),
132            F::from_u8(8),
133            F::from_u8(20),
134            F::from_u8(24),
135        ];
136
137        assert_eq!(mat.values, expected);
138    }
139
140    #[test]
141    fn test_coset_shift_cols_identity_shift() {
142        // shift = 1 → all weights = 1 → matrix should remain unchanged
143        let mut mat = RowMajorMatrix::new(
144            vec![F::from_u8(7), F::from_u8(8), F::from_u8(9), F::from_u8(10)],
145            2,
146        );
147
148        coset_shift_cols(&mut mat, F::from_u8(1));
149
150        let expected = vec![F::from_u8(7), F::from_u8(8), F::from_u8(9), F::from_u8(10)];
151
152        assert_eq!(mat.values, expected);
153    }
154}