zvariant_derive/
dict.rs

1use proc_macro2::{Span, TokenStream};
2use quote::{format_ident, quote, ToTokens};
3use syn::{punctuated::Punctuated, spanned::Spanned, Data, DeriveInput, Error, Field};
4use zvariant_utils::{case, macros};
5
6use crate::utils::*;
7
8fn dict_name_for_field(
9    f: &Field,
10    rename_attr: Option<String>,
11    rename_all_attr: Option<&str>,
12) -> Result<String, Error> {
13    if let Some(name) = rename_attr {
14        Ok(name)
15    } else {
16        let ident = f.ident.as_ref().unwrap().to_string();
17
18        match rename_all_attr {
19            Some("lowercase") => Ok(ident.to_ascii_lowercase()),
20            Some("UPPERCASE") => Ok(ident.to_ascii_uppercase()),
21            Some("PascalCase") => Ok(case::pascal_or_camel_case(&ident, true)),
22            Some("camelCase") => Ok(case::pascal_or_camel_case(&ident, false)),
23            Some("snake_case") => Ok(case::snake_or_kebab_case(&ident, true)),
24            Some("kebab-case") => Ok(case::snake_or_kebab_case(&ident, false)),
25            None => Ok(ident),
26            Some(other) => Err(Error::new(
27                f.span(),
28                format!("invalid `rename_all` attribute value {other}"),
29            )),
30        }
31    }
32}
33
34pub fn expand_serialize_derive(input: DeriveInput) -> Result<TokenStream, Error> {
35    let (name, data) = match input.data {
36        Data::Struct(data) => (input.ident, data),
37        _ => return Err(Error::new(input.span(), "only structs supported")),
38    };
39
40    let StructAttributes { rename_all, .. } = StructAttributes::parse(&input.attrs)?;
41
42    let zv = zvariant_path();
43    let mut entries = quote! {};
44    let mut num_entries: usize = 0;
45
46    for f in &data.fields {
47        let FieldAttributes { rename } = FieldAttributes::parse(&f.attrs)?;
48
49        let name = &f.ident;
50        let dict_name = dict_name_for_field(f, rename, rename_all.as_deref())?;
51
52        let is_option = macros::ty_is_option(&f.ty);
53
54        let e = if is_option {
55            quote! {
56                if self.#name.is_some() {
57                    map.serialize_entry(#dict_name, &#zv::SerializeValue(self.#name.as_ref().unwrap()))?;
58                }
59            }
60        } else {
61            quote! {
62                map.serialize_entry(#dict_name, &#zv::SerializeValue(&self.#name))?;
63            }
64        };
65
66        entries.extend(e);
67        num_entries += 1;
68    }
69
70    let generics = input.generics;
71    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
72
73    let num_entries = num_entries.to_token_stream();
74    Ok(quote! {
75        #[allow(deprecated)]
76        impl #impl_generics #zv::export::serde::ser::Serialize for #name #ty_generics
77        #where_clause
78        {
79            fn serialize<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error>
80            where
81                S: #zv::export::serde::ser::Serializer,
82            {
83                use #zv::export::serde::ser::SerializeMap;
84
85                // zbus doesn't care about number of entries (it would need bytes instead)
86                let mut map = serializer.serialize_map(::std::option::Option::Some(#num_entries))?;
87                #entries
88                map.end()
89            }
90        }
91    })
92}
93
94pub fn expand_deserialize_derive(input: DeriveInput) -> Result<TokenStream, Error> {
95    let (name, data) = match input.data {
96        Data::Struct(data) => (input.ident, data),
97        _ => return Err(Error::new(input.span(), "only structs supported")),
98    };
99
100    let StructAttributes {
101        rename_all,
102        deny_unknown_fields,
103        ..
104    } = StructAttributes::parse(&input.attrs)?;
105
106    let visitor = format_ident!("{}Visitor", name);
107    let zv = zvariant_path();
108    let mut fields = Vec::new();
109    let mut req_fields = Vec::new();
110    let mut dict_names = Vec::new();
111    let mut entries = Vec::new();
112
113    for f in &data.fields {
114        let FieldAttributes { rename } = FieldAttributes::parse(&f.attrs)?;
115
116        let name = &f.ident;
117        let dict_name = dict_name_for_field(f, rename, rename_all.as_deref())?;
118
119        let is_option = macros::ty_is_option(&f.ty);
120
121        entries.push(quote! {
122            #dict_name => {
123                // FIXME: add an option about strict parsing (instead of silently skipping the field)
124                #name = access.next_value::<#zv::DeserializeValue<_>>().map(|v| v.0).ok();
125            }
126        });
127
128        dict_names.push(dict_name);
129        fields.push(name);
130
131        if !is_option {
132            req_fields.push(name);
133        }
134    }
135
136    let fallback = if deny_unknown_fields {
137        quote! {
138            field => {
139                return ::std::result::Result::Err(
140                    <M::Error as #zv::export::serde::de::Error>::unknown_field(
141                        field,
142                        &[#(#dict_names),*],
143                    ),
144                );
145            }
146        }
147    } else {
148        quote! {
149            unknown => {
150                let _ = access.next_value::<#zv::Value>();
151            }
152        }
153    };
154    entries.push(fallback);
155
156    let (_, ty_generics, _) = input.generics.split_for_impl();
157    let mut generics = input.generics.clone();
158    let def = syn::LifetimeParam {
159        attrs: Vec::new(),
160        lifetime: syn::Lifetime::new("'de", Span::call_site()),
161        colon_token: None,
162        bounds: Punctuated::new(),
163    };
164    generics.params = Some(syn::GenericParam::Lifetime(def))
165        .into_iter()
166        .chain(generics.params)
167        .collect();
168
169    let (impl_generics, _, where_clause) = generics.split_for_impl();
170
171    Ok(quote! {
172        #[allow(deprecated)]
173        impl #impl_generics #zv::export::serde::de::Deserialize<'de> for #name #ty_generics
174        #where_clause
175        {
176            fn deserialize<D>(deserializer: D) -> ::std::result::Result<Self, D::Error>
177            where
178                D: #zv::export::serde::de::Deserializer<'de>,
179            {
180                struct #visitor #ty_generics(::std::marker::PhantomData<#name #ty_generics>);
181
182                impl #impl_generics #zv::export::serde::de::Visitor<'de> for #visitor #ty_generics {
183                    type Value = #name #ty_generics;
184
185                    fn expecting(&self, formatter: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
186                        formatter.write_str("a dictionary")
187                    }
188
189                    fn visit_map<M>(
190                        self,
191                        mut access: M,
192                    ) -> ::std::result::Result<Self::Value, M::Error>
193                    where
194                        M: #zv::export::serde::de::MapAccess<'de>,
195                    {
196                        #( let mut #fields = ::std::default::Default::default(); )*
197
198                        // does not check duplicated fields, since those shouldn't exist in stream
199                        while let ::std::option::Option::Some(key) = access.next_key::<&str>()? {
200                            match key {
201                                #(#entries)*
202                            }
203                        }
204
205                        #(let #req_fields = if let ::std::option::Option::Some(val) = #req_fields {
206                            val
207                        } else {
208                            return ::std::result::Result::Err(
209                                <M::Error as #zv::export::serde::de::Error>::missing_field(
210                                    ::std::stringify!(#req_fields),
211                                ),
212                            );
213                        };)*
214
215                        ::std::result::Result::Ok(#name { #(#fields),* })
216                    }
217                }
218
219
220                deserializer.deserialize_map(#visitor(::std::marker::PhantomData))
221            }
222        }
223    })
224}