1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{parse_macro_input, parse_quote, Data, DeriveInput, Fields, Type};
5
6fn generate_encoding_impl(input: &DeriveInput) -> TokenStream2 {
7 let name = &input.ident;
8
9 let encoding_impl = match &input.data {
10 Data::Struct(data) => {
11 let mut encoding_bounds = Vec::new();
12 let field_encodings = match &data.fields {
13 Fields::Named(fields) => fields
14 .named
15 .iter()
16 .filter_map(|f| {
17 if has_skip_attribute(&f.attrs) {
18 return None;
19 }
20 let field_name = &f.ident;
21 encoding_bounds.push(f.ty.clone());
22 Some(quote! {
23 output.extend_from_slice(self.#field_name.encode().as_ref());
24 })
25 })
26 .collect::<Vec<_>>(),
27 Fields::Unnamed(fields) => fields
28 .unnamed
29 .iter()
30 .enumerate()
31 .filter_map(|(i, f)| {
32 if has_skip_attribute(&f.attrs) {
33 return None;
34 }
35 let index = syn::Index::from(i);
36 encoding_bounds.push(f.ty.clone());
37 Some(quote! {
38 output.extend_from_slice(self.#index.encode().as_ref());
39 })
40 })
41 .collect::<Vec<_>>(),
42 Fields::Unit => vec![],
43 };
44
45 let bound = quote!(::spongefish::Encoding<[u8]>);
46 let generics = add_trait_bounds_for_fields(input.generics.clone(), &encoding_bounds, &bound);
47 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
48
49 quote! {
50 impl #impl_generics ::spongefish::Encoding<[u8]> for #name #ty_generics #where_clause {
51 fn encode(&self) -> impl AsRef<[u8]> {
52 let mut output = ::std::vec::Vec::new();
53 #(#field_encodings)*
54 output
55 }
56 }
57 }
58 }
59 _ => panic!("Encoding can only be derived for structs"),
60 };
61
62 encoding_impl
63}
64
65fn generate_decoding_impl(input: &DeriveInput) -> TokenStream2 {
66 let name = &input.ident;
67
68 let decoding_impl = match &input.data {
69 Data::Struct(data) => {
70 let mut decoding_bounds = Vec::new();
71 let (size_calc, field_decodings) = match &data.fields {
72 Fields::Named(fields) => {
73 let mut offset = quote!(0usize);
74 let mut field_decodings = vec![];
75 let mut size_components = vec![];
76
77 for field in fields.named.iter() {
78 if has_skip_attribute(&field.attrs) {
79 let field_name = &field.ident;
80 field_decodings.push(quote! {
81 #field_name: Default::default(),
82 });
83 continue;
84 }
85
86 let field_name = &field.ident;
87 let field_type = &field.ty;
88 decoding_bounds.push(field_type.clone());
89
90 size_components.push(quote! {
91 ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>()
92 });
93
94 let current_offset = offset.clone();
95 field_decodings.push(quote! {
96 #field_name: {
97 let field_size = ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>();
98 let start = #current_offset;
99 let end = start + field_size;
100 let mut field_buf = <#field_type as spongefish::Decoding<[u8]>>::Repr::default();
101 field_buf.as_mut().copy_from_slice(&buf.as_ref()[start..end]);
102 <#field_type as spongefish::Decoding<[u8]>>::decode(field_buf)
103 },
104 });
105
106 offset = quote! {
107 #offset + <#field_type as spongefish::Decoding<[u8]>>::Repr::default().as_mut().len()
108 };
109 }
110
111 let size_calc = if size_components.is_empty() {
112 quote!(0usize)
113 } else {
114 quote!(#(#size_components)+*)
115 };
116
117 (
118 size_calc,
119 quote! {
120 Self {
121 #(#field_decodings)*
122 }
123 },
124 )
125 }
126 Fields::Unnamed(fields) => {
127 let mut offset = quote!(0usize);
128 let mut field_decodings = vec![];
129 let mut size_components = vec![];
130
131 for field in fields.unnamed.iter() {
132 if has_skip_attribute(&field.attrs) {
133 field_decodings.push(quote! {
134 Default::default(),
135 });
136 continue;
137 }
138
139 let field_type = &field.ty;
140 decoding_bounds.push(field_type.clone());
141
142 size_components.push(quote! {
143 ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>()
144 });
145
146 let current_offset = offset.clone();
147 field_decodings.push(quote! {
148 {
149 let field_size = ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>();
150 let start = #current_offset;
151 let end = start + field_size;
152 let mut field_buf = <#field_type as spongefish::Decoding<[u8]>>::Repr::default();
153 field_buf.as_mut().copy_from_slice(&buf.as_ref()[start..end]);
154 <#field_type as spongefish::Decoding<[u8]>>::decode(field_buf)
155 },
156 });
157
158 offset = quote! {
159 #offset + <#field_type as spongefish::Decoding<[u8]>>::Repr::default().as_mut().len()
160 };
161 }
162
163 let size_calc = if size_components.is_empty() {
164 quote!(0usize)
165 } else {
166 quote!(#(#size_components)+*)
167 };
168
169 (
170 size_calc,
171 quote! {
172 Self(#(#field_decodings)*)
173 },
174 )
175 }
176 Fields::Unit => (quote!(0usize), quote!(Self)),
177 };
178
179 let bound = quote!(::spongefish::Decoding<[u8]>);
180 let generics = add_trait_bounds_for_fields(input.generics.clone(), &decoding_bounds, &bound);
181 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
182
183 quote! {
184 impl #impl_generics ::spongefish::Decoding<[u8]> for #name #ty_generics #where_clause {
185 type Repr = spongefish::ByteArray<{ #size_calc }>;
186
187 fn decode(buf: Self::Repr) -> Self {
188 #field_decodings
189 }
190 }
191 }
192 }
193 _ => panic!("Decoding can only be derived for structs"),
194 };
195
196 decoding_impl
197}
198
199fn generate_narg_deserialize_impl(input: &DeriveInput) -> TokenStream2 {
200 let name = &input.ident;
201
202 let deserialize_impl = match &input.data {
203 Data::Struct(data) => {
204 let mut deserialize_bounds = Vec::new();
205 let field_deserializations = match &data.fields {
206 Fields::Named(fields) => {
207 let field_inits = fields.named.iter().map(|f| {
208 let field_name = &f.ident;
209 let field_type = &f.ty;
210
211 if has_skip_attribute(&f.attrs) {
212 quote! {
213 #field_name: Default::default(),
214 }
215 } else {
216 deserialize_bounds.push(field_type.clone());
217 quote! {
218 #field_name: <#field_type as spongefish::NargDeserialize>::deserialize_from_narg(buf)?,
219 }
220 }
221 });
222
223 quote! {
224 Ok(Self {
225 #(#field_inits)*
226 })
227 }
228 }
229 Fields::Unnamed(fields) => {
230 let field_inits = fields.unnamed.iter().map(|f| {
231 let field_type = &f.ty;
232
233 if has_skip_attribute(&f.attrs) {
234 quote! {
235 Default::default(),
236 }
237 } else {
238 deserialize_bounds.push(field_type.clone());
239 quote! {
240 <#field_type as spongefish::NargDeserialize>::deserialize_from_narg(buf)?,
241 }
242 }
243 });
244
245 quote! {
246 Ok(Self(#(#field_inits)*))
247 }
248 }
249 Fields::Unit => quote! {
250 Ok(Self)
251 },
252 };
253
254 let bound = quote!(::spongefish::NargDeserialize);
255 let generics = add_trait_bounds_for_fields(input.generics.clone(), &deserialize_bounds, &bound);
256 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
257
258 quote! {
259 impl #impl_generics ::spongefish::NargDeserialize for #name #ty_generics #where_clause {
260 fn deserialize_from_narg(buf: &mut &[u8]) -> spongefish::VerificationResult<Self> {
261 #field_deserializations
262 }
263 }
264 }
265 }
266 _ => panic!("NargDeserialize can only be derived for structs"),
267 };
268
269 deserialize_impl
270}
271
272#[proc_macro_derive(Encoding, attributes(spongefish))]
293pub fn derive_encoding(input: TokenStream) -> TokenStream {
294 let input = parse_macro_input!(input as DeriveInput);
295 TokenStream::from(generate_encoding_impl(&input))
296}
297
298#[proc_macro_derive(Decoding, attributes(spongefish))]
303pub fn derive_decoding(input: TokenStream) -> TokenStream {
304 let input = parse_macro_input!(input as DeriveInput);
305 TokenStream::from(generate_decoding_impl(&input))
306}
307
308#[proc_macro_derive(NargDeserialize, attributes(spongefish))]
313pub fn derive_narg_deserialize(input: TokenStream) -> TokenStream {
314 let input = parse_macro_input!(input as DeriveInput);
315 TokenStream::from(generate_narg_deserialize_impl(&input))
316}
317
318#[proc_macro_derive(Codec, attributes(spongefish))]
322pub fn derive_codec(input: TokenStream) -> TokenStream {
323 let input = parse_macro_input!(input as DeriveInput);
324 let encoding = generate_encoding_impl(&input);
325 let decoding = generate_decoding_impl(&input);
326 let deserialize = generate_narg_deserialize_impl(&input);
327
328 TokenStream::from(quote! {
329 #encoding
330 #decoding
331 #deserialize
332 })
333}
334
335#[proc_macro_derive(Unit, attributes(spongefish))]
354pub fn derive_unit(input: TokenStream) -> TokenStream {
355 let input = parse_macro_input!(input as DeriveInput);
356 let name = input.ident;
357 let mut generics = input.generics;
358
359 let (zero_expr, unit_bounds) = match input.data {
360 Data::Struct(data) => match data.fields {
361 Fields::Named(fields) => {
362 let mut zero_fields = Vec::new();
363 let mut unit_bounds = Vec::new();
364
365 for field in fields.named.iter() {
366 let field_name = &field.ident;
367
368 if has_skip_attribute(&field.attrs) {
369 zero_fields.push(quote! {
370 #field_name: ::core::default::Default::default(),
371 });
372 continue;
373 }
374
375 let ty: Type = field.ty.clone();
376 unit_bounds.push(ty.clone());
377 zero_fields.push(quote! {
378 #field_name: <#ty as ::spongefish::Unit>::ZERO,
379 });
380 }
381
382 (
383 quote! {
384 Self {
385 #(#zero_fields)*
386 }
387 },
388 unit_bounds,
389 )
390 }
391 Fields::Unnamed(fields) => {
392 let mut zero_fields = Vec::new();
393 let mut unit_bounds = Vec::new();
394
395 for field in fields.unnamed.iter() {
396 if has_skip_attribute(&field.attrs) {
397 zero_fields.push(quote! {
398 ::core::default::Default::default()
399 });
400 continue;
401 }
402
403 let ty: Type = field.ty.clone();
404 unit_bounds.push(ty.clone());
405 zero_fields.push(quote! {
406 <#ty as ::spongefish::Unit>::ZERO
407 });
408 }
409
410 (
411 quote! {
412 Self(#(#zero_fields),*)
413 },
414 unit_bounds,
415 )
416 }
417 Fields::Unit => (quote!(Self), Vec::new()),
418 },
419 _ => panic!("Unit can only be derived for structs"),
420 };
421
422 let where_clause = generics.make_where_clause();
423 for ty in unit_bounds {
424 where_clause
425 .predicates
426 .push(parse_quote!(#ty: ::spongefish::Unit));
427 }
428
429 let (impl_generics, ty_generics, where_generics) = generics.split_for_impl();
430
431 let expanded = quote! {
432 impl #impl_generics ::spongefish::Unit for #name #ty_generics #where_generics {
433 const ZERO: Self = #zero_expr;
434 }
435 };
436
437 TokenStream::from(expanded)
438}
439
440fn has_skip_attribute(attrs: &[syn::Attribute]) -> bool {
442 attrs.iter().any(|attr| {
443 if !attr.path().is_ident("spongefish") {
444 return false;
445 }
446
447 attr.parse_nested_meta(|meta| {
448 if meta.path.is_ident("skip") {
449 Ok(())
450 } else {
451 Err(meta.error("expected `skip`"))
452 }
453 })
454 .is_ok()
455 })
456}
457
458fn add_trait_bounds_for_fields(
459 mut generics: syn::Generics,
460 field_types: &[Type],
461 trait_bound: &TokenStream2,
462) -> syn::Generics {
463 if field_types.is_empty() {
464 return generics;
465 }
466
467 let where_clause = generics.make_where_clause();
468 for ty in field_types {
469 where_clause
470 .predicates
471 .push(parse_quote!(#ty: #trait_bound));
472 }
473
474 generics
475}