1use super::error::{ErrorKind, Result, ResultExt};
2use super::header::{decode, encode};
3use super::{ServicePair, ServiceResult};
4use crate::api::Master;
5use crate::rosmsg::RosMsg;
6use crate::util::FAILED_TO_LOCK;
7use byteorder::{LittleEndian, ReadBytesExt};
8use error_chain::bail;
9use log::error;
10use socket2::Socket;
11use std::collections::HashMap;
12use std::io;
13use std::io::{Read, Write};
14use std::net::{TcpStream, ToSocketAddrs};
15use std::sync::{Arc, Mutex};
16use std::thread;
17
18pub struct ClientResponse<T> {
19 handle: thread::JoinHandle<Result<ServiceResult<T>>>,
20}
21
22impl<T> ClientResponse<T> {
23 pub fn read(self) -> Result<ServiceResult<T>> {
24 self.handle
25 .join()
26 .unwrap_or_else(|_| Err(ErrorKind::ServiceResponseUnknown.into()))
27 }
28}
29
30impl<T: Send + 'static> ClientResponse<T> {
31 pub fn callback<F>(self, callback: F)
32 where
33 F: FnOnce(Result<ServiceResult<T>>) + Send + 'static,
34 {
35 thread::spawn(move || callback(self.read()));
36 }
37}
38
39struct ClientInfo {
40 caller_id: String,
41 service: String,
42}
43
44struct UriCache {
45 master: std::sync::Arc<Master>,
46 data: Mutex<Option<String>>,
47 service: String,
48}
49
50impl UriCache {
51 fn get(&self) -> Result<String> {
52 if let Some(uri) = Option::<String>::clone(&self.data.lock().expect(FAILED_TO_LOCK)) {
53 return Ok(uri);
54 }
55 let new_uri = self
56 .master
57 .lookup_service(&self.service)
58 .chain_err(|| ErrorKind::ServiceConnectionFail(self.service.clone()))?;
59 *self.data.lock().expect(FAILED_TO_LOCK) = Some(new_uri.clone());
60 Ok(new_uri)
61 }
62
63 fn clear(&self) {
64 *self.data.lock().expect(FAILED_TO_LOCK) = None;
65 }
66}
67
68#[derive(Clone)]
69pub struct Client<T: ServicePair> {
70 info: std::sync::Arc<ClientInfo>,
71 uri_cache: std::sync::Arc<UriCache>,
72 phantom: std::marker::PhantomData<T>,
73}
74
75fn connect_to_tcp_attempt(
76 uri_cache: &UriCache,
77 timeout: Option<std::time::Duration>,
78) -> Result<TcpStream> {
79 let uri = uri_cache.get()?;
80 let trimmed_uri = uri.trim_start_matches("rosrpc://");
81 let stream = match timeout {
82 Some(timeout) => {
83 let invalid_addr_error = || {
84 ErrorKind::Io(io::Error::new(
85 io::ErrorKind::Other,
86 "Invalid socket address",
87 ))
88 };
89 let socket_addr = trimmed_uri
90 .to_socket_addrs()
91 .chain_err(invalid_addr_error)?
92 .next()
93 .ok_or_else(invalid_addr_error)?;
94 TcpStream::connect_timeout(&socket_addr, timeout)?
95 }
96 None => TcpStream::connect(trimmed_uri)?,
97 };
98 let socket: Socket = stream.into();
99 if let Some(timeout) = timeout {
100 socket.set_read_timeout(Some(timeout))?;
102 socket.set_write_timeout(Some(timeout))?;
103 }
104 socket.set_linger(None)?;
105 let stream: TcpStream = socket.into();
106 Ok(stream)
107}
108
109fn connect_to_tcp_with_multiple_attempts(
110 uri_cache: &UriCache,
111 attempts: usize,
112) -> Result<TcpStream> {
113 let mut err = io::Error::new(
114 io::ErrorKind::Other,
115 "Tried to connect via TCP with 0 connection attempts",
116 )
117 .into();
118 let mut repeat_delay_ms = 1;
119 for _ in 0..attempts {
120 let stream_result = connect_to_tcp_attempt(uri_cache, None);
121 match stream_result {
122 Ok(stream) => {
123 return Ok(stream);
124 }
125 Err(error) => err = error,
126 }
127 uri_cache.clear();
128 std::thread::sleep(std::time::Duration::from_millis(repeat_delay_ms));
129 repeat_delay_ms *= 2;
130 }
131 Err(err)
132}
133
134impl<T: ServicePair> Client<T> {
135 pub fn new(master: Arc<Master>, caller_id: &str, service: &str) -> Client<T> {
136 Client {
137 info: std::sync::Arc::new(ClientInfo {
138 caller_id: String::from(caller_id),
139 service: String::from(service),
140 }),
141 uri_cache: std::sync::Arc::new(UriCache {
142 master,
143 data: Mutex::new(None),
144 service: String::from(service),
145 }),
146 phantom: std::marker::PhantomData,
147 }
148 }
149
150 fn probe_inner(&self, timeout: std::time::Duration) -> Result<()> {
151 let mut stream = connect_to_tcp_attempt(&self.uri_cache, Some(timeout))?;
152 exchange_probe_headers(&mut stream, &self.info.caller_id, &self.info.service)?;
153 Ok(())
154 }
155
156 pub fn probe(&self, timeout: std::time::Duration) -> Result<()> {
157 let probe_result = self.probe_inner(timeout);
158 if probe_result.is_err() {
159 self.uri_cache.clear();
160 }
161 probe_result
162 }
163
164 pub fn req(&self, args: &T::Request) -> Result<ServiceResult<T::Response>> {
165 Self::request_body(
166 args,
167 &self.uri_cache,
168 &self.info.caller_id,
169 &self.info.service,
170 )
171 }
172
173 pub fn req_async(&self, args: T::Request) -> ClientResponse<T::Response> {
174 let info = Arc::clone(&self.info);
175 let uri_cache = Arc::clone(&self.uri_cache);
176 ClientResponse {
177 handle: thread::spawn(move || {
178 Self::request_body(&args, &uri_cache, &info.caller_id, &info.service)
179 }),
180 }
181 }
182
183 fn request_body(
184 args: &T::Request,
185 uri_cache: &UriCache,
186 caller_id: &str,
187 service: &str,
188 ) -> Result<ServiceResult<T::Response>> {
189 let mut stream = connect_to_tcp_with_multiple_attempts(uri_cache, 15)
190 .chain_err(|| ErrorKind::ServiceConnectionFail(service.into()))?;
191
192 exchange_headers::<T, _>(&mut stream, caller_id, service)?;
194
195 let mut writer = io::Cursor::new(Vec::with_capacity(128));
196 writer.set_position(4);
198
199 args.encode(&mut writer)?;
200
201 let message_length = (writer.position() - 4) as u32;
203 writer.set_position(0);
204 message_length.encode(&mut writer)?;
205
206 stream.write_all(&writer.into_inner())?;
208
209 let success = read_verification_byte(&mut stream)
211 .chain_err(|| ErrorKind::ServiceResponseInterruption)?;
212 Ok(if success {
213 let _length = stream.read_u32::<LittleEndian>();
217
218 let data = RosMsg::decode(&mut stream)?;
219
220 let mut dump = vec![];
221 if let Err(err) = stream.read_to_end(&mut dump) {
222 error!("Failed to read from TCP stream: {:?}", err)
223 }
224
225 Ok(data)
226 } else {
227 let data = RosMsg::decode(&mut stream)?;
229
230 let mut dump = vec![];
231 if let Err(err) = stream.read_to_end(&mut dump) {
232 error!("Failed to read from TCP stream: {:?}", err)
233 }
234
235 Err(data)
236 })
237 }
238}
239
240#[inline]
241fn read_verification_byte<R: std::io::Read>(reader: &mut R) -> std::io::Result<bool> {
242 reader.read_u8().map(|v| v != 0)
243}
244
245fn write_request<T, U>(mut stream: &mut U, caller_id: &str, service: &str) -> Result<()>
246where
247 T: ServicePair,
248 U: std::io::Write,
249{
250 let mut fields = HashMap::<String, String>::new();
251 fields.insert(String::from("callerid"), String::from(caller_id));
252 fields.insert(String::from("service"), String::from(service));
253 fields.insert(String::from("md5sum"), T::md5sum());
254 fields.insert(String::from("type"), T::msg_type());
255 encode(&mut stream, &fields)?;
256 Ok(())
257}
258
259fn write_probe_request<U>(mut stream: &mut U, caller_id: &str, service: &str) -> Result<()>
260where
261 U: std::io::Write,
262{
263 let mut fields = HashMap::<String, String>::new();
264 fields.insert(String::from("probe"), String::from("1"));
265 fields.insert(String::from("callerid"), String::from(caller_id));
266 fields.insert(String::from("service"), String::from(service));
267 fields.insert(String::from("md5sum"), String::from("*"));
268 encode(&mut stream, &fields)?;
269 Ok(())
270}
271
272fn read_response<U>(mut stream: &mut U) -> Result<()>
273where
274 U: std::io::Read,
275{
276 let fields = decode(&mut stream)?;
277 if fields.get("callerid").is_none() {
278 bail!(ErrorKind::HeaderMissingField("callerid".into()));
279 }
280 Ok(())
281}
282
283fn exchange_headers<T, U>(stream: &mut U, caller_id: &str, service: &str) -> Result<()>
284where
285 T: ServicePair,
286 U: std::io::Write + std::io::Read,
287{
288 write_request::<T, U>(stream, caller_id, service)?;
289 read_response::<U>(stream)
290}
291
292fn exchange_probe_headers<U>(stream: &mut U, caller_id: &str, service: &str) -> Result<()>
293where
294 U: std::io::Write + std::io::Read,
295{
296 write_probe_request::<U>(stream, caller_id, service)?;
297 read_response::<U>(stream)
298}