zbus/connection/socket/
mod.rs

1#[cfg(feature = "p2p")]
2pub mod channel;
3#[cfg(feature = "p2p")]
4pub use channel::Channel;
5
6mod split;
7pub use split::{BoxedSplit, Split};
8
9mod tcp;
10mod unix;
11mod vsock;
12
13#[cfg(not(feature = "tokio"))]
14use async_io::Async;
15#[cfg(not(feature = "tokio"))]
16use std::sync::Arc;
17use std::{io, mem};
18use tracing::trace;
19
20use crate::{
21    fdo::ConnectionCredentials,
22    message::{
23        header::{MAX_MESSAGE_SIZE, MIN_MESSAGE_SIZE},
24        PrimaryHeader,
25    },
26    padding_for_8_bytes, Message,
27};
28#[cfg(unix)]
29use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
30use zvariant::{
31    serialized::{self, Context},
32    Endian,
33};
34
35#[cfg(unix)]
36type RecvmsgResult = io::Result<(usize, Vec<OwnedFd>)>;
37
38#[cfg(not(unix))]
39type RecvmsgResult = io::Result<usize>;
40
41/// Trait representing some transport layer over which the DBus protocol can be used
42///
43/// In order to allow simultaneous reading and writing, this trait requires you to split the socket
44/// into a read half and a write half. The reader and writer halves can be any types that implement
45/// [`ReadHalf`] and [`WriteHalf`] respectively.
46///
47/// The crate provides implementations for `async_io` and `tokio`'s `UnixStream` wrappers if you
48/// enable the corresponding crate features (`async_io` is enabled by default).
49///
50/// You can implement it manually to integrate with other runtimes or other dbus transports.  Feel
51/// free to submit pull requests to add support for more runtimes to zbus itself so rust's orphan
52/// rules don't force the use of a wrapper struct (and to avoid duplicating the work across many
53/// projects).
54pub trait Socket {
55    type ReadHalf: ReadHalf;
56    type WriteHalf: WriteHalf;
57
58    /// Split the socket into a read half and a write half.
59    fn split(self) -> Split<Self::ReadHalf, Self::WriteHalf>
60    where
61        Self: Sized;
62}
63
64/// The read half of a socket.
65///
66/// See [`Socket`] for more details.
67#[async_trait::async_trait]
68pub trait ReadHalf: std::fmt::Debug + Send + Sync + 'static {
69    /// Receive a message on the socket.
70    ///
71    /// This is the higher-level method to receive a full D-Bus message.
72    ///
73    /// The default implementation uses `recvmsg` to receive the message. Implementers should
74    /// override either this or `recvmsg`. Note that if you override this method, zbus will not be
75    /// able perform an authentication handshake and hence will skip the handshake. Therefore your
76    /// implementation will only be useful for pre-authenticated connections or connections that do
77    /// not require authentication.
78    ///
79    /// # Parameters
80    ///
81    /// - `seq`: The sequence number of the message. The returned message should have this sequence.
82    /// - `already_received_bytes`: Sometimes, zbus already received some bytes from the socket
83    ///   belonging to the first message(s) (as part of the connection handshake process). This is
84    ///   the buffer containing those bytes (if any). If you're implementing this method, most
85    ///   likely you can safely ignore this parameter.
86    /// - `already_received_fds`: Same goes for file descriptors belonging to first messages.
87    async fn receive_message(
88        &mut self,
89        seq: u64,
90        already_received_bytes: &mut Vec<u8>,
91        #[cfg(unix)] already_received_fds: &mut Vec<std::os::fd::OwnedFd>,
92    ) -> crate::Result<Message> {
93        #[cfg(unix)]
94        let mut fds = vec![];
95        let mut bytes = if already_received_bytes.len() < MIN_MESSAGE_SIZE {
96            let mut bytes = vec![];
97            if !already_received_bytes.is_empty() {
98                mem::swap(already_received_bytes, &mut bytes);
99            }
100            let mut pos = bytes.len();
101            bytes.resize(MIN_MESSAGE_SIZE, 0);
102            // We don't have enough data to make a proper message header yet.
103            // Some partial read may be in raw_in_buffer, so we try to complete it
104            // until we have MIN_MESSAGE_SIZE bytes
105            //
106            // Given that MIN_MESSAGE_SIZE is 16, this codepath is actually extremely unlikely
107            // to be taken more than once
108            while pos < MIN_MESSAGE_SIZE {
109                let res = self.recvmsg(&mut bytes[pos..]).await?;
110                let len = {
111                    #[cfg(unix)]
112                    {
113                        fds.extend(res.1);
114                        res.0
115                    }
116                    #[cfg(not(unix))]
117                    {
118                        res
119                    }
120                };
121                pos += len;
122                if len == 0 {
123                    return Err(std::io::Error::new(
124                        std::io::ErrorKind::UnexpectedEof,
125                        "failed to receive message",
126                    )
127                    .into());
128                }
129            }
130
131            bytes
132        } else {
133            already_received_bytes.drain(..MIN_MESSAGE_SIZE).collect()
134        };
135
136        let (primary_header, fields_len) = PrimaryHeader::read(&bytes)?;
137        let header_len = MIN_MESSAGE_SIZE + fields_len as usize;
138        let body_padding = padding_for_8_bytes(header_len);
139        let body_len = primary_header.body_len() as usize;
140        let total_len = header_len + body_padding + body_len;
141        if total_len > MAX_MESSAGE_SIZE {
142            return Err(crate::Error::ExcessData);
143        }
144
145        // By this point we have a full primary header, so we know the exact length of the complete
146        // message.
147        if !already_received_bytes.is_empty() {
148            // still have some bytes buffered.
149            let pending = total_len - bytes.len();
150            let to_take = std::cmp::min(pending, already_received_bytes.len());
151            bytes.extend(already_received_bytes.drain(..to_take));
152        }
153        let mut pos = bytes.len();
154        bytes.resize(total_len, 0);
155
156        // Read the rest, if any
157        while pos < total_len {
158            let res = self.recvmsg(&mut bytes[pos..]).await?;
159            let read = {
160                #[cfg(unix)]
161                {
162                    fds.extend(res.1);
163                    res.0
164                }
165                #[cfg(not(unix))]
166                {
167                    res
168                }
169            };
170            pos += read;
171            if read == 0 {
172                return Err(crate::Error::InputOutput(
173                    std::io::Error::new(
174                        std::io::ErrorKind::UnexpectedEof,
175                        "failed to receive message",
176                    )
177                    .into(),
178                ));
179            }
180        }
181
182        // If we reach here, the message is complete; return it
183        let endian = Endian::from(primary_header.endian_sig());
184
185        #[cfg(unix)]
186        if !already_received_fds.is_empty() {
187            use crate::message::{header::PRIMARY_HEADER_SIZE, Field};
188
189            let ctxt = Context::new_dbus(endian, PRIMARY_HEADER_SIZE);
190            let encoded_fields =
191                serialized::Data::new(&bytes[PRIMARY_HEADER_SIZE..header_len], ctxt);
192            let fields: crate::message::Fields<'_> = encoded_fields.deserialize()?.0;
193            let num_required_fds = match fields.get_field(crate::message::FieldCode::UnixFDs) {
194                Some(Field::UnixFDs(num_fds)) => *num_fds as usize,
195                _ => 0,
196            };
197            let num_pending = num_required_fds
198                .checked_sub(fds.len())
199                .ok_or_else(|| crate::Error::ExcessData)?;
200            // If we had previously received FDs, `num_pending` has to be > 0
201            if num_pending == 0 {
202                return Err(crate::Error::MissingParameter("Missing file descriptors"));
203            }
204            // All previously received FDs must go first in the list.
205            let mut already_received: Vec<_> = already_received_fds.drain(..num_pending).collect();
206            mem::swap(&mut already_received, &mut fds);
207            fds.extend(already_received);
208        }
209
210        let ctxt = Context::new_dbus(endian, 0);
211        #[cfg(unix)]
212        let bytes = serialized::Data::new_fds(bytes, ctxt, fds);
213        #[cfg(not(unix))]
214        let bytes = serialized::Data::new(bytes, ctxt);
215        Message::from_raw_parts(bytes, seq)
216    }
217
218    /// Attempt to receive bytes from the socket.
219    ///
220    /// On success, returns the number of bytes read as well as a `Vec` containing
221    /// any associated file descriptors.
222    ///
223    /// The default implementation simply panics. Implementers must override either `read_message`
224    /// or this method.
225    async fn recvmsg(&mut self, _buf: &mut [u8]) -> RecvmsgResult {
226        unimplemented!("`ReadHalf` implementers must either override `read_message` or `recvmsg`");
227    }
228
229    /// Supports passing file descriptors.
230    ///
231    /// Default implementation returns `false`.
232    fn can_pass_unix_fd(&self) -> bool {
233        false
234    }
235
236    /// Return the peer credentials.
237    async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
238        Ok(ConnectionCredentials::default())
239    }
240}
241
242/// The write half of a socket.
243///
244/// See [`Socket`] for more details.
245#[async_trait::async_trait]
246pub trait WriteHalf: std::fmt::Debug + Send + Sync + 'static {
247    /// Send a message on the socket.
248    ///
249    /// This is the higher-level method to send a full D-Bus message.
250    ///
251    /// The default implementation uses `sendmsg` to send the message. Implementers should override
252    /// either this or `sendmsg`.
253    async fn send_message(&mut self, msg: &Message) -> crate::Result<()> {
254        let data = msg.data();
255        let serial = msg.primary_header().serial_num();
256
257        trace!("Sending message: {:?}", msg);
258        let mut pos = 0;
259        while pos < data.len() {
260            #[cfg(unix)]
261            let fds = if pos == 0 {
262                data.fds().iter().map(|f| f.as_fd()).collect()
263            } else {
264                vec![]
265            };
266            pos += self
267                .sendmsg(
268                    &data[pos..],
269                    #[cfg(unix)]
270                    &fds,
271                )
272                .await?;
273        }
274        trace!("Sent message with serial: {}", serial);
275
276        Ok(())
277    }
278
279    /// Attempt to send a message on the socket
280    ///
281    /// On success, return the number of bytes written. There may be a partial write, in
282    /// which case the caller is responsible of sending the remaining data by calling this
283    /// method again until everything is written or it returns an error of kind `WouldBlock`.
284    ///
285    /// If at least one byte has been written, then all the provided file descriptors will
286    /// have been sent as well, and should not be provided again in subsequent calls.
287    ///
288    /// If the underlying transport does not support transmitting file descriptors, this
289    /// will return `Err(ErrorKind::InvalidInput)`.
290    ///
291    /// The default implementation simply panics. Implementers must override either `send_message`
292    /// or this method.
293    async fn sendmsg(
294        &mut self,
295        _buffer: &[u8],
296        #[cfg(unix)] _fds: &[BorrowedFd<'_>],
297    ) -> io::Result<usize> {
298        unimplemented!("`WriteHalf` implementers must either override `send_message` or `sendmsg`");
299    }
300
301    /// The dbus daemon on `freebsd` and `dragonfly` currently requires sending the zero byte
302    /// as a separate message with SCM_CREDS, as part of the `EXTERNAL` authentication on unix
303    /// sockets. This method is used by the authentication machinery in zbus to send this
304    /// zero byte. Socket implementations based on unix sockets should implement this method.
305    #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
306    async fn send_zero_byte(&mut self) -> io::Result<Option<usize>> {
307        Ok(None)
308    }
309
310    /// Close the socket.
311    ///
312    /// After this call, it is valid for all reading and writing operations to fail.
313    async fn close(&mut self) -> io::Result<()>;
314
315    /// Supports passing file descriptors.
316    ///
317    /// Default implementation returns `false`.
318    fn can_pass_unix_fd(&self) -> bool {
319        false
320    }
321
322    /// Return the peer credentials.
323    async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
324        Ok(ConnectionCredentials::default())
325    }
326}
327
328#[async_trait::async_trait]
329impl ReadHalf for Box<dyn ReadHalf> {
330    fn can_pass_unix_fd(&self) -> bool {
331        (**self).can_pass_unix_fd()
332    }
333
334    async fn receive_message(
335        &mut self,
336        seq: u64,
337        already_received_bytes: &mut Vec<u8>,
338        #[cfg(unix)] already_received_fds: &mut Vec<std::os::fd::OwnedFd>,
339    ) -> crate::Result<Message> {
340        (**self)
341            .receive_message(
342                seq,
343                already_received_bytes,
344                #[cfg(unix)]
345                already_received_fds,
346            )
347            .await
348    }
349
350    async fn recvmsg(&mut self, buf: &mut [u8]) -> RecvmsgResult {
351        (**self).recvmsg(buf).await
352    }
353
354    async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
355        (**self).peer_credentials().await
356    }
357}
358
359#[async_trait::async_trait]
360impl WriteHalf for Box<dyn WriteHalf> {
361    async fn send_message(&mut self, msg: &Message) -> crate::Result<()> {
362        (**self).send_message(msg).await
363    }
364
365    async fn sendmsg(
366        &mut self,
367        buffer: &[u8],
368        #[cfg(unix)] fds: &[BorrowedFd<'_>],
369    ) -> io::Result<usize> {
370        (**self)
371            .sendmsg(
372                buffer,
373                #[cfg(unix)]
374                fds,
375            )
376            .await
377    }
378
379    #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
380    async fn send_zero_byte(&mut self) -> io::Result<Option<usize>> {
381        (**self).send_zero_byte().await
382    }
383
384    async fn close(&mut self) -> io::Result<()> {
385        (**self).close().await
386    }
387
388    fn can_pass_unix_fd(&self) -> bool {
389        (**self).can_pass_unix_fd()
390    }
391
392    async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
393        (**self).peer_credentials().await
394    }
395}
396
397#[cfg(not(feature = "tokio"))]
398impl<T> Socket for Async<T>
399where
400    T: std::fmt::Debug + Send + Sync,
401    Arc<Async<T>>: ReadHalf + WriteHalf,
402{
403    type ReadHalf = Arc<Async<T>>;
404    type WriteHalf = Arc<Async<T>>;
405
406    fn split(self) -> Split<Self::ReadHalf, Self::WriteHalf> {
407        let arc = Arc::new(self);
408
409        Split {
410            read: arc.clone(),
411            write: arc,
412        }
413    }
414}