spongefish_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{parse_macro_input, parse_quote, Data, DeriveInput, Fields, Type};
5
6fn generate_encoding_impl(input: &DeriveInput) -> TokenStream2 {
7    let name = &input.ident;
8
9    let encoding_impl = match &input.data {
10        Data::Struct(data) => {
11            let mut encoding_bounds = Vec::new();
12            let field_encodings = match &data.fields {
13                Fields::Named(fields) => fields
14                    .named
15                    .iter()
16                    .filter_map(|f| {
17                        if has_skip_attribute(&f.attrs) {
18                            return None;
19                        }
20                        let field_name = &f.ident;
21                        encoding_bounds.push(f.ty.clone());
22                        Some(quote! {
23                            output.extend_from_slice(self.#field_name.encode().as_ref());
24                        })
25                    })
26                    .collect::<Vec<_>>(),
27                Fields::Unnamed(fields) => fields
28                    .unnamed
29                    .iter()
30                    .enumerate()
31                    .filter_map(|(i, f)| {
32                        if has_skip_attribute(&f.attrs) {
33                            return None;
34                        }
35                        let index = syn::Index::from(i);
36                        encoding_bounds.push(f.ty.clone());
37                        Some(quote! {
38                            output.extend_from_slice(self.#index.encode().as_ref());
39                        })
40                    })
41                    .collect::<Vec<_>>(),
42                Fields::Unit => vec![],
43            };
44
45            let bound = quote!(::spongefish::Encoding<[u8]>);
46            let generics =
47                add_trait_bounds_for_fields(input.generics.clone(), &encoding_bounds, &bound);
48            let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
49
50            quote! {
51                impl #impl_generics ::spongefish::Encoding<[u8]> for #name #ty_generics #where_clause {
52                    fn encode(&self) -> impl AsRef<[u8]> {
53                        let mut output = ::std::vec::Vec::new();
54                        #(#field_encodings)*
55                        output
56                    }
57                }
58            }
59        }
60        _ => panic!("Encoding can only be derived for structs"),
61    };
62
63    encoding_impl
64}
65
66fn generate_decoding_impl(input: &DeriveInput) -> TokenStream2 {
67    let name = &input.ident;
68
69    let decoding_impl = match &input.data {
70        Data::Struct(data) => {
71            let mut decoding_bounds = Vec::new();
72            let (size_calc, field_decodings) = match &data.fields {
73                Fields::Named(fields) => {
74                    let mut offset = quote!(0usize);
75                    let mut field_decodings = vec![];
76                    let mut size_components = vec![];
77
78                    for field in fields.named.iter() {
79                        if has_skip_attribute(&field.attrs) {
80                            let field_name = &field.ident;
81                            field_decodings.push(quote! {
82                                #field_name: Default::default(),
83                            });
84                            continue;
85                        }
86
87                        let field_name = &field.ident;
88                        let field_type = &field.ty;
89                        decoding_bounds.push(field_type.clone());
90
91                        size_components.push(quote! {
92                            ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>()
93                        });
94
95                        let current_offset = offset.clone();
96                        field_decodings.push(quote! {
97                            #field_name: {
98                                let field_size = ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>();
99                                let start = #current_offset;
100                                let end = start + field_size;
101                                let mut field_buf = <#field_type as spongefish::Decoding<[u8]>>::Repr::default();
102                                field_buf.as_mut().copy_from_slice(&buf.as_ref()[start..end]);
103                                <#field_type as spongefish::Decoding<[u8]>>::decode(field_buf)
104                            },
105                        });
106
107                        offset = quote! {
108                            #offset + <#field_type as spongefish::Decoding<[u8]>>::Repr::default().as_mut().len()
109                        };
110                    }
111
112                    let size_calc = if size_components.is_empty() {
113                        quote!(0usize)
114                    } else {
115                        quote!(#(#size_components)+*)
116                    };
117
118                    (
119                        size_calc,
120                        quote! {
121                            Self {
122                                #(#field_decodings)*
123                            }
124                        },
125                    )
126                }
127                Fields::Unnamed(fields) => {
128                    let mut offset = quote!(0usize);
129                    let mut field_decodings = vec![];
130                    let mut size_components = vec![];
131
132                    for field in fields.unnamed.iter() {
133                        if has_skip_attribute(&field.attrs) {
134                            field_decodings.push(quote! {
135                                Default::default(),
136                            });
137                            continue;
138                        }
139
140                        let field_type = &field.ty;
141                        decoding_bounds.push(field_type.clone());
142
143                        size_components.push(quote! {
144                            ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>()
145                        });
146
147                        let current_offset = offset.clone();
148                        field_decodings.push(quote! {
149                            {
150                                let field_size = ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>();
151                                let start = #current_offset;
152                                let end = start + field_size;
153                                let mut field_buf = <#field_type as spongefish::Decoding<[u8]>>::Repr::default();
154                                field_buf.as_mut().copy_from_slice(&buf.as_ref()[start..end]);
155                                <#field_type as spongefish::Decoding<[u8]>>::decode(field_buf)
156                            },
157                        });
158
159                        offset = quote! {
160                            #offset + <#field_type as spongefish::Decoding<[u8]>>::Repr::default().as_mut().len()
161                        };
162                    }
163
164                    let size_calc = if size_components.is_empty() {
165                        quote!(0usize)
166                    } else {
167                        quote!(#(#size_components)+*)
168                    };
169
170                    (
171                        size_calc,
172                        quote! {
173                            Self(#(#field_decodings)*)
174                        },
175                    )
176                }
177                Fields::Unit => (quote!(0usize), quote!(Self)),
178            };
179
180            let bound = quote!(::spongefish::Decoding<[u8]>);
181            let generics =
182                add_trait_bounds_for_fields(input.generics.clone(), &decoding_bounds, &bound);
183            let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
184
185            quote! {
186                impl #impl_generics ::spongefish::Decoding<[u8]> for #name #ty_generics #where_clause {
187                    type Repr = spongefish::ByteArray<{ #size_calc }>;
188
189                    fn decode(buf: Self::Repr) -> Self {
190                        #field_decodings
191                    }
192                }
193            }
194        }
195        _ => panic!("Decoding can only be derived for structs"),
196    };
197
198    decoding_impl
199}
200
201fn generate_narg_deserialize_impl(input: &DeriveInput) -> TokenStream2 {
202    let name = &input.ident;
203
204    let deserialize_impl = match &input.data {
205        Data::Struct(data) => {
206            let mut deserialize_bounds = Vec::new();
207            let field_deserializations = match &data.fields {
208                Fields::Named(fields) => {
209                    let field_inits = fields.named.iter().map(|f| {
210                        let field_name = &f.ident;
211                        let field_type = &f.ty;
212
213                        if has_skip_attribute(&f.attrs) {
214                            quote! {
215                                #field_name: Default::default(),
216                            }
217                        } else {
218                            deserialize_bounds.push(field_type.clone());
219                            quote! {
220                                #field_name: <#field_type as spongefish::NargDeserialize>::deserialize_from_narg(buf)?,
221                            }
222                        }
223                    });
224
225                    quote! {
226                        Ok(Self {
227                            #(#field_inits)*
228                        })
229                    }
230                }
231                Fields::Unnamed(fields) => {
232                    let field_inits = fields.unnamed.iter().map(|f| {
233                        let field_type = &f.ty;
234
235                        if has_skip_attribute(&f.attrs) {
236                            quote! {
237                                Default::default(),
238                            }
239                        } else {
240                            deserialize_bounds.push(field_type.clone());
241                            quote! {
242                                <#field_type as spongefish::NargDeserialize>::deserialize_from_narg(buf)?,
243                            }
244                        }
245                    });
246
247                    quote! {
248                        Ok(Self(#(#field_inits)*))
249                    }
250                }
251                Fields::Unit => quote! {
252                    Ok(Self)
253                },
254            };
255
256            let bound = quote!(::spongefish::NargDeserialize);
257            let generics =
258                add_trait_bounds_for_fields(input.generics.clone(), &deserialize_bounds, &bound);
259            let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
260
261            quote! {
262                impl #impl_generics ::spongefish::NargDeserialize for #name #ty_generics #where_clause {
263                    fn deserialize_from_narg(buf: &mut &[u8]) -> spongefish::VerificationResult<Self> {
264                        #field_deserializations
265                    }
266                }
267            }
268        }
269        _ => panic!("NargDeserialize can only be derived for structs"),
270    };
271
272    deserialize_impl
273}
274
275/// Derive [`Encoding`](https://docs.rs/spongefish/latest/spongefish/trait.Encoding.html) for structs.
276///
277/// Skipped fields fall back to `Default`.
278///
279/// ```
280/// use spongefish::Encoding;
281/// # use spongefish_derive::Encoding;
282///
283/// #[derive(Encoding)]
284/// struct Rgb {
285///     r: u8,
286///     g: u8,
287///     b: u8,
288/// }
289///
290/// let colors = Rgb { r: 1, g: 2, b: 3 };
291/// let data = colors.encode();
292/// assert_eq!(data.as_ref(), [1, 2, 3]);
293///
294/// ```
295#[proc_macro_derive(Encoding, attributes(spongefish))]
296pub fn derive_encoding(input: TokenStream) -> TokenStream {
297    let input = parse_macro_input!(input as DeriveInput);
298    TokenStream::from(generate_encoding_impl(&input))
299}
300
301/// Derive macro for the [`Decoding`](https://docs.rs/spongefish/latest/spongefish/trait.Decoding.html) trait.
302///
303/// Generates an implementation that decodes struct fields sequentially from a fixed-size buffer.
304/// Fields can be skipped using `#[spongefish(skip)]`.
305#[proc_macro_derive(Decoding, attributes(spongefish))]
306pub fn derive_decoding(input: TokenStream) -> TokenStream {
307    let input = parse_macro_input!(input as DeriveInput);
308    TokenStream::from(generate_decoding_impl(&input))
309}
310
311/// Derive macro for the [`NargDeserialize`](https://docs.rs/spongefish/latest/spongefish/trait.NargDeserialize.html) trait.
312///
313/// Generates an implementation that deserializes struct fields sequentially from a byte buffer.
314/// Fields can be skipped using `#[spongefish(skip)]`.
315#[proc_macro_derive(NargDeserialize, attributes(spongefish))]
316pub fn derive_narg_deserialize(input: TokenStream) -> TokenStream {
317    let input = parse_macro_input!(input as DeriveInput);
318    TokenStream::from(generate_narg_deserialize_impl(&input))
319}
320
321/// Derive macro that generates [`Encoding`](https://docs.rs/spongefish/latest/spongefish/trait.Encoding.html),
322/// [`Decoding`](https://docs.rs/spongefish/latest/spongefish/trait.Decoding.html), and
323/// [`NargDeserialize`](https://docs.rs/spongefish/latest/spongefish/trait.NargDeserialize.html) in one go.
324#[proc_macro_derive(Codec, attributes(spongefish))]
325pub fn derive_codec(input: TokenStream) -> TokenStream {
326    let input = parse_macro_input!(input as DeriveInput);
327    let encoding = generate_encoding_impl(&input);
328    let decoding = generate_decoding_impl(&input);
329    let deserialize = generate_narg_deserialize_impl(&input);
330
331    TokenStream::from(quote! {
332        #encoding
333        #decoding
334        #deserialize
335    })
336}
337
338/// Derive [`Unit`]s for structs.
339///
340/// ```
341/// use spongefish::Unit;
342/// # use spongefish_derive::Unit;
343///
344/// #[derive(Clone, Unit)]
345/// struct Rgb {
346///     r: u8,
347///     g: u8,
348///     b: u8,
349/// }
350///
351/// assert_eq!((Rgb::ZERO.r, Rgb::ZERO.g, Rgb::ZERO.b), (0, 0, 0));
352///
353/// ```
354///
355/// [Unit]: https://docs.rs/spongefish/latest/spongefish/trait.Unit.html
356#[proc_macro_derive(Unit, attributes(spongefish))]
357pub fn derive_unit(input: TokenStream) -> TokenStream {
358    let input = parse_macro_input!(input as DeriveInput);
359    let name = input.ident;
360    let mut generics = input.generics;
361
362    let (zero_expr, unit_bounds) = match input.data {
363        Data::Struct(data) => match data.fields {
364            Fields::Named(fields) => {
365                let mut zero_fields = Vec::new();
366                let mut unit_bounds = Vec::new();
367
368                for field in fields.named.iter() {
369                    let field_name = &field.ident;
370
371                    if has_skip_attribute(&field.attrs) {
372                        zero_fields.push(quote! {
373                            #field_name: ::core::default::Default::default(),
374                        });
375                        continue;
376                    }
377
378                    let ty: Type = field.ty.clone();
379                    unit_bounds.push(ty.clone());
380                    zero_fields.push(quote! {
381                        #field_name: <#ty as ::spongefish::Unit>::ZERO,
382                    });
383                }
384
385                (
386                    quote! {
387                        Self {
388                            #(#zero_fields)*
389                        }
390                    },
391                    unit_bounds,
392                )
393            }
394            Fields::Unnamed(fields) => {
395                let mut zero_fields = Vec::new();
396                let mut unit_bounds = Vec::new();
397
398                for field in fields.unnamed.iter() {
399                    if has_skip_attribute(&field.attrs) {
400                        zero_fields.push(quote! {
401                            ::core::default::Default::default()
402                        });
403                        continue;
404                    }
405
406                    let ty: Type = field.ty.clone();
407                    unit_bounds.push(ty.clone());
408                    zero_fields.push(quote! {
409                        <#ty as ::spongefish::Unit>::ZERO
410                    });
411                }
412
413                (
414                    quote! {
415                        Self(#(#zero_fields),*)
416                    },
417                    unit_bounds,
418                )
419            }
420            Fields::Unit => (quote!(Self), Vec::new()),
421        },
422        _ => panic!("Unit can only be derived for structs"),
423    };
424
425    let where_clause = generics.make_where_clause();
426    for ty in unit_bounds {
427        where_clause
428            .predicates
429            .push(parse_quote!(#ty: ::spongefish::Unit));
430    }
431
432    let (impl_generics, ty_generics, where_generics) = generics.split_for_impl();
433
434    let expanded = quote! {
435        impl #impl_generics ::spongefish::Unit for #name #ty_generics #where_generics {
436            const ZERO: Self = #zero_expr;
437        }
438    };
439
440    TokenStream::from(expanded)
441}
442
443/// Helper function to check if a field has the #[spongefish(skip)] attribute
444fn has_skip_attribute(attrs: &[syn::Attribute]) -> bool {
445    attrs.iter().any(|attr| {
446        if !attr.path().is_ident("spongefish") {
447            return false;
448        }
449
450        attr.parse_nested_meta(|meta| {
451            if meta.path.is_ident("skip") {
452                Ok(())
453            } else {
454                Err(meta.error("expected `skip`"))
455            }
456        })
457        .is_ok()
458    })
459}
460
461fn add_trait_bounds_for_fields(
462    mut generics: syn::Generics,
463    field_types: &[Type],
464    trait_bound: &TokenStream2,
465) -> syn::Generics {
466    if field_types.is_empty() {
467        return generics;
468    }
469
470    let where_clause = generics.make_where_clause();
471    for ty in field_types {
472        where_clause
473            .predicates
474            .push(parse_quote!(#ty: #trait_bound));
475    }
476
477    generics
478}