educe/trait_handlers/into/
into_enum.rs

1use std::collections::HashMap;
2
3use quote::{format_ident, quote};
4use syn::{Data, DeriveInput, Field, Fields, Ident, Meta, Path, Type};
5
6use super::{
7    models::{FieldAttribute, FieldAttributeBuilder, TypeAttributeBuilder},
8    TraitHandlerMultiple,
9};
10use crate::{panic, Trait};
11
12pub(crate) struct IntoEnumHandler;
13
14impl TraitHandlerMultiple for IntoEnumHandler {
15    #[inline]
16    fn trait_meta_handler(
17        ast: &DeriveInput,
18        token_stream: &mut proc_macro2::TokenStream,
19        traits: &[Trait],
20        meta: &[Meta],
21    ) -> syn::Result<()> {
22        let type_attribute = TypeAttributeBuilder {
23            enable_types: true
24        }
25        .build_from_into_meta(meta)?;
26
27        if let Data::Enum(data) = &ast.data {
28            let field_attributes: Vec<HashMap<usize, FieldAttribute>> = {
29                let mut map = Vec::new();
30
31                for variant in data.variants.iter() {
32                    let mut field_map = HashMap::new();
33
34                    let _ = TypeAttributeBuilder {
35                        enable_types: false
36                    }
37                    .build_from_attributes(&variant.attrs, traits)?;
38
39                    for (index, field) in variant.fields.iter().enumerate() {
40                        let field_attribute = FieldAttributeBuilder {
41                            enable_types: true
42                        }
43                        .build_from_attributes(&field.attrs, traits)?;
44
45                        for ty in field_attribute.types.keys() {
46                            if !type_attribute.types.contains_key(ty) {
47                                return Err(super::panic::no_into_impl(ty));
48                            }
49                        }
50
51                        field_map.insert(index, field_attribute);
52                    }
53
54                    map.push(field_map);
55                }
56
57                map
58            };
59
60            for (target_ty, bound) in type_attribute.types {
61                let mut into_types: Vec<&Type> = Vec::new();
62
63                let mut arms_token_stream = proc_macro2::TokenStream::new();
64
65                type Variants<'a> =
66                    Vec<(&'a Ident, bool, usize, Ident, &'a Type, Option<&'a Path>)>;
67
68                let mut variants: Variants = Vec::new();
69
70                for (variant, field_attributes) in data.variants.iter().zip(field_attributes.iter())
71                {
72                    if let Fields::Unit = &variant.fields {
73                        return Err(panic::trait_not_support_unit_variant(
74                            meta[0].path().get_ident().unwrap(),
75                            variant,
76                        ));
77                    }
78
79                    let (index, field, method) = {
80                        let fields = &variant.fields;
81
82                        if fields.len() == 1 {
83                            let field = fields.into_iter().next().unwrap();
84
85                            let method = if let Some(field_attribute) = field_attributes.get(&0) {
86                                if let Some(method) = field_attribute.types.get(&target_ty) {
87                                    method.as_ref()
88                                } else {
89                                    None
90                                }
91                            } else {
92                                None
93                            };
94
95                            (0usize, field, method)
96                        } else {
97                            let mut into_field: Option<(usize, &Field, Option<&Path>)> = None;
98
99                            for (index, field) in fields.iter().enumerate() {
100                                if let Some(field_attribute) = field_attributes.get(&index) {
101                                    if let Some((key, method)) =
102                                        field_attribute.types.get_key_value(&target_ty)
103                                    {
104                                        if into_field.is_some() {
105                                            return Err(super::panic::multiple_into_fields(key));
106                                        }
107
108                                        into_field = Some((index, field, method.as_ref()));
109                                    }
110                                }
111                            }
112
113                            if into_field.is_none() {
114                                // search the same type
115                                for (index, field) in fields.iter().enumerate() {
116                                    let field_ty = super::common::to_hash_type(&field.ty);
117
118                                    if target_ty.eq(&field_ty) {
119                                        if into_field.is_some() {
120                                            // multiple candidates
121                                            into_field = None;
122
123                                            break;
124                                        }
125
126                                        into_field = Some((index, field, None));
127                                    }
128                                }
129                            }
130
131                            if let Some(into_field) = into_field {
132                                into_field
133                            } else {
134                                return Err(super::panic::no_into_field(&target_ty));
135                            }
136                        }
137                    };
138
139                    let (field_name, is_tuple): (Ident, bool) = match field.ident.as_ref() {
140                        Some(ident) => (ident.clone(), false),
141                        None => (format_ident!("_{}", index), true),
142                    };
143
144                    variants.push((&variant.ident, is_tuple, index, field_name, &field.ty, method));
145                }
146
147                if variants.is_empty() {
148                    return Err(super::panic::no_into_field(&target_ty));
149                }
150
151                for (variant_ident, is_tuple, index, field_name, ty, method) in variants {
152                    let mut pattern_token_stream = proc_macro2::TokenStream::new();
153                    let mut body_token_stream = proc_macro2::TokenStream::new();
154
155                    if let Some(method) = method {
156                        body_token_stream.extend(quote!( #method(#field_name) ));
157                    } else {
158                        let field_ty = super::common::to_hash_type(ty);
159
160                        if target_ty.eq(&field_ty) {
161                            body_token_stream.extend(quote!( #field_name ));
162                        } else {
163                            into_types.push(ty);
164
165                            body_token_stream
166                                .extend(quote!( ::core::convert::Into::into(#field_name) ));
167                        }
168                    }
169
170                    if is_tuple {
171                        for _ in 0..index {
172                            pattern_token_stream.extend(quote!(_,));
173                        }
174
175                        pattern_token_stream.extend(quote!( #field_name, .. ));
176
177                        arms_token_stream.extend(
178                            quote!( Self::#variant_ident ( #pattern_token_stream ) => #body_token_stream, ),
179                        );
180                    } else {
181                        pattern_token_stream.extend(quote!( #field_name, .. ));
182
183                        arms_token_stream.extend(
184                            quote!( Self::#variant_ident { #pattern_token_stream } => #body_token_stream, ),
185                        );
186                    }
187                }
188
189                let ident = &ast.ident;
190
191                let bound = bound.into_where_predicates_by_generic_parameters_check_types(
192                    &ast.generics.params,
193                    &syn::parse2(quote!(::core::convert::Into<#target_ty>)).unwrap(),
194                    &into_types,
195                    &[],
196                );
197
198                // clone generics in order to not to affect other Into<T> implementations
199                let mut generics = ast.generics.clone();
200
201                let where_clause = generics.make_where_clause();
202
203                for where_predicate in bound {
204                    where_clause.predicates.push(where_predicate);
205                }
206
207                let (impl_generics, ty_generics, _) = ast.generics.split_for_impl();
208
209                token_stream.extend(quote! {
210                    impl #impl_generics ::core::convert::Into<#target_ty> for #ident #ty_generics #where_clause {
211                        #[inline]
212                        fn into(self) -> #target_ty {
213                            match self {
214                                #arms_token_stream
215                            }
216                        }
217                    }
218                });
219            }
220        }
221
222        Ok(())
223    }
224}