rosrust/tcpros/
publisher.rs

1use super::error::{ErrorKind, Result, ResultExt};
2use super::header;
3use super::util::streamfork::{fork, DataStream, TargetList};
4use super::util::tcpconnection;
5use super::{Message, Topic};
6use crate::util::FAILED_TO_LOCK;
7use crate::RawMessageDescription;
8use error_chain::bail;
9use log::error;
10use std::collections::HashMap;
11use std::net::{TcpListener, TcpStream, ToSocketAddrs};
12use std::sync::{atomic, Arc, Mutex};
13
14pub struct Publisher {
15    subscriptions: DataStream,
16    pub port: u16,
17    pub topic: Topic,
18    last_message: Arc<Mutex<Arc<Vec<u8>>>>,
19    queue_size: usize,
20    exists: Arc<atomic::AtomicBool>,
21}
22
23impl Drop for Publisher {
24    fn drop(&mut self) {
25        self.exists.store(false, atomic::Ordering::SeqCst);
26    }
27}
28
29fn match_headers(
30    fields: &HashMap<String, String>,
31    topic: &str,
32    message_description: &RawMessageDescription,
33) -> Result<()> {
34    header::match_field(fields, "md5sum", &message_description.md5sum)
35        .or_else(|e| header::match_field(fields, "md5sum", "*").or(Err(e)))?;
36    header::match_field(fields, "type", &message_description.msg_type)
37        .or_else(|e| header::match_field(fields, "type", "*").or(Err(e)))?;
38    header::match_field(fields, "topic", topic)?;
39    Ok(())
40}
41
42fn read_request<U: std::io::Read>(
43    mut stream: &mut U,
44    topic: &str,
45    message_description: &RawMessageDescription,
46) -> Result<String> {
47    let fields = header::decode(&mut stream)?;
48    match_headers(&fields, topic, message_description)?;
49    let caller_id = fields
50        .get("callerid")
51        .ok_or_else(|| ErrorKind::HeaderMissingField("callerid".into()))?;
52    Ok(caller_id.clone())
53}
54
55fn write_response<U: std::io::Write>(
56    mut stream: &mut U,
57    caller_id: &str,
58    topic: &str,
59    message_description: &RawMessageDescription,
60) -> Result<()> {
61    let mut fields = HashMap::<String, String>::new();
62    fields.insert(String::from("md5sum"), message_description.md5sum.clone());
63    fields.insert(String::from("type"), message_description.msg_type.clone());
64    fields.insert(String::from("topic"), String::from(topic));
65    fields.insert(String::from("callerid"), caller_id.into());
66    fields.insert(
67        String::from("message_definition"),
68        message_description.msg_definition.clone(),
69    );
70    header::encode(&mut stream, &fields)?;
71    Ok(())
72}
73
74fn exchange_headers<U>(
75    mut stream: &mut U,
76    topic: &str,
77    pub_caller_id: &str,
78    message_description: &RawMessageDescription,
79) -> Result<String>
80where
81    U: std::io::Write + std::io::Read,
82{
83    let caller_id = read_request(&mut stream, topic, message_description)?;
84    write_response(&mut stream, pub_caller_id, topic, message_description)?;
85    Ok(caller_id)
86}
87
88fn process_subscriber<U>(
89    topic: &str,
90    mut stream: U,
91    targets: &TargetList<U>,
92    last_message: &Mutex<Arc<Vec<u8>>>,
93    pub_caller_id: &str,
94    message_description: &RawMessageDescription,
95) -> tcpconnection::Feedback
96where
97    U: std::io::Read + std::io::Write + Send,
98{
99    let result = exchange_headers(&mut stream, topic, pub_caller_id, message_description)
100        .chain_err(|| ErrorKind::TopicConnectionFail(topic.into()));
101    let caller_id = match result {
102        Ok(caller_id) => caller_id,
103        Err(err) => {
104            let info = err
105                .iter()
106                .map(|v| format!("{}", v))
107                .collect::<Vec<_>>()
108                .join("\nCaused by:");
109            error!("{}", info);
110            return tcpconnection::Feedback::AcceptNextStream;
111        }
112    };
113
114    if let Err(err) = stream.write_all(&last_message.lock().expect(FAILED_TO_LOCK)) {
115        error!("{}", err);
116        return tcpconnection::Feedback::AcceptNextStream;
117    }
118
119    if targets.add(caller_id, stream).is_err() {
120        // The TCP listener gets shut down when streamfork's thread deallocates.
121        // This happens only when all the corresponding publisher streams get deallocated,
122        // causing streamfork's data channel to shut down
123        return tcpconnection::Feedback::StopAccepting;
124    }
125
126    tcpconnection::Feedback::AcceptNextStream
127}
128
129impl Publisher {
130    pub fn new<U>(
131        address: U,
132        topic: &str,
133        queue_size: usize,
134        caller_id: &str,
135        message_description: RawMessageDescription,
136    ) -> Result<Publisher>
137    where
138        U: ToSocketAddrs,
139    {
140        let listener = TcpListener::bind(address)?;
141        let socket_address = listener.local_addr()?;
142
143        let publisher_exists = Arc::new(atomic::AtomicBool::new(true));
144
145        let port = socket_address.port();
146        let (targets, data) = fork(queue_size);
147        let last_message = Arc::new(Mutex::new(Arc::new(Vec::new())));
148
149        let iterate_handler = {
150            let publisher_exists = publisher_exists.clone();
151            let topic = String::from(topic);
152            let last_message = Arc::clone(&last_message);
153            let caller_id = String::from(caller_id);
154            let message_description = message_description.clone();
155
156            move |stream: TcpStream| {
157                if !publisher_exists.load(atomic::Ordering::SeqCst) {
158                    return tcpconnection::Feedback::StopAccepting;
159                }
160                process_subscriber(
161                    &topic,
162                    stream,
163                    &targets,
164                    &last_message,
165                    &caller_id,
166                    &message_description,
167                )
168            }
169        };
170
171        tcpconnection::iterate(listener, format!("topic '{}'", topic), iterate_handler);
172
173        let topic = Topic {
174            name: String::from(topic),
175            msg_type: message_description.msg_type,
176            md5sum: message_description.md5sum,
177        };
178
179        Ok(Publisher {
180            subscriptions: data,
181            port,
182            topic,
183            last_message,
184            queue_size,
185            exists: publisher_exists,
186        })
187    }
188
189    pub fn stream<T: Message>(
190        &self,
191        queue_size: usize,
192        message_description: RawMessageDescription,
193    ) -> Result<PublisherStream<T>> {
194        let mut stream = PublisherStream::new(self, message_description)?;
195        stream.set_queue_size_max(queue_size);
196        Ok(stream)
197    }
198
199    pub fn get_topic(&self) -> &Topic {
200        &self.topic
201    }
202}
203
204// TODO: publisher should only be removed from master API once the publisher and all
205// publisher streams are gone. This should be done with a RAII Arc, residing next todo
206// the datastream. So maybe replace DataStream with a wrapper that holds that Arc too
207
208#[derive(Clone)]
209pub struct PublisherStream<T: Message> {
210    stream: DataStream,
211    last_message: Arc<Mutex<Arc<Vec<u8>>>>,
212    datatype: std::marker::PhantomData<T>,
213    latching: bool,
214}
215
216impl<T: Message> PublisherStream<T> {
217    fn new(
218        publisher: &Publisher,
219        message_description: RawMessageDescription,
220    ) -> Result<PublisherStream<T>> {
221        if publisher.topic.msg_type != message_description.msg_type {
222            bail!(ErrorKind::MessageTypeMismatch(
223                publisher.topic.msg_type.clone(),
224                message_description.msg_type,
225            ));
226        }
227        let mut stream = PublisherStream {
228            stream: publisher.subscriptions.clone(),
229            datatype: std::marker::PhantomData,
230            last_message: Arc::clone(&publisher.last_message),
231            latching: false,
232        };
233        stream.set_queue_size_max(publisher.queue_size);
234        Ok(stream)
235    }
236
237    #[inline]
238    pub fn subscriber_count(&self) -> usize {
239        self.stream.target_count()
240    }
241
242    #[inline]
243    pub fn subscriber_names(&self) -> Vec<String> {
244        self.stream.target_names()
245    }
246
247    #[inline]
248    pub fn set_latching(&mut self, latching: bool) {
249        self.latching = latching;
250    }
251
252    #[inline]
253    pub fn set_queue_size(&mut self, queue_size: usize) {
254        self.stream.set_queue_size(queue_size);
255    }
256
257    #[inline]
258    pub fn set_queue_size_max(&mut self, queue_size: usize) {
259        self.stream.set_queue_size_max(queue_size);
260    }
261
262    pub fn send(&self, message: &T) -> Result<()> {
263        let bytes = Arc::new(message.encode_vec()?);
264
265        if self.latching {
266            *self.last_message.lock().expect(FAILED_TO_LOCK) = Arc::clone(&bytes);
267        }
268
269        // Subscriptions can only be closed from the Publisher side
270        // There is no way for the streamfork thread to fail by itself
271        self.stream.send(bytes).expect("Connected thread died");
272        Ok(())
273    }
274}