educe/trait_handlers/partial_eq/
partial_eq_enum.rs

1use quote::{format_ident, quote};
2use syn::{Data, DeriveInput, Fields, Meta, Type};
3
4use super::{
5    models::{FieldAttributeBuilder, TypeAttributeBuilder},
6    TraitHandler,
7};
8use crate::Trait;
9
10pub(crate) struct PartialEqEnumHandler;
11
12impl TraitHandler for PartialEqEnumHandler {
13    #[inline]
14    fn trait_meta_handler(
15        ast: &DeriveInput,
16        token_stream: &mut proc_macro2::TokenStream,
17        traits: &[Trait],
18        meta: &Meta,
19    ) -> syn::Result<()> {
20        let type_attribute =
21            TypeAttributeBuilder {
22                enable_flag: true, enable_unsafe: false, enable_bound: true
23            }
24            .build_from_partial_eq_meta(meta)?;
25
26        let mut partial_eq_types: Vec<&Type> = Vec::new();
27
28        let mut eq_token_stream = proc_macro2::TokenStream::new();
29
30        let mut arms_token_stream = proc_macro2::TokenStream::new();
31
32        if let Data::Enum(data) = &ast.data {
33            for variant in data.variants.iter() {
34                let _ = TypeAttributeBuilder {
35                    enable_flag:   false,
36                    enable_unsafe: false,
37                    enable_bound:  false,
38                }
39                .build_from_attributes(&variant.attrs, traits)?;
40
41                let variant_ident = &variant.ident;
42
43                match &variant.fields {
44                    Fields::Unit => {
45                        arms_token_stream.extend(quote! {
46                            Self::#variant_ident => {
47                                if let Self::#variant_ident = other {
48                                    // same
49                                } else {
50                                    return false;
51                                }
52                            }
53                        });
54                    },
55                    Fields::Named(_) => {
56                        let mut pattern_self_token_stream = proc_macro2::TokenStream::new();
57                        let mut pattern_other_token_stream = proc_macro2::TokenStream::new();
58                        let mut block_token_stream = proc_macro2::TokenStream::new();
59
60                        for field in variant.fields.iter() {
61                            let field_attribute = FieldAttributeBuilder {
62                                enable_ignore: true,
63                                enable_method: true,
64                            }
65                            .build_from_attributes(&field.attrs, traits)?;
66
67                            let field_name_real = field.ident.as_ref().unwrap();
68                            let field_name_var_self = format_ident!("_s_{}", field_name_real);
69                            let field_name_var_other = format_ident!("_o_{}", field_name_real);
70
71                            if field_attribute.ignore {
72                                pattern_self_token_stream.extend(quote!(#field_name_real: _,));
73                                pattern_other_token_stream.extend(quote!(#field_name_real: _,));
74
75                                continue;
76                            }
77
78                            pattern_self_token_stream
79                                .extend(quote!(#field_name_real: #field_name_var_self,));
80                            pattern_other_token_stream
81                                .extend(quote!(#field_name_real: #field_name_var_other,));
82
83                            if let Some(method) = field_attribute.method {
84                                block_token_stream.extend(quote! {
85                                    if !#method(#field_name_var_self, #field_name_var_other) {
86                                        return false;
87                                    }
88                                });
89                            } else {
90                                let ty = &field.ty;
91
92                                partial_eq_types.push(ty);
93
94                                block_token_stream.extend(quote! {
95                                    if ::core::cmp::PartialEq::ne(#field_name_var_self, #field_name_var_other) {
96                                        return false;
97                                    }
98                                });
99                            }
100                        }
101
102                        arms_token_stream.extend(quote! {
103                            Self::#variant_ident { #pattern_self_token_stream } => {
104                                if let Self::#variant_ident { #pattern_other_token_stream } = other {
105                                    #block_token_stream
106                                } else {
107                                    return false;
108                                }
109                            }
110                        });
111                    },
112                    Fields::Unnamed(_) => {
113                        let mut pattern_token_stream = proc_macro2::TokenStream::new();
114                        let mut pattern2_token_stream = proc_macro2::TokenStream::new();
115                        let mut block_token_stream = proc_macro2::TokenStream::new();
116
117                        for (index, field) in variant.fields.iter().enumerate() {
118                            let field_attribute = FieldAttributeBuilder {
119                                enable_ignore: true,
120                                enable_method: true,
121                            }
122                            .build_from_attributes(&field.attrs, traits)?;
123
124                            if field_attribute.ignore {
125                                pattern_token_stream.extend(quote!(_,));
126                                pattern2_token_stream.extend(quote!(_,));
127
128                                continue;
129                            }
130
131                            let field_name_var_self = format_ident!("_{}", index);
132
133                            let field_name_var_other = format_ident!("_{}", field_name_var_self);
134
135                            pattern_token_stream.extend(quote!(#field_name_var_self,));
136                            pattern2_token_stream.extend(quote!(#field_name_var_other,));
137
138                            if let Some(method) = field_attribute.method {
139                                block_token_stream.extend(quote! {
140                                    if !#method(#field_name_var_self, #field_name_var_other) {
141                                        return false;
142                                    }
143                                });
144                            } else {
145                                let ty = &field.ty;
146
147                                partial_eq_types.push(ty);
148
149                                block_token_stream.extend(quote! {
150                                    if ::core::cmp::PartialEq::ne(#field_name_var_self, #field_name_var_other) {
151                                        return false;
152                                    }
153                                });
154                            }
155                        }
156
157                        arms_token_stream.extend(quote! {
158                            Self::#variant_ident ( #pattern_token_stream ) => {
159                                if let Self::#variant_ident ( #pattern2_token_stream ) = other {
160                                    #block_token_stream
161                                } else {
162                                    return false;
163                                }
164                            }
165                        });
166                    },
167                }
168            }
169        }
170
171        if !arms_token_stream.is_empty() {
172            eq_token_stream.extend(quote! {
173                match self {
174                    #arms_token_stream
175                }
176            });
177        }
178
179        let ident = &ast.ident;
180
181        let bound = type_attribute.bound.into_where_predicates_by_generic_parameters_check_types(
182            &ast.generics.params,
183            &syn::parse2(quote!(::core::cmp::PartialEq)).unwrap(),
184            &partial_eq_types,
185            &[],
186        );
187
188        let mut generics = ast.generics.clone();
189        let where_clause = generics.make_where_clause();
190
191        for where_predicate in bound {
192            where_clause.predicates.push(where_predicate);
193        }
194
195        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
196
197        token_stream.extend(quote! {
198            impl #impl_generics ::core::cmp::PartialEq for #ident #ty_generics #where_clause {
199                #[inline]
200                fn eq(&self, other: &Self) -> bool {
201                    #eq_token_stream
202
203                    true
204                }
205            }
206        });
207
208        #[cfg(feature = "Eq")]
209        if traits.contains(&Trait::Eq) {
210            token_stream.extend(quote! {
211                impl #impl_generics ::core::cmp::Eq for #ident #ty_generics #where_clause {
212                }
213            });
214        }
215
216        Ok(())
217    }
218}