1use core::ops::{AddAssign, Mul};
2
3use p3_dft::TwoAdicSubgroupDft;
4use p3_field::{Algebra, PrimeCharacteristicRing, TwoAdicField};
5
6#[inline(always)]
8pub fn dot_product<T, const N: usize>(u: [T; N], v: [T; N]) -> T
9where
10 T: Copy + AddAssign + Mul<Output = T>,
11{
12 debug_assert_ne!(N, 0);
13 let mut dp = u[0] * v[0];
14 for i in 1..N {
15 dp += u[i] * v[i];
16 }
17 dp
18}
19
20pub fn apply_circulant<R: PrimeCharacteristicRing, const N: usize>(
28 circ_matrix: &[u64; N],
29 input: &[R; N],
30) -> [R; N] {
31 let matrix = circ_matrix.map(R::from_u64);
32
33 core::array::from_fn(|row| {
34 let rotated: [R; N] = core::array::from_fn(|col| matrix[(N + col - row) % N].clone());
36 R::dot_product(&rotated, input)
37 })
38}
39
40pub const fn first_row_to_first_col<const N: usize, T: Copy>(v: &[T; N]) -> [T; N] {
53 let mut output = *v;
55 let mut i = 1;
56 while i < N {
57 output[i] = v[N - i];
59 i += 1;
60 }
61 output
62}
63
64#[inline]
70pub fn apply_circulant_fft<F: TwoAdicField, const N: usize, FFT: TwoAdicSubgroupDft<F>>(
71 fft: &FFT,
72 column: [u64; N],
73 input: &[F; N],
74) -> [F; N] {
75 let column = column.map(F::from_u64).to_vec();
77 let matrix = fft.dft(column);
78
79 let input = fft.dft(input.to_vec());
81
82 let product = matrix.iter().zip(input).map(|(&x, y)| x * y).collect();
84
85 let output = fft.idft(product);
87 output.try_into().unwrap()
88}
89
90#[inline]
108pub fn mds_multiply<F, A, const WIDTH: usize>(state: &mut [A; WIDTH], matrix: &[[F; WIDTH]; WIDTH])
109where
110 F: PrimeCharacteristicRing,
111 A: Algebra<F>,
112{
113 let input = state.clone();
115
116 for (out, row) in state.iter_mut().zip(matrix.iter()) {
118 *out = A::mixed_dot_product(&input, row);
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use p3_baby_bear::BabyBear;
125 use p3_dft::NaiveDft;
126 use p3_field::PrimeCharacteristicRing;
127 use proptest::prelude::*;
128
129 use super::*;
130
131 type F = BabyBear;
132
133 fn arb_f() -> impl Strategy<Value = F> {
134 prop::num::u32::ANY.prop_map(F::from_u32)
135 }
136
137 #[test]
138 fn first_row_to_first_col_even_length() {
139 let input = [0, 1, 2, 3, 4, 5];
140 assert_eq!(first_row_to_first_col(&input), [0, 5, 4, 3, 2, 1]);
141 }
142
143 #[test]
144 fn first_row_to_first_col_odd_length() {
145 let input = [10, 20, 30, 40, 50];
146 assert_eq!(first_row_to_first_col(&input), [10, 50, 40, 30, 20]);
147 }
148
149 #[test]
150 fn first_row_to_first_col_single_element() {
151 assert_eq!(first_row_to_first_col(&[42]), [42]);
152 }
153
154 #[test]
155 fn first_row_to_first_col_two_elements() {
156 assert_eq!(first_row_to_first_col(&[1, 2]), [1, 2]);
157 }
158
159 #[test]
160 fn apply_circulant_identity() {
161 let identity_row: [u64; 4] = [1, 0, 0, 0];
163 let input: [F; 4] = [5, 10, 15, 20].map(F::from_u32);
164 assert_eq!(apply_circulant(&identity_row, &input), input);
165 }
166
167 #[test]
168 fn apply_circulant_all_ones() {
169 let ones: [u64; 4] = [1, 1, 1, 1];
171 let input: [F; 4] = [1, 2, 3, 4].map(F::from_u32);
172 let sum = F::from_u32(10);
173 assert_eq!(apply_circulant(&ones, &input), [sum; 4]);
174 }
175
176 #[test]
177 fn apply_circulant_scalar() {
178 let row: [u64; 4] = [7, 0, 0, 0];
180 let input: [F; 4] = [1, 2, 3, 4].map(F::from_u32);
181 let expected: [F; 4] = [7, 14, 21, 28].map(F::from_u32);
182 assert_eq!(apply_circulant(&row, &input), expected);
183 }
184
185 #[test]
186 fn apply_circulant_size_1() {
187 let row: [u64; 1] = [5];
189 let input: [F; 1] = [F::from_u32(3)];
190 assert_eq!(apply_circulant(&row, &input), [F::from_u32(15)]);
191 }
192
193 #[test]
194 fn apply_circulant_fft_matches_naive_4() {
195 let row: [u64; 4] = [2, 3, 5, 7];
197 let col = first_row_to_first_col(&row);
198 let input: [F; 4] = [1, 2, 3, 4].map(F::from_u32);
199
200 let naive = apply_circulant(&row, &input);
201 let fft_result = apply_circulant_fft(&NaiveDft, col, &input);
202 assert_eq!(naive, fft_result);
203 }
204
205 #[test]
206 fn apply_circulant_fft_identity() {
207 let row: [u64; 4] = [1, 0, 0, 0];
209 let col = first_row_to_first_col(&row);
210 let input: [F; 4] = [5, 10, 15, 20].map(F::from_u32);
211 assert_eq!(apply_circulant_fft(&NaiveDft, col, &input), input);
212 }
213
214 proptest! {
215 #[test]
216 fn first_row_to_first_col_involution(v in prop::array::uniform4(0u64..1000)) {
217 let col = first_row_to_first_col(&v);
218 let back = first_row_to_first_col(&col);
219 prop_assert_eq!(back, v);
220 }
221
222 #[test]
223 fn apply_circulant_fft_matches_naive(
224 row in prop::array::uniform4(0u64..1000),
225 input in prop::array::uniform4(arb_f()),
226 ) {
227 let col = first_row_to_first_col(&row);
228 let naive = apply_circulant(&row, &input);
229 let fft_result = apply_circulant_fft(&NaiveDft, col, &input);
230 prop_assert_eq!(naive, fft_result);
231 }
232
233 #[test]
234 fn apply_circulant_linearity(
235 row in prop::array::uniform4(0u64..100),
236 a in prop::array::uniform4(arb_f()),
237 b in prop::array::uniform4(arb_f()),
238 ) {
239 let sum_input: [F; 4] = core::array::from_fn(|i| a[i] + b[i]);
240 let ca = apply_circulant(&row, &a);
241 let cb = apply_circulant(&row, &b);
242 let c_sum = apply_circulant(&row, &sum_input);
243 for i in 0..4 {
244 prop_assert_eq!(c_sum[i], ca[i] + cb[i]);
245 }
246 }
247
248 #[test]
249 fn apply_circulant_zero_matrix(input in prop::array::uniform4(arb_f())) {
250 let zeros: [u64; 4] = [0; 4];
251 let result = apply_circulant(&zeros, &input);
252 prop_assert_eq!(result, [F::ZERO; 4]);
253 }
254 }
255}