rosrust/
dynamic_msg.rs

1use crate::error::{Result, ResultExt};
2use crate::{Duration, RosMsg, Time};
3use lazy_static::lazy_static;
4use regex::RegexBuilder;
5use ros_message::{DataType, FieldCase, FieldInfo, MessagePath, MessageValue, Msg, Value};
6use std::collections::HashMap;
7use std::convert::TryInto;
8use std::io;
9
10#[derive(Clone, Debug)]
11pub struct DynamicMsg {
12    msg: Msg,
13    dependencies: HashMap<MessagePath, Msg>,
14}
15
16fn get_field<'a>(value: &'a MessageValue, name: &str) -> io::Result<&'a Value> {
17    value.get(name).ok_or_else(|| {
18        io::Error::new(
19            io::ErrorKind::Other,
20            format!("Missing field `{}` in value", name),
21        )
22    })
23}
24
25impl DynamicMsg {
26    pub fn new(message_type: &str, message_definition: &str) -> Result<Self> {
27        lazy_static! {
28            static ref RE_DESCRIPTOR_MESSAGES_SPLITTER: regex::Regex = RegexBuilder::new("^=+$")
29                .multi_line(true)
30                .build()
31                .expect("Invalid regex `^=+$`");
32        }
33        let mut message_bodies = RE_DESCRIPTOR_MESSAGES_SPLITTER.split(message_definition);
34        let message_src = message_bodies.next().chain_err(|| {
35            format!(
36                "Message definition for {} is missing main message body",
37                message_type,
38            )
39        })?;
40        let msg = Self::parse_msg(message_type, message_src)?;
41        let mut dependencies = HashMap::new();
42        for message_body in message_bodies {
43            let dependency = Self::parse_dependency(message_body)?;
44            dependencies.insert(dependency.path().clone(), dependency);
45        }
46
47        Ok(DynamicMsg { msg, dependencies })
48    }
49
50    pub fn msg(&self) -> &Msg {
51        &self.msg
52    }
53
54    pub fn dependency(&self, path: &MessagePath) -> Option<&Msg> {
55        self.dependencies.get(path)
56    }
57
58    pub fn from_headers(headers: HashMap<String, String>) -> Result<Self> {
59        let message_type = headers.get("type").chain_err(|| "Missing header `type`")?;
60        let message_definition = headers
61            .get("message_definition")
62            .chain_err(|| "Missing header `message_definition`")?;
63        Self::new(message_type, message_definition)
64    }
65
66    fn parse_msg(message_type: &str, message_src: &str) -> Result<Msg> {
67        let message_path = message_type
68            .try_into()
69            .chain_err(|| format!("Message type {} is invalid", message_type))?;
70        Msg::new(message_path, message_src)
71            .chain_err(|| format!("Failed to parse message {}", message_type))
72    }
73
74    fn parse_dependency(message_body: &str) -> Result<Msg> {
75        lazy_static! {
76            static ref RE_DESCRIPTOR_MSG_TYPE: regex::Regex =
77                regex::Regex::new(r#"^\s*MSG:\s*(\S+)\s*$"#).unwrap();
78        }
79        let message_body = message_body.trim();
80        let (message_type_line, message_src) = message_body
81            .split_once('\n')
82            .chain_err(|| "Message dependency is missing type declaration")?;
83        let cap = RE_DESCRIPTOR_MSG_TYPE
84            .captures(message_type_line)
85            .chain_err(|| format!("Failed to parse message type line `{}`", message_type_line))?;
86        let message_type = cap
87            .get(1)
88            .chain_err(|| format!("Failed to parse message type line `{}`", message_type_line))?;
89        Self::parse_msg(message_type.as_str(), message_src)
90    }
91
92    pub fn encode(&self, value: &MessageValue, mut w: impl io::Write) -> io::Result<()> {
93        self.encode_message(&self.msg, value, &mut w)
94    }
95
96    pub fn decode(&self, mut r: impl io::Read) -> io::Result<MessageValue> {
97        self.decode_message(&self.msg, &mut r)
98    }
99
100    fn get_dependency(&self, path: &MessagePath) -> io::Result<&Msg> {
101        self.dependencies.get(path).ok_or_else(|| {
102            io::Error::new(
103                io::ErrorKind::Other,
104                format!("Missing message dependency: {}", path),
105            )
106        })
107    }
108
109    fn encode_message(
110        &self,
111        msg: &Msg,
112        value: &MessageValue,
113        w: &mut impl io::Write,
114    ) -> io::Result<()> {
115        for field in msg.fields() {
116            match field.case() {
117                FieldCase::Const(_) => continue,
118                FieldCase::Unit => {
119                    let field_value = get_field(value, field.name())?;
120                    self.encode_field(msg.path(), field, field_value, w)?;
121                }
122                FieldCase::Vector => {
123                    let field_value = get_field(value, field.name())?;
124                    self.encode_field_array(msg.path(), field, field_value, None, w)?;
125                }
126                FieldCase::Array(l) => {
127                    let field_value = get_field(value, field.name())?;
128                    self.encode_field_array(msg.path(), field, field_value, Some(*l), w)?;
129                }
130            }
131        }
132        Ok(())
133    }
134
135    fn encode_field(
136        &self,
137        parent: &MessagePath,
138        field: &FieldInfo,
139        value: &Value,
140        w: &mut impl std::io::Write,
141    ) -> io::Result<()> {
142        match (field.datatype(), value) {
143            (DataType::Bool, Value::Bool(v)) => v.encode(w),
144            (DataType::I8(_), Value::I8(v)) => v.encode(w),
145            (DataType::I16, Value::I16(v)) => v.encode(w),
146            (DataType::I32, Value::I32(v)) => v.encode(w),
147            (DataType::I64, Value::I64(v)) => v.encode(w),
148            (DataType::U8(_), Value::U8(v)) => v.encode(w),
149            (DataType::U16, Value::U16(v)) => v.encode(w),
150            (DataType::U32, Value::U32(v)) => v.encode(w),
151            (DataType::U64, Value::U64(v)) => v.encode(w),
152            (DataType::F32, Value::F32(v)) => v.encode(w),
153            (DataType::F64, Value::F64(v)) => v.encode(w),
154            (DataType::String, Value::String(v)) => v.encode(w),
155            (DataType::Time, Value::Time(time)) => time.encode(w),
156            (DataType::Duration, Value::Duration(duration)) => duration.encode(w),
157            (DataType::LocalMessage(name), Value::Message(v)) => {
158                let path = parent.peer(name);
159                let dependency = self.get_dependency(&path)?;
160                self.encode_message(dependency, v, w)
161            }
162            (DataType::GlobalMessage(path), Value::Message(v)) => {
163                let dependency = self.get_dependency(path)?;
164                self.encode_message(dependency, v, w)
165            }
166            _ => Err(io::Error::new(
167                io::ErrorKind::InvalidData,
168                "Passed in dynamic data value does not match message format",
169            )),
170        }
171    }
172
173    fn encode_field_array(
174        &self,
175        parent: &MessagePath,
176        field: &FieldInfo,
177        value: &Value,
178        array_length: Option<usize>,
179        w: &mut impl std::io::Write,
180    ) -> io::Result<()> {
181        let value = match value {
182            Value::Array(v) => v,
183            Value::Bool(_)
184            | Value::I8(_)
185            | Value::I16(_)
186            | Value::I32(_)
187            | Value::I64(_)
188            | Value::U8(_)
189            | Value::U16(_)
190            | Value::U32(_)
191            | Value::U64(_)
192            | Value::F32(_)
193            | Value::F64(_)
194            | Value::String(_)
195            | Value::Time { .. }
196            | Value::Duration { .. }
197            | Value::Message(_) => {
198                return Err(io::Error::new(
199                    io::ErrorKind::InvalidData,
200                    "Passed in dynamic message field is not an array",
201                ));
202            }
203        };
204        match array_length {
205            Some(array_length) => {
206                if array_length != value.len() {
207                    return Err(io::Error::new(
208                        io::ErrorKind::InvalidData,
209                        "Passed in dynamic message array field has wrong length",
210                    ));
211                }
212            }
213            None => {
214                (value.len() as u32).encode(w.by_ref())?;
215            }
216        }
217        for value in value {
218            self.encode_field(parent, field, value, w.by_ref())?;
219        }
220        Ok(())
221    }
222
223    fn decode_message(&self, msg: &Msg, r: &mut impl io::Read) -> io::Result<MessageValue> {
224        let mut output = MessageValue::new();
225        for field in msg.fields() {
226            let value = match field.case() {
227                FieldCase::Const(_) => continue,
228                FieldCase::Unit => self.decode_field(msg.path(), field, r)?,
229                FieldCase::Vector => self.decode_field_array(msg.path(), field, None, r)?,
230                FieldCase::Array(l) => self.decode_field_array(msg.path(), field, Some(*l), r)?,
231            };
232            output.insert(field.name().into(), value);
233        }
234        Ok(output)
235    }
236
237    fn decode_field(
238        &self,
239        parent: &MessagePath,
240        field: &FieldInfo,
241        r: &mut impl io::Read,
242    ) -> io::Result<Value> {
243        Ok(match field.datatype() {
244            DataType::Bool => bool::decode(r)?.into(),
245            DataType::I8(_) => i8::decode(r)?.into(),
246            DataType::I16 => i16::decode(r)?.into(),
247            DataType::I32 => i32::decode(r)?.into(),
248            DataType::I64 => i64::decode(r)?.into(),
249            DataType::U8(_) => u8::decode(r)?.into(),
250            DataType::U16 => u16::decode(r)?.into(),
251            DataType::U32 => u32::decode(r)?.into(),
252            DataType::U64 => u64::decode(r)?.into(),
253            DataType::F32 => f32::decode(r)?.into(),
254            DataType::F64 => f64::decode(r)?.into(),
255            DataType::String => String::decode(r)?.into(),
256            DataType::Time => Time::decode(r)?.into(),
257            DataType::Duration => Duration::decode(r)?.into(),
258            DataType::LocalMessage(name) => {
259                let path = parent.peer(name);
260                let dependency = self.get_dependency(&path)?;
261                self.decode_message(dependency, r)?.into()
262            }
263            DataType::GlobalMessage(path) => {
264                let dependency = self.get_dependency(path)?;
265                self.decode_message(dependency, r)?.into()
266            }
267        })
268    }
269
270    fn decode_field_array(
271        &self,
272        parent: &MessagePath,
273        field: &FieldInfo,
274        array_length: Option<usize>,
275        r: &mut impl io::Read,
276    ) -> io::Result<Value> {
277        let array_length = match array_length {
278            Some(v) => v,
279            None => u32::decode(r.by_ref())? as usize,
280        };
281        // TODO: optimize by checking data type only once
282        (0..array_length)
283            .map(|_| self.decode_field(parent, field, r))
284            .collect()
285    }
286}