as_derive_utils/
datastructure.rs

1use syn::{
2    self, Attribute, Data, DeriveInput, Field as SynField, Fields as SynFields, Generics, Ident,
3    Type, Visibility,
4};
5
6use quote::ToTokens;
7
8use proc_macro2::{Span, TokenStream};
9
10use std::fmt::{self, Display};
11
12mod field_map;
13mod type_param_map;
14
15pub use self::{field_map::FieldMap, type_param_map::TypeParamMap};
16
17//////////////////////////////////////////////////////////////////////////////
18
19/// A type definition(enum,struct,union).
20#[derive(Clone, Debug, PartialEq, Hash)]
21pub struct DataStructure<'a> {
22    pub vis: &'a Visibility,
23    pub name: &'a Ident,
24    pub generics: &'a Generics,
25    pub lifetime_count: usize,
26    pub field_count: usize,
27    pub pub_field_count: usize,
28
29    pub attrs: &'a [Attribute],
30
31    /// Whether this is a struct/union/enum.
32    pub data_variant: DataVariant,
33
34    /// The variants in the type definition.
35    ///
36    /// If it is a struct or a union this only has 1 element.
37    pub variants: Vec<Struct<'a>>,
38}
39
40impl<'a> DataStructure<'a> {
41    pub fn new(ast: &'a DeriveInput) -> Self {
42        let name = &ast.ident;
43
44        let data_variant: DataVariant;
45
46        let mut variants = Vec::new();
47
48        match &ast.data {
49            Data::Enum(enum_) => {
50                let override_vis = Some(&ast.vis);
51
52                for (variant, var) in enum_.variants.iter().enumerate() {
53                    variants.push(Struct::new(
54                        StructParams {
55                            discriminant: var.discriminant.as_ref().map(|(_, v)| v),
56                            variant,
57                            attrs: &var.attrs,
58                            name: &var.ident,
59                            override_vis,
60                        },
61                        &var.fields,
62                    ));
63                }
64                data_variant = DataVariant::Enum;
65            }
66            Data::Struct(struct_) => {
67                let override_vis = None;
68
69                variants.push(Struct::new(
70                    StructParams {
71                        discriminant: None,
72                        variant: 0,
73                        attrs: &[],
74                        name,
75                        override_vis,
76                    },
77                    &struct_.fields,
78                ));
79                data_variant = DataVariant::Struct;
80            }
81
82            Data::Union(union_) => {
83                let override_vis = None;
84
85                let fields = Some(&union_.fields.named);
86                let sk = StructKind::Braced;
87                let vari = Struct::with_fields(
88                    StructParams {
89                        discriminant: None,
90                        variant: 0,
91                        attrs: &[],
92                        name,
93                        override_vis,
94                    },
95                    sk,
96                    fields,
97                );
98                variants.push(vari);
99                data_variant = DataVariant::Union;
100            }
101        }
102
103        let mut field_count = 0;
104        let mut pub_field_count = 0;
105
106        for vari in &variants {
107            field_count += vari.fields.len();
108            pub_field_count += vari.pub_field_count;
109        }
110
111        Self {
112            vis: &ast.vis,
113            name,
114            attrs: &ast.attrs,
115            generics: &ast.generics,
116            lifetime_count: ast.generics.lifetimes().count(),
117            data_variant,
118            variants,
119            field_count,
120            pub_field_count,
121        }
122    }
123
124    pub fn has_public_fields(&self) -> bool {
125        self.pub_field_count != 0
126    }
127}
128
129//////////////////////////////////////////////////////////////////////////////
130
131/// Whether the struct is tupled or not.
132#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Ord, Eq, Hash)]
133pub enum StructKind {
134    /// structs declared using the `struct Name( ... ) syntax.
135    Tuple,
136    /// structs declared using the `struct Name{ ... }` or `struct name;` syntaxes
137    Braced,
138}
139
140#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Ord, Eq, Hash)]
141pub enum DataVariant {
142    Struct,
143    Enum,
144    Union,
145}
146
147#[derive(Copy, Clone, Debug, PartialEq, Hash)]
148pub struct FieldIndex {
149    pub variant: usize,
150    pub pos: usize,
151}
152
153//////////////////////////////////////////////////////////////////////////////
154
155#[derive(Copy, Clone)]
156struct StructParams<'a> {
157    discriminant: Option<&'a syn::Expr>,
158    variant: usize,
159    attrs: &'a [Attribute],
160    name: &'a Ident,
161    override_vis: Option<&'a Visibility>,
162}
163
164/// A struct/union or a variant of an enum.
165#[non_exhaustive]
166#[derive(Clone, Debug, PartialEq, Hash)]
167pub struct Struct<'a> {
168    /// The attributes of this `Struct`.
169    ///
170    /// If this is a struct/union:these is the same as DataStructure.attrs.
171    ///
172    /// If this is an enum:these are the attributes on the variant.
173    pub attrs: &'a [Attribute],
174    /// The name of this `Struct`.
175    ///
176    /// If this is a struct/union:these is the same as DataStructure.name.
177    ///
178    /// If this is an enum:this is the name of the variant.
179    pub name: &'a Ident,
180    pub kind: StructKind,
181    pub fields: Vec<Field<'a>>,
182    pub pub_field_count: usize,
183    /// The value of this discriminant.
184    ///
185    /// If this is a Some(_):This is an enum with an explicit discriminant value.
186    ///
187    /// If this is an None:
188    ///     This is either a struct/union or an enum variant without an explicit discriminant.
189    pub discriminant: Option<&'a syn::Expr>,
190}
191
192impl<'a> Struct<'a> {
193    fn new(p: StructParams<'a>, fields: &'a SynFields) -> Self {
194        let kind = match *fields {
195            SynFields::Named { .. } => StructKind::Braced,
196            SynFields::Unnamed { .. } => StructKind::Tuple,
197            SynFields::Unit { .. } => StructKind::Braced,
198        };
199        let fields = match fields {
200            SynFields::Named(f) => Some(&f.named),
201            SynFields::Unnamed(f) => Some(&f.unnamed),
202            SynFields::Unit => None,
203        };
204
205        Self::with_fields(p, kind, fields)
206    }
207
208    fn with_fields<I>(p: StructParams<'a>, kind: StructKind, fields: Option<I>) -> Self
209    where
210        I: IntoIterator<Item = &'a SynField>,
211    {
212        let fields = match fields {
213            Some(x) => Field::from_iter(p, x),
214            None => Vec::new(),
215        };
216
217        let mut pub_field_count = 0usize;
218
219        for field in &fields {
220            if field.is_public() {
221                pub_field_count += 1;
222            }
223        }
224
225        Self {
226            discriminant: p.discriminant,
227            attrs: p.attrs,
228            name: p.name,
229            kind,
230            pub_field_count,
231            fields,
232        }
233    }
234}
235
236//////////////////////////////////////////////////////////////////////////////
237
238/// Represent a struct field
239///
240#[derive(Clone, Debug, PartialEq, Hash)]
241pub struct Field<'a> {
242    pub index: FieldIndex,
243    pub attrs: &'a [Attribute],
244    pub vis: &'a Visibility,
245    /// identifier for the field,which is either an index(in a tuple struct) or a name.
246    pub ident: FieldIdent<'a>,
247    pub ty: &'a Type,
248}
249
250impl<'a> Field<'a> {
251    fn new(
252        index: FieldIndex,
253        field: &'a SynField,
254        span: Span,
255        override_vis: Option<&'a Visibility>,
256    ) -> Self {
257        let ident = match field.ident.as_ref() {
258            Some(ident) => FieldIdent::Named(ident),
259            None => FieldIdent::new_index(index.pos, span),
260        };
261
262        Self {
263            index,
264            attrs: &field.attrs,
265            vis: override_vis.unwrap_or(&field.vis),
266            ident,
267            ty: &field.ty,
268        }
269    }
270
271    pub fn is_public(&self) -> bool {
272        matches!(self.vis, Visibility::Public { .. })
273    }
274
275    /// Gets the identifier of this field usable for the variable in a pattern.
276    ///
277    /// You can match on a single field struct (tupled or braced) like this:
278    ///
279    /// ```rust
280    /// use as_derive_utils::datastructure::Struct;
281    ///
282    /// fn example(struct_: Struct<'_>) -> proc_macro2::TokenStream {
283    ///     let field = &struct_.field[0];
284    ///     let field_name = &field.ident;
285    ///     let variable = field.pat_ident();
286    ///    
287    ///     quote::quote!( let Foo{#field_name: #variable} = bar; )
288    /// }
289    /// ```
290    pub fn pat_ident(&self) -> &Ident {
291        match &self.ident {
292            FieldIdent::Index(_, ident) => ident,
293            FieldIdent::Named(ident) => ident,
294        }
295    }
296
297    fn from_iter<I>(p: StructParams<'a>, fields: I) -> Vec<Self>
298    where
299        I: IntoIterator<Item = &'a SynField>,
300    {
301        fields
302            .into_iter()
303            .enumerate()
304            .map(|(pos, f)| {
305                let fi = FieldIndex {
306                    variant: p.variant,
307                    pos,
308                };
309                Field::new(fi, f, p.name.span(), p.override_vis)
310            })
311            .collect()
312    }
313}
314
315//////////////////////////////////////////////////////////////////////////////
316
317#[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)]
318pub enum FieldIdent<'a> {
319    Index(usize, Ident),
320    Named(&'a Ident),
321}
322
323impl<'a> Display for FieldIdent<'a> {
324    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
325        match self {
326            FieldIdent::Index(x, ..) => Display::fmt(x, f),
327            FieldIdent::Named(x) => Display::fmt(x, f),
328        }
329    }
330}
331
332impl<'a> ToTokens for FieldIdent<'a> {
333    fn to_tokens(&self, tokens: &mut TokenStream) {
334        match *self {
335            FieldIdent::Index(ind, ..) => syn::Index::from(ind).to_tokens(tokens),
336            FieldIdent::Named(name) => name.to_tokens(tokens),
337        }
338    }
339}
340
341impl<'a> FieldIdent<'a> {
342    fn new_index(index: usize, span: Span) -> Self {
343        FieldIdent::Index(index, Ident::new(&format!("field_{}", index), span))
344    }
345}