Skip to main content

p3_mds/
util.rs

1use core::ops::{AddAssign, Mul};
2
3use p3_dft::TwoAdicSubgroupDft;
4use p3_field::{Algebra, PrimeCharacteristicRing, TwoAdicField};
5
6/// This will throw an error if N = 0 but it's hard to imagine this case coming up.
7#[inline(always)]
8pub fn dot_product<T, const N: usize>(u: [T; N], v: [T; N]) -> T
9where
10    T: Copy + AddAssign + Mul<Output = T>,
11{
12    debug_assert_ne!(N, 0);
13    let mut dp = u[0] * v[0];
14    for i in 1..N {
15        dp += u[i] * v[i];
16    }
17    dp
18}
19
20/// Given the first row `circ_matrix` of an NxN circulant matrix, say
21/// C, return the product `C*input`.
22///
23/// NB: This is a naive O(N^2) implementation. It serves as a fallback
24/// for cases where faster paths (Karatsuba convolution or FFT) do not
25/// apply — e.g. non-power-of-two widths, non-two-adic fields, or
26/// packed types without a specialised implementation.
27pub fn apply_circulant<R: PrimeCharacteristicRing, const N: usize>(
28    circ_matrix: &[u64; N],
29    input: &[R; N],
30) -> [R; N] {
31    let matrix = circ_matrix.map(R::from_u64);
32
33    core::array::from_fn(|row| {
34        // Build the circulant row: C[row][col] = first_row[(N + col - row) % N].
35        let rotated: [R; N] = core::array::from_fn(|col| matrix[(N + col - row) % N].clone());
36        R::dot_product(&rotated, input)
37    })
38}
39
40/// Given the first row of a circulant matrix, return the first column.
41///
42/// For example if, `v = [0, 1, 2, 3, 4, 5]` then `output = [0, 5, 4, 3, 2, 1]`,
43/// i.e. the first element is the same and the other elements are reversed.
44///
45/// This is useful to prepare a circulant matrix for input to an FFT
46/// algorithm, which expects the first column of the matrix rather
47/// than the first row (as we normally store them).
48///
49/// NB: The algorithm is inefficient but simple enough that this
50/// function can be declared `const`, and that is the intended context
51/// for use.
52pub const fn first_row_to_first_col<const N: usize, T: Copy>(v: &[T; N]) -> [T; N] {
53    // Start with a copy; the first element is shared between row and column.
54    let mut output = *v;
55    let mut i = 1;
56    while i < N {
57        // Reverse the remaining elements: col[i] = row[N - i].
58        output[i] = v[N - i];
59        i += 1;
60    }
61    output
62}
63
64/// Use the convolution theorem to calculate the product of the given
65/// circulant matrix and the given vector.
66///
67/// The circulant matrix must be specified by its first *column*, not its first row. If you have
68/// the row as an array, you can obtain the column with `first_row_to_first_col()`.
69#[inline]
70pub fn apply_circulant_fft<F: TwoAdicField, const N: usize, FFT: TwoAdicSubgroupDft<F>>(
71    fft: &FFT,
72    column: [u64; N],
73    input: &[F; N],
74) -> [F; N] {
75    // Transform the circulant column to the frequency domain.
76    let column = column.map(F::from_u64).to_vec();
77    let matrix = fft.dft(column);
78
79    // Transform the input vector to the frequency domain.
80    let input = fft.dft(input.to_vec());
81
82    // Convolution theorem: point-wise multiply in frequency domain.
83    let product = matrix.iter().zip(input).map(|(&x, y)| x * y).collect();
84
85    // Transform back to the time domain to get the circulant product.
86    let output = fft.idft(product);
87    output.try_into().unwrap()
88}
89
90/// Dense matrix-vector product, applied in place to a fixed-width state vector.
91///
92/// # Overview
93///
94/// - Generic O(t^2) fallback for any dense square matrix.
95/// - Circulant matrices have faster paths in this module (Karatsuba, FFT).
96/// - Sparse or diagonal layouts can skip full-row scans entirely.
97///
98/// # Arguments
99///
100/// - The state vector, overwritten with the product on return.
101/// - The matrix, indexed row-first as `m[row][col]`.
102///
103/// # Performance
104///
105/// - Runtime: O(t^2) ring operations for a width-t state.
106/// - Allocations: one stack snapshot of the input state.
107#[inline]
108pub fn mds_multiply<F, A, const WIDTH: usize>(state: &mut [A; WIDTH], matrix: &[[F; WIDTH]; WIDTH])
109where
110    F: PrimeCharacteristicRing,
111    A: Algebra<F>,
112{
113    // Snapshot inputs so in-place writes don't corrupt later row reads.
114    let input = state.clone();
115
116    //     output[i] = sum_{j=0..t} matrix[i][j] * snapshot[j]
117    for (out, row) in state.iter_mut().zip(matrix.iter()) {
118        *out = A::mixed_dot_product(&input, row);
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use p3_baby_bear::BabyBear;
125    use p3_dft::NaiveDft;
126    use p3_field::PrimeCharacteristicRing;
127    use proptest::prelude::*;
128
129    use super::*;
130
131    type F = BabyBear;
132
133    fn arb_f() -> impl Strategy<Value = F> {
134        prop::num::u32::ANY.prop_map(F::from_u32)
135    }
136
137    #[test]
138    fn first_row_to_first_col_even_length() {
139        let input = [0, 1, 2, 3, 4, 5];
140        assert_eq!(first_row_to_first_col(&input), [0, 5, 4, 3, 2, 1]);
141    }
142
143    #[test]
144    fn first_row_to_first_col_odd_length() {
145        let input = [10, 20, 30, 40, 50];
146        assert_eq!(first_row_to_first_col(&input), [10, 50, 40, 30, 20]);
147    }
148
149    #[test]
150    fn first_row_to_first_col_single_element() {
151        assert_eq!(first_row_to_first_col(&[42]), [42]);
152    }
153
154    #[test]
155    fn first_row_to_first_col_two_elements() {
156        assert_eq!(first_row_to_first_col(&[1, 2]), [1, 2]);
157    }
158
159    #[test]
160    fn apply_circulant_identity() {
161        // The identity circulant [1, 0, 0, ...] must return the input unchanged.
162        let identity_row: [u64; 4] = [1, 0, 0, 0];
163        let input: [F; 4] = [5, 10, 15, 20].map(F::from_u32);
164        assert_eq!(apply_circulant(&identity_row, &input), input);
165    }
166
167    #[test]
168    fn apply_circulant_all_ones() {
169        // An all-ones circulant sums every input element into every output slot.
170        let ones: [u64; 4] = [1, 1, 1, 1];
171        let input: [F; 4] = [1, 2, 3, 4].map(F::from_u32);
172        let sum = F::from_u32(10);
173        assert_eq!(apply_circulant(&ones, &input), [sum; 4]);
174    }
175
176    #[test]
177    fn apply_circulant_scalar() {
178        // A scalar circulant [k, 0, 0, ...] multiplies each element by k.
179        let row: [u64; 4] = [7, 0, 0, 0];
180        let input: [F; 4] = [1, 2, 3, 4].map(F::from_u32);
181        let expected: [F; 4] = [7, 14, 21, 28].map(F::from_u32);
182        assert_eq!(apply_circulant(&row, &input), expected);
183    }
184
185    #[test]
186    fn apply_circulant_size_1() {
187        // A 1x1 circulant is just scalar multiplication.
188        let row: [u64; 1] = [5];
189        let input: [F; 1] = [F::from_u32(3)];
190        assert_eq!(apply_circulant(&row, &input), [F::from_u32(15)]);
191    }
192
193    #[test]
194    fn apply_circulant_fft_matches_naive_4() {
195        // The FFT-based path must agree with the naive O(N^2) path.
196        let row: [u64; 4] = [2, 3, 5, 7];
197        let col = first_row_to_first_col(&row);
198        let input: [F; 4] = [1, 2, 3, 4].map(F::from_u32);
199
200        let naive = apply_circulant(&row, &input);
201        let fft_result = apply_circulant_fft(&NaiveDft, col, &input);
202        assert_eq!(naive, fft_result);
203    }
204
205    #[test]
206    fn apply_circulant_fft_identity() {
207        // The FFT-based identity circulant must also return the input unchanged.
208        let row: [u64; 4] = [1, 0, 0, 0];
209        let col = first_row_to_first_col(&row);
210        let input: [F; 4] = [5, 10, 15, 20].map(F::from_u32);
211        assert_eq!(apply_circulant_fft(&NaiveDft, col, &input), input);
212    }
213
214    proptest! {
215        #[test]
216        fn first_row_to_first_col_involution(v in prop::array::uniform4(0u64..1000)) {
217            let col = first_row_to_first_col(&v);
218            let back = first_row_to_first_col(&col);
219            prop_assert_eq!(back, v);
220        }
221
222        #[test]
223        fn apply_circulant_fft_matches_naive(
224            row in prop::array::uniform4(0u64..1000),
225            input in prop::array::uniform4(arb_f()),
226        ) {
227            let col = first_row_to_first_col(&row);
228            let naive = apply_circulant(&row, &input);
229            let fft_result = apply_circulant_fft(&NaiveDft, col, &input);
230            prop_assert_eq!(naive, fft_result);
231        }
232
233        #[test]
234        fn apply_circulant_linearity(
235            row in prop::array::uniform4(0u64..100),
236            a in prop::array::uniform4(arb_f()),
237            b in prop::array::uniform4(arb_f()),
238        ) {
239            let sum_input: [F; 4] = core::array::from_fn(|i| a[i] + b[i]);
240            let ca = apply_circulant(&row, &a);
241            let cb = apply_circulant(&row, &b);
242            let c_sum = apply_circulant(&row, &sum_input);
243            for i in 0..4 {
244                prop_assert_eq!(c_sum[i], ca[i] + cb[i]);
245            }
246        }
247
248        #[test]
249        fn apply_circulant_zero_matrix(input in prop::array::uniform4(arb_f())) {
250            let zeros: [u64; 4] = [0; 4];
251            let result = apply_circulant(&zeros, &input);
252            prop_assert_eq!(result, [F::ZERO; 4]);
253        }
254    }
255}