spongefish_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{parse_macro_input, parse_quote, Data, DeriveInput, Fields, Type};
5
6fn generate_encoding_impl(input: &DeriveInput) -> TokenStream2 {
7    let name = &input.ident;
8
9    let encoding_impl = match &input.data {
10        Data::Struct(data) => {
11            let mut encoding_bounds = Vec::new();
12            let field_encodings = match &data.fields {
13                Fields::Named(fields) => fields
14                    .named
15                    .iter()
16                    .filter_map(|f| {
17                        if has_skip_attribute(&f.attrs) {
18                            return None;
19                        }
20                        let field_name = &f.ident;
21                        encoding_bounds.push(f.ty.clone());
22                        Some(quote! {
23                            output.extend_from_slice(self.#field_name.encode().as_ref());
24                        })
25                    })
26                    .collect::<Vec<_>>(),
27                Fields::Unnamed(fields) => fields
28                    .unnamed
29                    .iter()
30                    .enumerate()
31                    .filter_map(|(i, f)| {
32                        if has_skip_attribute(&f.attrs) {
33                            return None;
34                        }
35                        let index = syn::Index::from(i);
36                        encoding_bounds.push(f.ty.clone());
37                        Some(quote! {
38                            output.extend_from_slice(self.#index.encode().as_ref());
39                        })
40                    })
41                    .collect::<Vec<_>>(),
42                Fields::Unit => vec![],
43            };
44
45            let bound = quote!(::spongefish::Encoding<[u8]>);
46            let generics = add_trait_bounds_for_fields(input.generics.clone(), &encoding_bounds, &bound);
47            let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
48
49            quote! {
50                impl #impl_generics ::spongefish::Encoding<[u8]> for #name #ty_generics #where_clause {
51                    fn encode(&self) -> impl AsRef<[u8]> {
52                        let mut output = ::std::vec::Vec::new();
53                        #(#field_encodings)*
54                        output
55                    }
56                }
57            }
58        }
59        _ => panic!("Encoding can only be derived for structs"),
60    };
61
62    encoding_impl
63}
64
65fn generate_decoding_impl(input: &DeriveInput) -> TokenStream2 {
66    let name = &input.ident;
67
68    let decoding_impl = match &input.data {
69        Data::Struct(data) => {
70            let mut decoding_bounds = Vec::new();
71            let (size_calc, field_decodings) = match &data.fields {
72                Fields::Named(fields) => {
73                    let mut offset = quote!(0usize);
74                    let mut field_decodings = vec![];
75                    let mut size_components = vec![];
76
77                    for field in fields.named.iter() {
78                        if has_skip_attribute(&field.attrs) {
79                            let field_name = &field.ident;
80                            field_decodings.push(quote! {
81                                #field_name: Default::default(),
82                            });
83                            continue;
84                        }
85
86                        let field_name = &field.ident;
87                        let field_type = &field.ty;
88                        decoding_bounds.push(field_type.clone());
89
90                        size_components.push(quote! {
91                            ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>()
92                        });
93
94                        let current_offset = offset.clone();
95                        field_decodings.push(quote! {
96                            #field_name: {
97                                let field_size = ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>();
98                                let start = #current_offset;
99                                let end = start + field_size;
100                                let mut field_buf = <#field_type as spongefish::Decoding<[u8]>>::Repr::default();
101                                field_buf.as_mut().copy_from_slice(&buf.as_ref()[start..end]);
102                                <#field_type as spongefish::Decoding<[u8]>>::decode(field_buf)
103                            },
104                        });
105
106                        offset = quote! {
107                            #offset + <#field_type as spongefish::Decoding<[u8]>>::Repr::default().as_mut().len()
108                        };
109                    }
110
111                    let size_calc = if size_components.is_empty() {
112                        quote!(0usize)
113                    } else {
114                        quote!(#(#size_components)+*)
115                    };
116
117                    (
118                        size_calc,
119                        quote! {
120                            Self {
121                                #(#field_decodings)*
122                            }
123                        },
124                    )
125                }
126                Fields::Unnamed(fields) => {
127                    let mut offset = quote!(0usize);
128                    let mut field_decodings = vec![];
129                    let mut size_components = vec![];
130
131                    for field in fields.unnamed.iter() {
132                        if has_skip_attribute(&field.attrs) {
133                            field_decodings.push(quote! {
134                                Default::default(),
135                            });
136                            continue;
137                        }
138
139                        let field_type = &field.ty;
140                        decoding_bounds.push(field_type.clone());
141
142                        size_components.push(quote! {
143                            ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>()
144                        });
145
146                        let current_offset = offset.clone();
147                        field_decodings.push(quote! {
148                            {
149                                let field_size = ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>();
150                                let start = #current_offset;
151                                let end = start + field_size;
152                                let mut field_buf = <#field_type as spongefish::Decoding<[u8]>>::Repr::default();
153                                field_buf.as_mut().copy_from_slice(&buf.as_ref()[start..end]);
154                                <#field_type as spongefish::Decoding<[u8]>>::decode(field_buf)
155                            },
156                        });
157
158                        offset = quote! {
159                            #offset + <#field_type as spongefish::Decoding<[u8]>>::Repr::default().as_mut().len()
160                        };
161                    }
162
163                    let size_calc = if size_components.is_empty() {
164                        quote!(0usize)
165                    } else {
166                        quote!(#(#size_components)+*)
167                    };
168
169                    (
170                        size_calc,
171                        quote! {
172                            Self(#(#field_decodings)*)
173                        },
174                    )
175                }
176                Fields::Unit => (quote!(0usize), quote!(Self)),
177            };
178
179            let bound = quote!(::spongefish::Decoding<[u8]>);
180            let generics = add_trait_bounds_for_fields(input.generics.clone(), &decoding_bounds, &bound);
181            let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
182
183            quote! {
184                impl #impl_generics ::spongefish::Decoding<[u8]> for #name #ty_generics #where_clause {
185                    type Repr = spongefish::ByteArray<{ #size_calc }>;
186
187                    fn decode(buf: Self::Repr) -> Self {
188                        #field_decodings
189                    }
190                }
191            }
192        }
193        _ => panic!("Decoding can only be derived for structs"),
194    };
195
196    decoding_impl
197}
198
199fn generate_narg_deserialize_impl(input: &DeriveInput) -> TokenStream2 {
200    let name = &input.ident;
201
202    let deserialize_impl = match &input.data {
203        Data::Struct(data) => {
204            let mut deserialize_bounds = Vec::new();
205            let field_deserializations = match &data.fields {
206                Fields::Named(fields) => {
207                    let field_inits = fields.named.iter().map(|f| {
208                        let field_name = &f.ident;
209                        let field_type = &f.ty;
210
211                        if has_skip_attribute(&f.attrs) {
212                            quote! {
213                                #field_name: Default::default(),
214                            }
215                        } else {
216                            deserialize_bounds.push(field_type.clone());
217                            quote! {
218                                #field_name: <#field_type as spongefish::NargDeserialize>::deserialize_from_narg(buf)?,
219                            }
220                        }
221                    });
222
223                    quote! {
224                        Ok(Self {
225                            #(#field_inits)*
226                        })
227                    }
228                }
229                Fields::Unnamed(fields) => {
230                    let field_inits = fields.unnamed.iter().map(|f| {
231                        let field_type = &f.ty;
232
233                        if has_skip_attribute(&f.attrs) {
234                            quote! {
235                                Default::default(),
236                            }
237                        } else {
238                            deserialize_bounds.push(field_type.clone());
239                            quote! {
240                                <#field_type as spongefish::NargDeserialize>::deserialize_from_narg(buf)?,
241                            }
242                        }
243                    });
244
245                    quote! {
246                        Ok(Self(#(#field_inits)*))
247                    }
248                }
249                Fields::Unit => quote! {
250                    Ok(Self)
251                },
252            };
253
254            let bound = quote!(::spongefish::NargDeserialize);
255            let generics = add_trait_bounds_for_fields(input.generics.clone(), &deserialize_bounds, &bound);
256            let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
257
258            quote! {
259                impl #impl_generics ::spongefish::NargDeserialize for #name #ty_generics #where_clause {
260                    fn deserialize_from_narg(buf: &mut &[u8]) -> spongefish::VerificationResult<Self> {
261                        #field_deserializations
262                    }
263                }
264            }
265        }
266        _ => panic!("NargDeserialize can only be derived for structs"),
267    };
268
269    deserialize_impl
270}
271
272/// Derive [`Encoding`](https://docs.rs/spongefish/latest/spongefish/trait.Encoding.html) for structs.
273///
274/// Skipped fields fall back to `Default`.
275///
276/// ```
277/// use spongefish::Encoding;
278/// # use spongefish_derive::Encoding;
279///
280/// #[derive(Encoding)]
281/// struct Rgb {
282///     r: u8,
283///     g: u8,
284///     b: u8,
285/// }
286///
287/// let colors = Rgb { r: 1, g: 2, b: 3 };
288/// let data = colors.encode();
289/// assert_eq!(data.as_ref(), [1, 2, 3]);
290///
291/// ```
292#[proc_macro_derive(Encoding, attributes(spongefish))]
293pub fn derive_encoding(input: TokenStream) -> TokenStream {
294    let input = parse_macro_input!(input as DeriveInput);
295    TokenStream::from(generate_encoding_impl(&input))
296}
297
298/// Derive macro for the [`Decoding`](https://docs.rs/spongefish/latest/spongefish/trait.Decoding.html) trait.
299///
300/// Generates an implementation that decodes struct fields sequentially from a fixed-size buffer.
301/// Fields can be skipped using `#[spongefish(skip)]`.
302#[proc_macro_derive(Decoding, attributes(spongefish))]
303pub fn derive_decoding(input: TokenStream) -> TokenStream {
304    let input = parse_macro_input!(input as DeriveInput);
305    TokenStream::from(generate_decoding_impl(&input))
306}
307
308/// Derive macro for the [`NargDeserialize`](https://docs.rs/spongefish/latest/spongefish/trait.NargDeserialize.html) trait.
309///
310/// Generates an implementation that deserializes struct fields sequentially from a byte buffer.
311/// Fields can be skipped using `#[spongefish(skip)]`.
312#[proc_macro_derive(NargDeserialize, attributes(spongefish))]
313pub fn derive_narg_deserialize(input: TokenStream) -> TokenStream {
314    let input = parse_macro_input!(input as DeriveInput);
315    TokenStream::from(generate_narg_deserialize_impl(&input))
316}
317
318/// Derive macro that generates [`Encoding`](https://docs.rs/spongefish/latest/spongefish/trait.Encoding.html),
319/// [`Decoding`](https://docs.rs/spongefish/latest/spongefish/trait.Decoding.html), and
320/// [`NargDeserialize`](https://docs.rs/spongefish/latest/spongefish/trait.NargDeserialize.html) in one go.
321#[proc_macro_derive(Codec, attributes(spongefish))]
322pub fn derive_codec(input: TokenStream) -> TokenStream {
323    let input = parse_macro_input!(input as DeriveInput);
324    let encoding = generate_encoding_impl(&input);
325    let decoding = generate_decoding_impl(&input);
326    let deserialize = generate_narg_deserialize_impl(&input);
327
328    TokenStream::from(quote! {
329        #encoding
330        #decoding
331        #deserialize
332    })
333}
334
335/// Derive [`Unit`]s for structs.
336///
337/// ```
338/// use spongefish::Unit;
339/// # use spongefish_derive::Unit;
340///
341/// #[derive(Clone, Unit)]
342/// struct Rgb {
343///     r: u8,
344///     g: u8,
345///     b: u8,
346/// }
347///
348/// assert_eq!((Rgb::ZERO.r, Rgb::ZERO.g, Rgb::ZERO.b), (0, 0, 0));
349///
350/// ```
351///
352/// [Unit]: https://docs.rs/spongefish/latest/spongefish/trait.Unit.html
353#[proc_macro_derive(Unit, attributes(spongefish))]
354pub fn derive_unit(input: TokenStream) -> TokenStream {
355    let input = parse_macro_input!(input as DeriveInput);
356    let name = input.ident;
357    let mut generics = input.generics;
358
359    let (zero_expr, unit_bounds) = match input.data {
360        Data::Struct(data) => match data.fields {
361            Fields::Named(fields) => {
362                let mut zero_fields = Vec::new();
363                let mut unit_bounds = Vec::new();
364
365                for field in fields.named.iter() {
366                    let field_name = &field.ident;
367
368                    if has_skip_attribute(&field.attrs) {
369                        zero_fields.push(quote! {
370                            #field_name: ::core::default::Default::default(),
371                        });
372                        continue;
373                    }
374
375                    let ty: Type = field.ty.clone();
376                    unit_bounds.push(ty.clone());
377                    zero_fields.push(quote! {
378                        #field_name: <#ty as ::spongefish::Unit>::ZERO,
379                    });
380                }
381
382                (
383                    quote! {
384                        Self {
385                            #(#zero_fields)*
386                        }
387                    },
388                    unit_bounds,
389                )
390            }
391            Fields::Unnamed(fields) => {
392                let mut zero_fields = Vec::new();
393                let mut unit_bounds = Vec::new();
394
395                for field in fields.unnamed.iter() {
396                    if has_skip_attribute(&field.attrs) {
397                        zero_fields.push(quote! {
398                            ::core::default::Default::default()
399                        });
400                        continue;
401                    }
402
403                    let ty: Type = field.ty.clone();
404                    unit_bounds.push(ty.clone());
405                    zero_fields.push(quote! {
406                        <#ty as ::spongefish::Unit>::ZERO
407                    });
408                }
409
410                (
411                    quote! {
412                        Self(#(#zero_fields),*)
413                    },
414                    unit_bounds,
415                )
416            }
417            Fields::Unit => (quote!(Self), Vec::new()),
418        },
419        _ => panic!("Unit can only be derived for structs"),
420    };
421
422    let where_clause = generics.make_where_clause();
423    for ty in unit_bounds {
424        where_clause
425            .predicates
426            .push(parse_quote!(#ty: ::spongefish::Unit));
427    }
428
429    let (impl_generics, ty_generics, where_generics) = generics.split_for_impl();
430
431    let expanded = quote! {
432        impl #impl_generics ::spongefish::Unit for #name #ty_generics #where_generics {
433            const ZERO: Self = #zero_expr;
434        }
435    };
436
437    TokenStream::from(expanded)
438}
439
440/// Helper function to check if a field has the #[spongefish(skip)] attribute
441fn has_skip_attribute(attrs: &[syn::Attribute]) -> bool {
442    attrs.iter().any(|attr| {
443        if !attr.path().is_ident("spongefish") {
444            return false;
445        }
446
447        attr.parse_nested_meta(|meta| {
448            if meta.path.is_ident("skip") {
449                Ok(())
450            } else {
451                Err(meta.error("expected `skip`"))
452            }
453        })
454        .is_ok()
455    })
456}
457
458fn add_trait_bounds_for_fields(
459    mut generics: syn::Generics,
460    field_types: &[Type],
461    trait_bound: &TokenStream2,
462) -> syn::Generics {
463    if field_types.is_empty() {
464        return generics;
465    }
466
467    let where_clause = generics.make_where_clause();
468    for ty in field_types {
469        where_clause
470            .predicates
471            .push(parse_quote!(#ty: #trait_bound));
472    }
473
474    generics
475}