Skip to main content

cmov/
slice.rs

1//! Trait impls for core slices.
2
3use crate::{Cmov, CmovEq, Condition};
4use core::{
5    cmp,
6    num::{
7        NonZeroI8, NonZeroI16, NonZeroI32, NonZeroI64, NonZeroI128, NonZeroIsize, NonZeroU8,
8        NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU128, NonZeroUsize,
9    },
10    ops::{BitOrAssign, Shl},
11    ptr, slice,
12};
13
14// Uses 64-bit words on 64-bit targets, 32-bit everywhere else
15#[cfg(not(target_pointer_width = "64"))]
16type Word = u32;
17#[cfg(target_pointer_width = "64")]
18type Word = u64;
19const WORD_SIZE: usize = size_of::<Word>();
20
21/// Assert the lengths of the two slices are equal.
22macro_rules! assert_lengths_eq {
23    ($a:expr, $b:expr) => {
24        assert_eq!(
25            $a, $b,
26            "source slice length ({}) does not match destination slice length ({})",
27            $b, $a
28        );
29    };
30}
31
32//
33// `Cmov` trait impls
34//
35
36// Optimized implementation for byte slices which coalesces them into word-sized chunks first,
37// then performs [`Cmov`] at the word-level to cut down on the total number of instructions.
38impl Cmov for [u8] {
39    #[inline]
40    #[track_caller]
41    fn cmovnz(&mut self, value: &Self, condition: Condition) {
42        assert_lengths_eq!(self.len(), value.len());
43
44        let (dst_chunks, dst_remainder) = slice_as_chunks_mut::<u8, WORD_SIZE>(self);
45        let (src_chunks, src_remainder) = slice_as_chunks::<u8, WORD_SIZE>(value);
46
47        for (dst_chunk, src_chunk) in dst_chunks.iter_mut().zip(src_chunks.iter()) {
48            let mut a = Word::from_ne_bytes(*dst_chunk);
49            let b = Word::from_ne_bytes(*src_chunk);
50            a.cmovnz(&b, condition);
51            dst_chunk.copy_from_slice(&a.to_ne_bytes());
52        }
53
54        cmovnz_remainder(dst_remainder, src_remainder, condition);
55    }
56}
57
58// Optimized implementation for slices of `u16` which coalesces them into word-sized chunks first,
59// then performs [`Cmov`] at the word-level to cut down on the total number of instructions.
60#[cfg(not(target_pointer_width = "64"))]
61#[cfg_attr(docsrs, doc(cfg(true)))]
62impl Cmov for [u16] {
63    #[inline]
64    #[track_caller]
65    fn cmovnz(&mut self, value: &Self, condition: Condition) {
66        assert_lengths_eq!(self.len(), value.len());
67
68        let (dst_chunks, dst_remainder) = slice_as_chunks_mut::<u16, 2>(self);
69        let (src_chunks, src_remainder) = slice_as_chunks::<u16, 2>(value);
70
71        for (dst_chunk, src_chunk) in dst_chunks.iter_mut().zip(src_chunks.iter()) {
72            let mut a = Word::from(dst_chunk[0]) | (Word::from(dst_chunk[1]) << 16);
73            let b = Word::from(src_chunk[0]) | (Word::from(src_chunk[1]) << 16);
74            a.cmovnz(&b, condition);
75            dst_chunk[0] = (a & 0xFFFF) as u16;
76            dst_chunk[1] = (a >> 16) as u16;
77        }
78
79        cmovnz_remainder(dst_remainder, src_remainder, condition);
80    }
81}
82
83// Optimized implementation for slices of `u16` which coalesces them into word-sized chunks first,
84// then performs [`Cmov`] at the word-level to cut down on the total number of instructions.
85#[cfg(target_pointer_width = "64")]
86#[cfg_attr(docsrs, doc(cfg(true)))]
87impl Cmov for [u16] {
88    #[inline]
89    #[track_caller]
90    fn cmovnz(&mut self, value: &Self, condition: Condition) {
91        assert_lengths_eq!(self.len(), value.len());
92
93        #[inline(always)]
94        fn u16x4_to_u64(input: &[u16; 4]) -> u64 {
95            Word::from(input[0])
96                | (Word::from(input[1]) << 16)
97                | (Word::from(input[2]) << 32)
98                | (Word::from(input[3]) << 48)
99        }
100
101        let (dst_chunks, dst_remainder) = slice_as_chunks_mut::<u16, 4>(self);
102        let (src_chunks, src_remainder) = slice_as_chunks::<u16, 4>(value);
103
104        for (dst_chunk, src_chunk) in dst_chunks.iter_mut().zip(src_chunks.iter()) {
105            let mut a = u16x4_to_u64(dst_chunk);
106            let b = u16x4_to_u64(src_chunk);
107            a.cmovnz(&b, condition);
108            dst_chunk[0] = (a & 0xFFFF) as u16;
109            dst_chunk[1] = ((a >> 16) & 0xFFFF) as u16;
110            dst_chunk[2] = ((a >> 32) & 0xFFFF) as u16;
111            dst_chunk[3] = ((a >> 48) & 0xFFFF) as u16;
112        }
113
114        cmovnz_remainder(dst_remainder, src_remainder, condition);
115    }
116}
117
118/// Implement [`Cmov`] using a simple loop.
119macro_rules! impl_cmov_with_loop {
120    ( $($int:ty),+ ) => {
121        $(
122            impl Cmov for [$int] {
123                #[inline]
124                #[track_caller]
125                fn cmovnz(&mut self, value: &Self, condition: Condition) {
126                    assert_lengths_eq!(self.len(), value.len());
127                    for (a, b) in self.iter_mut().zip(value.iter()) {
128                        a.cmovnz(b, condition);
129                    }
130                }
131            }
132        )+
133    };
134}
135
136// These types are large enough we don't need to use anything more complex than a simple loop
137impl_cmov_with_loop!(u32, u64, u128, usize);
138
139/// Ensure the two provided types have the same size and alignment.
140macro_rules! assert_size_and_alignment_eq {
141    ($int:ty, $uint:ty) => {
142        const {
143            assert!(
144                size_of::<$int>() == size_of::<$uint>(),
145                "integers are of unequal size"
146            );
147
148            assert!(
149                align_of::<$int>() == align_of::<$uint>(),
150                "integers have unequal alignment"
151            );
152        }
153    };
154}
155
156/// Implement [`Cmov`] and [`CmovEq`] traits by casting to a different type that impls the traits.
157macro_rules! impl_cmov_with_cast {
158    ( $($src:ty => $dst:ty),+ ) => {
159        $(
160            impl Cmov for [$src] {
161                #[inline]
162                #[track_caller]
163                #[allow(unsafe_code)]
164                fn cmovnz(&mut self, value: &Self, condition: Condition) {
165                    assert_size_and_alignment_eq!($src, $dst);
166
167                    // SAFETY:
168                    // - Slices being constructed are of same-sized integers as asserted above.
169                    // - We source the slice length directly from the other valid slice.
170                    let self_unsigned = unsafe { cast_slice_mut::<$src, $dst>(self) };
171                    let value_unsigned = unsafe { cast_slice::<$src, $dst>(value) };
172                    self_unsigned.cmovnz(value_unsigned, condition);
173                }
174            }
175        )+
176    };
177}
178
179// These types are all safe to cast between each other
180impl_cmov_with_cast!(
181    i8 => u8,
182    i16 => u16,
183    i32 => u32,
184    i64 => u64,
185    i128 => u128,
186    isize => usize,
187    NonZeroI8 => i8,
188    NonZeroI16 => i16,
189    NonZeroI32 => i32,
190    NonZeroI64 => i64,
191    NonZeroI128 => i128,
192    NonZeroIsize => isize,
193    NonZeroU8 => u8,
194    NonZeroU16 => u16,
195    NonZeroU32 => u32,
196    NonZeroU64 => u64,
197    NonZeroU128 => u128,
198    NonZeroUsize => usize,
199    cmp::Ordering => i8 // #[repr(i8)]
200);
201
202//
203// `CmovEq` impls
204//
205
206// Optimized implementation for byte slices which coalesces them into word-sized chunks first,
207// then performs [`CmovEq`] at the word-level to cut down on the total number of instructions.
208impl CmovEq for [u8] {
209    #[inline]
210    fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
211        // Short-circuit the comparison if the slices are of different lengths, and set the output
212        // condition to the input condition.
213        if self.len() != rhs.len() {
214            *output = input;
215            return;
216        }
217
218        let (self_chunks, self_remainder) = slice_as_chunks::<u8, WORD_SIZE>(self);
219        let (rhs_chunks, rhs_remainder) = slice_as_chunks::<u8, WORD_SIZE>(rhs);
220
221        for (self_chunk, rhs_chunk) in self_chunks.iter().zip(rhs_chunks.iter()) {
222            let a = Word::from_ne_bytes(*self_chunk);
223            let b = Word::from_ne_bytes(*rhs_chunk);
224            a.cmovne(&b, input, output);
225        }
226
227        cmovne_remainder(self_remainder, rhs_remainder, input, output);
228    }
229}
230
231/// Implement [`CmovEq`] using a simple loop.
232macro_rules! impl_cmoveq_with_loop {
233    ( $($int:ty),+ ) => {
234        $(
235            impl CmovEq for [$int] {
236                #[inline]
237                fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
238                    // Short-circuit the comparison if the slices are of different lengths, and set the output
239                    // condition to the input condition.
240                    if self.len() != rhs.len() {
241                        *output = input;
242                        return;
243                    }
244
245                    for (a, b) in self.iter().zip(rhs.iter()) {
246                        a.cmovne(b, input, output);
247                    }
248                }
249            }
250        )+
251    };
252}
253
254// TODO(tarcieri): investigate word-coalescing impls
255impl_cmoveq_with_loop!(u16, u32, u64, u128, usize);
256
257/// Implement [`CmovEq`] traits by casting to a different type that impls the traits.
258macro_rules! impl_cmoveq_with_cast {
259    ( $($src:ty => $dst:ty),+ ) => {
260        $(
261            impl CmovEq for [$src] {
262                #[inline]
263                #[allow(unsafe_code)]
264                fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
265                    assert_size_and_alignment_eq!($src, $dst);
266
267                    // SAFETY:
268                    // - Slices being constructed are of same-sized types as asserted above.
269                    // - We source the slice length directly from the other valid slice.
270                    let self_unsigned = unsafe { cast_slice::<$src, $dst>(self) };
271                    let rhs_unsigned = unsafe { cast_slice::<$src, $dst>(rhs) };
272                    self_unsigned.cmovne(rhs_unsigned, input, output);
273                }
274            }
275        )+
276    };
277}
278
279// These types are all safe to cast between each other
280impl_cmoveq_with_cast!(
281    i8 => u8,
282    i16 => u16,
283    i32 => u32,
284    i64 => u64,
285    i128 => u128,
286    isize => usize,
287    NonZeroI8 => i8,
288    NonZeroI16 => i16,
289    NonZeroI32 => i32,
290    NonZeroI64 => i64,
291    NonZeroI128 => i128,
292    NonZeroIsize => isize,
293    NonZeroU8 => u8,
294    NonZeroU16 => u16,
295    NonZeroU32 => u32,
296    NonZeroU64 => u64,
297    NonZeroU128 => u128,
298    NonZeroUsize => usize,
299    cmp::Ordering => i8 // #[repr(i8)]
300);
301
302//
303// Helper functions
304//
305
306/// Performs an unsafe pointer cast from one slice type to the other.
307///
308/// # Compile-time panics
309/// - If `T` and `U` differ in size
310/// - If `T` and `U` differ in alignment
311unsafe fn cast_slice<T, U>(slice: &[T]) -> &[U] {
312    const {
313        assert!(size_of::<T>() == size_of::<U>(), "T/U size differs");
314        assert!(align_of::<T>() == align_of::<U>(), "T/U alignment differs");
315    }
316
317    // SAFETY:
318    // - Slices are of same-sized/aligned types as asserted above.
319    // - It's up to the caller to ensure the pointer cast from `T` to `U` itself is valid.
320    #[allow(unsafe_code)]
321    unsafe {
322        &*(ptr::from_ref::<[T]>(slice) as *const [U])
323    }
324}
325
326/// Performs an unsafe pointer cast from one mutable slice type to the other.
327///
328/// # Compile-time panics
329/// - If `T` and `U` differ in size
330/// - If `T` and `U` differ in alignment
331unsafe fn cast_slice_mut<T, U>(slice: &mut [T]) -> &mut [U] {
332    const {
333        assert!(size_of::<T>() == size_of::<U>(), "T/U size differs");
334        assert!(align_of::<T>() == align_of::<U>(), "T/U alignment differs");
335    }
336
337    // SAFETY:
338    // - Slices are of same-sized/aligned types as asserted above.
339    // - It's up to the caller to ensure the pointer cast from `T` to `U` itself is valid.
340    #[allow(unsafe_code)]
341    unsafe {
342        &mut *(ptr::from_mut::<[T]>(slice) as *mut [U])
343    }
344}
345
346/// Compare the two remainder slices by loading a `Word` then performing `cmovne`.
347#[inline]
348fn cmovne_remainder<T>(
349    a_remainder: &[T],
350    b_remainder: &[T],
351    input: Condition,
352    output: &mut Condition,
353) where
354    T: Copy,
355    Word: From<T>,
356{
357    let a = slice_to_word(a_remainder);
358    let b = slice_to_word(b_remainder);
359    a.cmovne(&b, input, output);
360}
361
362/// Load the remainder from chunking the slice into a single `Word`, perform `cmovnz`, then write
363/// the result back out to `dst_remainder`.
364#[inline]
365fn cmovnz_remainder<T>(dst_remainder: &mut [T], src_remainder: &[T], condition: Condition)
366where
367    T: BitOrAssign + Copy + From<u8> + Shl<usize, Output = T>,
368    Word: From<T>,
369{
370    let mut remainder = slice_to_word(dst_remainder);
371    remainder.cmovnz(&slice_to_word(src_remainder), condition);
372    word_to_slice(remainder, dst_remainder);
373}
374
375/// Create a [`Word`] from the given input slice.
376#[inline]
377fn slice_to_word<T>(slice: &[T]) -> Word
378where
379    T: Copy,
380    Word: From<T>,
381{
382    debug_assert!(size_of_val(slice) <= WORD_SIZE, "slice too large");
383    slice.iter().rev().copied().fold(0, |acc, n| {
384        (acc << (const { size_of::<T>() * 8 })) | Word::from(n)
385    })
386}
387
388/// Serialize [`Word`] as bytes using the same byte ordering as `slice_to_word`.
389#[inline]
390#[allow(clippy::arithmetic_side_effects)]
391fn word_to_slice<T>(word: Word, out: &mut [T])
392where
393    T: BitOrAssign + Copy + From<u8> + Shl<usize, Output = T>,
394{
395    debug_assert!(size_of::<T>() > 0, "can't be used with ZSTs");
396    debug_assert!(out.len() <= WORD_SIZE, "slice too large");
397
398    let bytes = word.to_le_bytes();
399    for (o, chunk) in out.iter_mut().zip(bytes.chunks(size_of::<T>())) {
400        *o = T::from(0u8);
401        for (i, &byte) in chunk.iter().enumerate() {
402            *o |= T::from(byte) << (i * 8);
403        }
404    }
405}
406
407//
408// Vendored `core` functions to allow a 1.85 MSRV
409//
410
411/// Rust core `[T]::as_chunks` vendored because of its 1.88 MSRV.
412/// TODO(tarcieri): use upstream function when we bump MSRV
413#[inline]
414#[track_caller]
415#[must_use]
416#[allow(
417    clippy::arithmetic_side_effects,
418    clippy::integer_division_remainder_used
419)]
420fn slice_as_chunks<T, const N: usize>(slice: &[T]) -> (&[[T; N]], &[T]) {
421    assert!(N != 0, "chunk size must be non-zero");
422    let len_rounded_down = slice.len() / N * N;
423    // SAFETY: The rounded-down value is always the same or smaller than the
424    // original length, and thus must be in-bounds of the slice.
425    let (multiple_of_n, remainder) = unsafe { slice.split_at_unchecked(len_rounded_down) };
426    // SAFETY: We already panicked for zero, and ensured by construction
427    // that the length of the subslice is a multiple of N.
428    let array_slice = unsafe { slice_as_chunks_unchecked(multiple_of_n) };
429    (array_slice, remainder)
430}
431
432/// Rust core `[T]::as_chunks_mut` vendored because of its 1.88 MSRV.
433/// TODO(tarcieri): use upstream function when we bump MSRV
434#[inline]
435#[track_caller]
436#[must_use]
437#[allow(
438    clippy::arithmetic_side_effects,
439    clippy::integer_division_remainder_used
440)]
441fn slice_as_chunks_mut<T, const N: usize>(slice: &mut [T]) -> (&mut [[T; N]], &mut [T]) {
442    assert!(N != 0, "chunk size must be non-zero");
443    let len_rounded_down = slice.len() / N * N;
444    // SAFETY: The rounded-down value is always the same or smaller than the
445    // original length, and thus must be in-bounds of the slice.
446    let (multiple_of_n, remainder) = unsafe { slice.split_at_mut_unchecked(len_rounded_down) };
447    // SAFETY: We already panicked for zero, and ensured by construction
448    // that the length of the subslice is a multiple of N.
449    let array_slice = unsafe { slice_as_chunks_unchecked_mut(multiple_of_n) };
450    (array_slice, remainder)
451}
452
453/// Rust core `[T]::as_chunks_unchecked` vendored because of its 1.88 MSRV.
454/// TODO(tarcieri): use upstream function when we bump MSRV
455#[inline]
456#[must_use]
457#[track_caller]
458#[allow(
459    clippy::arithmetic_side_effects,
460    clippy::integer_division_remainder_used
461)]
462unsafe fn slice_as_chunks_unchecked<T, const N: usize>(slice: &[T]) -> &[[T; N]] {
463    // Caller must guarantee that `N` is nonzero and exactly divides the slice length
464    const { debug_assert!(N != 0) };
465    debug_assert_eq!(slice.len() % N, 0);
466    let new_len = slice.len() / N;
467
468    // SAFETY: We cast a slice of `new_len * N` elements into
469    // a slice of `new_len` many `N` elements chunks.
470    unsafe { slice::from_raw_parts(slice.as_ptr().cast(), new_len) }
471}
472
473/// Rust core `[T]::as_chunks_unchecked_mut` vendored because of its 1.88 MSRV.
474/// TODO(tarcieri): use upstream function when we bump MSRV
475#[inline]
476#[must_use]
477#[track_caller]
478#[allow(
479    clippy::arithmetic_side_effects,
480    clippy::integer_division_remainder_used
481)]
482unsafe fn slice_as_chunks_unchecked_mut<T, const N: usize>(slice: &mut [T]) -> &mut [[T; N]] {
483    // Caller must guarantee that `N` is nonzero and exactly divides the slice length
484    const { debug_assert!(N != 0) };
485    debug_assert_eq!(slice.len() % N, 0);
486    let new_len = slice.len() / N;
487
488    // SAFETY: We cast a slice of `new_len * N` elements into
489    // a slice of `new_len` many `N` elements chunks.
490    unsafe { slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), new_len) }
491}
492
493#[cfg(test)]
494mod tests {
495    #[test]
496    fn cmovnz_remainder() {
497        // - Test endianness handling on non-64-bit platforms
498        // - Test handling of odd length slices on 64-bit platforms
499        #[cfg(not(target_pointer_width = "64"))]
500        const A_U16: [u16; 2] = [0xAAAA, 0xBBBB];
501        #[cfg(target_pointer_width = "64")]
502        const A_U16: [u16; 3] = [0xAAAA, 0xBBBB, 0xCCCC];
503
504        #[cfg(not(target_pointer_width = "64"))]
505        const B_U16: [u16; 2] = [0x10, 0xFFFF];
506        #[cfg(target_pointer_width = "64")]
507        const B_U16: [u16; 3] = [0x10, 0x10, 0xFFFF];
508
509        let mut out = A_U16;
510
511        super::cmovnz_remainder(&mut out, &B_U16, 0);
512        assert_eq!(A_U16, out);
513
514        super::cmovnz_remainder(&mut out, &B_U16, 1);
515        assert_eq!(B_U16, out);
516    }
517
518    #[test]
519    fn slice_to_word() {
520        assert_eq!(0xAABBCC, super::slice_to_word(&[0xCCu8, 0xBB, 0xAA]));
521        assert_eq!(0xAAAABBBB, super::slice_to_word(&[0xBBBBu16, 0xAAAA]));
522
523        #[cfg(target_pointer_width = "64")]
524        assert_eq!(
525            0xAAAABBBBCCCC,
526            super::slice_to_word(&[0xCCCCu16, 0xBBBB, 0xAAAA])
527        );
528    }
529
530    #[test]
531    fn word_to_slice() {
532        let mut out = [0u8; 3];
533        super::word_to_slice(0xAABBCC, &mut out);
534        assert_eq!(&[0xCC, 0xBB, 0xAA], &out);
535
536        let mut out = [0u16; 2];
537        super::word_to_slice(0xAAAABBBB, &mut out);
538        assert_eq!(&[0xBBBB, 0xAAAA], &out);
539
540        #[cfg(target_pointer_width = "64")]
541        {
542            let mut out = [0u16; 3];
543            super::word_to_slice(0xAAAABBBBCCCC, &mut out);
544            assert_eq!(&[0xCCCC, 0xBBBB, 0xAAAA], &out);
545        }
546    }
547}