educe/trait_handlers/ord/
ord_enum.rs1use std::collections::BTreeMap;
2
3use quote::{format_ident, quote};
4use syn::{spanned::Spanned, Data, DeriveInput, Field, Fields, Ident, Meta, Path, Type};
5
6use super::{
7 models::{FieldAttribute, FieldAttributeBuilder, TypeAttributeBuilder},
8 TraitHandler,
9};
10use crate::{common::tools::DiscriminantType, Trait};
11
12pub(crate) struct OrdEnumHandler;
13
14impl TraitHandler for OrdEnumHandler {
15 #[inline]
16 fn trait_meta_handler(
17 ast: &DeriveInput,
18 token_stream: &mut proc_macro2::TokenStream,
19 traits: &[Trait],
20 meta: &Meta,
21 ) -> syn::Result<()> {
22 let type_attribute = TypeAttributeBuilder {
23 enable_flag: true, enable_bound: true
24 }
25 .build_from_ord_meta(meta)?;
26
27 let mut ord_types: Vec<&Type> = Vec::new();
28
29 let mut cmp_token_stream = proc_macro2::TokenStream::new();
30
31 let discriminant_type = DiscriminantType::from_ast(ast)?;
32
33 let mut arms_token_stream = proc_macro2::TokenStream::new();
34
35 let mut all_unit = true;
36
37 if let Data::Enum(data) = &ast.data {
38 for variant in data.variants.iter() {
39 let _ = TypeAttributeBuilder {
40 enable_flag: false, enable_bound: false
41 }
42 .build_from_attributes(&variant.attrs, traits)?;
43
44 let variant_ident = &variant.ident;
45
46 let built_in_cmp: Path = syn::parse2(quote!(::core::cmp::Ord::cmp)).unwrap();
47
48 match &variant.fields {
49 Fields::Unit => {
50 arms_token_stream.extend(quote! {
51 Self::#variant_ident => {
52 return ::core::cmp::Ordering::Equal;
53 }
54 });
55 },
56 Fields::Named(_) => {
57 all_unit = false;
58
59 let mut pattern_self_token_stream = proc_macro2::TokenStream::new();
60 let mut pattern_other_token_stream = proc_macro2::TokenStream::new();
61 let mut block_token_stream = proc_macro2::TokenStream::new();
62
63 let mut fields: BTreeMap<isize, (&Field, Ident, Ident, FieldAttribute)> =
64 BTreeMap::new();
65
66 for (index, field) in variant.fields.iter().enumerate() {
67 let field_attribute = FieldAttributeBuilder {
68 enable_ignore: true,
69 enable_method: true,
70 enable_rank: true,
71 rank: isize::MIN + index as isize,
72 }
73 .build_from_attributes(&field.attrs, traits)?;
74
75 let field_name_real = field.ident.as_ref().unwrap();
76 let field_name_var_self = format_ident!("_s_{}", field_name_real);
77 let field_name_var_other = format_ident!("_o_{}", field_name_real);
78
79 if field_attribute.ignore {
80 pattern_self_token_stream.extend(quote!(#field_name_real: _,));
81 pattern_other_token_stream.extend(quote!(#field_name_real: _,));
82
83 continue;
84 }
85
86 pattern_self_token_stream
87 .extend(quote!(#field_name_real: #field_name_var_self,));
88 pattern_other_token_stream
89 .extend(quote!(#field_name_real: #field_name_var_other,));
90
91 let rank = field_attribute.rank;
92
93 if fields.contains_key(&rank) {
94 return Err(super::panic::reuse_a_rank(
95 field_attribute.rank_span.unwrap_or_else(|| field.span()),
96 rank,
97 ));
98 }
99
100 fields.insert(
101 rank,
102 (field, field_name_var_self, field_name_var_other, field_attribute),
103 );
104 }
105
106 for (field, field_name_var_self, field_name_var_other, field_attribute) in
107 fields.values()
108 {
109 let cmp = field_attribute.method.as_ref().unwrap_or_else(|| {
110 ord_types.push(&field.ty);
111
112 &built_in_cmp
113 });
114
115 block_token_stream.extend(quote! {
116 match #cmp(#field_name_var_self, #field_name_var_other) {
117 ::core::cmp::Ordering::Equal => (),
118 ::core::cmp::Ordering::Greater => return ::core::cmp::Ordering::Greater,
119 ::core::cmp::Ordering::Less => return ::core::cmp::Ordering::Less,
120 }
121 });
122 }
123
124 arms_token_stream.extend(quote! {
125 Self::#variant_ident { #pattern_self_token_stream } => {
126 if let Self::#variant_ident { #pattern_other_token_stream } = other {
127 #block_token_stream
128 }
129 }
130 });
131 },
132 Fields::Unnamed(_) => {
133 all_unit = false;
134
135 let mut pattern_token_stream = proc_macro2::TokenStream::new();
136 let mut pattern2_token_stream = proc_macro2::TokenStream::new();
137 let mut block_token_stream = proc_macro2::TokenStream::new();
138
139 let mut fields: BTreeMap<isize, (&Field, Ident, Ident, FieldAttribute)> =
140 BTreeMap::new();
141
142 for (index, field) in variant.fields.iter().enumerate() {
143 let field_attribute = FieldAttributeBuilder {
144 enable_ignore: true,
145 enable_method: true,
146 enable_rank: true,
147 rank: isize::MIN + index as isize,
148 }
149 .build_from_attributes(&field.attrs, traits)?;
150
151 let field_name_var_self = format_ident!("_{}", index);
152
153 if field_attribute.ignore {
154 pattern_token_stream.extend(quote!(_,));
155 pattern2_token_stream.extend(quote!(_,));
156
157 continue;
158 }
159
160 let field_name_var_other = format_ident!("_{}", field_name_var_self);
161
162 pattern_token_stream.extend(quote!(#field_name_var_self,));
163 pattern2_token_stream.extend(quote!(#field_name_var_other,));
164
165 let rank = field_attribute.rank;
166
167 if fields.contains_key(&rank) {
168 return Err(super::panic::reuse_a_rank(
169 field_attribute.rank_span.unwrap_or_else(|| field.span()),
170 rank,
171 ));
172 }
173
174 fields.insert(
175 rank,
176 (field, field_name_var_self, field_name_var_other, field_attribute),
177 );
178 }
179
180 for (field, field_name, field_name2, field_attribute) in fields.values() {
181 let cmp = field_attribute.method.as_ref().unwrap_or_else(|| {
182 ord_types.push(&field.ty);
183
184 &built_in_cmp
185 });
186
187 block_token_stream.extend(quote! {
188 match #cmp(#field_name, #field_name2) {
189 ::core::cmp::Ordering::Equal => (),
190 ::core::cmp::Ordering::Greater => return ::core::cmp::Ordering::Greater,
191 ::core::cmp::Ordering::Less => return ::core::cmp::Ordering::Less,
192 }
193 });
194 }
195
196 arms_token_stream.extend(quote! {
197 Self::#variant_ident ( #pattern_token_stream ) => {
198 if let Self::#variant_ident ( #pattern2_token_stream ) = other {
199 #block_token_stream
200 }
201 }
202 });
203 },
204 }
205 }
206 }
207
208 if arms_token_stream.is_empty() {
209 cmp_token_stream.extend(quote!(::core::cmp::Ordering::Equal));
210 } else {
211 let discriminant_cmp = quote! {
212 unsafe {
213 ::core::cmp::Ord::cmp(&*<*const _>::from(self).cast::<#discriminant_type>(), &*<*const _>::from(other).cast::<#discriminant_type>())
214 }
215 };
216
217 cmp_token_stream.extend(if all_unit {
218 quote! {
219 match #discriminant_cmp {
220 ::core::cmp::Ordering::Equal => ::core::cmp::Ordering::Equal,
221 ::core::cmp::Ordering::Greater => ::core::cmp::Ordering::Greater,
222 ::core::cmp::Ordering::Less => ::core::cmp::Ordering::Less,
223 }
224 }
225 } else {
226 quote! {
227 match #discriminant_cmp {
228 ::core::cmp::Ordering::Equal => {
229 match self {
230 #arms_token_stream
231 }
232
233 ::core::cmp::Ordering::Equal
234 },
235 ::core::cmp::Ordering::Greater => ::core::cmp::Ordering::Greater,
236 ::core::cmp::Ordering::Less => ::core::cmp::Ordering::Less,
237 }
238 }
239 });
240 }
241
242 let ident = &ast.ident;
243
244 let bound = type_attribute.bound.into_where_predicates_by_generic_parameters_check_types(
245 &ast.generics.params,
246 &syn::parse2(quote!(::core::cmp::Ord)).unwrap(),
247 &ord_types,
248 &crate::trait_handlers::ord::supertraits(traits),
249 );
250
251 let mut generics = ast.generics.clone();
252 let where_clause = generics.make_where_clause();
253
254 for where_predicate in bound {
255 where_clause.predicates.push(where_predicate);
256 }
257
258 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
259
260 token_stream.extend(quote! {
261 impl #impl_generics ::core::cmp::Ord for #ident #ty_generics #where_clause {
262 #[inline]
263 fn cmp(&self, other: &Self) -> ::core::cmp::Ordering {
264 #cmp_token_stream
265 }
266 }
267 });
268
269 #[cfg(feature = "PartialOrd")]
270 if traits.contains(&Trait::PartialOrd) {
271 token_stream.extend(quote! {
272 impl #impl_generics ::core::cmp::PartialOrd for #ident #ty_generics #where_clause {
273 #[inline]
274 fn partial_cmp(&self, other: &Self) -> Option<::core::cmp::Ordering> {
275 Some(::core::cmp::Ord::cmp(self, other))
276 }
277 }
278 });
279 }
280
281 Ok(())
282 }
283}