1use alloc::vec::Vec;
18
19use crate::core::{
20 digest::Digest,
21 hash::sha::{Block, BLOCK_WORDS, SHA256_INIT},
22};
23use risc0_zkvm_platform::{
24 align_up,
25 syscall::{sys_sha_buffer, sys_sha_compress},
26 WORD_SIZE,
27};
28
29const END_MARKER: u8 = 0x80;
32
33fn alloc_uninit_digest() -> *mut Digest {
34 extern crate alloc;
35 use core::alloc::Layout;
36 unsafe { alloc::alloc::alloc(Layout::from_size_align(64, 4).unwrap()).cast() }
37}
38
39fn compress(
40 out_state: *mut Digest,
41 in_state: *const Digest,
42 block_half1: &Digest,
43 block_half2: &Digest,
44) {
45 unsafe {
50 sys_sha_compress(
51 out_state.cast(),
52 in_state.cast(),
53 block_half1.as_ref(),
54 block_half2.as_ref(),
55 );
56 }
57}
58
59fn compress_slice(out_state: *mut Digest, in_state: *const Digest, blocks: &[Block]) {
60 unsafe {
65 sys_sha_buffer(
66 out_state.cast(),
67 in_state.cast(),
68 bytemuck::cast_slice(blocks).as_ptr(),
69 blocks.len() as u32,
70 );
71 }
72}
73
74#[derive(PartialEq, Eq, Debug, Copy, Clone)]
75pub(crate) enum Trailer {
76 WithTrailer { total_bits: u32 },
79 WithoutTrailer,
81}
82pub(crate) use Trailer::*;
83
84const fn compute_u32s_needed(len_bytes: usize, trailer: Trailer) -> usize {
85 match trailer {
86 WithoutTrailer => align_up(len_bytes, WORD_SIZE * BLOCK_WORDS) / WORD_SIZE,
87 WithTrailer { total_bits: _ } => {
88 let nwords = align_up(len_bytes + 1, WORD_SIZE) / WORD_SIZE;
90 let nwords = nwords + 2;
93
94 align_up(nwords, BLOCK_WORDS)
95 }
96 }
97}
98
99fn copy_and_update(
101 out_state: *mut Digest,
102 mut in_state: *const Digest,
103 bytes: &[u8],
104 trailer: Trailer,
105) {
106 let padlen = compute_u32s_needed(bytes.len(), trailer);
107 let mut padbuf: Vec<u32> = Vec::with_capacity(padlen);
108 assert!(bytes.len() <= padlen * WORD_SIZE);
109 unsafe {
112 let padbuf_u8: *mut u8 = padbuf.as_mut_ptr().cast();
113 padbuf_u8.copy_from_nonoverlapping(bytes.as_ptr(), bytes.len());
114
115 match trailer {
116 WithTrailer { total_bits: _ } => {
117 padbuf_u8.add(bytes.len()).write(END_MARKER);
118 padbuf_u8
119 .add(bytes.len() + 1)
120 .write_bytes(0, padlen * WORD_SIZE - bytes.len() - 1)
121 }
122 WithoutTrailer => {
123 padbuf_u8
124 .add(bytes.len())
125 .write_bytes(0, padlen * WORD_SIZE - bytes.len());
126 }
127 }
128 padbuf.set_len(padlen);
129 }
130
131 if let WithTrailer { total_bits } = trailer {
132 assert_eq!(padbuf[padlen - 1], 0);
133 padbuf[padlen - 1] = total_bits.to_be();
134 }
135
136 match bytemuck::pod_align_to::<u32, Block>(padbuf.as_slice()) {
137 (&[], blocks @ &[..], &[]) => {
138 if !blocks.is_empty() {
143 unsafe {
144 sys_sha_buffer(
145 out_state.cast(),
146 in_state.cast(),
147 bytemuck::cast_slice(blocks).as_ptr(),
148 blocks.len() as u32,
149 );
150 }
151 in_state = out_state;
152 }
153 }
154 _ => panic!("padbuf should already be aligned and padded"),
155 }
156
157 if in_state != out_state {
158 unsafe { out_state.write(*in_state) }
159 }
160}
161
162pub(crate) fn update_u32(
163 out_state: *mut Digest,
164 mut in_state: *const Digest,
165 words: &[u32],
166 trailer: Trailer,
167) {
168 match bytemuck::pod_align_to::<u32, Block>(words) {
169 (&[], blocks @ &[..], rest @ &[..]) => {
170 if !blocks.is_empty() {
175 unsafe {
176 sys_sha_buffer(
177 out_state.cast(),
178 in_state.cast(),
179 bytemuck::cast_slice(blocks).as_ptr(),
180 blocks.len() as u32,
181 );
182 }
183 in_state = out_state;
184 }
185
186 copy_and_update(out_state, in_state, bytemuck::cast_slice(rest), trailer);
187 }
188 _ => unreachable!(
189 "words should always have sufficient alignment to start on a block boundary"
190 ),
191 }
192}
193
194fn update_u8(out_state: *mut Digest, mut in_state: *const Digest, bytes: &[u8], trailer: Trailer) {
195 match bytemuck::pod_align_to::<u8, Block>(bytes) {
196 (&[], blocks @ &[..], rest @ &[..]) => {
197 if !blocks.is_empty() {
203 unsafe {
204 sys_sha_buffer(
205 out_state.cast(),
206 in_state.cast(),
207 bytemuck::cast_slice(blocks).as_ptr(),
208 blocks.len() as u32,
209 );
210 }
211 in_state = out_state;
212 }
213 copy_and_update(out_state, in_state, rest, trailer);
214 }
215 _ => {
216 copy_and_update(out_state, in_state, bytes, trailer)
218 }
219 }
220}
221
222#[derive(Debug, Clone, Default)]
226pub struct Impl {}
227
228impl crate::core::hash::sha::Sha256 for Impl {
229 type DigestPtr = &'static mut Digest;
230
231 fn hash_bytes(bytes: &[u8]) -> Self::DigestPtr {
232 let digest = alloc_uninit_digest();
233 update_u8(
234 digest,
235 &SHA256_INIT,
236 bytes,
237 WithTrailer {
238 total_bits: bytes.len() as u32 * 8,
239 },
240 );
241 unsafe { &mut *digest }
243 }
244
245 fn hash_words(words: &[u32]) -> Self::DigestPtr {
246 let digest = alloc_uninit_digest();
247 update_u32(
248 digest,
249 &SHA256_INIT,
250 words,
251 WithTrailer {
252 total_bits: words.len() as u32 * 32,
253 },
254 );
255 unsafe { &mut *digest }
257 }
258
259 fn hash_raw_data_slice<T: bytemuck::NoUninit>(data: &[T]) -> Self::DigestPtr {
260 let digest = alloc_uninit_digest();
261 let words: &[u32] = bytemuck::cast_slice(data);
262 update_u32(digest, &SHA256_INIT, words, WithoutTrailer);
263 unsafe { &mut *digest }
265 }
266
267 fn compress(state: &Digest, block_half1: &Digest, block_half2: &Digest) -> Self::DigestPtr {
268 let digest = alloc_uninit_digest();
269 compress(digest, state, block_half1, block_half2);
270 unsafe { &mut *digest }
272 }
273
274 fn compress_slice(state: &Digest, blocks: &[Block]) -> Self::DigestPtr {
275 let digest = alloc_uninit_digest();
276 compress_slice(digest, state, blocks);
277 unsafe { &mut *digest }
279 }
280}