enum_ordinalize_derive/
lib.rs

1/*!
2# Enum Ordinalize Derive
3
4This library enables enums to not only obtain the ordinal values of their variants but also allows for the construction of enums from an ordinal value. See the [`enum-ordinalize`](https://crates.io/crates/enum-ordinalize) crate.
5*/
6
7#![no_std]
8
9#[macro_use]
10extern crate alloc;
11
12mod int128;
13mod int_wrapper;
14mod panic;
15mod variant_type;
16
17use alloc::{string::ToString, vec::Vec};
18
19use proc_macro::TokenStream;
20use quote::quote;
21use syn::{
22    parse::{Parse, ParseStream},
23    parse_macro_input,
24    punctuated::Punctuated,
25    spanned::Spanned,
26    Data, DeriveInput, Expr, Fields, Ident, Lit, Meta, Token, UnOp, Visibility,
27};
28use variant_type::VariantType;
29
30use crate::{int128::Int128, int_wrapper::IntWrapper};
31
32#[proc_macro_derive(Ordinalize, attributes(ordinalize))]
33pub fn ordinalize_derive(input: TokenStream) -> TokenStream {
34    struct ConstMember {
35        vis:      Option<Visibility>,
36        ident:    Ident,
37        meta:     Vec<Meta>,
38        function: bool,
39    }
40
41    impl Parse for ConstMember {
42        #[inline]
43        fn parse(input: ParseStream) -> syn::Result<Self> {
44            let vis = input.parse::<Visibility>().ok();
45
46            let _ = input.parse::<Token![const]>();
47
48            let function = input.parse::<Token![fn]>().is_ok();
49
50            let ident = input.parse::<Ident>()?;
51
52            let mut meta = Vec::new();
53
54            if !input.is_empty() {
55                input.parse::<Token![,]>()?;
56
57                if !input.is_empty() {
58                    let result = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
59
60                    let mut has_inline = false;
61
62                    for m in result {
63                        if m.path().is_ident("inline") {
64                            has_inline = true;
65                        }
66
67                        meta.push(m);
68                    }
69
70                    if !has_inline {
71                        meta.push(syn::parse_str("inline")?);
72                    }
73                }
74            }
75
76            Ok(Self {
77                vis,
78                ident,
79                meta,
80                function,
81            })
82        }
83    }
84
85    struct ConstFunctionMember {
86        vis:   Option<Visibility>,
87        ident: Ident,
88        meta:  Vec<Meta>,
89    }
90
91    impl Parse for ConstFunctionMember {
92        #[inline]
93        fn parse(input: ParseStream) -> syn::Result<Self> {
94            let vis = input.parse::<Visibility>().ok();
95
96            let _ = input.parse::<Token![const]>();
97
98            input.parse::<Token![fn]>()?;
99
100            let ident = input.parse::<Ident>()?;
101
102            let mut meta = Vec::new();
103
104            if !input.is_empty() {
105                input.parse::<Token![,]>()?;
106
107                if !input.is_empty() {
108                    let result = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
109
110                    let mut has_inline = false;
111
112                    for m in result {
113                        if m.path().is_ident("inline") {
114                            has_inline = true;
115                        }
116
117                        meta.push(m);
118                    }
119
120                    if !has_inline {
121                        meta.push(syn::parse_str("inline")?);
122                    }
123                }
124            }
125
126            Ok(Self {
127                vis,
128                ident,
129                meta,
130            })
131        }
132    }
133
134    struct MyDeriveInput {
135        ast:                        DeriveInput,
136        variant_type:               VariantType,
137        values:                     Vec<IntWrapper>,
138        variant_idents:             Vec<Ident>,
139        use_constant_counter:       bool,
140        enable_trait:               bool,
141        enable_variant_count:       Option<ConstMember>,
142        enable_variants:            Option<ConstMember>,
143        enable_values:              Option<ConstMember>,
144        enable_from_ordinal_unsafe: Option<ConstFunctionMember>,
145        enable_from_ordinal:        Option<ConstFunctionMember>,
146        enable_ordinal:             Option<ConstFunctionMember>,
147    }
148
149    impl Parse for MyDeriveInput {
150        fn parse(input: ParseStream) -> syn::Result<Self> {
151            let ast = input.parse::<DeriveInput>()?;
152
153            let mut variant_type = VariantType::default();
154            let mut enable_trait = cfg!(feature = "traits");
155            let mut enable_variant_count = None;
156            let mut enable_variants = None;
157            let mut enable_values = None;
158            let mut enable_from_ordinal_unsafe = None;
159            let mut enable_from_ordinal = None;
160            let mut enable_ordinal = None;
161
162            for attr in ast.attrs.iter() {
163                let path = attr.path();
164
165                if let Some(ident) = path.get_ident() {
166                    match ident.to_string().as_str() {
167                        "repr" => {
168                            // #[repr(u8)], #[repr(u16)], ..., etc.
169                            if let Meta::List(list) = &attr.meta {
170                                let result = list.parse_args_with(
171                                    Punctuated::<Ident, Token![,]>::parse_terminated,
172                                )?;
173
174                                if let Some(value) = result.into_iter().next() {
175                                    variant_type = VariantType::from_str(value.to_string());
176                                }
177                            }
178
179                            break;
180                        },
181                        "ordinalize" => {
182                            if let Meta::List(list) = &attr.meta {
183                                let result = list.parse_args_with(
184                                    Punctuated::<Meta, Token![,]>::parse_terminated,
185                                )?;
186
187                                for meta in result {
188                                    let path = meta.path();
189
190                                    if let Some(ident) = path.get_ident() {
191                                        match ident.to_string().as_str() {
192                                            "impl_trait" => {
193                                                if let Meta::NameValue(meta) = &meta {
194                                                    if let Expr::Lit(lit) = &meta.value {
195                                                        if let Lit::Bool(value) = &lit.lit {
196                                                            if cfg!(feature = "traits") {
197                                                                enable_trait = value.value;
198                                                            }
199                                                        } else {
200                                                            return Err(
201                                                                panic::bool_attribute_usage(
202                                                                    ident,
203                                                                    ident.span(),
204                                                                ),
205                                                            );
206                                                        }
207                                                    } else {
208                                                        return Err(panic::bool_attribute_usage(
209                                                            ident,
210                                                            ident.span(),
211                                                        ));
212                                                    }
213                                                } else {
214                                                    return Err(panic::bool_attribute_usage(
215                                                        ident,
216                                                        ident.span(),
217                                                    ));
218                                                }
219                                            },
220                                            "variant_count" => {
221                                                if let Meta::List(list) = &meta {
222                                                    enable_variant_count = Some(list.parse_args()?);
223                                                } else {
224                                                    return Err(panic::list_attribute_usage(
225                                                        ident,
226                                                        ident.span(),
227                                                    ));
228                                                }
229                                            },
230                                            "variants" => {
231                                                if let Meta::List(list) = &meta {
232                                                    enable_variants = Some(list.parse_args()?);
233                                                } else {
234                                                    return Err(panic::list_attribute_usage(
235                                                        ident,
236                                                        ident.span(),
237                                                    ));
238                                                }
239                                            },
240                                            "values" => {
241                                                if let Meta::List(list) = &meta {
242                                                    enable_values = Some(list.parse_args()?);
243                                                } else {
244                                                    return Err(panic::list_attribute_usage(
245                                                        ident,
246                                                        ident.span(),
247                                                    ));
248                                                }
249                                            },
250                                            "from_ordinal_unsafe" => {
251                                                if let Meta::List(list) = &meta {
252                                                    enable_from_ordinal_unsafe =
253                                                        Some(list.parse_args()?);
254                                                } else {
255                                                    return Err(panic::list_attribute_usage(
256                                                        ident,
257                                                        ident.span(),
258                                                    ));
259                                                }
260                                            },
261                                            "from_ordinal" => {
262                                                if let Meta::List(list) = &meta {
263                                                    enable_from_ordinal = Some(list.parse_args()?);
264                                                } else {
265                                                    return Err(panic::list_attribute_usage(
266                                                        ident,
267                                                        ident.span(),
268                                                    ));
269                                                }
270                                            },
271                                            "ordinal" => {
272                                                if let Meta::List(list) = &meta {
273                                                    enable_ordinal = Some(list.parse_args()?);
274                                                } else {
275                                                    return Err(panic::list_attribute_usage(
276                                                        ident,
277                                                        ident.span(),
278                                                    ));
279                                                }
280                                            },
281                                            _ => {
282                                                return Err(panic::sub_attributes_for_ordinalize(
283                                                    ident.span(),
284                                                ));
285                                            },
286                                        }
287                                    } else {
288                                        return Err(panic::list_attribute_usage(
289                                            ident,
290                                            ident.span(),
291                                        ));
292                                    }
293                                }
294                            } else {
295                                return Err(panic::list_attribute_usage(ident, ident.span()));
296                            }
297                        },
298                        _ => (),
299                    }
300                }
301            }
302
303            let name = &ast.ident;
304
305            if let Data::Enum(data) = &ast.data {
306                let variant_count = data.variants.len();
307
308                if variant_count == 0 {
309                    return Err(panic::no_variant(name.span()));
310                }
311
312                let mut values: Vec<IntWrapper> = Vec::with_capacity(variant_count);
313                let mut variant_idents: Vec<Ident> = Vec::with_capacity(variant_count);
314
315                let mut use_constant_counter = false;
316
317                if let VariantType::NonDetermined = variant_type {
318                    let mut min = i128::MAX;
319                    let mut max = i128::MIN;
320                    let mut counter = 0;
321
322                    for variant in data.variants.iter() {
323                        if let Fields::Unit = variant.fields {
324                            if let Some((_, exp)) = variant.discriminant.as_ref() {
325                                match exp {
326                                    Expr::Lit(lit) => {
327                                        if let Lit::Int(lit) = &lit.lit {
328                                            counter = lit.base10_parse().map_err(|error| {
329                                                syn::Error::new(lit.span(), error)
330                                            })?;
331                                        } else {
332                                            return Err(panic::unsupported_discriminant(
333                                                lit.span(),
334                                            ));
335                                        }
336                                    },
337                                    Expr::Unary(unary) => {
338                                        if let UnOp::Neg(_) = unary.op {
339                                            match unary.expr.as_ref() {
340                                            Expr::Lit(lit) => {
341                                                if let Lit::Int(lit) = &lit.lit {
342                                                    match lit.base10_parse::<i128>() {
343                                                        Ok(i) => {
344                                                            counter = -i;
345                                                        },
346                                                        Err(error) => {
347                                                            // overflow
348                                                            if lit.base10_digits() == "170141183460469231731687303715884105728" {
349                                                                counter = i128::MIN;
350                                                            } else {
351                                                                return Err(syn::Error::new(lit.span(), error));
352                                                            }
353                                                        },
354                                                    }
355                                                } else {
356                                                    return Err(panic::unsupported_discriminant(lit.span()));
357                                                }
358                                            },
359                                            Expr::Path(_)
360                                            | Expr::Cast(_)
361                                            | Expr::Binary(_)
362                                            | Expr::Call(_) => {
363                                                return Err(panic::constant_variable_on_non_determined_size_enum(unary.expr.span()))
364                                            },
365                                            _ => return Err(panic::unsupported_discriminant(unary.expr.span())),
366                                        }
367                                        } else {
368                                            return Err(panic::unsupported_discriminant(
369                                                unary.op.span(),
370                                            ));
371                                        }
372                                    },
373                                    Expr::Path(_)
374                                    | Expr::Cast(_)
375                                    | Expr::Binary(_)
376                                    | Expr::Call(_) => {
377                                        return Err(
378                                            panic::constant_variable_on_non_determined_size_enum(
379                                                exp.span(),
380                                            ),
381                                        )
382                                    },
383                                    _ => return Err(panic::unsupported_discriminant(exp.span())),
384                                }
385                            };
386
387                            if min > counter {
388                                min = counter;
389                            }
390
391                            if max < counter {
392                                max = counter;
393                            }
394
395                            variant_idents.push(variant.ident.clone());
396
397                            values.push(IntWrapper::from(counter));
398
399                            counter = counter.saturating_add(1);
400                        } else {
401                            return Err(panic::not_unit_variant(variant.span()));
402                        }
403                    }
404
405                    if min >= i8::MIN as i128 && max <= i8::MAX as i128 {
406                        variant_type = VariantType::I8;
407                    } else if min >= i16::MIN as i128 && max <= i16::MAX as i128 {
408                        variant_type = VariantType::I16;
409                    } else if min >= i32::MIN as i128 && max <= i32::MAX as i128 {
410                        variant_type = VariantType::I32;
411                    } else if min >= i64::MIN as i128 && max <= i64::MAX as i128 {
412                        variant_type = VariantType::I64;
413                    } else {
414                        variant_type = VariantType::I128;
415                    }
416                } else {
417                    let mut counter = Int128::ZERO;
418                    let mut constant_counter = 0;
419                    let mut last_exp: Option<&Expr> = None;
420
421                    for variant in data.variants.iter() {
422                        if let Fields::Unit = variant.fields {
423                            if let Some((_, exp)) = variant.discriminant.as_ref() {
424                                match exp {
425                                    Expr::Lit(lit) => {
426                                        if let Lit::Int(lit) = &lit.lit {
427                                            counter = lit.base10_parse().map_err(|error| {
428                                                syn::Error::new(lit.span(), error)
429                                            })?;
430
431                                            values.push(IntWrapper::from(counter));
432
433                                            counter.inc();
434
435                                            last_exp = None;
436                                        } else {
437                                            return Err(panic::unsupported_discriminant(
438                                                lit.span(),
439                                            ));
440                                        }
441                                    },
442                                    Expr::Unary(unary) => {
443                                        if let UnOp::Neg(_) = unary.op {
444                                            match unary.expr.as_ref() {
445                                                Expr::Lit(lit) => {
446                                                    if let Lit::Int(lit) = &lit.lit {
447                                                        counter = -lit.base10_parse().map_err(
448                                                            |error| {
449                                                                syn::Error::new(lit.span(), error)
450                                                            },
451                                                        )?;
452
453                                                        values.push(IntWrapper::from(counter));
454
455                                                        counter.inc();
456
457                                                        last_exp = None;
458                                                    } else {
459                                                        return Err(
460                                                            panic::unsupported_discriminant(
461                                                                lit.span(),
462                                                            ),
463                                                        );
464                                                    }
465                                                },
466                                                Expr::Path(_) => {
467                                                    values.push(IntWrapper::from((exp, 0)));
468
469                                                    last_exp = Some(exp);
470                                                    constant_counter = 1;
471                                                },
472                                                Expr::Cast(_) | Expr::Binary(_) | Expr::Call(_) => {
473                                                    values.push(IntWrapper::from((exp, 0)));
474
475                                                    last_exp = Some(exp);
476                                                    constant_counter = 1;
477
478                                                    use_constant_counter = true;
479                                                },
480                                                _ => {
481                                                    return Err(panic::unsupported_discriminant(
482                                                        exp.span(),
483                                                    ));
484                                                },
485                                            }
486                                        } else {
487                                            return Err(panic::unsupported_discriminant(
488                                                unary.op.span(),
489                                            ));
490                                        }
491                                    },
492                                    Expr::Path(_) => {
493                                        values.push(IntWrapper::from((exp, 0)));
494
495                                        last_exp = Some(exp);
496                                        constant_counter = 1;
497                                    },
498                                    Expr::Cast(_) | Expr::Binary(_) | Expr::Call(_) => {
499                                        values.push(IntWrapper::from((exp, 0)));
500
501                                        last_exp = Some(exp);
502                                        constant_counter = 1;
503
504                                        use_constant_counter = true;
505                                    },
506                                    _ => return Err(panic::unsupported_discriminant(exp.span())),
507                                }
508                            } else if let Some(exp) = last_exp {
509                                values.push(IntWrapper::from((exp, constant_counter)));
510
511                                constant_counter += 1;
512
513                                use_constant_counter = true;
514                            } else {
515                                values.push(IntWrapper::from(counter));
516
517                                counter.inc();
518                            }
519
520                            variant_idents.push(variant.ident.clone());
521                        } else {
522                            return Err(panic::not_unit_variant(variant.span()));
523                        }
524                    }
525                }
526
527                Ok(MyDeriveInput {
528                    ast,
529                    variant_type,
530                    values,
531                    variant_idents,
532                    use_constant_counter,
533                    enable_trait,
534                    enable_variant_count,
535                    enable_variants,
536                    enable_values,
537                    enable_from_ordinal_unsafe,
538                    enable_from_ordinal,
539                    enable_ordinal,
540                })
541            } else {
542                Err(panic::not_enum(ast.ident.span()))
543            }
544        }
545    }
546
547    // Parse the token stream
548    let derive_input = parse_macro_input!(input as MyDeriveInput);
549
550    let MyDeriveInput {
551        ast,
552        variant_type,
553        values,
554        variant_idents,
555        use_constant_counter,
556        enable_trait,
557        enable_variant_count,
558        enable_variants,
559        enable_values,
560        enable_ordinal,
561        enable_from_ordinal_unsafe,
562        enable_from_ordinal,
563    } = derive_input;
564
565    // Get the identifier of the type.
566    let name = &ast.ident;
567
568    let variant_count = values.len();
569
570    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
571
572    // Build the code
573    let mut expanded = proc_macro2::TokenStream::new();
574
575    if enable_trait {
576        #[cfg(feature = "traits")]
577        {
578            let from_ordinal_unsafe = if variant_count == 1 {
579                let variant_ident = &variant_idents[0];
580
581                quote! {
582                    #[inline]
583                    unsafe fn from_ordinal_unsafe(_number: #variant_type) -> Self {
584                        Self::#variant_ident
585                    }
586                }
587            } else {
588                quote! {
589                    #[inline]
590                    unsafe fn from_ordinal_unsafe(number: #variant_type) -> Self {
591                        ::core::mem::transmute(number)
592                    }
593                }
594            };
595
596            let from_ordinal = if use_constant_counter {
597                quote! {
598                    #[inline]
599                    fn from_ordinal(number: #variant_type) -> Option<Self> {
600                        if false {
601                            unreachable!()
602                        } #( else if number == #values {
603                            Some(Self::#variant_idents)
604                        } )* else {
605                            None
606                        }
607                    }
608                }
609            } else {
610                quote! {
611                    #[inline]
612                    fn from_ordinal(number: #variant_type) -> Option<Self> {
613                        match number{
614                            #(
615                                #values => Some(Self::#variant_idents),
616                            )*
617                            _ => None
618                        }
619                    }
620                }
621            };
622
623            expanded.extend(quote! {
624                impl #impl_generics Ordinalize for #name #ty_generics #where_clause {
625                    type VariantType = #variant_type;
626
627                    const VARIANT_COUNT: usize = #variant_count;
628
629                    const VARIANTS: &'static [Self] = &[#( Self::#variant_idents, )*];
630
631                    const VALUES: &'static [#variant_type] = &[#( #values, )*];
632
633                    #[inline]
634                    fn ordinal(&self) -> #variant_type {
635                        match self {
636                            #(
637                                Self::#variant_idents => #values,
638                            )*
639                        }
640                    }
641
642                    #from_ordinal_unsafe
643
644                    #from_ordinal
645                }
646            });
647        }
648    }
649
650    let mut expanded_2 = proc_macro2::TokenStream::new();
651
652    if let Some(ConstMember {
653        vis,
654        ident,
655        meta,
656        function,
657    }) = enable_variant_count
658    {
659        expanded_2.extend(if function {
660            quote! {
661                #(#[#meta])*
662                #vis const fn #ident () -> usize {
663                    #variant_count
664                }
665            }
666        } else {
667            quote! {
668                #(#[#meta])*
669                #vis const #ident: usize = #variant_count;
670            }
671        });
672    }
673
674    if let Some(ConstMember {
675        vis,
676        ident,
677        meta,
678        function,
679    }) = enable_variants
680    {
681        expanded_2.extend(if function {
682            quote! {
683                #(#[#meta])*
684                #vis const fn #ident () -> [Self; #variant_count] {
685                    [#( Self::#variant_idents, )*]
686                }
687            }
688        } else {
689            quote! {
690                #(#[#meta])*
691                #vis const #ident: [Self; #variant_count] = [#( Self::#variant_idents, )*];
692            }
693        });
694    }
695
696    if let Some(ConstMember {
697        vis,
698        ident,
699        meta,
700        function,
701    }) = enable_values
702    {
703        expanded_2.extend(if function {
704            quote! {
705                #(#[#meta])*
706                #vis const fn #ident () -> [#variant_type; #variant_count] {
707                    [#( #values, )*]
708                }
709            }
710        } else {
711            quote! {
712                #(#[#meta])*
713                #vis const #ident: [#variant_type; #variant_count] = [#( #values, )*];
714            }
715        });
716    }
717
718    if let Some(ConstFunctionMember {
719        vis,
720        ident,
721        meta,
722    }) = enable_from_ordinal_unsafe
723    {
724        let from_ordinal_unsafe = if variant_count == 1 {
725            let variant_ident = &variant_idents[0];
726
727            quote! {
728                #(#[#meta])*
729                #vis const unsafe fn #ident (_number: #variant_type) -> Self {
730                    Self::#variant_ident
731                }
732            }
733        } else {
734            quote! {
735                #(#[#meta])*
736                #vis const unsafe fn #ident (number: #variant_type) -> Self {
737                    ::core::mem::transmute(number)
738                }
739            }
740        };
741
742        expanded_2.extend(from_ordinal_unsafe);
743    }
744
745    if let Some(ConstFunctionMember {
746        vis,
747        ident,
748        meta,
749    }) = enable_from_ordinal
750    {
751        let from_ordinal = if use_constant_counter {
752            quote! {
753                #(#[#meta])*
754                #vis const fn #ident (number: #variant_type) -> Option<Self> {
755                    if false {
756                        unreachable!()
757                    } #( else if number == #values {
758                        Some(Self::#variant_idents)
759                    } )* else {
760                        None
761                    }
762                }
763            }
764        } else {
765            quote! {
766                #(#[#meta])*
767                #vis const fn #ident (number: #variant_type) -> Option<Self> {
768                    match number{
769                        #(
770                            #values => Some(Self::#variant_idents),
771                        )*
772                        _ => None
773                    }
774                }
775            }
776        };
777
778        expanded_2.extend(from_ordinal);
779    }
780
781    if let Some(ConstFunctionMember {
782        vis,
783        ident,
784        meta,
785    }) = enable_ordinal
786    {
787        expanded_2.extend(quote! {
788            #(#[#meta])*
789            #vis const fn #ident (&self) -> #variant_type {
790                match self {
791                    #(
792                        Self::#variant_idents => #values,
793                    )*
794                }
795            }
796        });
797    }
798
799    if !expanded_2.is_empty() {
800        expanded.extend(quote! {
801            impl #impl_generics #name #ty_generics #where_clause {
802                #expanded_2
803            }
804        });
805    }
806
807    expanded.into()
808}