zbus/message/
builder.rs

1use std::{
2    io::{Cursor, Write},
3    sync::Arc,
4};
5#[cfg(unix)]
6use zvariant::OwnedFd;
7
8use enumflags2::BitFlags;
9use zbus_names::{BusName, ErrorName, InterfaceName, MemberName, UniqueName};
10use zvariant::{serialized, Endian};
11
12use crate::{
13    message::{Field, FieldCode, Fields, Flags, Header, Message, PrimaryHeader, Sequence, Type},
14    utils::padding_for_8_bytes,
15    zvariant::{serialized::Context, DynamicType, ObjectPath, Signature},
16    EndianSig, Error, Result,
17};
18
19use crate::message::{fields::QuickFields, header::MAX_MESSAGE_SIZE};
20
21#[cfg(unix)]
22type BuildGenericResult = Vec<OwnedFd>;
23
24#[cfg(not(unix))]
25type BuildGenericResult = ();
26
27macro_rules! dbus_context {
28    ($self:ident, $n_bytes_before: expr) => {
29        Context::new_dbus($self.header.primary().endian_sig().into(), $n_bytes_before)
30    };
31}
32
33/// A builder for [`Message`]
34#[derive(Debug, Clone)]
35pub struct Builder<'a> {
36    header: Header<'a>,
37}
38
39impl<'a> Builder<'a> {
40    fn new(msg_type: Type) -> Self {
41        let primary = PrimaryHeader::new(msg_type, 0);
42        let fields = Fields::new();
43        let header = Header::new(primary, fields);
44        Self { header }
45    }
46
47    /// Create a message of type [`Type::MethodCall`].
48    #[deprecated(since = "4.0.0", note = "Please use `Message::method` instead")]
49    pub fn method_call<'p: 'a, 'm: 'a, P, M>(path: P, method_name: M) -> Result<Self>
50    where
51        P: TryInto<ObjectPath<'p>>,
52        M: TryInto<MemberName<'m>>,
53        P::Error: Into<Error>,
54        M::Error: Into<Error>,
55    {
56        Self::new(Type::MethodCall).path(path)?.member(method_name)
57    }
58
59    /// Create a message of type [`Type::Signal`].
60    #[deprecated(since = "4.0.0", note = "Please use `Message::signal` instead")]
61    pub fn signal<'p: 'a, 'i: 'a, 'm: 'a, P, I, M>(path: P, interface: I, name: M) -> Result<Self>
62    where
63        P: TryInto<ObjectPath<'p>>,
64        I: TryInto<InterfaceName<'i>>,
65        M: TryInto<MemberName<'m>>,
66        P::Error: Into<Error>,
67        I::Error: Into<Error>,
68        M::Error: Into<Error>,
69    {
70        Self::new(Type::Signal)
71            .path(path)?
72            .interface(interface)?
73            .member(name)
74    }
75
76    /// Create a message of type [`Type::MethodReturn`].
77    #[deprecated(since = "4.0.0", note = "Please use `Message::method_reply` instead")]
78    pub fn method_return(reply_to: &Header<'_>) -> Result<Self> {
79        Self::new(Type::MethodReturn).reply_to(reply_to)
80    }
81
82    /// Create a message of type [`Type::Error`].
83    #[deprecated(since = "4.0.0", note = "Please use `Message::method_error` instead")]
84    pub fn error<'e: 'a, E>(reply_to: &Header<'_>, name: E) -> Result<Self>
85    where
86        E: TryInto<ErrorName<'e>>,
87        E::Error: Into<Error>,
88    {
89        Self::new(Type::Error).error_name(name)?.reply_to(reply_to)
90    }
91
92    /// Add flags to the message.
93    ///
94    /// See [`Flags`] documentation for the meaning of the flags.
95    ///
96    /// The function will return an error if invalid flags are given for the message type.
97    pub fn with_flags(mut self, flag: Flags) -> Result<Self> {
98        if self.header.message_type() != Type::MethodCall
99            && BitFlags::from_flag(flag).contains(Flags::NoReplyExpected)
100        {
101            return Err(Error::InvalidField);
102        }
103        let flags = self.header.primary().flags() | flag;
104        self.header.primary_mut().set_flags(flags);
105        Ok(self)
106    }
107
108    /// Set the unique name of the sending connection.
109    pub fn sender<'s: 'a, S>(mut self, sender: S) -> Result<Self>
110    where
111        S: TryInto<UniqueName<'s>>,
112        S::Error: Into<Error>,
113    {
114        self.header
115            .fields_mut()
116            .replace(Field::Sender(sender.try_into().map_err(Into::into)?));
117        Ok(self)
118    }
119
120    /// Set the object to send a call to, or the object a signal is emitted from.
121    pub fn path<'p: 'a, P>(mut self, path: P) -> Result<Self>
122    where
123        P: TryInto<ObjectPath<'p>>,
124        P::Error: Into<Error>,
125    {
126        self.header
127            .fields_mut()
128            .replace(Field::Path(path.try_into().map_err(Into::into)?));
129        Ok(self)
130    }
131
132    /// Set the interface to invoke a method call on, or that a signal is emitted from.
133    pub fn interface<'i: 'a, I>(mut self, interface: I) -> Result<Self>
134    where
135        I: TryInto<InterfaceName<'i>>,
136        I::Error: Into<Error>,
137    {
138        self.header
139            .fields_mut()
140            .replace(Field::Interface(interface.try_into().map_err(Into::into)?));
141        Ok(self)
142    }
143
144    /// Set the member, either the method name or signal name.
145    pub fn member<'m: 'a, M>(mut self, member: M) -> Result<Self>
146    where
147        M: TryInto<MemberName<'m>>,
148        M::Error: Into<Error>,
149    {
150        self.header
151            .fields_mut()
152            .replace(Field::Member(member.try_into().map_err(Into::into)?));
153        Ok(self)
154    }
155
156    fn error_name<'e: 'a, E>(mut self, error: E) -> Result<Self>
157    where
158        E: TryInto<ErrorName<'e>>,
159        E::Error: Into<Error>,
160    {
161        self.header
162            .fields_mut()
163            .replace(Field::ErrorName(error.try_into().map_err(Into::into)?));
164        Ok(self)
165    }
166
167    /// Set the name of the connection this message is intended for.
168    pub fn destination<'d: 'a, D>(mut self, destination: D) -> Result<Self>
169    where
170        D: TryInto<BusName<'d>>,
171        D::Error: Into<Error>,
172    {
173        self.header.fields_mut().replace(Field::Destination(
174            destination.try_into().map_err(Into::into)?,
175        ));
176        Ok(self)
177    }
178
179    fn reply_to(mut self, reply_to: &Header<'_>) -> Result<Self> {
180        let serial = reply_to.primary().serial_num();
181        self.header.fields_mut().replace(Field::ReplySerial(serial));
182        self = self.endian(reply_to.primary().endian_sig().into());
183
184        if let Some(sender) = reply_to.sender() {
185            self.destination(sender.to_owned())
186        } else {
187            Ok(self)
188        }
189    }
190
191    /// Set the endianness of the message.
192    ///
193    /// The default endianness is native.
194    pub fn endian(mut self, endian: Endian) -> Self {
195        let sig = EndianSig::from(endian);
196        self.header.primary_mut().set_endian_sig(sig);
197
198        self
199    }
200
201    /// Build the [`Message`] with the given body.
202    ///
203    /// You may pass `()` as the body if the message has no body.
204    ///
205    /// The caller is currently required to ensure that the resulting message contains the headers
206    /// as compliant with the [specification]. Additional checks may be added to this builder over
207    /// time as needed.
208    ///
209    /// [specification]:
210    /// https://dbus.freedesktop.org/doc/dbus-specification.html#message-protocol-header-fields
211    pub fn build<B>(self, body: &B) -> Result<Message>
212    where
213        B: serde::ser::Serialize + DynamicType,
214    {
215        let ctxt = dbus_context!(self, 0);
216
217        // Note: this iterates the body twice, but we prefer efficient handling of large messages
218        // to efficient handling of ones that are complex to serialize.
219        let body_size = zvariant::serialized_size(ctxt, body)?;
220
221        let signature = body.dynamic_signature();
222
223        self.build_generic(signature, body_size, move |cursor| {
224            // SAFETY: build_generic puts FDs and the body in the same Message.
225            unsafe { zvariant::to_writer(cursor, ctxt, body) }
226                .map(|s| {
227                    #[cfg(unix)]
228                    {
229                        s.into_fds()
230                    }
231                    #[cfg(not(unix))]
232                    {
233                        let _ = s;
234                    }
235                })
236                .map_err(Into::into)
237        })
238    }
239
240    /// Create a new message from a raw slice of bytes to populate the body with, rather than by
241    /// serializing a value. The message body will be the exact bytes.
242    ///
243    /// # Safety
244    ///
245    /// This method is unsafe because it can be used to build an invalid message.
246    pub unsafe fn build_raw_body<'b, S>(
247        self,
248        body_bytes: &[u8],
249        signature: S,
250        #[cfg(unix)] fds: Vec<OwnedFd>,
251    ) -> Result<Message>
252    where
253        S: TryInto<Signature<'b>>,
254        S::Error: Into<Error>,
255    {
256        let signature: Signature<'b> = signature.try_into().map_err(Into::into)?;
257        let body_size = serialized::Size::new(body_bytes.len(), dbus_context!(self, 0));
258        #[cfg(unix)]
259        let body_size = {
260            let num_fds = fds.len().try_into().map_err(|_| Error::ExcessData)?;
261            body_size.set_num_fds(num_fds)
262        };
263
264        self.build_generic(
265            signature,
266            body_size,
267            move |cursor: &mut Cursor<&mut Vec<u8>>| {
268                cursor.write_all(body_bytes)?;
269
270                #[cfg(unix)]
271                return Ok::<Vec<OwnedFd>, Error>(fds);
272
273                #[cfg(not(unix))]
274                return Ok::<(), Error>(());
275            },
276        )
277    }
278
279    fn build_generic<WriteFunc>(
280        self,
281        mut signature: Signature<'_>,
282        body_size: serialized::Size,
283        write_body: WriteFunc,
284    ) -> Result<Message>
285    where
286        WriteFunc: FnOnce(&mut Cursor<&mut Vec<u8>>) -> Result<BuildGenericResult>,
287    {
288        let ctxt = dbus_context!(self, 0);
289        let mut header = self.header;
290
291        if !signature.is_empty() {
292            if signature.starts_with(zvariant::STRUCT_SIG_START_STR) {
293                // Remove leading and trailing STRUCT delimiters
294                signature = signature.slice(1..signature.len() - 1);
295            }
296            header.fields_mut().add(Field::Signature(signature));
297        }
298
299        let body_len_u32 = body_size.size().try_into().map_err(|_| Error::ExcessData)?;
300        header.primary_mut().set_body_len(body_len_u32);
301
302        #[cfg(unix)]
303        {
304            let fds_len = body_size.num_fds();
305            if fds_len != 0 {
306                header.fields_mut().add(Field::UnixFDs(fds_len));
307            }
308        }
309
310        let hdr_len = *zvariant::serialized_size(ctxt, &header)?;
311        // We need to align the body to 8-byte boundary.
312        let body_padding = padding_for_8_bytes(hdr_len);
313        let body_offset = hdr_len + body_padding;
314        let total_len = body_offset + body_size.size();
315        if total_len > MAX_MESSAGE_SIZE {
316            return Err(Error::ExcessData);
317        }
318        let mut bytes: Vec<u8> = Vec::with_capacity(total_len);
319        let mut cursor = Cursor::new(&mut bytes);
320
321        // SAFETY: There are no FDs involved.
322        unsafe { zvariant::to_writer(&mut cursor, ctxt, &header) }?;
323        for _ in 0..body_padding {
324            cursor.write_all(&[0u8])?;
325        }
326        #[cfg(unix)]
327        let fds: Vec<_> = write_body(&mut cursor)?.into_iter().collect();
328        #[cfg(not(unix))]
329        write_body(&mut cursor)?;
330
331        let primary_header = header.into_primary();
332        #[cfg(unix)]
333        let bytes = serialized::Data::new_fds(bytes, ctxt, fds);
334        #[cfg(not(unix))]
335        let bytes = serialized::Data::new(bytes, ctxt);
336        let (header, actual_hdr_len): (Header<'_>, _) = bytes.deserialize()?;
337        assert_eq!(hdr_len, actual_hdr_len);
338        let quick_fields = QuickFields::new(&bytes, &header)?;
339
340        Ok(Message {
341            inner: Arc::new(super::Inner {
342                primary_header,
343                quick_fields,
344                bytes,
345                body_offset,
346                recv_seq: Sequence::default(),
347            }),
348        })
349    }
350}
351
352impl<'m> From<Header<'m>> for Builder<'m> {
353    fn from(mut header: Header<'m>) -> Self {
354        // Signature and Fds are added by body* methods.
355        let fields = header.fields_mut();
356        fields.remove(FieldCode::Signature);
357        fields.remove(FieldCode::UnixFDs);
358
359        Self { header }
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::Message;
366    use crate::Error;
367    use test_log::test;
368
369    #[test]
370    fn test_raw() -> Result<(), Error> {
371        let raw_body: &[u8] = &[16, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0];
372        let message_builder = Message::signal("/", "test.test", "test")?;
373        let message = unsafe {
374            message_builder.build_raw_body(
375                raw_body,
376                "ai",
377                #[cfg(unix)]
378                vec![],
379            )?
380        };
381
382        let output: Vec<i32> = message.body().deserialize()?;
383        assert_eq!(output, vec![1, 2, 3, 4]);
384
385        Ok(())
386    }
387}