strum_macros/macros/
enum_iter.rs1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use syn::{Data, DeriveInput, Fields, Ident};
4
5use crate::helpers::{non_enum_error, HasStrumVariantProperties, HasTypeProperties};
6
7pub fn enum_iter_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
8 let name = &ast.ident;
9 let gen = &ast.generics;
10 let (impl_generics, ty_generics, where_clause) = gen.split_for_impl();
11 let vis = &ast.vis;
12 let type_properties = ast.get_type_properties()?;
13 let strum_module_path = type_properties.crate_module_path();
14 let doc_comment = format!("An iterator over the variants of [{}]", name);
15
16 if gen.lifetimes().count() > 0 {
17 return Err(syn::Error::new(
18 Span::call_site(),
19 "This macro doesn't support enums with lifetimes. \
20 The resulting enums would be unbounded.",
21 ));
22 }
23
24 let phantom_data = if gen.type_params().count() > 0 {
25 let g = gen.type_params().map(|param| ¶m.ident);
26 quote! { < fn() -> ( #(#g),* ) > }
27 } else {
28 quote! { < fn() -> () > }
29 };
30
31 let variants = match &ast.data {
32 Data::Enum(v) => &v.variants,
33 _ => return Err(non_enum_error()),
34 };
35
36 let mut arms = Vec::new();
37 let mut idx = 0usize;
38 for variant in variants {
39 if variant.get_variant_properties()?.disabled.is_some() {
40 continue;
41 }
42
43 let ident = &variant.ident;
44 let params = match &variant.fields {
45 Fields::Unit => quote! {},
46 Fields::Unnamed(fields) => {
47 let defaults = ::core::iter::repeat(quote!(::core::default::Default::default()))
48 .take(fields.unnamed.len());
49 quote! { (#(#defaults),*) }
50 }
51 Fields::Named(fields) => {
52 let fields = fields
53 .named
54 .iter()
55 .map(|field| field.ident.as_ref().unwrap());
56 quote! { {#(#fields: ::core::default::Default::default()),*} }
57 }
58 };
59
60 arms.push(quote! {#idx => ::core::option::Option::Some(#name::#ident #params)});
61 idx += 1;
62 }
63
64 let variant_count = arms.len();
65 arms.push(quote! { _ => ::core::option::Option::None });
66 let iter_name = syn::parse_str::<Ident>(&format!("{}Iter", name)).unwrap();
67
68 let iter_name_debug_struct =
70 syn::parse_str::<syn::LitStr>(&format!("\"{}\"", iter_name)).unwrap();
71
72 Ok(quote! {
73 #[doc = #doc_comment]
74 #[allow(
75 missing_copy_implementations,
76 )]
77 #vis struct #iter_name #impl_generics {
78 idx: usize,
79 back_idx: usize,
80 marker: ::core::marker::PhantomData #phantom_data,
81 }
82
83 impl #impl_generics ::core::fmt::Debug for #iter_name #ty_generics #where_clause {
84 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
85 f.debug_struct(#iter_name_debug_struct)
88 .field("len", &self.len())
89 .finish()
90 }
91 }
92
93 impl #impl_generics #iter_name #ty_generics #where_clause {
94 fn get(&self, idx: usize) -> ::core::option::Option<#name #ty_generics> {
95 match idx {
96 #(#arms),*
97 }
98 }
99 }
100
101 impl #impl_generics #strum_module_path::IntoEnumIterator for #name #ty_generics #where_clause {
102 type Iterator = #iter_name #ty_generics;
103
104 #[inline]
105 fn iter() -> #iter_name #ty_generics {
106 #iter_name {
107 idx: 0,
108 back_idx: 0,
109 marker: ::core::marker::PhantomData,
110 }
111 }
112 }
113
114 impl #impl_generics Iterator for #iter_name #ty_generics #where_clause {
115 type Item = #name #ty_generics;
116
117 #[inline]
118 fn next(&mut self) -> ::core::option::Option<<Self as Iterator>::Item> {
119 self.nth(0)
120 }
121
122 #[inline]
123 fn size_hint(&self) -> (usize, ::core::option::Option<usize>) {
124 let t = if self.idx + self.back_idx >= #variant_count { 0 } else { #variant_count - self.idx - self.back_idx };
125 (t, Some(t))
126 }
127
128 #[inline]
129 fn nth(&mut self, n: usize) -> ::core::option::Option<<Self as Iterator>::Item> {
130 let idx = self.idx + n + 1;
131 if idx + self.back_idx > #variant_count {
132 self.idx = #variant_count;
136 ::core::option::Option::None
137 } else {
138 self.idx = idx;
139 #iter_name::get(self, idx - 1)
140 }
141 }
142 }
143
144 impl #impl_generics ExactSizeIterator for #iter_name #ty_generics #where_clause {
145 #[inline]
146 fn len(&self) -> usize {
147 self.size_hint().0
148 }
149 }
150
151 impl #impl_generics DoubleEndedIterator for #iter_name #ty_generics #where_clause {
152 #[inline]
153 fn next_back(&mut self) -> ::core::option::Option<<Self as Iterator>::Item> {
154 let back_idx = self.back_idx + 1;
155
156 if self.idx + back_idx > #variant_count {
157 self.back_idx = #variant_count;
161 ::core::option::Option::None
162 } else {
163 self.back_idx = back_idx;
164 #iter_name::get(self, #variant_count - self.back_idx)
165 }
166 }
167 }
168
169 impl #impl_generics ::core::iter::FusedIterator for #iter_name #ty_generics #where_clause { }
170
171 impl #impl_generics Clone for #iter_name #ty_generics #where_clause {
172 #[inline]
173 fn clone(&self) -> #iter_name #ty_generics {
174 #iter_name {
175 idx: self.idx,
176 back_idx: self.back_idx,
177 marker: self.marker.clone(),
178 }
179 }
180 }
181 })
182}