1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{parse_macro_input, parse_quote, Data, DeriveInput, Fields, Type};
5
6fn generate_encoding_impl(input: &DeriveInput) -> TokenStream2 {
7 let name = &input.ident;
8
9 let encoding_impl = match &input.data {
10 Data::Struct(data) => {
11 let mut encoding_bounds = Vec::new();
12 let field_encodings = match &data.fields {
13 Fields::Named(fields) => fields
14 .named
15 .iter()
16 .filter_map(|f| {
17 if has_skip_attribute(&f.attrs) {
18 return None;
19 }
20 let field_name = &f.ident;
21 encoding_bounds.push(f.ty.clone());
22 Some(quote! {
23 output.extend_from_slice(self.#field_name.encode().as_ref());
24 })
25 })
26 .collect::<Vec<_>>(),
27 Fields::Unnamed(fields) => fields
28 .unnamed
29 .iter()
30 .enumerate()
31 .filter_map(|(i, f)| {
32 if has_skip_attribute(&f.attrs) {
33 return None;
34 }
35 let index = syn::Index::from(i);
36 encoding_bounds.push(f.ty.clone());
37 Some(quote! {
38 output.extend_from_slice(self.#index.encode().as_ref());
39 })
40 })
41 .collect::<Vec<_>>(),
42 Fields::Unit => vec![],
43 };
44
45 let bound = quote!(::spongefish::Encoding<[u8]>);
46 let generics =
47 add_trait_bounds_for_fields(input.generics.clone(), &encoding_bounds, &bound);
48 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
49
50 quote! {
51 impl #impl_generics ::spongefish::Encoding<[u8]> for #name #ty_generics #where_clause {
52 fn encode(&self) -> impl AsRef<[u8]> {
53 let mut output = ::std::vec::Vec::new();
54 #(#field_encodings)*
55 output
56 }
57 }
58 }
59 }
60 _ => panic!("Encoding can only be derived for structs"),
61 };
62
63 encoding_impl
64}
65
66fn generate_decoding_impl(input: &DeriveInput) -> TokenStream2 {
67 let name = &input.ident;
68
69 let decoding_impl = match &input.data {
70 Data::Struct(data) => {
71 let mut decoding_bounds = Vec::new();
72 let (size_calc, field_decodings) = match &data.fields {
73 Fields::Named(fields) => {
74 let mut offset = quote!(0usize);
75 let mut field_decodings = vec![];
76 let mut size_components = vec![];
77
78 for field in fields.named.iter() {
79 if has_skip_attribute(&field.attrs) {
80 let field_name = &field.ident;
81 field_decodings.push(quote! {
82 #field_name: Default::default(),
83 });
84 continue;
85 }
86
87 let field_name = &field.ident;
88 let field_type = &field.ty;
89 decoding_bounds.push(field_type.clone());
90
91 size_components.push(quote! {
92 ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>()
93 });
94
95 let current_offset = offset.clone();
96 field_decodings.push(quote! {
97 #field_name: {
98 let field_size = ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>();
99 let start = #current_offset;
100 let end = start + field_size;
101 let mut field_buf = <#field_type as spongefish::Decoding<[u8]>>::Repr::default();
102 field_buf.as_mut().copy_from_slice(&buf.as_ref()[start..end]);
103 <#field_type as spongefish::Decoding<[u8]>>::decode(field_buf)
104 },
105 });
106
107 offset = quote! {
108 #offset + <#field_type as spongefish::Decoding<[u8]>>::Repr::default().as_mut().len()
109 };
110 }
111
112 let size_calc = if size_components.is_empty() {
113 quote!(0usize)
114 } else {
115 quote!(#(#size_components)+*)
116 };
117
118 (
119 size_calc,
120 quote! {
121 Self {
122 #(#field_decodings)*
123 }
124 },
125 )
126 }
127 Fields::Unnamed(fields) => {
128 let mut offset = quote!(0usize);
129 let mut field_decodings = vec![];
130 let mut size_components = vec![];
131
132 for field in fields.unnamed.iter() {
133 if has_skip_attribute(&field.attrs) {
134 field_decodings.push(quote! {
135 Default::default(),
136 });
137 continue;
138 }
139
140 let field_type = &field.ty;
141 decoding_bounds.push(field_type.clone());
142
143 size_components.push(quote! {
144 ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>()
145 });
146
147 let current_offset = offset.clone();
148 field_decodings.push(quote! {
149 {
150 let field_size = ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>();
151 let start = #current_offset;
152 let end = start + field_size;
153 let mut field_buf = <#field_type as spongefish::Decoding<[u8]>>::Repr::default();
154 field_buf.as_mut().copy_from_slice(&buf.as_ref()[start..end]);
155 <#field_type as spongefish::Decoding<[u8]>>::decode(field_buf)
156 },
157 });
158
159 offset = quote! {
160 #offset + <#field_type as spongefish::Decoding<[u8]>>::Repr::default().as_mut().len()
161 };
162 }
163
164 let size_calc = if size_components.is_empty() {
165 quote!(0usize)
166 } else {
167 quote!(#(#size_components)+*)
168 };
169
170 (
171 size_calc,
172 quote! {
173 Self(#(#field_decodings)*)
174 },
175 )
176 }
177 Fields::Unit => (quote!(0usize), quote!(Self)),
178 };
179
180 let bound = quote!(::spongefish::Decoding<[u8]>);
181 let generics =
182 add_trait_bounds_for_fields(input.generics.clone(), &decoding_bounds, &bound);
183 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
184
185 quote! {
186 impl #impl_generics ::spongefish::Decoding<[u8]> for #name #ty_generics #where_clause {
187 type Repr = spongefish::ByteArray<{ #size_calc }>;
188
189 fn decode(buf: Self::Repr) -> Self {
190 #field_decodings
191 }
192 }
193 }
194 }
195 _ => panic!("Decoding can only be derived for structs"),
196 };
197
198 decoding_impl
199}
200
201fn generate_narg_deserialize_impl(input: &DeriveInput) -> TokenStream2 {
202 let name = &input.ident;
203
204 let deserialize_impl = match &input.data {
205 Data::Struct(data) => {
206 let mut deserialize_bounds = Vec::new();
207 let field_deserializations = match &data.fields {
208 Fields::Named(fields) => {
209 let field_inits = fields.named.iter().map(|f| {
210 let field_name = &f.ident;
211 let field_type = &f.ty;
212
213 if has_skip_attribute(&f.attrs) {
214 quote! {
215 #field_name: Default::default(),
216 }
217 } else {
218 deserialize_bounds.push(field_type.clone());
219 quote! {
220 #field_name: <#field_type as spongefish::NargDeserialize>::deserialize_from_narg(buf)?,
221 }
222 }
223 });
224
225 quote! {
226 Ok(Self {
227 #(#field_inits)*
228 })
229 }
230 }
231 Fields::Unnamed(fields) => {
232 let field_inits = fields.unnamed.iter().map(|f| {
233 let field_type = &f.ty;
234
235 if has_skip_attribute(&f.attrs) {
236 quote! {
237 Default::default(),
238 }
239 } else {
240 deserialize_bounds.push(field_type.clone());
241 quote! {
242 <#field_type as spongefish::NargDeserialize>::deserialize_from_narg(buf)?,
243 }
244 }
245 });
246
247 quote! {
248 Ok(Self(#(#field_inits)*))
249 }
250 }
251 Fields::Unit => quote! {
252 Ok(Self)
253 },
254 };
255
256 let bound = quote!(::spongefish::NargDeserialize);
257 let generics =
258 add_trait_bounds_for_fields(input.generics.clone(), &deserialize_bounds, &bound);
259 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
260
261 quote! {
262 impl #impl_generics ::spongefish::NargDeserialize for #name #ty_generics #where_clause {
263 fn deserialize_from_narg(buf: &mut &[u8]) -> spongefish::VerificationResult<Self> {
264 #field_deserializations
265 }
266 }
267 }
268 }
269 _ => panic!("NargDeserialize can only be derived for structs"),
270 };
271
272 deserialize_impl
273}
274
275#[proc_macro_derive(Encoding, attributes(spongefish))]
296pub fn derive_encoding(input: TokenStream) -> TokenStream {
297 let input = parse_macro_input!(input as DeriveInput);
298 TokenStream::from(generate_encoding_impl(&input))
299}
300
301#[proc_macro_derive(Decoding, attributes(spongefish))]
306pub fn derive_decoding(input: TokenStream) -> TokenStream {
307 let input = parse_macro_input!(input as DeriveInput);
308 TokenStream::from(generate_decoding_impl(&input))
309}
310
311#[proc_macro_derive(NargDeserialize, attributes(spongefish))]
316pub fn derive_narg_deserialize(input: TokenStream) -> TokenStream {
317 let input = parse_macro_input!(input as DeriveInput);
318 TokenStream::from(generate_narg_deserialize_impl(&input))
319}
320
321#[proc_macro_derive(Codec, attributes(spongefish))]
325pub fn derive_codec(input: TokenStream) -> TokenStream {
326 let input = parse_macro_input!(input as DeriveInput);
327 let encoding = generate_encoding_impl(&input);
328 let decoding = generate_decoding_impl(&input);
329 let deserialize = generate_narg_deserialize_impl(&input);
330
331 TokenStream::from(quote! {
332 #encoding
333 #decoding
334 #deserialize
335 })
336}
337
338#[proc_macro_derive(Unit, attributes(spongefish))]
357pub fn derive_unit(input: TokenStream) -> TokenStream {
358 let input = parse_macro_input!(input as DeriveInput);
359 let name = input.ident;
360 let mut generics = input.generics;
361
362 let (zero_expr, unit_bounds) = match input.data {
363 Data::Struct(data) => match data.fields {
364 Fields::Named(fields) => {
365 let mut zero_fields = Vec::new();
366 let mut unit_bounds = Vec::new();
367
368 for field in fields.named.iter() {
369 let field_name = &field.ident;
370
371 if has_skip_attribute(&field.attrs) {
372 zero_fields.push(quote! {
373 #field_name: ::core::default::Default::default(),
374 });
375 continue;
376 }
377
378 let ty: Type = field.ty.clone();
379 unit_bounds.push(ty.clone());
380 zero_fields.push(quote! {
381 #field_name: <#ty as ::spongefish::Unit>::ZERO,
382 });
383 }
384
385 (
386 quote! {
387 Self {
388 #(#zero_fields)*
389 }
390 },
391 unit_bounds,
392 )
393 }
394 Fields::Unnamed(fields) => {
395 let mut zero_fields = Vec::new();
396 let mut unit_bounds = Vec::new();
397
398 for field in fields.unnamed.iter() {
399 if has_skip_attribute(&field.attrs) {
400 zero_fields.push(quote! {
401 ::core::default::Default::default()
402 });
403 continue;
404 }
405
406 let ty: Type = field.ty.clone();
407 unit_bounds.push(ty.clone());
408 zero_fields.push(quote! {
409 <#ty as ::spongefish::Unit>::ZERO
410 });
411 }
412
413 (
414 quote! {
415 Self(#(#zero_fields),*)
416 },
417 unit_bounds,
418 )
419 }
420 Fields::Unit => (quote!(Self), Vec::new()),
421 },
422 _ => panic!("Unit can only be derived for structs"),
423 };
424
425 let where_clause = generics.make_where_clause();
426 for ty in unit_bounds {
427 where_clause
428 .predicates
429 .push(parse_quote!(#ty: ::spongefish::Unit));
430 }
431
432 let (impl_generics, ty_generics, where_generics) = generics.split_for_impl();
433
434 let expanded = quote! {
435 impl #impl_generics ::spongefish::Unit for #name #ty_generics #where_generics {
436 const ZERO: Self = #zero_expr;
437 }
438 };
439
440 TokenStream::from(expanded)
441}
442
443fn has_skip_attribute(attrs: &[syn::Attribute]) -> bool {
445 attrs.iter().any(|attr| {
446 if !attr.path().is_ident("spongefish") {
447 return false;
448 }
449
450 attr.parse_nested_meta(|meta| {
451 if meta.path.is_ident("skip") {
452 Ok(())
453 } else {
454 Err(meta.error("expected `skip`"))
455 }
456 })
457 .is_ok()
458 })
459}
460
461fn add_trait_bounds_for_fields(
462 mut generics: syn::Generics,
463 field_types: &[Type],
464 trait_bound: &TokenStream2,
465) -> syn::Generics {
466 if field_types.is_empty() {
467 return generics;
468 }
469
470 let where_clause = generics.make_where_clause();
471 for ty in field_types {
472 where_clause
473 .predicates
474 .push(parse_quote!(#ty: #trait_bound));
475 }
476
477 generics
478}