educe/trait_handlers/default/
default_enum.rs

1use quote::quote;
2use syn::{spanned::Spanned, Data, DeriveInput, Fields, Meta, Type, Variant};
3
4use super::{
5    models::{FieldAttributeBuilder, TypeAttributeBuilder},
6    TraitHandler,
7};
8use crate::Trait;
9
10pub(crate) struct DefaultEnumHandler;
11
12impl TraitHandler for DefaultEnumHandler {
13    fn trait_meta_handler(
14        ast: &DeriveInput,
15        token_stream: &mut proc_macro2::TokenStream,
16        traits: &[Trait],
17        meta: &Meta,
18    ) -> syn::Result<()> {
19        let type_attribute = TypeAttributeBuilder {
20            enable_flag:       true,
21            enable_new:        true,
22            enable_expression: true,
23            enable_bound:      true,
24        }
25        .build_from_default_meta(meta)?;
26
27        let mut default_types: Vec<&Type> = Vec::new();
28
29        let mut default_token_stream = proc_macro2::TokenStream::new();
30
31        if let Data::Enum(data) = &ast.data {
32            if let Some(expression) = type_attribute.expression {
33                for variant in data.variants.iter() {
34                    let _ = TypeAttributeBuilder {
35                        enable_flag:       false,
36                        enable_new:        false,
37                        enable_expression: false,
38                        enable_bound:      false,
39                    }
40                    .build_from_attributes(&variant.attrs, traits)?;
41
42                    ensure_fields_no_attribute(&variant.fields, traits)?;
43                }
44
45                default_token_stream.extend(quote!(#expression));
46            } else {
47                let variant = {
48                    let variants = &data.variants;
49
50                    if variants.len() == 1 {
51                        let variant = &variants[0];
52
53                        let _ = TypeAttributeBuilder {
54                            enable_flag:       true,
55                            enable_new:        false,
56                            enable_expression: false,
57                            enable_bound:      false,
58                        }
59                        .build_from_attributes(&variant.attrs, traits)?;
60
61                        variant
62                    } else {
63                        let mut default_variant: Option<&Variant> = None;
64
65                        for variant in variants {
66                            let type_attribute = TypeAttributeBuilder {
67                                enable_flag:       true,
68                                enable_new:        false,
69                                enable_expression: false,
70                                enable_bound:      false,
71                            }
72                            .build_from_attributes(&variant.attrs, traits)?;
73
74                            if type_attribute.flag {
75                                if default_variant.is_some() {
76                                    return Err(super::panic::multiple_default_variants(
77                                        type_attribute.span,
78                                    ));
79                                }
80
81                                default_variant = Some(variant);
82                            } else {
83                                ensure_fields_no_attribute(&variant.fields, traits)?;
84                            }
85                        }
86
87                        if let Some(default_variant) = default_variant {
88                            default_variant
89                        } else {
90                            return Err(super::panic::no_default_variant(meta.span()));
91                        }
92                    }
93                };
94
95                let variant_ident = &variant.ident;
96
97                match &variant.fields {
98                    Fields::Unit => {
99                        default_token_stream.extend(quote!(Self::#variant_ident));
100                    },
101                    Fields::Named(_) => {
102                        let mut fields_token_stream = proc_macro2::TokenStream::new();
103
104                        for field in variant.fields.iter() {
105                            let field_attribute = FieldAttributeBuilder {
106                                enable_flag:       false,
107                                enable_expression: true,
108                            }
109                            .build_from_attributes(&field.attrs, traits, &field.ty)?;
110
111                            let field_name = field.ident.as_ref().unwrap();
112
113                            if let Some(expression) = field_attribute.expression {
114                                fields_token_stream.extend(quote! {
115                                    #field_name: #expression,
116                                });
117                            } else {
118                                let ty = &field.ty;
119
120                                default_types.push(ty);
121
122                                fields_token_stream.extend(quote! {
123                                    #field_name: <#ty as ::core::default::Default>::default(),
124                                });
125                            }
126                        }
127
128                        default_token_stream.extend(quote! {
129                            Self::#variant_ident {
130                                #fields_token_stream
131                            }
132                        });
133                    },
134                    Fields::Unnamed(_) => {
135                        let mut fields_token_stream = proc_macro2::TokenStream::new();
136
137                        for field in variant.fields.iter() {
138                            let field_attribute = FieldAttributeBuilder {
139                                enable_flag:       false,
140                                enable_expression: true,
141                            }
142                            .build_from_attributes(&field.attrs, traits, &field.ty)?;
143
144                            if let Some(expression) = field_attribute.expression {
145                                fields_token_stream.extend(quote!(#expression,));
146                            } else {
147                                let ty = &field.ty;
148
149                                default_types.push(ty);
150
151                                fields_token_stream
152                                    .extend(quote!(<#ty as ::core::default::Default>::default(),));
153                            }
154                        }
155
156                        default_token_stream
157                            .extend(quote!(Self::#variant_ident ( #fields_token_stream )));
158                    },
159                }
160            }
161        }
162
163        let ident = &ast.ident;
164
165        let bound = type_attribute.bound.into_where_predicates_by_generic_parameters_check_types(
166            &ast.generics.params,
167            &syn::parse2(quote!(::core::default::Default)).unwrap(),
168            &default_types,
169            &[],
170        );
171
172        let mut generics = ast.generics.clone();
173        let where_clause = generics.make_where_clause();
174
175        for where_predicate in bound {
176            where_clause.predicates.push(where_predicate);
177        }
178
179        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
180
181        token_stream.extend(quote! {
182            impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
183                #[inline]
184                fn default() -> Self {
185                    #default_token_stream
186                }
187            }
188        });
189
190        if type_attribute.new {
191            token_stream.extend(quote! {
192                impl #impl_generics #ident #ty_generics #where_clause {
193                    /// Returns the "default value" for a type.
194                    #[inline]
195                    pub fn new() -> Self {
196                        <Self as ::core::default::Default>::default()
197                    }
198                }
199            });
200        }
201
202        Ok(())
203    }
204}
205
206fn ensure_fields_no_attribute(fields: &Fields, traits: &[Trait]) -> syn::Result<()> {
207    match fields {
208        Fields::Unit => (),
209        Fields::Named(fields) => {
210            for field in fields.named.iter() {
211                let _ = FieldAttributeBuilder {
212                    enable_flag:       false,
213                    enable_expression: false,
214                }
215                .build_from_attributes(&field.attrs, traits, &field.ty)?;
216            }
217        },
218        Fields::Unnamed(fields) => {
219            for field in fields.unnamed.iter() {
220                let _ = FieldAttributeBuilder {
221                    enable_flag:       false,
222                    enable_expression: false,
223                }
224                .build_from_attributes(&field.attrs, traits, &field.ty)?;
225            }
226        },
227    }
228
229    Ok(())
230}