educe/common/tools/
discriminant_type.rs

1use proc_macro2::{Ident, Span, TokenStream};
2use quote::{ToTokens, TokenStreamExt};
3use syn::{
4    punctuated::Punctuated, spanned::Spanned, Data, DeriveInput, Expr, Lit, Meta, Token, UnOp,
5};
6
7#[derive(Debug)]
8pub(crate) enum DiscriminantType {
9    ISize,
10    I8,
11    I16,
12    I32,
13    I64,
14    I128,
15    USize,
16    U8,
17    U16,
18    U32,
19    U64,
20    U128,
21}
22
23impl DiscriminantType {
24    #[inline]
25    pub(crate) fn parse_str<S: AsRef<str>>(s: S) -> Option<Self> {
26        match s.as_ref() {
27            "i8" => Some(Self::I8),
28            "i16" => Some(Self::I16),
29            "i32" => Some(Self::I32),
30            "i64" => Some(Self::I64),
31            "i128" => Some(Self::I128),
32            "isize" => Some(Self::ISize),
33            "u8" => Some(Self::U8),
34            "u16" => Some(Self::U16),
35            "u32" => Some(Self::U32),
36            "u64" => Some(Self::U64),
37            "u128" => Some(Self::U128),
38            "usize" => Some(Self::USize),
39            _ => None,
40        }
41    }
42
43    #[inline]
44    pub(crate) const fn as_str(&self) -> &'static str {
45        match self {
46            Self::ISize => "isize",
47            Self::I8 => "i8",
48            Self::I16 => "i16",
49            Self::I32 => "i32",
50            Self::I64 => "i64",
51            Self::I128 => "i128",
52            Self::USize => "usize",
53            Self::U8 => "u8",
54            Self::U16 => "u16",
55            Self::U32 => "u32",
56            Self::U64 => "u64",
57            Self::U128 => "u128",
58        }
59    }
60}
61
62impl ToTokens for DiscriminantType {
63    #[inline]
64    fn to_tokens(&self, tokens: &mut TokenStream) {
65        tokens.append(Ident::new(self.as_str(), Span::call_site()));
66    }
67}
68
69impl DiscriminantType {
70    pub(crate) fn from_ast(ast: &DeriveInput) -> syn::Result<Self> {
71        if let Data::Enum(data) = &ast.data {
72            for attr in ast.attrs.iter() {
73                if attr.path().is_ident("repr") {
74                    // #[repr(u8)], #[repr(u16)], ..., etc.
75                    if let Meta::List(list) = &attr.meta {
76                        let result =
77                            list.parse_args_with(Punctuated::<Ident, Token![,]>::parse_terminated)?;
78
79                        if let Some(value) = result.into_iter().next() {
80                            if let Some(t) = Self::parse_str(value.to_string()) {
81                                return Ok(t);
82                            }
83                        }
84                    }
85                }
86            }
87
88            let mut min = i128::MAX;
89            let mut max = i128::MIN;
90            let mut counter = 0i128;
91
92            for variant in data.variants.iter() {
93                if let Some((_, exp)) = variant.discriminant.as_ref() {
94                    match exp {
95                        Expr::Lit(lit) => {
96                            if let Lit::Int(lit) = &lit.lit {
97                                counter = lit
98                                    .base10_parse()
99                                    .map_err(|error| syn::Error::new(lit.span(), error))?;
100                            } else {
101                                return Err(syn::Error::new(lit.span(), "not an integer"));
102                            }
103                        },
104                        Expr::Unary(unary) => {
105                            if let UnOp::Neg(_) = unary.op {
106                                if let Expr::Lit(lit) = unary.expr.as_ref() {
107                                    if let Lit::Int(lit) = &lit.lit {
108                                        match lit.base10_parse::<i128>() {
109                                            Ok(i) => {
110                                                counter = -i;
111                                            },
112                                            Err(error) => {
113                                                // overflow
114                                                if lit.base10_digits()
115                                                    == "170141183460469231731687303715884105728"
116                                                {
117                                                    counter = i128::MIN;
118                                                } else {
119                                                    return Err(syn::Error::new(lit.span(), error));
120                                                }
121                                            },
122                                        }
123                                    } else {
124                                        return Err(syn::Error::new(lit.span(), "not an integer"));
125                                    }
126                                } else {
127                                    return Err(syn::Error::new(
128                                        unary.expr.span(),
129                                        "not a literal",
130                                    ));
131                                }
132                            } else {
133                                return Err(syn::Error::new(
134                                    unary.op.span(),
135                                    "this operation is not allow here",
136                                ));
137                            }
138                        },
139                        _ => return Err(syn::Error::new(exp.span(), "not a literal")),
140                    }
141                }
142
143                if min > counter {
144                    min = counter;
145                }
146
147                if max < counter {
148                    max = counter;
149                }
150
151                counter = counter.saturating_add(1);
152            }
153
154            Ok(if min >= i8::MIN as i128 && max <= i8::MAX as i128 {
155                Self::I8
156            } else if min >= i16::MIN as i128 && max <= i16::MAX as i128 {
157                Self::I16
158            } else if min >= i32::MIN as i128 && max <= i32::MAX as i128 {
159                Self::I32
160            } else if min >= i64::MIN as i128 && max <= i64::MAX as i128 {
161                Self::I64
162            } else {
163                Self::I128
164            })
165        } else {
166            Err(syn::Error::new(ast.span(), "not an enum"))
167        }
168    }
169}