Skip to main content

sha3/
block_api.rs

1use core::{fmt, marker::PhantomData};
2use digest::{
3    HashMarker, Output,
4    array::ArraySize,
5    block_api::{
6        AlgorithmName, Block, BlockSizeUser, Buffer, BufferKindUser, Eager, ExtendableOutputCore,
7        FixedOutputCore, OutputSizeUser, Reset, UpdateCore, XofReaderCore,
8    },
9    block_buffer::BlockSizes,
10    common::hazmat::{DeserializeStateError, SerializableState, SerializedState},
11    typenum::{IsLessOrEqual, True, U0, U200},
12};
13use keccak::{F1600_ROUNDS, Keccak, State1600};
14
15pub use crate::cshake::{CShake128Core, CShake256Core};
16
17/// Core Sha3 fixed output hasher state.
18#[derive(Clone)]
19pub struct Sha3HasherCore<Rate, OutputSize, const PAD: u8, const ROUNDS: usize = F1600_ROUNDS>
20where
21    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
22    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
23{
24    state: State1600,
25    keccak: Keccak,
26    _pd: PhantomData<(Rate, OutputSize)>,
27}
28
29impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> HashMarker
30    for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
31where
32    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
33    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
34{
35}
36
37impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> BlockSizeUser
38    for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
39where
40    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
41    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
42{
43    type BlockSize = Rate;
44}
45
46impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> BufferKindUser
47    for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
48where
49    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
50    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
51{
52    type BufferKind = Eager;
53}
54
55impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> OutputSizeUser
56    for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
57where
58    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
59    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
60{
61    type OutputSize = OutputSize;
62}
63
64impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> UpdateCore
65    for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
66where
67    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
68    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
69{
70    #[inline]
71    fn update_blocks(&mut self, blocks: &[Block<Self>]) {
72        self.keccak.with_p1600::<ROUNDS>(|p1600| {
73            for block in blocks {
74                xor_block(&mut self.state, block);
75                p1600(&mut self.state);
76            }
77        });
78    }
79}
80
81impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> FixedOutputCore
82    for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
83where
84    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
85    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
86{
87    #[inline]
88    fn finalize_fixed_core(&mut self, buffer: &mut Buffer<Self>, out: &mut Output<Self>) {
89        let pos = buffer.get_pos();
90        let mut block = buffer.pad_with_zeros();
91        block[pos] = PAD;
92        let n = block.len();
93        block[n - 1] |= 0x80;
94
95        self.keccak.with_p1600::<ROUNDS>(|p1600| {
96            xor_block(&mut self.state, &block);
97            p1600(&mut self.state);
98
99            for (o, s) in out.chunks_mut(8).zip(self.state.as_mut().iter()) {
100                o.copy_from_slice(&s.to_le_bytes()[..o.len()]);
101            }
102        });
103    }
104}
105
106impl<Rate, const PAD: u8, const ROUNDS: usize> ExtendableOutputCore
107    for Sha3HasherCore<Rate, U0, PAD, ROUNDS>
108where
109    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
110{
111    type ReaderCore = Sha3ReaderCore<Rate, ROUNDS>;
112
113    #[inline]
114    fn finalize_xof_core(&mut self, buffer: &mut Buffer<Self>) -> Self::ReaderCore {
115        let pos = buffer.get_pos();
116        let mut block = buffer.pad_with_zeros();
117        block[pos] = PAD;
118        let n = block.len();
119        block[n - 1] |= 0x80;
120
121        self.keccak.with_p1600::<ROUNDS>(|p1600| {
122            xor_block(&mut self.state, &block);
123            p1600(&mut self.state);
124        });
125
126        Sha3ReaderCore::new(&self.state, self.keccak)
127    }
128}
129
130impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> Default
131    for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
132where
133    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
134    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
135{
136    #[inline]
137    fn default() -> Self {
138        Self {
139            state: Default::default(),
140            keccak: Keccak::new(),
141            _pd: PhantomData,
142        }
143    }
144}
145
146impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> Reset
147    for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
148where
149    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
150    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
151{
152    #[inline]
153    fn reset(&mut self) {
154        *self = Default::default();
155    }
156}
157
158impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> AlgorithmName
159    for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
160where
161    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
162    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
163{
164    fn write_alg_name(f: &mut fmt::Formatter<'_>) -> fmt::Result {
165        f.write_str("Sha3") // TODO
166    }
167}
168
169impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> fmt::Debug
170    for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
171where
172    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
173    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
174{
175    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176        f.write_str("Sha3FixedCore<Rate, OutputSize, PAD, ROUNDS> { ... }")
177    }
178}
179
180impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> Drop
181    for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
182where
183    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
184    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
185{
186    fn drop(&mut self) {
187        #[cfg(feature = "zeroize")]
188        {
189            use digest::zeroize::Zeroize;
190            self.state.as_mut().zeroize();
191        }
192    }
193}
194
195#[cfg(feature = "zeroize")]
196impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> digest::zeroize::ZeroizeOnDrop
197    for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
198where
199    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
200    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
201{
202}
203
204impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> SerializableState
205    for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
206where
207    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
208    OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
209{
210    type SerializedStateSize = U200;
211
212    fn serialize(&self) -> SerializedState<Self> {
213        let mut serialized_state = SerializedState::<Self>::default();
214        let chunks = serialized_state.chunks_exact_mut(8);
215        for (val, chunk) in self.state.as_ref().iter().zip(chunks) {
216            chunk.copy_from_slice(&val.to_le_bytes());
217        }
218
219        serialized_state
220    }
221
222    fn deserialize(
223        serialized_state: &SerializedState<Self>,
224    ) -> Result<Self, DeserializeStateError> {
225        let mut state = State1600::default();
226        let chunks = serialized_state.chunks_exact(8);
227        for (val, chunk) in state.iter_mut().zip(chunks) {
228            *val = u64::from_le_bytes(chunk.try_into().unwrap());
229        }
230
231        Ok(Self {
232            state,
233            keccak: Keccak::new(),
234            _pd: PhantomData,
235        })
236    }
237}
238
239/// Core Sha3 XOF reader.
240#[derive(Clone)]
241pub struct Sha3ReaderCore<Rate, const ROUNDS: usize = F1600_ROUNDS>
242where
243    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
244{
245    state: State1600,
246    keccak: Keccak,
247    _pd: PhantomData<Rate>,
248}
249
250impl<Rate, const ROUNDS: usize> Sha3ReaderCore<Rate, ROUNDS>
251where
252    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
253{
254    pub(crate) fn new(&state: &State1600, keccak: Keccak) -> Self {
255        let _pd = PhantomData;
256        Self { state, keccak, _pd }
257    }
258}
259
260impl<Rate, const ROUNDS: usize> BlockSizeUser for Sha3ReaderCore<Rate, ROUNDS>
261where
262    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
263{
264    type BlockSize = Rate;
265}
266
267impl<Rate, const ROUNDS: usize> XofReaderCore for Sha3ReaderCore<Rate, ROUNDS>
268where
269    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
270{
271    #[inline]
272    fn read_block(&mut self) -> Block<Self> {
273        let mut block = Block::<Self>::default();
274        for (src, dst) in self.state.iter().zip(block.chunks_mut(8)) {
275            dst.copy_from_slice(&src.to_le_bytes()[..dst.len()]);
276        }
277        self.keccak
278            .with_p1600::<ROUNDS>(|p1600| p1600(&mut self.state));
279        block
280    }
281}
282
283impl<Rate, const ROUNDS: usize> Drop for Sha3ReaderCore<Rate, ROUNDS>
284where
285    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
286{
287    fn drop(&mut self) {
288        #[cfg(feature = "zeroize")]
289        {
290            use digest::zeroize::Zeroize;
291            self.state.zeroize();
292        }
293    }
294}
295
296impl<Rate, const ROUNDS: usize> fmt::Debug for Sha3ReaderCore<Rate, ROUNDS>
297where
298    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
299{
300    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
301        f.write_str("Sha3ReaderCore { ... }")
302    }
303}
304
305#[cfg(feature = "zeroize")]
306impl<Rate, const ROUNDS: usize> digest::zeroize::ZeroizeOnDrop for Sha3ReaderCore<Rate, ROUNDS> where
307    Rate: BlockSizes + IsLessOrEqual<U200, Output = True>
308{
309}
310
311pub(crate) fn xor_block(state: &mut State1600, block: &[u8]) {
312    assert!(size_of_val(block) < size_of_val(state));
313
314    let mut chunks = block.chunks_exact(8);
315    for (s, chunk) in state.iter_mut().zip(&mut chunks) {
316        *s ^= u64::from_le_bytes(chunk.try_into().unwrap());
317    }
318
319    let rem = chunks.remainder();
320    if !rem.is_empty() {
321        let mut buf = [0u8; 8];
322        buf[..rem.len()].copy_from_slice(rem);
323        let n = block.len() / 8;
324        state[n] ^= u64::from_le_bytes(buf);
325    }
326}