zbus/message/
fields.rs

1use serde::{Deserialize, Serialize};
2use static_assertions::assert_impl_all;
3use std::num::NonZeroU32;
4use zbus_names::{BusName, ErrorName, InterfaceName, MemberName, UniqueName};
5use zvariant::{ObjectPath, Signature, Type};
6
7use crate::{
8    message::{Field, FieldCode, Header, Message},
9    Result,
10};
11
12// It's actually 10 (and even not that) but let's round it to next 8-byte alignment
13const MAX_FIELDS_IN_MESSAGE: usize = 16;
14
15/// A collection of [`Field`] instances.
16///
17/// [`Field`]: enum.Field.html
18#[derive(Debug, Clone, Serialize, Deserialize, Type)]
19pub(crate) struct Fields<'m>(#[serde(borrow)] Vec<Field<'m>>);
20
21assert_impl_all!(Fields<'_>: Send, Sync, Unpin);
22
23impl<'m> Fields<'m> {
24    /// Creates an empty collection of fields.
25    pub fn new() -> Self {
26        Self::default()
27    }
28
29    /// Appends a [`Field`] to the collection of fields in the message.
30    ///
31    /// [`Field`]: enum.Field.html
32    pub fn add<'f: 'm>(&mut self, field: Field<'f>) {
33        self.0.push(field);
34    }
35
36    /// Replaces a [`Field`] from the collection of fields with one with the same code,
37    /// returning the old value if present.
38    ///
39    /// [`Field`]: enum.Field.html
40    pub fn replace<'f: 'm>(&mut self, field: Field<'f>) -> Option<Field<'m>> {
41        let code = field.code();
42        if let Some(found) = self.0.iter_mut().find(|f| f.code() == code) {
43            return Some(std::mem::replace(found, field));
44        }
45        self.add(field);
46        None
47    }
48
49    /// Returns a slice with all the [`Field`] in the message.
50    ///
51    /// [`Field`]: enum.Field.html
52    pub fn get(&self) -> &[Field<'m>] {
53        &self.0
54    }
55
56    /// Gets a reference to a specific [`Field`] by its code.
57    ///
58    /// Returns `None` if the message has no such field.
59    ///
60    /// [`Field`]: enum.Field.html
61    pub fn get_field(&self, code: FieldCode) -> Option<&Field<'m>> {
62        self.0.iter().find(|f| f.code() == code)
63    }
64
65    /// Remove the field matching the `code`.
66    ///
67    /// Returns `true` if a field was found and removed, `false` otherwise.
68    pub(crate) fn remove(&mut self, code: FieldCode) -> bool {
69        match self.0.iter().enumerate().find(|(_, f)| f.code() == code) {
70            Some((i, _)) => {
71                self.0.remove(i);
72
73                true
74            }
75            None => false,
76        }
77    }
78}
79
80/// A byte range of a field in a Message, used in [`QuickFields`].
81///
82/// Some invalid encodings (end = 0) are used to indicate "not cached" and "not present".
83#[derive(Debug, Default, Clone, Copy)]
84pub(crate) struct FieldPos {
85    start: u32,
86    end: u32,
87}
88
89impl FieldPos {
90    pub fn new_not_present() -> Self {
91        Self { start: 1, end: 0 }
92    }
93
94    pub fn build(msg_buf: &[u8], field_buf: &str) -> Option<Self> {
95        let buf_start = msg_buf.as_ptr() as usize;
96        let field_start = field_buf.as_ptr() as usize;
97        let offset = field_start.checked_sub(buf_start)?;
98        if offset <= msg_buf.len() && offset + field_buf.len() <= msg_buf.len() {
99            Some(Self {
100                start: offset.try_into().ok()?,
101                end: (offset + field_buf.len()).try_into().ok()?,
102            })
103        } else {
104            None
105        }
106    }
107
108    pub fn new<T>(msg_buf: &[u8], field: Option<&T>) -> Self
109    where
110        T: std::ops::Deref<Target = str>,
111    {
112        field
113            .and_then(|f| Self::build(msg_buf, f.deref()))
114            .unwrap_or_else(Self::new_not_present)
115    }
116
117    /// Reassemble a previously cached field.
118    ///
119    /// **NOTE**: The caller must ensure that the `msg_buff` is the same one `build` was called for.
120    /// Otherwise, you'll get a panic.
121    pub fn read<'m, T>(&self, msg_buf: &'m [u8]) -> Option<T>
122    where
123        T: TryFrom<&'m str>,
124        T::Error: std::fmt::Debug,
125    {
126        match self {
127            Self {
128                start: 0..=1,
129                end: 0,
130            } => None,
131            Self { start, end } => {
132                let s = std::str::from_utf8(&msg_buf[(*start as usize)..(*end as usize)])
133                    .expect("Invalid utf8 when reconstructing string");
134                // We already check the fields during the construction of `Self`.
135                T::try_from(s)
136                    .map(Some)
137                    .expect("Invalid field reconstruction")
138            }
139        }
140    }
141}
142
143/// A cache of the Message header fields.
144#[derive(Debug, Default, Copy, Clone)]
145pub(crate) struct QuickFields {
146    path: FieldPos,
147    interface: FieldPos,
148    member: FieldPos,
149    error_name: FieldPos,
150    reply_serial: Option<NonZeroU32>,
151    destination: FieldPos,
152    sender: FieldPos,
153    signature: FieldPos,
154    unix_fds: Option<u32>,
155}
156
157impl QuickFields {
158    pub fn new(buf: &[u8], header: &Header<'_>) -> Result<Self> {
159        Ok(Self {
160            path: FieldPos::new(buf, header.path()),
161            interface: FieldPos::new(buf, header.interface()),
162            member: FieldPos::new(buf, header.member()),
163            error_name: FieldPos::new(buf, header.error_name()),
164            reply_serial: header.reply_serial(),
165            destination: FieldPos::new(buf, header.destination()),
166            sender: FieldPos::new(buf, header.sender()),
167            signature: FieldPos::new(buf, header.signature()),
168            unix_fds: header.unix_fds(),
169        })
170    }
171
172    pub fn path<'m>(&self, msg: &'m Message) -> Option<ObjectPath<'m>> {
173        self.path.read(msg.data())
174    }
175
176    pub fn interface<'m>(&self, msg: &'m Message) -> Option<InterfaceName<'m>> {
177        self.interface.read(msg.data())
178    }
179
180    pub fn member<'m>(&self, msg: &'m Message) -> Option<MemberName<'m>> {
181        self.member.read(msg.data())
182    }
183
184    pub fn error_name<'m>(&self, msg: &'m Message) -> Option<ErrorName<'m>> {
185        self.error_name.read(msg.data())
186    }
187
188    pub fn reply_serial(&self) -> Option<NonZeroU32> {
189        self.reply_serial
190    }
191
192    pub fn destination<'m>(&self, msg: &'m Message) -> Option<BusName<'m>> {
193        self.destination.read(msg.data())
194    }
195
196    pub fn sender<'m>(&self, msg: &'m Message) -> Option<UniqueName<'m>> {
197        self.sender.read(msg.data())
198    }
199
200    pub fn signature<'m>(&self, msg: &'m Message) -> Option<Signature<'m>> {
201        self.signature.read(msg.data())
202    }
203
204    pub fn unix_fds(&self) -> Option<u32> {
205        self.unix_fds
206    }
207}
208
209impl<'m> Default for Fields<'m> {
210    fn default() -> Self {
211        Self(Vec::with_capacity(MAX_FIELDS_IN_MESSAGE))
212    }
213}
214
215impl<'m> std::ops::Deref for Fields<'m> {
216    type Target = [Field<'m>];
217
218    fn deref(&self) -> &Self::Target {
219        self.get()
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::{Field, Fields};
226
227    #[test]
228    fn test() {
229        let mut mf = Fields::new();
230        assert_eq!(mf.len(), 0);
231        mf.add(Field::ReplySerial(42.try_into().unwrap()));
232        assert_eq!(mf.len(), 1);
233        mf.add(Field::ReplySerial(43.try_into().unwrap()));
234        assert_eq!(mf.len(), 2);
235
236        let mut mf = Fields::new();
237        assert_eq!(mf.len(), 0);
238        mf.replace(Field::ReplySerial(42.try_into().unwrap()));
239        assert_eq!(mf.len(), 1);
240        mf.replace(Field::ReplySerial(43.try_into().unwrap()));
241        assert_eq!(mf.len(), 1);
242    }
243}