1use super::error::{ErrorKind, Result, ResultExt};
2use super::header::{decode, encode, match_field};
3use super::{Message, Topic};
4use crate::rosmsg::RosMsg;
5use crate::util::lossy_channel::{lossy_channel, LossyReceiver, LossySender};
6use crate::SubscriptionHandler;
7use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
8use crossbeam::channel::{bounded, select, Receiver, Sender, TrySendError};
9use log::error;
10use std::collections::{BTreeMap, BTreeSet, HashMap};
11use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
12use std::sync::Arc;
13use std::thread;
14
15enum DataStreamConnectionChange {
16 Connect(
17 usize,
18 LossySender<MessageInfo>,
19 Sender<HashMap<String, String>>,
20 ),
21 Disconnect(usize),
22}
23
24pub struct SubscriberRosConnection {
25 next_data_stream_id: usize,
26 data_stream_tx: Sender<DataStreamConnectionChange>,
27 publishers_stream: Sender<SocketAddr>,
28 topic: Topic,
29 pub connected_ids: BTreeSet<usize>,
30 pub connected_publishers: BTreeSet<String>,
31}
32
33impl SubscriberRosConnection {
34 pub fn new(
35 caller_id: &str,
36 topic: &str,
37 msg_definition: String,
38 msg_type: String,
39 md5sum: String,
40 ) -> SubscriberRosConnection {
41 let subscriber_connection_queue_size = 8;
42 let (data_stream_tx, data_stream_rx) = bounded(subscriber_connection_queue_size);
43 let publisher_connection_queue_size = 8;
44 let (pub_tx, pub_rx) = bounded(publisher_connection_queue_size);
45 let caller_id = String::from(caller_id);
46 let topic_name = String::from(topic);
47 thread::spawn({
48 let msg_type = msg_type.clone();
49 let md5sum = md5sum.clone();
50 move || {
51 join_connections(
52 data_stream_rx,
53 pub_rx,
54 &caller_id,
55 &topic_name,
56 &msg_definition,
57 &md5sum,
58 &msg_type,
59 )
60 }
61 });
62 let topic = Topic {
63 name: String::from(topic),
64 msg_type,
65 md5sum,
66 };
67 SubscriberRosConnection {
68 next_data_stream_id: 1,
69 data_stream_tx,
70 publishers_stream: pub_tx,
71 topic,
72 connected_ids: BTreeSet::new(),
73 connected_publishers: BTreeSet::new(),
74 }
75 }
76
77 pub fn add_subscriber<T, H>(&mut self, queue_size: usize, handler: H) -> usize
82 where
83 T: Message,
84 H: SubscriptionHandler<T>,
85 {
86 let data_stream_id = self.next_data_stream_id;
87 self.connected_ids.insert(data_stream_id);
88 self.next_data_stream_id += 1;
89 let (data_tx, data_rx) = lossy_channel(queue_size);
90 let (connection_tx, connection_rx) = bounded(8);
91 if self
92 .data_stream_tx
93 .send(DataStreamConnectionChange::Connect(
94 data_stream_id,
95 data_tx,
96 connection_tx,
97 ))
98 .is_err()
99 {
100 error!("Subscriber failed to connect to data stream");
102 }
103 thread::spawn(move || handle_data::<T, H>(data_rx, connection_rx, handler));
104 data_stream_id
105 }
106
107 pub fn remove_subscriber(&mut self, id: usize) {
108 self.connected_ids.remove(&id);
109 if self
110 .data_stream_tx
111 .send(DataStreamConnectionChange::Disconnect(id))
112 .is_err()
113 {
114 error!("Subscriber failed to disconnect from data stream");
116 }
117 }
118
119 pub fn has_subscribers(&self) -> bool {
120 !self.connected_ids.is_empty()
121 }
122
123 #[inline]
124 pub fn publisher_count(&self) -> usize {
125 self.connected_publishers.len()
126 }
127
128 #[inline]
129 pub fn publisher_uris(&self) -> Vec<String> {
130 self.connected_publishers.iter().cloned().collect()
131 }
132
133 #[allow(clippy::useless_conversion)]
134 pub fn connect_to<U: ToSocketAddrs>(
135 &mut self,
136 publisher: &str,
137 addresses: U,
138 ) -> std::io::Result<()> {
139 for address in addresses.to_socket_addrs()? {
140 self.publishers_stream
145 .send(address)
146 .expect("Connected thread died");
147 }
148 self.connected_publishers.insert(publisher.to_owned());
149 Ok(())
150 }
151
152 pub fn is_connected_to(&self, publisher: &str) -> bool {
153 self.connected_publishers.contains(publisher)
154 }
155
156 pub fn limit_publishers_to(&mut self, publishers: &BTreeSet<String>) {
157 let difference: Vec<String> = self
158 .connected_publishers
159 .difference(publishers)
160 .cloned()
161 .collect();
162 for item in difference {
163 self.connected_publishers.remove(&item);
164 }
165 }
166
167 pub fn get_topic(&self) -> &Topic {
168 &self.topic
169 }
170}
171
172fn handle_data<T, H>(
173 data: LossyReceiver<MessageInfo>,
174 connections: Receiver<HashMap<String, String>>,
175 mut handler: H,
176) where
177 T: Message,
178 H: SubscriptionHandler<T>,
179{
180 loop {
181 select! {
182 recv(data.kill_rx.kill_rx) -> _ => break,
183 recv(data.data_rx) -> msg => match msg {
184 Err(_) => break,
185 Ok(buffer) => match RosMsg::decode_slice(&buffer.data) {
186 Ok(value) => handler.message(value, &buffer.caller_id),
187 Err(err) => error!("Failed to decode message: {}", err),
188 },
189 },
190 recv(connections) -> msg => match msg {
191 Err(_) => break,
192 Ok(conn) => handler.connection(conn),
193 },
194 }
195 }
196}
197
198fn join_connections(
199 subscribers: Receiver<DataStreamConnectionChange>,
200 publishers: Receiver<SocketAddr>,
201 caller_id: &str,
202 topic: &str,
203 msg_definition: &str,
204 md5sum: &str,
205 msg_type: &str,
206) {
207 type Sub = (LossySender<MessageInfo>, Sender<HashMap<String, String>>);
208 let mut subs: BTreeMap<usize, Sub> = BTreeMap::new();
209 let mut existing_headers: Vec<HashMap<String, String>> = Vec::new();
210
211 let (data_tx, data_rx): (Sender<MessageInfo>, Receiver<MessageInfo>) = bounded(8);
212
213 loop {
215 select! {
216 recv(data_rx) -> msg => {
217 match msg {
218 Err(_) => break,
219 Ok(v) => for sub in subs.values() {
220 if sub.0.try_send(v.clone()).is_err() {
221 error!("Failed to send data to subscriber");
222 }
223 }
224 }
225 }
226 recv(subscribers) -> msg => {
227 match msg {
228 Err(_) => break,
229 Ok(DataStreamConnectionChange::Connect(id, data, conn)) => {
230 for header in &existing_headers {
231 if conn.send(header.clone()).is_err() {
232 error!("Failed to send connection info for subscriber");
233 };
234 }
235 subs.insert(id, (data, conn));
236 }
237 Ok(DataStreamConnectionChange::Disconnect(id)) => {
238 if let Some((mut data, _)) = subs.remove(&id) {
239 if data.close().is_err() {
240 error!("Subscriber data stream to topic has already been killed");
241 }
242 }
243 }
244 }
245 }
246 recv(publishers) -> msg => {
247 match msg {
248 Err(_) => break,
249 Ok(publisher) => {
250 let result = join_connection(
251 &data_tx,
252 &publisher,
253 caller_id,
254 topic,
255 msg_definition,
256 md5sum,
257 msg_type,
258 )
259 .chain_err(|| ErrorKind::TopicConnectionFail(topic.into()));
260 match result {
261 Ok(headers) => {
262 for sub in subs.values() {
263 if sub.1.send(headers.clone()).is_err() {
264 error!("Failed to send connection info for subscriber");
265 }
266 }
267 existing_headers.push(headers);
268 }
269 Err(err) => {
270 let info = err
271 .iter()
272 .map(|v| format!("{}", v))
273 .collect::<Vec<_>>()
274 .join("\nCaused by:");
275 error!("{}", info);
276 }
277 }
278 }
279 }
280 }
281 }
282 }
283}
284
285fn join_connection(
286 data_stream: &Sender<MessageInfo>,
287 publisher: &SocketAddr,
288 caller_id: &str,
289 topic: &str,
290 msg_definition: &str,
291 md5sum: &str,
292 msg_type: &str,
293) -> Result<HashMap<String, String>> {
294 let mut stream = TcpStream::connect(publisher)?;
295 let headers = exchange_headers::<_>(
296 &mut stream,
297 caller_id,
298 topic,
299 msg_definition,
300 md5sum,
301 msg_type,
302 )?;
303 let pub_caller_id = headers.get("callerid").cloned();
304 let target = data_stream.clone();
305 thread::spawn(move || {
306 let pub_caller_id = Arc::new(pub_caller_id.unwrap_or_default());
307 while let Ok(buffer) = package_to_vector(&mut stream) {
308 if let Err(TrySendError::Disconnected(_)) =
309 target.try_send(MessageInfo::new(Arc::clone(&pub_caller_id), buffer))
310 {
311 break;
314 }
315 }
316 });
317 Ok(headers)
318}
319
320fn write_request<U: std::io::Write>(
321 mut stream: &mut U,
322 caller_id: &str,
323 topic: &str,
324 msg_definition: &str,
325 md5sum: &str,
326 msg_type: &str,
327) -> Result<()> {
328 let mut fields = HashMap::<String, String>::new();
329 fields.insert(String::from("message_definition"), msg_definition.into());
330 fields.insert(String::from("callerid"), caller_id.into());
331 fields.insert(String::from("topic"), topic.into());
332 fields.insert(String::from("md5sum"), md5sum.into());
333 fields.insert(String::from("type"), msg_type.into());
334 encode(&mut stream, &fields)?;
335 Ok(())
336}
337
338fn read_response<U: std::io::Read>(
339 mut stream: &mut U,
340 md5sum: &str,
341 msg_type: &str,
342) -> Result<HashMap<String, String>> {
343 let fields = decode(&mut stream)?;
344 if md5sum != "*" {
345 match_field(&fields, "md5sum", md5sum)?;
346 }
347 if msg_type != "*" {
348 match_field(&fields, "type", msg_type)?;
349 }
350 Ok(fields)
351}
352
353fn exchange_headers<U>(
354 stream: &mut U,
355 caller_id: &str,
356 topic: &str,
357 msg_definition: &str,
358 md5sum: &str,
359 msg_type: &str,
360) -> Result<HashMap<String, String>>
361where
362 U: std::io::Write + std::io::Read,
363{
364 write_request::<U>(stream, caller_id, topic, msg_definition, md5sum, msg_type)?;
365 read_response::<U>(stream, md5sum, msg_type)
366}
367
368#[inline]
369fn package_to_vector<R: std::io::Read>(stream: &mut R) -> std::io::Result<Vec<u8>> {
370 let length = stream.read_u32::<LittleEndian>()?;
371 let u32_size = std::mem::size_of::<u32>();
372 let num_bytes = length as usize + u32_size;
373
374 let mut out = Vec::<u8>::with_capacity(num_bytes);
381
382 let out_ptr = out.as_mut_ptr();
383 std::io::Cursor::new(unsafe { std::slice::from_raw_parts_mut(out_ptr as *mut u8, u32_size) })
385 .write_u32::<LittleEndian>(length)?;
386
387 let read_buf = unsafe { std::slice::from_raw_parts_mut(out_ptr as *mut u8, num_bytes) };
389 stream.read_exact(&mut read_buf[u32_size..])?;
390
391 std::mem::forget(out);
394
395 Ok(unsafe { Vec::from_raw_parts(out_ptr, num_bytes, num_bytes) })
397}
398
399#[derive(Clone)]
400struct MessageInfo {
401 caller_id: Arc<String>,
402 data: Vec<u8>,
403}
404
405impl MessageInfo {
406 fn new(caller_id: Arc<String>, data: Vec<u8>) -> Self {
407 Self { caller_id, data }
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414
415 static FAILED_TO_READ_WRITE_VECTOR: &str = "Failed to read or write from vector";
416
417 #[test]
418 fn package_to_vector_creates_right_buffer_from_reader() {
419 let input = [7, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7];
420 let data =
421 package_to_vector(&mut std::io::Cursor::new(input)).expect(FAILED_TO_READ_WRITE_VECTOR);
422 assert_eq!(data, [7, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7]);
423 }
424
425 #[test]
426 fn package_to_vector_respects_provided_length() {
427 let input = [7, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
428 let data =
429 package_to_vector(&mut std::io::Cursor::new(input)).expect(FAILED_TO_READ_WRITE_VECTOR);
430 assert_eq!(data, [7, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7]);
431 }
432
433 #[test]
434 fn package_to_vector_fails_if_stream_is_shorter_than_annotated() {
435 let input = [7, 0, 0, 0, 1, 2, 3, 4, 5];
436 package_to_vector(&mut std::io::Cursor::new(input)).unwrap_err();
437 }
438
439 #[test]
440 fn package_to_vector_fails_leaves_cursor_at_end_of_reading() {
441 let input = [7, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 4, 0, 0, 0, 11, 12, 13, 14];
442 let mut cursor = std::io::Cursor::new(input);
443 let data = package_to_vector(&mut cursor).expect(FAILED_TO_READ_WRITE_VECTOR);
444 assert_eq!(data, [7, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7]);
445 let data = package_to_vector(&mut cursor).expect(FAILED_TO_READ_WRITE_VECTOR);
446 assert_eq!(data, [4, 0, 0, 0, 11, 12, 13, 14]);
447 }
448}