Skip to main content

spongefish_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{parse_macro_input, parse_quote, Data, DeriveInput, Fields, Type};
5
6fn generate_encoding_impl(input: &DeriveInput) -> TokenStream2 {
7    let name = &input.ident;
8
9    let encoding_impl = match &input.data {
10        Data::Struct(data) => {
11            let mut encoding_bounds = Vec::new();
12            let field_encodings = match &data.fields {
13                Fields::Named(fields) => fields
14                    .named
15                    .iter()
16                    .filter_map(|f| {
17                        if has_skip_attribute(&f.attrs) {
18                            return None;
19                        }
20                        let field_name = &f.ident;
21                        encoding_bounds.push(f.ty.clone());
22                        Some(quote! {
23                            output.extend_from_slice(self.#field_name.encode().as_ref());
24                        })
25                    })
26                    .collect::<Vec<_>>(),
27                Fields::Unnamed(fields) => fields
28                    .unnamed
29                    .iter()
30                    .enumerate()
31                    .filter_map(|(i, f)| {
32                        if has_skip_attribute(&f.attrs) {
33                            return None;
34                        }
35                        let index = syn::Index::from(i);
36                        encoding_bounds.push(f.ty.clone());
37                        Some(quote! {
38                            output.extend_from_slice(self.#index.encode().as_ref());
39                        })
40                    })
41                    .collect::<Vec<_>>(),
42                Fields::Unit => vec![],
43            };
44
45            let bound = quote!(::spongefish::Encoding<[u8]>);
46            let generics =
47                add_trait_bounds_for_fields(input.generics.clone(), &encoding_bounds, &bound);
48            let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
49
50            quote! {
51                impl #impl_generics ::spongefish::Encoding<[u8]> for #name #ty_generics #where_clause {
52                    fn encode(&self) -> impl AsRef<[u8]> {
53                        let mut output = ::std::vec::Vec::new();
54                        #(#field_encodings)*
55                        output
56                    }
57                }
58            }
59        }
60        _ => panic!("Encoding can only be derived for structs"),
61    };
62
63    encoding_impl
64}
65
66fn generate_decoding_impl(input: &DeriveInput) -> TokenStream2 {
67    let name = &input.ident;
68
69    let decoding_impl = match &input.data {
70        Data::Struct(data) => {
71            let mut decoding_bounds = Vec::new();
72            let (size_calc, field_decodings) = match &data.fields {
73                Fields::Named(fields) => {
74                    let mut offset = quote!(0usize);
75                    let mut field_decodings = vec![];
76                    let mut size_components = vec![];
77
78                    for field in fields.named.iter() {
79                        if has_skip_attribute(&field.attrs) {
80                            let field_name = &field.ident;
81                            field_decodings.push(quote! {
82                                #field_name: Default::default(),
83                            });
84                            continue;
85                        }
86
87                        let field_name = &field.ident;
88                        let field_type = &field.ty;
89                        decoding_bounds.push(field_type.clone());
90
91                        size_components.push(quote! {
92                            ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>()
93                        });
94
95                        let current_offset = offset.clone();
96                        field_decodings.push(quote! {
97                            #field_name: {
98                                let field_size = ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>();
99                                let start = #current_offset;
100                                let end = start + field_size;
101                                let mut field_buf = <#field_type as spongefish::Decoding<[u8]>>::Repr::default();
102                                field_buf.as_mut().copy_from_slice(&buf.as_ref()[start..end]);
103                                <#field_type as spongefish::Decoding<[u8]>>::decode(field_buf)
104                            },
105                        });
106
107                        offset = quote! {
108                            #offset + <#field_type as spongefish::Decoding<[u8]>>::Repr::default().as_mut().len()
109                        };
110                    }
111
112                    let size_calc = if size_components.is_empty() {
113                        quote!(0usize)
114                    } else {
115                        quote!(#(#size_components)+*)
116                    };
117
118                    (
119                        size_calc,
120                        quote! {
121                            Self {
122                                #(#field_decodings)*
123                            }
124                        },
125                    )
126                }
127                Fields::Unnamed(fields) => {
128                    let mut offset = quote!(0usize);
129                    let mut field_decodings = vec![];
130                    let mut size_components = vec![];
131
132                    for field in fields.unnamed.iter() {
133                        if has_skip_attribute(&field.attrs) {
134                            field_decodings.push(quote! {
135                                Default::default(),
136                            });
137                            continue;
138                        }
139
140                        let field_type = &field.ty;
141                        decoding_bounds.push(field_type.clone());
142
143                        size_components.push(quote! {
144                            ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>()
145                        });
146
147                        let current_offset = offset.clone();
148                        field_decodings.push(quote! {
149                            {
150                                let field_size = ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>();
151                                let start = #current_offset;
152                                let end = start + field_size;
153                                let mut field_buf = <#field_type as spongefish::Decoding<[u8]>>::Repr::default();
154                                field_buf.as_mut().copy_from_slice(&buf.as_ref()[start..end]);
155                                <#field_type as spongefish::Decoding<[u8]>>::decode(field_buf)
156                            },
157                        });
158
159                        offset = quote! {
160                            #offset + <#field_type as spongefish::Decoding<[u8]>>::Repr::default().as_mut().len()
161                        };
162                    }
163
164                    let size_calc = if size_components.is_empty() {
165                        quote!(0usize)
166                    } else {
167                        quote!(#(#size_components)+*)
168                    };
169
170                    (
171                        size_calc,
172                        quote! {
173                            Self(#(#field_decodings)*)
174                        },
175                    )
176                }
177                Fields::Unit => (quote!(0usize), quote!(Self)),
178            };
179
180            let bound = quote!(::spongefish::Decoding<[u8]>);
181            let generics =
182                add_trait_bounds_for_fields(input.generics.clone(), &decoding_bounds, &bound);
183            let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
184
185            quote! {
186                impl #impl_generics ::spongefish::Decoding<[u8]> for #name #ty_generics #where_clause {
187                    type Repr = spongefish::ByteArray<{ #size_calc }>;
188
189                    fn decode(buf: Self::Repr) -> Self {
190                        #field_decodings
191                    }
192                }
193            }
194        }
195        _ => panic!("Decoding can only be derived for structs"),
196    };
197
198    decoding_impl
199}
200
201fn generate_narg_deserialize_impl(input: &DeriveInput) -> TokenStream2 {
202    let name = &input.ident;
203
204    let deserialize_impl = match &input.data {
205        Data::Struct(data) => {
206            let mut deserialize_bounds = Vec::new();
207            let field_deserializations = match &data.fields {
208                Fields::Named(fields) => {
209                    let field_inits = fields.named.iter().map(|f| {
210                        let field_name = &f.ident;
211                        let field_type = &f.ty;
212
213                        if has_skip_attribute(&f.attrs) {
214                            quote! {
215                                #field_name: Default::default(),
216                            }
217                        } else {
218                            deserialize_bounds.push(field_type.clone());
219                            quote! {
220                                #field_name: <#field_type as spongefish::NargDeserialize>::deserialize_from_narg(&mut rest)?,
221                            }
222                        }
223                    });
224
225                    quote! {
226                        Ok(Self {
227                            #(#field_inits)*
228                        })
229                    }
230                }
231                Fields::Unnamed(fields) => {
232                    let field_inits = fields.unnamed.iter().map(|f| {
233                        let field_type = &f.ty;
234
235                        if has_skip_attribute(&f.attrs) {
236                            quote! {
237                                Default::default(),
238                            }
239                        } else {
240                            deserialize_bounds.push(field_type.clone());
241                            quote! {
242                                <#field_type as spongefish::NargDeserialize>::deserialize_from_narg(&mut rest)?,
243                            }
244                        }
245                    });
246
247                    quote! {
248                        Ok(Self(#(#field_inits)*))
249                    }
250                }
251                Fields::Unit => quote! {
252                    Ok(Self)
253                },
254            };
255
256            let bound = quote!(::spongefish::NargDeserialize);
257            let generics =
258                add_trait_bounds_for_fields(input.generics.clone(), &deserialize_bounds, &bound);
259            let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
260
261            quote! {
262                impl #impl_generics ::spongefish::NargDeserialize for #name #ty_generics #where_clause {
263                    fn deserialize_from_narg(buf: &mut &[u8]) -> spongefish::VerificationResult<Self> {
264                        let mut rest = *buf;
265                        let value = (|| -> spongefish::VerificationResult<Self> {
266                            #field_deserializations
267                        })()?;
268                        *buf = rest;
269                        Ok(value)
270                    }
271                }
272            }
273        }
274        _ => panic!("NargDeserialize can only be derived for structs"),
275    };
276
277    deserialize_impl
278}
279
280/// Derive [`Encoding`](https://docs.rs/spongefish/latest/spongefish/trait.Encoding.html) for structs.
281///
282/// Skipped fields fall back to `Default`.
283///
284/// ```
285/// use spongefish::Encoding;
286/// # use spongefish_derive::Encoding;
287///
288/// #[derive(Encoding)]
289/// struct Rgb {
290///     r: u8,
291///     g: u8,
292///     b: u8,
293/// }
294///
295/// let colors = Rgb { r: 1, g: 2, b: 3 };
296/// let data = colors.encode();
297/// assert_eq!(data.as_ref(), [1, 2, 3]);
298///
299/// ```
300#[proc_macro_derive(Encoding, attributes(spongefish))]
301pub fn derive_encoding(input: TokenStream) -> TokenStream {
302    let input = parse_macro_input!(input as DeriveInput);
303    TokenStream::from(generate_encoding_impl(&input))
304}
305
306/// Derive macro for the [`Decoding`](https://docs.rs/spongefish/latest/spongefish/trait.Decoding.html) trait.
307///
308/// Generates an implementation that decodes struct fields sequentially from a fixed-size buffer.
309/// Fields can be skipped using `#[spongefish(skip)]`.
310#[proc_macro_derive(Decoding, attributes(spongefish))]
311pub fn derive_decoding(input: TokenStream) -> TokenStream {
312    let input = parse_macro_input!(input as DeriveInput);
313    TokenStream::from(generate_decoding_impl(&input))
314}
315
316/// Derive macro for the [`NargDeserialize`](https://docs.rs/spongefish/latest/spongefish/trait.NargDeserialize.html) trait.
317///
318/// Generates an implementation that deserializes struct fields sequentially from a byte buffer.
319/// Fields can be skipped using `#[spongefish(skip)]`.
320#[proc_macro_derive(NargDeserialize, attributes(spongefish))]
321pub fn derive_narg_deserialize(input: TokenStream) -> TokenStream {
322    let input = parse_macro_input!(input as DeriveInput);
323    TokenStream::from(generate_narg_deserialize_impl(&input))
324}
325
326/// Derive macro that generates [`Encoding`](https://docs.rs/spongefish/latest/spongefish/trait.Encoding.html),
327/// [`Decoding`](https://docs.rs/spongefish/latest/spongefish/trait.Decoding.html), and
328/// [`NargDeserialize`](https://docs.rs/spongefish/latest/spongefish/trait.NargDeserialize.html) in one go.
329#[proc_macro_derive(Codec, attributes(spongefish))]
330pub fn derive_codec(input: TokenStream) -> TokenStream {
331    let input = parse_macro_input!(input as DeriveInput);
332    let encoding = generate_encoding_impl(&input);
333    let decoding = generate_decoding_impl(&input);
334    let deserialize = generate_narg_deserialize_impl(&input);
335
336    TokenStream::from(quote! {
337        #encoding
338        #decoding
339        #deserialize
340    })
341}
342
343/// Derive [`Unit`]s for structs.
344///
345/// ```
346/// use spongefish::Unit;
347/// # use spongefish_derive::Unit;
348///
349/// #[derive(Clone, Unit)]
350/// struct Rgb {
351///     r: u8,
352///     g: u8,
353///     b: u8,
354/// }
355///
356/// assert_eq!((Rgb::ZERO.r, Rgb::ZERO.g, Rgb::ZERO.b), (0, 0, 0));
357///
358/// ```
359///
360/// [Unit]: https://docs.rs/spongefish/latest/spongefish/trait.Unit.html
361#[proc_macro_derive(Unit, attributes(spongefish))]
362pub fn derive_unit(input: TokenStream) -> TokenStream {
363    let input = parse_macro_input!(input as DeriveInput);
364    let name = input.ident;
365    let mut generics = input.generics;
366
367    let (zero_expr, unit_bounds) = match input.data {
368        Data::Struct(data) => match data.fields {
369            Fields::Named(fields) => {
370                let mut zero_fields = Vec::new();
371                let mut unit_bounds = Vec::new();
372
373                for field in fields.named.iter() {
374                    let field_name = &field.ident;
375
376                    if has_skip_attribute(&field.attrs) {
377                        zero_fields.push(quote! {
378                            #field_name: ::core::default::Default::default(),
379                        });
380                        continue;
381                    }
382
383                    let ty: Type = field.ty.clone();
384                    unit_bounds.push(ty.clone());
385                    zero_fields.push(quote! {
386                        #field_name: <#ty as ::spongefish::Unit>::ZERO,
387                    });
388                }
389
390                (
391                    quote! {
392                        Self {
393                            #(#zero_fields)*
394                        }
395                    },
396                    unit_bounds,
397                )
398            }
399            Fields::Unnamed(fields) => {
400                let mut zero_fields = Vec::new();
401                let mut unit_bounds = Vec::new();
402
403                for field in fields.unnamed.iter() {
404                    if has_skip_attribute(&field.attrs) {
405                        zero_fields.push(quote! {
406                            ::core::default::Default::default()
407                        });
408                        continue;
409                    }
410
411                    let ty: Type = field.ty.clone();
412                    unit_bounds.push(ty.clone());
413                    zero_fields.push(quote! {
414                        <#ty as ::spongefish::Unit>::ZERO
415                    });
416                }
417
418                (
419                    quote! {
420                        Self(#(#zero_fields),*)
421                    },
422                    unit_bounds,
423                )
424            }
425            Fields::Unit => (quote!(Self), Vec::new()),
426        },
427        _ => panic!("Unit can only be derived for structs"),
428    };
429
430    let where_clause = generics.make_where_clause();
431    for ty in unit_bounds {
432        where_clause
433            .predicates
434            .push(parse_quote!(#ty: ::spongefish::Unit));
435    }
436
437    let (impl_generics, ty_generics, where_generics) = generics.split_for_impl();
438
439    let expanded = quote! {
440        impl #impl_generics ::spongefish::Unit for #name #ty_generics #where_generics {
441            const ZERO: Self = #zero_expr;
442        }
443    };
444
445    TokenStream::from(expanded)
446}
447
448/// Helper function to check if a field has the #[spongefish(skip)] attribute
449fn has_skip_attribute(attrs: &[syn::Attribute]) -> bool {
450    attrs.iter().any(|attr| {
451        if !attr.path().is_ident("spongefish") {
452            return false;
453        }
454
455        attr.parse_nested_meta(|meta| {
456            if meta.path.is_ident("skip") {
457                Ok(())
458            } else {
459                Err(meta.error("expected `skip`"))
460            }
461        })
462        .is_ok()
463    })
464}
465
466fn add_trait_bounds_for_fields(
467    mut generics: syn::Generics,
468    field_types: &[Type],
469    trait_bound: &TokenStream2,
470) -> syn::Generics {
471    if field_types.is_empty() {
472        return generics;
473    }
474
475    let where_clause = generics.make_where_clause();
476    for ty in field_types {
477        where_clause
478            .predicates
479            .push(parse_quote!(#ty: #trait_bound));
480    }
481
482    generics
483}