zbus_macros/
error.rs

1use proc_macro2::TokenStream;
2use quote::{quote, ToTokens};
3use syn::{spanned::Spanned, Data, DeriveInput, Error, Fields, Ident, Variant};
4use zvariant_utils::def_attrs;
5
6def_attrs! {
7    crate zbus;
8
9    pub StructAttributes("struct") {
10        prefix str,
11        impl_display bool
12    };
13
14    pub VariantAttributes("enum variant") {
15        name str,
16        error none
17    };
18}
19
20use crate::utils::*;
21
22pub fn expand_derive(input: DeriveInput) -> Result<TokenStream, Error> {
23    let StructAttributes {
24        prefix,
25        impl_display,
26    } = StructAttributes::parse(&input.attrs)?;
27    let prefix = prefix.unwrap_or_else(|| "org.freedesktop.DBus".to_string());
28    let generate_display = impl_display.unwrap_or(true);
29
30    let (_vis, name, _generics, data) = match input.data {
31        Data::Enum(data) => (input.vis, input.ident, input.generics, data),
32        _ => return Err(Error::new(input.span(), "only enums supported")),
33    };
34
35    let zbus = zbus_path();
36    let mut replies = quote! {};
37    let mut error_names = quote! {};
38    let mut error_descriptions = quote! {};
39    let mut error_converts = quote! {};
40
41    let mut zbus_error_variant = None;
42
43    for variant in data.variants {
44        let VariantAttributes { name, error } = VariantAttributes::parse(&variant.attrs)?;
45        let ident = &variant.ident;
46        let name = name.unwrap_or_else(|| ident.to_string());
47
48        let fqn = if !error {
49            format!("{prefix}.{name}")
50        } else {
51            // The ZBus error variant will always be a hardcoded string.
52            String::from("org.freedesktop.zbus.Error")
53        };
54
55        let error_name = quote! {
56            #zbus::names::ErrorName::from_static_str_unchecked(#fqn)
57        };
58        let e = match variant.fields {
59            Fields::Unit => quote! {
60                Self::#ident => #error_name,
61            },
62            Fields::Unnamed(_) => quote! {
63                Self::#ident(..) => #error_name,
64            },
65            Fields::Named(_) => quote! {
66                Self::#ident { .. } => #error_name,
67            },
68        };
69        error_names.extend(e);
70
71        if error {
72            if zbus_error_variant.is_some() {
73                panic!("More than 1 `#[zbus(error)]` variant found");
74            }
75
76            zbus_error_variant = Some(quote! { #ident });
77        }
78
79        // FIXME: this will error if the first field is not a string as per the dbus spec, but we
80        // may support other cases?
81        let e = match &variant.fields {
82            Fields::Unit => quote! {
83                Self::#ident => None,
84            },
85            Fields::Unnamed(_) => {
86                if error {
87                    quote! {
88                        Self::#ident(#zbus::Error::MethodError(_, desc, _)) => desc.as_deref(),
89                        Self::#ident(_) => None,
90                    }
91                } else {
92                    quote! {
93                        Self::#ident(desc, ..) => Some(&desc),
94                    }
95                }
96            }
97            Fields::Named(n) => {
98                let f = &n
99                    .named
100                    .first()
101                    .ok_or_else(|| Error::new(n.span(), "expected at least one field"))?
102                    .ident;
103                quote! {
104                    Self::#ident { #f, } => Some(#f),
105                }
106            }
107        };
108        error_descriptions.extend(e);
109
110        // The conversion for #[zbus(error)] variant is handled separately/explicitly.
111        if !error {
112            // FIXME: deserialize msg to error field instead, to support variable args
113            let e = match &variant.fields {
114                Fields::Unit => quote! {
115                    #fqn => Self::#ident,
116                },
117                Fields::Unnamed(_) => quote! {
118                    #fqn => { Self::#ident(::std::clone::Clone::clone(desc).unwrap_or_default()) },
119                },
120                Fields::Named(n) => {
121                    let f = &n
122                        .named
123                        .first()
124                        .ok_or_else(|| Error::new(n.span(), "expected at least one field"))?
125                        .ident;
126                    quote! {
127                        #fqn => {
128                            let desc = ::std::clone::Clone::clone(desc).unwrap_or_default();
129
130                            Self::#ident { #f: desc }
131                        }
132                    }
133                }
134            };
135            error_converts.extend(e);
136        }
137
138        let r = gen_reply_for_variant(&variant, error)?;
139        replies.extend(r);
140    }
141
142    let from_zbus_error_impl = zbus_error_variant
143        .map(|ident| {
144            quote! {
145                impl ::std::convert::From<#zbus::Error> for #name {
146                    fn from(value: #zbus::Error) -> #name {
147                        if let #zbus::Error::MethodError(name, desc, _) = &value {
148                            match name.as_str() {
149                                #error_converts
150                                _ => Self::#ident(value),
151                            }
152                        } else {
153                            Self::#ident(value)
154                        }
155                    }
156                }
157            }
158        })
159        .unwrap_or_default();
160
161    let display_impl = if generate_display {
162        quote! {
163            impl ::std::fmt::Display for #name {
164                fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
165                    let name = #zbus::DBusError::name(self);
166                    let description = #zbus::DBusError::description(self).unwrap_or("no description");
167                    ::std::write!(f, "{}: {}", name, description)
168                }
169            }
170        }
171    } else {
172        quote! {}
173    };
174
175    Ok(quote! {
176        impl #zbus::DBusError for #name {
177            fn name(&self) -> #zbus::names::ErrorName {
178                match self {
179                    #error_names
180                }
181            }
182
183            fn description(&self) -> Option<&str> {
184                match self {
185                    #error_descriptions
186                }
187            }
188
189            fn create_reply(&self, call: &#zbus::message::Header) -> #zbus::Result<#zbus::message::Message> {
190                let name = self.name();
191                match self {
192                    #replies
193                }
194            }
195        }
196
197        #display_impl
198
199        impl ::std::error::Error for #name {}
200
201        #from_zbus_error_impl
202    })
203}
204
205fn gen_reply_for_variant(
206    variant: &Variant,
207    zbus_error_variant: bool,
208) -> Result<TokenStream, Error> {
209    let zbus = zbus_path();
210    let ident = &variant.ident;
211    match &variant.fields {
212        Fields::Unit => Ok(quote! {
213            Self::#ident => #zbus::message::Builder::error(call, name)?.build(&()),
214        }),
215        Fields::Unnamed(f) => {
216            // Name the unnamed fields as the number of the field with an 'f' in front.
217            let in_fields = (0..f.unnamed.len())
218                .map(|n| Ident::new(&format!("f{n}"), ident.span()).to_token_stream())
219                .collect::<Vec<_>>();
220            let out_fields = if zbus_error_variant {
221                let error_field = in_fields.first().ok_or_else(|| {
222                    Error::new(
223                        ident.span(),
224                        "expected at least one field for #[zbus(error)] variant",
225                    )
226                })?;
227                vec![quote! {
228                    match #error_field {
229                        #zbus::Error::MethodError(name, desc, _) => {
230                            ::std::clone::Clone::clone(desc)
231                        }
232                        _ => None,
233                    }
234                    .unwrap_or_else(|| ::std::string::ToString::to_string(#error_field))
235                }]
236            } else {
237                // FIXME: Workaround for https://github.com/rust-lang/rust-clippy/issues/10577
238                #[allow(clippy::redundant_clone)]
239                in_fields.clone()
240            };
241
242            Ok(quote! {
243                Self::#ident(#(#in_fields),*) => #zbus::message::Builder::error(call, name)?.build(&(#(#out_fields),*)),
244            })
245        }
246        Fields::Named(f) => {
247            let fields = f.named.iter().map(|v| v.ident.as_ref()).collect::<Vec<_>>();
248            Ok(quote! {
249                Self::#ident { #(#fields),* } => #zbus::message::Builder::error(call, name)?.build(&(#(#fields),*)),
250            })
251        }
252    }
253}