p3_field/
batch_inverse.rs

1use alloc::vec::Vec;
2
3use p3_maybe_rayon::prelude::*;
4use tracing::instrument;
5
6use crate::field::Field;
7use crate::{FieldArray, PackedValue, PrimeCharacteristicRing};
8
9/// Batch multiplicative inverses with Montgomery's trick
10/// This is Montgomery's trick. At a high level, we invert the product of the given field
11/// elements, then derive the individual inverses from that via multiplication.
12///
13/// The usual Montgomery trick involves calculating an array of cumulative products,
14/// resulting in a long dependency chain. To increase instruction-level parallelism, we
15/// compute WIDTH separate cumulative product arrays that only meet at the end.
16///
17/// # Panics
18/// This will panic if any of the inputs is zero.
19#[instrument(level = "debug", skip_all)]
20#[must_use]
21pub fn batch_multiplicative_inverse<F: Field>(x: &[F]) -> Vec<F> {
22    // How many elements to invert in one thread.
23    const CHUNK_SIZE: usize = 1024;
24
25    let n = x.len();
26    let mut result = F::zero_vec(n);
27
28    x.par_chunks(CHUNK_SIZE)
29        .zip(result.par_chunks_mut(CHUNK_SIZE))
30        .for_each(|(x, result)| {
31            batch_multiplicative_inverse_helper(x, result);
32        });
33
34    result
35}
36
37/// Like `batch_multiplicative_inverse`, but writes the result to the given output buffer.
38fn batch_multiplicative_inverse_helper<F: Field>(x: &[F], result: &mut [F]) {
39    // Higher WIDTH increases instruction-level parallelism, but too high a value will cause us
40    // to run out of registers.
41    const WIDTH: usize = 4;
42
43    let n = x.len();
44    assert_eq!(result.len(), n);
45    if !n.is_multiple_of(WIDTH) {
46        // There isn't a very clean way to do this with FieldArray; for now just do it in serial.
47        // Another simple (though suboptimal) workaround would be to make two separate calls, one
48        // for the packed part and one for the remainder.
49        return batch_multiplicative_inverse_general(x, result, |x| x.inverse());
50    }
51
52    let x_packed = FieldArray::<F, 4>::pack_slice(x);
53    let result_packed = FieldArray::<F, 4>::pack_slice_mut(result);
54
55    batch_multiplicative_inverse_general(x_packed, result_packed, |x_packed| x_packed.inverse());
56}
57
58/// A simple single-threaded implementation of Montgomery's trick. Since not all `PrimeCharacteristicRing`s
59/// support inversion, this takes a custom inversion function.
60pub(crate) fn batch_multiplicative_inverse_general<F, Inv>(x: &[F], result: &mut [F], inv: Inv)
61where
62    F: PrimeCharacteristicRing + Copy,
63    Inv: Fn(F) -> F,
64{
65    let n = x.len();
66    assert_eq!(result.len(), n);
67    if n == 0 {
68        return;
69    }
70
71    result[0] = F::ONE;
72    for i in 1..n {
73        result[i] = result[i - 1] * x[i - 1];
74    }
75
76    let product = result[n - 1] * x[n - 1];
77    let mut inv = inv(product);
78
79    for i in (0..n).rev() {
80        result[i] *= inv;
81        inv *= x[i];
82    }
83}