Skip to main content

sha3/
block_api.rs

1use core::{fmt, marker::PhantomData};
2use digest::{
3    HashMarker, Output,
4    array::ArraySize,
5    block_api::{
6        AlgorithmName, Block, BlockSizeUser, Buffer, BufferKindUser, Eager, ExtendableOutputCore,
7        FixedOutputCore, OutputSizeUser, Reset, UpdateCore, XofReaderCore,
8    },
9    block_buffer::BlockSizes,
10    common::hazmat::{DeserializeStateError, SerializableState, SerializedState},
11    typenum::{IsLessOrEqual, True, U0, U200},
12};
13use keccak::{Keccak, State1600};
14
15/// Core SHA-3 hasher state.
16#[derive(Clone)]
17pub struct Sha3HasherCore<Rate, OutputSize, const PAD: u8>
18where
19    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
20    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
21{
22    state: State1600,
23    keccak: Keccak,
24    _pd: PhantomData<(Rate, OutputSize)>,
25}
26
27impl<Rate, OutputSize, const PAD: u8> HashMarker for Sha3HasherCore<Rate, OutputSize, PAD>
28where
29    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
30    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
31{
32}
33
34impl<Rate, OutputSize, const PAD: u8> BlockSizeUser for Sha3HasherCore<Rate, OutputSize, PAD>
35where
36    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
37    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
38{
39    type BlockSize = Rate;
40}
41
42impl<Rate, OutputSize, const PAD: u8> BufferKindUser for Sha3HasherCore<Rate, OutputSize, PAD>
43where
44    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
45    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
46{
47    type BufferKind = Eager;
48}
49
50impl<Rate, OutputSize, const PAD: u8> OutputSizeUser for Sha3HasherCore<Rate, OutputSize, PAD>
51where
52    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
53    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
54{
55    type OutputSize = OutputSize;
56}
57
58impl<Rate, OutputSize, const PAD: u8> UpdateCore for Sha3HasherCore<Rate, OutputSize, PAD>
59where
60    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
61    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
62{
63    #[inline]
64    fn update_blocks(&mut self, blocks: &[Block<Self>]) {
65        self.keccak.with_f1600(|f1600| {
66            for block in blocks {
67                xor_block(&mut self.state, block);
68                f1600(&mut self.state);
69            }
70        });
71    }
72}
73
74impl<Rate, OutputSize, const PAD: u8> FixedOutputCore for Sha3HasherCore<Rate, OutputSize, PAD>
75where
76    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
77    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
78{
79    #[inline]
80    fn finalize_fixed_core(&mut self, buffer: &mut Buffer<Self>, out: &mut Output<Self>) {
81        let pos = buffer.get_pos();
82        let mut block = buffer.pad_with_zeros();
83        block[pos] = PAD;
84        let n = block.len();
85        block[n - 1] |= 0x80;
86
87        self.keccak.with_f1600(|f1600| {
88            xor_block(&mut self.state, &block);
89            f1600(&mut self.state);
90
91            for (o, s) in out.chunks_mut(8).zip(self.state.as_mut().iter()) {
92                o.copy_from_slice(&s.to_le_bytes()[..o.len()]);
93            }
94        });
95    }
96}
97
98impl<Rate, const PAD: u8> ExtendableOutputCore for Sha3HasherCore<Rate, U0, PAD>
99where
100    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
101{
102    type ReaderCore = Sha3ReaderCore<Rate>;
103
104    #[inline]
105    fn finalize_xof_core(&mut self, buffer: &mut Buffer<Self>) -> Self::ReaderCore {
106        let pos = buffer.get_pos();
107        let mut block = buffer.pad_with_zeros();
108        block[pos] = PAD;
109        let n = block.len();
110        block[n - 1] |= 0x80;
111
112        self.keccak.with_f1600(|f1600| {
113            xor_block(&mut self.state, &block);
114            f1600(&mut self.state);
115        });
116
117        Sha3ReaderCore::new(&self.state, self.keccak)
118    }
119}
120
121impl<Rate, OutputSize, const PAD: u8> Default for Sha3HasherCore<Rate, OutputSize, PAD>
122where
123    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
124    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
125{
126    #[inline]
127    fn default() -> Self {
128        Self {
129            state: Default::default(),
130            keccak: Keccak::new(),
131            _pd: PhantomData,
132        }
133    }
134}
135
136impl<Rate, OutputSize, const PAD: u8> Reset for Sha3HasherCore<Rate, OutputSize, PAD>
137where
138    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
139    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
140{
141    #[inline]
142    fn reset(&mut self) {
143        *self = Default::default();
144    }
145}
146
147impl<Rate, OutputSize, const PAD: u8> AlgorithmName for Sha3HasherCore<Rate, OutputSize, PAD>
148where
149    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
150    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
151{
152    fn write_alg_name(f: &mut fmt::Formatter<'_>) -> fmt::Result {
153        // TODO: change algorithm name depending on the generic parameters
154        f.write_str("Sha3")
155    }
156}
157
158impl<Rate, OutputSize, const PAD: u8> fmt::Debug for Sha3HasherCore<Rate, OutputSize, PAD>
159where
160    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
161    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
162{
163    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164        f.write_str("Sha3FixedCore<Rate, OutputSize, PAD, ROUNDS> { ... }")
165    }
166}
167
168impl<Rate, OutputSize, const PAD: u8> Drop for Sha3HasherCore<Rate, OutputSize, PAD>
169where
170    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
171    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
172{
173    fn drop(&mut self) {
174        #[cfg(feature = "zeroize")]
175        {
176            use digest::zeroize::Zeroize;
177            self.state.as_mut().zeroize();
178        }
179    }
180}
181
182#[cfg(feature = "zeroize")]
183impl<Rate, OutputSize, const PAD: u8> digest::zeroize::ZeroizeOnDrop
184    for Sha3HasherCore<Rate, OutputSize, PAD>
185where
186    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
187    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
188{
189}
190
191impl<Rate, OutputSize, const PAD: u8> SerializableState for Sha3HasherCore<Rate, OutputSize, PAD>
192where
193    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
194    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
195{
196    type SerializedStateSize = U200;
197
198    fn serialize(&self) -> SerializedState<Self> {
199        let mut serialized_state = SerializedState::<Self>::default();
200        let chunks = serialized_state.chunks_exact_mut(8);
201        for (val, chunk) in self.state.as_ref().iter().zip(chunks) {
202            chunk.copy_from_slice(&val.to_le_bytes());
203        }
204
205        serialized_state
206    }
207
208    fn deserialize(
209        serialized_state: &SerializedState<Self>,
210    ) -> Result<Self, DeserializeStateError> {
211        let mut state = State1600::default();
212        let chunks = serialized_state.chunks_exact(8);
213        for (val, chunk) in state.iter_mut().zip(chunks) {
214            *val = u64::from_le_bytes(chunk.try_into().unwrap());
215        }
216
217        Ok(Self {
218            state,
219            keccak: Keccak::new(),
220            _pd: PhantomData,
221        })
222    }
223}
224
225/// Core SHA-3 XOF reader.
226#[derive(Clone)]
227pub struct Sha3ReaderCore<Rate>
228where
229    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
230{
231    state: State1600,
232    keccak: Keccak,
233    _pd: PhantomData<Rate>,
234}
235
236impl<Rate> Sha3ReaderCore<Rate>
237where
238    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
239{
240    pub(crate) fn new(&state: &State1600, keccak: Keccak) -> Self {
241        let _pd = PhantomData;
242        Self { state, keccak, _pd }
243    }
244}
245
246impl<Rate> BlockSizeUser for Sha3ReaderCore<Rate>
247where
248    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
249{
250    type BlockSize = Rate;
251}
252
253impl<Rate> XofReaderCore for Sha3ReaderCore<Rate>
254where
255    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
256{
257    #[inline]
258    fn read_block(&mut self) -> Block<Self> {
259        let mut block = Block::<Self>::default();
260        for (src, dst) in self.state.iter().zip(block.chunks_mut(8)) {
261            dst.copy_from_slice(&src.to_le_bytes()[..dst.len()]);
262        }
263        self.keccak.with_f1600(|f1600| f1600(&mut self.state));
264        block
265    }
266}
267
268impl<Rate> Drop for Sha3ReaderCore<Rate>
269where
270    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
271{
272    fn drop(&mut self) {
273        #[cfg(feature = "zeroize")]
274        {
275            use digest::zeroize::Zeroize;
276            self.state.zeroize();
277        }
278    }
279}
280
281impl<Rate> fmt::Debug for Sha3ReaderCore<Rate>
282where
283    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
284{
285    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
286        f.write_str("Sha3ReaderCore { ... }")
287    }
288}
289
290#[cfg(feature = "zeroize")]
291impl<Rate> digest::zeroize::ZeroizeOnDrop for Sha3ReaderCore<Rate> where
292    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>
293{
294}
295
296pub(crate) fn xor_block(state: &mut State1600, block: &[u8]) {
297    assert!(size_of_val(block) < size_of_val(state));
298
299    let mut chunks = block.chunks_exact(8);
300    for (s, chunk) in state.iter_mut().zip(&mut chunks) {
301        *s ^= u64::from_le_bytes(chunk.try_into().unwrap());
302    }
303
304    let rem = chunks.remainder();
305    if !rem.is_empty() {
306        let mut buf = [0u8; 8];
307        buf[..rem.len()].copy_from_slice(rem);
308        let n = block.len() / 8;
309        state[n] ^= u64::from_le_bytes(buf);
310    }
311}