zvariant_derive/
value.rs

1use proc_macro2::{Span, TokenStream};
2use quote::{quote, ToTokens};
3use syn::{
4    spanned::Spanned, Attribute, Data, DataEnum, DeriveInput, Error, Expr, Fields, Generics, Ident,
5    Lifetime, LifetimeParam,
6};
7
8use crate::utils::*;
9
10pub enum ValueType {
11    Value,
12    OwnedValue,
13}
14
15pub fn expand_derive(ast: DeriveInput, value_type: ValueType) -> Result<TokenStream, Error> {
16    let zv = zvariant_path();
17
18    match &ast.data {
19        Data::Struct(ds) => match &ds.fields {
20            Fields::Named(_) | Fields::Unnamed(_) => {
21                let StructAttributes { signature, .. } = StructAttributes::parse(&ast.attrs)?;
22                let signature = signature.map(|signature| match signature.as_str() {
23                    "dict" => "a{sv}".to_string(),
24                    _ => signature,
25                });
26
27                impl_struct(
28                    value_type,
29                    ast.ident,
30                    ast.generics,
31                    &ds.fields,
32                    signature,
33                    &zv,
34                )
35            }
36            Fields::Unit => Err(Error::new(ast.span(), "Unit structures not supported")),
37        },
38        Data::Enum(data) => impl_enum(value_type, ast.ident, ast.generics, ast.attrs, data, &zv),
39        _ => Err(Error::new(
40            ast.span(),
41            "only structs and enums are supported",
42        )),
43    }
44}
45
46fn impl_struct(
47    value_type: ValueType,
48    name: Ident,
49    generics: Generics,
50    fields: &Fields,
51    signature: Option<String>,
52    zv: &TokenStream,
53) -> Result<TokenStream, Error> {
54    let statc_lifetime = LifetimeParam::new(Lifetime::new("'static", Span::call_site()));
55    let (
56        value_type,
57        value_lifetime,
58        into_value_trait,
59        into_value_method,
60        into_value_error_decl,
61        into_value_ret,
62        into_value_error_transform,
63    ) = match value_type {
64        ValueType::Value => {
65            let mut lifetimes = generics.lifetimes();
66            let value_lifetime = lifetimes
67                .next()
68                .cloned()
69                .unwrap_or_else(|| statc_lifetime.clone());
70            if lifetimes.next().is_some() {
71                return Err(Error::new(
72                    name.span(),
73                    "Type with more than 1 lifetime not supported",
74                ));
75            }
76
77            (
78                quote! { #zv::Value<#value_lifetime> },
79                value_lifetime,
80                quote! { From },
81                quote! { from },
82                quote! {},
83                quote! { Self },
84                quote! {},
85            )
86        }
87        ValueType::OwnedValue => (
88            quote! { #zv::OwnedValue },
89            statc_lifetime,
90            quote! { TryFrom },
91            quote! { try_from },
92            quote! { type Error = #zv::Error; },
93            quote! { #zv::Result<Self> },
94            quote! { .map_err(::std::convert::Into::into) },
95        ),
96    };
97
98    let type_params = generics.type_params().cloned().collect::<Vec<_>>();
99    let (from_value_where_clause, into_value_where_clause) = if !type_params.is_empty() {
100        (
101            Some(quote! {
102                where
103                #(
104                    #type_params: ::std::convert::TryFrom<#zv::Value<#value_lifetime>> + #zv::Type,
105                    <#type_params as ::std::convert::TryFrom<#zv::Value<#value_lifetime>>>::Error: ::std::convert::Into<#zv::Error>
106                ),*
107            }),
108            Some(quote! {
109                where
110                #(
111                    #type_params: ::std::convert::Into<#zv::Value<#value_lifetime>> + #zv::Type
112                ),*
113            }),
114        )
115    } else {
116        (None, None)
117    };
118    let (impl_generics, ty_generics, _) = generics.split_for_impl();
119    match fields {
120        Fields::Named(_) => {
121            let field_names: Vec<_> = fields
122                .iter()
123                .map(|field| field.ident.to_token_stream())
124                .collect();
125            let (from_value_impl, into_value_impl) = match signature {
126                Some(signature) if signature == "a{sv}" => (
127                    // User wants the type to be encoded as a dict.
128                    // FIXME: Not the most efficient implementation.
129                    quote! {
130                        let mut fields = <::std::collections::HashMap::<::std::string::String, #zv::Value>>::try_from(value)?;
131
132                        ::std::result::Result::Ok(Self {
133                            #(
134                                #field_names:
135                                    fields
136                                        .remove(stringify!(#field_names))
137                                        .ok_or_else(|| #zv::Error::IncorrectType)?
138                                        .downcast()?
139                            ),*
140                        })
141                    },
142                    quote! {
143                        let mut fields = ::std::collections::HashMap::new();
144                        #(
145                            fields.insert(stringify!(#field_names), #zv::Value::from(s.#field_names));
146                        )*
147
148                        <#value_type>::#into_value_method(#zv::Value::from(fields))
149                            #into_value_error_transform
150                    },
151                ),
152                Some(_) | None => (
153                    quote! {
154                        let mut fields = #zv::Structure::try_from(value)?.into_fields();
155
156                        ::std::result::Result::Ok(Self {
157                            #(
158                                #field_names: fields.remove(0).downcast()?
159                            ),*
160                        })
161                    },
162                    quote! {
163                        <#value_type>::#into_value_method(#zv::StructureBuilder::new()
164                        #(
165                            .add_field(s.#field_names)
166                        )*
167                        .build())
168                        #into_value_error_transform
169                    },
170                ),
171            };
172            Ok(quote! {
173                impl #impl_generics ::std::convert::TryFrom<#value_type> for #name #ty_generics
174                    #from_value_where_clause
175                {
176                    type Error = #zv::Error;
177
178                    #[inline]
179                    fn try_from(value: #value_type) -> #zv::Result<Self> {
180                        #from_value_impl
181                    }
182                }
183
184                impl #impl_generics #into_value_trait<#name #ty_generics> for #value_type
185                    #into_value_where_clause
186                {
187                    #into_value_error_decl
188
189                    #[inline]
190                    fn #into_value_method(s: #name #ty_generics) -> #into_value_ret {
191                        #into_value_impl
192                    }
193                }
194            })
195        }
196        Fields::Unnamed(_) if fields.iter().next().is_some() => {
197            // Newtype struct.
198            Ok(quote! {
199                impl #impl_generics ::std::convert::TryFrom<#value_type> for #name #ty_generics
200                    #from_value_where_clause
201                {
202                    type Error = #zv::Error;
203
204                    #[inline]
205                    fn try_from(value: #value_type) -> #zv::Result<Self> {
206                        ::std::convert::TryInto::try_into(value).map(Self)
207                    }
208                }
209
210                impl #impl_generics #into_value_trait<#name #ty_generics> for #value_type
211                    #into_value_where_clause
212                {
213                    #into_value_error_decl
214
215                    #[inline]
216                    fn #into_value_method(s: #name #ty_generics) -> #into_value_ret {
217                        <#value_type>::#into_value_method(s.0) #into_value_error_transform
218                    }
219                }
220            })
221        }
222        Fields::Unnamed(_) => panic!("impl_struct must not be called for tuples"),
223        Fields::Unit => panic!("impl_struct must not be called for unit structures"),
224    }
225}
226
227fn impl_enum(
228    value_type: ValueType,
229    name: Ident,
230    _generics: Generics,
231    attrs: Vec<Attribute>,
232    data: &DataEnum,
233    zv: &TokenStream,
234) -> Result<TokenStream, Error> {
235    let repr: TokenStream = match attrs.iter().find(|attr| attr.path().is_ident("repr")) {
236        Some(repr_attr) => repr_attr.parse_args()?,
237        None => quote! { u32 },
238    };
239
240    let mut variant_names = vec![];
241    let mut variant_values = vec![];
242    for variant in &data.variants {
243        // Ensure all variants of the enum are unit type
244        match variant.fields {
245            Fields::Unit => {
246                variant_names.push(&variant.ident);
247                let value = match &variant
248                    .discriminant
249                    .as_ref()
250                    .ok_or_else(|| Error::new(variant.span(), "expected `Name = Value` variants"))?
251                    .1
252                {
253                    Expr::Lit(lit_exp) => &lit_exp.lit,
254                    _ => {
255                        return Err(Error::new(
256                            variant.span(),
257                            "expected `Name = Value` variants",
258                        ))
259                    }
260                };
261                variant_values.push(value);
262            }
263            _ => return Err(Error::new(variant.span(), "must be a unit variant")),
264        }
265    }
266
267    let (value_type, into_value) = match value_type {
268        ValueType::Value => (
269            quote! { #zv::Value<'_> },
270            quote! {
271                impl ::std::convert::From<#name> for #zv::Value<'_> {
272                    #[inline]
273                    fn from(e: #name) -> Self {
274                        let u: #repr = match e {
275                            #(
276                                #name::#variant_names => #variant_values
277                            ),*
278                        };
279
280                        <#zv::Value as ::std::convert::From<_>>::from(u).into()
281                    }
282                }
283            },
284        ),
285        ValueType::OwnedValue => (
286            quote! { #zv::OwnedValue },
287            quote! {
288                impl ::std::convert::TryFrom<#name> for #zv::OwnedValue {
289                    type Error = #zv::Error;
290
291                    #[inline]
292                    fn try_from(e: #name) -> #zv::Result<Self> {
293                        let u: #repr = match e {
294                            #(
295                                #name::#variant_names => #variant_values
296                            ),*
297                        };
298
299                        <#zv::OwnedValue as ::std::convert::TryFrom<_>>::try_from(
300                            <#zv::Value as ::std::convert::From<_>>::from(u)
301                        )
302                    }
303                }
304            },
305        ),
306    };
307
308    Ok(quote! {
309        impl ::std::convert::TryFrom<#value_type> for #name {
310            type Error = #zv::Error;
311
312            #[inline]
313            fn try_from(value: #value_type) -> #zv::Result<Self> {
314                let v: #repr = ::std::convert::TryInto::try_into(value)?;
315
316                ::std::result::Result::Ok(match v {
317                    #(
318                        #variant_values => #name::#variant_names
319                     ),*,
320                    _ => return ::std::result::Result::Err(#zv::Error::IncorrectType),
321                })
322            }
323        }
324
325        #into_value
326    })
327}