zbus/connection/socket/
unix.rs

1#[cfg(not(feature = "tokio"))]
2use async_io::Async;
3#[cfg(unix)]
4use std::os::unix::io::{AsRawFd, BorrowedFd, FromRawFd, RawFd};
5#[cfg(all(unix, not(feature = "tokio")))]
6use std::os::unix::net::UnixStream;
7#[cfg(not(feature = "tokio"))]
8use std::sync::Arc;
9#[cfg(unix)]
10use std::{
11    future::poll_fn,
12    io::{self, IoSlice, IoSliceMut},
13    os::fd::OwnedFd,
14    task::Poll,
15};
16#[cfg(all(windows, not(feature = "tokio")))]
17use uds_windows::UnixStream;
18
19#[cfg(unix)]
20use nix::{
21    cmsg_space,
22    sys::socket::{recvmsg, sendmsg, ControlMessage, ControlMessageOwned, MsgFlags, UnixAddr},
23};
24
25#[cfg(unix)]
26use crate::utils::FDS_MAX;
27
28#[cfg(all(unix, not(feature = "tokio")))]
29#[async_trait::async_trait]
30impl super::ReadHalf for Arc<Async<UnixStream>> {
31    async fn recvmsg(&mut self, buf: &mut [u8]) -> super::RecvmsgResult {
32        poll_fn(|cx| {
33            let (len, fds) = loop {
34                match fd_recvmsg(self.as_raw_fd(), buf) {
35                    Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
36                    Err(e) if e.kind() == io::ErrorKind::WouldBlock => match self.poll_readable(cx)
37                    {
38                        Poll::Pending => return Poll::Pending,
39                        Poll::Ready(res) => res?,
40                    },
41                    v => break v?,
42                }
43            };
44            Poll::Ready(Ok((len, fds)))
45        })
46        .await
47    }
48
49    /// Supports passing file descriptors.
50    fn can_pass_unix_fd(&self) -> bool {
51        true
52    }
53
54    async fn peer_credentials(&mut self) -> io::Result<crate::fdo::ConnectionCredentials> {
55        get_unix_peer_creds(self).await
56    }
57}
58
59#[cfg(all(unix, not(feature = "tokio")))]
60#[async_trait::async_trait]
61impl super::WriteHalf for Arc<Async<UnixStream>> {
62    async fn sendmsg(
63        &mut self,
64        buffer: &[u8],
65        #[cfg(unix)] fds: &[BorrowedFd<'_>],
66    ) -> io::Result<usize> {
67        poll_fn(|cx| loop {
68            match fd_sendmsg(
69                self.as_raw_fd(),
70                buffer,
71                #[cfg(unix)]
72                fds,
73            ) {
74                Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
75                Err(e) if e.kind() == io::ErrorKind::WouldBlock => match self.poll_writable(cx) {
76                    Poll::Pending => return Poll::Pending,
77                    Poll::Ready(res) => res?,
78                },
79                v => return Poll::Ready(v),
80            }
81        })
82        .await
83    }
84
85    async fn close(&mut self) -> io::Result<()> {
86        let stream = self.clone();
87        crate::Task::spawn_blocking(
88            move || stream.get_ref().shutdown(std::net::Shutdown::Both),
89            "close socket",
90        )
91        .await
92    }
93
94    #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
95    async fn send_zero_byte(&mut self) -> io::Result<Option<usize>> {
96        send_zero_byte(self).await.map(Some)
97    }
98
99    /// Supports passing file descriptors.
100    fn can_pass_unix_fd(&self) -> bool {
101        true
102    }
103
104    async fn peer_credentials(&mut self) -> io::Result<crate::fdo::ConnectionCredentials> {
105        get_unix_peer_creds(self).await
106    }
107}
108
109#[cfg(all(unix, feature = "tokio"))]
110impl super::Socket for tokio::net::UnixStream {
111    type ReadHalf = tokio::net::unix::OwnedReadHalf;
112    type WriteHalf = tokio::net::unix::OwnedWriteHalf;
113
114    fn split(self) -> super::Split<Self::ReadHalf, Self::WriteHalf> {
115        let (read, write) = self.into_split();
116
117        super::Split { read, write }
118    }
119}
120
121#[cfg(all(unix, feature = "tokio"))]
122#[async_trait::async_trait]
123impl super::ReadHalf for tokio::net::unix::OwnedReadHalf {
124    async fn recvmsg(&mut self, buf: &mut [u8]) -> super::RecvmsgResult {
125        let stream = self.as_ref();
126        poll_fn(|cx| {
127            loop {
128                match stream.try_io(tokio::io::Interest::READABLE, || {
129                    // We use own custom function for reading because we need to receive file
130                    // descriptors too.
131                    fd_recvmsg(stream.as_raw_fd(), buf)
132                }) {
133                    Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
134                    Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
135                        match stream.poll_read_ready(cx) {
136                            Poll::Pending => return Poll::Pending,
137                            Poll::Ready(res) => res?,
138                        }
139                    }
140                    v => return Poll::Ready(v),
141                }
142            }
143        })
144        .await
145    }
146
147    /// Supports passing file descriptors.
148    fn can_pass_unix_fd(&self) -> bool {
149        true
150    }
151
152    async fn peer_credentials(&mut self) -> io::Result<crate::fdo::ConnectionCredentials> {
153        get_unix_peer_creds(self.as_ref()).await
154    }
155}
156
157#[cfg(all(unix, feature = "tokio"))]
158#[async_trait::async_trait]
159impl super::WriteHalf for tokio::net::unix::OwnedWriteHalf {
160    async fn sendmsg(
161        &mut self,
162        buffer: &[u8],
163        #[cfg(unix)] fds: &[BorrowedFd<'_>],
164    ) -> io::Result<usize> {
165        let stream = self.as_ref();
166        poll_fn(|cx| loop {
167            match stream.try_io(tokio::io::Interest::WRITABLE, || {
168                fd_sendmsg(
169                    stream.as_raw_fd(),
170                    buffer,
171                    #[cfg(unix)]
172                    fds,
173                )
174            }) {
175                Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
176                Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
177                    match stream.poll_write_ready(cx) {
178                        Poll::Pending => return Poll::Pending,
179                        Poll::Ready(res) => res?,
180                    }
181                }
182                v => return Poll::Ready(v),
183            }
184        })
185        .await
186    }
187
188    async fn close(&mut self) -> io::Result<()> {
189        tokio::io::AsyncWriteExt::shutdown(self).await
190    }
191
192    #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
193    async fn send_zero_byte(&mut self) -> io::Result<Option<usize>> {
194        send_zero_byte(self.as_ref()).await.map(Some)
195    }
196
197    /// Supports passing file descriptors.
198    fn can_pass_unix_fd(&self) -> bool {
199        true
200    }
201
202    async fn peer_credentials(&mut self) -> io::Result<crate::fdo::ConnectionCredentials> {
203        get_unix_peer_creds(self.as_ref()).await
204    }
205}
206
207#[cfg(all(windows, not(feature = "tokio")))]
208#[async_trait::async_trait]
209impl super::ReadHalf for Arc<Async<UnixStream>> {
210    async fn recvmsg(&mut self, buf: &mut [u8]) -> super::RecvmsgResult {
211        match futures_util::AsyncReadExt::read(&mut self.as_ref(), buf).await {
212            Err(e) => Err(e),
213            Ok(len) => {
214                #[cfg(unix)]
215                let ret = (len, vec![]);
216                #[cfg(not(unix))]
217                let ret = len;
218                Ok(ret)
219            }
220        }
221    }
222
223    async fn peer_credentials(&mut self) -> std::io::Result<crate::fdo::ConnectionCredentials> {
224        let stream = self.clone();
225        crate::Task::spawn_blocking(
226            move || {
227                use crate::win32::{unix_stream_get_peer_pid, ProcessToken};
228
229                let pid = unix_stream_get_peer_pid(stream.get_ref())? as _;
230                let sid = ProcessToken::open(if pid != 0 { Some(pid as _) } else { None })
231                    .and_then(|process_token| process_token.sid())?;
232                Ok(crate::fdo::ConnectionCredentials::default()
233                    .set_process_id(pid)
234                    .set_windows_sid(sid))
235            },
236            "peer credentials",
237        )
238        .await
239    }
240}
241
242#[cfg(all(windows, not(feature = "tokio")))]
243#[async_trait::async_trait]
244impl super::WriteHalf for Arc<Async<UnixStream>> {
245    async fn sendmsg(
246        &mut self,
247        buf: &[u8],
248        #[cfg(unix)] _fds: &[BorrowedFd<'_>],
249    ) -> std::io::Result<usize> {
250        futures_util::AsyncWriteExt::write(&mut self.as_ref(), buf).await
251    }
252
253    async fn close(&mut self) -> std::io::Result<()> {
254        let stream = self.clone();
255        crate::Task::spawn_blocking(
256            move || stream.get_ref().shutdown(std::net::Shutdown::Both),
257            "close socket",
258        )
259        .await
260    }
261
262    async fn peer_credentials(&mut self) -> std::io::Result<crate::fdo::ConnectionCredentials> {
263        super::ReadHalf::peer_credentials(self).await
264    }
265}
266
267#[cfg(unix)]
268fn fd_recvmsg(fd: RawFd, buffer: &mut [u8]) -> io::Result<(usize, Vec<OwnedFd>)> {
269    let mut iov = [IoSliceMut::new(buffer)];
270    let mut cmsgspace = cmsg_space!([RawFd; FDS_MAX]);
271
272    let msg = recvmsg::<UnixAddr>(fd, &mut iov, Some(&mut cmsgspace), MsgFlags::empty())?;
273    if msg.bytes == 0 {
274        return Err(io::Error::new(
275            io::ErrorKind::BrokenPipe,
276            "failed to read from socket",
277        ));
278    }
279    let mut fds = vec![];
280    for cmsg in msg.cmsgs()? {
281        #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
282        if let ControlMessageOwned::ScmCreds(_) = cmsg {
283            continue;
284        }
285        if let ControlMessageOwned::ScmRights(fd) = cmsg {
286            fds.extend(fd.iter().map(|&f| unsafe { OwnedFd::from_raw_fd(f) }));
287        } else {
288            return Err(io::Error::new(
289                io::ErrorKind::InvalidData,
290                "unexpected CMSG kind",
291            ));
292        }
293    }
294    Ok((msg.bytes, fds))
295}
296
297#[cfg(unix)]
298fn fd_sendmsg(fd: RawFd, buffer: &[u8], fds: &[BorrowedFd<'_>]) -> io::Result<usize> {
299    // FIXME: Remove this conversion once nix supports BorrowedFd here.
300    //
301    // Tracking issue: https://github.com/nix-rust/nix/issues/1750
302    let fds: Vec<_> = fds.iter().map(|f| f.as_raw_fd()).collect();
303    let cmsg = if !fds.is_empty() {
304        vec![ControlMessage::ScmRights(&fds)]
305    } else {
306        vec![]
307    };
308    let iov = [IoSlice::new(buffer)];
309    match sendmsg::<UnixAddr>(fd, &iov, &cmsg, MsgFlags::empty(), None) {
310        // can it really happen?
311        Ok(0) => Err(io::Error::new(
312            io::ErrorKind::WriteZero,
313            "failed to write to buffer",
314        )),
315        Ok(n) => Ok(n),
316        Err(e) => Err(e.into()),
317    }
318}
319
320#[cfg(unix)]
321async fn get_unix_peer_creds(fd: &impl AsRawFd) -> io::Result<crate::fdo::ConnectionCredentials> {
322    let fd = fd.as_raw_fd();
323    // FIXME: Is it likely enough for sending of 1 byte to block, to justify a task (possibly
324    // launching a thread in turn)?
325    crate::Task::spawn_blocking(move || get_unix_peer_creds_blocking(fd), "peer credentials").await
326}
327
328#[cfg(unix)]
329fn get_unix_peer_creds_blocking(fd: RawFd) -> io::Result<crate::fdo::ConnectionCredentials> {
330    // TODO: get this BorrowedFd directly from get_unix_peer_creds(), but this requires a
331    // 'static lifetime due to the Task.
332    let fd = unsafe { BorrowedFd::borrow_raw(fd) };
333
334    #[cfg(any(target_os = "android", target_os = "linux"))]
335    {
336        use nix::sys::socket::{getsockopt, sockopt::PeerCredentials};
337
338        getsockopt(&fd, PeerCredentials)
339            .map(|creds| {
340                crate::fdo::ConnectionCredentials::default()
341                    .set_process_id(creds.pid() as _)
342                    .set_unix_user_id(creds.uid())
343            })
344            .map_err(|e| e.into())
345    }
346
347    #[cfg(any(
348        target_os = "macos",
349        target_os = "ios",
350        target_os = "freebsd",
351        target_os = "dragonfly",
352        target_os = "openbsd",
353        target_os = "netbsd"
354    ))]
355    {
356        let uid = nix::unistd::getpeereid(fd).map(|(uid, _)| uid.into())?;
357        // FIXME: Handle pid fetching too.
358        Ok(crate::fdo::ConnectionCredentials::default().set_unix_user_id(uid))
359    }
360}
361
362// Send 0 byte as a separate SCM_CREDS message.
363#[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
364async fn send_zero_byte(fd: &impl AsRawFd) -> io::Result<usize> {
365    let fd = fd.as_raw_fd();
366    crate::Task::spawn_blocking(move || send_zero_byte_blocking(fd), "send zero byte").await
367}
368
369#[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
370fn send_zero_byte_blocking(fd: RawFd) -> io::Result<usize> {
371    let iov = [std::io::IoSlice::new(b"\0")];
372    sendmsg::<()>(
373        fd,
374        &iov,
375        &[ControlMessage::ScmCreds],
376        MsgFlags::empty(),
377        None,
378    )
379    .map_err(|e| e.into())
380}