ark_serialize_derive/
deserialize.rs

1use crate::serialize::IdentOrIndex;
2use proc_macro2::TokenStream;
3use quote::quote;
4use syn::{Data, Index, Type};
5
6fn impl_valid_field(
7    check_body: &mut Vec<TokenStream>,
8    batch_check_body: &mut Vec<TokenStream>,
9    idents: &mut Vec<IdentOrIndex>,
10    ty: &Type,
11) {
12    // Check if type is a tuple.
13    match ty {
14        Type::Tuple(tuple) => {
15            for (i, elem_ty) in tuple.elems.iter().enumerate() {
16                let index = Index::from(i);
17                idents.push(IdentOrIndex::Index(index));
18                impl_valid_field(check_body, batch_check_body, idents, elem_ty);
19                idents.pop();
20            }
21        },
22        _ => {
23            check_body.push(quote! { ark_serialize::Valid::check(&self.#(#idents).*)?; });
24            batch_check_body
25                .push(quote! { ark_serialize::Valid::batch_check(batch.iter().map(|v| &v.#(#idents).*))?; });
26        },
27    }
28}
29
30fn impl_valid(ast: &syn::DeriveInput) -> TokenStream {
31    let name = &ast.ident;
32
33    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
34
35    let len = if let Data::Struct(ref data_struct) = ast.data {
36        data_struct.fields.len()
37    } else {
38        panic!(
39            "`Valid` can only be derived for structs, {} is not a struct",
40            name
41        );
42    };
43
44    let mut check_body = Vec::<TokenStream>::with_capacity(len);
45    let mut batch_body = Vec::<TokenStream>::with_capacity(len);
46
47    match ast.data {
48        Data::Struct(ref data_struct) => {
49            let mut idents = Vec::<IdentOrIndex>::new();
50
51            for (i, field) in data_struct.fields.iter().enumerate() {
52                match field.ident {
53                    None => {
54                        let index = Index::from(i);
55                        idents.push(IdentOrIndex::Index(index));
56                    },
57                    Some(ref ident) => {
58                        idents.push(IdentOrIndex::Ident(ident.clone()));
59                    },
60                }
61
62                impl_valid_field(&mut check_body, &mut batch_body, &mut idents, &field.ty);
63
64                idents.clear();
65            }
66        },
67        _ => panic!(
68            "`Valid` can only be derived for structs, {} is not a struct",
69            name
70        ),
71    };
72
73    let gen = quote! {
74        impl #impl_generics ark_serialize::Valid for #name #ty_generics #where_clause {
75            #[allow(unused_mut, unused_variables)]
76            fn check(&self) -> Result<(), ark_serialize::SerializationError> {
77                #(#check_body)*
78                Ok(())
79            }
80            #[allow(unused_mut, unused_variables)]
81            fn batch_check<'a>(batch: impl Iterator<Item = &'a Self> + Send) -> Result<(), ark_serialize::SerializationError>
82                where
83            Self: 'a
84            {
85
86                let batch: Vec<_> = batch.collect();
87                #(#batch_body)*
88                Ok(())
89            }
90        }
91    };
92    gen
93}
94
95/// Returns a `TokenStream` for `deserialize_with_mode`.
96/// uncompressed.
97fn impl_deserialize_field(ty: &Type) -> TokenStream {
98    // Check if type is a tuple.
99    match ty {
100        Type::Tuple(tuple) => {
101            let compressed_fields: Vec<_> =
102                tuple.elems.iter().map(impl_deserialize_field).collect();
103            quote! { (#(#compressed_fields)*), }
104        },
105        _ => {
106            quote! { CanonicalDeserialize::deserialize_with_mode(&mut reader, compress, validate)?, }
107        },
108    }
109}
110
111pub(super) fn impl_canonical_deserialize(ast: &syn::DeriveInput) -> TokenStream {
112    let valid_impl = impl_valid(ast);
113    let name = &ast.ident;
114
115    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
116
117    let deserialize_body;
118
119    match ast.data {
120        Data::Struct(ref data_struct) => {
121            let mut field_cases = Vec::<TokenStream>::with_capacity(data_struct.fields.len());
122            let mut tuple = false;
123            for field in data_struct.fields.iter() {
124                match &field.ident {
125                    None => {
126                        tuple = true;
127                        let compressed = impl_deserialize_field(&field.ty);
128                        field_cases.push(compressed);
129                    },
130                    // struct field without len_type
131                    Some(ident) => {
132                        let compressed = impl_deserialize_field(&field.ty);
133                        field_cases.push(quote! { #ident: #compressed });
134                    },
135                }
136            }
137
138            deserialize_body = if tuple {
139                quote!({
140                    Ok(#name (
141                        #(#field_cases)*
142                     ))
143                })
144            } else {
145                quote!({
146                    Ok(#name {
147                        #(#field_cases)*
148                    })
149                })
150            };
151        },
152        _ => panic!(
153            "`CanonicalDeserialize` can only be derived for structs, {} is not a Struct",
154            name
155        ),
156    };
157
158    let mut gen = quote! {
159        impl #impl_generics CanonicalDeserialize for #name #ty_generics #where_clause {
160            #[allow(unused_mut,unused_variables)]
161            fn deserialize_with_mode<R: ark_serialize::Read>(
162                mut reader: R,
163                compress: ark_serialize::Compress,
164                validate: ark_serialize::Validate,
165            ) -> Result<Self, ark_serialize::SerializationError> {
166                #deserialize_body
167            }
168        }
169    };
170    gen.extend(valid_impl);
171    gen
172}