1use core::ops::{Add, AddAssign, Neg, ShrAssign, Sub, SubAssign};
48
49pub trait RngElt:
53 Add<Output = Self>
54 + AddAssign
55 + Copy
56 + Default
57 + Neg<Output = Self>
58 + ShrAssign<u32>
59 + Sub<Output = Self>
60 + SubAssign
61{
62}
63
64impl RngElt for i64 {}
65impl RngElt for i128 {}
66
67pub trait Convolve<F, T: RngElt, U: RngElt, V: RngElt> {
94 fn read(input: F) -> T;
97
98 fn parity_dot<const N: usize>(lhs: [T; N], rhs: [U; N]) -> V;
105
106 fn reduce(z: V) -> F;
109
110 #[inline(always)]
115 fn apply<const N: usize, C: Fn([T; N], [U; N], &mut [V])>(
116 lhs: [F; N],
117 rhs: [U; N],
118 conv: C,
119 ) -> [F; N] {
120 let lhs = lhs.map(Self::read);
121 let mut output = [V::default(); N];
122 conv(lhs, rhs, &mut output);
123 output.map(Self::reduce)
124 }
125
126 #[inline(always)]
127 fn conv3(lhs: [T; 3], rhs: [U; 3], output: &mut [V]) {
128 output[0] = Self::parity_dot(lhs, [rhs[0], rhs[2], rhs[1]]);
129 output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], rhs[2]]);
130 output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0]]);
131 }
132
133 #[inline(always)]
134 fn negacyclic_conv3(lhs: [T; 3], rhs: [U; 3], output: &mut [V]) {
135 output[0] = Self::parity_dot(lhs, [rhs[0], -rhs[2], -rhs[1]]);
136 output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], -rhs[2]]);
137 output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0]]);
138 }
139
140 #[inline(always)]
141 fn conv4(lhs: [T; 4], rhs: [U; 4], output: &mut [V]) {
142 let u_p = [lhs[0] + lhs[2], lhs[1] + lhs[3]];
145 let u_m = [lhs[0] - lhs[2], lhs[1] - lhs[3]];
146 let v_p = [rhs[0] + rhs[2], rhs[1] + rhs[3]];
147 let v_m = [rhs[0] - rhs[2], rhs[1] - rhs[3]];
148
149 output[0] = Self::parity_dot(u_m, [v_m[0], -v_m[1]]);
150 output[1] = Self::parity_dot(u_m, [v_m[1], v_m[0]]);
151 output[2] = Self::parity_dot(u_p, v_p);
152 output[3] = Self::parity_dot(u_p, [v_p[1], v_p[0]]);
153
154 output[0] += output[2];
155 output[1] += output[3];
156
157 output[0] >>= 1;
158 output[1] >>= 1;
159
160 output[2] -= output[0];
161 output[3] -= output[1];
162 }
163
164 #[inline(always)]
165 fn negacyclic_conv4(lhs: [T; 4], rhs: [U; 4], output: &mut [V]) {
166 output[0] = Self::parity_dot(lhs, [rhs[0], -rhs[3], -rhs[2], -rhs[1]]);
167 output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], -rhs[3], -rhs[2]]);
168 output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0], -rhs[3]]);
169 output[3] = Self::parity_dot(lhs, [rhs[3], rhs[2], rhs[1], rhs[0]]);
170 }
171
172 #[inline(always)]
173 fn conv6(lhs: [T; 6], rhs: [U; 6], output: &mut [V]) {
174 conv_n_recursive(lhs, rhs, output, Self::conv3, Self::negacyclic_conv3);
175 }
176
177 #[inline(always)]
178 fn negacyclic_conv6(lhs: [T; 6], rhs: [U; 6], output: &mut [V]) {
179 negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv3);
180 }
181
182 #[inline(always)]
183 fn conv8(lhs: [T; 8], rhs: [U; 8], output: &mut [V]) {
184 conv_n_recursive(lhs, rhs, output, Self::conv4, Self::negacyclic_conv4);
185 }
186
187 #[inline(always)]
188 fn negacyclic_conv8(lhs: [T; 8], rhs: [U; 8], output: &mut [V]) {
189 negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv4);
190 }
191
192 #[inline(always)]
193 fn conv12(lhs: [T; 12], rhs: [U; 12], output: &mut [V]) {
194 conv_n_recursive(lhs, rhs, output, Self::conv6, Self::negacyclic_conv6);
195 }
196
197 #[inline(always)]
198 fn negacyclic_conv12(lhs: [T; 12], rhs: [U; 12], output: &mut [V]) {
199 negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv6);
200 }
201
202 #[inline(always)]
203 fn conv16(lhs: [T; 16], rhs: [U; 16], output: &mut [V]) {
204 conv_n_recursive(lhs, rhs, output, Self::conv8, Self::negacyclic_conv8);
205 }
206
207 #[inline(always)]
208 fn negacyclic_conv16(lhs: [T; 16], rhs: [U; 16], output: &mut [V]) {
209 negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv8);
210 }
211
212 #[inline(always)]
213 fn conv24(lhs: [T; 24], rhs: [U; 24], output: &mut [V]) {
214 conv_n_recursive(lhs, rhs, output, Self::conv12, Self::negacyclic_conv12);
215 }
216
217 #[inline(always)]
218 fn conv32(lhs: [T; 32], rhs: [U; 32], output: &mut [V]) {
219 conv_n_recursive(lhs, rhs, output, Self::conv16, Self::negacyclic_conv16);
220 }
221
222 #[inline(always)]
223 fn negacyclic_conv32(lhs: [T; 32], rhs: [U; 32], output: &mut [V]) {
224 negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv16);
225 }
226
227 #[inline(always)]
228 fn conv64(lhs: [T; 64], rhs: [U; 64], output: &mut [V]) {
229 conv_n_recursive(lhs, rhs, output, Self::conv32, Self::negacyclic_conv32);
230 }
231}
232
233#[inline(always)]
236fn conv_n_recursive<const N: usize, const HALF_N: usize, T, U, V, C, NC>(
237 lhs: [T; N],
238 rhs: [U; N],
239 output: &mut [V],
240 inner_conv: C,
241 inner_negacyclic_conv: NC,
242) where
243 T: RngElt,
244 U: RngElt,
245 V: RngElt,
246 C: Fn([T; HALF_N], [U; HALF_N], &mut [V]),
247 NC: Fn([T; HALF_N], [U; HALF_N], &mut [V]),
248{
249 debug_assert_eq!(2 * HALF_N, N);
250 let mut lhs_pos = [T::default(); HALF_N]; let mut lhs_neg = [T::default(); HALF_N]; let mut rhs_pos = [U::default(); HALF_N]; let mut rhs_neg = [U::default(); HALF_N]; for i in 0..HALF_N {
257 let s = lhs[i];
258 let t = lhs[i + HALF_N];
259
260 lhs_pos[i] = s + t;
261 lhs_neg[i] = s - t;
262
263 let s = rhs[i];
264 let t = rhs[i + HALF_N];
265
266 rhs_pos[i] = s + t;
267 rhs_neg[i] = s - t;
268 }
269
270 let (left, right) = output.split_at_mut(HALF_N);
271
272 inner_negacyclic_conv(lhs_neg, rhs_neg, left);
274
275 inner_conv(lhs_pos, rhs_pos, right);
277
278 for i in 0..HALF_N {
279 left[i] += right[i]; left[i] >>= 1; right[i] -= left[i]; }
283}
284
285#[inline(always)]
288fn negacyclic_conv_n_recursive<const N: usize, const HALF_N: usize, T, U, V, NC>(
289 lhs: [T; N],
290 rhs: [U; N],
291 output: &mut [V],
292 inner_negacyclic_conv: NC,
293) where
294 T: RngElt,
295 U: RngElt,
296 V: RngElt,
297 NC: Fn([T; HALF_N], [U; HALF_N], &mut [V]),
298{
299 debug_assert_eq!(2 * HALF_N, N);
300 let mut lhs_even = [T::default(); HALF_N];
302 let mut lhs_odd = [T::default(); HALF_N];
303 let mut lhs_sum = [T::default(); HALF_N];
304 let mut rhs_even = [U::default(); HALF_N];
305 let mut rhs_odd = [U::default(); HALF_N];
306 let mut rhs_sum = [U::default(); HALF_N];
307
308 for i in 0..HALF_N {
309 let s = lhs[2 * i];
310 let t = lhs[2 * i + 1];
311 lhs_even[i] = s;
312 lhs_odd[i] = t;
313 lhs_sum[i] = s + t;
314
315 let s = rhs[2 * i];
316 let t = rhs[2 * i + 1];
317 rhs_even[i] = s;
318 rhs_odd[i] = t;
319 rhs_sum[i] = s + t;
320 }
321
322 let mut even_s_conv = [V::default(); HALF_N];
323 let (left, right) = output.split_at_mut(HALF_N);
324
325 inner_negacyclic_conv(lhs_even, rhs_even, &mut even_s_conv);
328 inner_negacyclic_conv(lhs_odd, rhs_odd, left);
329 inner_negacyclic_conv(lhs_sum, rhs_sum, right);
330
331 right[0] -= even_s_conv[0] + left[0];
334 even_s_conv[0] -= left[HALF_N - 1];
335
336 for i in 1..HALF_N {
337 right[i] -= even_s_conv[i] + left[i];
338 even_s_conv[i] += left[i - 1];
339 }
340
341 for i in 0..HALF_N {
343 output[2 * i] = even_s_conv[i];
344 output[2 * i + 1] = output[i + HALF_N];
345 }
346}