p3_util/lib.rs
1//! Various simple utilities.
2
3#![no_std]
4
5extern crate alloc;
6
7use alloc::slice;
8use alloc::string::String;
9use alloc::vec::Vec;
10use core::any::type_name;
11use core::hint::unreachable_unchecked;
12use core::mem::{ManuallyDrop, MaybeUninit};
13use core::{iter, mem};
14
15use crate::transpose::transpose_in_place_square;
16
17pub mod array_serialization;
18pub mod linear_map;
19pub mod transpose;
20pub mod zip_eq;
21
22/// Computes `ceil(log_2(n))`.
23#[must_use]
24pub const fn log2_ceil_usize(n: usize) -> usize {
25 (usize::BITS - n.saturating_sub(1).leading_zeros()) as usize
26}
27
28#[must_use]
29pub fn log2_ceil_u64(n: u64) -> u64 {
30 (u64::BITS - n.saturating_sub(1).leading_zeros()).into()
31}
32
33/// Computes `log_2(n)`
34///
35/// # Panics
36/// Panics if `n` is not a power of two.
37#[must_use]
38#[inline]
39pub fn log2_strict_usize(n: usize) -> usize {
40 let res = n.trailing_zeros();
41 assert_eq!(n.wrapping_shr(res), 1, "Not a power of two: {n}");
42 // Tell the optimizer about the semantics of `log2_strict`. i.e. it can replace `n` with
43 // `1 << res` and vice versa.
44 unsafe {
45 assume(n == 1 << res);
46 }
47 res as usize
48}
49
50/// Returns `[0, ..., N - 1]`.
51#[must_use]
52pub const fn indices_arr<const N: usize>() -> [usize; N] {
53 let mut indices_arr = [0; N];
54 let mut i = 0;
55 while i < N {
56 indices_arr[i] = i;
57 i += 1;
58 }
59 indices_arr
60}
61
62#[inline]
63pub const fn reverse_bits(x: usize, n: usize) -> usize {
64 // Assert that n is a power of 2
65 debug_assert!(n.is_power_of_two());
66 reverse_bits_len(x, n.trailing_zeros() as usize)
67}
68
69#[inline]
70pub const fn reverse_bits_len(x: usize, bit_len: usize) -> usize {
71 // NB: The only reason we need overflowing_shr() here as opposed
72 // to plain '>>' is to accommodate the case n == num_bits == 0,
73 // which would become `0 >> 64`. Rust thinks that any shift of 64
74 // bits causes overflow, even when the argument is zero.
75 x.reverse_bits()
76 .overflowing_shr(usize::BITS - bit_len as u32)
77 .0
78}
79
80// Lookup table of 6-bit reverses.
81// NB: 2^6=64 bytes is a cache line. A smaller table wastes cache space.
82#[cfg(not(target_arch = "aarch64"))]
83#[rustfmt::skip]
84const BIT_REVERSE_6BIT: &[u8] = &[
85 0o00, 0o40, 0o20, 0o60, 0o10, 0o50, 0o30, 0o70,
86 0o04, 0o44, 0o24, 0o64, 0o14, 0o54, 0o34, 0o74,
87 0o02, 0o42, 0o22, 0o62, 0o12, 0o52, 0o32, 0o72,
88 0o06, 0o46, 0o26, 0o66, 0o16, 0o56, 0o36, 0o76,
89 0o01, 0o41, 0o21, 0o61, 0o11, 0o51, 0o31, 0o71,
90 0o05, 0o45, 0o25, 0o65, 0o15, 0o55, 0o35, 0o75,
91 0o03, 0o43, 0o23, 0o63, 0o13, 0o53, 0o33, 0o73,
92 0o07, 0o47, 0o27, 0o67, 0o17, 0o57, 0o37, 0o77,
93];
94
95// Ensure that SMALL_ARR_SIZE >= 4 * BIG_T_SIZE.
96const BIG_T_SIZE: usize = 1 << 14;
97const SMALL_ARR_SIZE: usize = 1 << 16;
98
99/// Permutes `arr` such that each index is mapped to its reverse in binary.
100///
101/// If the whole array fits in fast cache, then the trivial algorithm is cache friendly. Also, if
102/// `T` is really big, then the trivial algorithm is cache-friendly, no matter the size of the array.
103pub fn reverse_slice_index_bits<F>(vals: &mut [F])
104where
105 F: Copy + Send + Sync,
106{
107 let n = vals.len();
108 if n == 0 {
109 return;
110 }
111 let log_n = log2_strict_usize(n);
112
113 // If the whole array fits in fast cache, then the trivial algorithm is cache friendly. Also, if
114 // `T` is really big, then the trivial algorithm is cache-friendly, no matter the size of the array.
115 if core::mem::size_of::<F>() << log_n <= SMALL_ARR_SIZE
116 || core::mem::size_of::<F>() >= BIG_T_SIZE
117 {
118 reverse_slice_index_bits_small(vals, log_n);
119 } else {
120 debug_assert!(n >= 4); // By our choice of `BIG_T_SIZE` and `SMALL_ARR_SIZE`.
121
122 // Algorithm:
123 //
124 // Treat `arr` as a `sqrt(n)` by `sqrt(n)` row-major matrix. (Assume for now that `lb_n` is
125 // even, i.e., `n` is a square number.) To perform bit-order reversal we:
126 // 1. Bit-reverse the order of the rows. (They are contiguous in memory, so this is
127 // basically a series of large `memcpy`s.)
128 // 2. Transpose the matrix.
129 // 3. Bit-reverse the order of the rows.
130 //
131 // This is equivalent to, for every index `0 <= i < n`:
132 // 1. bit-reversing `i[lb_n / 2..lb_n]`,
133 // 2. swapping `i[0..lb_n / 2]` and `i[lb_n / 2..lb_n]`,
134 // 3. bit-reversing `i[lb_n / 2..lb_n]`.
135 //
136 // If `lb_n` is odd, i.e., `n` is not a square number, then the above procedure requires
137 // slight modification. At steps 1 and 3 we bit-reverse bits `ceil(lb_n / 2)..lb_n`, of the
138 // index (shuffling `floor(lb_n / 2)` chunks of length `ceil(lb_n / 2)`). At step 2, we
139 // perform _two_ transposes. We treat `arr` as two matrices, one where the middle bit of the
140 // index is `0` and another, where the middle bit is `1`; we transpose each individually.
141
142 let lb_num_chunks = log_n >> 1;
143 let lb_chunk_size = log_n - lb_num_chunks;
144 unsafe {
145 reverse_slice_index_bits_chunks(vals, lb_num_chunks, lb_chunk_size);
146 transpose_in_place_square(vals, lb_chunk_size, lb_num_chunks, 0);
147 if lb_num_chunks != lb_chunk_size {
148 // `arr` cannot be interpreted as a square matrix. We instead interpret it as a
149 // `1 << lb_num_chunks` by `2` by `1 << lb_num_chunks` tensor, in row-major order.
150 // The above transpose acted on `tensor[..., 0, ...]` (all indices with middle bit
151 // `0`). We still need to transpose `tensor[..., 1, ...]`. To do so, we advance
152 // arr by `1 << lb_num_chunks` effectively, adding that to every index.
153 let vals_with_offset = &mut vals[1 << lb_num_chunks..];
154 transpose_in_place_square(vals_with_offset, lb_chunk_size, lb_num_chunks, 0);
155 }
156 reverse_slice_index_bits_chunks(vals, lb_num_chunks, lb_chunk_size);
157 }
158 }
159}
160
161// Both functions below are semantically equivalent to:
162// for i in 0..n {
163// result.push(arr[reverse_bits(i, n_power)]);
164// }
165// where reverse_bits(i, n_power) computes the n_power-bit reverse. The complications are there
166// to guide the compiler to generate optimal assembly.
167
168#[cfg(not(target_arch = "aarch64"))]
169fn reverse_slice_index_bits_small<F>(vals: &mut [F], lb_n: usize) {
170 if lb_n <= 6 {
171 // BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them lb_n-bit reverses.
172 let dst_shr_amt = 6 - lb_n as u32;
173 #[allow(clippy::needless_range_loop)]
174 for src in 0..vals.len() {
175 let dst = (BIT_REVERSE_6BIT[src] as usize).wrapping_shr(dst_shr_amt);
176 if src < dst {
177 vals.swap(src, dst);
178 }
179 }
180 } else {
181 // LLVM does not know that it does not need to reverse src at each iteration (which is
182 // expensive on x86). We take advantage of the fact that the low bits of dst change rarely and the high
183 // bits of dst are dependent only on the low bits of src.
184 let dst_lo_shr_amt = usize::BITS - (lb_n - 6) as u32;
185 let dst_hi_shl_amt = lb_n - 6;
186 for src_chunk in 0..(vals.len() >> 6) {
187 let src_hi = src_chunk << 6;
188 let dst_lo = src_chunk.reverse_bits().wrapping_shr(dst_lo_shr_amt);
189 #[allow(clippy::needless_range_loop)]
190 for src_lo in 0..(1 << 6) {
191 let dst_hi = (BIT_REVERSE_6BIT[src_lo] as usize) << dst_hi_shl_amt;
192 let src = src_hi + src_lo;
193 let dst = dst_hi + dst_lo;
194 if src < dst {
195 vals.swap(src, dst);
196 }
197 }
198 }
199 }
200}
201
202#[cfg(target_arch = "aarch64")]
203fn reverse_slice_index_bits_small<F>(vals: &mut [F], lb_n: usize) {
204 // Aarch64 can reverse bits in one instruction, so the trivial version works best.
205 for src in 0..vals.len() {
206 let dst = src.reverse_bits().wrapping_shr(usize::BITS - lb_n as u32);
207 if src < dst {
208 vals.swap(src, dst);
209 }
210 }
211}
212
213/// Split `arr` chunks and bit-reverse the order of the chunks. There are `1 << lb_num_chunks`
214/// chunks, each of length `1 << lb_chunk_size`.
215/// SAFETY: ensure that `arr.len() == 1 << lb_num_chunks + lb_chunk_size`.
216unsafe fn reverse_slice_index_bits_chunks<F>(
217 vals: &mut [F],
218 lb_num_chunks: usize,
219 lb_chunk_size: usize,
220) {
221 for i in 0..1usize << lb_num_chunks {
222 // `wrapping_shr` handles the silly case when `lb_num_chunks == 0`.
223 let j = i
224 .reverse_bits()
225 .wrapping_shr(usize::BITS - lb_num_chunks as u32);
226 if i < j {
227 unsafe {
228 core::ptr::swap_nonoverlapping(
229 vals.get_unchecked_mut(i << lb_chunk_size),
230 vals.get_unchecked_mut(j << lb_chunk_size),
231 1 << lb_chunk_size,
232 );
233 }
234 }
235 }
236}
237
238/// Allow the compiler to assume that the given predicate `p` is always `true`.
239///
240/// # Safety
241///
242/// Callers must ensure that `p` is true. If this is not the case, the behavior is undefined.
243#[inline(always)]
244pub unsafe fn assume(p: bool) {
245 debug_assert!(p);
246 if !p {
247 unsafe {
248 unreachable_unchecked();
249 }
250 }
251}
252
253/// Try to force Rust to emit a branch. Example:
254///
255/// ```no_run
256/// let x = 100;
257/// if x > 20 {
258/// println!("x is big!");
259/// p3_util::branch_hint();
260/// } else {
261/// println!("x is small!");
262/// }
263/// ```
264///
265/// This function has no semantics. It is a hint only.
266#[inline(always)]
267pub fn branch_hint() {
268 // NOTE: These are the currently supported assembly architectures. See the
269 // [nightly reference](https://doc.rust-lang.org/nightly/reference/inline-assembly.html) for
270 // the most up-to-date list.
271 #[cfg(any(
272 target_arch = "aarch64",
273 target_arch = "arm",
274 target_arch = "riscv32",
275 target_arch = "riscv64",
276 target_arch = "x86",
277 target_arch = "x86_64",
278 ))]
279 unsafe {
280 core::arch::asm!("", options(nomem, nostack, preserves_flags));
281 }
282}
283
284/// Return a String containing the name of T but with all the crate
285/// and module prefixes removed.
286pub fn pretty_name<T>() -> String {
287 let name = type_name::<T>();
288 let mut result = String::new();
289 for qual in name.split_inclusive(&['<', '>', ',']) {
290 result.push_str(qual.split("::").last().unwrap());
291 }
292 result
293}
294
295/// A C-style buffered input reader, similar to
296/// `core::iter::Iterator::next_chunk()` from nightly.
297///
298/// Returns an array of `MaybeUninit<T>` and the number of items in the
299/// array which have been correctly initialized.
300#[inline]
301fn iter_next_chunk_erased<const BUFLEN: usize, I: Iterator>(
302 iter: &mut I,
303) -> ([MaybeUninit<I::Item>; BUFLEN], usize)
304where
305 I::Item: Copy,
306{
307 let mut buf = [const { MaybeUninit::<I::Item>::uninit() }; BUFLEN];
308 let mut i = 0;
309
310 while i < BUFLEN {
311 if let Some(c) = iter.next() {
312 // Copy the next Item into `buf`.
313 unsafe {
314 buf.get_unchecked_mut(i).write(c);
315 i = i.unchecked_add(1);
316 }
317 } else {
318 // No more items in the iterator.
319 break;
320 }
321 }
322 (buf, i)
323}
324
325/// Gets a shared reference to the contained value.
326///
327/// # Safety
328///
329/// Calling this when the content is not yet fully initialized causes undefined
330/// behavior: it is up to the caller to guarantee that every `MaybeUninit<T>` in
331/// the slice really is in an initialized state.
332///
333/// Copied from:
334/// https://doc.rust-lang.org/std/primitive.slice.html#method.assume_init_ref
335/// Once that is stabilized, this should be removed.
336#[inline(always)]
337pub const unsafe fn assume_init_ref<T>(slice: &[MaybeUninit<T>]) -> &[T] {
338 // SAFETY: casting `slice` to a `*const [T]` is safe since the caller guarantees that
339 // `slice` is initialized, and `MaybeUninit` is guaranteed to have the same layout as `T`.
340 // The pointer obtained is valid since it refers to memory owned by `slice` which is a
341 // reference and thus guaranteed to be valid for reads.
342 unsafe { &*(slice as *const [MaybeUninit<T>] as *const [T]) }
343}
344
345/// Split an iterator into small arrays and apply `func` to each.
346///
347/// Repeatedly read `BUFLEN` elements from `input` into an array and
348/// pass the array to `func` as a slice. If less than `BUFLEN`
349/// elements are remaining, that smaller slice is passed to `func` (if
350/// it is non-empty) and the function returns.
351#[inline]
352pub fn apply_to_chunks<const BUFLEN: usize, I, H>(input: I, mut func: H)
353where
354 I: IntoIterator<Item = u8>,
355 H: FnMut(&[I::Item]),
356{
357 let mut iter = input.into_iter();
358 loop {
359 let (buf, n) = iter_next_chunk_erased::<BUFLEN, _>(&mut iter);
360 if n == 0 {
361 break;
362 }
363 func(unsafe { assume_init_ref(buf.get_unchecked(..n)) });
364 }
365}
366
367/// Pulls `N` items from `iter` and returns them as an array. If the iterator
368/// yields fewer than `N` items (but more than `0`), pads by the given default value.
369///
370/// Since the iterator is passed as a mutable reference and this function calls
371/// `next` at most `N` times, the iterator can still be used afterwards to
372/// retrieve the remaining items.
373///
374/// If `iter.next()` panics, all items already yielded by the iterator are
375/// dropped.
376#[inline]
377fn iter_next_chunk_padded<T: Copy, const N: usize>(
378 iter: &mut impl Iterator<Item = T>,
379 default: T, // Needed due to [T; M] not always implementing Default. Can probably be dropped if const generics stabilize.
380) -> Option<[T; N]> {
381 let (mut arr, n) = iter_next_chunk_erased::<N, _>(iter);
382 (n != 0).then(|| {
383 // Fill the rest of the array with default values.
384 arr[n..].fill(MaybeUninit::new(default));
385 unsafe { mem::transmute_copy::<_, [T; N]>(&arr) }
386 })
387}
388
389/// Returns an iterator over `N` elements of the iterator at a time.
390///
391/// The chunks do not overlap. If `N` does not divide the length of the
392/// iterator, then the last `N-1` elements will be padded with the given default value.
393///
394/// This is essentially a copy pasted version of the nightly `array_chunks` function.
395/// https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.array_chunks
396/// Once that is stabilized this and the functions above it should be removed.
397#[inline]
398pub fn iter_array_chunks_padded<T: Copy, const N: usize>(
399 iter: impl IntoIterator<Item = T>,
400 default: T, // Needed due to [T; M] not always implementing Default. Can probably be dropped if const generics stabilize.
401) -> impl Iterator<Item = [T; N]> {
402 let mut iter = iter.into_iter();
403 iter::from_fn(move || iter_next_chunk_padded(&mut iter, default))
404}
405
406/// Reinterpret a slice of `BaseArray` elements as a slice of `Base` elements
407///
408/// This is useful to convert `&[F; N]` to `&[F]` or `&[A]` to `&[F]` where
409/// `A` has the same size, alignment and memory layout as `[F; N]` for some `N`.
410///
411/// # Safety
412///
413/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
414/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
415/// the array is the same as the alignment of its elements, this means that `BaseArray`
416/// must have the same alignment as `Base`.
417///
418/// # Panics
419///
420/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
421#[inline]
422pub const unsafe fn as_base_slice<Base, BaseArray>(buf: &[BaseArray]) -> &[Base] {
423 const {
424 assert!(align_of::<Base>() == align_of::<BaseArray>());
425 assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
426 }
427
428 let d = size_of::<BaseArray>() / size_of::<Base>();
429
430 let buf_ptr = buf.as_ptr().cast::<Base>();
431 let n = buf.len() * d;
432 unsafe { slice::from_raw_parts(buf_ptr, n) }
433}
434
435/// Reinterpret a mutable slice of `BaseArray` elements as a slice of `Base` elements
436///
437/// This is useful to convert `&[F; N]` to `&[F]` or `&[A]` to `&[F]` where
438/// `A` has the same size, alignment and memory layout as `[F; N]` for some `N`.
439///
440/// # Safety
441///
442/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
443/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
444/// the array is the same as the alignment of its elements, this means that `BaseArray`
445/// must have the same alignment as `Base`.
446///
447/// # Panics
448///
449/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
450#[inline]
451pub const unsafe fn as_base_slice_mut<Base, BaseArray>(buf: &mut [BaseArray]) -> &mut [Base] {
452 const {
453 assert!(align_of::<Base>() == align_of::<BaseArray>());
454 assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
455 }
456
457 let d = size_of::<BaseArray>() / size_of::<Base>();
458
459 let buf_ptr = buf.as_mut_ptr().cast::<Base>();
460 let n = buf.len() * d;
461 unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
462}
463
464/// Convert a vector of `BaseArray` elements to a vector of `Base` elements without any
465/// reallocations.
466///
467/// This is useful to convert `Vec<[F; N]>` to `Vec<F>` or `Vec<A>` to `Vec<F>` where
468/// `A` has the same size, alignment and memory layout as `[F; N]` for some `N`. It can also,
469/// be used to safely convert `Vec<u32>` to `Vec<F>` if `F` is a `32` bit field
470/// or `Vec<u64>` to `Vec<F>` if `F` is a `64` bit field.
471///
472/// # Safety
473///
474/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
475/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
476/// the array is the same as the alignment of its elements, this means that `BaseArray`
477/// must have the same alignment as `Base`.
478///
479/// # Panics
480///
481/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
482#[inline]
483pub unsafe fn flatten_to_base<Base, BaseArray>(vec: Vec<BaseArray>) -> Vec<Base> {
484 const {
485 assert!(align_of::<Base>() == align_of::<BaseArray>());
486 assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
487 }
488
489 let d = size_of::<BaseArray>() / size_of::<Base>();
490 // Prevent running `vec`'s destructor so we are in complete control
491 // of the allocation.
492 let mut values = ManuallyDrop::new(vec);
493
494 // Each `Self` is an array of `d` elements, so the length and capacity of
495 // the new vector will be multiplied by `d`.
496 let new_len = values.len() * d;
497 let new_cap = values.capacity() * d;
498
499 // Safe as BaseArray and Base have the same alignment.
500 let ptr = values.as_mut_ptr() as *mut Base;
501
502 unsafe {
503 // Safety:
504 // - BaseArray and Base have the same alignment.
505 // - As size_of::<BaseArray>() == size_of::<Base>() * d:
506 // -- The capacity of the new vector is equal to the capacity of the old vector.
507 // -- The first new_len elements of the new vector correspond to the first
508 // len elements of the old vector and so are properly initialized.
509 Vec::from_raw_parts(ptr, new_len, new_cap)
510 }
511}
512
513/// Convert a vector of `Base` elements to a vector of `BaseArray` elements ideally without any
514/// reallocations.
515///
516/// This is an inverse of `flatten_to_base`. Unfortunately, unlike `flatten_to_base`, it may not be
517/// possible to avoid allocations. This issue is that there is not way to guarantee that the capacity
518/// of the vector is a multiple of `d`.
519///
520/// # Safety
521///
522/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
523/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
524/// the array is the same as the alignment of its elements, this means that `BaseArray`
525/// must have the same alignment as `Base`.
526///
527/// # Panics
528///
529/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
530/// This panics if the length of the vector is not a multiple of the ratio of the sizes.
531#[inline]
532pub unsafe fn reconstitute_from_base<Base, BaseArray: Clone>(mut vec: Vec<Base>) -> Vec<BaseArray> {
533 const {
534 assert!(align_of::<Base>() == align_of::<BaseArray>());
535 assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
536 }
537
538 let d = size_of::<BaseArray>() / size_of::<Base>();
539
540 assert!(
541 vec.len().is_multiple_of(d),
542 "Vector length (got {}) must be a multiple of the extension field dimension ({}).",
543 vec.len(),
544 d
545 );
546
547 let new_len = vec.len() / d;
548
549 // We could call vec.shrink_to_fit() here to try and increase the probability that
550 // the capacity is a multiple of d. That might cause a reallocation though which
551 // would defeat the whole purpose.
552 let cap = vec.capacity();
553
554 // The assumption is that basically all callers of `reconstitute_from_base_vec` will be calling it
555 // with a vector constructed from `flatten_to_base` and so the capacity should be a multiple of `d`.
556 // But capacities can do strange things so we need to support both possibilities.
557 // Note that the `else` branch would also work if the capacity is a multiple of `d` but it is slower.
558 if cap.is_multiple_of(d) {
559 // Prevent running `vec`'s destructor so we are in complete control
560 // of the allocation.
561 let mut values = ManuallyDrop::new(vec);
562
563 // If we are on this branch then the capacity is a multiple of `d`.
564 let new_cap = cap / d;
565
566 // Safe as BaseArray and Base have the same alignment.
567 let ptr = values.as_mut_ptr() as *mut BaseArray;
568
569 unsafe {
570 // Safety:
571 // - BaseArray and Base have the same alignment.
572 // - As size_of::<Base>() == size_of::<BaseArray>() / d:
573 // -- If we have reached this point, the length and capacity are both divisible by `d`.
574 // -- The capacity of the new vector is equal to the capacity of the old vector.
575 // -- The first new_len elements of the new vector correspond to the first
576 // len elements of the old vector and so are properly initialized.
577 Vec::from_raw_parts(ptr, new_len, new_cap)
578 }
579 } else {
580 // If the capacity is not a multiple of `D`, we go via slices.
581
582 let buf_ptr = vec.as_mut_ptr().cast::<BaseArray>();
583 let slice = unsafe {
584 // Safety:
585 // - BaseArray and Base have the same alignment.
586 // - As size_of::<Base>() == size_of::<BaseArray>() / D:
587 // -- If we have reached this point, the length is divisible by `D`.
588 // -- The first new_len elements of the slice correspond to the first
589 // len elements of the old slice and so are properly initialized.
590 slice::from_raw_parts(buf_ptr, new_len)
591 };
592
593 // Ideally the compiler could optimize this away to avoid the copy but it appears not to.
594 slice.to_vec()
595 }
596}
597
598#[inline(always)]
599pub const fn relatively_prime_u64(mut u: u64, mut v: u64) -> bool {
600 // Check that neither input is 0.
601 if u == 0 || v == 0 {
602 return false;
603 }
604
605 // Check divisibility by 2.
606 if (u | v) & 1 == 0 {
607 return false;
608 }
609
610 // Remove factors of 2 from `u` and `v`
611 u >>= u.trailing_zeros();
612 if u == 1 {
613 return true;
614 }
615
616 while v != 0 {
617 v >>= v.trailing_zeros();
618 if v == 1 {
619 return true;
620 }
621
622 // Ensure u <= v
623 if u > v {
624 core::mem::swap(&mut u, &mut v);
625 }
626
627 // This looks inefficient for v >> u but thanks to the fact that we remove
628 // trailing_zeros of v in every iteration, it ends up much more performative
629 // than first glance implies.
630 v -= u;
631 }
632 // If we made it through the loop, at no point is u or v equal to 1 and so the gcd
633 // must be greater than 1.
634 false
635}
636
637/// Inner loop of the deferred GCD algorithm.
638///
639/// See: https://eprint.iacr.org/2020/972.pdf for more information.
640///
641/// This is basically a mini GCD algorithm which builds up a transformation to apply to the larger
642/// numbers in the main loop. The key point is that this small loop only uses u64s, subtractions and
643/// bit shifts, which are very fast operations.
644///
645/// The bottom `NUM_ROUNDS` bits of `a` and `b` should match the bottom `NUM_ROUNDS` bits of
646/// the corresponding big-ints and the top `NUM_ROUNDS + 2` should match the top bits including
647/// zeroes if the original numbers have different sizes.
648#[inline]
649pub fn gcd_inner<const NUM_ROUNDS: usize>(a: &mut u64, b: &mut u64) -> (i64, i64, i64, i64) {
650 // Initialise update factors.
651 // At the start of round 0: -1 < f0, g0, f1, g1 <= 1
652 let (mut f0, mut g0, mut f1, mut g1) = (1, 0, 0, 1);
653
654 // If at the start of a round: -2^i < f0, g0, f1, g1 <= 2^i
655 // Then, at the end of the round: -2^{i + 1} < f0, g0, f1, g1 <= 2^{i + 1}
656 for _ in 0..NUM_ROUNDS {
657 if *a & 1 == 0 {
658 *a >>= 1;
659 } else {
660 if a < b {
661 core::mem::swap(a, b);
662 (f0, f1) = (f1, f0);
663 (g0, g1) = (g1, g0);
664 }
665 *a -= *b;
666 *a >>= 1;
667 f0 -= f1;
668 g0 -= g1;
669 }
670 f1 <<= 1;
671 g1 <<= 1;
672 }
673
674 // -2^NUM_ROUNDS < f0, g0, f1, g1 <= 2^NUM_ROUNDS
675 // Hence provided NUM_ROUNDS <= 62, we will not get any overflow.
676 // Additionally, if NUM_ROUNDS <= 63, then the only source of overflow will be
677 // if a variable is meant to equal 2^{63} in which case it will overflow to -2^{63}.
678 (f0, g0, f1, g1)
679}
680
681/// Inverts elements inside the prime field `F_P` with `P < 2^FIELD_BITS`.
682///
683/// Arguments:
684/// - a: The value we want to invert. It must be < P.
685/// - b: The value of the prime `P > 2`.
686///
687/// Output:
688/// - A `64-bit` signed integer `v` equal to `2^{2 * FIELD_BITS - 2} a^{-1} mod P` with
689/// size `|v| < 2^{2 * FIELD_BITS - 2}`.
690///
691/// It is up to the user to ensure that `b` is an odd prime with at most `FIELD_BITS` bits and
692/// `a < b`. If either of these assumptions break, the output is undefined.
693#[inline]
694pub fn gcd_inversion_prime_field_32<const FIELD_BITS: u32>(mut a: u32, mut b: u32) -> i64 {
695 const {
696 assert!(FIELD_BITS <= 32);
697 }
698 debug_assert!(((1_u64 << FIELD_BITS) - 1) >= b as u64);
699
700 // Initialise u, v. Note that |u|, |v| <= 2^0
701 let (mut u, mut v) = (1_i64, 0_i64);
702
703 // Let a0 and P denote the initial values of a and b. Observe:
704 // `a = u * a0 mod P`
705 // `b = v * a0 mod P`
706 // `len(a) + len(b) <= 2 * len(P) <= 2 * FIELD_BITS`
707
708 for _ in 0..(2 * FIELD_BITS - 2) {
709 // Assume at the start of the loop i:
710 // (1) `|u|, |v| <= 2^{i}`
711 // (2) `2^i * a = u * a0 mod P`
712 // (3) `2^i * b = v * a0 mod P`
713 // (4) `gcd(a, b) = 1`
714 // (5) `b` is odd.
715 // (6) `len(a) + len(b) <= max(n - i, 1)`
716
717 if a & 1 != 0 {
718 if a < b {
719 (a, b) = (b, a);
720 (u, v) = (v, u);
721 }
722 // As b < a, this subtraction cannot increase `len(a) + len(b)`
723 a -= b;
724 // Observe |u'| = |u - v| <= |u| + |v| <= 2^{i + 1}
725 u -= v;
726
727 // As (1) and (2) hold, we have
728 // `2^i a' = 2^i * (a - b) = (u - v) * a0 mod P = u' * a0 mod P`
729 }
730 // As b is odd, a must now be even.
731 // This reduces `len(a) + len(b)` by 1 (unless `a = 0` in which case `b = 1` and the sum of the lengths is always 1)
732 a >>= 1;
733
734 // Observe |v'| = 2|v| <= 2^{i + 1}
735 v <<= 1;
736
737 // Thus as the end of loop i:
738 // (1) `|u|, |v| <= 2^{i + 1}`
739 // (2) `2^{i + 1} * a = u * a0 mod P` (As we have halved a)
740 // (3) `2^{i + 1} * b = v * a0 mod P` (As we have doubled v)
741 // (4) `gcd(a, b) = 1`
742 // (5) `b` is odd.
743 // (6) `len(a) + len(b) <= max(n - i - 1, 1)`
744 }
745
746 // After the loops, we see that:
747 // |u|, |v| <= 2^{2 * FIELD_BITS - 2}: Hence for FIELD_BITS <= 32 we will not overflow an i64.
748 // `2^{2 * FIELD_BITS - 2} * b = v * a0 mod P`
749 // `len(a) + len(b) <= 2` with `gcd(a, b) = 1` and `b` odd.
750 // This implies that `b` must be `1` and so `v = 2^{2 * FIELD_BITS - 2} a0^{-1} mod P` as desired.
751 v
752}
753
754#[cfg(test)]
755mod tests {
756 use alloc::vec;
757 use alloc::vec::Vec;
758
759 use rand::rngs::SmallRng;
760 use rand::{Rng, SeedableRng};
761
762 use super::*;
763
764 #[test]
765 fn test_reverse_bits_len() {
766 assert_eq!(reverse_bits_len(0b0000000000, 10), 0b0000000000);
767 assert_eq!(reverse_bits_len(0b0000000001, 10), 0b1000000000);
768 assert_eq!(reverse_bits_len(0b1000000000, 10), 0b0000000001);
769 assert_eq!(reverse_bits_len(0b00000, 5), 0b00000);
770 assert_eq!(reverse_bits_len(0b01011, 5), 0b11010);
771 }
772
773 #[test]
774 fn test_reverse_index_bits() {
775 let mut arg = vec![10, 20, 30, 40];
776 reverse_slice_index_bits(&mut arg);
777 assert_eq!(arg, vec![10, 30, 20, 40]);
778
779 let mut input256: Vec<u64> = (0..256).collect();
780 #[rustfmt::skip]
781 let output256: Vec<u64> = vec![
782 0x00, 0x80, 0x40, 0xc0, 0x20, 0xa0, 0x60, 0xe0, 0x10, 0x90, 0x50, 0xd0, 0x30, 0xb0, 0x70, 0xf0,
783 0x08, 0x88, 0x48, 0xc8, 0x28, 0xa8, 0x68, 0xe8, 0x18, 0x98, 0x58, 0xd8, 0x38, 0xb8, 0x78, 0xf8,
784 0x04, 0x84, 0x44, 0xc4, 0x24, 0xa4, 0x64, 0xe4, 0x14, 0x94, 0x54, 0xd4, 0x34, 0xb4, 0x74, 0xf4,
785 0x0c, 0x8c, 0x4c, 0xcc, 0x2c, 0xac, 0x6c, 0xec, 0x1c, 0x9c, 0x5c, 0xdc, 0x3c, 0xbc, 0x7c, 0xfc,
786 0x02, 0x82, 0x42, 0xc2, 0x22, 0xa2, 0x62, 0xe2, 0x12, 0x92, 0x52, 0xd2, 0x32, 0xb2, 0x72, 0xf2,
787 0x0a, 0x8a, 0x4a, 0xca, 0x2a, 0xaa, 0x6a, 0xea, 0x1a, 0x9a, 0x5a, 0xda, 0x3a, 0xba, 0x7a, 0xfa,
788 0x06, 0x86, 0x46, 0xc6, 0x26, 0xa6, 0x66, 0xe6, 0x16, 0x96, 0x56, 0xd6, 0x36, 0xb6, 0x76, 0xf6,
789 0x0e, 0x8e, 0x4e, 0xce, 0x2e, 0xae, 0x6e, 0xee, 0x1e, 0x9e, 0x5e, 0xde, 0x3e, 0xbe, 0x7e, 0xfe,
790 0x01, 0x81, 0x41, 0xc1, 0x21, 0xa1, 0x61, 0xe1, 0x11, 0x91, 0x51, 0xd1, 0x31, 0xb1, 0x71, 0xf1,
791 0x09, 0x89, 0x49, 0xc9, 0x29, 0xa9, 0x69, 0xe9, 0x19, 0x99, 0x59, 0xd9, 0x39, 0xb9, 0x79, 0xf9,
792 0x05, 0x85, 0x45, 0xc5, 0x25, 0xa5, 0x65, 0xe5, 0x15, 0x95, 0x55, 0xd5, 0x35, 0xb5, 0x75, 0xf5,
793 0x0d, 0x8d, 0x4d, 0xcd, 0x2d, 0xad, 0x6d, 0xed, 0x1d, 0x9d, 0x5d, 0xdd, 0x3d, 0xbd, 0x7d, 0xfd,
794 0x03, 0x83, 0x43, 0xc3, 0x23, 0xa3, 0x63, 0xe3, 0x13, 0x93, 0x53, 0xd3, 0x33, 0xb3, 0x73, 0xf3,
795 0x0b, 0x8b, 0x4b, 0xcb, 0x2b, 0xab, 0x6b, 0xeb, 0x1b, 0x9b, 0x5b, 0xdb, 0x3b, 0xbb, 0x7b, 0xfb,
796 0x07, 0x87, 0x47, 0xc7, 0x27, 0xa7, 0x67, 0xe7, 0x17, 0x97, 0x57, 0xd7, 0x37, 0xb7, 0x77, 0xf7,
797 0x0f, 0x8f, 0x4f, 0xcf, 0x2f, 0xaf, 0x6f, 0xef, 0x1f, 0x9f, 0x5f, 0xdf, 0x3f, 0xbf, 0x7f, 0xff,
798 ];
799 reverse_slice_index_bits(&mut input256[..]);
800 assert_eq!(input256, output256);
801 }
802
803 #[test]
804 fn test_apply_to_chunks_exact_fit() {
805 const CHUNK_SIZE: usize = 4;
806 let input: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8];
807 let mut results: Vec<Vec<u8>> = Vec::new();
808
809 apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
810 results.push(chunk.to_vec());
811 });
812
813 assert_eq!(results, vec![vec![1, 2, 3, 4], vec![5, 6, 7, 8]]);
814 }
815
816 #[test]
817 fn test_apply_to_chunks_with_remainder() {
818 const CHUNK_SIZE: usize = 3;
819 let input: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7];
820 let mut results: Vec<Vec<u8>> = Vec::new();
821
822 apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
823 results.push(chunk.to_vec());
824 });
825
826 assert_eq!(results, vec![vec![1, 2, 3], vec![4, 5, 6], vec![7]]);
827 }
828
829 #[test]
830 fn test_apply_to_chunks_empty_input() {
831 const CHUNK_SIZE: usize = 4;
832 let input: Vec<u8> = vec![];
833 let mut results: Vec<Vec<u8>> = Vec::new();
834
835 apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
836 results.push(chunk.to_vec());
837 });
838
839 assert!(results.is_empty());
840 }
841
842 #[test]
843 fn test_apply_to_chunks_single_chunk() {
844 const CHUNK_SIZE: usize = 10;
845 let input: Vec<u8> = vec![1, 2, 3, 4, 5];
846 let mut results: Vec<Vec<u8>> = Vec::new();
847
848 apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
849 results.push(chunk.to_vec());
850 });
851
852 assert_eq!(results, vec![vec![1, 2, 3, 4, 5]]);
853 }
854
855 #[test]
856 fn test_apply_to_chunks_large_chunk_size() {
857 const CHUNK_SIZE: usize = 100;
858 let input: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8];
859 let mut results: Vec<Vec<u8>> = Vec::new();
860
861 apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
862 results.push(chunk.to_vec());
863 });
864
865 assert_eq!(results, vec![vec![1, 2, 3, 4, 5, 6, 7, 8]]);
866 }
867
868 #[test]
869 fn test_apply_to_chunks_large_input() {
870 const CHUNK_SIZE: usize = 5;
871 let input: Vec<u8> = (1..=20).collect();
872 let mut results: Vec<Vec<u8>> = Vec::new();
873
874 apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
875 results.push(chunk.to_vec());
876 });
877
878 assert_eq!(
879 results,
880 vec![
881 vec![1, 2, 3, 4, 5],
882 vec![6, 7, 8, 9, 10],
883 vec![11, 12, 13, 14, 15],
884 vec![16, 17, 18, 19, 20]
885 ]
886 );
887 }
888
889 #[test]
890 fn test_reverse_slice_index_bits_random() {
891 let lengths = [32, 128, 1 << 16];
892 let mut rng = SmallRng::seed_from_u64(1);
893 for _ in 0..32 {
894 for &length in &lengths {
895 let mut rand_list: Vec<u32> = Vec::with_capacity(length);
896 rand_list.resize_with(length, || rng.random());
897 let expect = reverse_index_bits_naive(&rand_list);
898
899 let mut actual = rand_list.clone();
900 reverse_slice_index_bits(&mut actual);
901
902 assert_eq!(actual, expect);
903 }
904 }
905 }
906
907 #[test]
908 fn test_log2_strict_usize_edge_cases() {
909 assert_eq!(log2_strict_usize(1), 0);
910 assert_eq!(log2_strict_usize(2), 1);
911 assert_eq!(log2_strict_usize(1 << 18), 18);
912 assert_eq!(log2_strict_usize(1 << 31), 31);
913 assert_eq!(
914 log2_strict_usize(1 << (usize::BITS - 1)),
915 usize::BITS as usize - 1
916 );
917 }
918
919 #[test]
920 #[should_panic]
921 fn test_log2_strict_usize_zero() {
922 let _ = log2_strict_usize(0);
923 }
924
925 #[test]
926 #[should_panic]
927 fn test_log2_strict_usize_nonpower_2() {
928 let _ = log2_strict_usize(0x78c341c65ae6d262);
929 }
930
931 #[test]
932 #[should_panic]
933 fn test_log2_strict_usize_max() {
934 let _ = log2_strict_usize(usize::MAX);
935 }
936
937 #[test]
938 fn test_log2_ceil_usize_comprehensive() {
939 // Powers of 2
940 assert_eq!(log2_ceil_usize(0), 0);
941 assert_eq!(log2_ceil_usize(1), 0);
942 assert_eq!(log2_ceil_usize(2), 1);
943 assert_eq!(log2_ceil_usize(1 << 18), 18);
944 assert_eq!(log2_ceil_usize(1 << 31), 31);
945 assert_eq!(
946 log2_ceil_usize(1 << (usize::BITS - 1)),
947 usize::BITS as usize - 1
948 );
949
950 // Nonpowers; want to round up
951 assert_eq!(log2_ceil_usize(3), 2);
952 assert_eq!(log2_ceil_usize(0x14fe901b), 29);
953 assert_eq!(
954 log2_ceil_usize((1 << (usize::BITS - 1)) + 1),
955 usize::BITS as usize
956 );
957 assert_eq!(log2_ceil_usize(usize::MAX - 1), usize::BITS as usize);
958 assert_eq!(log2_ceil_usize(usize::MAX), usize::BITS as usize);
959 }
960
961 fn reverse_index_bits_naive<T: Copy>(arr: &[T]) -> Vec<T> {
962 let n = arr.len();
963 let n_power = log2_strict_usize(n);
964
965 let mut out = vec![None; n];
966 for (i, v) in arr.iter().enumerate() {
967 let dst = i.reverse_bits() >> (usize::BITS - n_power as u32);
968 out[dst] = Some(*v);
969 }
970
971 out.into_iter().map(|x| x.unwrap()).collect()
972 }
973
974 #[test]
975 fn test_relatively_prime_u64() {
976 // Zero cases (should always return false)
977 assert!(!relatively_prime_u64(0, 0));
978 assert!(!relatively_prime_u64(10, 0));
979 assert!(!relatively_prime_u64(0, 10));
980 assert!(!relatively_prime_u64(0, 123456789));
981
982 // Number with itself (if greater than 1, not relatively prime)
983 assert!(relatively_prime_u64(1, 1));
984 assert!(!relatively_prime_u64(10, 10));
985 assert!(!relatively_prime_u64(99999, 99999));
986
987 // Powers of 2 (always false since they share factor 2)
988 assert!(!relatively_prime_u64(2, 4));
989 assert!(!relatively_prime_u64(16, 32));
990 assert!(!relatively_prime_u64(64, 128));
991 assert!(!relatively_prime_u64(1024, 4096));
992 assert!(!relatively_prime_u64(u64::MAX, u64::MAX));
993
994 // One number is a multiple of the other (always false)
995 assert!(!relatively_prime_u64(5, 10));
996 assert!(!relatively_prime_u64(12, 36));
997 assert!(!relatively_prime_u64(15, 45));
998 assert!(!relatively_prime_u64(100, 500));
999
1000 // Co-prime numbers (should be true)
1001 assert!(relatively_prime_u64(17, 31));
1002 assert!(relatively_prime_u64(97, 43));
1003 assert!(relatively_prime_u64(7919, 65537));
1004 assert!(relatively_prime_u64(15485863, 32452843));
1005
1006 // Small prime numbers (should be true)
1007 assert!(relatively_prime_u64(13, 17));
1008 assert!(relatively_prime_u64(101, 103));
1009 assert!(relatively_prime_u64(1009, 1013));
1010
1011 // Large numbers (some cases where they are relatively prime or not)
1012 assert!(!relatively_prime_u64(
1013 190266297176832000,
1014 10430732356495263744
1015 ));
1016 assert!(!relatively_prime_u64(
1017 2040134905096275968,
1018 5701159354248194048
1019 ));
1020 assert!(!relatively_prime_u64(
1021 16611311494648745984,
1022 7514969329383038976
1023 ));
1024 assert!(!relatively_prime_u64(
1025 14863931409971066880,
1026 7911906750992527360
1027 ));
1028
1029 // Max values
1030 assert!(relatively_prime_u64(u64::MAX, 1));
1031 assert!(relatively_prime_u64(u64::MAX, u64::MAX - 1));
1032 assert!(!relatively_prime_u64(u64::MAX, u64::MAX));
1033 }
1034}