p3_mds/
util.rs

1use core::ops::{AddAssign, Mul};
2
3use p3_dft::TwoAdicSubgroupDft;
4use p3_field::{PrimeCharacteristicRing, TwoAdicField};
5
6// NB: These are all MDS for M31, BabyBear and Goldilocks
7// const MATRIX_CIRC_MDS_8_2EXP: [u64; 8] = [1, 1, 2, 1, 8, 32, 4, 256];
8// const MATRIX_CIRC_MDS_8_SML: [u64; 8] = [4, 1, 2, 9, 10, 5, 1, 1];
9// Much smaller: [1, 1, -1, 2, 3, 8, 2, -3] but need to deal with the -ve's
10
11// const MATRIX_CIRC_MDS_12_2EXP: [u64; 12] = [1, 1, 2, 1, 8, 32, 2, 256, 4096, 8, 65536, 1024];
12// const MATRIX_CIRC_MDS_12_SML: [u64; 12] = [9, 7, 4, 1, 16, 2, 256, 128, 3, 32, 1, 1];
13// const MATRIX_CIRC_MDS_12_SML: [u64; 12] = [1, 1, 2, 1, 8, 9, 10, 7, 5, 9, 4, 10];
14
15// Trying to maximise the # of 1's in the vector.
16// Not clear exactly what we should be optimising here but that seems reasonable.
17// const MATRIX_CIRC_MDS_16_SML: [u64; 16] =
18//   [1, 1, 51, 1, 11, 17, 2, 1, 101, 63, 15, 2, 67, 22, 13, 3];
19// 1, 1, 51, 52, 11, 63, 1, 2, 1, 2, 15, 67, 2, 22, 13, 3
20// [1, 1, 2, 1, 8, 32, 2, 65, 77, 8, 91, 31, 3, 65, 32, 7];
21
22/// This will throw an error if N = 0 but it's hard to imagine this case coming up.
23#[inline(always)]
24pub fn dot_product<T, const N: usize>(u: [T; N], v: [T; N]) -> T
25where
26    T: Copy + AddAssign + Mul<Output = T>,
27{
28    debug_assert_ne!(N, 0);
29    let mut dp = u[0] * v[0];
30    for i in 1..N {
31        dp += u[i] * v[i];
32    }
33    dp
34}
35
36/// Given the first row `circ_matrix` of an NxN circulant matrix, say
37/// C, return the product `C*input`.
38///
39/// NB: This function is a naive implementation of the n²
40/// evaluation. It is a placeholder until we have FFT implementations
41/// for all combinations of field and size.
42pub fn apply_circulant<R: PrimeCharacteristicRing, const N: usize>(
43    circ_matrix: &[u64; N],
44    input: &[R; N],
45) -> [R; N] {
46    let mut matrix = circ_matrix.map(R::from_u64);
47
48    let mut output = [R::ZERO; N];
49    for out_i in output.iter_mut().take(N - 1) {
50        *out_i = R::dot_product(&matrix, input);
51        matrix.rotate_right(1);
52    }
53    output[N - 1] = R::dot_product(&matrix, input);
54    output
55}
56
57/// Given the first row of a circulant matrix, return the first column.
58///
59/// For example if, `v = [0, 1, 2, 3, 4, 5]` then `output = [0, 5, 4, 3, 2, 1]`,
60/// i.e. the first element is the same and the other elements are reversed.
61///
62/// This is useful to prepare a circulant matrix for input to an FFT
63/// algorithm, which expects the first column of the matrix rather
64/// than the first row (as we normally store them).
65///
66/// NB: The algorithm is inefficient but simple enough that this
67/// function can be declared `const`, and that is the intended context
68/// for use.
69pub const fn first_row_to_first_col<const N: usize, T: Copy>(v: &[T; N]) -> [T; N] {
70    let mut output = *v;
71    let mut i = 1;
72    while i < N {
73        // Reverse elements
74        output[i] = v[N - i];
75        i += 1;
76    }
77    output
78}
79
80/// Use the convolution theorem to calculate the product of the given
81/// circulant matrix and the given vector.
82///
83/// The circulant matrix must be specified by its first *column*, not its first row. If you have
84/// the row as an array, you can obtain the column with `first_row_to_first_col()`.
85#[inline]
86pub fn apply_circulant_fft<F: TwoAdicField, const N: usize, FFT: TwoAdicSubgroupDft<F>>(
87    fft: &FFT,
88    column: [u64; N],
89    input: &[F; N],
90) -> [F; N] {
91    let column = column.map(F::from_u64).to_vec();
92    let matrix = fft.dft(column);
93    let input = fft.dft(input.to_vec());
94
95    // point-wise product
96    let product = matrix.iter().zip(input).map(|(&x, y)| x * y).collect();
97
98    let output = fft.idft(product);
99    output.try_into().unwrap()
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    #[test]
107    fn test_first_row_to_first_col_even_length() {
108        let input = [0, 1, 2, 3, 4, 5];
109        let output = [0, 5, 4, 3, 2, 1];
110
111        assert_eq!(first_row_to_first_col(&input), output);
112    }
113
114    #[test]
115    fn test_first_row_to_first_col_odd_length() {
116        let input = [10, 20, 30, 40, 50];
117        let output = [10, 50, 40, 30, 20];
118
119        assert_eq!(first_row_to_first_col(&input), output);
120    }
121
122    #[test]
123    fn test_first_row_to_first_col_single_element() {
124        let input = [42];
125        let output = [42];
126
127        assert_eq!(first_row_to_first_col(&input), output);
128    }
129
130    #[test]
131    fn test_first_row_to_first_col_all_zeros() {
132        let input = [0; 6];
133        let output = [0; 6];
134
135        assert_eq!(first_row_to_first_col(&input), output);
136    }
137
138    #[test]
139    fn test_first_row_to_first_col_negative_numbers() {
140        let input = [-1, -2, -3, -4];
141        let output = [-1, -4, -3, -2];
142
143        assert_eq!(first_row_to_first_col(&input), output);
144    }
145
146    #[test]
147    fn test_first_row_to_first_col_large_numbers() {
148        let input = [1_000_000, 2_000_000, 3_000_000, 4_000_000];
149        let output = [1_000_000, 4_000_000, 3_000_000, 2_000_000];
150
151        assert_eq!(first_row_to_first_col(&input), output);
152    }
153
154    #[test]
155    fn test_basic_dot_product() {
156        let u = [1, 2, 3];
157        let v = [4, 5, 6];
158        assert_eq!(dot_product(u, v), 4 + 2 * 5 + 3 * 6);
159    }
160
161    #[test]
162    fn test_single_element() {
163        let u = [7];
164        let v = [8];
165        assert_eq!(dot_product(u, v), 7 * 8);
166    }
167
168    #[test]
169    fn test_all_zeros() {
170        let u = [0; 4];
171        let v = [0; 4];
172        assert_eq!(dot_product(u, v), 0);
173    }
174
175    #[test]
176    fn test_negative_numbers() {
177        let u = [-1, -2, -3];
178        let v = [-4, -5, -6];
179        assert_eq!(dot_product(u, v), (-1) * (-4) + (-2) * (-5) + (-3) * (-6));
180    }
181
182    #[test]
183    fn test_large_numbers() {
184        let u = [1_000_000, 2_000_000, 3_000_000];
185        let v = [4, 5, 6];
186        assert_eq!(
187            dot_product(u, v),
188            1_000_000 * 4 + 2_000_000 * 5 + 3_000_000 * 6
189        );
190    }
191}