educe/trait_handlers/ord/
ord_enum.rs

1use std::collections::BTreeMap;
2
3use quote::{format_ident, quote};
4use syn::{spanned::Spanned, Data, DeriveInput, Field, Fields, Ident, Meta, Path, Type};
5
6use super::{
7    models::{FieldAttribute, FieldAttributeBuilder, TypeAttributeBuilder},
8    TraitHandler,
9};
10use crate::{common::tools::DiscriminantType, Trait};
11
12pub(crate) struct OrdEnumHandler;
13
14impl TraitHandler for OrdEnumHandler {
15    #[inline]
16    fn trait_meta_handler(
17        ast: &DeriveInput,
18        token_stream: &mut proc_macro2::TokenStream,
19        traits: &[Trait],
20        meta: &Meta,
21    ) -> syn::Result<()> {
22        let type_attribute = TypeAttributeBuilder {
23            enable_flag: true, enable_bound: true
24        }
25        .build_from_ord_meta(meta)?;
26
27        let mut ord_types: Vec<&Type> = Vec::new();
28
29        let mut cmp_token_stream = proc_macro2::TokenStream::new();
30
31        let discriminant_type = DiscriminantType::from_ast(ast)?;
32
33        let mut arms_token_stream = proc_macro2::TokenStream::new();
34
35        let mut all_unit = true;
36
37        if let Data::Enum(data) = &ast.data {
38            for variant in data.variants.iter() {
39                let _ = TypeAttributeBuilder {
40                    enable_flag: false, enable_bound: false
41                }
42                .build_from_attributes(&variant.attrs, traits)?;
43
44                let variant_ident = &variant.ident;
45
46                let built_in_cmp: Path = syn::parse2(quote!(::core::cmp::Ord::cmp)).unwrap();
47
48                match &variant.fields {
49                    Fields::Unit => {
50                        arms_token_stream.extend(quote! {
51                            Self::#variant_ident => {
52                                return ::core::cmp::Ordering::Equal;
53                            }
54                        });
55                    },
56                    Fields::Named(_) => {
57                        all_unit = false;
58
59                        let mut pattern_self_token_stream = proc_macro2::TokenStream::new();
60                        let mut pattern_other_token_stream = proc_macro2::TokenStream::new();
61                        let mut block_token_stream = proc_macro2::TokenStream::new();
62
63                        let mut fields: BTreeMap<isize, (&Field, Ident, Ident, FieldAttribute)> =
64                            BTreeMap::new();
65
66                        for (index, field) in variant.fields.iter().enumerate() {
67                            let field_attribute = FieldAttributeBuilder {
68                                enable_ignore: true,
69                                enable_method: true,
70                                enable_rank:   true,
71                                rank:          isize::MIN + index as isize,
72                            }
73                            .build_from_attributes(&field.attrs, traits)?;
74
75                            let field_name_real = field.ident.as_ref().unwrap();
76                            let field_name_var_self = format_ident!("_s_{}", field_name_real);
77                            let field_name_var_other = format_ident!("_o_{}", field_name_real);
78
79                            if field_attribute.ignore {
80                                pattern_self_token_stream.extend(quote!(#field_name_real: _,));
81                                pattern_other_token_stream.extend(quote!(#field_name_real: _,));
82
83                                continue;
84                            }
85
86                            pattern_self_token_stream
87                                .extend(quote!(#field_name_real: #field_name_var_self,));
88                            pattern_other_token_stream
89                                .extend(quote!(#field_name_real: #field_name_var_other,));
90
91                            let rank = field_attribute.rank;
92
93                            if fields.contains_key(&rank) {
94                                return Err(super::panic::reuse_a_rank(
95                                    field_attribute.rank_span.unwrap_or_else(|| field.span()),
96                                    rank,
97                                ));
98                            }
99
100                            fields.insert(
101                                rank,
102                                (field, field_name_var_self, field_name_var_other, field_attribute),
103                            );
104                        }
105
106                        for (field, field_name_var_self, field_name_var_other, field_attribute) in
107                            fields.values()
108                        {
109                            let cmp = field_attribute.method.as_ref().unwrap_or_else(|| {
110                                ord_types.push(&field.ty);
111
112                                &built_in_cmp
113                            });
114
115                            block_token_stream.extend(quote! {
116                                match #cmp(#field_name_var_self, #field_name_var_other) {
117                                    ::core::cmp::Ordering::Equal => (),
118                                    ::core::cmp::Ordering::Greater => return ::core::cmp::Ordering::Greater,
119                                    ::core::cmp::Ordering::Less => return ::core::cmp::Ordering::Less,
120                                }
121                            });
122                        }
123
124                        arms_token_stream.extend(quote! {
125                            Self::#variant_ident { #pattern_self_token_stream } => {
126                                if let Self::#variant_ident { #pattern_other_token_stream } = other {
127                                    #block_token_stream
128                                }
129                            }
130                        });
131                    },
132                    Fields::Unnamed(_) => {
133                        all_unit = false;
134
135                        let mut pattern_token_stream = proc_macro2::TokenStream::new();
136                        let mut pattern2_token_stream = proc_macro2::TokenStream::new();
137                        let mut block_token_stream = proc_macro2::TokenStream::new();
138
139                        let mut fields: BTreeMap<isize, (&Field, Ident, Ident, FieldAttribute)> =
140                            BTreeMap::new();
141
142                        for (index, field) in variant.fields.iter().enumerate() {
143                            let field_attribute = FieldAttributeBuilder {
144                                enable_ignore: true,
145                                enable_method: true,
146                                enable_rank:   true,
147                                rank:          isize::MIN + index as isize,
148                            }
149                            .build_from_attributes(&field.attrs, traits)?;
150
151                            let field_name_var_self = format_ident!("_{}", index);
152
153                            if field_attribute.ignore {
154                                pattern_token_stream.extend(quote!(_,));
155                                pattern2_token_stream.extend(quote!(_,));
156
157                                continue;
158                            }
159
160                            let field_name_var_other = format_ident!("_{}", field_name_var_self);
161
162                            pattern_token_stream.extend(quote!(#field_name_var_self,));
163                            pattern2_token_stream.extend(quote!(#field_name_var_other,));
164
165                            let rank = field_attribute.rank;
166
167                            if fields.contains_key(&rank) {
168                                return Err(super::panic::reuse_a_rank(
169                                    field_attribute.rank_span.unwrap_or_else(|| field.span()),
170                                    rank,
171                                ));
172                            }
173
174                            fields.insert(
175                                rank,
176                                (field, field_name_var_self, field_name_var_other, field_attribute),
177                            );
178                        }
179
180                        for (field, field_name, field_name2, field_attribute) in fields.values() {
181                            let cmp = field_attribute.method.as_ref().unwrap_or_else(|| {
182                                ord_types.push(&field.ty);
183
184                                &built_in_cmp
185                            });
186
187                            block_token_stream.extend(quote! {
188                                match #cmp(#field_name, #field_name2) {
189                                    ::core::cmp::Ordering::Equal => (),
190                                    ::core::cmp::Ordering::Greater => return ::core::cmp::Ordering::Greater,
191                                    ::core::cmp::Ordering::Less => return ::core::cmp::Ordering::Less,
192                                }
193                            });
194                        }
195
196                        arms_token_stream.extend(quote! {
197                            Self::#variant_ident ( #pattern_token_stream ) => {
198                                if let Self::#variant_ident ( #pattern2_token_stream ) = other {
199                                    #block_token_stream
200                                }
201                            }
202                        });
203                    },
204                }
205            }
206        }
207
208        if arms_token_stream.is_empty() {
209            cmp_token_stream.extend(quote!(::core::cmp::Ordering::Equal));
210        } else {
211            let discriminant_cmp = quote! {
212                unsafe {
213                    ::core::cmp::Ord::cmp(&*<*const _>::from(self).cast::<#discriminant_type>(), &*<*const _>::from(other).cast::<#discriminant_type>())
214                }
215            };
216
217            cmp_token_stream.extend(if all_unit {
218                quote! {
219                    match #discriminant_cmp {
220                        ::core::cmp::Ordering::Equal => ::core::cmp::Ordering::Equal,
221                        ::core::cmp::Ordering::Greater => ::core::cmp::Ordering::Greater,
222                        ::core::cmp::Ordering::Less => ::core::cmp::Ordering::Less,
223                    }
224                }
225            } else {
226                quote! {
227                    match #discriminant_cmp {
228                        ::core::cmp::Ordering::Equal => {
229                            match self {
230                                #arms_token_stream
231                            }
232
233                            ::core::cmp::Ordering::Equal
234                        },
235                        ::core::cmp::Ordering::Greater => ::core::cmp::Ordering::Greater,
236                        ::core::cmp::Ordering::Less => ::core::cmp::Ordering::Less,
237                    }
238                }
239            });
240        }
241
242        let ident = &ast.ident;
243
244        let bound = type_attribute.bound.into_where_predicates_by_generic_parameters_check_types(
245            &ast.generics.params,
246            &syn::parse2(quote!(::core::cmp::Ord)).unwrap(),
247            &ord_types,
248            &crate::trait_handlers::ord::supertraits(traits),
249        );
250
251        let mut generics = ast.generics.clone();
252        let where_clause = generics.make_where_clause();
253
254        for where_predicate in bound {
255            where_clause.predicates.push(where_predicate);
256        }
257
258        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
259
260        token_stream.extend(quote! {
261            impl #impl_generics ::core::cmp::Ord for #ident #ty_generics #where_clause {
262                #[inline]
263                fn cmp(&self, other: &Self) -> ::core::cmp::Ordering {
264                    #cmp_token_stream
265                }
266            }
267        });
268
269        #[cfg(feature = "PartialOrd")]
270        if traits.contains(&Trait::PartialOrd) {
271            token_stream.extend(quote! {
272                impl #impl_generics ::core::cmp::PartialOrd for #ident #ty_generics #where_clause {
273                    #[inline]
274                    fn partial_cmp(&self, other: &Self) -> Option<::core::cmp::Ordering> {
275                        Some(::core::cmp::Ord::cmp(self, other))
276                    }
277                }
278            });
279        }
280
281        Ok(())
282    }
283}