Skip to main content

spongefish_derive/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2
3use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use quote::quote;
6use syn::{parse_macro_input, parse_quote, Data, DeriveInput, Fields, Type};
7
8fn generate_encoding_impl(input: &DeriveInput) -> TokenStream2 {
9    let name = &input.ident;
10
11    let encoding_impl = match &input.data {
12        Data::Struct(data) => {
13            let mut encoding_bounds = Vec::new();
14            let field_encodings = match &data.fields {
15                Fields::Named(fields) => fields
16                    .named
17                    .iter()
18                    .filter_map(|f| {
19                        if has_skip_attribute(&f.attrs) {
20                            return None;
21                        }
22                        let field_name = &f.ident;
23                        encoding_bounds.push(f.ty.clone());
24                        Some(quote! {
25                            output.extend_from_slice(self.#field_name.encode().as_ref());
26                        })
27                    })
28                    .collect::<Vec<_>>(),
29                Fields::Unnamed(fields) => fields
30                    .unnamed
31                    .iter()
32                    .enumerate()
33                    .filter_map(|(i, f)| {
34                        if has_skip_attribute(&f.attrs) {
35                            return None;
36                        }
37                        let index = syn::Index::from(i);
38                        encoding_bounds.push(f.ty.clone());
39                        Some(quote! {
40                            output.extend_from_slice(self.#index.encode().as_ref());
41                        })
42                    })
43                    .collect::<Vec<_>>(),
44                Fields::Unit => vec![],
45            };
46
47            let bound = quote!(::spongefish::Encoding<[u8]>);
48            let generics =
49                add_trait_bounds_for_fields(input.generics.clone(), &encoding_bounds, &bound);
50            let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
51
52            quote! {
53                impl #impl_generics ::spongefish::Encoding<[u8]> for #name #ty_generics #where_clause {
54                    fn encode(&self) -> impl AsRef<[u8]> {
55                        let mut output = ::std::vec::Vec::new();
56                        #(#field_encodings)*
57                        output
58                    }
59                }
60            }
61        }
62        _ => panic!("Encoding can only be derived for structs"),
63    };
64
65    encoding_impl
66}
67
68fn generate_decoding_impl(input: &DeriveInput) -> TokenStream2 {
69    let name = &input.ident;
70
71    let decoding_impl = match &input.data {
72        Data::Struct(data) => {
73            let mut decoding_bounds = Vec::new();
74            let (size_calc, field_decodings) = match &data.fields {
75                Fields::Named(fields) => {
76                    let mut offset = quote!(0usize);
77                    let mut field_decodings = vec![];
78                    let mut size_components = vec![];
79
80                    for field in fields.named.iter() {
81                        if has_skip_attribute(&field.attrs) {
82                            let field_name = &field.ident;
83                            field_decodings.push(quote! {
84                                #field_name: Default::default(),
85                            });
86                            continue;
87                        }
88
89                        let field_name = &field.ident;
90                        let field_type = &field.ty;
91                        decoding_bounds.push(field_type.clone());
92
93                        size_components.push(quote! {
94                            ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>()
95                        });
96
97                        let current_offset = offset.clone();
98                        field_decodings.push(quote! {
99                            #field_name: {
100                                let field_size = ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>();
101                                let start = #current_offset;
102                                let end = start + field_size;
103                                let mut field_buf = <#field_type as spongefish::Decoding<[u8]>>::Repr::default();
104                                field_buf.as_mut().copy_from_slice(&buf.as_ref()[start..end]);
105                                <#field_type as spongefish::Decoding<[u8]>>::decode(field_buf)
106                            },
107                        });
108
109                        offset = quote! {
110                            #offset + <#field_type as spongefish::Decoding<[u8]>>::Repr::default().as_mut().len()
111                        };
112                    }
113
114                    let size_calc = if size_components.is_empty() {
115                        quote!(0usize)
116                    } else {
117                        quote!(#(#size_components)+*)
118                    };
119
120                    (
121                        size_calc,
122                        quote! {
123                            Self {
124                                #(#field_decodings)*
125                            }
126                        },
127                    )
128                }
129                Fields::Unnamed(fields) => {
130                    let mut offset = quote!(0usize);
131                    let mut field_decodings = vec![];
132                    let mut size_components = vec![];
133
134                    for field in fields.unnamed.iter() {
135                        if has_skip_attribute(&field.attrs) {
136                            field_decodings.push(quote! {
137                                Default::default(),
138                            });
139                            continue;
140                        }
141
142                        let field_type = &field.ty;
143                        decoding_bounds.push(field_type.clone());
144
145                        size_components.push(quote! {
146                            ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>()
147                        });
148
149                        let current_offset = offset.clone();
150                        field_decodings.push(quote! {
151                            {
152                                let field_size = ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>();
153                                let start = #current_offset;
154                                let end = start + field_size;
155                                let mut field_buf = <#field_type as spongefish::Decoding<[u8]>>::Repr::default();
156                                field_buf.as_mut().copy_from_slice(&buf.as_ref()[start..end]);
157                                <#field_type as spongefish::Decoding<[u8]>>::decode(field_buf)
158                            },
159                        });
160
161                        offset = quote! {
162                            #offset + <#field_type as spongefish::Decoding<[u8]>>::Repr::default().as_mut().len()
163                        };
164                    }
165
166                    let size_calc = if size_components.is_empty() {
167                        quote!(0usize)
168                    } else {
169                        quote!(#(#size_components)+*)
170                    };
171
172                    (
173                        size_calc,
174                        quote! {
175                            Self(#(#field_decodings)*)
176                        },
177                    )
178                }
179                Fields::Unit => (quote!(0usize), quote!(Self)),
180            };
181
182            let bound = quote!(::spongefish::Decoding<[u8]>);
183            let generics =
184                add_trait_bounds_for_fields(input.generics.clone(), &decoding_bounds, &bound);
185            let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
186
187            quote! {
188                impl #impl_generics ::spongefish::Decoding<[u8]> for #name #ty_generics #where_clause {
189                    type Repr = spongefish::ByteArray<{ #size_calc }>;
190
191                    fn decode(buf: Self::Repr) -> Self {
192                        #field_decodings
193                    }
194                }
195            }
196        }
197        _ => panic!("Decoding can only be derived for structs"),
198    };
199
200    decoding_impl
201}
202
203fn generate_narg_deserialize_impl(input: &DeriveInput) -> TokenStream2 {
204    let name = &input.ident;
205
206    let deserialize_impl = match &input.data {
207        Data::Struct(data) => {
208            let mut deserialize_bounds = Vec::new();
209            let field_deserializations = match &data.fields {
210                Fields::Named(fields) => {
211                    let field_inits = fields.named.iter().map(|f| {
212                        let field_name = &f.ident;
213                        let field_type = &f.ty;
214
215                        if has_skip_attribute(&f.attrs) {
216                            quote! {
217                                #field_name: Default::default(),
218                            }
219                        } else {
220                            deserialize_bounds.push(field_type.clone());
221                            quote! {
222                                #field_name: <#field_type as spongefish::NargDeserialize>::deserialize_from_narg(&mut rest)?,
223                            }
224                        }
225                    });
226
227                    quote! {
228                        Ok(Self {
229                            #(#field_inits)*
230                        })
231                    }
232                }
233                Fields::Unnamed(fields) => {
234                    let field_inits = fields.unnamed.iter().map(|f| {
235                        let field_type = &f.ty;
236
237                        if has_skip_attribute(&f.attrs) {
238                            quote! {
239                                Default::default(),
240                            }
241                        } else {
242                            deserialize_bounds.push(field_type.clone());
243                            quote! {
244                                <#field_type as spongefish::NargDeserialize>::deserialize_from_narg(&mut rest)?,
245                            }
246                        }
247                    });
248
249                    quote! {
250                        Ok(Self(#(#field_inits)*))
251                    }
252                }
253                Fields::Unit => quote! {
254                    Ok(Self)
255                },
256            };
257
258            let bound = quote!(::spongefish::NargDeserialize);
259            let generics =
260                add_trait_bounds_for_fields(input.generics.clone(), &deserialize_bounds, &bound);
261            let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
262
263            quote! {
264                impl #impl_generics ::spongefish::NargDeserialize for #name #ty_generics #where_clause {
265                    fn deserialize_from_narg(buf: &mut &[u8]) -> spongefish::VerificationResult<Self> {
266                        let mut rest = *buf;
267                        let value = (|| -> spongefish::VerificationResult<Self> {
268                            #field_deserializations
269                        })()?;
270                        *buf = rest;
271                        Ok(value)
272                    }
273                }
274            }
275        }
276        _ => panic!("NargDeserialize can only be derived for structs"),
277    };
278
279    deserialize_impl
280}
281
282/// Derive [`Encoding`](https://docs.rs/spongefish/latest/spongefish/trait.Encoding.html) for structs.
283///
284/// Skipped fields fall back to `Default`.
285///
286/// ```
287/// use spongefish::Encoding;
288/// # use spongefish_derive::Encoding;
289///
290/// #[derive(Encoding)]
291/// struct Rgb {
292///     r: u8,
293///     g: u8,
294///     b: u8,
295/// }
296///
297/// let colors = Rgb { r: 1, g: 2, b: 3 };
298/// let data = colors.encode();
299/// assert_eq!(data.as_ref(), [1, 2, 3]);
300///
301/// ```
302#[proc_macro_derive(Encoding, attributes(spongefish))]
303pub fn derive_encoding(input: TokenStream) -> TokenStream {
304    let input = parse_macro_input!(input as DeriveInput);
305    TokenStream::from(generate_encoding_impl(&input))
306}
307
308/// Derive macro for the [`Decoding`](https://docs.rs/spongefish/latest/spongefish/trait.Decoding.html) trait.
309///
310/// Generates an implementation that decodes struct fields sequentially from a fixed-size buffer.
311/// Fields can be skipped using `#[spongefish(skip)]`.
312#[proc_macro_derive(Decoding, attributes(spongefish))]
313pub fn derive_decoding(input: TokenStream) -> TokenStream {
314    let input = parse_macro_input!(input as DeriveInput);
315    TokenStream::from(generate_decoding_impl(&input))
316}
317
318/// Derive macro for the [`NargDeserialize`](https://docs.rs/spongefish/latest/spongefish/trait.NargDeserialize.html) trait.
319///
320/// Generates an implementation that deserializes struct fields sequentially from a byte buffer.
321/// Fields can be skipped using `#[spongefish(skip)]`.
322#[proc_macro_derive(NargDeserialize, attributes(spongefish))]
323pub fn derive_narg_deserialize(input: TokenStream) -> TokenStream {
324    let input = parse_macro_input!(input as DeriveInput);
325    TokenStream::from(generate_narg_deserialize_impl(&input))
326}
327
328/// Derive macro that generates [`Encoding`](https://docs.rs/spongefish/latest/spongefish/trait.Encoding.html),
329/// [`Decoding`](https://docs.rs/spongefish/latest/spongefish/trait.Decoding.html), and
330/// [`NargDeserialize`](https://docs.rs/spongefish/latest/spongefish/trait.NargDeserialize.html) in one go.
331#[proc_macro_derive(Codec, attributes(spongefish))]
332pub fn derive_codec(input: TokenStream) -> TokenStream {
333    let input = parse_macro_input!(input as DeriveInput);
334    let encoding = generate_encoding_impl(&input);
335    let decoding = generate_decoding_impl(&input);
336    let deserialize = generate_narg_deserialize_impl(&input);
337
338    TokenStream::from(quote! {
339        #encoding
340        #decoding
341        #deserialize
342    })
343}
344
345/// Derive [`Unit`]s for structs.
346///
347/// ```
348/// use spongefish::Unit;
349/// # use spongefish_derive::Unit;
350///
351/// #[derive(Clone, Unit)]
352/// struct Rgb {
353///     r: u8,
354///     g: u8,
355///     b: u8,
356/// }
357///
358/// assert_eq!((Rgb::ZERO.r, Rgb::ZERO.g, Rgb::ZERO.b), (0, 0, 0));
359///
360/// ```
361///
362/// [Unit]: https://docs.rs/spongefish/latest/spongefish/trait.Unit.html
363#[proc_macro_derive(Unit, attributes(spongefish))]
364pub fn derive_unit(input: TokenStream) -> TokenStream {
365    let input = parse_macro_input!(input as DeriveInput);
366    let name = input.ident;
367    let mut generics = input.generics;
368
369    let (zero_expr, unit_bounds) = match input.data {
370        Data::Struct(data) => match data.fields {
371            Fields::Named(fields) => {
372                let mut zero_fields = Vec::new();
373                let mut unit_bounds = Vec::new();
374
375                for field in fields.named.iter() {
376                    let field_name = &field.ident;
377
378                    if has_skip_attribute(&field.attrs) {
379                        zero_fields.push(quote! {
380                            #field_name: ::core::default::Default::default(),
381                        });
382                        continue;
383                    }
384
385                    let ty: Type = field.ty.clone();
386                    unit_bounds.push(ty.clone());
387                    zero_fields.push(quote! {
388                        #field_name: <#ty as ::spongefish::Unit>::ZERO,
389                    });
390                }
391
392                (
393                    quote! {
394                        Self {
395                            #(#zero_fields)*
396                        }
397                    },
398                    unit_bounds,
399                )
400            }
401            Fields::Unnamed(fields) => {
402                let mut zero_fields = Vec::new();
403                let mut unit_bounds = Vec::new();
404
405                for field in fields.unnamed.iter() {
406                    if has_skip_attribute(&field.attrs) {
407                        zero_fields.push(quote! {
408                            ::core::default::Default::default()
409                        });
410                        continue;
411                    }
412
413                    let ty: Type = field.ty.clone();
414                    unit_bounds.push(ty.clone());
415                    zero_fields.push(quote! {
416                        <#ty as ::spongefish::Unit>::ZERO
417                    });
418                }
419
420                (
421                    quote! {
422                        Self(#(#zero_fields),*)
423                    },
424                    unit_bounds,
425                )
426            }
427            Fields::Unit => (quote!(Self), Vec::new()),
428        },
429        _ => panic!("Unit can only be derived for structs"),
430    };
431
432    let where_clause = generics.make_where_clause();
433    for ty in unit_bounds {
434        where_clause
435            .predicates
436            .push(parse_quote!(#ty: ::spongefish::Unit));
437    }
438
439    let (impl_generics, ty_generics, where_generics) = generics.split_for_impl();
440
441    let expanded = quote! {
442        impl #impl_generics ::spongefish::Unit for #name #ty_generics #where_generics {
443            const ZERO: Self = #zero_expr;
444        }
445    };
446
447    TokenStream::from(expanded)
448}
449
450/// Helper function to check if a field has the #[spongefish(skip)] attribute
451fn has_skip_attribute(attrs: &[syn::Attribute]) -> bool {
452    attrs.iter().any(|attr| {
453        if !attr.path().is_ident("spongefish") {
454            return false;
455        }
456
457        attr.parse_nested_meta(|meta| {
458            if meta.path.is_ident("skip") {
459                Ok(())
460            } else {
461                Err(meta.error("expected `skip`"))
462            }
463        })
464        .is_ok()
465    })
466}
467
468fn add_trait_bounds_for_fields(
469    mut generics: syn::Generics,
470    field_types: &[Type],
471    trait_bound: &TokenStream2,
472) -> syn::Generics {
473    if field_types.is_empty() {
474        return generics;
475    }
476
477    let where_clause = generics.make_where_clause();
478    for ty in field_types {
479        where_clause
480            .predicates
481            .push(parse_quote!(#ty: #trait_bound));
482    }
483
484    generics
485}