1use super::error::{ErrorKind, Result, ResultExt};
2use super::header;
3use super::util::streamfork::{fork, DataStream, TargetList};
4use super::util::tcpconnection;
5use super::{Message, Topic};
6use crate::util::FAILED_TO_LOCK;
7use crate::RawMessageDescription;
8use error_chain::bail;
9use log::error;
10use std::collections::HashMap;
11use std::net::{TcpListener, TcpStream, ToSocketAddrs};
12use std::sync::{atomic, Arc, Mutex};
13
14pub struct Publisher {
15 subscriptions: DataStream,
16 pub port: u16,
17 pub topic: Topic,
18 last_message: Arc<Mutex<Arc<Vec<u8>>>>,
19 queue_size: usize,
20 exists: Arc<atomic::AtomicBool>,
21}
22
23impl Drop for Publisher {
24 fn drop(&mut self) {
25 self.exists.store(false, atomic::Ordering::SeqCst);
26 }
27}
28
29fn match_headers(
30 fields: &HashMap<String, String>,
31 topic: &str,
32 message_description: &RawMessageDescription,
33) -> Result<()> {
34 header::match_field(fields, "md5sum", &message_description.md5sum)
35 .or_else(|e| header::match_field(fields, "md5sum", "*").or(Err(e)))?;
36 header::match_field(fields, "type", &message_description.msg_type)
37 .or_else(|e| header::match_field(fields, "type", "*").or(Err(e)))?;
38 header::match_field(fields, "topic", topic)?;
39 Ok(())
40}
41
42fn read_request<U: std::io::Read>(
43 mut stream: &mut U,
44 topic: &str,
45 message_description: &RawMessageDescription,
46) -> Result<String> {
47 let fields = header::decode(&mut stream)?;
48 match_headers(&fields, topic, message_description)?;
49 let caller_id = fields
50 .get("callerid")
51 .ok_or_else(|| ErrorKind::HeaderMissingField("callerid".into()))?;
52 Ok(caller_id.clone())
53}
54
55fn write_response<U: std::io::Write>(
56 mut stream: &mut U,
57 caller_id: &str,
58 topic: &str,
59 message_description: &RawMessageDescription,
60) -> Result<()> {
61 let mut fields = HashMap::<String, String>::new();
62 fields.insert(String::from("md5sum"), message_description.md5sum.clone());
63 fields.insert(String::from("type"), message_description.msg_type.clone());
64 fields.insert(String::from("topic"), String::from(topic));
65 fields.insert(String::from("callerid"), caller_id.into());
66 fields.insert(
67 String::from("message_definition"),
68 message_description.msg_definition.clone(),
69 );
70 header::encode(&mut stream, &fields)?;
71 Ok(())
72}
73
74fn exchange_headers<U>(
75 mut stream: &mut U,
76 topic: &str,
77 pub_caller_id: &str,
78 message_description: &RawMessageDescription,
79) -> Result<String>
80where
81 U: std::io::Write + std::io::Read,
82{
83 let caller_id = read_request(&mut stream, topic, message_description)?;
84 write_response(&mut stream, pub_caller_id, topic, message_description)?;
85 Ok(caller_id)
86}
87
88fn process_subscriber<U>(
89 topic: &str,
90 mut stream: U,
91 targets: &TargetList<U>,
92 last_message: &Mutex<Arc<Vec<u8>>>,
93 pub_caller_id: &str,
94 message_description: &RawMessageDescription,
95) -> tcpconnection::Feedback
96where
97 U: std::io::Read + std::io::Write + Send,
98{
99 let result = exchange_headers(&mut stream, topic, pub_caller_id, message_description)
100 .chain_err(|| ErrorKind::TopicConnectionFail(topic.into()));
101 let caller_id = match result {
102 Ok(caller_id) => caller_id,
103 Err(err) => {
104 let info = err
105 .iter()
106 .map(|v| format!("{}", v))
107 .collect::<Vec<_>>()
108 .join("\nCaused by:");
109 error!("{}", info);
110 return tcpconnection::Feedback::AcceptNextStream;
111 }
112 };
113
114 if let Err(err) = stream.write_all(&last_message.lock().expect(FAILED_TO_LOCK)) {
115 error!("{}", err);
116 return tcpconnection::Feedback::AcceptNextStream;
117 }
118
119 if targets.add(caller_id, stream).is_err() {
120 return tcpconnection::Feedback::StopAccepting;
124 }
125
126 tcpconnection::Feedback::AcceptNextStream
127}
128
129impl Publisher {
130 pub fn new<U>(
131 address: U,
132 topic: &str,
133 queue_size: usize,
134 caller_id: &str,
135 message_description: RawMessageDescription,
136 ) -> Result<Publisher>
137 where
138 U: ToSocketAddrs,
139 {
140 let listener = TcpListener::bind(address)?;
141 let socket_address = listener.local_addr()?;
142
143 let publisher_exists = Arc::new(atomic::AtomicBool::new(true));
144
145 let port = socket_address.port();
146 let (targets, data) = fork(queue_size);
147 let last_message = Arc::new(Mutex::new(Arc::new(Vec::new())));
148
149 let iterate_handler = {
150 let publisher_exists = publisher_exists.clone();
151 let topic = String::from(topic);
152 let last_message = Arc::clone(&last_message);
153 let caller_id = String::from(caller_id);
154 let message_description = message_description.clone();
155
156 move |stream: TcpStream| {
157 if !publisher_exists.load(atomic::Ordering::SeqCst) {
158 return tcpconnection::Feedback::StopAccepting;
159 }
160 process_subscriber(
161 &topic,
162 stream,
163 &targets,
164 &last_message,
165 &caller_id,
166 &message_description,
167 )
168 }
169 };
170
171 tcpconnection::iterate(listener, format!("topic '{}'", topic), iterate_handler);
172
173 let topic = Topic {
174 name: String::from(topic),
175 msg_type: message_description.msg_type,
176 md5sum: message_description.md5sum,
177 };
178
179 Ok(Publisher {
180 subscriptions: data,
181 port,
182 topic,
183 last_message,
184 queue_size,
185 exists: publisher_exists,
186 })
187 }
188
189 pub fn stream<T: Message>(
190 &self,
191 queue_size: usize,
192 message_description: RawMessageDescription,
193 ) -> Result<PublisherStream<T>> {
194 let mut stream = PublisherStream::new(self, message_description)?;
195 stream.set_queue_size_max(queue_size);
196 Ok(stream)
197 }
198
199 pub fn get_topic(&self) -> &Topic {
200 &self.topic
201 }
202}
203
204#[derive(Clone)]
209pub struct PublisherStream<T: Message> {
210 stream: DataStream,
211 last_message: Arc<Mutex<Arc<Vec<u8>>>>,
212 datatype: std::marker::PhantomData<T>,
213 latching: bool,
214}
215
216impl<T: Message> PublisherStream<T> {
217 fn new(
218 publisher: &Publisher,
219 message_description: RawMessageDescription,
220 ) -> Result<PublisherStream<T>> {
221 if publisher.topic.msg_type != message_description.msg_type {
222 bail!(ErrorKind::MessageTypeMismatch(
223 publisher.topic.msg_type.clone(),
224 message_description.msg_type,
225 ));
226 }
227 let mut stream = PublisherStream {
228 stream: publisher.subscriptions.clone(),
229 datatype: std::marker::PhantomData,
230 last_message: Arc::clone(&publisher.last_message),
231 latching: false,
232 };
233 stream.set_queue_size_max(publisher.queue_size);
234 Ok(stream)
235 }
236
237 #[inline]
238 pub fn subscriber_count(&self) -> usize {
239 self.stream.target_count()
240 }
241
242 #[inline]
243 pub fn subscriber_names(&self) -> Vec<String> {
244 self.stream.target_names()
245 }
246
247 #[inline]
248 pub fn set_latching(&mut self, latching: bool) {
249 self.latching = latching;
250 }
251
252 #[inline]
253 pub fn set_queue_size(&mut self, queue_size: usize) {
254 self.stream.set_queue_size(queue_size);
255 }
256
257 #[inline]
258 pub fn set_queue_size_max(&mut self, queue_size: usize) {
259 self.stream.set_queue_size_max(queue_size);
260 }
261
262 pub fn send(&self, message: &T) -> Result<()> {
263 let bytes = Arc::new(message.encode_vec()?);
264
265 if self.latching {
266 *self.last_message.lock().expect(FAILED_TO_LOCK) = Arc::clone(&bytes);
267 }
268
269 self.stream.send(bytes).expect("Connected thread died");
272 Ok(())
273 }
274}