zbus/connection/handshake/
client.rs
1use async_trait::async_trait;
2use std::collections::VecDeque;
3use tracing::{debug, instrument, trace, warn};
4
5use sha1::{Digest, Sha1};
6
7use crate::{conn::socket::ReadHalf, is_flatpak, names::OwnedUniqueName, Message};
8
9use super::{
10 random_ascii, sasl_auth_id, AuthMechanism, Authenticated, BoxedSplit, Command, Common, Cookie,
11 Error, Handshake, OwnedGuid, Result, Str,
12};
13
14#[derive(Debug)]
19pub struct Client {
20 common: Common,
21 server_guid: Option<OwnedGuid>,
22 bus: bool,
23}
24
25impl Client {
26 pub fn new(
28 socket: BoxedSplit,
29 mechanisms: Option<VecDeque<AuthMechanism>>,
30 server_guid: Option<OwnedGuid>,
31 bus: bool,
32 ) -> Client {
33 let mechanisms = mechanisms.unwrap_or_else(|| {
34 let mut mechanisms = VecDeque::new();
35 mechanisms.push_back(AuthMechanism::External);
36 mechanisms.push_back(AuthMechanism::Cookie);
37 mechanisms.push_back(AuthMechanism::Anonymous);
38 mechanisms
39 });
40
41 Client {
42 common: Common::new(socket, mechanisms),
43 server_guid,
44 bus,
45 }
46 }
47
48 async fn handle_cookie_challenge(&mut self, data: Vec<u8>) -> Result<Command> {
52 let context = std::str::from_utf8(&data)
53 .map_err(|_| Error::Handshake("Cookie context was not valid UTF-8".into()))?;
54 let mut split = context.split_ascii_whitespace();
55 let context = split
56 .next()
57 .ok_or_else(|| Error::Handshake("Missing cookie context name".into()))?;
58 let context = Str::from(context).try_into()?;
59 let id = split
60 .next()
61 .ok_or_else(|| Error::Handshake("Missing cookie ID".into()))?;
62 let id = id
63 .parse()
64 .map_err(|e| Error::Handshake(format!("Invalid cookie ID `{id}`: {e}")))?;
65 let server_challenge = split
66 .next()
67 .ok_or_else(|| Error::Handshake("Missing cookie challenge".into()))?;
68
69 let cookie = Cookie::lookup(&context, id).await?;
70 let cookie = cookie.cookie();
71 let client_challenge = random_ascii(16);
72 let sec = format!("{server_challenge}:{client_challenge}:{cookie}");
73 let sha1 = hex::encode(Sha1::digest(sec));
74 let data = format!("{client_challenge} {sha1}").into_bytes();
75
76 Ok(Command::Data(Some(data)))
77 }
78
79 fn set_guid(&mut self, guid: OwnedGuid) -> Result<()> {
80 match &self.server_guid {
81 Some(server_guid) if *server_guid != guid => {
82 return Err(Error::Handshake(format!(
83 "Server GUID mismatch: expected {server_guid}, got {guid}",
84 )));
85 }
86 Some(_) => (),
87 None => self.server_guid = Some(guid),
88 }
89
90 Ok(())
91 }
92
93 #[instrument(skip(self))]
96 #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
97 async fn send_zero_byte(&mut self) -> Result<()> {
98 let write = self.common.socket_mut().write_mut();
99
100 let written = match write.send_zero_byte().await.map_err(|e| {
101 Error::Handshake(format!("Could not send zero byte with credentials: {}", e))
102 })? {
103 None => write.sendmsg(&[0], &[]).await?,
106 Some(n) => n,
107 };
108
109 if written != 1 {
110 return Err(Error::Handshake(
111 "Could not send zero byte with credentials".to_string(),
112 ));
113 }
114
115 Ok(())
116 }
117
118 #[instrument(skip(self))]
123 async fn authenticate(&mut self) -> Result<Option<Command>> {
124 loop {
125 let mechanism = self.common.next_mechanism()?;
126 trace!("Trying {mechanism} mechanism");
127 let auth_cmd = match mechanism {
128 AuthMechanism::Anonymous => Command::Auth(Some(mechanism), Some("zbus".into())),
129 AuthMechanism::External => {
130 Command::Auth(Some(mechanism), Some(sasl_auth_id()?.into_bytes()))
131 }
132 AuthMechanism::Cookie => Command::Auth(
133 Some(AuthMechanism::Cookie),
134 Some(sasl_auth_id()?.into_bytes()),
135 ),
136 };
137 self.common.write_command(auth_cmd).await?;
138
139 match self.common.read_command().await? {
140 Command::Ok(guid) => {
141 trace!("Received OK from server");
142 self.set_guid(guid)?;
143
144 return Ok(None);
145 }
146 Command::Data(data) if mechanism == AuthMechanism::Cookie => {
147 let data = data.ok_or_else(|| {
148 Error::Handshake("Received DATA with no data from server".into())
149 })?;
150 trace!("Received cookie challenge from server");
151 let response = self.handle_cookie_challenge(data).await?;
152
153 return Ok(Some(response));
154 }
155 Command::Rejected(_) => debug!("{mechanism} rejected by the server"),
156 Command::Error(e) => debug!("Received error from server: {e}"),
157 cmd => {
158 return Err(Error::Handshake(format!(
159 "Unexpected command from server: {cmd}"
160 )))
161 }
162 }
163 }
164 }
165
166 #[instrument(skip(self))]
171 async fn send_secondary_commands(
172 &mut self,
173 challenge_response: Option<Command>,
174 ) -> Result<usize> {
175 let mut commands = Vec::with_capacity(4);
176 if let Some(response) = challenge_response {
177 commands.push(response);
178 }
179
180 let can_pass_fd = self.common.socket_mut().read_mut().can_pass_unix_fd();
181 if can_pass_fd {
182 if is_flatpak() {
186 self.common.write_command(Command::NegotiateUnixFD).await?;
187 match self.common.read_command().await? {
188 Command::AgreeUnixFD => self.common.set_cap_unix_fd(true),
189 Command::Error(e) => warn!("UNIX file descriptor passing rejected: {e}"),
190 cmd => {
191 return Err(Error::Handshake(format!(
192 "Unexpected command from server: {cmd}"
193 )))
194 }
195 }
196 } else {
197 commands.push(Command::NegotiateUnixFD);
198 }
199 };
200 commands.push(Command::Begin);
201 let hello_method = if self.bus {
202 Some(create_hello_method_call())
203 } else {
204 None
205 };
206
207 self.common
208 .write_commands(&commands, hello_method.as_ref().map(|m| &**m.data()))
209 .await?;
210
211 Ok(commands.len() - 1)
213 }
214
215 #[instrument(skip(self))]
216 async fn receive_secondary_responses(&mut self, expected_n_responses: usize) -> Result<()> {
217 for response in self.common.read_commands(expected_n_responses).await? {
218 match response {
219 Command::Ok(guid) => {
220 trace!("Received OK from server");
221 self.set_guid(guid)?;
222 }
223 Command::AgreeUnixFD => self.common.set_cap_unix_fd(true),
224 Command::Error(e) => warn!("UNIX file descriptor passing rejected: {e}"),
225 cmd => {
231 return Err(Error::Handshake(format!(
232 "Unexpected command from server: {cmd}"
233 )))
234 }
235 }
236 }
237
238 Ok(())
239 }
240}
241
242#[async_trait]
243impl Handshake for Client {
244 #[instrument(skip(self))]
245 async fn perform(mut self) -> Result<Authenticated> {
246 trace!("Initializing");
247
248 #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
249 self.send_zero_byte().await?;
250
251 let challenge_response = self.authenticate().await?;
252 let expected_n_responses = self.send_secondary_commands(challenge_response).await?;
253
254 if expected_n_responses > 0 {
255 self.receive_secondary_responses(expected_n_responses)
256 .await?;
257 }
258
259 trace!("Handshake done");
260 #[cfg(unix)]
261 let (socket, mut recv_buffer, received_fds, cap_unix_fd, _) = self.common.into_components();
262 #[cfg(not(unix))]
263 let (socket, mut recv_buffer, _, _) = self.common.into_components();
264 let (mut read, write) = socket.take();
265
266 let unique_name = if self.bus {
268 let unique_name = receive_hello_response(&mut read, &mut recv_buffer).await?;
269
270 Some(unique_name)
271 } else {
272 None
273 };
274
275 Ok(Authenticated {
276 socket_write: write,
277 socket_read: Some(read),
278 server_guid: self.server_guid.unwrap(),
279 #[cfg(unix)]
280 cap_unix_fd,
281 already_received_bytes: recv_buffer,
282 #[cfg(unix)]
283 already_received_fds: received_fds,
284 unique_name,
285 })
286 }
287}
288
289fn create_hello_method_call() -> Message {
290 Message::method("/org/freedesktop/DBus", "Hello")
291 .unwrap()
292 .destination("org.freedesktop.DBus")
293 .unwrap()
294 .interface("org.freedesktop.DBus")
295 .unwrap()
296 .build(&())
297 .unwrap()
298}
299
300async fn receive_hello_response(
301 read: &mut Box<dyn ReadHalf>,
302 recv_buffer: &mut Vec<u8>,
303) -> Result<OwnedUniqueName> {
304 use crate::message::Type;
305
306 let reply = read
307 .receive_message(
308 0,
309 recv_buffer,
310 #[cfg(unix)]
311 &mut vec![],
312 )
313 .await?;
314 match reply.message_type() {
315 Type::MethodReturn => reply.body().deserialize(),
316 Type::Error => Err(Error::from(reply)),
317 m => Err(Error::Handshake(format!("Unexpected messgage `{m:?}`"))),
318 }
319}