1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{parse_macro_input, parse_quote, Data, DeriveInput, Fields, Type};
5
6fn generate_encoding_impl(input: &DeriveInput) -> TokenStream2 {
7 let name = &input.ident;
8
9 let encoding_impl = match &input.data {
10 Data::Struct(data) => {
11 let mut encoding_bounds = Vec::new();
12 let field_encodings = match &data.fields {
13 Fields::Named(fields) => fields
14 .named
15 .iter()
16 .filter_map(|f| {
17 if has_skip_attribute(&f.attrs) {
18 return None;
19 }
20 let field_name = &f.ident;
21 encoding_bounds.push(f.ty.clone());
22 Some(quote! {
23 output.extend_from_slice(self.#field_name.encode().as_ref());
24 })
25 })
26 .collect::<Vec<_>>(),
27 Fields::Unnamed(fields) => fields
28 .unnamed
29 .iter()
30 .enumerate()
31 .filter_map(|(i, f)| {
32 if has_skip_attribute(&f.attrs) {
33 return None;
34 }
35 let index = syn::Index::from(i);
36 encoding_bounds.push(f.ty.clone());
37 Some(quote! {
38 output.extend_from_slice(self.#index.encode().as_ref());
39 })
40 })
41 .collect::<Vec<_>>(),
42 Fields::Unit => vec![],
43 };
44
45 let bound = quote!(::spongefish::Encoding<[u8]>);
46 let generics =
47 add_trait_bounds_for_fields(input.generics.clone(), &encoding_bounds, &bound);
48 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
49
50 quote! {
51 impl #impl_generics ::spongefish::Encoding<[u8]> for #name #ty_generics #where_clause {
52 fn encode(&self) -> impl AsRef<[u8]> {
53 let mut output = ::std::vec::Vec::new();
54 #(#field_encodings)*
55 output
56 }
57 }
58 }
59 }
60 _ => panic!("Encoding can only be derived for structs"),
61 };
62
63 encoding_impl
64}
65
66fn generate_decoding_impl(input: &DeriveInput) -> TokenStream2 {
67 let name = &input.ident;
68
69 let decoding_impl = match &input.data {
70 Data::Struct(data) => {
71 let mut decoding_bounds = Vec::new();
72 let (size_calc, field_decodings) = match &data.fields {
73 Fields::Named(fields) => {
74 let mut offset = quote!(0usize);
75 let mut field_decodings = vec![];
76 let mut size_components = vec![];
77
78 for field in fields.named.iter() {
79 if has_skip_attribute(&field.attrs) {
80 let field_name = &field.ident;
81 field_decodings.push(quote! {
82 #field_name: Default::default(),
83 });
84 continue;
85 }
86
87 let field_name = &field.ident;
88 let field_type = &field.ty;
89 decoding_bounds.push(field_type.clone());
90
91 size_components.push(quote! {
92 ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>()
93 });
94
95 let current_offset = offset.clone();
96 field_decodings.push(quote! {
97 #field_name: {
98 let field_size = ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>();
99 let start = #current_offset;
100 let end = start + field_size;
101 let mut field_buf = <#field_type as spongefish::Decoding<[u8]>>::Repr::default();
102 field_buf.as_mut().copy_from_slice(&buf.as_ref()[start..end]);
103 <#field_type as spongefish::Decoding<[u8]>>::decode(field_buf)
104 },
105 });
106
107 offset = quote! {
108 #offset + <#field_type as spongefish::Decoding<[u8]>>::Repr::default().as_mut().len()
109 };
110 }
111
112 let size_calc = if size_components.is_empty() {
113 quote!(0usize)
114 } else {
115 quote!(#(#size_components)+*)
116 };
117
118 (
119 size_calc,
120 quote! {
121 Self {
122 #(#field_decodings)*
123 }
124 },
125 )
126 }
127 Fields::Unnamed(fields) => {
128 let mut offset = quote!(0usize);
129 let mut field_decodings = vec![];
130 let mut size_components = vec![];
131
132 for field in fields.unnamed.iter() {
133 if has_skip_attribute(&field.attrs) {
134 field_decodings.push(quote! {
135 Default::default(),
136 });
137 continue;
138 }
139
140 let field_type = &field.ty;
141 decoding_bounds.push(field_type.clone());
142
143 size_components.push(quote! {
144 ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>()
145 });
146
147 let current_offset = offset.clone();
148 field_decodings.push(quote! {
149 {
150 let field_size = ::core::mem::size_of::<<#field_type as spongefish::Decoding<[u8]>>::Repr>();
151 let start = #current_offset;
152 let end = start + field_size;
153 let mut field_buf = <#field_type as spongefish::Decoding<[u8]>>::Repr::default();
154 field_buf.as_mut().copy_from_slice(&buf.as_ref()[start..end]);
155 <#field_type as spongefish::Decoding<[u8]>>::decode(field_buf)
156 },
157 });
158
159 offset = quote! {
160 #offset + <#field_type as spongefish::Decoding<[u8]>>::Repr::default().as_mut().len()
161 };
162 }
163
164 let size_calc = if size_components.is_empty() {
165 quote!(0usize)
166 } else {
167 quote!(#(#size_components)+*)
168 };
169
170 (
171 size_calc,
172 quote! {
173 Self(#(#field_decodings)*)
174 },
175 )
176 }
177 Fields::Unit => (quote!(0usize), quote!(Self)),
178 };
179
180 let bound = quote!(::spongefish::Decoding<[u8]>);
181 let generics =
182 add_trait_bounds_for_fields(input.generics.clone(), &decoding_bounds, &bound);
183 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
184
185 quote! {
186 impl #impl_generics ::spongefish::Decoding<[u8]> for #name #ty_generics #where_clause {
187 type Repr = spongefish::ByteArray<{ #size_calc }>;
188
189 fn decode(buf: Self::Repr) -> Self {
190 #field_decodings
191 }
192 }
193 }
194 }
195 _ => panic!("Decoding can only be derived for structs"),
196 };
197
198 decoding_impl
199}
200
201fn generate_narg_deserialize_impl(input: &DeriveInput) -> TokenStream2 {
202 let name = &input.ident;
203
204 let deserialize_impl = match &input.data {
205 Data::Struct(data) => {
206 let mut deserialize_bounds = Vec::new();
207 let field_deserializations = match &data.fields {
208 Fields::Named(fields) => {
209 let field_inits = fields.named.iter().map(|f| {
210 let field_name = &f.ident;
211 let field_type = &f.ty;
212
213 if has_skip_attribute(&f.attrs) {
214 quote! {
215 #field_name: Default::default(),
216 }
217 } else {
218 deserialize_bounds.push(field_type.clone());
219 quote! {
220 #field_name: <#field_type as spongefish::NargDeserialize>::deserialize_from_narg(&mut rest)?,
221 }
222 }
223 });
224
225 quote! {
226 Ok(Self {
227 #(#field_inits)*
228 })
229 }
230 }
231 Fields::Unnamed(fields) => {
232 let field_inits = fields.unnamed.iter().map(|f| {
233 let field_type = &f.ty;
234
235 if has_skip_attribute(&f.attrs) {
236 quote! {
237 Default::default(),
238 }
239 } else {
240 deserialize_bounds.push(field_type.clone());
241 quote! {
242 <#field_type as spongefish::NargDeserialize>::deserialize_from_narg(&mut rest)?,
243 }
244 }
245 });
246
247 quote! {
248 Ok(Self(#(#field_inits)*))
249 }
250 }
251 Fields::Unit => quote! {
252 Ok(Self)
253 },
254 };
255
256 let bound = quote!(::spongefish::NargDeserialize);
257 let generics =
258 add_trait_bounds_for_fields(input.generics.clone(), &deserialize_bounds, &bound);
259 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
260
261 quote! {
262 impl #impl_generics ::spongefish::NargDeserialize for #name #ty_generics #where_clause {
263 fn deserialize_from_narg(buf: &mut &[u8]) -> spongefish::VerificationResult<Self> {
264 let mut rest = *buf;
265 let value = (|| -> spongefish::VerificationResult<Self> {
266 #field_deserializations
267 })()?;
268 *buf = rest;
269 Ok(value)
270 }
271 }
272 }
273 }
274 _ => panic!("NargDeserialize can only be derived for structs"),
275 };
276
277 deserialize_impl
278}
279
280#[proc_macro_derive(Encoding, attributes(spongefish))]
301pub fn derive_encoding(input: TokenStream) -> TokenStream {
302 let input = parse_macro_input!(input as DeriveInput);
303 TokenStream::from(generate_encoding_impl(&input))
304}
305
306#[proc_macro_derive(Decoding, attributes(spongefish))]
311pub fn derive_decoding(input: TokenStream) -> TokenStream {
312 let input = parse_macro_input!(input as DeriveInput);
313 TokenStream::from(generate_decoding_impl(&input))
314}
315
316#[proc_macro_derive(NargDeserialize, attributes(spongefish))]
321pub fn derive_narg_deserialize(input: TokenStream) -> TokenStream {
322 let input = parse_macro_input!(input as DeriveInput);
323 TokenStream::from(generate_narg_deserialize_impl(&input))
324}
325
326#[proc_macro_derive(Codec, attributes(spongefish))]
330pub fn derive_codec(input: TokenStream) -> TokenStream {
331 let input = parse_macro_input!(input as DeriveInput);
332 let encoding = generate_encoding_impl(&input);
333 let decoding = generate_decoding_impl(&input);
334 let deserialize = generate_narg_deserialize_impl(&input);
335
336 TokenStream::from(quote! {
337 #encoding
338 #decoding
339 #deserialize
340 })
341}
342
343#[proc_macro_derive(Unit, attributes(spongefish))]
362pub fn derive_unit(input: TokenStream) -> TokenStream {
363 let input = parse_macro_input!(input as DeriveInput);
364 let name = input.ident;
365 let mut generics = input.generics;
366
367 let (zero_expr, unit_bounds) = match input.data {
368 Data::Struct(data) => match data.fields {
369 Fields::Named(fields) => {
370 let mut zero_fields = Vec::new();
371 let mut unit_bounds = Vec::new();
372
373 for field in fields.named.iter() {
374 let field_name = &field.ident;
375
376 if has_skip_attribute(&field.attrs) {
377 zero_fields.push(quote! {
378 #field_name: ::core::default::Default::default(),
379 });
380 continue;
381 }
382
383 let ty: Type = field.ty.clone();
384 unit_bounds.push(ty.clone());
385 zero_fields.push(quote! {
386 #field_name: <#ty as ::spongefish::Unit>::ZERO,
387 });
388 }
389
390 (
391 quote! {
392 Self {
393 #(#zero_fields)*
394 }
395 },
396 unit_bounds,
397 )
398 }
399 Fields::Unnamed(fields) => {
400 let mut zero_fields = Vec::new();
401 let mut unit_bounds = Vec::new();
402
403 for field in fields.unnamed.iter() {
404 if has_skip_attribute(&field.attrs) {
405 zero_fields.push(quote! {
406 ::core::default::Default::default()
407 });
408 continue;
409 }
410
411 let ty: Type = field.ty.clone();
412 unit_bounds.push(ty.clone());
413 zero_fields.push(quote! {
414 <#ty as ::spongefish::Unit>::ZERO
415 });
416 }
417
418 (
419 quote! {
420 Self(#(#zero_fields),*)
421 },
422 unit_bounds,
423 )
424 }
425 Fields::Unit => (quote!(Self), Vec::new()),
426 },
427 _ => panic!("Unit can only be derived for structs"),
428 };
429
430 let where_clause = generics.make_where_clause();
431 for ty in unit_bounds {
432 where_clause
433 .predicates
434 .push(parse_quote!(#ty: ::spongefish::Unit));
435 }
436
437 let (impl_generics, ty_generics, where_generics) = generics.split_for_impl();
438
439 let expanded = quote! {
440 impl #impl_generics ::spongefish::Unit for #name #ty_generics #where_generics {
441 const ZERO: Self = #zero_expr;
442 }
443 };
444
445 TokenStream::from(expanded)
446}
447
448fn has_skip_attribute(attrs: &[syn::Attribute]) -> bool {
450 attrs.iter().any(|attr| {
451 if !attr.path().is_ident("spongefish") {
452 return false;
453 }
454
455 attr.parse_nested_meta(|meta| {
456 if meta.path.is_ident("skip") {
457 Ok(())
458 } else {
459 Err(meta.error("expected `skip`"))
460 }
461 })
462 .is_ok()
463 })
464}
465
466fn add_trait_bounds_for_fields(
467 mut generics: syn::Generics,
468 field_types: &[Type],
469 trait_bound: &TokenStream2,
470) -> syn::Generics {
471 if field_types.is_empty() {
472 return generics;
473 }
474
475 let where_clause = generics.make_where_clause();
476 for ty in field_types {
477 where_clause
478 .predicates
479 .push(parse_quote!(#ty: #trait_bound));
480 }
481
482 generics
483}