strum_macros/macros/
enum_table.rs1use proc_macro2::{Span, TokenStream};
2use quote::{format_ident, quote};
3use syn::{spanned::Spanned, Data, DeriveInput, Fields};
4
5use crate::helpers::{non_enum_error, snakify, HasStrumVariantProperties};
6
7pub fn enum_table_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
8 let name = &ast.ident;
9 let gen = &ast.generics;
10 let vis = &ast.vis;
11 let mut doc_comment = format!("A map over the variants of `{}`", name);
12
13 if gen.lifetimes().count() > 0 {
14 return Err(syn::Error::new(
15 Span::call_site(),
16 "`EnumTable` doesn't support enums with lifetimes.",
17 ));
18 }
19
20 let variants = match &ast.data {
21 Data::Enum(v) => &v.variants,
22 _ => return Err(non_enum_error()),
23 };
24
25 let table_name = format_ident!("{}Table", name);
26
27 let mut pascal_idents = Vec::new();
29 let mut snake_idents = Vec::new();
31 let mut get_matches = Vec::new();
33 let mut get_matches_mut = Vec::new();
35 let mut set_matches = Vec::new();
37 let mut closure_fields = Vec::new();
39 let mut transform_fields = Vec::new();
41
42 let mut disabled_variants = Vec::new();
44 let mut disabled_matches = Vec::new();
46
47 for variant in variants {
48 if variant.get_variant_properties()?.disabled.is_some() {
50 let disabled_ident = &variant.ident;
51 let panic_message = format!(
52 "Can't use `{}` with `{}` - variant is disabled for Strum features",
53 disabled_ident, table_name
54 );
55 disabled_variants.push(disabled_ident);
56 disabled_matches.push(quote!(#name::#disabled_ident => panic!(#panic_message),));
57 continue;
58 }
59
60 if !matches!(variant.fields, Fields::Unit) {
62 return Err(syn::Error::new(
63 variant.fields.span(),
64 "`EnumTable` doesn't support enums with non-unit variants",
65 ));
66 };
67
68 let pascal_case = &variant.ident;
69 let snake_case = format_ident!("_{}", snakify(&pascal_case.to_string()));
70
71 get_matches.push(quote! {#name::#pascal_case => &self.#snake_case,});
72 get_matches_mut.push(quote! {#name::#pascal_case => &mut self.#snake_case,});
73 set_matches.push(quote! {#name::#pascal_case => self.#snake_case = new_value,});
74 closure_fields.push(quote! {#snake_case: func(#name::#pascal_case),});
75 transform_fields.push(quote! {#snake_case: func(#name::#pascal_case, &self.#snake_case),});
76 pascal_idents.push(pascal_case);
77 snake_idents.push(snake_case);
78 }
79
80 if pascal_idents.is_empty() {
82 return Err(syn::Error::new(
83 variants.span(),
84 "`EnumTable` requires at least one non-disabled variant",
85 ));
86 }
87
88 if !disabled_variants.is_empty() {
90 doc_comment.push_str(&format!(
91 "\n# Panics\nIndexing `{}` with any of the following variants will cause a panic:",
92 table_name
93 ));
94 for variant in disabled_variants {
95 doc_comment.push_str(&format!("\n\n- `{}::{}`", name, variant));
96 }
97 }
98
99 let doc_new = format!(
100 "Create a new {} with a value for each variant of {}",
101 table_name, name
102 );
103 let doc_closure = format!(
104 "Create a new {} by running a function on each variant of `{}`",
105 table_name, name
106 );
107 let doc_transform = format!("Create a new `{}` by running a function on each variant of `{}` and the corresponding value in the current `{0}`", table_name, name);
108 let doc_filled = format!(
109 "Create a new `{}` with the same value in each field.",
110 table_name
111 );
112 let doc_option_all = format!("Converts `{}<Option<T>>` into `Option<{0}<T>>`. Returns `Some` if all fields are `Some`, otherwise returns `None`.", table_name);
113 let doc_result_all_ok = format!("Converts `{}<Result<T, E>>` into `Result<{0}<T>, E>`. Returns `Ok` if all fields are `Ok`, otherwise returns `Err`.", table_name);
114
115 Ok(quote! {
116 #[doc = #doc_comment]
117 #[allow(
118 missing_copy_implementations,
119 )]
120 #[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
121 #vis struct #table_name<T> {
122 #(#snake_idents: T,)*
123 }
124
125 impl<T: Clone> #table_name<T> {
126 #[doc = #doc_filled]
127 #vis fn filled(value: T) -> #table_name<T> {
128 #table_name {
129 #(#snake_idents: value.clone(),)*
130 }
131 }
132 }
133
134 impl<T> #table_name<T> {
135 #[doc = #doc_new]
136 #[inline]
137 #vis fn new(
138 #(#snake_idents: T,)*
139 ) -> #table_name<T> {
140 #table_name {
141 #(#snake_idents,)*
142 }
143 }
144
145 #[doc = #doc_closure]
146 #[inline]
147 #vis fn from_closure<F: Fn(#name)->T>(func: F) -> #table_name<T> {
148 #table_name {
149 #(#closure_fields)*
150 }
151 }
152
153 #[doc = #doc_transform]
154 #[inline]
155 #vis fn transform<U, F: Fn(#name, &T)->U>(&self, func: F) -> #table_name<U> {
156 #table_name {
157 #(#transform_fields)*
158 }
159 }
160
161 }
162
163 impl<T> ::core::ops::Index<#name> for #table_name<T> {
164 type Output = T;
165
166 #[inline]
167 fn index(&self, idx: #name) -> &T {
168 match idx {
169 #(#get_matches)*
170 #(#disabled_matches)*
171 }
172 }
173 }
174
175 impl<T> ::core::ops::IndexMut<#name> for #table_name<T> {
176 #[inline]
177 fn index_mut(&mut self, idx: #name) -> &mut T {
178 match idx {
179 #(#get_matches_mut)*
180 #(#disabled_matches)*
181 }
182 }
183 }
184
185 impl<T> #table_name<::core::option::Option<T>> {
186 #[doc = #doc_option_all]
187 #[inline]
188 #vis fn all(self) -> ::core::option::Option<#table_name<T>> {
189 if let #table_name {
190 #(#snake_idents: ::core::option::Option::Some(#snake_idents),)*
191 } = self {
192 ::core::option::Option::Some(#table_name {
193 #(#snake_idents,)*
194 })
195 } else {
196 ::core::option::Option::None
197 }
198 }
199 }
200
201 impl<T, E> #table_name<::core::result::Result<T, E>> {
202 #[doc = #doc_result_all_ok]
203 #[inline]
204 #vis fn all_ok(self) -> ::core::result::Result<#table_name<T>, E> {
205 ::core::result::Result::Ok(#table_name {
206 #(#snake_idents: self.#snake_idents?,)*
207 })
208 }
209 }
210 })
211}