1use crate::{Cmov, CmovEq, Condition};
4use core::{
5 cmp,
6 num::{
7 NonZeroI8, NonZeroI16, NonZeroI32, NonZeroI64, NonZeroI128, NonZeroU8, NonZeroU16,
8 NonZeroU32, NonZeroU64, NonZeroU128,
9 },
10 ops::{BitOrAssign, Shl},
11 ptr, slice,
12};
13
14#[cfg(not(target_pointer_width = "64"))]
16type Word = u32;
17#[cfg(target_pointer_width = "64")]
18type Word = u64;
19const WORD_SIZE: usize = size_of::<Word>();
20
21macro_rules! assert_lengths_eq {
23 ($a:expr, $b:expr) => {
24 assert_eq!(
25 $a, $b,
26 "source slice length ({}) does not match destination slice length ({})",
27 $b, $a
28 );
29 };
30}
31
32impl Cmov for [u8] {
39 #[inline]
40 #[track_caller]
41 fn cmovnz(&mut self, value: &Self, condition: Condition) {
42 assert_lengths_eq!(self.len(), value.len());
43
44 let (dst_chunks, dst_remainder) = slice_as_chunks_mut::<u8, WORD_SIZE>(self);
45 let (src_chunks, src_remainder) = slice_as_chunks::<u8, WORD_SIZE>(value);
46
47 for (dst_chunk, src_chunk) in dst_chunks.iter_mut().zip(src_chunks.iter()) {
48 let mut a = Word::from_ne_bytes(*dst_chunk);
49 let b = Word::from_ne_bytes(*src_chunk);
50 a.cmovnz(&b, condition);
51 dst_chunk.copy_from_slice(&a.to_ne_bytes());
52 }
53
54 cmovnz_remainder(dst_remainder, src_remainder, condition);
55 }
56}
57
58#[cfg(not(target_pointer_width = "64"))]
61#[cfg_attr(docsrs, doc(cfg(true)))]
62impl Cmov for [u16] {
63 #[inline]
64 #[track_caller]
65 fn cmovnz(&mut self, value: &Self, condition: Condition) {
66 assert_lengths_eq!(self.len(), value.len());
67
68 let (dst_chunks, dst_remainder) = slice_as_chunks_mut::<u16, 2>(self);
69 let (src_chunks, src_remainder) = slice_as_chunks::<u16, 2>(value);
70
71 for (dst_chunk, src_chunk) in dst_chunks.iter_mut().zip(src_chunks.iter()) {
72 let mut a = Word::from(dst_chunk[0]) | (Word::from(dst_chunk[1]) << 16);
73 let b = Word::from(src_chunk[0]) | (Word::from(src_chunk[1]) << 16);
74 a.cmovnz(&b, condition);
75 dst_chunk[0] = (a & 0xFFFF) as u16;
76 dst_chunk[1] = (a >> 16) as u16;
77 }
78
79 cmovnz_remainder(dst_remainder, src_remainder, condition);
80 }
81}
82
83#[cfg(target_pointer_width = "64")]
86#[cfg_attr(docsrs, doc(cfg(true)))]
87impl Cmov for [u16] {
88 #[inline]
89 #[track_caller]
90 fn cmovnz(&mut self, value: &Self, condition: Condition) {
91 assert_lengths_eq!(self.len(), value.len());
92
93 #[inline(always)]
94 fn u16x4_to_u64(input: &[u16; 4]) -> u64 {
95 Word::from(input[0])
96 | (Word::from(input[1]) << 16)
97 | (Word::from(input[2]) << 32)
98 | (Word::from(input[3]) << 48)
99 }
100
101 let (dst_chunks, dst_remainder) = slice_as_chunks_mut::<u16, 4>(self);
102 let (src_chunks, src_remainder) = slice_as_chunks::<u16, 4>(value);
103
104 for (dst_chunk, src_chunk) in dst_chunks.iter_mut().zip(src_chunks.iter()) {
105 let mut a = u16x4_to_u64(dst_chunk);
106 let b = u16x4_to_u64(src_chunk);
107 a.cmovnz(&b, condition);
108 dst_chunk[0] = (a & 0xFFFF) as u16;
109 dst_chunk[1] = ((a >> 16) & 0xFFFF) as u16;
110 dst_chunk[2] = ((a >> 32) & 0xFFFF) as u16;
111 dst_chunk[3] = ((a >> 48) & 0xFFFF) as u16;
112 }
113
114 cmovnz_remainder(dst_remainder, src_remainder, condition);
115 }
116}
117
118macro_rules! impl_cmov_with_loop {
120 ( $($int:ty),+ ) => {
121 $(
122 impl Cmov for [$int] {
123 #[inline]
124 #[track_caller]
125 fn cmovnz(&mut self, value: &Self, condition: Condition) {
126 assert_lengths_eq!(self.len(), value.len());
127 for (a, b) in self.iter_mut().zip(value.iter()) {
128 a.cmovnz(b, condition);
129 }
130 }
131 }
132 )+
133 };
134}
135
136impl_cmov_with_loop!(u32, u64, u128);
138
139macro_rules! assert_size_and_alignment_eq {
141 ($int:ty, $uint:ty) => {
142 const {
143 assert!(
144 size_of::<$int>() == size_of::<$uint>(),
145 "integers are of unequal size"
146 );
147
148 assert!(
149 align_of::<$int>() == align_of::<$uint>(),
150 "integers have unequal alignment"
151 );
152 }
153 };
154}
155
156macro_rules! impl_cmov_with_cast {
158 ( $($src:ty => $dst:ty),+ ) => {
159 $(
160 impl Cmov for [$src] {
161 #[inline]
162 #[track_caller]
163 #[allow(unsafe_code)]
164 fn cmovnz(&mut self, value: &Self, condition: Condition) {
165 assert_size_and_alignment_eq!($src, $dst);
166
167 let self_unsigned = unsafe { cast_slice_mut::<$src, $dst>(self) };
171 let value_unsigned = unsafe { cast_slice::<$src, $dst>(value) };
172 self_unsigned.cmovnz(value_unsigned, condition);
173 }
174 }
175 )+
176 };
177}
178
179impl_cmov_with_cast!(
181 i8 => u8,
182 i16 => u16,
183 i32 => u32,
184 i64 => u64,
185 i128 => u128,
186 NonZeroI8 => i8,
187 NonZeroI16 => i16,
188 NonZeroI32 => i32,
189 NonZeroI64 => i64,
190 NonZeroI128 => i128,
191 NonZeroU8 => u8,
192 NonZeroU16 => u16,
193 NonZeroU32 => u32,
194 NonZeroU64 => u64,
195 NonZeroU128 => u128,
196 cmp::Ordering => i8 );
198
199impl CmovEq for [u8] {
206 #[inline]
207 fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
208 if self.len() != rhs.len() {
211 *output = input;
212 return;
213 }
214
215 let (self_chunks, self_remainder) = slice_as_chunks::<u8, WORD_SIZE>(self);
216 let (rhs_chunks, rhs_remainder) = slice_as_chunks::<u8, WORD_SIZE>(rhs);
217
218 for (self_chunk, rhs_chunk) in self_chunks.iter().zip(rhs_chunks.iter()) {
219 let a = Word::from_ne_bytes(*self_chunk);
220 let b = Word::from_ne_bytes(*rhs_chunk);
221 a.cmovne(&b, input, output);
222 }
223
224 cmovne_remainder(self_remainder, rhs_remainder, input, output);
225 }
226}
227
228macro_rules! impl_cmoveq_with_loop {
230 ( $($int:ty),+ ) => {
231 $(
232 impl CmovEq for [$int] {
233 #[inline]
234 fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
235 if self.len() != rhs.len() {
238 *output = input;
239 return;
240 }
241
242 for (a, b) in self.iter().zip(rhs.iter()) {
243 a.cmovne(b, input, output);
244 }
245 }
246 }
247 )+
248 };
249}
250
251impl_cmoveq_with_loop!(u16, u32, u64, u128);
253
254macro_rules! impl_cmoveq_with_cast {
256 ( $($src:ty => $dst:ty),+ ) => {
257 $(
258 impl CmovEq for [$src] {
259 #[inline]
260 #[allow(unsafe_code)]
261 fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
262 assert_size_and_alignment_eq!($src, $dst);
263
264 let self_unsigned = unsafe { cast_slice::<$src, $dst>(self) };
268 let rhs_unsigned = unsafe { cast_slice::<$src, $dst>(rhs) };
269 self_unsigned.cmovne(rhs_unsigned, input, output);
270 }
271 }
272 )+
273 };
274}
275
276impl_cmoveq_with_cast!(
278 i8 => u8,
279 i16 => u16,
280 i32 => u32,
281 i64 => u64,
282 i128 => u128,
283 NonZeroI8 => i8,
284 NonZeroI16 => i16,
285 NonZeroI32 => i32,
286 NonZeroI64 => i64,
287 NonZeroI128 => i128,
288 NonZeroU8 => u8,
289 NonZeroU16 => u16,
290 NonZeroU32 => u32,
291 NonZeroU64 => u64,
292 NonZeroU128 => u128,
293 cmp::Ordering => i8 );
295
296unsafe fn cast_slice<T, U>(slice: &[T]) -> &[U] {
306 const {
307 assert!(size_of::<T>() == size_of::<U>(), "T/U size differs");
308 assert!(align_of::<T>() == align_of::<U>(), "T/U alignment differs");
309 }
310
311 #[allow(unsafe_code)]
315 unsafe {
316 &*(ptr::from_ref::<[T]>(slice) as *const [U])
317 }
318}
319
320unsafe fn cast_slice_mut<T, U>(slice: &mut [T]) -> &mut [U] {
326 const {
327 assert!(size_of::<T>() == size_of::<U>(), "T/U size differs");
328 assert!(align_of::<T>() == align_of::<U>(), "T/U alignment differs");
329 }
330
331 #[allow(unsafe_code)]
335 unsafe {
336 &mut *(ptr::from_mut::<[T]>(slice) as *mut [U])
337 }
338}
339
340#[inline]
342fn cmovne_remainder<T>(
343 a_remainder: &[T],
344 b_remainder: &[T],
345 input: Condition,
346 output: &mut Condition,
347) where
348 T: Copy,
349 Word: From<T>,
350{
351 let a = slice_to_word(a_remainder);
352 let b = slice_to_word(b_remainder);
353 a.cmovne(&b, input, output);
354}
355
356#[inline]
359fn cmovnz_remainder<T>(dst_remainder: &mut [T], src_remainder: &[T], condition: Condition)
360where
361 T: BitOrAssign + Copy + From<u8> + Shl<usize, Output = T>,
362 Word: From<T>,
363{
364 let mut remainder = slice_to_word(dst_remainder);
365 remainder.cmovnz(&slice_to_word(src_remainder), condition);
366 word_to_slice(remainder, dst_remainder);
367}
368
369#[inline]
371fn slice_to_word<T>(slice: &[T]) -> Word
372where
373 T: Copy,
374 Word: From<T>,
375{
376 debug_assert!(size_of_val(slice) <= WORD_SIZE, "slice too large");
377 slice
378 .iter()
379 .rev()
380 .copied()
381 .fold(0, |acc, n| (acc << (size_of::<T>() * 8)) | Word::from(n))
382}
383
384#[inline]
386fn word_to_slice<T>(word: Word, out: &mut [T])
387where
388 T: BitOrAssign + Copy + From<u8> + Shl<usize, Output = T>,
389{
390 debug_assert!(size_of::<T>() > 0, "can't be used with ZSTs");
391 debug_assert!(out.len() <= WORD_SIZE, "slice too large");
392
393 let bytes = word.to_le_bytes();
394 for (o, chunk) in out.iter_mut().zip(bytes.chunks(size_of::<T>())) {
395 *o = T::from(0u8);
396 for (i, &byte) in chunk.iter().enumerate() {
397 *o |= T::from(byte) << (i * 8);
398 }
399 }
400}
401
402#[inline]
409#[track_caller]
410#[must_use]
411#[allow(clippy::integer_division_remainder_used)]
412fn slice_as_chunks<T, const N: usize>(slice: &[T]) -> (&[[T; N]], &[T]) {
413 assert!(N != 0, "chunk size must be non-zero");
414 let len_rounded_down = slice.len() / N * N;
415 let (multiple_of_n, remainder) = unsafe { slice.split_at_unchecked(len_rounded_down) };
418 let array_slice = unsafe { slice_as_chunks_unchecked(multiple_of_n) };
421 (array_slice, remainder)
422}
423
424#[inline]
427#[track_caller]
428#[must_use]
429#[allow(clippy::integer_division_remainder_used)]
430fn slice_as_chunks_mut<T, const N: usize>(slice: &mut [T]) -> (&mut [[T; N]], &mut [T]) {
431 assert!(N != 0, "chunk size must be non-zero");
432 let len_rounded_down = slice.len() / N * N;
433 let (multiple_of_n, remainder) = unsafe { slice.split_at_mut_unchecked(len_rounded_down) };
436 let array_slice = unsafe { slice_as_chunks_unchecked_mut(multiple_of_n) };
439 (array_slice, remainder)
440}
441
442#[inline]
445#[must_use]
446#[track_caller]
447#[allow(clippy::integer_division_remainder_used)]
448unsafe fn slice_as_chunks_unchecked<T, const N: usize>(slice: &[T]) -> &[[T; N]] {
449 const { debug_assert!(N != 0) };
451 debug_assert_eq!(slice.len() % N, 0);
452 let new_len = slice.len() / N;
453
454 unsafe { slice::from_raw_parts(slice.as_ptr().cast(), new_len) }
457}
458
459#[inline]
462#[must_use]
463#[track_caller]
464#[allow(clippy::integer_division_remainder_used)]
465unsafe fn slice_as_chunks_unchecked_mut<T, const N: usize>(slice: &mut [T]) -> &mut [[T; N]] {
466 const { debug_assert!(N != 0) };
468 debug_assert_eq!(slice.len() % N, 0);
469 let new_len = slice.len() / N;
470
471 unsafe { slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), new_len) }
474}
475
476#[cfg(test)]
477mod tests {
478 #[test]
479 fn cmovnz_remainder() {
480 #[cfg(not(target_pointer_width = "64"))]
483 const A_U16: [u16; 2] = [0xAAAA, 0xBBBB];
484 #[cfg(target_pointer_width = "64")]
485 const A_U16: [u16; 3] = [0xAAAA, 0xBBBB, 0xCCCC];
486
487 #[cfg(not(target_pointer_width = "64"))]
488 const B_U16: [u16; 2] = [0x10, 0xFFFF];
489 #[cfg(target_pointer_width = "64")]
490 const B_U16: [u16; 3] = [0x10, 0x10, 0xFFFF];
491
492 let mut out = A_U16;
493
494 super::cmovnz_remainder(&mut out, &B_U16, 0);
495 assert_eq!(A_U16, out);
496
497 super::cmovnz_remainder(&mut out, &B_U16, 1);
498 assert_eq!(B_U16, out);
499 }
500
501 #[test]
502 fn slice_to_word() {
503 assert_eq!(0xAABBCC, super::slice_to_word(&[0xCCu8, 0xBB, 0xAA]));
504 assert_eq!(0xAAAABBBB, super::slice_to_word(&[0xBBBBu16, 0xAAAA]));
505
506 #[cfg(target_pointer_width = "64")]
507 assert_eq!(
508 0xAAAABBBBCCCC,
509 super::slice_to_word(&[0xCCCCu16, 0xBBBB, 0xAAAA])
510 );
511 }
512
513 #[test]
514 fn word_to_slice() {
515 let mut out = [0u8; 3];
516 super::word_to_slice(0xAABBCC, &mut out);
517 assert_eq!(&[0xCC, 0xBB, 0xAA], &out);
518
519 let mut out = [0u16; 2];
520 super::word_to_slice(0xAAAABBBB, &mut out);
521 assert_eq!(&[0xBBBB, 0xAAAA], &out);
522
523 #[cfg(target_pointer_width = "64")]
524 {
525 let mut out = [0u16; 3];
526 super::word_to_slice(0xAAAABBBBCCCC, &mut out);
527 assert_eq!(&[0xCCCC, 0xBBBB, 0xAAAA], &out);
528 }
529 }
530}