1#![cfg_attr(docsrs, feature(doc_cfg))]
2
3use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use quote::quote;
6use syn::{parse_macro_input, parse_quote, Data, DeriveInput, Fields, Type};
7
8fn generate_encoding_impl(input: &DeriveInput) -> TokenStream2 {
9 let name = &input.ident;
10
11 let encoding_impl = match &input.data {
12 Data::Struct(data) => {
13 let mut encoding_bounds = Vec::new();
14 let field_encodings = match &data.fields {
15 Fields::Named(fields) => fields
16 .named
17 .iter()
18 .filter_map(|f| {
19 if has_skip_attribute(&f.attrs) {
20 return None;
21 }
22 let field_name = &f.ident;
23 encoding_bounds.push(f.ty.clone());
24 Some(quote! {
25 output.extend_from_slice(self.#field_name.encode().as_ref());
26 })
27 })
28 .collect::<Vec<_>>(),
29 Fields::Unnamed(fields) => fields
30 .unnamed
31 .iter()
32 .enumerate()
33 .filter_map(|(i, f)| {
34 if has_skip_attribute(&f.attrs) {
35 return None;
36 }
37 let index = syn::Index::from(i);
38 encoding_bounds.push(f.ty.clone());
39 Some(quote! {
40 output.extend_from_slice(self.#index.encode().as_ref());
41 })
42 })
43 .collect::<Vec<_>>(),
44 Fields::Unit => vec![],
45 };
46
47 let bound = quote!(::spongefish::Encoding<[u8]>);
48 let generics =
49 add_trait_bounds_for_fields(input.generics.clone(), &encoding_bounds, &bound);
50 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
51
52 quote! {
53 impl #impl_generics ::spongefish::Encoding<[u8]> for #name #ty_generics #where_clause {
54 fn encode(&self) -> impl AsRef<[u8]> {
55 let mut output = ::std::vec::Vec::new();
56 #(#field_encodings)*
57 output
58 }
59 }
60 }
61 }
62 _ => panic!("Encoding can only be derived for structs"),
63 };
64
65 encoding_impl
66}
67
68fn generate_decoding_impl(input: &DeriveInput) -> TokenStream2 {
69 let name = &input.ident;
70
71 let decoding_impl = match &input.data {
72 Data::Struct(data) => {
73 let mut decoding_bounds = Vec::new();
74 let (size_calc, field_decodings) = match &data.fields {
75 Fields::Named(fields) => {
76 let mut offset = quote!(0usize);
77 let mut field_decodings = vec![];
78 let mut size_components = vec![];
79
80 for field in fields.named.iter() {
81 if has_skip_attribute(&field.attrs) {
82 let field_name = &field.ident;
83 field_decodings.push(quote! {
84 #field_name: Default::default(),
85 });
86 continue;
87 }
88
89 let field_name = &field.ident;
90 let field_type = &field.ty;
91 decoding_bounds.push(field_type.clone());
92
93 size_components.push(quote! {
94 ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>()
95 });
96
97 let current_offset = offset.clone();
98 field_decodings.push(quote! {
99 #field_name: {
100 let field_size = ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>();
101 let start = #current_offset;
102 let end = start + field_size;
103 let mut field_buf = <#field_type as spongefish::Decoding<[u8]>>::Repr::default();
104 field_buf.as_mut().copy_from_slice(&buf.as_ref()[start..end]);
105 <#field_type as spongefish::Decoding<[u8]>>::decode(field_buf)
106 },
107 });
108
109 offset = quote! {
110 #offset + <#field_type as spongefish::Decoding<[u8]>>::Repr::default().as_mut().len()
111 };
112 }
113
114 let size_calc = if size_components.is_empty() {
115 quote!(0usize)
116 } else {
117 quote!(#(#size_components)+*)
118 };
119
120 (
121 size_calc,
122 quote! {
123 Self {
124 #(#field_decodings)*
125 }
126 },
127 )
128 }
129 Fields::Unnamed(fields) => {
130 let mut offset = quote!(0usize);
131 let mut field_decodings = vec![];
132 let mut size_components = vec![];
133
134 for field in fields.unnamed.iter() {
135 if has_skip_attribute(&field.attrs) {
136 field_decodings.push(quote! {
137 Default::default(),
138 });
139 continue;
140 }
141
142 let field_type = &field.ty;
143 decoding_bounds.push(field_type.clone());
144
145 size_components.push(quote! {
146 ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>()
147 });
148
149 let current_offset = offset.clone();
150 field_decodings.push(quote! {
151 {
152 let field_size = ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>();
153 let start = #current_offset;
154 let end = start + field_size;
155 let mut field_buf = <#field_type as spongefish::Decoding<[u8]>>::Repr::default();
156 field_buf.as_mut().copy_from_slice(&buf.as_ref()[start..end]);
157 <#field_type as spongefish::Decoding<[u8]>>::decode(field_buf)
158 },
159 });
160
161 offset = quote! {
162 #offset + <#field_type as spongefish::Decoding<[u8]>>::Repr::default().as_mut().len()
163 };
164 }
165
166 let size_calc = if size_components.is_empty() {
167 quote!(0usize)
168 } else {
169 quote!(#(#size_components)+*)
170 };
171
172 (
173 size_calc,
174 quote! {
175 Self(#(#field_decodings)*)
176 },
177 )
178 }
179 Fields::Unit => (quote!(0usize), quote!(Self)),
180 };
181
182 let bound = quote!(::spongefish::Decoding<[u8]>);
183 let generics =
184 add_trait_bounds_for_fields(input.generics.clone(), &decoding_bounds, &bound);
185 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
186
187 quote! {
188 impl #impl_generics ::spongefish::Decoding<[u8]> for #name #ty_generics #where_clause {
189 type Repr = spongefish::ByteArray<{ #size_calc }>;
190
191 fn decode(buf: Self::Repr) -> Self {
192 #field_decodings
193 }
194 }
195 }
196 }
197 _ => panic!("Decoding can only be derived for structs"),
198 };
199
200 decoding_impl
201}
202
203fn generate_narg_deserialize_impl(input: &DeriveInput) -> TokenStream2 {
204 let name = &input.ident;
205
206 let deserialize_impl = match &input.data {
207 Data::Struct(data) => {
208 let mut deserialize_bounds = Vec::new();
209 let field_deserializations = match &data.fields {
210 Fields::Named(fields) => {
211 let field_inits = fields.named.iter().map(|f| {
212 let field_name = &f.ident;
213 let field_type = &f.ty;
214
215 if has_skip_attribute(&f.attrs) {
216 quote! {
217 #field_name: Default::default(),
218 }
219 } else {
220 deserialize_bounds.push(field_type.clone());
221 quote! {
222 #field_name: <#field_type as spongefish::NargDeserialize>::deserialize_from_narg(&mut rest)?,
223 }
224 }
225 });
226
227 quote! {
228 Ok(Self {
229 #(#field_inits)*
230 })
231 }
232 }
233 Fields::Unnamed(fields) => {
234 let field_inits = fields.unnamed.iter().map(|f| {
235 let field_type = &f.ty;
236
237 if has_skip_attribute(&f.attrs) {
238 quote! {
239 Default::default(),
240 }
241 } else {
242 deserialize_bounds.push(field_type.clone());
243 quote! {
244 <#field_type as spongefish::NargDeserialize>::deserialize_from_narg(&mut rest)?,
245 }
246 }
247 });
248
249 quote! {
250 Ok(Self(#(#field_inits)*))
251 }
252 }
253 Fields::Unit => quote! {
254 Ok(Self)
255 },
256 };
257
258 let bound = quote!(::spongefish::NargDeserialize);
259 let generics =
260 add_trait_bounds_for_fields(input.generics.clone(), &deserialize_bounds, &bound);
261 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
262
263 quote! {
264 impl #impl_generics ::spongefish::NargDeserialize for #name #ty_generics #where_clause {
265 fn deserialize_from_narg(buf: &mut &[u8]) -> spongefish::VerificationResult<Self> {
266 let mut rest = *buf;
267 let value = (|| -> spongefish::VerificationResult<Self> {
268 #field_deserializations
269 })()?;
270 *buf = rest;
271 Ok(value)
272 }
273 }
274 }
275 }
276 _ => panic!("NargDeserialize can only be derived for structs"),
277 };
278
279 deserialize_impl
280}
281
282#[proc_macro_derive(Encoding, attributes(spongefish))]
303pub fn derive_encoding(input: TokenStream) -> TokenStream {
304 let input = parse_macro_input!(input as DeriveInput);
305 TokenStream::from(generate_encoding_impl(&input))
306}
307
308#[proc_macro_derive(Decoding, attributes(spongefish))]
313pub fn derive_decoding(input: TokenStream) -> TokenStream {
314 let input = parse_macro_input!(input as DeriveInput);
315 TokenStream::from(generate_decoding_impl(&input))
316}
317
318#[proc_macro_derive(NargDeserialize, attributes(spongefish))]
323pub fn derive_narg_deserialize(input: TokenStream) -> TokenStream {
324 let input = parse_macro_input!(input as DeriveInput);
325 TokenStream::from(generate_narg_deserialize_impl(&input))
326}
327
328#[proc_macro_derive(Codec, attributes(spongefish))]
332pub fn derive_codec(input: TokenStream) -> TokenStream {
333 let input = parse_macro_input!(input as DeriveInput);
334 let encoding = generate_encoding_impl(&input);
335 let decoding = generate_decoding_impl(&input);
336 let deserialize = generate_narg_deserialize_impl(&input);
337
338 TokenStream::from(quote! {
339 #encoding
340 #decoding
341 #deserialize
342 })
343}
344
345#[proc_macro_derive(Unit, attributes(spongefish))]
364pub fn derive_unit(input: TokenStream) -> TokenStream {
365 let input = parse_macro_input!(input as DeriveInput);
366 let name = input.ident;
367 let mut generics = input.generics;
368
369 let (zero_expr, unit_bounds) = match input.data {
370 Data::Struct(data) => match data.fields {
371 Fields::Named(fields) => {
372 let mut zero_fields = Vec::new();
373 let mut unit_bounds = Vec::new();
374
375 for field in fields.named.iter() {
376 let field_name = &field.ident;
377
378 if has_skip_attribute(&field.attrs) {
379 zero_fields.push(quote! {
380 #field_name: ::core::default::Default::default(),
381 });
382 continue;
383 }
384
385 let ty: Type = field.ty.clone();
386 unit_bounds.push(ty.clone());
387 zero_fields.push(quote! {
388 #field_name: <#ty as ::spongefish::Unit>::ZERO,
389 });
390 }
391
392 (
393 quote! {
394 Self {
395 #(#zero_fields)*
396 }
397 },
398 unit_bounds,
399 )
400 }
401 Fields::Unnamed(fields) => {
402 let mut zero_fields = Vec::new();
403 let mut unit_bounds = Vec::new();
404
405 for field in fields.unnamed.iter() {
406 if has_skip_attribute(&field.attrs) {
407 zero_fields.push(quote! {
408 ::core::default::Default::default()
409 });
410 continue;
411 }
412
413 let ty: Type = field.ty.clone();
414 unit_bounds.push(ty.clone());
415 zero_fields.push(quote! {
416 <#ty as ::spongefish::Unit>::ZERO
417 });
418 }
419
420 (
421 quote! {
422 Self(#(#zero_fields),*)
423 },
424 unit_bounds,
425 )
426 }
427 Fields::Unit => (quote!(Self), Vec::new()),
428 },
429 _ => panic!("Unit can only be derived for structs"),
430 };
431
432 let where_clause = generics.make_where_clause();
433 for ty in unit_bounds {
434 where_clause
435 .predicates
436 .push(parse_quote!(#ty: ::spongefish::Unit));
437 }
438
439 let (impl_generics, ty_generics, where_generics) = generics.split_for_impl();
440
441 let expanded = quote! {
442 impl #impl_generics ::spongefish::Unit for #name #ty_generics #where_generics {
443 const ZERO: Self = #zero_expr;
444 }
445 };
446
447 TokenStream::from(expanded)
448}
449
450fn has_skip_attribute(attrs: &[syn::Attribute]) -> bool {
452 attrs.iter().any(|attr| {
453 if !attr.path().is_ident("spongefish") {
454 return false;
455 }
456
457 attr.parse_nested_meta(|meta| {
458 if meta.path.is_ident("skip") {
459 Ok(())
460 } else {
461 Err(meta.error("expected `skip`"))
462 }
463 })
464 .is_ok()
465 })
466}
467
468fn add_trait_bounds_for_fields(
469 mut generics: syn::Generics,
470 field_types: &[Type],
471 trait_bound: &TokenStream2,
472) -> syn::Generics {
473 if field_types.is_empty() {
474 return generics;
475 }
476
477 let where_clause = generics.make_where_clause();
478 for ty in field_types {
479 where_clause
480 .predicates
481 .push(parse_quote!(#ty: #trait_bound));
482 }
483
484 generics
485}