educe/trait_handlers/default/
default_enum.rs
1use quote::quote;
2use syn::{spanned::Spanned, Data, DeriveInput, Fields, Meta, Type, Variant};
3
4use super::{
5 models::{FieldAttributeBuilder, TypeAttributeBuilder},
6 TraitHandler,
7};
8use crate::Trait;
9
10pub(crate) struct DefaultEnumHandler;
11
12impl TraitHandler for DefaultEnumHandler {
13 fn trait_meta_handler(
14 ast: &DeriveInput,
15 token_stream: &mut proc_macro2::TokenStream,
16 traits: &[Trait],
17 meta: &Meta,
18 ) -> syn::Result<()> {
19 let type_attribute = TypeAttributeBuilder {
20 enable_flag: true,
21 enable_new: true,
22 enable_expression: true,
23 enable_bound: true,
24 }
25 .build_from_default_meta(meta)?;
26
27 let mut default_types: Vec<&Type> = Vec::new();
28
29 let mut default_token_stream = proc_macro2::TokenStream::new();
30
31 if let Data::Enum(data) = &ast.data {
32 if let Some(expression) = type_attribute.expression {
33 for variant in data.variants.iter() {
34 let _ = TypeAttributeBuilder {
35 enable_flag: false,
36 enable_new: false,
37 enable_expression: false,
38 enable_bound: false,
39 }
40 .build_from_attributes(&variant.attrs, traits)?;
41
42 ensure_fields_no_attribute(&variant.fields, traits)?;
43 }
44
45 default_token_stream.extend(quote!(#expression));
46 } else {
47 let variant = {
48 let variants = &data.variants;
49
50 if variants.len() == 1 {
51 let variant = &variants[0];
52
53 let _ = TypeAttributeBuilder {
54 enable_flag: true,
55 enable_new: false,
56 enable_expression: false,
57 enable_bound: false,
58 }
59 .build_from_attributes(&variant.attrs, traits)?;
60
61 variant
62 } else {
63 let mut default_variant: Option<&Variant> = None;
64
65 for variant in variants {
66 let type_attribute = TypeAttributeBuilder {
67 enable_flag: true,
68 enable_new: false,
69 enable_expression: false,
70 enable_bound: false,
71 }
72 .build_from_attributes(&variant.attrs, traits)?;
73
74 if type_attribute.flag {
75 if default_variant.is_some() {
76 return Err(super::panic::multiple_default_variants(
77 type_attribute.span,
78 ));
79 }
80
81 default_variant = Some(variant);
82 } else {
83 ensure_fields_no_attribute(&variant.fields, traits)?;
84 }
85 }
86
87 if let Some(default_variant) = default_variant {
88 default_variant
89 } else {
90 return Err(super::panic::no_default_variant(meta.span()));
91 }
92 }
93 };
94
95 let variant_ident = &variant.ident;
96
97 match &variant.fields {
98 Fields::Unit => {
99 default_token_stream.extend(quote!(Self::#variant_ident));
100 },
101 Fields::Named(_) => {
102 let mut fields_token_stream = proc_macro2::TokenStream::new();
103
104 for field in variant.fields.iter() {
105 let field_attribute = FieldAttributeBuilder {
106 enable_flag: false,
107 enable_expression: true,
108 }
109 .build_from_attributes(&field.attrs, traits, &field.ty)?;
110
111 let field_name = field.ident.as_ref().unwrap();
112
113 if let Some(expression) = field_attribute.expression {
114 fields_token_stream.extend(quote! {
115 #field_name: #expression,
116 });
117 } else {
118 let ty = &field.ty;
119
120 default_types.push(ty);
121
122 fields_token_stream.extend(quote! {
123 #field_name: <#ty as ::core::default::Default>::default(),
124 });
125 }
126 }
127
128 default_token_stream.extend(quote! {
129 Self::#variant_ident {
130 #fields_token_stream
131 }
132 });
133 },
134 Fields::Unnamed(_) => {
135 let mut fields_token_stream = proc_macro2::TokenStream::new();
136
137 for field in variant.fields.iter() {
138 let field_attribute = FieldAttributeBuilder {
139 enable_flag: false,
140 enable_expression: true,
141 }
142 .build_from_attributes(&field.attrs, traits, &field.ty)?;
143
144 if let Some(expression) = field_attribute.expression {
145 fields_token_stream.extend(quote!(#expression,));
146 } else {
147 let ty = &field.ty;
148
149 default_types.push(ty);
150
151 fields_token_stream
152 .extend(quote!(<#ty as ::core::default::Default>::default(),));
153 }
154 }
155
156 default_token_stream
157 .extend(quote!(Self::#variant_ident ( #fields_token_stream )));
158 },
159 }
160 }
161 }
162
163 let ident = &ast.ident;
164
165 let bound = type_attribute.bound.into_where_predicates_by_generic_parameters_check_types(
166 &ast.generics.params,
167 &syn::parse2(quote!(::core::default::Default)).unwrap(),
168 &default_types,
169 &[],
170 );
171
172 let mut generics = ast.generics.clone();
173 let where_clause = generics.make_where_clause();
174
175 for where_predicate in bound {
176 where_clause.predicates.push(where_predicate);
177 }
178
179 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
180
181 token_stream.extend(quote! {
182 impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
183 #[inline]
184 fn default() -> Self {
185 #default_token_stream
186 }
187 }
188 });
189
190 if type_attribute.new {
191 token_stream.extend(quote! {
192 impl #impl_generics #ident #ty_generics #where_clause {
193 #[inline]
195 pub fn new() -> Self {
196 <Self as ::core::default::Default>::default()
197 }
198 }
199 });
200 }
201
202 Ok(())
203 }
204}
205
206fn ensure_fields_no_attribute(fields: &Fields, traits: &[Trait]) -> syn::Result<()> {
207 match fields {
208 Fields::Unit => (),
209 Fields::Named(fields) => {
210 for field in fields.named.iter() {
211 let _ = FieldAttributeBuilder {
212 enable_flag: false,
213 enable_expression: false,
214 }
215 .build_from_attributes(&field.attrs, traits, &field.ty)?;
216 }
217 },
218 Fields::Unnamed(fields) => {
219 for field in fields.unnamed.iter() {
220 let _ = FieldAttributeBuilder {
221 enable_flag: false,
222 enable_expression: false,
223 }
224 .build_from_attributes(&field.attrs, traits, &field.ty)?;
225 }
226 },
227 }
228
229 Ok(())
230}