educe/trait_handlers/default/
default_union.rs

1use quote::quote;
2use syn::{spanned::Spanned, Data, DeriveInput, Field, Meta, Type};
3
4use super::{
5    models::{FieldAttribute, FieldAttributeBuilder, TypeAttributeBuilder},
6    TraitHandler,
7};
8use crate::Trait;
9
10pub(crate) struct DefaultUnionHandler;
11
12impl TraitHandler for DefaultUnionHandler {
13    fn trait_meta_handler(
14        ast: &DeriveInput,
15        token_stream: &mut proc_macro2::TokenStream,
16        traits: &[Trait],
17        meta: &Meta,
18    ) -> syn::Result<()> {
19        let type_attribute = TypeAttributeBuilder {
20            enable_flag:       true,
21            enable_new:        true,
22            enable_expression: true,
23            enable_bound:      true,
24        }
25        .build_from_default_meta(meta)?;
26
27        let mut default_types: Vec<&Type> = Vec::new();
28
29        let mut default_token_stream = proc_macro2::TokenStream::new();
30
31        if let Data::Union(data) = &ast.data {
32            if let Some(expression) = type_attribute.expression {
33                for field in data.fields.named.iter() {
34                    let _ = FieldAttributeBuilder {
35                        enable_flag:       false,
36                        enable_expression: false,
37                    }
38                    .build_from_attributes(&field.attrs, traits, &field.ty)?;
39                }
40
41                default_token_stream.extend(quote!(#expression));
42            } else {
43                let (field, field_attribute) =
44                    {
45                        let fields = &data.fields.named;
46
47                        if fields.len() == 1 {
48                            let field = &fields[0];
49
50                            let field_attribute = FieldAttributeBuilder {
51                                enable_flag:       true,
52                                enable_expression: true,
53                            }
54                            .build_from_attributes(&field.attrs, traits, &field.ty)?;
55
56                            (field, field_attribute)
57                        } else {
58                            let mut default_field: Option<(&Field, FieldAttribute)> = None;
59
60                            for field in fields {
61                                let field_attribute = FieldAttributeBuilder {
62                                    enable_flag:       true,
63                                    enable_expression: true,
64                                }
65                                .build_from_attributes(&field.attrs, traits, &field.ty)?;
66
67                                if field_attribute.flag || field_attribute.expression.is_some() {
68                                    if default_field.is_some() {
69                                        return Err(super::panic::multiple_default_fields(
70                                            field_attribute.span,
71                                        ));
72                                    }
73
74                                    default_field = Some((field, field_attribute));
75                                }
76                            }
77
78                            if let Some(default_field) = default_field {
79                                default_field
80                            } else {
81                                return Err(super::panic::no_default_field(meta.span()));
82                            }
83                        }
84                    };
85
86                let mut fields_token_stream = proc_macro2::TokenStream::new();
87
88                let field_name = field.ident.as_ref().unwrap();
89
90                if let Some(expression) = field_attribute.expression {
91                    fields_token_stream.extend(quote! {
92                        #field_name: #expression,
93                    });
94                } else {
95                    let ty = &field.ty;
96
97                    default_types.push(ty);
98
99                    fields_token_stream.extend(quote! {
100                        #field_name: <#ty as ::core::default::Default>::default(),
101                    });
102                }
103
104                default_token_stream.extend(quote! {
105                    Self {
106                        #fields_token_stream
107                    }
108                });
109            }
110        }
111
112        let ident = &ast.ident;
113
114        let bound = type_attribute.bound.into_where_predicates_by_generic_parameters_check_types(
115            &ast.generics.params,
116            &syn::parse2(quote!(::core::default::Default)).unwrap(),
117            &default_types,
118            &[],
119        );
120
121        let mut generics = ast.generics.clone();
122        let where_clause = generics.make_where_clause();
123
124        for where_predicate in bound {
125            where_clause.predicates.push(where_predicate);
126        }
127
128        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
129
130        token_stream.extend(quote! {
131            impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
132                #[inline]
133                fn default() -> Self {
134                    #default_token_stream
135                }
136            }
137        });
138
139        if type_attribute.new {
140            token_stream.extend(quote! {
141                impl #impl_generics #ident #ty_generics #where_clause {
142                    /// Returns the "default value" for a type.
143                    #[inline]
144                    pub fn new() -> Self {
145                        <Self as ::core::default::Default>::default()
146                    }
147                }
148            });
149        }
150
151        Ok(())
152    }
153}