Skip to main content

ctutils/traits/
ct_eq.rs

1use crate::Choice;
2use cmov::CmovEq;
3use core::{
4    cmp,
5    num::{
6        NonZeroI8, NonZeroI16, NonZeroI32, NonZeroI64, NonZeroI128, NonZeroU8, NonZeroU16,
7        NonZeroU32, NonZeroU64, NonZeroU128,
8    },
9};
10
11#[cfg(feature = "subtle")]
12use crate::CtOption;
13
14/// Constant-time equality: like `(Partial)Eq` with [`Choice`] instead of [`bool`].
15///
16/// Impl'd for: [`u8`], [`u16`], [`u32`], [`u64`], [`u128`], [`usize`], [`cmp::Ordering`],
17/// [`Choice`], and arrays/slices of any type which also impls [`CtEq`].
18///
19/// This crate provides built-in implementations for the following types:
20/// - [`i8`], [`i16`], [`i32`], [`i64`], [`i128`], [`isize`]
21/// - [`u8`], [`u16`], [`u32`], [`u64`], [`u128`], [`usize`]
22/// - [`NonZeroI8`], [`NonZeroI16`], [`NonZeroI32`], [`NonZeroI64`], [`NonZeroI128`]
23/// - [`NonZeroU8`], [`NonZeroU16`], [`NonZeroU32`], [`NonZeroU64`], [`NonZeroU128`]
24/// - [`cmp::Ordering`]
25/// - [`Choice`]
26/// - `[T]` and `[T; N]` where `T` impls [`CtEqSlice`], which the previously mentioned types all do.
27pub trait CtEq<Rhs = Self>
28where
29    Rhs: ?Sized,
30{
31    /// Determine if `self` is equal to `other` in constant-time.
32    #[must_use]
33    fn ct_eq(&self, other: &Rhs) -> Choice;
34
35    /// Determine if `self` is NOT equal to `other` in constant-time.
36    #[must_use]
37    fn ct_ne(&self, other: &Rhs) -> Choice {
38        !self.ct_eq(other)
39    }
40}
41
42/// Implementing this trait enables use of the [`CtEq`] trait for `[T]` where `T` is the
43/// `Self` type implementing the trait, via a blanket impl.
44///
45/// It needs to be a separate trait from [`CtEq`] because we need to be able to impl
46/// [`CtEq`] for `[T]` which is `?Sized`.
47pub trait CtEqSlice: CtEq + Sized {
48    /// Determine if `a` is equal to `b` in constant-time.
49    #[must_use]
50    fn ct_eq_slice(a: &[Self], b: &[Self]) -> Choice {
51        let mut ret = a.len().ct_eq(&b.len());
52        for (a, b) in a.iter().zip(b.iter()) {
53            ret &= a.ct_eq(b);
54        }
55        ret
56    }
57
58    /// Determine if `a` is NOT equal to `b` in constant-time.
59    #[must_use]
60    fn ct_ne_slice(a: &[Self], b: &[Self]) -> Choice {
61        !Self::ct_eq_slice(a, b)
62    }
63}
64
65impl<T: CtEqSlice> CtEq for [T] {
66    fn ct_eq(&self, other: &Self) -> Choice {
67        T::ct_eq_slice(self, other)
68    }
69
70    fn ct_ne(&self, other: &Self) -> Choice {
71        T::ct_ne_slice(self, other)
72    }
73}
74
75/// Impl `CtEq` using the `cmov::CmovEq` trait
76macro_rules! impl_ct_eq_with_cmov_eq {
77    ( $($ty:ty),+ ) => {
78        $(
79            impl CtEq for $ty {
80                #[inline]
81                fn ct_eq(&self, other: &Self) -> Choice {
82                    let mut ret = Choice::FALSE;
83                    self.cmoveq(other, 1, &mut ret.0);
84                    ret
85                }
86            }
87        )+
88    };
89}
90
91/// Impl `CtEq` and `CtEqSlice` using the `cmov::CmovEq` trait
92macro_rules! impl_ct_eq_slice_with_cmov_eq {
93    ( $($ty:ty),+ ) => {
94        $(
95            impl_ct_eq_with_cmov_eq!($ty);
96
97            impl CtEqSlice for $ty {
98                #[inline]
99                fn ct_eq_slice(a: &[Self], b: &[Self]) -> Choice {
100                    let mut ret = Choice::FALSE;
101                    a.cmoveq(b, 1, &mut ret.0);
102                    ret
103                }
104            }
105        )+
106    };
107}
108
109impl_ct_eq_slice_with_cmov_eq!(i8, i16, i32, i64, i128, u8, u16, u32, u64, u128);
110impl_ct_eq_with_cmov_eq!(isize, usize);
111impl CtEqSlice for isize {}
112impl CtEqSlice for usize {}
113
114/// Impl `CtEq` for `NonZero<T>` by calling `NonZero::get`.
115macro_rules! impl_ct_eq_for_nonzero_integer {
116    ( $($ty:ty),+ ) => {
117        $(
118            impl CtEq for $ty {
119                #[inline]
120                fn ct_eq(&self, other: &Self) -> Choice {
121                    self.get().ct_eq(&other.get())
122                }
123            }
124
125            impl CtEqSlice for $ty {}
126        )+
127    };
128}
129
130impl_ct_eq_for_nonzero_integer!(
131    NonZeroI8,
132    NonZeroI16,
133    NonZeroI32,
134    NonZeroI64,
135    NonZeroI128,
136    NonZeroU8,
137    NonZeroU16,
138    NonZeroU32,
139    NonZeroU64,
140    NonZeroU128
141);
142
143impl CtEq for cmp::Ordering {
144    #[inline]
145    fn ct_eq(&self, other: &Self) -> Choice {
146        // `Ordering` is `repr(i8)`, which has a `CtEq` impl
147        (*self as i8).ct_eq(&(*other as i8))
148    }
149}
150
151impl CtEqSlice for cmp::Ordering {}
152
153impl<T, const N: usize> CtEq for [T; N]
154where
155    T: CtEqSlice,
156{
157    #[inline]
158    fn ct_eq(&self, other: &[T; N]) -> Choice {
159        self.as_slice().ct_eq(other.as_slice())
160    }
161}
162
163impl<T, const N: usize> CtEqSlice for [T; N] where T: CtEqSlice {}
164
165#[cfg(feature = "subtle")]
166impl CtEq for subtle::Choice {
167    #[inline]
168    fn ct_eq(&self, other: &Self) -> Choice {
169        self.unwrap_u8().ct_eq(&other.unwrap_u8())
170    }
171}
172
173#[cfg(feature = "subtle")]
174impl<T> CtEq for subtle::CtOption<T>
175where
176    T: CtEq + Default + subtle::ConditionallySelectable,
177{
178    #[inline]
179    fn ct_eq(&self, other: &Self) -> Choice {
180        CtOption::from(*self).ct_eq(&CtOption::from(*other))
181    }
182}
183
184#[cfg(feature = "alloc")]
185mod alloc {
186    use super::{Choice, CtEq, CtEqSlice};
187    use ::alloc::{boxed::Box, vec::Vec};
188
189    impl<T> CtEq for Box<T>
190    where
191        T: CtEq,
192    {
193        #[inline]
194        #[track_caller]
195        fn ct_eq(&self, rhs: &Self) -> Choice {
196            (**self).ct_eq(rhs)
197        }
198    }
199
200    impl<T> CtEq for Box<[T]>
201    where
202        T: CtEqSlice,
203    {
204        #[inline]
205        #[track_caller]
206        fn ct_eq(&self, rhs: &Self) -> Choice {
207            self.ct_eq(&**rhs)
208        }
209    }
210
211    impl<T> CtEq<[T]> for Box<[T]>
212    where
213        T: CtEqSlice,
214    {
215        #[inline]
216        #[track_caller]
217        fn ct_eq(&self, rhs: &[T]) -> Choice {
218            (**self).ct_eq(rhs)
219        }
220    }
221
222    impl<T> CtEq for Vec<T>
223    where
224        T: CtEqSlice,
225    {
226        #[inline]
227        #[track_caller]
228        fn ct_eq(&self, rhs: &Self) -> Choice {
229            self.ct_eq(rhs.as_slice())
230        }
231    }
232
233    impl<T> CtEq<[T]> for Vec<T>
234    where
235        T: CtEqSlice,
236    {
237        #[inline]
238        #[track_caller]
239        fn ct_eq(&self, rhs: &[T]) -> Choice {
240            self.as_slice().ct_eq(rhs)
241        }
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::CtEq;
248    use core::cmp::Ordering;
249
250    macro_rules! truth_table {
251        ($a:expr, $b:expr, $c:expr) => {
252            assert!($a.ct_eq(&$b).to_bool());
253            assert!(!$a.ct_eq(&$c).to_bool());
254            assert!(!$b.ct_eq(&$c).to_bool());
255
256            assert!(!$a.ct_ne(&$b).to_bool());
257            assert!($a.ct_ne(&$c).to_bool());
258            assert!($b.ct_ne(&$c).to_bool());
259        };
260    }
261
262    macro_rules! ct_eq_test_unsigned {
263        ($ty:ty, $name:ident) => {
264            #[test]
265            fn $name() {
266                let a = <$ty>::MAX;
267                let b = <$ty>::MAX;
268                let c = <$ty>::MIN;
269                truth_table!(a, b, c);
270            }
271        };
272    }
273
274    macro_rules! ct_eq_test_signed {
275        ($ty:ty, $name:ident) => {
276            #[test]
277            fn $name() {
278                let a = <$ty>::MAX;
279                let b = <$ty>::MAX;
280                let c = <$ty>::MIN;
281                truth_table!(a, b, c);
282            }
283        };
284    }
285
286    ct_eq_test_unsigned!(u8, u8_ct_eq);
287    ct_eq_test_unsigned!(u16, u16_ct_eq);
288    ct_eq_test_unsigned!(u32, u32_ct_eq);
289    ct_eq_test_unsigned!(u64, u64_ct_eq);
290    ct_eq_test_unsigned!(u128, u128_ct_eq);
291    ct_eq_test_unsigned!(usize, usize_ct_eq);
292
293    ct_eq_test_signed!(i8, i8_ct_eq);
294    ct_eq_test_signed!(i16, i16_ct_eq);
295    ct_eq_test_signed!(i32, i32_ct_eq);
296    ct_eq_test_signed!(i64, i64_ct_eq);
297    ct_eq_test_signed!(i128, i128_ct_eq);
298    ct_eq_test_signed!(isize, isize_ct_eq);
299
300    #[test]
301    fn array_ct_eq() {
302        let a = [1u64, 2, 3];
303        let b = [1u64, 2, 3];
304        let c = [1u64, 2, 4];
305        truth_table!(a, b, c);
306    }
307
308    #[test]
309    fn ordering_ct_eq() {
310        let a = Ordering::Greater;
311        let b = Ordering::Greater;
312        let c = Ordering::Less;
313        truth_table!(a, b, c);
314    }
315
316    #[test]
317    fn slice_ct_eq() {
318        let a: &[u64] = &[1, 2, 3];
319        let b: &[u64] = &[1, 2, 3];
320        let c: &[u64] = &[1, 2, 4];
321        truth_table!(a, b, c);
322
323        // Length mismatches
324        assert!(a.ct_ne(&[]).to_bool());
325        assert!(a.ct_ne(&[1, 2]).to_bool());
326    }
327}