strum_macros/macros/
enum_iter.rs

1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use syn::{Data, DeriveInput, Fields, Ident};
4
5use crate::helpers::{non_enum_error, HasStrumVariantProperties, HasTypeProperties};
6
7pub fn enum_iter_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
8    let name = &ast.ident;
9    let gen = &ast.generics;
10    let (impl_generics, ty_generics, where_clause) = gen.split_for_impl();
11    let vis = &ast.vis;
12    let type_properties = ast.get_type_properties()?;
13    let strum_module_path = type_properties.crate_module_path();
14    let doc_comment = format!("An iterator over the variants of [{}]", name);
15
16    if gen.lifetimes().count() > 0 {
17        return Err(syn::Error::new(
18            Span::call_site(),
19            "This macro doesn't support enums with lifetimes. \
20             The resulting enums would be unbounded.",
21        ));
22    }
23
24    let phantom_data = if gen.type_params().count() > 0 {
25        let g = gen.type_params().map(|param| &param.ident);
26        quote! { < fn() -> ( #(#g),* ) > }
27    } else {
28        quote! { < fn() -> () > }
29    };
30
31    let variants = match &ast.data {
32        Data::Enum(v) => &v.variants,
33        _ => return Err(non_enum_error()),
34    };
35
36    let mut arms = Vec::new();
37    let mut idx = 0usize;
38    for variant in variants {
39        if variant.get_variant_properties()?.disabled.is_some() {
40            continue;
41        }
42
43        let ident = &variant.ident;
44        let params = match &variant.fields {
45            Fields::Unit => quote! {},
46            Fields::Unnamed(fields) => {
47                let defaults = ::core::iter::repeat(quote!(::core::default::Default::default()))
48                    .take(fields.unnamed.len());
49                quote! { (#(#defaults),*) }
50            }
51            Fields::Named(fields) => {
52                let fields = fields
53                    .named
54                    .iter()
55                    .map(|field| field.ident.as_ref().unwrap());
56                quote! { {#(#fields: ::core::default::Default::default()),*} }
57            }
58        };
59
60        arms.push(quote! {#idx => ::core::option::Option::Some(#name::#ident #params)});
61        idx += 1;
62    }
63
64    let variant_count = arms.len();
65    arms.push(quote! { _ => ::core::option::Option::None });
66    let iter_name = syn::parse_str::<Ident>(&format!("{}Iter", name)).unwrap();
67
68    // Create a string literal "MyEnumIter" to use in the debug impl.
69    let iter_name_debug_struct =
70        syn::parse_str::<syn::LitStr>(&format!("\"{}\"", iter_name)).unwrap();
71
72    Ok(quote! {
73        #[doc = #doc_comment]
74        #[allow(
75            missing_copy_implementations,
76        )]
77        #vis struct #iter_name #impl_generics {
78            idx: usize,
79            back_idx: usize,
80            marker: ::core::marker::PhantomData #phantom_data,
81        }
82
83        impl #impl_generics ::core::fmt::Debug for #iter_name #ty_generics #where_clause {
84            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
85                // We don't know if the variants implement debug themselves so the only thing we
86                // can really show is how many elements are left.
87                f.debug_struct(#iter_name_debug_struct)
88                    .field("len", &self.len())
89                    .finish()
90            }
91        }
92
93        impl #impl_generics #iter_name #ty_generics #where_clause {
94            fn get(&self, idx: usize) -> ::core::option::Option<#name #ty_generics> {
95                match idx {
96                    #(#arms),*
97                }
98            }
99        }
100
101        impl #impl_generics #strum_module_path::IntoEnumIterator for #name #ty_generics #where_clause {
102            type Iterator = #iter_name #ty_generics;
103
104            #[inline]
105            fn iter() -> #iter_name #ty_generics {
106                #iter_name {
107                    idx: 0,
108                    back_idx: 0,
109                    marker: ::core::marker::PhantomData,
110                }
111            }
112        }
113
114        impl #impl_generics Iterator for #iter_name #ty_generics #where_clause {
115            type Item = #name #ty_generics;
116
117            #[inline]
118            fn next(&mut self) -> ::core::option::Option<<Self as Iterator>::Item> {
119                self.nth(0)
120            }
121
122            #[inline]
123            fn size_hint(&self) -> (usize, ::core::option::Option<usize>) {
124                let t = if self.idx + self.back_idx >= #variant_count { 0 } else { #variant_count - self.idx - self.back_idx };
125                (t, Some(t))
126            }
127
128            #[inline]
129            fn nth(&mut self, n: usize) -> ::core::option::Option<<Self as Iterator>::Item> {
130                let idx = self.idx + n + 1;
131                if idx + self.back_idx > #variant_count {
132                    // We went past the end of the iterator. Freeze idx at #variant_count
133                    // so that it doesn't overflow if the user calls this repeatedly.
134                    // See PR #76 for context.
135                    self.idx = #variant_count;
136                    ::core::option::Option::None
137                } else {
138                    self.idx = idx;
139                    #iter_name::get(self, idx - 1)
140                }
141            }
142        }
143
144        impl #impl_generics ExactSizeIterator for #iter_name #ty_generics #where_clause {
145            #[inline]
146            fn len(&self) -> usize {
147                self.size_hint().0
148            }
149        }
150
151        impl #impl_generics DoubleEndedIterator for #iter_name #ty_generics #where_clause {
152            #[inline]
153            fn next_back(&mut self) -> ::core::option::Option<<Self as Iterator>::Item> {
154                let back_idx = self.back_idx + 1;
155
156                if self.idx + back_idx > #variant_count {
157                    // We went past the end of the iterator. Freeze back_idx at #variant_count
158                    // so that it doesn't overflow if the user calls this repeatedly.
159                    // See PR #76 for context.
160                    self.back_idx = #variant_count;
161                    ::core::option::Option::None
162                } else {
163                    self.back_idx = back_idx;
164                    #iter_name::get(self, #variant_count - self.back_idx)
165                }
166            }
167        }
168
169        impl #impl_generics ::core::iter::FusedIterator for #iter_name #ty_generics #where_clause { }
170
171        impl #impl_generics Clone for #iter_name #ty_generics #where_clause {
172            #[inline]
173            fn clone(&self) -> #iter_name #ty_generics {
174                #iter_name {
175                    idx: self.idx,
176                    back_idx: self.back_idx,
177                    marker: self.marker.clone(),
178                }
179            }
180        }
181    })
182}