Skip to main content

p3_symmetric/
hash.rs

1use alloc::vec;
2use alloc::vec::{IntoIter, Vec};
3use core::borrow::Borrow;
4use core::marker::PhantomData;
5
6use p3_util::log2_strict_usize;
7use serde::{Deserialize, Serialize};
8
9/// A wrapper around an array digest, with a phantom type parameter to ensure that the digest is
10/// associated with a particular field.
11#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(bound(serialize = "[W; DIGEST_ELEMS]: Serialize"))]
13#[serde(bound(deserialize = "[W; DIGEST_ELEMS]: Deserialize<'de>"))]
14pub struct Hash<F, W, const DIGEST_ELEMS: usize> {
15    value: [W; DIGEST_ELEMS],
16    _marker: PhantomData<F>,
17}
18
19/// The Merkle cap of height `h` of a Merkle tree is the `h`-th layer (from the root) of the tree.
20/// It can be used in place of the root to verify Merkle paths, which are `h` elements shorter.
21///
22/// A cap of height 0 contains a single element (the root), while a cap of height `h` contains
23/// `2^h` elements. The `Digest` type is the full digest (e.g. `[W; DIGEST_ELEMS]`).
24#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
25#[serde(bound(serialize = "Digest: Serialize"))]
26#[serde(bound(deserialize = "Digest: Deserialize<'de>"))]
27pub struct MerkleCap<F, Digest> {
28    cap: Vec<Digest>,
29    _marker: PhantomData<F>,
30}
31
32impl<F, Digest> MerkleCap<F, Digest> {
33    /// Create a new `MerkleCap` from a vector of digests.
34    pub fn new(cap: Vec<Digest>) -> Self {
35        assert!(cap.len().is_power_of_two());
36        Self {
37            cap,
38            _marker: PhantomData,
39        }
40    }
41
42    /// Returns the number of digests in the cap.
43    #[must_use]
44    pub const fn num_roots(&self) -> usize {
45        self.cap.len()
46    }
47
48    /// Returns the height of the cap (log2 of the number of elements).
49    /// A cap with 1 element has height 0, a cap with 2 elements has height 1, etc.
50    #[must_use]
51    pub const fn height(&self) -> usize {
52        log2_strict_usize(self.num_roots())
53    }
54
55    /// Returns a reference to the underlying slice of digests.
56    #[must_use]
57    pub fn roots(&self) -> &[Digest] {
58        &self.cap
59    }
60
61    /// Flattens the cap into a single vector of digest words.
62    pub fn into_roots(self) -> Vec<Digest> {
63        self.cap.into_iter().collect()
64    }
65}
66
67impl<F, Digest> From<Vec<Digest>> for MerkleCap<F, Digest> {
68    fn from(cap: Vec<Digest>) -> Self {
69        Self::new(cap)
70    }
71}
72
73impl<F, W, const N: usize> From<Hash<F, W, N>> for MerkleCap<F, [W; N]> {
74    fn from(hash: Hash<F, W, N>) -> Self {
75        Self::new(vec![hash.into()])
76    }
77}
78
79impl<F, Digest> Borrow<[Digest]> for MerkleCap<F, Digest> {
80    fn borrow(&self) -> &[Digest] {
81        &self.cap
82    }
83}
84
85impl<F, Digest> AsRef<[Digest]> for MerkleCap<F, Digest> {
86    fn as_ref(&self) -> &[Digest] {
87        &self.cap
88    }
89}
90
91impl<F, Digest> core::ops::Index<usize> for MerkleCap<F, Digest> {
92    type Output = Digest;
93
94    fn index(&self, index: usize) -> &Self::Output {
95        &self.cap[index]
96    }
97}
98
99impl<F, Digest> IntoIterator for MerkleCap<F, Digest> {
100    type Item = Digest;
101    type IntoIter = IntoIter<Digest>;
102
103    fn into_iter(self) -> Self::IntoIter {
104        self.cap.into_iter()
105    }
106}
107
108impl<F, W, const DIGEST_ELEMS: usize> From<[W; DIGEST_ELEMS]> for Hash<F, W, DIGEST_ELEMS> {
109    fn from(value: [W; DIGEST_ELEMS]) -> Self {
110        Self {
111            value,
112            _marker: PhantomData,
113        }
114    }
115}
116
117impl<F, W, const DIGEST_ELEMS: usize> From<Hash<F, W, DIGEST_ELEMS>> for [W; DIGEST_ELEMS] {
118    fn from(value: Hash<F, W, DIGEST_ELEMS>) -> [W; DIGEST_ELEMS] {
119        value.value
120    }
121}
122
123impl<F, W: PartialEq, const DIGEST_ELEMS: usize> PartialEq<[W; DIGEST_ELEMS]>
124    for Hash<F, W, DIGEST_ELEMS>
125{
126    fn eq(&self, other: &[W; DIGEST_ELEMS]) -> bool {
127        self.value == *other
128    }
129}
130
131impl<F, W, const DIGEST_ELEMS: usize> IntoIterator for Hash<F, W, DIGEST_ELEMS> {
132    type Item = W;
133    type IntoIter = core::array::IntoIter<W, DIGEST_ELEMS>;
134
135    fn into_iter(self) -> Self::IntoIter {
136        self.value.into_iter()
137    }
138}
139
140impl<F, W, const DIGEST_ELEMS: usize> Borrow<[W; DIGEST_ELEMS]> for Hash<F, W, DIGEST_ELEMS> {
141    fn borrow(&self) -> &[W; DIGEST_ELEMS] {
142        &self.value
143    }
144}
145
146impl<F, W, const DIGEST_ELEMS: usize> AsRef<[W; DIGEST_ELEMS]> for Hash<F, W, DIGEST_ELEMS> {
147    fn as_ref(&self) -> &[W; DIGEST_ELEMS] {
148        &self.value
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use alloc::vec;
155
156    use p3_goldilocks::Goldilocks;
157
158    use super::*;
159
160    type F = Goldilocks;
161    type Digest = [u8; 4];
162
163    #[test]
164    fn test_merkle_cap_new_with_power_of_two_sizes() {
165        let cap = MerkleCap::<F, Digest>::new(vec![[7u8; 4]]);
166        assert_eq!(cap.num_roots(), 1);
167        assert_eq!(cap.height(), 0);
168
169        let cap = MerkleCap::<F, Digest>::new(vec![[0u8; 4]; 2]);
170        assert_eq!(cap.num_roots(), 2);
171        assert_eq!(cap.height(), 1);
172
173        let cap = MerkleCap::<F, Digest>::new(vec![[0u8; 4]; 8]);
174        assert_eq!(cap.num_roots(), 8);
175        assert_eq!(cap.height(), 3);
176    }
177
178    #[test]
179    #[should_panic]
180    fn test_merkle_cap_new_panics_on_empty() {
181        let _ = MerkleCap::<F, Digest>::new(vec![]);
182    }
183
184    #[test]
185    #[should_panic]
186    fn test_merkle_cap_new_panics_on_three() {
187        let _ = MerkleCap::<F, Digest>::new(vec![[0u8; 4]; 3]);
188    }
189
190    #[test]
191    #[should_panic]
192    fn test_merkle_cap_from_vec_panics_on_empty() {
193        let _: MerkleCap<F, Digest> = vec![].into();
194    }
195
196    #[test]
197    fn test_merkle_cap_from_hash() {
198        let hash = Hash::<F, u8, 4>::from([1u8, 2, 3, 4]);
199        let cap: MerkleCap<F, [u8; 4]> = hash.into();
200        assert_eq!(cap.num_roots(), 1);
201        assert_eq!(cap.height(), 0);
202    }
203}