educe/trait_handlers/partial_ord/
partial_ord_enum.rs

1use std::collections::BTreeMap;
2
3use quote::{format_ident, quote};
4use syn::{spanned::Spanned, Data, DeriveInput, Field, Fields, Ident, Meta, Path, Type};
5
6use super::{
7    models::{FieldAttribute, FieldAttributeBuilder, TypeAttributeBuilder},
8    TraitHandler,
9};
10use crate::{common::tools::DiscriminantType, Trait};
11
12pub(crate) struct PartialOrdEnumHandler;
13
14impl TraitHandler for PartialOrdEnumHandler {
15    #[inline]
16    fn trait_meta_handler(
17        ast: &DeriveInput,
18        token_stream: &mut proc_macro2::TokenStream,
19        traits: &[Trait],
20        meta: &Meta,
21    ) -> syn::Result<()> {
22        let type_attribute = TypeAttributeBuilder {
23            enable_flag: true, enable_bound: true
24        }
25        .build_from_partial_ord_meta(meta)?;
26
27        let mut partial_ord_types: Vec<&Type> = Vec::new();
28
29        let mut partial_cmp_token_stream = proc_macro2::TokenStream::new();
30
31        let discriminant_type = DiscriminantType::from_ast(ast)?;
32
33        let mut arms_token_stream = proc_macro2::TokenStream::new();
34
35        let mut all_unit = true;
36
37        if let Data::Enum(data) = &ast.data {
38            for variant in data.variants.iter() {
39                let _ = TypeAttributeBuilder {
40                    enable_flag: false, enable_bound: false
41                }
42                .build_from_attributes(&variant.attrs, traits)?;
43
44                let variant_ident = &variant.ident;
45
46                let built_in_partial_cmp: Path =
47                    syn::parse2(quote!(::core::cmp::PartialOrd::partial_cmp)).unwrap();
48
49                match &variant.fields {
50                    Fields::Unit => {
51                        arms_token_stream.extend(quote! {
52                            Self::#variant_ident => {
53                                return Some(::core::cmp::Ordering::Equal);
54                            }
55                        });
56                    },
57                    Fields::Named(_) => {
58                        all_unit = false;
59
60                        let mut pattern_self_token_stream = proc_macro2::TokenStream::new();
61                        let mut pattern_other_token_stream = proc_macro2::TokenStream::new();
62                        let mut block_token_stream = proc_macro2::TokenStream::new();
63
64                        let mut fields: BTreeMap<isize, (&Field, Ident, Ident, FieldAttribute)> =
65                            BTreeMap::new();
66
67                        for (index, field) in variant.fields.iter().enumerate() {
68                            let field_attribute = FieldAttributeBuilder {
69                                enable_ignore: true,
70                                enable_method: true,
71                                enable_rank:   true,
72                                rank:          isize::MIN + index as isize,
73                            }
74                            .build_from_attributes(&field.attrs, traits)?;
75
76                            let field_name_real = field.ident.as_ref().unwrap();
77                            let field_name_var_self = format_ident!("_s_{}", field_name_real);
78                            let field_name_var_other = format_ident!("_o_{}", field_name_real);
79
80                            if field_attribute.ignore {
81                                pattern_self_token_stream.extend(quote!(#field_name_real: _,));
82                                pattern_other_token_stream.extend(quote!(#field_name_real: _,));
83
84                                continue;
85                            }
86
87                            pattern_self_token_stream
88                                .extend(quote!(#field_name_real: #field_name_var_self,));
89                            pattern_other_token_stream
90                                .extend(quote!(#field_name_real: #field_name_var_other,));
91
92                            let rank = field_attribute.rank;
93
94                            if fields.contains_key(&rank) {
95                                return Err(super::panic::reuse_a_rank(
96                                    field_attribute.rank_span.unwrap_or_else(|| field.span()),
97                                    rank,
98                                ));
99                            }
100
101                            fields.insert(
102                                rank,
103                                (field, field_name_var_self, field_name_var_other, field_attribute),
104                            );
105                        }
106
107                        for (field, field_name_var_self, field_name_var_other, field_attribute) in
108                            fields.values()
109                        {
110                            let partial_cmp =
111                                field_attribute.method.as_ref().unwrap_or_else(|| {
112                                    partial_ord_types.push(&field.ty);
113
114                                    &built_in_partial_cmp
115                                });
116
117                            block_token_stream.extend(quote! {
118                                match #partial_cmp(#field_name_var_self, #field_name_var_other) {
119                                    Some(::core::cmp::Ordering::Equal) => (),
120                                    Some(::core::cmp::Ordering::Greater) => return Some(::core::cmp::Ordering::Greater),
121                                    Some(::core::cmp::Ordering::Less) => return Some(::core::cmp::Ordering::Less),
122                                    None => return None,
123                                }
124                            });
125                        }
126
127                        arms_token_stream.extend(quote! {
128                            Self::#variant_ident { #pattern_self_token_stream } => {
129                                if let Self::#variant_ident { #pattern_other_token_stream } = other {
130                                    #block_token_stream
131                                }
132                            }
133                        });
134                    },
135                    Fields::Unnamed(_) => {
136                        all_unit = false;
137
138                        let mut pattern_token_stream = proc_macro2::TokenStream::new();
139                        let mut pattern2_token_stream = proc_macro2::TokenStream::new();
140                        let mut block_token_stream = proc_macro2::TokenStream::new();
141
142                        let mut fields: BTreeMap<isize, (&Field, Ident, Ident, FieldAttribute)> =
143                            BTreeMap::new();
144
145                        for (index, field) in variant.fields.iter().enumerate() {
146                            let field_attribute = FieldAttributeBuilder {
147                                enable_ignore: true,
148                                enable_method: true,
149                                enable_rank:   true,
150                                rank:          isize::MIN + index as isize,
151                            }
152                            .build_from_attributes(&field.attrs, traits)?;
153
154                            let field_name_var_self = format_ident!("_{}", index);
155
156                            if field_attribute.ignore {
157                                pattern_token_stream.extend(quote!(_,));
158                                pattern2_token_stream.extend(quote!(_,));
159
160                                continue;
161                            }
162
163                            let field_name_var_other = format_ident!("_{}", field_name_var_self);
164
165                            pattern_token_stream.extend(quote!(#field_name_var_self,));
166                            pattern2_token_stream.extend(quote!(#field_name_var_other,));
167
168                            let rank = field_attribute.rank;
169
170                            if fields.contains_key(&rank) {
171                                return Err(super::panic::reuse_a_rank(
172                                    field_attribute.rank_span.unwrap_or_else(|| field.span()),
173                                    rank,
174                                ));
175                            }
176
177                            fields.insert(
178                                rank,
179                                (field, field_name_var_self, field_name_var_other, field_attribute),
180                            );
181                        }
182
183                        for (field, field_name, field_name2, field_attribute) in fields.values() {
184                            let partial_cmp =
185                                field_attribute.method.as_ref().unwrap_or_else(|| {
186                                    partial_ord_types.push(&field.ty);
187
188                                    &built_in_partial_cmp
189                                });
190
191                            block_token_stream.extend(quote! {
192                                match #partial_cmp(#field_name, #field_name2) {
193                                    Some(::core::cmp::Ordering::Equal) => (),
194                                    Some(::core::cmp::Ordering::Greater) => return Some(::core::cmp::Ordering::Greater),
195                                    Some(::core::cmp::Ordering::Less) => return Some(::core::cmp::Ordering::Less),
196                                    None => return None,
197                                }
198                            });
199                        }
200
201                        arms_token_stream.extend(quote! {
202                            Self::#variant_ident ( #pattern_token_stream ) => {
203                                if let Self::#variant_ident ( #pattern2_token_stream ) = other {
204                                    #block_token_stream
205                                }
206                            }
207                        });
208                    },
209                }
210            }
211        }
212
213        if arms_token_stream.is_empty() {
214            partial_cmp_token_stream.extend(quote!(Some(::core::cmp::Ordering::Equal)));
215        } else {
216            let discriminant_cmp = quote! {
217                unsafe {
218                    ::core::cmp::Ord::cmp(&*<*const _>::from(self).cast::<#discriminant_type>(), &*<*const _>::from(other).cast::<#discriminant_type>())
219                }
220            };
221
222            partial_cmp_token_stream.extend(if all_unit {
223                quote! {
224                    match #discriminant_cmp {
225                        ::core::cmp::Ordering::Equal => Some(::core::cmp::Ordering::Equal),
226                        ::core::cmp::Ordering::Greater => Some(::core::cmp::Ordering::Greater),
227                        ::core::cmp::Ordering::Less => Some(::core::cmp::Ordering::Less),
228                    }
229                }
230            } else {
231                quote! {
232                    match #discriminant_cmp {
233                        ::core::cmp::Ordering::Equal => {
234                            match self {
235                                #arms_token_stream
236                            }
237
238                            Some(::core::cmp::Ordering::Equal)
239                        },
240                        ::core::cmp::Ordering::Greater => Some(::core::cmp::Ordering::Greater),
241                        ::core::cmp::Ordering::Less => Some(::core::cmp::Ordering::Less),
242                    }
243                }
244            });
245        }
246
247        let ident = &ast.ident;
248
249        let bound = type_attribute.bound.into_where_predicates_by_generic_parameters_check_types(
250            &ast.generics.params,
251            &syn::parse2(quote!(::core::cmp::PartialOrd)).unwrap(),
252            &partial_ord_types,
253            &[quote! {::core::cmp::PartialEq}],
254        );
255
256        let mut generics = ast.generics.clone();
257        let where_clause = generics.make_where_clause();
258
259        for where_predicate in bound {
260            where_clause.predicates.push(where_predicate);
261        }
262
263        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
264
265        token_stream.extend(quote! {
266            impl #impl_generics ::core::cmp::PartialOrd for #ident #ty_generics #where_clause {
267                #[inline]
268                fn partial_cmp(&self, other: &Self) -> Option<::core::cmp::Ordering> {
269                    #partial_cmp_token_stream
270                }
271            }
272        });
273
274        Ok(())
275    }
276}