rosrust_codegen/
helpers.rs

1use crate::alerts::MESSAGE_NAME_SHOULD_BE_VALID;
2use crate::error::{ErrorKind, Result, ResultExt};
3use crate::msg::{Msg, Srv};
4use error_chain::bail;
5use lazy_static::lazy_static;
6use ros_message::MessagePath;
7use std::collections::{HashMap, HashSet, LinkedList};
8use std::fs::{read_dir, File};
9use std::path::{Path, PathBuf};
10
11pub fn calculate_md5(message_map: &MessageMap) -> Result<HashMap<MessagePath, String>> {
12    let mut representations = HashMap::<MessagePath, String>::new();
13    let mut hashes = HashMap::<MessagePath, String>::new();
14    while hashes.len() < message_map.messages.len() {
15        let mut changed = false;
16        for (key, value) in &message_map.messages {
17            if hashes.contains_key(key) {
18                continue;
19            }
20            if let Ok(answer) = value.get_md5_representation(&hashes) {
21                hashes.insert(key.clone(), calculate_md5_from_representation(&answer));
22                representations.insert(key.clone(), answer);
23                changed = true;
24            }
25        }
26        if !changed {
27            break;
28        }
29    }
30    for message in message_map.services.keys() {
31        let key_req = message.peer(format!("{}Req", message.name()));
32        let key_res = message.peer(format!("{}Res", message.name()));
33        let req = match representations.get(&key_req) {
34            Some(v) => v,
35            None => bail!("Message map does not contain all needed elements"),
36        };
37        let res = match representations.get(&key_res) {
38            Some(v) => v,
39            None => bail!("Message map does not contain all needed elements"),
40        };
41        hashes.insert(
42            message.clone(),
43            calculate_md5_from_representation(&format!("{}{}", req, res)),
44        );
45    }
46    if hashes.len() < message_map.messages.len() + message_map.services.len() {
47        bail!("Message map does not contain all needed elements");
48    }
49    Ok(hashes)
50}
51
52fn calculate_md5_from_representation(v: &str) -> String {
53    use md5::{Digest, Md5};
54    let mut hasher = Md5::new();
55    hasher.update(v);
56    hex::encode(hasher.finalize())
57}
58
59pub fn generate_message_definition<S: std::hash::BuildHasher>(
60    message_map: &HashMap<MessagePath, Msg, S>,
61    message: &Msg,
62) -> Result<String> {
63    let mut handled_messages = HashSet::<MessagePath>::new();
64    let mut result = message.0.source().to_owned();
65    let mut pending = message
66        .dependencies()
67        .into_iter()
68        .collect::<LinkedList<_>>();
69    while let Some(value) = pending.pop_front() {
70        if handled_messages.contains(&value) {
71            continue;
72        }
73        handled_messages.insert(value.clone());
74        result += "\n\n========================================";
75        result += "========================================";
76        result += &format!("\nMSG: {}\n", value);
77        let message = match message_map.get(&value) {
78            Some(msg) => msg,
79            None => bail!("Message map does not contain all needed elements"),
80        };
81        for dependency in message.dependencies() {
82            pending.push_back(dependency);
83        }
84        result += message.0.source();
85    }
86    result += "\n";
87    Ok(result)
88}
89
90pub struct MessageMap {
91    pub messages: HashMap<MessagePath, Msg>,
92    pub services: HashMap<MessagePath, Srv>,
93}
94
95pub fn get_message_map(
96    ignore_bad_messages: bool,
97    folders: &[&str],
98    message_paths: &[MessagePath],
99) -> Result<MessageMap> {
100    let mut message_locations = HashMap::new();
101    let mut service_locations = HashMap::new();
102
103    let mut messages_and_services = vec![];
104    for folder in folders {
105        messages_and_services.append(&mut find_all_messages_and_services(Path::new(folder)));
106    }
107
108    for (message_path, file_path, message_type) in messages_and_services {
109        match message_type {
110            MessageType::Message => message_locations.insert(message_path, file_path),
111            MessageType::Service => service_locations.insert(message_path, file_path),
112        };
113    }
114
115    let mut messages = HashMap::new();
116    let mut services = HashMap::new();
117    let mut pending = message_paths.to_vec();
118    while let Some(message_path) = pending.pop() {
119        if messages.contains_key(&message_path) {
120            continue;
121        }
122        match get_message_or_service(
123            ignore_bad_messages,
124            folders,
125            &message_locations,
126            &service_locations,
127            message_path,
128        )? {
129            MessageCase::Message(message) => {
130                for dependency in message.dependencies() {
131                    pending.push(dependency);
132                }
133                messages.insert(message.0.path().clone(), message);
134            }
135            MessageCase::Service(service, req, res) => {
136                for dependency in req.dependencies() {
137                    pending.push(dependency);
138                }
139                for dependency in res.dependencies() {
140                    pending.push(dependency);
141                }
142                messages.insert(req.0.path().clone(), req);
143                messages.insert(res.0.path().clone(), res);
144                services.insert(service.path.clone(), service);
145            }
146        }
147    }
148    Ok(MessageMap { messages, services })
149}
150
151enum MessageType {
152    Message,
153    Service,
154}
155
156fn find_all_messages_and_services(root: &Path) -> Vec<(MessagePath, PathBuf, MessageType)> {
157    if !root.is_dir() {
158        return identify_message_or_service(root).into_iter().collect();
159    }
160    let mut items = vec![];
161    if let Ok(children) = read_dir(root) {
162        for child in children.filter_map(|v| v.ok()) {
163            items.append(&mut find_all_messages_and_services(&child.path()));
164        }
165    }
166    items
167}
168
169fn identify_message_or_service(filename: &Path) -> Option<(MessagePath, PathBuf, MessageType)> {
170    let extension = filename.extension()?;
171    let message = filename.file_stem()?;
172    let parent = filename.parent()?;
173    let grandparent = parent.parent()?;
174    let package = grandparent.file_name()?;
175    if Some(extension) != parent.file_name() {
176        return None;
177    }
178    let message_type = match extension.to_str() {
179        Some("msg") => MessageType::Message,
180        Some("srv") => MessageType::Service,
181        _ => return None,
182    };
183    Some((
184        MessagePath::new(package.to_str()?, message.to_str()?).ok()?,
185        filename.into(),
186        message_type,
187    ))
188}
189
190enum MessageCase {
191    Message(Msg),
192    Service(Srv, Msg, Msg),
193}
194
195lazy_static! {
196    static ref IN_MEMORY_MESSAGES: HashMap<MessagePath, &'static str> =
197        generate_in_memory_messages();
198}
199
200fn generate_in_memory_messages() -> HashMap<MessagePath, &'static str> {
201    let mut output = HashMap::new();
202    output.insert(
203        MessagePath::new("rosgraph_msgs", "Clock").expect(MESSAGE_NAME_SHOULD_BE_VALID),
204        include_str!("in_memory_messages/Clock.msg"),
205    );
206    output.insert(
207        MessagePath::new("rosgraph_msgs", "Log").expect(MESSAGE_NAME_SHOULD_BE_VALID),
208        include_str!("in_memory_messages/Log.msg"),
209    );
210    output.insert(
211        MessagePath::new("std_msgs", "Header").expect(MESSAGE_NAME_SHOULD_BE_VALID),
212        include_str!("in_memory_messages/Header.msg"),
213    );
214    output
215}
216
217fn get_message_or_service(
218    ignore_bad_messages: bool,
219    folders: &[&str],
220    message_locations: &HashMap<MessagePath, PathBuf>,
221    service_locations: &HashMap<MessagePath, PathBuf>,
222    path: MessagePath,
223) -> Result<MessageCase> {
224    use std::io::Read;
225
226    if let Some(full_path) = message_locations.get(&path) {
227        if let Ok(mut f) = File::open(full_path) {
228            let mut contents = String::new();
229            f.read_to_string(&mut contents)
230                .chain_err(|| "Failed to read file to string!")?;
231            return create_message(path, &contents, ignore_bad_messages).map(MessageCase::Message);
232        }
233    }
234    if let Some(full_path) = service_locations.get(&path) {
235        if let Ok(mut f) = File::open(full_path) {
236            let mut contents = String::new();
237            f.read_to_string(&mut contents)
238                .chain_err(|| "Failed to read file to string!")?;
239
240            let service = ros_message::Srv::new(path.clone(), &contents)
241                .or_else(|err| {
242                    if ignore_bad_messages {
243                        ros_message::Srv::new(path.clone(), "\n\n---\n\n")
244                    } else {
245                        Err(err)
246                    }
247                })
248                .chain_err(|| "Failed to build service messages")?;
249
250            return Ok(MessageCase::Service(
251                Srv {
252                    path: service.path().clone(),
253                    source: service.source().into(),
254                },
255                Msg(service.request().clone()),
256                Msg(service.response().clone()),
257            ));
258        }
259    }
260    if let Some(contents) = IN_MEMORY_MESSAGES.get(&path) {
261        return Msg::new(path, contents).map(MessageCase::Message);
262    }
263    if ignore_bad_messages {
264        return Msg::new(path, "").map(MessageCase::Message);
265    }
266    bail!(ErrorKind::MessageNotFound(
267        path.to_string(),
268        folders.join("\n"),
269    ))
270}
271
272fn create_message(message: MessagePath, contents: &str, ignore_bad_messages: bool) -> Result<Msg> {
273    Msg::new(message.clone(), contents).or_else(|err| {
274        if ignore_bad_messages {
275            Msg::new(message, "")
276        } else {
277            Err(err)
278        }
279    })
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use ros_message::MessagePath;
286
287    static FILEPATH: &str = "../msg_examples";
288
289    #[test]
290    fn get_message_map_fetches_leaf_message() {
291        let message_map = get_message_map(
292            false,
293            &[FILEPATH],
294            &[MessagePath::new("geometry_msgs", "Point").unwrap()],
295        )
296        .unwrap()
297        .messages;
298        assert_eq!(message_map.len(), 1);
299        assert!(message_map.contains_key(&MessagePath::new("geometry_msgs", "Point").unwrap()));
300    }
301
302    #[test]
303    fn get_message_map_fetches_message_and_dependencies() {
304        let message_map = get_message_map(
305            false,
306            &[FILEPATH],
307            &[MessagePath::new("geometry_msgs", "Pose").unwrap()],
308        )
309        .unwrap()
310        .messages;
311        assert_eq!(message_map.len(), 3);
312        assert!(message_map.contains_key(&MessagePath::new("geometry_msgs", "Point").unwrap()));
313        assert!(message_map.contains_key(&MessagePath::new("geometry_msgs", "Pose").unwrap()));
314        assert!(message_map.contains_key(&MessagePath::new("geometry_msgs", "Quaternion").unwrap()));
315    }
316
317    #[test]
318    fn get_message_map_traverses_whole_dependency_tree() {
319        let message_map = get_message_map(
320            false,
321            &[FILEPATH],
322            &[MessagePath::new("geometry_msgs", "PoseStamped").unwrap()],
323        )
324        .unwrap()
325        .messages;
326        assert_eq!(message_map.len(), 5);
327        assert!(message_map.contains_key(&MessagePath::new("geometry_msgs", "Point").unwrap()));
328        assert!(message_map.contains_key(&MessagePath::new("geometry_msgs", "Pose").unwrap()));
329        assert!(
330            message_map.contains_key(&MessagePath::new("geometry_msgs", "PoseStamped").unwrap())
331        );
332        assert!(message_map.contains_key(&MessagePath::new("geometry_msgs", "Quaternion").unwrap()));
333        assert!(message_map.contains_key(&MessagePath::new("std_msgs", "Header").unwrap()));
334    }
335
336    #[test]
337    fn get_message_map_traverses_all_passed_messages_dependency_tree() {
338        let message_map = get_message_map(
339            false,
340            &[FILEPATH],
341            &[
342                MessagePath::new("geometry_msgs", "PoseStamped").unwrap(),
343                MessagePath::new("sensor_msgs", "Imu").unwrap(),
344                MessagePath::new("rosgraph_msgs", "Clock").unwrap(),
345                MessagePath::new("rosgraph_msgs", "Log").unwrap(),
346            ],
347        )
348        .unwrap()
349        .messages;
350        assert_eq!(message_map.len(), 9);
351        assert!(message_map.contains_key(&MessagePath::new("geometry_msgs", "Vector3").unwrap()));
352        assert!(message_map.contains_key(&MessagePath::new("geometry_msgs", "Point").unwrap()));
353        assert!(message_map.contains_key(&MessagePath::new("geometry_msgs", "Pose").unwrap()));
354        assert!(
355            message_map.contains_key(&MessagePath::new("geometry_msgs", "PoseStamped").unwrap())
356        );
357        assert!(message_map.contains_key(&MessagePath::new("geometry_msgs", "Quaternion").unwrap()));
358        assert!(message_map.contains_key(&MessagePath::new("sensor_msgs", "Imu").unwrap()));
359        assert!(message_map.contains_key(&MessagePath::new("std_msgs", "Header").unwrap()));
360        assert!(message_map.contains_key(&MessagePath::new("rosgraph_msgs", "Clock").unwrap()));
361        assert!(message_map.contains_key(&MessagePath::new("rosgraph_msgs", "Log").unwrap()));
362    }
363
364    #[test]
365    fn calculate_md5_works() {
366        let message_map = get_message_map(
367            false,
368            &[FILEPATH],
369            &[
370                MessagePath::new("geometry_msgs", "PoseStamped").unwrap(),
371                MessagePath::new("sensor_msgs", "Imu").unwrap(),
372                MessagePath::new("rosgraph_msgs", "Clock").unwrap(),
373                MessagePath::new("rosgraph_msgs", "Log").unwrap(),
374            ],
375        )
376        .unwrap();
377        let hashes = calculate_md5(&message_map).unwrap();
378        assert_eq!(hashes.len(), 9);
379        assert_eq!(
380            *hashes
381                .get(&MessagePath::new("geometry_msgs", "Vector3").unwrap())
382                .unwrap(),
383            "4a842b65f413084dc2b10fb484ea7f17".to_owned()
384        );
385        assert_eq!(
386            *hashes
387                .get(&MessagePath::new("geometry_msgs", "Point").unwrap())
388                .unwrap(),
389            "4a842b65f413084dc2b10fb484ea7f17".to_owned()
390        );
391        assert_eq!(
392            *hashes
393                .get(&MessagePath::new("geometry_msgs", "Quaternion").unwrap())
394                .unwrap(),
395            "a779879fadf0160734f906b8c19c7004".to_owned()
396        );
397        assert_eq!(
398            *hashes
399                .get(&MessagePath::new("geometry_msgs", "Pose").unwrap())
400                .unwrap(),
401            "e45d45a5a1ce597b249e23fb30fc871f".to_owned()
402        );
403        assert_eq!(
404            *hashes
405                .get(&MessagePath::new("std_msgs", "Header").unwrap())
406                .unwrap(),
407            "2176decaecbce78abc3b96ef049fabed".to_owned()
408        );
409        assert_eq!(
410            *hashes
411                .get(&MessagePath::new("geometry_msgs", "PoseStamped").unwrap())
412                .unwrap(),
413            "d3812c3cbc69362b77dc0b19b345f8f5".to_owned()
414        );
415        assert_eq!(
416            *hashes
417                .get(&MessagePath::new("sensor_msgs", "Imu").unwrap())
418                .unwrap(),
419            "6a62c6daae103f4ff57a132d6f95cec2".to_owned()
420        );
421        assert_eq!(
422            *hashes
423                .get(&MessagePath::new("rosgraph_msgs", "Clock").unwrap())
424                .unwrap(),
425            "a9c97c1d230cfc112e270351a944ee47".to_owned()
426        );
427        assert_eq!(
428            *hashes
429                .get(&MessagePath::new("rosgraph_msgs", "Log").unwrap())
430                .unwrap(),
431            "acffd30cd6b6de30f120938c17c593fb".to_owned()
432        );
433    }
434
435    #[test]
436    fn generate_message_definition_works() {
437        let message_map = get_message_map(
438            false,
439            &[FILEPATH],
440            &[MessagePath::new("geometry_msgs", "Vector3").unwrap()],
441        )
442        .unwrap()
443        .messages;
444        let definition = generate_message_definition(
445            &message_map,
446            message_map
447                .get(&MessagePath::new("geometry_msgs", "Vector3").unwrap())
448                .unwrap(),
449        )
450        .unwrap();
451        assert_eq!(
452            definition,
453            "# This represents a vector in free space. \n# It is only meant to represent \
454             a direction. Therefore, it does not\n# make sense to apply a translation to \
455             it (e.g., when applying a \n# generic rigid transformation to a Vector3, tf2 \
456             will only apply the\n# rotation). If you want your data to be translatable \
457             too, use the\n# geometry_msgs/Point message instead.\n\nfloat64 x\nfloat64 \
458             y\nfloat64 z\n"
459        );
460        let message_map = get_message_map(
461            false,
462            &[FILEPATH],
463            &[MessagePath::new("geometry_msgs", "PoseStamped").unwrap()],
464        )
465        .unwrap()
466        .messages;
467        let definition = generate_message_definition(
468            &message_map,
469            message_map
470                .get(&MessagePath::new("geometry_msgs", "PoseStamped").unwrap())
471                .unwrap(),
472        )
473        .unwrap();
474        assert_eq!(
475            definition,
476            "# A Pose with reference coordinate frame and timestamp\n\
477Header header\n\
478Pose pose\n\
479\n\
480================================================================================\n\
481MSG: std_msgs/Header\n\
482# Standard metadata for higher-level stamped data types.\n\
483# This is generally used to communicate timestamped data \n\
484# in a particular coordinate frame.\n\
485# \n\
486# sequence ID: consecutively increasing ID \n\
487uint32 seq\n\
488#Two-integer timestamp that is expressed as:\n\
489# * stamp.sec: seconds (stamp_secs) since epoch (in Python the variable is called 'secs')\n\
490# * stamp.nsec: nanoseconds since stamp_secs (in Python the variable is called 'nsecs')\n\
491# time-handling sugar is provided by the client library\n\
492time stamp\n\
493#Frame this data is associated with\n\
494# 0: no frame\n\
495# 1: global frame\n\
496string frame_id\n\
497\n\
498================================================================================\n\
499MSG: geometry_msgs/Pose\n\
500# A representation of pose in free space, composed of position and orientation. \n\
501Point position\n\
502Quaternion orientation\n\
503\n\
504================================================================================\n\
505MSG: geometry_msgs/Point\n\
506# This contains the position of a point in free space\n\
507float64 x\n\
508float64 y\n\
509float64 z\n\
510\n\
511================================================================================\n\
512MSG: geometry_msgs/Quaternion\n\
513# This represents an orientation in free space in quaternion form.\n\
514\n\
515float64 x\n\
516float64 y\n\
517float64 z\n\
518float64 w\n\
519"
520        );
521    }
522
523    #[test]
524    fn calculate_md5_works_for_services() {
525        let message_map = get_message_map(
526            false,
527            &[FILEPATH],
528            &[
529                MessagePath::new("diagnostic_msgs", "AddDiagnostics").unwrap(),
530                MessagePath::new("simple_srv", "Something").unwrap(),
531            ],
532        )
533        .unwrap();
534        let hashes = calculate_md5(&message_map).unwrap();
535        assert_eq!(hashes.len(), 11);
536        assert_eq!(
537            *hashes
538                .get(&MessagePath::new("diagnostic_msgs", "AddDiagnostics").unwrap())
539                .unwrap(),
540            "e6ac9bbde83d0d3186523c3687aecaee".to_owned()
541        );
542        assert_eq!(
543            *hashes
544                .get(&MessagePath::new("simple_srv", "Something").unwrap())
545                .unwrap(),
546            "63715c08716373d8624430cde1434192".to_owned()
547        );
548    }
549
550    #[test]
551    fn parse_tricky_srv_files() {
552        get_message_map(
553            false,
554            &[FILEPATH],
555            &[
556                MessagePath::new("empty_srv", "Empty").unwrap(),
557                MessagePath::new("empty_req_srv", "EmptyRequest").unwrap(),
558                MessagePath::new("tricky_comment_srv", "TrickyComment").unwrap(),
559            ],
560        )
561        .unwrap();
562    }
563}