rosrust/tcpros/
subscriber.rs

1use super::error::{ErrorKind, Result, ResultExt};
2use super::header::{decode, encode, match_field};
3use super::{Message, Topic};
4use crate::rosmsg::RosMsg;
5use crate::util::lossy_channel::{lossy_channel, LossyReceiver, LossySender};
6use crate::SubscriptionHandler;
7use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
8use crossbeam::channel::{bounded, select, Receiver, Sender, TrySendError};
9use log::error;
10use std::collections::{BTreeMap, BTreeSet, HashMap};
11use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
12use std::sync::Arc;
13use std::thread;
14
15enum DataStreamConnectionChange {
16    Connect(
17        usize,
18        LossySender<MessageInfo>,
19        Sender<HashMap<String, String>>,
20    ),
21    Disconnect(usize),
22}
23
24pub struct SubscriberRosConnection {
25    next_data_stream_id: usize,
26    data_stream_tx: Sender<DataStreamConnectionChange>,
27    publishers_stream: Sender<SocketAddr>,
28    topic: Topic,
29    pub connected_ids: BTreeSet<usize>,
30    pub connected_publishers: BTreeSet<String>,
31}
32
33impl SubscriberRosConnection {
34    pub fn new(
35        caller_id: &str,
36        topic: &str,
37        msg_definition: String,
38        msg_type: String,
39        md5sum: String,
40    ) -> SubscriberRosConnection {
41        let subscriber_connection_queue_size = 8;
42        let (data_stream_tx, data_stream_rx) = bounded(subscriber_connection_queue_size);
43        let publisher_connection_queue_size = 8;
44        let (pub_tx, pub_rx) = bounded(publisher_connection_queue_size);
45        let caller_id = String::from(caller_id);
46        let topic_name = String::from(topic);
47        thread::spawn({
48            let msg_type = msg_type.clone();
49            let md5sum = md5sum.clone();
50            move || {
51                join_connections(
52                    data_stream_rx,
53                    pub_rx,
54                    &caller_id,
55                    &topic_name,
56                    &msg_definition,
57                    &md5sum,
58                    &msg_type,
59                )
60            }
61        });
62        let topic = Topic {
63            name: String::from(topic),
64            msg_type,
65            md5sum,
66        };
67        SubscriberRosConnection {
68            next_data_stream_id: 1,
69            data_stream_tx,
70            publishers_stream: pub_tx,
71            topic,
72            connected_ids: BTreeSet::new(),
73            connected_publishers: BTreeSet::new(),
74        }
75    }
76
77    // TODO: allow synchronous handling for subscribers
78    // This creates a new thread to call on_message. Next API change should
79    // allow subscribing with either callback or inline handler of the queue.
80    // The queue is lossy, so it wouldn't be blocking.
81    pub fn add_subscriber<T, H>(&mut self, queue_size: usize, handler: H) -> usize
82    where
83        T: Message,
84        H: SubscriptionHandler<T>,
85    {
86        let data_stream_id = self.next_data_stream_id;
87        self.connected_ids.insert(data_stream_id);
88        self.next_data_stream_id += 1;
89        let (data_tx, data_rx) = lossy_channel(queue_size);
90        let (connection_tx, connection_rx) = bounded(8);
91        if self
92            .data_stream_tx
93            .send(DataStreamConnectionChange::Connect(
94                data_stream_id,
95                data_tx,
96                connection_tx,
97            ))
98            .is_err()
99        {
100            // TODO: we might want to panic here
101            error!("Subscriber failed to connect to data stream");
102        }
103        thread::spawn(move || handle_data::<T, H>(data_rx, connection_rx, handler));
104        data_stream_id
105    }
106
107    pub fn remove_subscriber(&mut self, id: usize) {
108        self.connected_ids.remove(&id);
109        if self
110            .data_stream_tx
111            .send(DataStreamConnectionChange::Disconnect(id))
112            .is_err()
113        {
114            // TODO: we might want to panic here
115            error!("Subscriber failed to disconnect from data stream");
116        }
117    }
118
119    pub fn has_subscribers(&self) -> bool {
120        !self.connected_ids.is_empty()
121    }
122
123    #[inline]
124    pub fn publisher_count(&self) -> usize {
125        self.connected_publishers.len()
126    }
127
128    #[inline]
129    pub fn publisher_uris(&self) -> Vec<String> {
130        self.connected_publishers.iter().cloned().collect()
131    }
132
133    #[allow(clippy::useless_conversion)]
134    pub fn connect_to<U: ToSocketAddrs>(
135        &mut self,
136        publisher: &str,
137        addresses: U,
138    ) -> std::io::Result<()> {
139        for address in addresses.to_socket_addrs()? {
140            // This should never fail, so it's safe to unwrap
141            // Failure could only be caused by the join_connections
142            // thread not running, which only happens after
143            // Subscriber has been deconstructed
144            self.publishers_stream
145                .send(address)
146                .expect("Connected thread died");
147        }
148        self.connected_publishers.insert(publisher.to_owned());
149        Ok(())
150    }
151
152    pub fn is_connected_to(&self, publisher: &str) -> bool {
153        self.connected_publishers.contains(publisher)
154    }
155
156    pub fn limit_publishers_to(&mut self, publishers: &BTreeSet<String>) {
157        let difference: Vec<String> = self
158            .connected_publishers
159            .difference(publishers)
160            .cloned()
161            .collect();
162        for item in difference {
163            self.connected_publishers.remove(&item);
164        }
165    }
166
167    pub fn get_topic(&self) -> &Topic {
168        &self.topic
169    }
170}
171
172fn handle_data<T, H>(
173    data: LossyReceiver<MessageInfo>,
174    connections: Receiver<HashMap<String, String>>,
175    mut handler: H,
176) where
177    T: Message,
178    H: SubscriptionHandler<T>,
179{
180    loop {
181        select! {
182            recv(data.kill_rx.kill_rx) -> _ => break,
183            recv(data.data_rx) -> msg => match msg {
184                Err(_) => break,
185                Ok(buffer) => match RosMsg::decode_slice(&buffer.data) {
186                    Ok(value) => handler.message(value, &buffer.caller_id),
187                    Err(err) => error!("Failed to decode message: {}", err),
188                },
189            },
190            recv(connections) -> msg => match msg {
191                Err(_) => break,
192                Ok(conn) => handler.connection(conn),
193            },
194        }
195    }
196}
197
198fn join_connections(
199    subscribers: Receiver<DataStreamConnectionChange>,
200    publishers: Receiver<SocketAddr>,
201    caller_id: &str,
202    topic: &str,
203    msg_definition: &str,
204    md5sum: &str,
205    msg_type: &str,
206) {
207    type Sub = (LossySender<MessageInfo>, Sender<HashMap<String, String>>);
208    let mut subs: BTreeMap<usize, Sub> = BTreeMap::new();
209    let mut existing_headers: Vec<HashMap<String, String>> = Vec::new();
210
211    let (data_tx, data_rx): (Sender<MessageInfo>, Receiver<MessageInfo>) = bounded(8);
212
213    // Ends when subscriber or publisher sender is destroyed, which happens at Subscriber destruction
214    loop {
215        select! {
216            recv(data_rx) -> msg => {
217                match msg {
218                    Err(_) => break,
219                    Ok(v) => for sub in subs.values() {
220                        if sub.0.try_send(v.clone()).is_err() {
221                            error!("Failed to send data to subscriber");
222                        }
223                    }
224                }
225            }
226            recv(subscribers) -> msg => {
227                match msg {
228                    Err(_) => break,
229                    Ok(DataStreamConnectionChange::Connect(id, data, conn)) => {
230                        for header in &existing_headers {
231                            if conn.send(header.clone()).is_err() {
232                                error!("Failed to send connection info for subscriber");
233                            };
234                        }
235                        subs.insert(id, (data, conn));
236                    }
237                    Ok(DataStreamConnectionChange::Disconnect(id)) => {
238                        if let Some((mut data, _)) = subs.remove(&id) {
239                            if data.close().is_err() {
240                                error!("Subscriber data stream to topic has already been killed");
241                            }
242                        }
243                    }
244                }
245            }
246            recv(publishers) -> msg => {
247                match msg {
248                    Err(_) => break,
249                    Ok(publisher) => {
250                        let result = join_connection(
251                            &data_tx,
252                            &publisher,
253                            caller_id,
254                            topic,
255                            msg_definition,
256                            md5sum,
257                            msg_type,
258                        )
259                        .chain_err(|| ErrorKind::TopicConnectionFail(topic.into()));
260                        match result {
261                            Ok(headers) => {
262                                for sub in subs.values() {
263                                    if sub.1.send(headers.clone()).is_err() {
264                                        error!("Failed to send connection info for subscriber");
265                                    }
266                                }
267                                existing_headers.push(headers);
268                            }
269                            Err(err) => {
270                                let info = err
271                                    .iter()
272                                    .map(|v| format!("{}", v))
273                                    .collect::<Vec<_>>()
274                                    .join("\nCaused by:");
275                                error!("{}", info);
276                            }
277                        }
278                    }
279                }
280            }
281        }
282    }
283}
284
285fn join_connection(
286    data_stream: &Sender<MessageInfo>,
287    publisher: &SocketAddr,
288    caller_id: &str,
289    topic: &str,
290    msg_definition: &str,
291    md5sum: &str,
292    msg_type: &str,
293) -> Result<HashMap<String, String>> {
294    let mut stream = TcpStream::connect(publisher)?;
295    let headers = exchange_headers::<_>(
296        &mut stream,
297        caller_id,
298        topic,
299        msg_definition,
300        md5sum,
301        msg_type,
302    )?;
303    let pub_caller_id = headers.get("callerid").cloned();
304    let target = data_stream.clone();
305    thread::spawn(move || {
306        let pub_caller_id = Arc::new(pub_caller_id.unwrap_or_default());
307        while let Ok(buffer) = package_to_vector(&mut stream) {
308            if let Err(TrySendError::Disconnected(_)) =
309                target.try_send(MessageInfo::new(Arc::clone(&pub_caller_id), buffer))
310            {
311                // Data receiver has been destroyed after
312                // Subscriber destructor's kill signal
313                break;
314            }
315        }
316    });
317    Ok(headers)
318}
319
320fn write_request<U: std::io::Write>(
321    mut stream: &mut U,
322    caller_id: &str,
323    topic: &str,
324    msg_definition: &str,
325    md5sum: &str,
326    msg_type: &str,
327) -> Result<()> {
328    let mut fields = HashMap::<String, String>::new();
329    fields.insert(String::from("message_definition"), msg_definition.into());
330    fields.insert(String::from("callerid"), caller_id.into());
331    fields.insert(String::from("topic"), topic.into());
332    fields.insert(String::from("md5sum"), md5sum.into());
333    fields.insert(String::from("type"), msg_type.into());
334    encode(&mut stream, &fields)?;
335    Ok(())
336}
337
338fn read_response<U: std::io::Read>(
339    mut stream: &mut U,
340    md5sum: &str,
341    msg_type: &str,
342) -> Result<HashMap<String, String>> {
343    let fields = decode(&mut stream)?;
344    if md5sum != "*" {
345        match_field(&fields, "md5sum", md5sum)?;
346    }
347    if msg_type != "*" {
348        match_field(&fields, "type", msg_type)?;
349    }
350    Ok(fields)
351}
352
353fn exchange_headers<U>(
354    stream: &mut U,
355    caller_id: &str,
356    topic: &str,
357    msg_definition: &str,
358    md5sum: &str,
359    msg_type: &str,
360) -> Result<HashMap<String, String>>
361where
362    U: std::io::Write + std::io::Read,
363{
364    write_request::<U>(stream, caller_id, topic, msg_definition, md5sum, msg_type)?;
365    read_response::<U>(stream, md5sum, msg_type)
366}
367
368#[inline]
369fn package_to_vector<R: std::io::Read>(stream: &mut R) -> std::io::Result<Vec<u8>> {
370    let length = stream.read_u32::<LittleEndian>()?;
371    let u32_size = std::mem::size_of::<u32>();
372    let num_bytes = length as usize + u32_size;
373
374    // Allocate memory of the proper size for the incoming message. We
375    // do not initialize the memory to zero here (as would be safe)
376    // because it is expensive and ultimately unnecessary. We know the
377    // length of the message and if the length is incorrect, the
378    // stream reading functions will bail with an Error rather than
379    // leaving memory uninitialized.
380    let mut out = Vec::<u8>::with_capacity(num_bytes);
381
382    let out_ptr = out.as_mut_ptr();
383    // Read length from stream.
384    std::io::Cursor::new(unsafe { std::slice::from_raw_parts_mut(out_ptr as *mut u8, u32_size) })
385        .write_u32::<LittleEndian>(length)?;
386
387    // Read data from stream.
388    let read_buf = unsafe { std::slice::from_raw_parts_mut(out_ptr as *mut u8, num_bytes) };
389    stream.read_exact(&mut read_buf[u32_size..])?;
390
391    // Don't drop the original Vec which has size==0 and instead use
392    // its memory to initialize a new Vec with size == capacity == num_bytes.
393    std::mem::forget(out);
394
395    // Return the new, now full and "safely" initialized.
396    Ok(unsafe { Vec::from_raw_parts(out_ptr, num_bytes, num_bytes) })
397}
398
399#[derive(Clone)]
400struct MessageInfo {
401    caller_id: Arc<String>,
402    data: Vec<u8>,
403}
404
405impl MessageInfo {
406    fn new(caller_id: Arc<String>, data: Vec<u8>) -> Self {
407        Self { caller_id, data }
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    static FAILED_TO_READ_WRITE_VECTOR: &str = "Failed to read or write from vector";
416
417    #[test]
418    fn package_to_vector_creates_right_buffer_from_reader() {
419        let input = [7, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7];
420        let data =
421            package_to_vector(&mut std::io::Cursor::new(input)).expect(FAILED_TO_READ_WRITE_VECTOR);
422        assert_eq!(data, [7, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7]);
423    }
424
425    #[test]
426    fn package_to_vector_respects_provided_length() {
427        let input = [7, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
428        let data =
429            package_to_vector(&mut std::io::Cursor::new(input)).expect(FAILED_TO_READ_WRITE_VECTOR);
430        assert_eq!(data, [7, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7]);
431    }
432
433    #[test]
434    fn package_to_vector_fails_if_stream_is_shorter_than_annotated() {
435        let input = [7, 0, 0, 0, 1, 2, 3, 4, 5];
436        package_to_vector(&mut std::io::Cursor::new(input)).unwrap_err();
437    }
438
439    #[test]
440    fn package_to_vector_fails_leaves_cursor_at_end_of_reading() {
441        let input = [7, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 4, 0, 0, 0, 11, 12, 13, 14];
442        let mut cursor = std::io::Cursor::new(input);
443        let data = package_to_vector(&mut cursor).expect(FAILED_TO_READ_WRITE_VECTOR);
444        assert_eq!(data, [7, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7]);
445        let data = package_to_vector(&mut cursor).expect(FAILED_TO_READ_WRITE_VECTOR);
446        assert_eq!(data, [4, 0, 0, 0, 11, 12, 13, 14]);
447    }
448}