Skip to main content

cmov/
slice.rs

1//! Trait impls for core slices.
2
3use crate::{Cmov, CmovEq, Condition};
4use core::{
5    cmp,
6    num::{
7        NonZeroI8, NonZeroI16, NonZeroI32, NonZeroI64, NonZeroI128, NonZeroU8, NonZeroU16,
8        NonZeroU32, NonZeroU64, NonZeroU128,
9    },
10    ops::{BitOrAssign, Shl},
11    ptr, slice,
12};
13
14// Uses 64-bit words on 64-bit targets, 32-bit everywhere else
15#[cfg(not(target_pointer_width = "64"))]
16type Word = u32;
17#[cfg(target_pointer_width = "64")]
18type Word = u64;
19const WORD_SIZE: usize = size_of::<Word>();
20
21/// Assert the lengths of the two slices are equal.
22macro_rules! assert_lengths_eq {
23    ($a:expr, $b:expr) => {
24        assert_eq!(
25            $a, $b,
26            "source slice length ({}) does not match destination slice length ({})",
27            $b, $a
28        );
29    };
30}
31
32//
33// `Cmov` trait impls
34//
35
36// Optimized implementation for byte slices which coalesces them into word-sized chunks first,
37// then performs [`Cmov`] at the word-level to cut down on the total number of instructions.
38impl Cmov for [u8] {
39    #[inline]
40    #[track_caller]
41    fn cmovnz(&mut self, value: &Self, condition: Condition) {
42        assert_lengths_eq!(self.len(), value.len());
43
44        let (dst_chunks, dst_remainder) = slice_as_chunks_mut::<u8, WORD_SIZE>(self);
45        let (src_chunks, src_remainder) = slice_as_chunks::<u8, WORD_SIZE>(value);
46
47        for (dst_chunk, src_chunk) in dst_chunks.iter_mut().zip(src_chunks.iter()) {
48            let mut a = Word::from_ne_bytes(*dst_chunk);
49            let b = Word::from_ne_bytes(*src_chunk);
50            a.cmovnz(&b, condition);
51            dst_chunk.copy_from_slice(&a.to_ne_bytes());
52        }
53
54        cmovnz_remainder(dst_remainder, src_remainder, condition);
55    }
56}
57
58// Optimized implementation for slices of `u16` which coalesces them into word-sized chunks first,
59// then performs [`Cmov`] at the word-level to cut down on the total number of instructions.
60#[cfg(not(target_pointer_width = "64"))]
61#[cfg_attr(docsrs, doc(cfg(true)))]
62impl Cmov for [u16] {
63    #[inline]
64    #[track_caller]
65    fn cmovnz(&mut self, value: &Self, condition: Condition) {
66        assert_lengths_eq!(self.len(), value.len());
67
68        let (dst_chunks, dst_remainder) = slice_as_chunks_mut::<u16, 2>(self);
69        let (src_chunks, src_remainder) = slice_as_chunks::<u16, 2>(value);
70
71        for (dst_chunk, src_chunk) in dst_chunks.iter_mut().zip(src_chunks.iter()) {
72            let mut a = Word::from(dst_chunk[0]) | (Word::from(dst_chunk[1]) << 16);
73            let b = Word::from(src_chunk[0]) | (Word::from(src_chunk[1]) << 16);
74            a.cmovnz(&b, condition);
75            dst_chunk[0] = (a & 0xFFFF) as u16;
76            dst_chunk[1] = (a >> 16) as u16;
77        }
78
79        cmovnz_remainder(dst_remainder, src_remainder, condition);
80    }
81}
82
83// Optimized implementation for slices of `u16` which coalesces them into word-sized chunks first,
84// then performs [`Cmov`] at the word-level to cut down on the total number of instructions.
85#[cfg(target_pointer_width = "64")]
86#[cfg_attr(docsrs, doc(cfg(true)))]
87impl Cmov for [u16] {
88    #[inline]
89    #[track_caller]
90    fn cmovnz(&mut self, value: &Self, condition: Condition) {
91        assert_lengths_eq!(self.len(), value.len());
92
93        #[inline(always)]
94        fn u16x4_to_u64(input: &[u16; 4]) -> u64 {
95            Word::from(input[0])
96                | (Word::from(input[1]) << 16)
97                | (Word::from(input[2]) << 32)
98                | (Word::from(input[3]) << 48)
99        }
100
101        let (dst_chunks, dst_remainder) = slice_as_chunks_mut::<u16, 4>(self);
102        let (src_chunks, src_remainder) = slice_as_chunks::<u16, 4>(value);
103
104        for (dst_chunk, src_chunk) in dst_chunks.iter_mut().zip(src_chunks.iter()) {
105            let mut a = u16x4_to_u64(dst_chunk);
106            let b = u16x4_to_u64(src_chunk);
107            a.cmovnz(&b, condition);
108            dst_chunk[0] = (a & 0xFFFF) as u16;
109            dst_chunk[1] = ((a >> 16) & 0xFFFF) as u16;
110            dst_chunk[2] = ((a >> 32) & 0xFFFF) as u16;
111            dst_chunk[3] = ((a >> 48) & 0xFFFF) as u16;
112        }
113
114        cmovnz_remainder(dst_remainder, src_remainder, condition);
115    }
116}
117
118/// Implement [`Cmov`] using a simple loop.
119macro_rules! impl_cmov_with_loop {
120    ( $($int:ty),+ ) => {
121        $(
122            impl Cmov for [$int] {
123                #[inline]
124                #[track_caller]
125                fn cmovnz(&mut self, value: &Self, condition: Condition) {
126                    assert_lengths_eq!(self.len(), value.len());
127                    for (a, b) in self.iter_mut().zip(value.iter()) {
128                        a.cmovnz(b, condition);
129                    }
130                }
131            }
132        )+
133    };
134}
135
136// These types are large enough we don't need to use anything more complex than a simple loop
137impl_cmov_with_loop!(u32, u64, u128);
138
139/// Ensure the two provided types have the same size and alignment.
140macro_rules! assert_size_and_alignment_eq {
141    ($int:ty, $uint:ty) => {
142        const {
143            assert!(
144                size_of::<$int>() == size_of::<$uint>(),
145                "integers are of unequal size"
146            );
147
148            assert!(
149                align_of::<$int>() == align_of::<$uint>(),
150                "integers have unequal alignment"
151            );
152        }
153    };
154}
155
156/// Implement [`Cmov`] and [`CmovEq`] traits by casting to a different type that impls the traits.
157macro_rules! impl_cmov_with_cast {
158    ( $($src:ty => $dst:ty),+ ) => {
159        $(
160            impl Cmov for [$src] {
161                #[inline]
162                #[track_caller]
163                #[allow(unsafe_code)]
164                fn cmovnz(&mut self, value: &Self, condition: Condition) {
165                    assert_size_and_alignment_eq!($src, $dst);
166
167                    // SAFETY:
168                    // - Slices being constructed are of same-sized integers as asserted above.
169                    // - We source the slice length directly from the other valid slice.
170                    let self_unsigned = unsafe { cast_slice_mut::<$src, $dst>(self) };
171                    let value_unsigned = unsafe { cast_slice::<$src, $dst>(value) };
172                    self_unsigned.cmovnz(value_unsigned, condition);
173                }
174            }
175        )+
176    };
177}
178
179// These types are all safe to cast between each other
180impl_cmov_with_cast!(
181    i8 => u8,
182    i16 => u16,
183    i32 => u32,
184    i64 => u64,
185    i128 => u128,
186    NonZeroI8 => i8,
187    NonZeroI16 => i16,
188    NonZeroI32 => i32,
189    NonZeroI64 => i64,
190    NonZeroI128 => i128,
191    NonZeroU8 => u8,
192    NonZeroU16 => u16,
193    NonZeroU32 => u32,
194    NonZeroU64 => u64,
195    NonZeroU128 => u128,
196    cmp::Ordering => i8 // #[repr(i8)]
197);
198
199//
200// `CmovEq` impls
201//
202
203// Optimized implementation for byte slices which coalesces them into word-sized chunks first,
204// then performs [`CmovEq`] at the word-level to cut down on the total number of instructions.
205impl CmovEq for [u8] {
206    #[inline]
207    fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
208        // Short-circuit the comparison if the slices are of different lengths, and set the output
209        // condition to the input condition.
210        if self.len() != rhs.len() {
211            *output = input;
212            return;
213        }
214
215        let (self_chunks, self_remainder) = slice_as_chunks::<u8, WORD_SIZE>(self);
216        let (rhs_chunks, rhs_remainder) = slice_as_chunks::<u8, WORD_SIZE>(rhs);
217
218        for (self_chunk, rhs_chunk) in self_chunks.iter().zip(rhs_chunks.iter()) {
219            let a = Word::from_ne_bytes(*self_chunk);
220            let b = Word::from_ne_bytes(*rhs_chunk);
221            a.cmovne(&b, input, output);
222        }
223
224        cmovne_remainder(self_remainder, rhs_remainder, input, output);
225    }
226}
227
228/// Implement [`CmovEq`] using a simple loop.
229macro_rules! impl_cmoveq_with_loop {
230    ( $($int:ty),+ ) => {
231        $(
232            impl CmovEq for [$int] {
233                #[inline]
234                fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
235                    // Short-circuit the comparison if the slices are of different lengths, and set the output
236                    // condition to the input condition.
237                    if self.len() != rhs.len() {
238                        *output = input;
239                        return;
240                    }
241
242                    for (a, b) in self.iter().zip(rhs.iter()) {
243                        a.cmovne(b, input, output);
244                    }
245                }
246            }
247        )+
248    };
249}
250
251// TODO(tarcieri): investigate word-coalescing impls
252impl_cmoveq_with_loop!(u16, u32, u64, u128);
253
254/// Implement [`CmovEq`] traits by casting to a different type that impls the traits.
255macro_rules! impl_cmoveq_with_cast {
256    ( $($src:ty => $dst:ty),+ ) => {
257        $(
258            impl CmovEq for [$src] {
259                #[inline]
260                #[allow(unsafe_code)]
261                fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
262                    assert_size_and_alignment_eq!($src, $dst);
263
264                    // SAFETY:
265                    // - Slices being constructed are of same-sized types as asserted above.
266                    // - We source the slice length directly from the other valid slice.
267                    let self_unsigned = unsafe { cast_slice::<$src, $dst>(self) };
268                    let rhs_unsigned = unsafe { cast_slice::<$src, $dst>(rhs) };
269                    self_unsigned.cmovne(rhs_unsigned, input, output);
270                }
271            }
272        )+
273    };
274}
275
276// These types are all safe to cast between each other
277impl_cmoveq_with_cast!(
278    i8 => u8,
279    i16 => u16,
280    i32 => u32,
281    i64 => u64,
282    i128 => u128,
283    NonZeroI8 => i8,
284    NonZeroI16 => i16,
285    NonZeroI32 => i32,
286    NonZeroI64 => i64,
287    NonZeroI128 => i128,
288    NonZeroU8 => u8,
289    NonZeroU16 => u16,
290    NonZeroU32 => u32,
291    NonZeroU64 => u64,
292    NonZeroU128 => u128,
293    cmp::Ordering => i8 // #[repr(i8)]
294);
295
296//
297// Helper functions
298//
299
300/// Performs an unsafe pointer cast from one slice type to the other.
301///
302/// # Compile-time panics
303/// - If `T` and `U` differ in size
304/// - If `T` and `U` differ in alignment
305unsafe fn cast_slice<T, U>(slice: &[T]) -> &[U] {
306    const {
307        assert!(size_of::<T>() == size_of::<U>(), "T/U size differs");
308        assert!(align_of::<T>() == align_of::<U>(), "T/U alignment differs");
309    }
310
311    // SAFETY:
312    // - Slices are of same-sized/aligned types as asserted above.
313    // - It's up to the caller to ensure the pointer cast from `T` to `U` itself is valid.
314    #[allow(unsafe_code)]
315    unsafe {
316        &*(ptr::from_ref::<[T]>(slice) as *const [U])
317    }
318}
319
320/// Performs an unsafe pointer cast from one mutable slice type to the other.
321///
322/// # Compile-time panics
323/// - If `T` and `U` differ in size
324/// - If `T` and `U` differ in alignment
325unsafe fn cast_slice_mut<T, U>(slice: &mut [T]) -> &mut [U] {
326    const {
327        assert!(size_of::<T>() == size_of::<U>(), "T/U size differs");
328        assert!(align_of::<T>() == align_of::<U>(), "T/U alignment differs");
329    }
330
331    // SAFETY:
332    // - Slices are of same-sized/aligned types as asserted above.
333    // - It's up to the caller to ensure the pointer cast from `T` to `U` itself is valid.
334    #[allow(unsafe_code)]
335    unsafe {
336        &mut *(ptr::from_mut::<[T]>(slice) as *mut [U])
337    }
338}
339
340/// Compare the two remainder slices by loading a `Word` then performing `cmovne`.
341#[inline]
342fn cmovne_remainder<T>(
343    a_remainder: &[T],
344    b_remainder: &[T],
345    input: Condition,
346    output: &mut Condition,
347) where
348    T: Copy,
349    Word: From<T>,
350{
351    let a = slice_to_word(a_remainder);
352    let b = slice_to_word(b_remainder);
353    a.cmovne(&b, input, output);
354}
355
356/// Load the remainder from chunking the slice into a single `Word`, perform `cmovnz`, then write
357/// the result back out to `dst_remainder`.
358#[inline]
359fn cmovnz_remainder<T>(dst_remainder: &mut [T], src_remainder: &[T], condition: Condition)
360where
361    T: BitOrAssign + Copy + From<u8> + Shl<usize, Output = T>,
362    Word: From<T>,
363{
364    let mut remainder = slice_to_word(dst_remainder);
365    remainder.cmovnz(&slice_to_word(src_remainder), condition);
366    word_to_slice(remainder, dst_remainder);
367}
368
369/// Create a [`Word`] from the given input slice.
370#[inline]
371fn slice_to_word<T>(slice: &[T]) -> Word
372where
373    T: Copy,
374    Word: From<T>,
375{
376    debug_assert!(size_of_val(slice) <= WORD_SIZE, "slice too large");
377    slice
378        .iter()
379        .rev()
380        .copied()
381        .fold(0, |acc, n| (acc << (size_of::<T>() * 8)) | Word::from(n))
382}
383
384/// Serialize [`Word`] as bytes using the same byte ordering as `slice_to_word`.
385#[inline]
386fn word_to_slice<T>(word: Word, out: &mut [T])
387where
388    T: BitOrAssign + Copy + From<u8> + Shl<usize, Output = T>,
389{
390    debug_assert!(size_of::<T>() > 0, "can't be used with ZSTs");
391    debug_assert!(out.len() <= WORD_SIZE, "slice too large");
392
393    let bytes = word.to_le_bytes();
394    for (o, chunk) in out.iter_mut().zip(bytes.chunks(size_of::<T>())) {
395        *o = T::from(0u8);
396        for (i, &byte) in chunk.iter().enumerate() {
397            *o |= T::from(byte) << (i * 8);
398        }
399    }
400}
401
402//
403// Vendored `core` functions to allow a 1.85 MSRV
404//
405
406/// Rust core `[T]::as_chunks` vendored because of its 1.88 MSRV.
407/// TODO(tarcieri): use upstream function when we bump MSRV
408#[inline]
409#[track_caller]
410#[must_use]
411#[allow(clippy::integer_division_remainder_used)]
412fn slice_as_chunks<T, const N: usize>(slice: &[T]) -> (&[[T; N]], &[T]) {
413    assert!(N != 0, "chunk size must be non-zero");
414    let len_rounded_down = slice.len() / N * N;
415    // SAFETY: The rounded-down value is always the same or smaller than the
416    // original length, and thus must be in-bounds of the slice.
417    let (multiple_of_n, remainder) = unsafe { slice.split_at_unchecked(len_rounded_down) };
418    // SAFETY: We already panicked for zero, and ensured by construction
419    // that the length of the subslice is a multiple of N.
420    let array_slice = unsafe { slice_as_chunks_unchecked(multiple_of_n) };
421    (array_slice, remainder)
422}
423
424/// Rust core `[T]::as_chunks_mut` vendored because of its 1.88 MSRV.
425/// TODO(tarcieri): use upstream function when we bump MSRV
426#[inline]
427#[track_caller]
428#[must_use]
429#[allow(clippy::integer_division_remainder_used)]
430fn slice_as_chunks_mut<T, const N: usize>(slice: &mut [T]) -> (&mut [[T; N]], &mut [T]) {
431    assert!(N != 0, "chunk size must be non-zero");
432    let len_rounded_down = slice.len() / N * N;
433    // SAFETY: The rounded-down value is always the same or smaller than the
434    // original length, and thus must be in-bounds of the slice.
435    let (multiple_of_n, remainder) = unsafe { slice.split_at_mut_unchecked(len_rounded_down) };
436    // SAFETY: We already panicked for zero, and ensured by construction
437    // that the length of the subslice is a multiple of N.
438    let array_slice = unsafe { slice_as_chunks_unchecked_mut(multiple_of_n) };
439    (array_slice, remainder)
440}
441
442/// Rust core `[T]::as_chunks_unchecked` vendored because of its 1.88 MSRV.
443/// TODO(tarcieri): use upstream function when we bump MSRV
444#[inline]
445#[must_use]
446#[track_caller]
447#[allow(clippy::integer_division_remainder_used)]
448unsafe fn slice_as_chunks_unchecked<T, const N: usize>(slice: &[T]) -> &[[T; N]] {
449    // SAFETY: Caller must guarantee that `N` is nonzero and exactly divides the slice length
450    const { debug_assert!(N != 0) };
451    debug_assert_eq!(slice.len() % N, 0);
452    let new_len = slice.len() / N;
453
454    // SAFETY: We cast a slice of `new_len * N` elements into
455    // a slice of `new_len` many `N` elements chunks.
456    unsafe { slice::from_raw_parts(slice.as_ptr().cast(), new_len) }
457}
458
459/// Rust core `[T]::as_chunks_unchecked_mut` vendored because of its 1.88 MSRV.
460/// TODO(tarcieri): use upstream function when we bump MSRV
461#[inline]
462#[must_use]
463#[track_caller]
464#[allow(clippy::integer_division_remainder_used)]
465unsafe fn slice_as_chunks_unchecked_mut<T, const N: usize>(slice: &mut [T]) -> &mut [[T; N]] {
466    // SAFETY: Caller must guarantee that `N` is nonzero and exactly divides the slice length
467    const { debug_assert!(N != 0) };
468    debug_assert_eq!(slice.len() % N, 0);
469    let new_len = slice.len() / N;
470
471    // SAFETY: We cast a slice of `new_len * N` elements into
472    // a slice of `new_len` many `N` elements chunks.
473    unsafe { slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), new_len) }
474}
475
476#[cfg(test)]
477mod tests {
478    #[test]
479    fn cmovnz_remainder() {
480        // - Test endianness handling on non-64-bit platforms
481        // - Test handling of odd length slices on 64-bit platforms
482        #[cfg(not(target_pointer_width = "64"))]
483        const A_U16: [u16; 2] = [0xAAAA, 0xBBBB];
484        #[cfg(target_pointer_width = "64")]
485        const A_U16: [u16; 3] = [0xAAAA, 0xBBBB, 0xCCCC];
486
487        #[cfg(not(target_pointer_width = "64"))]
488        const B_U16: [u16; 2] = [0x10, 0xFFFF];
489        #[cfg(target_pointer_width = "64")]
490        const B_U16: [u16; 3] = [0x10, 0x10, 0xFFFF];
491
492        let mut out = A_U16;
493
494        super::cmovnz_remainder(&mut out, &B_U16, 0);
495        assert_eq!(A_U16, out);
496
497        super::cmovnz_remainder(&mut out, &B_U16, 1);
498        assert_eq!(B_U16, out);
499    }
500
501    #[test]
502    fn slice_to_word() {
503        assert_eq!(0xAABBCC, super::slice_to_word(&[0xCCu8, 0xBB, 0xAA]));
504        assert_eq!(0xAAAABBBB, super::slice_to_word(&[0xBBBBu16, 0xAAAA]));
505
506        #[cfg(target_pointer_width = "64")]
507        assert_eq!(
508            0xAAAABBBBCCCC,
509            super::slice_to_word(&[0xCCCCu16, 0xBBBB, 0xAAAA])
510        );
511    }
512
513    #[test]
514    fn word_to_slice() {
515        let mut out = [0u8; 3];
516        super::word_to_slice(0xAABBCC, &mut out);
517        assert_eq!(&[0xCC, 0xBB, 0xAA], &out);
518
519        let mut out = [0u16; 2];
520        super::word_to_slice(0xAAAABBBB, &mut out);
521        assert_eq!(&[0xBBBB, 0xAAAA], &out);
522
523        #[cfg(target_pointer_width = "64")]
524        {
525            let mut out = [0u16; 3];
526            super::word_to_slice(0xAAAABBBBCCCC, &mut out);
527            assert_eq!(&[0xCCCC, 0xBBBB, 0xAAAA], &out);
528        }
529    }
530}