zbus/connection/socket/
mod.rs1#[cfg(feature = "p2p")]
2pub mod channel;
3#[cfg(feature = "p2p")]
4pub use channel::Channel;
5
6mod split;
7pub use split::{BoxedSplit, Split};
8
9mod tcp;
10mod unix;
11mod vsock;
12
13#[cfg(not(feature = "tokio"))]
14use async_io::Async;
15#[cfg(not(feature = "tokio"))]
16use std::sync::Arc;
17use std::{io, mem};
18use tracing::trace;
19
20use crate::{
21 fdo::ConnectionCredentials,
22 message::{
23 header::{MAX_MESSAGE_SIZE, MIN_MESSAGE_SIZE},
24 PrimaryHeader,
25 },
26 padding_for_8_bytes, Message,
27};
28#[cfg(unix)]
29use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
30use zvariant::{
31 serialized::{self, Context},
32 Endian,
33};
34
35#[cfg(unix)]
36type RecvmsgResult = io::Result<(usize, Vec<OwnedFd>)>;
37
38#[cfg(not(unix))]
39type RecvmsgResult = io::Result<usize>;
40
41pub trait Socket {
55 type ReadHalf: ReadHalf;
56 type WriteHalf: WriteHalf;
57
58 fn split(self) -> Split<Self::ReadHalf, Self::WriteHalf>
60 where
61 Self: Sized;
62}
63
64#[async_trait::async_trait]
68pub trait ReadHalf: std::fmt::Debug + Send + Sync + 'static {
69 async fn receive_message(
88 &mut self,
89 seq: u64,
90 already_received_bytes: &mut Vec<u8>,
91 #[cfg(unix)] already_received_fds: &mut Vec<std::os::fd::OwnedFd>,
92 ) -> crate::Result<Message> {
93 #[cfg(unix)]
94 let mut fds = vec![];
95 let mut bytes = if already_received_bytes.len() < MIN_MESSAGE_SIZE {
96 let mut bytes = vec![];
97 if !already_received_bytes.is_empty() {
98 mem::swap(already_received_bytes, &mut bytes);
99 }
100 let mut pos = bytes.len();
101 bytes.resize(MIN_MESSAGE_SIZE, 0);
102 while pos < MIN_MESSAGE_SIZE {
109 let res = self.recvmsg(&mut bytes[pos..]).await?;
110 let len = {
111 #[cfg(unix)]
112 {
113 fds.extend(res.1);
114 res.0
115 }
116 #[cfg(not(unix))]
117 {
118 res
119 }
120 };
121 pos += len;
122 if len == 0 {
123 return Err(std::io::Error::new(
124 std::io::ErrorKind::UnexpectedEof,
125 "failed to receive message",
126 )
127 .into());
128 }
129 }
130
131 bytes
132 } else {
133 already_received_bytes.drain(..MIN_MESSAGE_SIZE).collect()
134 };
135
136 let (primary_header, fields_len) = PrimaryHeader::read(&bytes)?;
137 let header_len = MIN_MESSAGE_SIZE + fields_len as usize;
138 let body_padding = padding_for_8_bytes(header_len);
139 let body_len = primary_header.body_len() as usize;
140 let total_len = header_len + body_padding + body_len;
141 if total_len > MAX_MESSAGE_SIZE {
142 return Err(crate::Error::ExcessData);
143 }
144
145 if !already_received_bytes.is_empty() {
148 let pending = total_len - bytes.len();
150 let to_take = std::cmp::min(pending, already_received_bytes.len());
151 bytes.extend(already_received_bytes.drain(..to_take));
152 }
153 let mut pos = bytes.len();
154 bytes.resize(total_len, 0);
155
156 while pos < total_len {
158 let res = self.recvmsg(&mut bytes[pos..]).await?;
159 let read = {
160 #[cfg(unix)]
161 {
162 fds.extend(res.1);
163 res.0
164 }
165 #[cfg(not(unix))]
166 {
167 res
168 }
169 };
170 pos += read;
171 if read == 0 {
172 return Err(crate::Error::InputOutput(
173 std::io::Error::new(
174 std::io::ErrorKind::UnexpectedEof,
175 "failed to receive message",
176 )
177 .into(),
178 ));
179 }
180 }
181
182 let endian = Endian::from(primary_header.endian_sig());
184
185 #[cfg(unix)]
186 if !already_received_fds.is_empty() {
187 use crate::message::{header::PRIMARY_HEADER_SIZE, Field};
188
189 let ctxt = Context::new_dbus(endian, PRIMARY_HEADER_SIZE);
190 let encoded_fields =
191 serialized::Data::new(&bytes[PRIMARY_HEADER_SIZE..header_len], ctxt);
192 let fields: crate::message::Fields<'_> = encoded_fields.deserialize()?.0;
193 let num_required_fds = match fields.get_field(crate::message::FieldCode::UnixFDs) {
194 Some(Field::UnixFDs(num_fds)) => *num_fds as usize,
195 _ => 0,
196 };
197 let num_pending = num_required_fds
198 .checked_sub(fds.len())
199 .ok_or_else(|| crate::Error::ExcessData)?;
200 if num_pending == 0 {
202 return Err(crate::Error::MissingParameter("Missing file descriptors"));
203 }
204 let mut already_received: Vec<_> = already_received_fds.drain(..num_pending).collect();
206 mem::swap(&mut already_received, &mut fds);
207 fds.extend(already_received);
208 }
209
210 let ctxt = Context::new_dbus(endian, 0);
211 #[cfg(unix)]
212 let bytes = serialized::Data::new_fds(bytes, ctxt, fds);
213 #[cfg(not(unix))]
214 let bytes = serialized::Data::new(bytes, ctxt);
215 Message::from_raw_parts(bytes, seq)
216 }
217
218 async fn recvmsg(&mut self, _buf: &mut [u8]) -> RecvmsgResult {
226 unimplemented!("`ReadHalf` implementers must either override `read_message` or `recvmsg`");
227 }
228
229 fn can_pass_unix_fd(&self) -> bool {
233 false
234 }
235
236 async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
238 Ok(ConnectionCredentials::default())
239 }
240}
241
242#[async_trait::async_trait]
246pub trait WriteHalf: std::fmt::Debug + Send + Sync + 'static {
247 async fn send_message(&mut self, msg: &Message) -> crate::Result<()> {
254 let data = msg.data();
255 let serial = msg.primary_header().serial_num();
256
257 trace!("Sending message: {:?}", msg);
258 let mut pos = 0;
259 while pos < data.len() {
260 #[cfg(unix)]
261 let fds = if pos == 0 {
262 data.fds().iter().map(|f| f.as_fd()).collect()
263 } else {
264 vec![]
265 };
266 pos += self
267 .sendmsg(
268 &data[pos..],
269 #[cfg(unix)]
270 &fds,
271 )
272 .await?;
273 }
274 trace!("Sent message with serial: {}", serial);
275
276 Ok(())
277 }
278
279 async fn sendmsg(
294 &mut self,
295 _buffer: &[u8],
296 #[cfg(unix)] _fds: &[BorrowedFd<'_>],
297 ) -> io::Result<usize> {
298 unimplemented!("`WriteHalf` implementers must either override `send_message` or `sendmsg`");
299 }
300
301 #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
306 async fn send_zero_byte(&mut self) -> io::Result<Option<usize>> {
307 Ok(None)
308 }
309
310 async fn close(&mut self) -> io::Result<()>;
314
315 fn can_pass_unix_fd(&self) -> bool {
319 false
320 }
321
322 async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
324 Ok(ConnectionCredentials::default())
325 }
326}
327
328#[async_trait::async_trait]
329impl ReadHalf for Box<dyn ReadHalf> {
330 fn can_pass_unix_fd(&self) -> bool {
331 (**self).can_pass_unix_fd()
332 }
333
334 async fn receive_message(
335 &mut self,
336 seq: u64,
337 already_received_bytes: &mut Vec<u8>,
338 #[cfg(unix)] already_received_fds: &mut Vec<std::os::fd::OwnedFd>,
339 ) -> crate::Result<Message> {
340 (**self)
341 .receive_message(
342 seq,
343 already_received_bytes,
344 #[cfg(unix)]
345 already_received_fds,
346 )
347 .await
348 }
349
350 async fn recvmsg(&mut self, buf: &mut [u8]) -> RecvmsgResult {
351 (**self).recvmsg(buf).await
352 }
353
354 async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
355 (**self).peer_credentials().await
356 }
357}
358
359#[async_trait::async_trait]
360impl WriteHalf for Box<dyn WriteHalf> {
361 async fn send_message(&mut self, msg: &Message) -> crate::Result<()> {
362 (**self).send_message(msg).await
363 }
364
365 async fn sendmsg(
366 &mut self,
367 buffer: &[u8],
368 #[cfg(unix)] fds: &[BorrowedFd<'_>],
369 ) -> io::Result<usize> {
370 (**self)
371 .sendmsg(
372 buffer,
373 #[cfg(unix)]
374 fds,
375 )
376 .await
377 }
378
379 #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
380 async fn send_zero_byte(&mut self) -> io::Result<Option<usize>> {
381 (**self).send_zero_byte().await
382 }
383
384 async fn close(&mut self) -> io::Result<()> {
385 (**self).close().await
386 }
387
388 fn can_pass_unix_fd(&self) -> bool {
389 (**self).can_pass_unix_fd()
390 }
391
392 async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
393 (**self).peer_credentials().await
394 }
395}
396
397#[cfg(not(feature = "tokio"))]
398impl<T> Socket for Async<T>
399where
400 T: std::fmt::Debug + Send + Sync,
401 Arc<Async<T>>: ReadHalf + WriteHalf,
402{
403 type ReadHalf = Arc<Async<T>>;
404 type WriteHalf = Arc<Async<T>>;
405
406 fn split(self) -> Split<Self::ReadHalf, Self::WriteHalf> {
407 let arc = Arc::new(self);
408
409 Split {
410 read: arc.clone(),
411 write: arc,
412 }
413 }
414}