1use crate::{Cmov, CmovEq, Condition};
4use core::{
5 cmp,
6 num::{
7 NonZeroI8, NonZeroI16, NonZeroI32, NonZeroI64, NonZeroI128, NonZeroIsize, NonZeroU8,
8 NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU128, NonZeroUsize,
9 },
10 ops::{BitOrAssign, Shl},
11 ptr, slice,
12};
13
14#[cfg(not(target_pointer_width = "64"))]
16type Word = u32;
17#[cfg(target_pointer_width = "64")]
18type Word = u64;
19const WORD_SIZE: usize = size_of::<Word>();
20
21macro_rules! assert_lengths_eq {
23 ($a:expr, $b:expr) => {
24 assert_eq!(
25 $a, $b,
26 "source slice length ({}) does not match destination slice length ({})",
27 $b, $a
28 );
29 };
30}
31
32impl Cmov for [u8] {
39 #[inline]
40 #[track_caller]
41 fn cmovnz(&mut self, value: &Self, condition: Condition) {
42 assert_lengths_eq!(self.len(), value.len());
43
44 let (dst_chunks, dst_remainder) = slice_as_chunks_mut::<u8, WORD_SIZE>(self);
45 let (src_chunks, src_remainder) = slice_as_chunks::<u8, WORD_SIZE>(value);
46
47 for (dst_chunk, src_chunk) in dst_chunks.iter_mut().zip(src_chunks.iter()) {
48 let mut a = Word::from_ne_bytes(*dst_chunk);
49 let b = Word::from_ne_bytes(*src_chunk);
50 a.cmovnz(&b, condition);
51 dst_chunk.copy_from_slice(&a.to_ne_bytes());
52 }
53
54 cmovnz_remainder(dst_remainder, src_remainder, condition);
55 }
56}
57
58#[cfg(not(target_pointer_width = "64"))]
61#[cfg_attr(docsrs, doc(cfg(true)))]
62impl Cmov for [u16] {
63 #[inline]
64 #[track_caller]
65 fn cmovnz(&mut self, value: &Self, condition: Condition) {
66 assert_lengths_eq!(self.len(), value.len());
67
68 let (dst_chunks, dst_remainder) = slice_as_chunks_mut::<u16, 2>(self);
69 let (src_chunks, src_remainder) = slice_as_chunks::<u16, 2>(value);
70
71 for (dst_chunk, src_chunk) in dst_chunks.iter_mut().zip(src_chunks.iter()) {
72 let mut a = Word::from(dst_chunk[0]) | (Word::from(dst_chunk[1]) << 16);
73 let b = Word::from(src_chunk[0]) | (Word::from(src_chunk[1]) << 16);
74 a.cmovnz(&b, condition);
75 dst_chunk[0] = (a & 0xFFFF) as u16;
76 dst_chunk[1] = (a >> 16) as u16;
77 }
78
79 cmovnz_remainder(dst_remainder, src_remainder, condition);
80 }
81}
82
83#[cfg(target_pointer_width = "64")]
86#[cfg_attr(docsrs, doc(cfg(true)))]
87impl Cmov for [u16] {
88 #[inline]
89 #[track_caller]
90 fn cmovnz(&mut self, value: &Self, condition: Condition) {
91 assert_lengths_eq!(self.len(), value.len());
92
93 #[inline(always)]
94 fn u16x4_to_u64(input: &[u16; 4]) -> u64 {
95 Word::from(input[0])
96 | (Word::from(input[1]) << 16)
97 | (Word::from(input[2]) << 32)
98 | (Word::from(input[3]) << 48)
99 }
100
101 let (dst_chunks, dst_remainder) = slice_as_chunks_mut::<u16, 4>(self);
102 let (src_chunks, src_remainder) = slice_as_chunks::<u16, 4>(value);
103
104 for (dst_chunk, src_chunk) in dst_chunks.iter_mut().zip(src_chunks.iter()) {
105 let mut a = u16x4_to_u64(dst_chunk);
106 let b = u16x4_to_u64(src_chunk);
107 a.cmovnz(&b, condition);
108 dst_chunk[0] = (a & 0xFFFF) as u16;
109 dst_chunk[1] = ((a >> 16) & 0xFFFF) as u16;
110 dst_chunk[2] = ((a >> 32) & 0xFFFF) as u16;
111 dst_chunk[3] = ((a >> 48) & 0xFFFF) as u16;
112 }
113
114 cmovnz_remainder(dst_remainder, src_remainder, condition);
115 }
116}
117
118macro_rules! impl_cmov_with_loop {
120 ( $($int:ty),+ ) => {
121 $(
122 impl Cmov for [$int] {
123 #[inline]
124 #[track_caller]
125 fn cmovnz(&mut self, value: &Self, condition: Condition) {
126 assert_lengths_eq!(self.len(), value.len());
127 for (a, b) in self.iter_mut().zip(value.iter()) {
128 a.cmovnz(b, condition);
129 }
130 }
131 }
132 )+
133 };
134}
135
136impl_cmov_with_loop!(u32, u64, u128, usize);
138
139macro_rules! assert_size_and_alignment_eq {
141 ($int:ty, $uint:ty) => {
142 const {
143 assert!(
144 size_of::<$int>() == size_of::<$uint>(),
145 "integers are of unequal size"
146 );
147
148 assert!(
149 align_of::<$int>() == align_of::<$uint>(),
150 "integers have unequal alignment"
151 );
152 }
153 };
154}
155
156macro_rules! impl_cmov_with_cast {
158 ( $($src:ty => $dst:ty),+ ) => {
159 $(
160 impl Cmov for [$src] {
161 #[inline]
162 #[track_caller]
163 #[allow(unsafe_code)]
164 fn cmovnz(&mut self, value: &Self, condition: Condition) {
165 assert_size_and_alignment_eq!($src, $dst);
166
167 let self_unsigned = unsafe { cast_slice_mut::<$src, $dst>(self) };
171 let value_unsigned = unsafe { cast_slice::<$src, $dst>(value) };
172 self_unsigned.cmovnz(value_unsigned, condition);
173 }
174 }
175 )+
176 };
177}
178
179impl_cmov_with_cast!(
181 i8 => u8,
182 i16 => u16,
183 i32 => u32,
184 i64 => u64,
185 i128 => u128,
186 isize => usize,
187 NonZeroI8 => i8,
188 NonZeroI16 => i16,
189 NonZeroI32 => i32,
190 NonZeroI64 => i64,
191 NonZeroI128 => i128,
192 NonZeroIsize => isize,
193 NonZeroU8 => u8,
194 NonZeroU16 => u16,
195 NonZeroU32 => u32,
196 NonZeroU64 => u64,
197 NonZeroU128 => u128,
198 NonZeroUsize => usize,
199 cmp::Ordering => i8 );
201
202impl CmovEq for [u8] {
209 #[inline]
210 fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
211 if self.len() != rhs.len() {
214 *output = input;
215 return;
216 }
217
218 let (self_chunks, self_remainder) = slice_as_chunks::<u8, WORD_SIZE>(self);
219 let (rhs_chunks, rhs_remainder) = slice_as_chunks::<u8, WORD_SIZE>(rhs);
220
221 for (self_chunk, rhs_chunk) in self_chunks.iter().zip(rhs_chunks.iter()) {
222 let a = Word::from_ne_bytes(*self_chunk);
223 let b = Word::from_ne_bytes(*rhs_chunk);
224 a.cmovne(&b, input, output);
225 }
226
227 cmovne_remainder(self_remainder, rhs_remainder, input, output);
228 }
229}
230
231macro_rules! impl_cmoveq_with_loop {
233 ( $($int:ty),+ ) => {
234 $(
235 impl CmovEq for [$int] {
236 #[inline]
237 fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
238 if self.len() != rhs.len() {
241 *output = input;
242 return;
243 }
244
245 for (a, b) in self.iter().zip(rhs.iter()) {
246 a.cmovne(b, input, output);
247 }
248 }
249 }
250 )+
251 };
252}
253
254impl_cmoveq_with_loop!(u16, u32, u64, u128, usize);
256
257macro_rules! impl_cmoveq_with_cast {
259 ( $($src:ty => $dst:ty),+ ) => {
260 $(
261 impl CmovEq for [$src] {
262 #[inline]
263 #[allow(unsafe_code)]
264 fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
265 assert_size_and_alignment_eq!($src, $dst);
266
267 let self_unsigned = unsafe { cast_slice::<$src, $dst>(self) };
271 let rhs_unsigned = unsafe { cast_slice::<$src, $dst>(rhs) };
272 self_unsigned.cmovne(rhs_unsigned, input, output);
273 }
274 }
275 )+
276 };
277}
278
279impl_cmoveq_with_cast!(
281 i8 => u8,
282 i16 => u16,
283 i32 => u32,
284 i64 => u64,
285 i128 => u128,
286 isize => usize,
287 NonZeroI8 => i8,
288 NonZeroI16 => i16,
289 NonZeroI32 => i32,
290 NonZeroI64 => i64,
291 NonZeroI128 => i128,
292 NonZeroIsize => isize,
293 NonZeroU8 => u8,
294 NonZeroU16 => u16,
295 NonZeroU32 => u32,
296 NonZeroU64 => u64,
297 NonZeroU128 => u128,
298 NonZeroUsize => usize,
299 cmp::Ordering => i8 );
301
302unsafe fn cast_slice<T, U>(slice: &[T]) -> &[U] {
312 const {
313 assert!(size_of::<T>() == size_of::<U>(), "T/U size differs");
314 assert!(align_of::<T>() == align_of::<U>(), "T/U alignment differs");
315 }
316
317 #[allow(unsafe_code)]
321 unsafe {
322 &*(ptr::from_ref::<[T]>(slice) as *const [U])
323 }
324}
325
326unsafe fn cast_slice_mut<T, U>(slice: &mut [T]) -> &mut [U] {
332 const {
333 assert!(size_of::<T>() == size_of::<U>(), "T/U size differs");
334 assert!(align_of::<T>() == align_of::<U>(), "T/U alignment differs");
335 }
336
337 #[allow(unsafe_code)]
341 unsafe {
342 &mut *(ptr::from_mut::<[T]>(slice) as *mut [U])
343 }
344}
345
346#[inline]
348fn cmovne_remainder<T>(
349 a_remainder: &[T],
350 b_remainder: &[T],
351 input: Condition,
352 output: &mut Condition,
353) where
354 T: Copy,
355 Word: From<T>,
356{
357 let a = slice_to_word(a_remainder);
358 let b = slice_to_word(b_remainder);
359 a.cmovne(&b, input, output);
360}
361
362#[inline]
365fn cmovnz_remainder<T>(dst_remainder: &mut [T], src_remainder: &[T], condition: Condition)
366where
367 T: BitOrAssign + Copy + From<u8> + Shl<usize, Output = T>,
368 Word: From<T>,
369{
370 let mut remainder = slice_to_word(dst_remainder);
371 remainder.cmovnz(&slice_to_word(src_remainder), condition);
372 word_to_slice(remainder, dst_remainder);
373}
374
375#[inline]
377fn slice_to_word<T>(slice: &[T]) -> Word
378where
379 T: Copy,
380 Word: From<T>,
381{
382 debug_assert!(size_of_val(slice) <= WORD_SIZE, "slice too large");
383 slice.iter().rev().copied().fold(0, |acc, n| {
384 (acc << (const { size_of::<T>() * 8 })) | Word::from(n)
385 })
386}
387
388#[inline]
390#[allow(clippy::arithmetic_side_effects)]
391fn word_to_slice<T>(word: Word, out: &mut [T])
392where
393 T: BitOrAssign + Copy + From<u8> + Shl<usize, Output = T>,
394{
395 debug_assert!(size_of::<T>() > 0, "can't be used with ZSTs");
396 debug_assert!(out.len() <= WORD_SIZE, "slice too large");
397
398 let bytes = word.to_le_bytes();
399 for (o, chunk) in out.iter_mut().zip(bytes.chunks(size_of::<T>())) {
400 *o = T::from(0u8);
401 for (i, &byte) in chunk.iter().enumerate() {
402 *o |= T::from(byte) << (i * 8);
403 }
404 }
405}
406
407#[inline]
414#[track_caller]
415#[must_use]
416#[allow(
417 clippy::arithmetic_side_effects,
418 clippy::integer_division_remainder_used
419)]
420fn slice_as_chunks<T, const N: usize>(slice: &[T]) -> (&[[T; N]], &[T]) {
421 assert!(N != 0, "chunk size must be non-zero");
422 let len_rounded_down = slice.len() / N * N;
423 let (multiple_of_n, remainder) = unsafe { slice.split_at_unchecked(len_rounded_down) };
426 let array_slice = unsafe { slice_as_chunks_unchecked(multiple_of_n) };
429 (array_slice, remainder)
430}
431
432#[inline]
435#[track_caller]
436#[must_use]
437#[allow(
438 clippy::arithmetic_side_effects,
439 clippy::integer_division_remainder_used
440)]
441fn slice_as_chunks_mut<T, const N: usize>(slice: &mut [T]) -> (&mut [[T; N]], &mut [T]) {
442 assert!(N != 0, "chunk size must be non-zero");
443 let len_rounded_down = slice.len() / N * N;
444 let (multiple_of_n, remainder) = unsafe { slice.split_at_mut_unchecked(len_rounded_down) };
447 let array_slice = unsafe { slice_as_chunks_unchecked_mut(multiple_of_n) };
450 (array_slice, remainder)
451}
452
453#[inline]
456#[must_use]
457#[track_caller]
458#[allow(
459 clippy::arithmetic_side_effects,
460 clippy::integer_division_remainder_used
461)]
462unsafe fn slice_as_chunks_unchecked<T, const N: usize>(slice: &[T]) -> &[[T; N]] {
463 const { debug_assert!(N != 0) };
465 debug_assert_eq!(slice.len() % N, 0);
466 let new_len = slice.len() / N;
467
468 unsafe { slice::from_raw_parts(slice.as_ptr().cast(), new_len) }
471}
472
473#[inline]
476#[must_use]
477#[track_caller]
478#[allow(
479 clippy::arithmetic_side_effects,
480 clippy::integer_division_remainder_used
481)]
482unsafe fn slice_as_chunks_unchecked_mut<T, const N: usize>(slice: &mut [T]) -> &mut [[T; N]] {
483 const { debug_assert!(N != 0) };
485 debug_assert_eq!(slice.len() % N, 0);
486 let new_len = slice.len() / N;
487
488 unsafe { slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), new_len) }
491}
492
493#[cfg(test)]
494mod tests {
495 #[test]
496 fn cmovnz_remainder() {
497 #[cfg(not(target_pointer_width = "64"))]
500 const A_U16: [u16; 2] = [0xAAAA, 0xBBBB];
501 #[cfg(target_pointer_width = "64")]
502 const A_U16: [u16; 3] = [0xAAAA, 0xBBBB, 0xCCCC];
503
504 #[cfg(not(target_pointer_width = "64"))]
505 const B_U16: [u16; 2] = [0x10, 0xFFFF];
506 #[cfg(target_pointer_width = "64")]
507 const B_U16: [u16; 3] = [0x10, 0x10, 0xFFFF];
508
509 let mut out = A_U16;
510
511 super::cmovnz_remainder(&mut out, &B_U16, 0);
512 assert_eq!(A_U16, out);
513
514 super::cmovnz_remainder(&mut out, &B_U16, 1);
515 assert_eq!(B_U16, out);
516 }
517
518 #[test]
519 fn slice_to_word() {
520 assert_eq!(0xAABBCC, super::slice_to_word(&[0xCCu8, 0xBB, 0xAA]));
521 assert_eq!(0xAAAABBBB, super::slice_to_word(&[0xBBBBu16, 0xAAAA]));
522
523 #[cfg(target_pointer_width = "64")]
524 assert_eq!(
525 0xAAAABBBBCCCC,
526 super::slice_to_word(&[0xCCCCu16, 0xBBBB, 0xAAAA])
527 );
528 }
529
530 #[test]
531 fn word_to_slice() {
532 let mut out = [0u8; 3];
533 super::word_to_slice(0xAABBCC, &mut out);
534 assert_eq!(&[0xCC, 0xBB, 0xAA], &out);
535
536 let mut out = [0u16; 2];
537 super::word_to_slice(0xAAAABBBB, &mut out);
538 assert_eq!(&[0xBBBB, 0xAAAA], &out);
539
540 #[cfg(target_pointer_width = "64")]
541 {
542 let mut out = [0u16; 3];
543 super::word_to_slice(0xAAAABBBBCCCC, &mut out);
544 assert_eq!(&[0xCCCC, 0xBBBB, 0xAAAA], &out);
545 }
546 }
547}