zbus/connection/handshake/
common.rs

1use std::collections::VecDeque;
2use tracing::{instrument, trace};
3
4use super::{AuthMechanism, BoxedSplit, Command};
5use crate::{Error, Result};
6
7// Common code for the client and server side of the handshake.
8#[derive(Debug)]
9pub(super) struct Common {
10    socket: BoxedSplit,
11    recv_buffer: Vec<u8>,
12    #[cfg(unix)]
13    received_fds: Vec<std::os::fd::OwnedFd>,
14    cap_unix_fd: bool,
15    // the current AUTH mechanism is front, ordered by priority
16    mechanisms: VecDeque<AuthMechanism>,
17    first_command: bool,
18}
19
20impl Common {
21    /// Start a handshake on this client socket
22    pub fn new(socket: BoxedSplit, mechanisms: VecDeque<AuthMechanism>) -> Self {
23        Self {
24            socket,
25            recv_buffer: Vec::new(),
26            #[cfg(unix)]
27            received_fds: Vec::new(),
28            cap_unix_fd: false,
29            mechanisms,
30            first_command: true,
31        }
32    }
33
34    #[cfg(all(unix, feature = "p2p"))]
35    pub fn socket(&self) -> &BoxedSplit {
36        &self.socket
37    }
38
39    pub fn socket_mut(&mut self) -> &mut BoxedSplit {
40        &mut self.socket
41    }
42
43    pub fn set_cap_unix_fd(&mut self, cap_unix_fd: bool) {
44        self.cap_unix_fd = cap_unix_fd;
45    }
46
47    #[cfg(feature = "p2p")]
48    pub fn mechanisms(&self) -> &VecDeque<AuthMechanism> {
49        &self.mechanisms
50    }
51
52    pub fn into_components(self) -> IntoComponentsReturn {
53        (
54            self.socket,
55            self.recv_buffer,
56            #[cfg(unix)]
57            self.received_fds,
58            self.cap_unix_fd,
59            self.mechanisms,
60        )
61    }
62
63    #[instrument(skip(self))]
64    pub async fn write_command(&mut self, command: Command) -> Result<()> {
65        self.write_commands(&[command], None).await
66    }
67
68    #[instrument(skip(self))]
69    pub async fn write_commands(
70        &mut self,
71        commands: &[Command],
72        extra_bytes: Option<&[u8]>,
73    ) -> Result<()> {
74        let mut send_buffer =
75            commands
76                .iter()
77                .map(Vec::<u8>::from)
78                .fold(vec![], |mut acc, mut c| {
79                    if self.first_command {
80                        // The first command is sent by the client so we can assume it's the client.
81                        self.first_command = false;
82                        // leading 0 is sent separately for `freebsd` and `dragonfly`.
83                        #[cfg(not(any(target_os = "freebsd", target_os = "dragonfly")))]
84                        acc.push(b'\0');
85                    }
86                    acc.append(&mut c);
87                    acc.extend_from_slice(b"\r\n");
88                    acc
89                });
90        if let Some(extra_bytes) = extra_bytes {
91            send_buffer.extend_from_slice(extra_bytes);
92        }
93        while !send_buffer.is_empty() {
94            let written = self
95                .socket
96                .write_mut()
97                .sendmsg(
98                    &send_buffer,
99                    #[cfg(unix)]
100                    &[],
101                )
102                .await?;
103            send_buffer.drain(..written);
104        }
105        trace!("Wrote all commands");
106        Ok(())
107    }
108
109    #[instrument(skip(self))]
110    pub async fn read_command(&mut self) -> Result<Command> {
111        self.read_commands(1)
112            .await
113            .map(|cmds| cmds.into_iter().next().unwrap())
114    }
115
116    #[instrument(skip(self))]
117    pub async fn read_commands(&mut self, n_commands: usize) -> Result<Vec<Command>> {
118        let mut commands = Vec::with_capacity(n_commands);
119        let mut n_received_commands = 0;
120        'outer: loop {
121            while let Some(lf_index) = self.recv_buffer.iter().position(|b| *b == b'\n') {
122                if self.recv_buffer[lf_index - 1] != b'\r' {
123                    return Err(Error::Handshake("Invalid line ending in handshake".into()));
124                }
125
126                #[allow(unused_mut)]
127                let mut start_index = 0;
128                if self.first_command {
129                    // The first command is sent by the client so we can assume it's the server.
130                    self.first_command = false;
131                    if self.recv_buffer[0] != b'\0' {
132                        return Err(Error::Handshake(
133                            "First client byte is not NUL!".to_string(),
134                        ));
135                    }
136
137                    start_index = 1;
138                };
139
140                let line_bytes = self.recv_buffer.drain(..=lf_index);
141                let line = std::str::from_utf8(&line_bytes.as_slice()[start_index..])
142                    .map_err(|e| Error::Handshake(e.to_string()))?;
143
144                trace!("Reading {line}");
145                commands.push(line.parse()?);
146                n_received_commands += 1;
147
148                if n_received_commands == n_commands {
149                    break 'outer;
150                }
151            }
152
153            let mut buf = vec![0; 1024];
154            let res = self.socket.read_mut().recvmsg(&mut buf).await?;
155            let read = {
156                #[cfg(unix)]
157                {
158                    let (read, fds) = res;
159                    if !fds.is_empty() {
160                        // Most likely belonging to the messages already received.
161                        self.received_fds.extend(fds);
162                    }
163                    read
164                }
165                #[cfg(not(unix))]
166                {
167                    res
168                }
169            };
170            if read == 0 {
171                return Err(Error::Handshake("Unexpected EOF during handshake".into()));
172            }
173            self.recv_buffer.extend(&buf[..read]);
174        }
175
176        Ok(commands)
177    }
178
179    pub fn next_mechanism(&mut self) -> Result<AuthMechanism> {
180        self.mechanisms
181            .pop_front()
182            .ok_or_else(|| Error::Handshake("Exhausted available AUTH mechanisms".into()))
183    }
184}
185
186#[cfg(unix)]
187type IntoComponentsReturn = (
188    BoxedSplit,
189    Vec<u8>,
190    Vec<std::os::fd::OwnedFd>,
191    bool,
192    VecDeque<AuthMechanism>,
193);
194#[cfg(not(unix))]
195type IntoComponentsReturn = (BoxedSplit, Vec<u8>, bool, VecDeque<AuthMechanism>);