strum_macros/macros/strings/
from_string.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::{parse_quote, Data, DeriveInput, Fields, Path};
4
5use crate::helpers::{
6    missing_parse_err_attr_error, non_enum_error, occurrence_error, HasInnerVariantProperties,
7    HasStrumVariantProperties, HasTypeProperties,
8};
9
10pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
11    let name = &ast.ident;
12    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
13    let variants = match &ast.data {
14        Data::Enum(v) => &v.variants,
15        _ => return Err(non_enum_error()),
16    };
17
18    let type_properties = ast.get_type_properties()?;
19    let strum_module_path = type_properties.crate_module_path();
20
21    let mut default_kw = None;
22    let (mut default_err_ty, mut default) = match (
23        type_properties.parse_err_ty,
24        type_properties.parse_err_fn,
25    ) {
26        (None, None) => (
27            quote! { #strum_module_path::ParseError },
28            quote! { ::core::result::Result::Err(#strum_module_path::ParseError::VariantNotFound) },
29        ),
30        (Some(ty), Some(f)) => {
31            let ty_path: Path = parse_quote!(#ty);
32            let fn_path: Path = parse_quote!(#f);
33
34            (
35                quote! { #ty_path },
36                quote! { ::core::result::Result::Err(#fn_path(s)) },
37            )
38        }
39        _ => return Err(missing_parse_err_attr_error()),
40    };
41    let mut phf_exact_match_arms = Vec::new();
42    let mut standard_match_arms = Vec::new();
43    for variant in variants {
44        let ident = &variant.ident;
45        let variant_properties = variant.get_variant_properties()?;
46
47        if variant_properties.disabled.is_some() {
48            continue;
49        }
50
51        if let Some(kw) = variant_properties.default {
52            if let Some(fst_kw) = default_kw {
53                return Err(occurrence_error(fst_kw, kw, "default"));
54            }
55
56            default_kw = Some(kw);
57            default_err_ty = quote! { #strum_module_path::ParseError };
58
59            match &variant.fields {
60                Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
61                    default = quote! {
62                        ::core::result::Result::Ok(#name::#ident(s.into()))
63                    };
64                }
65                Fields::Named(ref f) if f.named.len() == 1 => {
66                    let field_name = f.named.last().unwrap().ident.as_ref().unwrap();
67                    default = quote! {
68                        ::core::result::Result::Ok(#name::#ident { #field_name : s.into() } )
69                    };
70                }
71                _ => {
72                    return Err(syn::Error::new_spanned(
73                        variant,
74                        "Default only works on newtype structs with a single String field",
75                    ))
76                }
77            }
78
79            continue;
80        }
81
82        let params = match &variant.fields {
83            Fields::Unit => quote! {},
84            Fields::Unnamed(fields) => {
85                if let Some(ref value) = variant_properties.default_with {
86                    let func = proc_macro2::Ident::new(&value.value(), value.span());
87                    let defaults = vec![quote! { #func() }];
88                    quote! { (#(#defaults),*) }
89                } else {
90                    let defaults =
91                        ::core::iter::repeat(quote!(Default::default())).take(fields.unnamed.len());
92                    quote! { (#(#defaults),*) }
93                }
94            }
95            Fields::Named(fields) => {
96                let mut defaults = vec![];
97                for field in &fields.named {
98                    let meta = field.get_variant_inner_properties()?;
99                    let field = field.ident.as_ref().unwrap();
100
101                    if let Some(default_with) = meta.default_with {
102                        let func =
103                            proc_macro2::Ident::new(&default_with.value(), default_with.span());
104                        defaults.push(quote! {
105                            #field: #func()
106                        });
107                    } else {
108                        defaults.push(quote! { #field: Default::default() });
109                    }
110                }
111
112                quote! { {#(#defaults),*} }
113            }
114        };
115
116        let is_ascii_case_insensitive = variant_properties
117            .ascii_case_insensitive
118            .unwrap_or(type_properties.ascii_case_insensitive);
119
120        // If we don't have any custom variants, add the default serialized name.
121        for serialization in variant_properties.get_serializations(type_properties.case_style) {
122            if type_properties.use_phf {
123                phf_exact_match_arms.push(quote! { #serialization => #name::#ident #params, });
124
125                if is_ascii_case_insensitive {
126                    // Store the lowercase and UPPERCASE variants in the phf map to capture
127                    let ser_string = serialization.value();
128
129                    let lower =
130                        syn::LitStr::new(&ser_string.to_ascii_lowercase(), serialization.span());
131                    let upper =
132                        syn::LitStr::new(&ser_string.to_ascii_uppercase(), serialization.span());
133                    phf_exact_match_arms.push(quote! { #lower => #name::#ident #params, });
134                    phf_exact_match_arms.push(quote! { #upper => #name::#ident #params, });
135                    standard_match_arms.push(quote! { s if s.eq_ignore_ascii_case(#serialization) => #name::#ident #params, });
136                }
137            } else {
138                standard_match_arms.push(if !is_ascii_case_insensitive {
139                    quote! { #serialization => #name::#ident #params, }
140                } else {
141                    quote! { s if s.eq_ignore_ascii_case(#serialization) => #name::#ident #params, }
142                });
143            }
144        }
145    }
146
147    let phf_body = if phf_exact_match_arms.is_empty() {
148        quote!()
149    } else {
150        quote! {
151            use #strum_module_path::_private_phf_reexport_for_macro_if_phf_feature as phf;
152            static PHF: phf::Map<&'static str, #name> = phf::phf_map! {
153                #(#phf_exact_match_arms)*
154            };
155            if let Some(value) = PHF.get(s).cloned() {
156                return ::core::result::Result::Ok(value);
157            }
158        }
159    };
160
161    let standard_match_body = if standard_match_arms.is_empty() {
162        default
163    } else {
164        quote! {
165            ::core::result::Result::Ok(match s {
166                #(#standard_match_arms)*
167                _ => return #default,
168            })
169        }
170    };
171
172    let from_str = quote! {
173        #[allow(clippy::use_self)]
174        impl #impl_generics ::core::str::FromStr for #name #ty_generics #where_clause {
175            type Err = #default_err_ty;
176
177            #[inline]
178            fn from_str(s: &str) -> ::core::result::Result< #name #ty_generics , <Self as ::core::str::FromStr>::Err> {
179                #phf_body
180                #standard_match_body
181            }
182        }
183    };
184    let try_from_str = try_from_str(
185        name,
186        &impl_generics,
187        &ty_generics,
188        where_clause,
189        &default_err_ty,
190    );
191
192    Ok(quote! {
193        #from_str
194        #try_from_str
195    })
196}
197
198#[rustversion::before(1.34)]
199fn try_from_str(
200    _name: &proc_macro2::Ident,
201    _impl_generics: &syn::ImplGenerics,
202    _ty_generics: &syn::TypeGenerics,
203    _where_clause: Option<&syn::WhereClause>,
204    _strum_module_path: &syn::Path,
205) -> TokenStream {
206    Default::default()
207}
208
209#[rustversion::since(1.34)]
210fn try_from_str(
211    name: &proc_macro2::Ident,
212    impl_generics: &syn::ImplGenerics,
213    ty_generics: &syn::TypeGenerics,
214    where_clause: Option<&syn::WhereClause>,
215    default_err_ty: &TokenStream,
216) -> TokenStream {
217    quote! {
218        #[allow(clippy::use_self)]
219        impl #impl_generics ::core::convert::TryFrom<&str> for #name #ty_generics #where_clause {
220            type Error = #default_err_ty;
221
222            #[inline]
223            fn try_from(s: &str) -> ::core::result::Result< #name #ty_generics , <Self as ::core::convert::TryFrom<&str>>::Error> {
224                ::core::str::FromStr::from_str(s)
225            }
226        }
227    }
228}