Skip to main content

ctutils/traits/
ct_eq.rs

1use crate::Choice;
2use cmov::CmovEq;
3use core::{
4    cmp,
5    num::{
6        NonZeroI8, NonZeroI16, NonZeroI32, NonZeroI64, NonZeroI128, NonZeroU8, NonZeroU16,
7        NonZeroU32, NonZeroU64, NonZeroU128,
8    },
9};
10
11#[cfg(feature = "subtle")]
12use crate::CtOption;
13
14/// Constant-time equality: like `(Partial)Eq` with [`Choice`] instead of [`bool`].
15///
16/// Impl'd for: [`u8`], [`u16`], [`u32`], [`u64`], [`u128`], [`usize`], [`cmp::Ordering`],
17/// [`Choice`], and arrays/slices of any type which also impls [`CtEq`].
18///
19/// This crate provides built-in implementations for the following types:
20/// - [`i8`], [`i16`], [`i32`], [`i64`], [`i128`], [`isize`]
21/// - [`u8`], [`u16`], [`u32`], [`u64`], [`u128`], [`usize`]
22/// - [`NonZeroI8`], [`NonZeroI16`], [`NonZeroI32`], [`NonZeroI64`], [`NonZeroI128`]
23/// - [`NonZeroU8`], [`NonZeroU16`], [`NonZeroU32`], [`NonZeroU64`], [`NonZeroU128`]
24/// - [`cmp::Ordering`]
25/// - [`Choice`]
26/// - `[T]` and `[T; N]` where `T` impls [`CtEqSlice`], which the previously mentioned types all do.
27pub trait CtEq<Rhs = Self>
28where
29    Rhs: ?Sized,
30{
31    /// Determine if `self` is equal to `other` in constant-time.
32    #[must_use]
33    fn ct_eq(&self, other: &Rhs) -> Choice;
34
35    /// Determine if `self` is NOT equal to `other` in constant-time.
36    #[must_use]
37    fn ct_ne(&self, other: &Rhs) -> Choice {
38        !self.ct_eq(other)
39    }
40}
41
42/// Implementing this trait enables use of the [`CtEq`] trait for `[T]` where `T` is the
43/// `Self` type implementing the trait, via a blanket impl.
44///
45/// It needs to be a separate trait from [`CtEq`] because we need to be able to impl
46/// [`CtEq`] for `[T]` which is `?Sized`.
47pub trait CtEqSlice: CtEq + Sized {
48    /// Determine if `a` is equal to `b` in constant-time.
49    #[must_use]
50    fn ct_eq_slice(a: &[Self], b: &[Self]) -> Choice {
51        let mut ret = a.len().ct_eq(&b.len());
52        for (a, b) in a.iter().zip(b.iter()) {
53            ret &= a.ct_eq(b);
54        }
55        ret
56    }
57
58    /// Determine if `a` is NOT equal to `b` in constant-time.
59    #[must_use]
60    fn ct_ne_slice(a: &[Self], b: &[Self]) -> Choice {
61        !Self::ct_eq_slice(a, b)
62    }
63}
64
65impl<T: CtEqSlice> CtEq for [T] {
66    fn ct_eq(&self, other: &Self) -> Choice {
67        T::ct_eq_slice(self, other)
68    }
69
70    fn ct_ne(&self, other: &Self) -> Choice {
71        T::ct_ne_slice(self, other)
72    }
73}
74
75/// Impl `CtEq` using the `cmov::CmovEq` trait
76macro_rules! impl_ct_eq_with_cmov_eq {
77    ( $($ty:ty),+ ) => {
78        $(
79            impl CtEq for $ty {
80                #[inline]
81                fn ct_eq(&self, other: &Self) -> Choice {
82                    let mut ret = Choice::FALSE;
83                    self.cmoveq(other, 1, &mut ret.0);
84                    ret
85                }
86            }
87        )+
88    };
89}
90
91/// Impl `CtEq` and `CtEqSlice` using the `cmov::CmovEq` trait
92macro_rules! impl_ct_eq_slice_with_cmov_eq {
93    ( $($ty:ty),+ ) => {
94        $(
95            impl_ct_eq_with_cmov_eq!($ty);
96
97            impl CtEqSlice for $ty {
98                #[inline]
99                fn ct_eq_slice(a: &[Self], b: &[Self]) -> Choice {
100                    let mut ret = Choice::FALSE;
101                    a.cmoveq(b, 1, &mut ret.0);
102                    ret
103                }
104            }
105        )+
106    };
107}
108
109impl_ct_eq_slice_with_cmov_eq!(i8, i16, i32, i64, i128, u8, u16, u32, u64, u128);
110impl_ct_eq_with_cmov_eq!(isize, usize);
111impl CtEqSlice for isize {}
112impl CtEqSlice for usize {}
113
114/// Impl `CtEq` for `NonZero<T>` by calling `NonZero::get`.
115macro_rules! impl_ct_eq_for_nonzero_integer {
116    ( $($ty:ty),+ ) => {
117        $(
118            impl CtEq for $ty {
119                #[inline]
120                fn ct_eq(&self, other: &Self) -> Choice {
121                    self.get().ct_eq(&other.get())
122                }
123            }
124
125            impl CtEqSlice for $ty {}
126        )+
127    };
128}
129
130impl_ct_eq_for_nonzero_integer!(
131    NonZeroI8,
132    NonZeroI16,
133    NonZeroI32,
134    NonZeroI64,
135    NonZeroI128,
136    NonZeroU8,
137    NonZeroU16,
138    NonZeroU32,
139    NonZeroU64,
140    NonZeroU128
141);
142
143impl CtEq for cmp::Ordering {
144    #[inline]
145    fn ct_eq(&self, other: &Self) -> Choice {
146        // `Ordering` is `repr(i8)`, which has a `CtEq` impl
147        (*self as i8).ct_eq(&(*other as i8))
148    }
149}
150
151impl CtEqSlice for cmp::Ordering {}
152
153impl<T, const N: usize> CtEq for [T; N]
154where
155    T: CtEqSlice,
156{
157    #[inline]
158    fn ct_eq(&self, other: &[T; N]) -> Choice {
159        self.as_slice().ct_eq(other.as_slice())
160    }
161}
162
163impl<T, const N: usize> CtEqSlice for [T; N] where T: CtEqSlice {}
164
165#[cfg(feature = "subtle")]
166impl CtEq for subtle::Choice {
167    #[inline]
168    fn ct_eq(&self, other: &Self) -> Choice {
169        self.unwrap_u8().ct_eq(&other.unwrap_u8())
170    }
171}
172
173#[cfg(feature = "subtle")]
174impl<T> CtEq for subtle::CtOption<T>
175where
176    T: CtEq + Default + subtle::ConditionallySelectable,
177{
178    #[inline]
179    fn ct_eq(&self, other: &Self) -> Choice {
180        CtOption::from(*self).ct_eq(&CtOption::from(*other))
181    }
182}
183
184#[cfg(feature = "alloc")]
185mod alloc {
186    use super::{Choice, CtEq, CtEqSlice};
187    use ::alloc::{boxed::Box, vec::Vec};
188
189    impl<T> CtEq for Box<T>
190    where
191        T: CtEq,
192    {
193        #[inline]
194        #[track_caller]
195        fn ct_eq(&self, rhs: &Self) -> Choice {
196            (**self).ct_eq(rhs)
197        }
198    }
199
200    impl<T> CtEq for Box<[T]>
201    where
202        T: CtEqSlice,
203    {
204        #[inline]
205        #[track_caller]
206        fn ct_eq(&self, rhs: &Self) -> Choice {
207            self.ct_eq(&**rhs)
208        }
209    }
210
211    impl<T> CtEq<[T]> for Box<[T]>
212    where
213        T: CtEqSlice,
214    {
215        #[inline]
216        #[track_caller]
217        fn ct_eq(&self, rhs: &[T]) -> Choice {
218            (**self).ct_eq(rhs)
219        }
220    }
221
222    impl<T> CtEq for Vec<T>
223    where
224        T: CtEqSlice,
225    {
226        #[inline]
227        #[track_caller]
228        fn ct_eq(&self, rhs: &Self) -> Choice {
229            self.ct_eq(rhs.as_slice())
230        }
231    }
232
233    impl<T> CtEq<[T]> for Vec<T>
234    where
235        T: CtEqSlice,
236    {
237        #[inline]
238        #[track_caller]
239        fn ct_eq(&self, rhs: &[T]) -> Choice {
240            self.as_slice().ct_eq(rhs)
241        }
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::CtEq;
248    use core::cmp::Ordering;
249
250    macro_rules! truth_table {
251        ($a:expr, $b:expr, $c:expr) => {
252            assert!($a.ct_eq(&$b).to_bool());
253            assert!(!$a.ct_eq(&$c).to_bool());
254            assert!(!$b.ct_eq(&$c).to_bool());
255
256            assert!(!$a.ct_ne(&$b).to_bool());
257            assert!($a.ct_ne(&$c).to_bool());
258            assert!($b.ct_ne(&$c).to_bool());
259        };
260    }
261
262    macro_rules! ct_eq_test {
263        ($ty:ty, $name:ident) => {
264            #[test]
265            fn $name() {
266                let a: $ty = 42;
267                let b: $ty = 42;
268                let c: $ty = 1;
269                truth_table!(a, b, c);
270            }
271        };
272    }
273
274    ct_eq_test!(u8, u8_ct_eq);
275    ct_eq_test!(u16, u16_ct_eq);
276    ct_eq_test!(u32, u32_ct_eq);
277    ct_eq_test!(u64, u64_ct_eq);
278    ct_eq_test!(u128, u128_ct_eq);
279    ct_eq_test!(usize, usize_ct_eq);
280
281    #[test]
282    fn array_ct_eq() {
283        let a = [1u64, 2, 3];
284        let b = [1u64, 2, 3];
285        let c = [1u64, 2, 4];
286        truth_table!(a, b, c);
287    }
288
289    #[test]
290    fn ordering_ct_eq() {
291        let a = Ordering::Greater;
292        let b = Ordering::Greater;
293        let c = Ordering::Less;
294        truth_table!(a, b, c);
295    }
296
297    #[test]
298    fn slice_ct_eq() {
299        let a: &[u64] = &[1, 2, 3];
300        let b: &[u64] = &[1, 2, 3];
301        let c: &[u64] = &[1, 2, 4];
302        truth_table!(a, b, c);
303
304        // Length mismatches
305        assert!(a.ct_ne(&[]).to_bool());
306        assert!(a.ct_ne(&[1, 2]).to_bool());
307    }
308}