1use core::ops::{AddAssign, Mul};
2
3use p3_dft::TwoAdicSubgroupDft;
4use p3_field::{PrimeCharacteristicRing, TwoAdicField};
5
6#[inline(always)]
24pub fn dot_product<T, const N: usize>(u: [T; N], v: [T; N]) -> T
25where
26 T: Copy + AddAssign + Mul<Output = T>,
27{
28 debug_assert_ne!(N, 0);
29 let mut dp = u[0] * v[0];
30 for i in 1..N {
31 dp += u[i] * v[i];
32 }
33 dp
34}
35
36pub fn apply_circulant<R: PrimeCharacteristicRing, const N: usize>(
43 circ_matrix: &[u64; N],
44 input: &[R; N],
45) -> [R; N] {
46 let mut matrix = circ_matrix.map(R::from_u64);
47
48 let mut output = [R::ZERO; N];
49 for out_i in output.iter_mut().take(N - 1) {
50 *out_i = R::dot_product(&matrix, input);
51 matrix.rotate_right(1);
52 }
53 output[N - 1] = R::dot_product(&matrix, input);
54 output
55}
56
57pub const fn first_row_to_first_col<const N: usize, T: Copy>(v: &[T; N]) -> [T; N] {
70 let mut output = *v;
71 let mut i = 1;
72 while i < N {
73 output[i] = v[N - i];
75 i += 1;
76 }
77 output
78}
79
80#[inline]
86pub fn apply_circulant_fft<F: TwoAdicField, const N: usize, FFT: TwoAdicSubgroupDft<F>>(
87 fft: &FFT,
88 column: [u64; N],
89 input: &[F; N],
90) -> [F; N] {
91 let column = column.map(F::from_u64).to_vec();
92 let matrix = fft.dft(column);
93 let input = fft.dft(input.to_vec());
94
95 let product = matrix.iter().zip(input).map(|(&x, y)| x * y).collect();
97
98 let output = fft.idft(product);
99 output.try_into().unwrap()
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105
106 #[test]
107 fn test_first_row_to_first_col_even_length() {
108 let input = [0, 1, 2, 3, 4, 5];
109 let output = [0, 5, 4, 3, 2, 1];
110
111 assert_eq!(first_row_to_first_col(&input), output);
112 }
113
114 #[test]
115 fn test_first_row_to_first_col_odd_length() {
116 let input = [10, 20, 30, 40, 50];
117 let output = [10, 50, 40, 30, 20];
118
119 assert_eq!(first_row_to_first_col(&input), output);
120 }
121
122 #[test]
123 fn test_first_row_to_first_col_single_element() {
124 let input = [42];
125 let output = [42];
126
127 assert_eq!(first_row_to_first_col(&input), output);
128 }
129
130 #[test]
131 fn test_first_row_to_first_col_all_zeros() {
132 let input = [0; 6];
133 let output = [0; 6];
134
135 assert_eq!(first_row_to_first_col(&input), output);
136 }
137
138 #[test]
139 fn test_first_row_to_first_col_negative_numbers() {
140 let input = [-1, -2, -3, -4];
141 let output = [-1, -4, -3, -2];
142
143 assert_eq!(first_row_to_first_col(&input), output);
144 }
145
146 #[test]
147 fn test_first_row_to_first_col_large_numbers() {
148 let input = [1_000_000, 2_000_000, 3_000_000, 4_000_000];
149 let output = [1_000_000, 4_000_000, 3_000_000, 2_000_000];
150
151 assert_eq!(first_row_to_first_col(&input), output);
152 }
153
154 #[test]
155 fn test_basic_dot_product() {
156 let u = [1, 2, 3];
157 let v = [4, 5, 6];
158 assert_eq!(dot_product(u, v), 4 + 2 * 5 + 3 * 6);
159 }
160
161 #[test]
162 fn test_single_element() {
163 let u = [7];
164 let v = [8];
165 assert_eq!(dot_product(u, v), 7 * 8);
166 }
167
168 #[test]
169 fn test_all_zeros() {
170 let u = [0; 4];
171 let v = [0; 4];
172 assert_eq!(dot_product(u, v), 0);
173 }
174
175 #[test]
176 fn test_negative_numbers() {
177 let u = [-1, -2, -3];
178 let v = [-4, -5, -6];
179 assert_eq!(dot_product(u, v), (-1) * (-4) + (-2) * (-5) + (-3) * (-6));
180 }
181
182 #[test]
183 fn test_large_numbers() {
184 let u = [1_000_000, 2_000_000, 3_000_000];
185 let v = [4, 5, 6];
186 assert_eq!(
187 dot_product(u, v),
188 1_000_000 * 4 + 2_000_000 * 5 + 3_000_000 * 6
189 );
190 }
191}