1use core::{fmt, marker::PhantomData};
2use digest::{
3 HashMarker, Output,
4 array::ArraySize,
5 block_api::{
6 AlgorithmName, Block, BlockSizeUser, Buffer, BufferKindUser, Eager, ExtendableOutputCore,
7 FixedOutputCore, OutputSizeUser, Reset, UpdateCore, XofReaderCore,
8 },
9 block_buffer::BlockSizes,
10 common::hazmat::{DeserializeStateError, SerializableState, SerializedState},
11 typenum::{IsLessOrEqual, True, U0, U200},
12};
13use keccak::{F1600_ROUNDS, Keccak, State1600};
14
15pub use crate::cshake::{CShake128Core, CShake256Core};
16
17#[derive(Clone)]
19pub struct Sha3HasherCore<Rate, OutputSize, const PAD: u8, const ROUNDS: usize = F1600_ROUNDS>
20where
21 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
22 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
23{
24 state: State1600,
25 keccak: Keccak,
26 _pd: PhantomData<(Rate, OutputSize)>,
27}
28
29impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> HashMarker
30 for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
31where
32 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
33 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
34{
35}
36
37impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> BlockSizeUser
38 for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
39where
40 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
41 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
42{
43 type BlockSize = Rate;
44}
45
46impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> BufferKindUser
47 for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
48where
49 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
50 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
51{
52 type BufferKind = Eager;
53}
54
55impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> OutputSizeUser
56 for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
57where
58 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
59 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
60{
61 type OutputSize = OutputSize;
62}
63
64impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> UpdateCore
65 for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
66where
67 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
68 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
69{
70 #[inline]
71 fn update_blocks(&mut self, blocks: &[Block<Self>]) {
72 self.keccak.with_p1600::<ROUNDS>(|p1600| {
73 for block in blocks {
74 xor_block(&mut self.state, block);
75 p1600(&mut self.state);
76 }
77 });
78 }
79}
80
81impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> FixedOutputCore
82 for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
83where
84 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
85 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
86{
87 #[inline]
88 fn finalize_fixed_core(&mut self, buffer: &mut Buffer<Self>, out: &mut Output<Self>) {
89 let pos = buffer.get_pos();
90 let mut block = buffer.pad_with_zeros();
91 block[pos] = PAD;
92 let n = block.len();
93 block[n - 1] |= 0x80;
94
95 self.keccak.with_p1600::<ROUNDS>(|p1600| {
96 xor_block(&mut self.state, &block);
97 p1600(&mut self.state);
98
99 for (o, s) in out.chunks_mut(8).zip(self.state.as_mut().iter()) {
100 o.copy_from_slice(&s.to_le_bytes()[..o.len()]);
101 }
102 });
103 }
104}
105
106impl<Rate, const PAD: u8, const ROUNDS: usize> ExtendableOutputCore
107 for Sha3HasherCore<Rate, U0, PAD, ROUNDS>
108where
109 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
110{
111 type ReaderCore = Sha3ReaderCore<Rate, ROUNDS>;
112
113 #[inline]
114 fn finalize_xof_core(&mut self, buffer: &mut Buffer<Self>) -> Self::ReaderCore {
115 let pos = buffer.get_pos();
116 let mut block = buffer.pad_with_zeros();
117 block[pos] = PAD;
118 let n = block.len();
119 block[n - 1] |= 0x80;
120
121 self.keccak.with_p1600::<ROUNDS>(|p1600| {
122 xor_block(&mut self.state, &block);
123 p1600(&mut self.state);
124 });
125
126 Sha3ReaderCore::new(&self.state, self.keccak)
127 }
128}
129
130impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> Default
131 for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
132where
133 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
134 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
135{
136 #[inline]
137 fn default() -> Self {
138 Self {
139 state: Default::default(),
140 keccak: Keccak::new(),
141 _pd: PhantomData,
142 }
143 }
144}
145
146impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> Reset
147 for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
148where
149 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
150 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
151{
152 #[inline]
153 fn reset(&mut self) {
154 *self = Default::default();
155 }
156}
157
158impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> AlgorithmName
159 for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
160where
161 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
162 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
163{
164 fn write_alg_name(f: &mut fmt::Formatter<'_>) -> fmt::Result {
165 f.write_str("Sha3") }
167}
168
169impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> fmt::Debug
170 for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
171where
172 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
173 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
174{
175 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176 f.write_str("Sha3FixedCore<Rate, OutputSize, PAD, ROUNDS> { ... }")
177 }
178}
179
180impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> Drop
181 for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
182where
183 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
184 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
185{
186 fn drop(&mut self) {
187 #[cfg(feature = "zeroize")]
188 {
189 use digest::zeroize::Zeroize;
190 self.state.as_mut().zeroize();
191 }
192 }
193}
194
195#[cfg(feature = "zeroize")]
196impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> digest::zeroize::ZeroizeOnDrop
197 for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
198where
199 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
200 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
201{
202}
203
204impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> SerializableState
205 for Sha3HasherCore<Rate, OutputSize, PAD, ROUNDS>
206where
207 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
208 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
209{
210 type SerializedStateSize = U200;
211
212 fn serialize(&self) -> SerializedState<Self> {
213 let mut serialized_state = SerializedState::<Self>::default();
214 let chunks = serialized_state.chunks_exact_mut(8);
215 for (val, chunk) in self.state.as_ref().iter().zip(chunks) {
216 chunk.copy_from_slice(&val.to_le_bytes());
217 }
218
219 serialized_state
220 }
221
222 fn deserialize(
223 serialized_state: &SerializedState<Self>,
224 ) -> Result<Self, DeserializeStateError> {
225 let mut state = State1600::default();
226 let chunks = serialized_state.chunks_exact(8);
227 for (val, chunk) in state.iter_mut().zip(chunks) {
228 *val = u64::from_le_bytes(chunk.try_into().unwrap());
229 }
230
231 Ok(Self {
232 state,
233 keccak: Keccak::new(),
234 _pd: PhantomData,
235 })
236 }
237}
238
239#[derive(Clone)]
241pub struct Sha3ReaderCore<Rate, const ROUNDS: usize = F1600_ROUNDS>
242where
243 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
244{
245 state: State1600,
246 keccak: Keccak,
247 _pd: PhantomData<Rate>,
248}
249
250impl<Rate, const ROUNDS: usize> Sha3ReaderCore<Rate, ROUNDS>
251where
252 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
253{
254 pub(crate) fn new(&state: &State1600, keccak: Keccak) -> Self {
255 let _pd = PhantomData;
256 Self { state, keccak, _pd }
257 }
258}
259
260impl<Rate, const ROUNDS: usize> BlockSizeUser for Sha3ReaderCore<Rate, ROUNDS>
261where
262 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
263{
264 type BlockSize = Rate;
265}
266
267impl<Rate, const ROUNDS: usize> XofReaderCore for Sha3ReaderCore<Rate, ROUNDS>
268where
269 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
270{
271 #[inline]
272 fn read_block(&mut self) -> Block<Self> {
273 let mut block = Block::<Self>::default();
274 for (src, dst) in self.state.iter().zip(block.chunks_mut(8)) {
275 dst.copy_from_slice(&src.to_le_bytes()[..dst.len()]);
276 }
277 self.keccak
278 .with_p1600::<ROUNDS>(|p1600| p1600(&mut self.state));
279 block
280 }
281}
282
283impl<Rate, const ROUNDS: usize> Drop for Sha3ReaderCore<Rate, ROUNDS>
284where
285 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
286{
287 fn drop(&mut self) {
288 #[cfg(feature = "zeroize")]
289 {
290 use digest::zeroize::Zeroize;
291 self.state.zeroize();
292 }
293 }
294}
295
296impl<Rate, const ROUNDS: usize> fmt::Debug for Sha3ReaderCore<Rate, ROUNDS>
297where
298 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
299{
300 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
301 f.write_str("Sha3ReaderCore { ... }")
302 }
303}
304
305#[cfg(feature = "zeroize")]
306impl<Rate, const ROUNDS: usize> digest::zeroize::ZeroizeOnDrop for Sha3ReaderCore<Rate, ROUNDS> where
307 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>
308{
309}
310
311pub(crate) fn xor_block(state: &mut State1600, block: &[u8]) {
312 assert!(size_of_val(block) < size_of_val(state));
313
314 let mut chunks = block.chunks_exact(8);
315 for (s, chunk) in state.iter_mut().zip(&mut chunks) {
316 *s ^= u64::from_le_bytes(chunk.try_into().unwrap());
317 }
318
319 let rem = chunks.remainder();
320 if !rem.is_empty() {
321 let mut buf = [0u8; 8];
322 buf[..rem.len()].copy_from_slice(rem);
323 let n = block.len() / 8;
324 state[n] ^= u64::from_le_bytes(buf);
325 }
326}