1use std::{
2 num::NonZeroU32,
3 sync::atomic::{AtomicU32, Ordering::SeqCst},
4};
5
6use enumflags2::{bitflags, BitFlags};
7use serde::{Deserialize, Serialize};
8use serde_repr::{Deserialize_repr, Serialize_repr};
9
10use static_assertions::assert_impl_all;
11use zbus_names::{BusName, ErrorName, InterfaceName, MemberName, UniqueName};
12use zvariant::{
13 serialized::{self, Context},
14 Endian, ObjectPath, Signature, Type as VariantType,
15};
16
17use crate::{
18 message::{Field, FieldCode, Fields},
19 Error,
20};
21
22pub(crate) const PRIMARY_HEADER_SIZE: usize = 12;
23pub(crate) const MIN_MESSAGE_SIZE: usize = PRIMARY_HEADER_SIZE + 4;
24pub(crate) const MAX_MESSAGE_SIZE: usize = 128 * 1024 * 1024; #[repr(u8)]
28#[derive(Debug, Copy, Clone, Deserialize_repr, PartialEq, Eq, Serialize_repr, VariantType)]
29pub enum EndianSig {
30 Big = b'B',
32
33 Little = b'l',
35}
36
37assert_impl_all!(EndianSig: Send, Sync, Unpin);
38
39impl TryFrom<u8> for EndianSig {
41 type Error = Error;
42
43 fn try_from(val: u8) -> Result<EndianSig, Error> {
44 match val {
45 b'B' => Ok(EndianSig::Big),
46 b'l' => Ok(EndianSig::Little),
47 _ => Err(Error::IncorrectEndian),
48 }
49 }
50}
51
52#[cfg(target_endian = "big")]
53pub const NATIVE_ENDIAN_SIG: EndianSig = EndianSig::Big;
55#[cfg(target_endian = "little")]
56pub const NATIVE_ENDIAN_SIG: EndianSig = EndianSig::Little;
58
59impl From<Endian> for EndianSig {
60 fn from(endian: Endian) -> Self {
61 match endian {
62 Endian::Little => EndianSig::Little,
63 Endian::Big => EndianSig::Big,
64 }
65 }
66}
67
68impl From<EndianSig> for Endian {
69 fn from(endian_sig: EndianSig) -> Self {
70 match endian_sig {
71 EndianSig::Little => Endian::Little,
72 EndianSig::Big => Endian::Big,
73 }
74 }
75}
76
77#[repr(u8)]
79#[derive(
80 Debug, Copy, Clone, Deserialize_repr, PartialEq, Eq, Hash, Serialize_repr, VariantType,
81)]
82pub enum Type {
83 MethodCall = 1,
85 MethodReturn = 2,
87 Error = 3,
89 Signal = 4,
91}
92
93assert_impl_all!(Type: Send, Sync, Unpin);
94
95#[bitflags]
97#[repr(u8)]
98#[derive(Debug, Copy, Clone, PartialEq, Eq, VariantType)]
99pub enum Flags {
100 NoReplyExpected = 0x1,
108 NoAutoStart = 0x2,
110 AllowInteractiveAuth = 0x4,
115}
116
117assert_impl_all!(Flags: Send, Sync, Unpin);
118
119#[derive(Clone, Debug, Serialize, Deserialize, VariantType)]
123pub struct PrimaryHeader {
124 endian_sig: EndianSig,
125 msg_type: Type,
126 flags: BitFlags<Flags>,
127 protocol_version: u8,
128 body_len: u32,
129 serial_num: NonZeroU32,
130}
131
132assert_impl_all!(PrimaryHeader: Send, Sync, Unpin);
133
134impl PrimaryHeader {
135 pub fn new(msg_type: Type, body_len: u32) -> Self {
137 Self {
138 endian_sig: NATIVE_ENDIAN_SIG,
139 msg_type,
140 flags: BitFlags::empty(),
141 protocol_version: 1,
142 body_len,
143 serial_num: SERIAL_NUM.fetch_add(1, SeqCst).try_into().unwrap(),
144 }
145 }
146
147 pub(crate) fn read(buf: &[u8]) -> Result<(PrimaryHeader, u32), Error> {
148 let endian = Endian::from(EndianSig::try_from(buf[0])?);
149 let ctx = Context::new_dbus(endian, 0);
150 let data = serialized::Data::new(buf, ctx);
151
152 Self::read_from_data(&data)
153 }
154
155 pub(crate) fn read_from_data(
156 data: &serialized::Data<'_, '_>,
157 ) -> Result<(PrimaryHeader, u32), Error> {
158 let (primary_header, size) = data.deserialize()?;
159 assert_eq!(size, PRIMARY_HEADER_SIZE);
160 let (fields_len, _) = data.slice(PRIMARY_HEADER_SIZE..).deserialize()?;
161 Ok((primary_header, fields_len))
162 }
163
164 pub fn endian_sig(&self) -> EndianSig {
166 self.endian_sig
167 }
168
169 pub fn set_endian_sig(&mut self, sig: EndianSig) {
171 self.endian_sig = sig;
172 }
173
174 pub fn msg_type(&self) -> Type {
176 self.msg_type
177 }
178
179 pub fn set_msg_type(&mut self, msg_type: Type) {
181 self.msg_type = msg_type;
182 }
183
184 pub fn flags(&self) -> BitFlags<Flags> {
186 self.flags
187 }
188
189 pub fn set_flags(&mut self, flags: BitFlags<Flags>) {
191 self.flags = flags;
192 }
193
194 pub fn protocol_version(&self) -> u8 {
198 self.protocol_version
199 }
200
201 pub fn set_protocol_version(&mut self, version: u8) {
205 self.protocol_version = version;
206 }
207
208 pub fn body_len(&self) -> u32 {
210 self.body_len
211 }
212
213 pub fn set_body_len(&mut self, len: u32) {
215 self.body_len = len;
216 }
217
218 pub fn serial_num(&self) -> NonZeroU32 {
222 self.serial_num
223 }
224
225 pub fn set_serial_num(&mut self, serial_num: NonZeroU32) {
229 self.serial_num = serial_num;
230 }
231}
232
233#[derive(Debug, Clone, Serialize, Deserialize, VariantType)]
240pub struct Header<'m> {
241 primary: PrimaryHeader,
242 #[serde(borrow)]
243 fields: Fields<'m>,
244}
245
246assert_impl_all!(Header<'_>: Send, Sync, Unpin);
247
248macro_rules! get_field {
249 ($self:ident, $kind:ident) => {
250 get_field!($self, $kind, (|v| v))
251 };
252 ($self:ident, $kind:ident, $closure:tt) => {
253 #[allow(clippy::redundant_closure_call)]
254 match $self.fields().get_field(FieldCode::$kind) {
255 Some(Field::$kind(value)) => Some($closure(value)),
256 Some(_) => unreachable!("FieldCode and Field mismatch"),
258 None => None,
259 }
260 };
261}
262
263macro_rules! get_field_u32 {
264 ($self:ident, $kind:ident) => {
265 get_field!($self, $kind, (|v: &u32| *v))
266 };
267}
268
269impl<'m> Header<'m> {
270 pub(super) fn new(primary: PrimaryHeader, fields: Fields<'m>) -> Self {
272 Self { primary, fields }
273 }
274
275 pub fn primary(&self) -> &PrimaryHeader {
277 &self.primary
278 }
279
280 pub fn primary_mut(&mut self) -> &mut PrimaryHeader {
282 &mut self.primary
283 }
284
285 pub fn into_primary(self) -> PrimaryHeader {
287 self.primary
288 }
289
290 fn fields(&self) -> &Fields<'m> {
292 &self.fields
293 }
294
295 pub(super) fn fields_mut(&mut self) -> &mut Fields<'m> {
297 &mut self.fields
298 }
299
300 pub fn message_type(&self) -> Type {
302 self.primary().msg_type()
303 }
304
305 pub fn path(&self) -> Option<&ObjectPath<'m>> {
307 get_field!(self, Path)
308 }
309
310 pub fn interface(&self) -> Option<&InterfaceName<'m>> {
312 get_field!(self, Interface)
313 }
314
315 pub fn member(&self) -> Option<&MemberName<'m>> {
317 get_field!(self, Member)
318 }
319
320 pub fn error_name(&self) -> Option<&ErrorName<'m>> {
322 get_field!(self, ErrorName)
323 }
324
325 pub fn reply_serial(&self) -> Option<NonZeroU32> {
327 match self.fields().get_field(FieldCode::ReplySerial) {
328 Some(Field::ReplySerial(value)) => Some(*value),
329 Some(_) => unreachable!("FieldCode and Field mismatch"),
331 None => None,
332 }
333 }
334
335 pub fn destination(&self) -> Option<&BusName<'m>> {
337 get_field!(self, Destination)
338 }
339
340 pub fn sender(&self) -> Option<&UniqueName<'m>> {
342 get_field!(self, Sender)
343 }
344
345 pub fn signature(&self) -> Option<&Signature<'m>> {
347 get_field!(self, Signature)
348 }
349
350 pub fn unix_fds(&self) -> Option<u32> {
352 get_field_u32!(self, UnixFDs)
353 }
354}
355
356static SERIAL_NUM: AtomicU32 = AtomicU32::new(1);
357
358#[cfg(test)]
359mod tests {
360 use crate::message::{Field, Fields, Header, PrimaryHeader, Type};
361
362 use std::error::Error;
363 use test_log::test;
364 use zbus_names::{InterfaceName, MemberName};
365 use zvariant::{ObjectPath, Signature};
366
367 #[test]
368 fn header() -> Result<(), Box<dyn Error>> {
369 let path = ObjectPath::try_from("/some/path")?;
370 let iface = InterfaceName::try_from("some.interface")?;
371 let member = MemberName::try_from("Member")?;
372 let mut f = Fields::new();
373 f.add(Field::Path(path.clone()));
374 f.add(Field::Interface(iface.clone()));
375 f.add(Field::Member(member.clone()));
376 f.add(Field::Sender(":1.84".try_into()?));
377 let h = Header::new(PrimaryHeader::new(Type::Signal, 77), f);
378
379 assert_eq!(h.message_type(), Type::Signal);
380 assert_eq!(h.path(), Some(&path));
381 assert_eq!(h.interface(), Some(&iface));
382 assert_eq!(h.member(), Some(&member));
383 assert_eq!(h.error_name(), None);
384 assert_eq!(h.destination(), None);
385 assert_eq!(h.reply_serial(), None);
386 assert_eq!(h.sender().unwrap(), ":1.84");
387 assert_eq!(h.signature(), None);
388 assert_eq!(h.unix_fds(), None);
389
390 let mut f = Fields::new();
391 f.add(Field::ErrorName("org.zbus.Error".try_into()?));
392 f.add(Field::Destination(":1.11".try_into()?));
393 f.add(Field::ReplySerial(88.try_into()?));
394 f.add(Field::Signature(Signature::from_str_unchecked("say")));
395 f.add(Field::UnixFDs(12));
396 let h = Header::new(PrimaryHeader::new(Type::MethodReturn, 77), f);
397
398 assert_eq!(h.message_type(), Type::MethodReturn);
399 assert_eq!(h.path(), None);
400 assert_eq!(h.interface(), None);
401 assert_eq!(h.member(), None);
402 assert_eq!(h.error_name().unwrap(), "org.zbus.Error");
403 assert_eq!(h.destination().unwrap(), ":1.11");
404 assert_eq!(h.reply_serial().map(Into::into), Some(88));
405 assert_eq!(h.sender(), None);
406 assert_eq!(h.signature(), Some(&Signature::from_str_unchecked("say")));
407 assert_eq!(h.unix_fds(), Some(12));
408
409 Ok(())
410 }
411}