enum_ordinalize_derive/lib.rs
1/*!
2# Enum Ordinalize Derive
3
4This library enables enums to not only obtain the ordinal values of their variants but also allows for the construction of enums from an ordinal value. See the [`enum-ordinalize`](https://crates.io/crates/enum-ordinalize) crate.
5*/
6
7#![no_std]
8
9#[macro_use]
10extern crate alloc;
11
12mod int128;
13mod int_wrapper;
14mod panic;
15mod variant_type;
16
17use alloc::{string::ToString, vec::Vec};
18
19use proc_macro::TokenStream;
20use quote::quote;
21use syn::{
22 parse::{Parse, ParseStream},
23 parse_macro_input,
24 punctuated::Punctuated,
25 spanned::Spanned,
26 Data, DeriveInput, Expr, Fields, Ident, Lit, Meta, Token, UnOp, Visibility,
27};
28use variant_type::VariantType;
29
30use crate::{int128::Int128, int_wrapper::IntWrapper};
31
32#[proc_macro_derive(Ordinalize, attributes(ordinalize))]
33pub fn ordinalize_derive(input: TokenStream) -> TokenStream {
34 struct ConstMember {
35 vis: Option<Visibility>,
36 ident: Ident,
37 meta: Vec<Meta>,
38 function: bool,
39 }
40
41 impl Parse for ConstMember {
42 #[inline]
43 fn parse(input: ParseStream) -> syn::Result<Self> {
44 let vis = input.parse::<Visibility>().ok();
45
46 let _ = input.parse::<Token![const]>();
47
48 let function = input.parse::<Token![fn]>().is_ok();
49
50 let ident = input.parse::<Ident>()?;
51
52 let mut meta = Vec::new();
53
54 if !input.is_empty() {
55 input.parse::<Token![,]>()?;
56
57 if !input.is_empty() {
58 let result = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
59
60 let mut has_inline = false;
61
62 for m in result {
63 if m.path().is_ident("inline") {
64 has_inline = true;
65 }
66
67 meta.push(m);
68 }
69
70 if !has_inline {
71 meta.push(syn::parse_str("inline")?);
72 }
73 }
74 }
75
76 Ok(Self {
77 vis,
78 ident,
79 meta,
80 function,
81 })
82 }
83 }
84
85 struct ConstFunctionMember {
86 vis: Option<Visibility>,
87 ident: Ident,
88 meta: Vec<Meta>,
89 }
90
91 impl Parse for ConstFunctionMember {
92 #[inline]
93 fn parse(input: ParseStream) -> syn::Result<Self> {
94 let vis = input.parse::<Visibility>().ok();
95
96 let _ = input.parse::<Token![const]>();
97
98 input.parse::<Token![fn]>()?;
99
100 let ident = input.parse::<Ident>()?;
101
102 let mut meta = Vec::new();
103
104 if !input.is_empty() {
105 input.parse::<Token![,]>()?;
106
107 if !input.is_empty() {
108 let result = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
109
110 let mut has_inline = false;
111
112 for m in result {
113 if m.path().is_ident("inline") {
114 has_inline = true;
115 }
116
117 meta.push(m);
118 }
119
120 if !has_inline {
121 meta.push(syn::parse_str("inline")?);
122 }
123 }
124 }
125
126 Ok(Self {
127 vis,
128 ident,
129 meta,
130 })
131 }
132 }
133
134 struct MyDeriveInput {
135 ast: DeriveInput,
136 variant_type: VariantType,
137 values: Vec<IntWrapper>,
138 variant_idents: Vec<Ident>,
139 use_constant_counter: bool,
140 enable_trait: bool,
141 enable_variant_count: Option<ConstMember>,
142 enable_variants: Option<ConstMember>,
143 enable_values: Option<ConstMember>,
144 enable_from_ordinal_unsafe: Option<ConstFunctionMember>,
145 enable_from_ordinal: Option<ConstFunctionMember>,
146 enable_ordinal: Option<ConstFunctionMember>,
147 }
148
149 impl Parse for MyDeriveInput {
150 fn parse(input: ParseStream) -> syn::Result<Self> {
151 let ast = input.parse::<DeriveInput>()?;
152
153 let mut variant_type = VariantType::default();
154 let mut enable_trait = cfg!(feature = "traits");
155 let mut enable_variant_count = None;
156 let mut enable_variants = None;
157 let mut enable_values = None;
158 let mut enable_from_ordinal_unsafe = None;
159 let mut enable_from_ordinal = None;
160 let mut enable_ordinal = None;
161
162 for attr in ast.attrs.iter() {
163 let path = attr.path();
164
165 if let Some(ident) = path.get_ident() {
166 match ident.to_string().as_str() {
167 "repr" => {
168 // #[repr(u8)], #[repr(u16)], ..., etc.
169 if let Meta::List(list) = &attr.meta {
170 let result = list.parse_args_with(
171 Punctuated::<Ident, Token![,]>::parse_terminated,
172 )?;
173
174 if let Some(value) = result.into_iter().next() {
175 variant_type = VariantType::from_str(value.to_string());
176 }
177 }
178
179 break;
180 },
181 "ordinalize" => {
182 if let Meta::List(list) = &attr.meta {
183 let result = list.parse_args_with(
184 Punctuated::<Meta, Token![,]>::parse_terminated,
185 )?;
186
187 for meta in result {
188 let path = meta.path();
189
190 if let Some(ident) = path.get_ident() {
191 match ident.to_string().as_str() {
192 "impl_trait" => {
193 if let Meta::NameValue(meta) = &meta {
194 if let Expr::Lit(lit) = &meta.value {
195 if let Lit::Bool(value) = &lit.lit {
196 if cfg!(feature = "traits") {
197 enable_trait = value.value;
198 }
199 } else {
200 return Err(
201 panic::bool_attribute_usage(
202 ident,
203 ident.span(),
204 ),
205 );
206 }
207 } else {
208 return Err(panic::bool_attribute_usage(
209 ident,
210 ident.span(),
211 ));
212 }
213 } else {
214 return Err(panic::bool_attribute_usage(
215 ident,
216 ident.span(),
217 ));
218 }
219 },
220 "variant_count" => {
221 if let Meta::List(list) = &meta {
222 enable_variant_count = Some(list.parse_args()?);
223 } else {
224 return Err(panic::list_attribute_usage(
225 ident,
226 ident.span(),
227 ));
228 }
229 },
230 "variants" => {
231 if let Meta::List(list) = &meta {
232 enable_variants = Some(list.parse_args()?);
233 } else {
234 return Err(panic::list_attribute_usage(
235 ident,
236 ident.span(),
237 ));
238 }
239 },
240 "values" => {
241 if let Meta::List(list) = &meta {
242 enable_values = Some(list.parse_args()?);
243 } else {
244 return Err(panic::list_attribute_usage(
245 ident,
246 ident.span(),
247 ));
248 }
249 },
250 "from_ordinal_unsafe" => {
251 if let Meta::List(list) = &meta {
252 enable_from_ordinal_unsafe =
253 Some(list.parse_args()?);
254 } else {
255 return Err(panic::list_attribute_usage(
256 ident,
257 ident.span(),
258 ));
259 }
260 },
261 "from_ordinal" => {
262 if let Meta::List(list) = &meta {
263 enable_from_ordinal = Some(list.parse_args()?);
264 } else {
265 return Err(panic::list_attribute_usage(
266 ident,
267 ident.span(),
268 ));
269 }
270 },
271 "ordinal" => {
272 if let Meta::List(list) = &meta {
273 enable_ordinal = Some(list.parse_args()?);
274 } else {
275 return Err(panic::list_attribute_usage(
276 ident,
277 ident.span(),
278 ));
279 }
280 },
281 _ => {
282 return Err(panic::sub_attributes_for_ordinalize(
283 ident.span(),
284 ));
285 },
286 }
287 } else {
288 return Err(panic::list_attribute_usage(
289 ident,
290 ident.span(),
291 ));
292 }
293 }
294 } else {
295 return Err(panic::list_attribute_usage(ident, ident.span()));
296 }
297 },
298 _ => (),
299 }
300 }
301 }
302
303 let name = &ast.ident;
304
305 if let Data::Enum(data) = &ast.data {
306 let variant_count = data.variants.len();
307
308 if variant_count == 0 {
309 return Err(panic::no_variant(name.span()));
310 }
311
312 let mut values: Vec<IntWrapper> = Vec::with_capacity(variant_count);
313 let mut variant_idents: Vec<Ident> = Vec::with_capacity(variant_count);
314
315 let mut use_constant_counter = false;
316
317 if let VariantType::NonDetermined = variant_type {
318 let mut min = i128::MAX;
319 let mut max = i128::MIN;
320 let mut counter = 0;
321
322 for variant in data.variants.iter() {
323 if let Fields::Unit = variant.fields {
324 if let Some((_, exp)) = variant.discriminant.as_ref() {
325 match exp {
326 Expr::Lit(lit) => {
327 if let Lit::Int(lit) = &lit.lit {
328 counter = lit.base10_parse().map_err(|error| {
329 syn::Error::new(lit.span(), error)
330 })?;
331 } else {
332 return Err(panic::unsupported_discriminant(
333 lit.span(),
334 ));
335 }
336 },
337 Expr::Unary(unary) => {
338 if let UnOp::Neg(_) = unary.op {
339 match unary.expr.as_ref() {
340 Expr::Lit(lit) => {
341 if let Lit::Int(lit) = &lit.lit {
342 match lit.base10_parse::<i128>() {
343 Ok(i) => {
344 counter = -i;
345 },
346 Err(error) => {
347 // overflow
348 if lit.base10_digits() == "170141183460469231731687303715884105728" {
349 counter = i128::MIN;
350 } else {
351 return Err(syn::Error::new(lit.span(), error));
352 }
353 },
354 }
355 } else {
356 return Err(panic::unsupported_discriminant(lit.span()));
357 }
358 },
359 Expr::Path(_)
360 | Expr::Cast(_)
361 | Expr::Binary(_)
362 | Expr::Call(_) => {
363 return Err(panic::constant_variable_on_non_determined_size_enum(unary.expr.span()))
364 },
365 _ => return Err(panic::unsupported_discriminant(unary.expr.span())),
366 }
367 } else {
368 return Err(panic::unsupported_discriminant(
369 unary.op.span(),
370 ));
371 }
372 },
373 Expr::Path(_)
374 | Expr::Cast(_)
375 | Expr::Binary(_)
376 | Expr::Call(_) => {
377 return Err(
378 panic::constant_variable_on_non_determined_size_enum(
379 exp.span(),
380 ),
381 )
382 },
383 _ => return Err(panic::unsupported_discriminant(exp.span())),
384 }
385 };
386
387 if min > counter {
388 min = counter;
389 }
390
391 if max < counter {
392 max = counter;
393 }
394
395 variant_idents.push(variant.ident.clone());
396
397 values.push(IntWrapper::from(counter));
398
399 counter = counter.saturating_add(1);
400 } else {
401 return Err(panic::not_unit_variant(variant.span()));
402 }
403 }
404
405 if min >= i8::MIN as i128 && max <= i8::MAX as i128 {
406 variant_type = VariantType::I8;
407 } else if min >= i16::MIN as i128 && max <= i16::MAX as i128 {
408 variant_type = VariantType::I16;
409 } else if min >= i32::MIN as i128 && max <= i32::MAX as i128 {
410 variant_type = VariantType::I32;
411 } else if min >= i64::MIN as i128 && max <= i64::MAX as i128 {
412 variant_type = VariantType::I64;
413 } else {
414 variant_type = VariantType::I128;
415 }
416 } else {
417 let mut counter = Int128::ZERO;
418 let mut constant_counter = 0;
419 let mut last_exp: Option<&Expr> = None;
420
421 for variant in data.variants.iter() {
422 if let Fields::Unit = variant.fields {
423 if let Some((_, exp)) = variant.discriminant.as_ref() {
424 match exp {
425 Expr::Lit(lit) => {
426 if let Lit::Int(lit) = &lit.lit {
427 counter = lit.base10_parse().map_err(|error| {
428 syn::Error::new(lit.span(), error)
429 })?;
430
431 values.push(IntWrapper::from(counter));
432
433 counter.inc();
434
435 last_exp = None;
436 } else {
437 return Err(panic::unsupported_discriminant(
438 lit.span(),
439 ));
440 }
441 },
442 Expr::Unary(unary) => {
443 if let UnOp::Neg(_) = unary.op {
444 match unary.expr.as_ref() {
445 Expr::Lit(lit) => {
446 if let Lit::Int(lit) = &lit.lit {
447 counter = -lit.base10_parse().map_err(
448 |error| {
449 syn::Error::new(lit.span(), error)
450 },
451 )?;
452
453 values.push(IntWrapper::from(counter));
454
455 counter.inc();
456
457 last_exp = None;
458 } else {
459 return Err(
460 panic::unsupported_discriminant(
461 lit.span(),
462 ),
463 );
464 }
465 },
466 Expr::Path(_) => {
467 values.push(IntWrapper::from((exp, 0)));
468
469 last_exp = Some(exp);
470 constant_counter = 1;
471 },
472 Expr::Cast(_) | Expr::Binary(_) | Expr::Call(_) => {
473 values.push(IntWrapper::from((exp, 0)));
474
475 last_exp = Some(exp);
476 constant_counter = 1;
477
478 use_constant_counter = true;
479 },
480 _ => {
481 return Err(panic::unsupported_discriminant(
482 exp.span(),
483 ));
484 },
485 }
486 } else {
487 return Err(panic::unsupported_discriminant(
488 unary.op.span(),
489 ));
490 }
491 },
492 Expr::Path(_) => {
493 values.push(IntWrapper::from((exp, 0)));
494
495 last_exp = Some(exp);
496 constant_counter = 1;
497 },
498 Expr::Cast(_) | Expr::Binary(_) | Expr::Call(_) => {
499 values.push(IntWrapper::from((exp, 0)));
500
501 last_exp = Some(exp);
502 constant_counter = 1;
503
504 use_constant_counter = true;
505 },
506 _ => return Err(panic::unsupported_discriminant(exp.span())),
507 }
508 } else if let Some(exp) = last_exp {
509 values.push(IntWrapper::from((exp, constant_counter)));
510
511 constant_counter += 1;
512
513 use_constant_counter = true;
514 } else {
515 values.push(IntWrapper::from(counter));
516
517 counter.inc();
518 }
519
520 variant_idents.push(variant.ident.clone());
521 } else {
522 return Err(panic::not_unit_variant(variant.span()));
523 }
524 }
525 }
526
527 Ok(MyDeriveInput {
528 ast,
529 variant_type,
530 values,
531 variant_idents,
532 use_constant_counter,
533 enable_trait,
534 enable_variant_count,
535 enable_variants,
536 enable_values,
537 enable_from_ordinal_unsafe,
538 enable_from_ordinal,
539 enable_ordinal,
540 })
541 } else {
542 Err(panic::not_enum(ast.ident.span()))
543 }
544 }
545 }
546
547 // Parse the token stream
548 let derive_input = parse_macro_input!(input as MyDeriveInput);
549
550 let MyDeriveInput {
551 ast,
552 variant_type,
553 values,
554 variant_idents,
555 use_constant_counter,
556 enable_trait,
557 enable_variant_count,
558 enable_variants,
559 enable_values,
560 enable_ordinal,
561 enable_from_ordinal_unsafe,
562 enable_from_ordinal,
563 } = derive_input;
564
565 // Get the identifier of the type.
566 let name = &ast.ident;
567
568 let variant_count = values.len();
569
570 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
571
572 // Build the code
573 let mut expanded = proc_macro2::TokenStream::new();
574
575 if enable_trait {
576 #[cfg(feature = "traits")]
577 {
578 let from_ordinal_unsafe = if variant_count == 1 {
579 let variant_ident = &variant_idents[0];
580
581 quote! {
582 #[inline]
583 unsafe fn from_ordinal_unsafe(_number: #variant_type) -> Self {
584 Self::#variant_ident
585 }
586 }
587 } else {
588 quote! {
589 #[inline]
590 unsafe fn from_ordinal_unsafe(number: #variant_type) -> Self {
591 ::core::mem::transmute(number)
592 }
593 }
594 };
595
596 let from_ordinal = if use_constant_counter {
597 quote! {
598 #[inline]
599 fn from_ordinal(number: #variant_type) -> Option<Self> {
600 if false {
601 unreachable!()
602 } #( else if number == #values {
603 Some(Self::#variant_idents)
604 } )* else {
605 None
606 }
607 }
608 }
609 } else {
610 quote! {
611 #[inline]
612 fn from_ordinal(number: #variant_type) -> Option<Self> {
613 match number{
614 #(
615 #values => Some(Self::#variant_idents),
616 )*
617 _ => None
618 }
619 }
620 }
621 };
622
623 expanded.extend(quote! {
624 impl #impl_generics Ordinalize for #name #ty_generics #where_clause {
625 type VariantType = #variant_type;
626
627 const VARIANT_COUNT: usize = #variant_count;
628
629 const VARIANTS: &'static [Self] = &[#( Self::#variant_idents, )*];
630
631 const VALUES: &'static [#variant_type] = &[#( #values, )*];
632
633 #[inline]
634 fn ordinal(&self) -> #variant_type {
635 match self {
636 #(
637 Self::#variant_idents => #values,
638 )*
639 }
640 }
641
642 #from_ordinal_unsafe
643
644 #from_ordinal
645 }
646 });
647 }
648 }
649
650 let mut expanded_2 = proc_macro2::TokenStream::new();
651
652 if let Some(ConstMember {
653 vis,
654 ident,
655 meta,
656 function,
657 }) = enable_variant_count
658 {
659 expanded_2.extend(if function {
660 quote! {
661 #(#[#meta])*
662 #vis const fn #ident () -> usize {
663 #variant_count
664 }
665 }
666 } else {
667 quote! {
668 #(#[#meta])*
669 #vis const #ident: usize = #variant_count;
670 }
671 });
672 }
673
674 if let Some(ConstMember {
675 vis,
676 ident,
677 meta,
678 function,
679 }) = enable_variants
680 {
681 expanded_2.extend(if function {
682 quote! {
683 #(#[#meta])*
684 #vis const fn #ident () -> [Self; #variant_count] {
685 [#( Self::#variant_idents, )*]
686 }
687 }
688 } else {
689 quote! {
690 #(#[#meta])*
691 #vis const #ident: [Self; #variant_count] = [#( Self::#variant_idents, )*];
692 }
693 });
694 }
695
696 if let Some(ConstMember {
697 vis,
698 ident,
699 meta,
700 function,
701 }) = enable_values
702 {
703 expanded_2.extend(if function {
704 quote! {
705 #(#[#meta])*
706 #vis const fn #ident () -> [#variant_type; #variant_count] {
707 [#( #values, )*]
708 }
709 }
710 } else {
711 quote! {
712 #(#[#meta])*
713 #vis const #ident: [#variant_type; #variant_count] = [#( #values, )*];
714 }
715 });
716 }
717
718 if let Some(ConstFunctionMember {
719 vis,
720 ident,
721 meta,
722 }) = enable_from_ordinal_unsafe
723 {
724 let from_ordinal_unsafe = if variant_count == 1 {
725 let variant_ident = &variant_idents[0];
726
727 quote! {
728 #(#[#meta])*
729 #vis const unsafe fn #ident (_number: #variant_type) -> Self {
730 Self::#variant_ident
731 }
732 }
733 } else {
734 quote! {
735 #(#[#meta])*
736 #vis const unsafe fn #ident (number: #variant_type) -> Self {
737 ::core::mem::transmute(number)
738 }
739 }
740 };
741
742 expanded_2.extend(from_ordinal_unsafe);
743 }
744
745 if let Some(ConstFunctionMember {
746 vis,
747 ident,
748 meta,
749 }) = enable_from_ordinal
750 {
751 let from_ordinal = if use_constant_counter {
752 quote! {
753 #(#[#meta])*
754 #vis const fn #ident (number: #variant_type) -> Option<Self> {
755 if false {
756 unreachable!()
757 } #( else if number == #values {
758 Some(Self::#variant_idents)
759 } )* else {
760 None
761 }
762 }
763 }
764 } else {
765 quote! {
766 #(#[#meta])*
767 #vis const fn #ident (number: #variant_type) -> Option<Self> {
768 match number{
769 #(
770 #values => Some(Self::#variant_idents),
771 )*
772 _ => None
773 }
774 }
775 }
776 };
777
778 expanded_2.extend(from_ordinal);
779 }
780
781 if let Some(ConstFunctionMember {
782 vis,
783 ident,
784 meta,
785 }) = enable_ordinal
786 {
787 expanded_2.extend(quote! {
788 #(#[#meta])*
789 #vis const fn #ident (&self) -> #variant_type {
790 match self {
791 #(
792 Self::#variant_idents => #values,
793 )*
794 }
795 }
796 });
797 }
798
799 if !expanded_2.is_empty() {
800 expanded.extend(quote! {
801 impl #impl_generics #name #ty_generics #where_clause {
802 #expanded_2
803 }
804 });
805 }
806
807 expanded.into()
808}