p3_util/
lib.rs

1//! Various simple utilities.
2
3#![no_std]
4
5extern crate alloc;
6
7use alloc::slice;
8use alloc::string::String;
9use alloc::vec::Vec;
10use core::any::type_name;
11use core::hint::unreachable_unchecked;
12use core::mem::{ManuallyDrop, MaybeUninit};
13use core::{iter, mem};
14
15use crate::transpose::transpose_in_place_square;
16
17pub mod array_serialization;
18pub mod linear_map;
19pub mod transpose;
20pub mod zip_eq;
21
22/// Computes `ceil(log_2(n))`.
23#[must_use]
24pub const fn log2_ceil_usize(n: usize) -> usize {
25    (usize::BITS - n.saturating_sub(1).leading_zeros()) as usize
26}
27
28#[must_use]
29pub fn log2_ceil_u64(n: u64) -> u64 {
30    (u64::BITS - n.saturating_sub(1).leading_zeros()).into()
31}
32
33/// Computes `log_2(n)`
34///
35/// # Panics
36/// Panics if `n` is not a power of two.
37#[must_use]
38#[inline]
39pub fn log2_strict_usize(n: usize) -> usize {
40    let res = n.trailing_zeros();
41    assert_eq!(n.wrapping_shr(res), 1, "Not a power of two: {n}");
42    // Tell the optimizer about the semantics of `log2_strict`. i.e. it can replace `n` with
43    // `1 << res` and vice versa.
44    unsafe {
45        assume(n == 1 << res);
46    }
47    res as usize
48}
49
50/// Returns `[0, ..., N - 1]`.
51#[must_use]
52pub const fn indices_arr<const N: usize>() -> [usize; N] {
53    let mut indices_arr = [0; N];
54    let mut i = 0;
55    while i < N {
56        indices_arr[i] = i;
57        i += 1;
58    }
59    indices_arr
60}
61
62#[inline]
63pub const fn reverse_bits(x: usize, n: usize) -> usize {
64    // Assert that n is a power of 2
65    debug_assert!(n.is_power_of_two());
66    reverse_bits_len(x, n.trailing_zeros() as usize)
67}
68
69#[inline]
70pub const fn reverse_bits_len(x: usize, bit_len: usize) -> usize {
71    // NB: The only reason we need overflowing_shr() here as opposed
72    // to plain '>>' is to accommodate the case n == num_bits == 0,
73    // which would become `0 >> 64`. Rust thinks that any shift of 64
74    // bits causes overflow, even when the argument is zero.
75    x.reverse_bits()
76        .overflowing_shr(usize::BITS - bit_len as u32)
77        .0
78}
79
80// Lookup table of 6-bit reverses.
81// NB: 2^6=64 bytes is a cache line. A smaller table wastes cache space.
82#[cfg(not(target_arch = "aarch64"))]
83#[rustfmt::skip]
84const BIT_REVERSE_6BIT: &[u8] = &[
85    0o00, 0o40, 0o20, 0o60, 0o10, 0o50, 0o30, 0o70,
86    0o04, 0o44, 0o24, 0o64, 0o14, 0o54, 0o34, 0o74,
87    0o02, 0o42, 0o22, 0o62, 0o12, 0o52, 0o32, 0o72,
88    0o06, 0o46, 0o26, 0o66, 0o16, 0o56, 0o36, 0o76,
89    0o01, 0o41, 0o21, 0o61, 0o11, 0o51, 0o31, 0o71,
90    0o05, 0o45, 0o25, 0o65, 0o15, 0o55, 0o35, 0o75,
91    0o03, 0o43, 0o23, 0o63, 0o13, 0o53, 0o33, 0o73,
92    0o07, 0o47, 0o27, 0o67, 0o17, 0o57, 0o37, 0o77,
93];
94
95// Ensure that SMALL_ARR_SIZE >= 4 * BIG_T_SIZE.
96const BIG_T_SIZE: usize = 1 << 14;
97const SMALL_ARR_SIZE: usize = 1 << 16;
98
99/// Permutes `arr` such that each index is mapped to its reverse in binary.
100///
101/// If the whole array fits in fast cache, then the trivial algorithm is cache friendly. Also, if
102/// `T` is really big, then the trivial algorithm is cache-friendly, no matter the size of the array.
103pub fn reverse_slice_index_bits<F>(vals: &mut [F])
104where
105    F: Copy + Send + Sync,
106{
107    let n = vals.len();
108    if n == 0 {
109        return;
110    }
111    let log_n = log2_strict_usize(n);
112
113    // If the whole array fits in fast cache, then the trivial algorithm is cache friendly. Also, if
114    // `T` is really big, then the trivial algorithm is cache-friendly, no matter the size of the array.
115    if core::mem::size_of::<F>() << log_n <= SMALL_ARR_SIZE
116        || core::mem::size_of::<F>() >= BIG_T_SIZE
117    {
118        reverse_slice_index_bits_small(vals, log_n);
119    } else {
120        debug_assert!(n >= 4); // By our choice of `BIG_T_SIZE` and `SMALL_ARR_SIZE`.
121
122        // Algorithm:
123        //
124        // Treat `arr` as a `sqrt(n)` by `sqrt(n)` row-major matrix. (Assume for now that `lb_n` is
125        // even, i.e., `n` is a square number.) To perform bit-order reversal we:
126        //  1. Bit-reverse the order of the rows. (They are contiguous in memory, so this is
127        //     basically a series of large `memcpy`s.)
128        //  2. Transpose the matrix.
129        //  3. Bit-reverse the order of the rows.
130        //
131        // This is equivalent to, for every index `0 <= i < n`:
132        //  1. bit-reversing `i[lb_n / 2..lb_n]`,
133        //  2. swapping `i[0..lb_n / 2]` and `i[lb_n / 2..lb_n]`,
134        //  3. bit-reversing `i[lb_n / 2..lb_n]`.
135        //
136        // If `lb_n` is odd, i.e., `n` is not a square number, then the above procedure requires
137        // slight modification. At steps 1 and 3 we bit-reverse bits `ceil(lb_n / 2)..lb_n`, of the
138        // index (shuffling `floor(lb_n / 2)` chunks of length `ceil(lb_n / 2)`). At step 2, we
139        // perform _two_ transposes. We treat `arr` as two matrices, one where the middle bit of the
140        // index is `0` and another, where the middle bit is `1`; we transpose each individually.
141
142        let lb_num_chunks = log_n >> 1;
143        let lb_chunk_size = log_n - lb_num_chunks;
144        unsafe {
145            reverse_slice_index_bits_chunks(vals, lb_num_chunks, lb_chunk_size);
146            transpose_in_place_square(vals, lb_chunk_size, lb_num_chunks, 0);
147            if lb_num_chunks != lb_chunk_size {
148                // `arr` cannot be interpreted as a square matrix. We instead interpret it as a
149                // `1 << lb_num_chunks` by `2` by `1 << lb_num_chunks` tensor, in row-major order.
150                // The above transpose acted on `tensor[..., 0, ...]` (all indices with middle bit
151                // `0`). We still need to transpose `tensor[..., 1, ...]`. To do so, we advance
152                // arr by `1 << lb_num_chunks` effectively, adding that to every index.
153                let vals_with_offset = &mut vals[1 << lb_num_chunks..];
154                transpose_in_place_square(vals_with_offset, lb_chunk_size, lb_num_chunks, 0);
155            }
156            reverse_slice_index_bits_chunks(vals, lb_num_chunks, lb_chunk_size);
157        }
158    }
159}
160
161// Both functions below are semantically equivalent to:
162//     for i in 0..n {
163//         result.push(arr[reverse_bits(i, n_power)]);
164//     }
165// where reverse_bits(i, n_power) computes the n_power-bit reverse. The complications are there
166// to guide the compiler to generate optimal assembly.
167
168#[cfg(not(target_arch = "aarch64"))]
169fn reverse_slice_index_bits_small<F>(vals: &mut [F], lb_n: usize) {
170    if lb_n <= 6 {
171        // BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them lb_n-bit reverses.
172        let dst_shr_amt = 6 - lb_n as u32;
173        #[allow(clippy::needless_range_loop)]
174        for src in 0..vals.len() {
175            let dst = (BIT_REVERSE_6BIT[src] as usize).wrapping_shr(dst_shr_amt);
176            if src < dst {
177                vals.swap(src, dst);
178            }
179        }
180    } else {
181        // LLVM does not know that it does not need to reverse src at each iteration (which is
182        // expensive on x86). We take advantage of the fact that the low bits of dst change rarely and the high
183        // bits of dst are dependent only on the low bits of src.
184        let dst_lo_shr_amt = usize::BITS - (lb_n - 6) as u32;
185        let dst_hi_shl_amt = lb_n - 6;
186        for src_chunk in 0..(vals.len() >> 6) {
187            let src_hi = src_chunk << 6;
188            let dst_lo = src_chunk.reverse_bits().wrapping_shr(dst_lo_shr_amt);
189            #[allow(clippy::needless_range_loop)]
190            for src_lo in 0..(1 << 6) {
191                let dst_hi = (BIT_REVERSE_6BIT[src_lo] as usize) << dst_hi_shl_amt;
192                let src = src_hi + src_lo;
193                let dst = dst_hi + dst_lo;
194                if src < dst {
195                    vals.swap(src, dst);
196                }
197            }
198        }
199    }
200}
201
202#[cfg(target_arch = "aarch64")]
203fn reverse_slice_index_bits_small<F>(vals: &mut [F], lb_n: usize) {
204    // Aarch64 can reverse bits in one instruction, so the trivial version works best.
205    for src in 0..vals.len() {
206        let dst = src.reverse_bits().wrapping_shr(usize::BITS - lb_n as u32);
207        if src < dst {
208            vals.swap(src, dst);
209        }
210    }
211}
212
213/// Split `arr` chunks and bit-reverse the order of the chunks. There are `1 << lb_num_chunks`
214/// chunks, each of length `1 << lb_chunk_size`.
215/// SAFETY: ensure that `arr.len() == 1 << lb_num_chunks + lb_chunk_size`.
216unsafe fn reverse_slice_index_bits_chunks<F>(
217    vals: &mut [F],
218    lb_num_chunks: usize,
219    lb_chunk_size: usize,
220) {
221    for i in 0..1usize << lb_num_chunks {
222        // `wrapping_shr` handles the silly case when `lb_num_chunks == 0`.
223        let j = i
224            .reverse_bits()
225            .wrapping_shr(usize::BITS - lb_num_chunks as u32);
226        if i < j {
227            unsafe {
228                core::ptr::swap_nonoverlapping(
229                    vals.get_unchecked_mut(i << lb_chunk_size),
230                    vals.get_unchecked_mut(j << lb_chunk_size),
231                    1 << lb_chunk_size,
232                );
233            }
234        }
235    }
236}
237
238/// Allow the compiler to assume that the given predicate `p` is always `true`.
239///
240/// # Safety
241///
242/// Callers must ensure that `p` is true. If this is not the case, the behavior is undefined.
243#[inline(always)]
244pub unsafe fn assume(p: bool) {
245    debug_assert!(p);
246    if !p {
247        unsafe {
248            unreachable_unchecked();
249        }
250    }
251}
252
253/// Try to force Rust to emit a branch. Example:
254///
255/// ```no_run
256/// let x = 100;
257/// if x > 20 {
258///     println!("x is big!");
259///     p3_util::branch_hint();
260/// } else {
261///     println!("x is small!");
262/// }
263/// ```
264///
265/// This function has no semantics. It is a hint only.
266#[inline(always)]
267pub fn branch_hint() {
268    // NOTE: These are the currently supported assembly architectures. See the
269    // [nightly reference](https://doc.rust-lang.org/nightly/reference/inline-assembly.html) for
270    // the most up-to-date list.
271    #[cfg(any(
272        target_arch = "aarch64",
273        target_arch = "arm",
274        target_arch = "riscv32",
275        target_arch = "riscv64",
276        target_arch = "x86",
277        target_arch = "x86_64",
278    ))]
279    unsafe {
280        core::arch::asm!("", options(nomem, nostack, preserves_flags));
281    }
282}
283
284/// Return a String containing the name of T but with all the crate
285/// and module prefixes removed.
286pub fn pretty_name<T>() -> String {
287    let name = type_name::<T>();
288    let mut result = String::new();
289    for qual in name.split_inclusive(&['<', '>', ',']) {
290        result.push_str(qual.split("::").last().unwrap());
291    }
292    result
293}
294
295/// A C-style buffered input reader, similar to
296/// `core::iter::Iterator::next_chunk()` from nightly.
297///
298/// Returns an array of `MaybeUninit<T>` and the number of items in the
299/// array which have been correctly initialized.
300#[inline]
301fn iter_next_chunk_erased<const BUFLEN: usize, I: Iterator>(
302    iter: &mut I,
303) -> ([MaybeUninit<I::Item>; BUFLEN], usize)
304where
305    I::Item: Copy,
306{
307    let mut buf = [const { MaybeUninit::<I::Item>::uninit() }; BUFLEN];
308    let mut i = 0;
309
310    while i < BUFLEN {
311        if let Some(c) = iter.next() {
312            // Copy the next Item into `buf`.
313            unsafe {
314                buf.get_unchecked_mut(i).write(c);
315                i = i.unchecked_add(1);
316            }
317        } else {
318            // No more items in the iterator.
319            break;
320        }
321    }
322    (buf, i)
323}
324
325/// Gets a shared reference to the contained value.
326///
327/// # Safety
328///
329/// Calling this when the content is not yet fully initialized causes undefined
330/// behavior: it is up to the caller to guarantee that every `MaybeUninit<T>` in
331/// the slice really is in an initialized state.
332///
333/// Copied from:
334/// https://doc.rust-lang.org/std/primitive.slice.html#method.assume_init_ref
335/// Once that is stabilized, this should be removed.
336#[inline(always)]
337pub const unsafe fn assume_init_ref<T>(slice: &[MaybeUninit<T>]) -> &[T] {
338    // SAFETY: casting `slice` to a `*const [T]` is safe since the caller guarantees that
339    // `slice` is initialized, and `MaybeUninit` is guaranteed to have the same layout as `T`.
340    // The pointer obtained is valid since it refers to memory owned by `slice` which is a
341    // reference and thus guaranteed to be valid for reads.
342    unsafe { &*(slice as *const [MaybeUninit<T>] as *const [T]) }
343}
344
345/// Split an iterator into small arrays and apply `func` to each.
346///
347/// Repeatedly read `BUFLEN` elements from `input` into an array and
348/// pass the array to `func` as a slice. If less than `BUFLEN`
349/// elements are remaining, that smaller slice is passed to `func` (if
350/// it is non-empty) and the function returns.
351#[inline]
352pub fn apply_to_chunks<const BUFLEN: usize, I, H>(input: I, mut func: H)
353where
354    I: IntoIterator<Item = u8>,
355    H: FnMut(&[I::Item]),
356{
357    let mut iter = input.into_iter();
358    loop {
359        let (buf, n) = iter_next_chunk_erased::<BUFLEN, _>(&mut iter);
360        if n == 0 {
361            break;
362        }
363        func(unsafe { assume_init_ref(buf.get_unchecked(..n)) });
364    }
365}
366
367/// Pulls `N` items from `iter` and returns them as an array. If the iterator
368/// yields fewer than `N` items (but more than `0`), pads by the given default value.
369///
370/// Since the iterator is passed as a mutable reference and this function calls
371/// `next` at most `N` times, the iterator can still be used afterwards to
372/// retrieve the remaining items.
373///
374/// If `iter.next()` panics, all items already yielded by the iterator are
375/// dropped.
376#[inline]
377fn iter_next_chunk_padded<T: Copy, const N: usize>(
378    iter: &mut impl Iterator<Item = T>,
379    default: T, // Needed due to [T; M] not always implementing Default. Can probably be dropped if const generics stabilize.
380) -> Option<[T; N]> {
381    let (mut arr, n) = iter_next_chunk_erased::<N, _>(iter);
382    (n != 0).then(|| {
383        // Fill the rest of the array with default values.
384        arr[n..].fill(MaybeUninit::new(default));
385        unsafe { mem::transmute_copy::<_, [T; N]>(&arr) }
386    })
387}
388
389/// Returns an iterator over `N` elements of the iterator at a time.
390///
391/// The chunks do not overlap. If `N` does not divide the length of the
392/// iterator, then the last `N-1` elements will be padded with the given default value.
393///
394/// This is essentially a copy pasted version of the nightly `array_chunks` function.
395/// https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.array_chunks
396/// Once that is stabilized this and the functions above it should be removed.
397#[inline]
398pub fn iter_array_chunks_padded<T: Copy, const N: usize>(
399    iter: impl IntoIterator<Item = T>,
400    default: T, // Needed due to [T; M] not always implementing Default. Can probably be dropped if const generics stabilize.
401) -> impl Iterator<Item = [T; N]> {
402    let mut iter = iter.into_iter();
403    iter::from_fn(move || iter_next_chunk_padded(&mut iter, default))
404}
405
406/// Reinterpret a slice of `BaseArray` elements as a slice of `Base` elements
407///
408/// This is useful to convert `&[F; N]` to `&[F]` or `&[A]` to `&[F]` where
409/// `A` has the same size, alignment and memory layout as `[F; N]` for some `N`.
410///
411/// # Safety
412///
413/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
414/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
415/// the array is the same as the alignment of its elements, this means that `BaseArray`
416/// must have the same alignment as `Base`.
417///
418/// # Panics
419///
420/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
421#[inline]
422pub const unsafe fn as_base_slice<Base, BaseArray>(buf: &[BaseArray]) -> &[Base] {
423    const {
424        assert!(align_of::<Base>() == align_of::<BaseArray>());
425        assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
426    }
427
428    let d = size_of::<BaseArray>() / size_of::<Base>();
429
430    let buf_ptr = buf.as_ptr().cast::<Base>();
431    let n = buf.len() * d;
432    unsafe { slice::from_raw_parts(buf_ptr, n) }
433}
434
435/// Reinterpret a mutable slice of `BaseArray` elements as a slice of `Base` elements
436///
437/// This is useful to convert `&[F; N]` to `&[F]` or `&[A]` to `&[F]` where
438/// `A` has the same size, alignment and memory layout as `[F; N]` for some `N`.
439///
440/// # Safety
441///
442/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
443/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
444/// the array is the same as the alignment of its elements, this means that `BaseArray`
445/// must have the same alignment as `Base`.
446///
447/// # Panics
448///
449/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
450#[inline]
451pub const unsafe fn as_base_slice_mut<Base, BaseArray>(buf: &mut [BaseArray]) -> &mut [Base] {
452    const {
453        assert!(align_of::<Base>() == align_of::<BaseArray>());
454        assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
455    }
456
457    let d = size_of::<BaseArray>() / size_of::<Base>();
458
459    let buf_ptr = buf.as_mut_ptr().cast::<Base>();
460    let n = buf.len() * d;
461    unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
462}
463
464/// Convert a vector of `BaseArray` elements to a vector of `Base` elements without any
465/// reallocations.
466///
467/// This is useful to convert `Vec<[F; N]>` to `Vec<F>` or `Vec<A>` to `Vec<F>` where
468/// `A` has the same size, alignment and memory layout as `[F; N]` for some `N`. It can also,
469/// be used to safely convert `Vec<u32>` to `Vec<F>` if `F` is a `32` bit field
470/// or `Vec<u64>` to `Vec<F>` if `F` is a `64` bit field.
471///
472/// # Safety
473///
474/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
475/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
476/// the array is the same as the alignment of its elements, this means that `BaseArray`
477/// must have the same alignment as `Base`.
478///
479/// # Panics
480///
481/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
482#[inline]
483pub unsafe fn flatten_to_base<Base, BaseArray>(vec: Vec<BaseArray>) -> Vec<Base> {
484    const {
485        assert!(align_of::<Base>() == align_of::<BaseArray>());
486        assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
487    }
488
489    let d = size_of::<BaseArray>() / size_of::<Base>();
490    // Prevent running `vec`'s destructor so we are in complete control
491    // of the allocation.
492    let mut values = ManuallyDrop::new(vec);
493
494    // Each `Self` is an array of `d` elements, so the length and capacity of
495    // the new vector will be multiplied by `d`.
496    let new_len = values.len() * d;
497    let new_cap = values.capacity() * d;
498
499    // Safe as BaseArray and Base have the same alignment.
500    let ptr = values.as_mut_ptr() as *mut Base;
501
502    unsafe {
503        // Safety:
504        // - BaseArray and Base have the same alignment.
505        // - As size_of::<BaseArray>() == size_of::<Base>() * d:
506        //      -- The capacity of the new vector is equal to the capacity of the old vector.
507        //      -- The first new_len elements of the new vector correspond to the first
508        //         len elements of the old vector and so are properly initialized.
509        Vec::from_raw_parts(ptr, new_len, new_cap)
510    }
511}
512
513/// Convert a vector of `Base` elements to a vector of `BaseArray` elements ideally without any
514/// reallocations.
515///
516/// This is an inverse of `flatten_to_base`. Unfortunately, unlike `flatten_to_base`, it may not be
517/// possible to avoid allocations. This issue is that there is not way to guarantee that the capacity
518/// of the vector is a multiple of `d`.
519///
520/// # Safety
521///
522/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
523/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
524/// the array is the same as the alignment of its elements, this means that `BaseArray`
525/// must have the same alignment as `Base`.
526///
527/// # Panics
528///
529/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
530/// This panics if the length of the vector is not a multiple of the ratio of the sizes.
531#[inline]
532pub unsafe fn reconstitute_from_base<Base, BaseArray: Clone>(mut vec: Vec<Base>) -> Vec<BaseArray> {
533    const {
534        assert!(align_of::<Base>() == align_of::<BaseArray>());
535        assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
536    }
537
538    let d = size_of::<BaseArray>() / size_of::<Base>();
539
540    assert!(
541        vec.len().is_multiple_of(d),
542        "Vector length (got {}) must be a multiple of the extension field dimension ({}).",
543        vec.len(),
544        d
545    );
546
547    let new_len = vec.len() / d;
548
549    // We could call vec.shrink_to_fit() here to try and increase the probability that
550    // the capacity is a multiple of d. That might cause a reallocation though which
551    // would defeat the whole purpose.
552    let cap = vec.capacity();
553
554    // The assumption is that basically all callers of `reconstitute_from_base_vec` will be calling it
555    // with a vector constructed from `flatten_to_base` and so the capacity should be a multiple of `d`.
556    // But capacities can do strange things so we need to support both possibilities.
557    // Note that the `else` branch would also work if the capacity is a multiple of `d` but it is slower.
558    if cap.is_multiple_of(d) {
559        // Prevent running `vec`'s destructor so we are in complete control
560        // of the allocation.
561        let mut values = ManuallyDrop::new(vec);
562
563        // If we are on this branch then the capacity is a multiple of `d`.
564        let new_cap = cap / d;
565
566        // Safe as BaseArray and Base have the same alignment.
567        let ptr = values.as_mut_ptr() as *mut BaseArray;
568
569        unsafe {
570            // Safety:
571            // - BaseArray and Base have the same alignment.
572            // - As size_of::<Base>() == size_of::<BaseArray>() / d:
573            //      -- If we have reached this point, the length and capacity are both divisible by `d`.
574            //      -- The capacity of the new vector is equal to the capacity of the old vector.
575            //      -- The first new_len elements of the new vector correspond to the first
576            //         len elements of the old vector and so are properly initialized.
577            Vec::from_raw_parts(ptr, new_len, new_cap)
578        }
579    } else {
580        // If the capacity is not a multiple of `D`, we go via slices.
581
582        let buf_ptr = vec.as_mut_ptr().cast::<BaseArray>();
583        let slice = unsafe {
584            // Safety:
585            // - BaseArray and Base have the same alignment.
586            // - As size_of::<Base>() == size_of::<BaseArray>() / D:
587            //      -- If we have reached this point, the length is divisible by `D`.
588            //      -- The first new_len elements of the slice correspond to the first
589            //         len elements of the old slice and so are properly initialized.
590            slice::from_raw_parts(buf_ptr, new_len)
591        };
592
593        // Ideally the compiler could optimize this away to avoid the copy but it appears not to.
594        slice.to_vec()
595    }
596}
597
598#[inline(always)]
599pub const fn relatively_prime_u64(mut u: u64, mut v: u64) -> bool {
600    // Check that neither input is 0.
601    if u == 0 || v == 0 {
602        return false;
603    }
604
605    // Check divisibility by 2.
606    if (u | v) & 1 == 0 {
607        return false;
608    }
609
610    // Remove factors of 2 from `u` and `v`
611    u >>= u.trailing_zeros();
612    if u == 1 {
613        return true;
614    }
615
616    while v != 0 {
617        v >>= v.trailing_zeros();
618        if v == 1 {
619            return true;
620        }
621
622        // Ensure u <= v
623        if u > v {
624            core::mem::swap(&mut u, &mut v);
625        }
626
627        // This looks inefficient for v >> u but thanks to the fact that we remove
628        // trailing_zeros of v in every iteration, it ends up much more performative
629        // than first glance implies.
630        v -= u;
631    }
632    // If we made it through the loop, at no point is u or v equal to 1 and so the gcd
633    // must be greater than 1.
634    false
635}
636
637/// Inner loop of the deferred GCD algorithm.
638///
639/// See: https://eprint.iacr.org/2020/972.pdf for more information.
640///
641/// This is basically a mini GCD algorithm which builds up a transformation to apply to the larger
642/// numbers in the main loop. The key point is that this small loop only uses u64s, subtractions and
643/// bit shifts, which are very fast operations.
644///
645/// The bottom `NUM_ROUNDS` bits of `a` and `b` should match the bottom `NUM_ROUNDS` bits of
646/// the corresponding big-ints and the top `NUM_ROUNDS + 2` should match the top bits including
647/// zeroes if the original numbers have different sizes.
648#[inline]
649pub fn gcd_inner<const NUM_ROUNDS: usize>(a: &mut u64, b: &mut u64) -> (i64, i64, i64, i64) {
650    // Initialise update factors.
651    // At the start of round 0: -1 < f0, g0, f1, g1 <= 1
652    let (mut f0, mut g0, mut f1, mut g1) = (1, 0, 0, 1);
653
654    // If at the start of a round: -2^i < f0, g0, f1, g1 <= 2^i
655    // Then, at the end of the round: -2^{i + 1} < f0, g0, f1, g1 <= 2^{i + 1}
656    for _ in 0..NUM_ROUNDS {
657        if *a & 1 == 0 {
658            *a >>= 1;
659        } else {
660            if a < b {
661                core::mem::swap(a, b);
662                (f0, f1) = (f1, f0);
663                (g0, g1) = (g1, g0);
664            }
665            *a -= *b;
666            *a >>= 1;
667            f0 -= f1;
668            g0 -= g1;
669        }
670        f1 <<= 1;
671        g1 <<= 1;
672    }
673
674    // -2^NUM_ROUNDS < f0, g0, f1, g1 <= 2^NUM_ROUNDS
675    // Hence provided NUM_ROUNDS <= 62, we will not get any overflow.
676    // Additionally, if NUM_ROUNDS <= 63, then the only source of overflow will be
677    // if a variable is meant to equal 2^{63} in which case it will overflow to -2^{63}.
678    (f0, g0, f1, g1)
679}
680
681/// Inverts elements inside the prime field `F_P` with `P < 2^FIELD_BITS`.
682///
683/// Arguments:
684///  - a: The value we want to invert. It must be < P.
685///  - b: The value of the prime `P > 2`.
686///
687/// Output:
688/// - A `64-bit` signed integer `v` equal to `2^{2 * FIELD_BITS - 2} a^{-1} mod P` with
689///   size `|v| < 2^{2 * FIELD_BITS - 2}`.
690///
691/// It is up to the user to ensure that `b` is an odd prime with at most `FIELD_BITS` bits and
692/// `a < b`. If either of these assumptions break, the output is undefined.
693#[inline]
694pub fn gcd_inversion_prime_field_32<const FIELD_BITS: u32>(mut a: u32, mut b: u32) -> i64 {
695    const {
696        assert!(FIELD_BITS <= 32);
697    }
698    debug_assert!(((1_u64 << FIELD_BITS) - 1) >= b as u64);
699
700    // Initialise u, v. Note that |u|, |v| <= 2^0
701    let (mut u, mut v) = (1_i64, 0_i64);
702
703    // Let a0 and P denote the initial values of a and b. Observe:
704    // `a = u * a0 mod P`
705    // `b = v * a0 mod P`
706    // `len(a) + len(b) <= 2 * len(P) <= 2 * FIELD_BITS`
707
708    for _ in 0..(2 * FIELD_BITS - 2) {
709        // Assume at the start of the loop i:
710        // (1) `|u|, |v| <= 2^{i}`
711        // (2) `2^i * a = u * a0 mod P`
712        // (3) `2^i * b = v * a0 mod P`
713        // (4) `gcd(a, b) = 1`
714        // (5) `b` is odd.
715        // (6) `len(a) + len(b) <= max(n - i, 1)`
716
717        if a & 1 != 0 {
718            if a < b {
719                (a, b) = (b, a);
720                (u, v) = (v, u);
721            }
722            // As b < a, this subtraction cannot increase `len(a) + len(b)`
723            a -= b;
724            // Observe |u'| = |u - v| <= |u| + |v| <= 2^{i + 1}
725            u -= v;
726
727            // As (1) and (2) hold, we have
728            // `2^i a' = 2^i * (a - b) = (u - v) * a0 mod P = u' * a0 mod P`
729        }
730        // As b is odd, a must now be even.
731        // This reduces `len(a) + len(b)` by 1 (unless `a = 0` in which case `b = 1` and the sum of the lengths is always 1)
732        a >>= 1;
733
734        // Observe |v'| = 2|v| <= 2^{i + 1}
735        v <<= 1;
736
737        // Thus as the end of loop i:
738        // (1) `|u|, |v| <= 2^{i + 1}`
739        // (2) `2^{i + 1} * a = u * a0 mod P`  (As we have halved a)
740        // (3) `2^{i + 1} * b = v * a0 mod P`  (As we have doubled v)
741        // (4) `gcd(a, b) = 1`
742        // (5) `b` is odd.
743        // (6) `len(a) + len(b) <= max(n - i - 1, 1)`
744    }
745
746    // After the loops, we see that:
747    // |u|, |v| <= 2^{2 * FIELD_BITS - 2}: Hence for FIELD_BITS <= 32 we will not overflow an i64.
748    // `2^{2 * FIELD_BITS - 2} * b = v * a0 mod P`
749    // `len(a) + len(b) <= 2` with `gcd(a, b) = 1` and `b` odd.
750    // This implies that `b` must be `1` and so `v = 2^{2 * FIELD_BITS - 2} a0^{-1} mod P` as desired.
751    v
752}
753
754#[cfg(test)]
755mod tests {
756    use alloc::vec;
757    use alloc::vec::Vec;
758
759    use rand::rngs::SmallRng;
760    use rand::{Rng, SeedableRng};
761
762    use super::*;
763
764    #[test]
765    fn test_reverse_bits_len() {
766        assert_eq!(reverse_bits_len(0b0000000000, 10), 0b0000000000);
767        assert_eq!(reverse_bits_len(0b0000000001, 10), 0b1000000000);
768        assert_eq!(reverse_bits_len(0b1000000000, 10), 0b0000000001);
769        assert_eq!(reverse_bits_len(0b00000, 5), 0b00000);
770        assert_eq!(reverse_bits_len(0b01011, 5), 0b11010);
771    }
772
773    #[test]
774    fn test_reverse_index_bits() {
775        let mut arg = vec![10, 20, 30, 40];
776        reverse_slice_index_bits(&mut arg);
777        assert_eq!(arg, vec![10, 30, 20, 40]);
778
779        let mut input256: Vec<u64> = (0..256).collect();
780        #[rustfmt::skip]
781        let output256: Vec<u64> = vec![
782            0x00, 0x80, 0x40, 0xc0, 0x20, 0xa0, 0x60, 0xe0, 0x10, 0x90, 0x50, 0xd0, 0x30, 0xb0, 0x70, 0xf0,
783            0x08, 0x88, 0x48, 0xc8, 0x28, 0xa8, 0x68, 0xe8, 0x18, 0x98, 0x58, 0xd8, 0x38, 0xb8, 0x78, 0xf8,
784            0x04, 0x84, 0x44, 0xc4, 0x24, 0xa4, 0x64, 0xe4, 0x14, 0x94, 0x54, 0xd4, 0x34, 0xb4, 0x74, 0xf4,
785            0x0c, 0x8c, 0x4c, 0xcc, 0x2c, 0xac, 0x6c, 0xec, 0x1c, 0x9c, 0x5c, 0xdc, 0x3c, 0xbc, 0x7c, 0xfc,
786            0x02, 0x82, 0x42, 0xc2, 0x22, 0xa2, 0x62, 0xe2, 0x12, 0x92, 0x52, 0xd2, 0x32, 0xb2, 0x72, 0xf2,
787            0x0a, 0x8a, 0x4a, 0xca, 0x2a, 0xaa, 0x6a, 0xea, 0x1a, 0x9a, 0x5a, 0xda, 0x3a, 0xba, 0x7a, 0xfa,
788            0x06, 0x86, 0x46, 0xc6, 0x26, 0xa6, 0x66, 0xe6, 0x16, 0x96, 0x56, 0xd6, 0x36, 0xb6, 0x76, 0xf6,
789            0x0e, 0x8e, 0x4e, 0xce, 0x2e, 0xae, 0x6e, 0xee, 0x1e, 0x9e, 0x5e, 0xde, 0x3e, 0xbe, 0x7e, 0xfe,
790            0x01, 0x81, 0x41, 0xc1, 0x21, 0xa1, 0x61, 0xe1, 0x11, 0x91, 0x51, 0xd1, 0x31, 0xb1, 0x71, 0xf1,
791            0x09, 0x89, 0x49, 0xc9, 0x29, 0xa9, 0x69, 0xe9, 0x19, 0x99, 0x59, 0xd9, 0x39, 0xb9, 0x79, 0xf9,
792            0x05, 0x85, 0x45, 0xc5, 0x25, 0xa5, 0x65, 0xe5, 0x15, 0x95, 0x55, 0xd5, 0x35, 0xb5, 0x75, 0xf5,
793            0x0d, 0x8d, 0x4d, 0xcd, 0x2d, 0xad, 0x6d, 0xed, 0x1d, 0x9d, 0x5d, 0xdd, 0x3d, 0xbd, 0x7d, 0xfd,
794            0x03, 0x83, 0x43, 0xc3, 0x23, 0xa3, 0x63, 0xe3, 0x13, 0x93, 0x53, 0xd3, 0x33, 0xb3, 0x73, 0xf3,
795            0x0b, 0x8b, 0x4b, 0xcb, 0x2b, 0xab, 0x6b, 0xeb, 0x1b, 0x9b, 0x5b, 0xdb, 0x3b, 0xbb, 0x7b, 0xfb,
796            0x07, 0x87, 0x47, 0xc7, 0x27, 0xa7, 0x67, 0xe7, 0x17, 0x97, 0x57, 0xd7, 0x37, 0xb7, 0x77, 0xf7,
797            0x0f, 0x8f, 0x4f, 0xcf, 0x2f, 0xaf, 0x6f, 0xef, 0x1f, 0x9f, 0x5f, 0xdf, 0x3f, 0xbf, 0x7f, 0xff,
798        ];
799        reverse_slice_index_bits(&mut input256[..]);
800        assert_eq!(input256, output256);
801    }
802
803    #[test]
804    fn test_apply_to_chunks_exact_fit() {
805        const CHUNK_SIZE: usize = 4;
806        let input: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8];
807        let mut results: Vec<Vec<u8>> = Vec::new();
808
809        apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
810            results.push(chunk.to_vec());
811        });
812
813        assert_eq!(results, vec![vec![1, 2, 3, 4], vec![5, 6, 7, 8]]);
814    }
815
816    #[test]
817    fn test_apply_to_chunks_with_remainder() {
818        const CHUNK_SIZE: usize = 3;
819        let input: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7];
820        let mut results: Vec<Vec<u8>> = Vec::new();
821
822        apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
823            results.push(chunk.to_vec());
824        });
825
826        assert_eq!(results, vec![vec![1, 2, 3], vec![4, 5, 6], vec![7]]);
827    }
828
829    #[test]
830    fn test_apply_to_chunks_empty_input() {
831        const CHUNK_SIZE: usize = 4;
832        let input: Vec<u8> = vec![];
833        let mut results: Vec<Vec<u8>> = Vec::new();
834
835        apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
836            results.push(chunk.to_vec());
837        });
838
839        assert!(results.is_empty());
840    }
841
842    #[test]
843    fn test_apply_to_chunks_single_chunk() {
844        const CHUNK_SIZE: usize = 10;
845        let input: Vec<u8> = vec![1, 2, 3, 4, 5];
846        let mut results: Vec<Vec<u8>> = Vec::new();
847
848        apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
849            results.push(chunk.to_vec());
850        });
851
852        assert_eq!(results, vec![vec![1, 2, 3, 4, 5]]);
853    }
854
855    #[test]
856    fn test_apply_to_chunks_large_chunk_size() {
857        const CHUNK_SIZE: usize = 100;
858        let input: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8];
859        let mut results: Vec<Vec<u8>> = Vec::new();
860
861        apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
862            results.push(chunk.to_vec());
863        });
864
865        assert_eq!(results, vec![vec![1, 2, 3, 4, 5, 6, 7, 8]]);
866    }
867
868    #[test]
869    fn test_apply_to_chunks_large_input() {
870        const CHUNK_SIZE: usize = 5;
871        let input: Vec<u8> = (1..=20).collect();
872        let mut results: Vec<Vec<u8>> = Vec::new();
873
874        apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
875            results.push(chunk.to_vec());
876        });
877
878        assert_eq!(
879            results,
880            vec![
881                vec![1, 2, 3, 4, 5],
882                vec![6, 7, 8, 9, 10],
883                vec![11, 12, 13, 14, 15],
884                vec![16, 17, 18, 19, 20]
885            ]
886        );
887    }
888
889    #[test]
890    fn test_reverse_slice_index_bits_random() {
891        let lengths = [32, 128, 1 << 16];
892        let mut rng = SmallRng::seed_from_u64(1);
893        for _ in 0..32 {
894            for &length in &lengths {
895                let mut rand_list: Vec<u32> = Vec::with_capacity(length);
896                rand_list.resize_with(length, || rng.random());
897                let expect = reverse_index_bits_naive(&rand_list);
898
899                let mut actual = rand_list.clone();
900                reverse_slice_index_bits(&mut actual);
901
902                assert_eq!(actual, expect);
903            }
904        }
905    }
906
907    #[test]
908    fn test_log2_strict_usize_edge_cases() {
909        assert_eq!(log2_strict_usize(1), 0);
910        assert_eq!(log2_strict_usize(2), 1);
911        assert_eq!(log2_strict_usize(1 << 18), 18);
912        assert_eq!(log2_strict_usize(1 << 31), 31);
913        assert_eq!(
914            log2_strict_usize(1 << (usize::BITS - 1)),
915            usize::BITS as usize - 1
916        );
917    }
918
919    #[test]
920    #[should_panic]
921    fn test_log2_strict_usize_zero() {
922        let _ = log2_strict_usize(0);
923    }
924
925    #[test]
926    #[should_panic]
927    fn test_log2_strict_usize_nonpower_2() {
928        let _ = log2_strict_usize(0x78c341c65ae6d262);
929    }
930
931    #[test]
932    #[should_panic]
933    fn test_log2_strict_usize_max() {
934        let _ = log2_strict_usize(usize::MAX);
935    }
936
937    #[test]
938    fn test_log2_ceil_usize_comprehensive() {
939        // Powers of 2
940        assert_eq!(log2_ceil_usize(0), 0);
941        assert_eq!(log2_ceil_usize(1), 0);
942        assert_eq!(log2_ceil_usize(2), 1);
943        assert_eq!(log2_ceil_usize(1 << 18), 18);
944        assert_eq!(log2_ceil_usize(1 << 31), 31);
945        assert_eq!(
946            log2_ceil_usize(1 << (usize::BITS - 1)),
947            usize::BITS as usize - 1
948        );
949
950        // Nonpowers; want to round up
951        assert_eq!(log2_ceil_usize(3), 2);
952        assert_eq!(log2_ceil_usize(0x14fe901b), 29);
953        assert_eq!(
954            log2_ceil_usize((1 << (usize::BITS - 1)) + 1),
955            usize::BITS as usize
956        );
957        assert_eq!(log2_ceil_usize(usize::MAX - 1), usize::BITS as usize);
958        assert_eq!(log2_ceil_usize(usize::MAX), usize::BITS as usize);
959    }
960
961    fn reverse_index_bits_naive<T: Copy>(arr: &[T]) -> Vec<T> {
962        let n = arr.len();
963        let n_power = log2_strict_usize(n);
964
965        let mut out = vec![None; n];
966        for (i, v) in arr.iter().enumerate() {
967            let dst = i.reverse_bits() >> (usize::BITS - n_power as u32);
968            out[dst] = Some(*v);
969        }
970
971        out.into_iter().map(|x| x.unwrap()).collect()
972    }
973
974    #[test]
975    fn test_relatively_prime_u64() {
976        // Zero cases (should always return false)
977        assert!(!relatively_prime_u64(0, 0));
978        assert!(!relatively_prime_u64(10, 0));
979        assert!(!relatively_prime_u64(0, 10));
980        assert!(!relatively_prime_u64(0, 123456789));
981
982        // Number with itself (if greater than 1, not relatively prime)
983        assert!(relatively_prime_u64(1, 1));
984        assert!(!relatively_prime_u64(10, 10));
985        assert!(!relatively_prime_u64(99999, 99999));
986
987        // Powers of 2 (always false since they share factor 2)
988        assert!(!relatively_prime_u64(2, 4));
989        assert!(!relatively_prime_u64(16, 32));
990        assert!(!relatively_prime_u64(64, 128));
991        assert!(!relatively_prime_u64(1024, 4096));
992        assert!(!relatively_prime_u64(u64::MAX, u64::MAX));
993
994        // One number is a multiple of the other (always false)
995        assert!(!relatively_prime_u64(5, 10));
996        assert!(!relatively_prime_u64(12, 36));
997        assert!(!relatively_prime_u64(15, 45));
998        assert!(!relatively_prime_u64(100, 500));
999
1000        // Co-prime numbers (should be true)
1001        assert!(relatively_prime_u64(17, 31));
1002        assert!(relatively_prime_u64(97, 43));
1003        assert!(relatively_prime_u64(7919, 65537));
1004        assert!(relatively_prime_u64(15485863, 32452843));
1005
1006        // Small prime numbers (should be true)
1007        assert!(relatively_prime_u64(13, 17));
1008        assert!(relatively_prime_u64(101, 103));
1009        assert!(relatively_prime_u64(1009, 1013));
1010
1011        // Large numbers (some cases where they are relatively prime or not)
1012        assert!(!relatively_prime_u64(
1013            190266297176832000,
1014            10430732356495263744
1015        ));
1016        assert!(!relatively_prime_u64(
1017            2040134905096275968,
1018            5701159354248194048
1019        ));
1020        assert!(!relatively_prime_u64(
1021            16611311494648745984,
1022            7514969329383038976
1023        ));
1024        assert!(!relatively_prime_u64(
1025            14863931409971066880,
1026            7911906750992527360
1027        ));
1028
1029        // Max values
1030        assert!(relatively_prime_u64(u64::MAX, 1));
1031        assert!(relatively_prime_u64(u64::MAX, u64::MAX - 1));
1032        assert!(!relatively_prime_u64(u64::MAX, u64::MAX));
1033    }
1034}