educe/trait_handlers/deref_mut/
deref_mut_enum.rs

1use quote::{format_ident, quote};
2use syn::{spanned::Spanned, Data, DeriveInput, Field, Fields, Ident, Meta};
3
4use super::{
5    models::{FieldAttributeBuilder, TypeAttributeBuilder},
6    TraitHandler,
7};
8use crate::{panic, supported_traits::Trait};
9
10pub(crate) struct DerefMutEnumHandler;
11
12impl TraitHandler for DerefMutEnumHandler {
13    #[inline]
14    fn trait_meta_handler(
15        ast: &DeriveInput,
16        token_stream: &mut proc_macro2::TokenStream,
17        traits: &[Trait],
18        meta: &Meta,
19    ) -> syn::Result<()> {
20        let _ = TypeAttributeBuilder {
21            enable_flag: true
22        }
23        .build_from_deref_mut_meta(meta)?;
24
25        let mut arms_token_stream = proc_macro2::TokenStream::new();
26
27        if let Data::Enum(data) = &ast.data {
28            type Variants<'a> = Vec<(&'a Ident, bool, usize, Ident)>;
29
30            let mut variants: Variants = Vec::new();
31
32            for variant in data.variants.iter() {
33                let _ = TypeAttributeBuilder {
34                    enable_flag: false
35                }
36                .build_from_attributes(&variant.attrs, traits)?;
37
38                if let Fields::Unit = &variant.fields {
39                    return Err(panic::trait_not_support_unit_variant(
40                        meta.path().get_ident().unwrap(),
41                        variant,
42                    ));
43                }
44
45                let fields = &variant.fields;
46
47                let (index, field) = if fields.len() == 1 {
48                    let field = fields.into_iter().next().unwrap();
49
50                    let _ = FieldAttributeBuilder {
51                        enable_flag: true
52                    }
53                    .build_from_attributes(&field.attrs, traits)?;
54
55                    (0usize, field)
56                } else {
57                    let mut deref_field: Option<(usize, &Field)> = None;
58
59                    for (index, field) in variant.fields.iter().enumerate() {
60                        let field_attribute = FieldAttributeBuilder {
61                            enable_flag: true
62                        }
63                        .build_from_attributes(&field.attrs, traits)?;
64
65                        if field_attribute.flag {
66                            if deref_field.is_some() {
67                                return Err(super::panic::multiple_deref_mut_fields_of_variant(
68                                    field_attribute.span,
69                                    variant,
70                                ));
71                            }
72
73                            deref_field = Some((index, field));
74                        }
75                    }
76
77                    if let Some(deref_field) = deref_field {
78                        deref_field
79                    } else {
80                        return Err(super::panic::no_deref_mut_field_of_variant(
81                            meta.span(),
82                            variant,
83                        ));
84                    }
85                };
86
87                let (field_name, is_tuple): (Ident, bool) = match field.ident.as_ref() {
88                    Some(ident) => (ident.clone(), false),
89                    None => (format_ident!("_{}", index), true),
90                };
91
92                variants.push((&variant.ident, is_tuple, index, field_name));
93            }
94
95            if variants.is_empty() {
96                return Err(super::panic::no_deref_mut_field(meta.span()));
97            }
98
99            for (variant_ident, is_tuple, index, field_name) in variants {
100                let mut pattern_token_stream = proc_macro2::TokenStream::new();
101
102                if is_tuple {
103                    for _ in 0..index {
104                        pattern_token_stream.extend(quote!(_,));
105                    }
106
107                    pattern_token_stream.extend(quote!( #field_name, .. ));
108
109                    arms_token_stream.extend(
110                        quote!( Self::#variant_ident ( #pattern_token_stream ) => #field_name, ),
111                    );
112                } else {
113                    pattern_token_stream.extend(quote!( #field_name, .. ));
114
115                    arms_token_stream.extend(
116                        quote!( Self::#variant_ident { #pattern_token_stream } => #field_name, ),
117                    );
118                }
119            }
120        }
121
122        let ident = &ast.ident;
123
124        let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
125
126        token_stream.extend(quote! {
127            impl #impl_generics ::core::ops::DerefMut for #ident #ty_generics #where_clause {
128                #[inline]
129                fn deref_mut(&mut self) -> &mut Self::Target {
130                    match self {
131                        #arms_token_stream
132                    }
133                }
134            }
135        });
136
137        Ok(())
138    }
139}