rosrust/tcpros/
client.rs

1use super::error::{ErrorKind, Result, ResultExt};
2use super::header::{decode, encode};
3use super::{ServicePair, ServiceResult};
4use crate::api::Master;
5use crate::rosmsg::RosMsg;
6use crate::util::FAILED_TO_LOCK;
7use byteorder::{LittleEndian, ReadBytesExt};
8use error_chain::bail;
9use log::error;
10use socket2::Socket;
11use std::collections::HashMap;
12use std::io;
13use std::io::{Read, Write};
14use std::net::{TcpStream, ToSocketAddrs};
15use std::sync::{Arc, Mutex};
16use std::thread;
17
18pub struct ClientResponse<T> {
19    handle: thread::JoinHandle<Result<ServiceResult<T>>>,
20}
21
22impl<T> ClientResponse<T> {
23    pub fn read(self) -> Result<ServiceResult<T>> {
24        self.handle
25            .join()
26            .unwrap_or_else(|_| Err(ErrorKind::ServiceResponseUnknown.into()))
27    }
28}
29
30impl<T: Send + 'static> ClientResponse<T> {
31    pub fn callback<F>(self, callback: F)
32    where
33        F: FnOnce(Result<ServiceResult<T>>) + Send + 'static,
34    {
35        thread::spawn(move || callback(self.read()));
36    }
37}
38
39struct ClientInfo {
40    caller_id: String,
41    service: String,
42}
43
44struct UriCache {
45    master: std::sync::Arc<Master>,
46    data: Mutex<Option<String>>,
47    service: String,
48}
49
50impl UriCache {
51    fn get(&self) -> Result<String> {
52        if let Some(uri) = Option::<String>::clone(&self.data.lock().expect(FAILED_TO_LOCK)) {
53            return Ok(uri);
54        }
55        let new_uri = self
56            .master
57            .lookup_service(&self.service)
58            .chain_err(|| ErrorKind::ServiceConnectionFail(self.service.clone()))?;
59        *self.data.lock().expect(FAILED_TO_LOCK) = Some(new_uri.clone());
60        Ok(new_uri)
61    }
62
63    fn clear(&self) {
64        *self.data.lock().expect(FAILED_TO_LOCK) = None;
65    }
66}
67
68#[derive(Clone)]
69pub struct Client<T: ServicePair> {
70    info: std::sync::Arc<ClientInfo>,
71    uri_cache: std::sync::Arc<UriCache>,
72    phantom: std::marker::PhantomData<T>,
73}
74
75fn connect_to_tcp_attempt(
76    uri_cache: &UriCache,
77    timeout: Option<std::time::Duration>,
78) -> Result<TcpStream> {
79    let uri = uri_cache.get()?;
80    let trimmed_uri = uri.trim_start_matches("rosrpc://");
81    let stream = match timeout {
82        Some(timeout) => {
83            let invalid_addr_error = || {
84                ErrorKind::Io(io::Error::new(
85                    io::ErrorKind::Other,
86                    "Invalid socket address",
87                ))
88            };
89            let socket_addr = trimmed_uri
90                .to_socket_addrs()
91                .chain_err(invalid_addr_error)?
92                .next()
93                .ok_or_else(invalid_addr_error)?;
94            TcpStream::connect_timeout(&socket_addr, timeout)?
95        }
96        None => TcpStream::connect(trimmed_uri)?,
97    };
98    let socket: Socket = stream.into();
99    if let Some(timeout) = timeout {
100        // In case defaults are not None, only apply if a timeout is passed
101        socket.set_read_timeout(Some(timeout))?;
102        socket.set_write_timeout(Some(timeout))?;
103    }
104    socket.set_linger(None)?;
105    let stream: TcpStream = socket.into();
106    Ok(stream)
107}
108
109fn connect_to_tcp_with_multiple_attempts(
110    uri_cache: &UriCache,
111    attempts: usize,
112) -> Result<TcpStream> {
113    let mut err = io::Error::new(
114        io::ErrorKind::Other,
115        "Tried to connect via TCP with 0 connection attempts",
116    )
117    .into();
118    let mut repeat_delay_ms = 1;
119    for _ in 0..attempts {
120        let stream_result = connect_to_tcp_attempt(uri_cache, None);
121        match stream_result {
122            Ok(stream) => {
123                return Ok(stream);
124            }
125            Err(error) => err = error,
126        }
127        uri_cache.clear();
128        std::thread::sleep(std::time::Duration::from_millis(repeat_delay_ms));
129        repeat_delay_ms *= 2;
130    }
131    Err(err)
132}
133
134impl<T: ServicePair> Client<T> {
135    pub fn new(master: Arc<Master>, caller_id: &str, service: &str) -> Client<T> {
136        Client {
137            info: std::sync::Arc::new(ClientInfo {
138                caller_id: String::from(caller_id),
139                service: String::from(service),
140            }),
141            uri_cache: std::sync::Arc::new(UriCache {
142                master,
143                data: Mutex::new(None),
144                service: String::from(service),
145            }),
146            phantom: std::marker::PhantomData,
147        }
148    }
149
150    fn probe_inner(&self, timeout: std::time::Duration) -> Result<()> {
151        let mut stream = connect_to_tcp_attempt(&self.uri_cache, Some(timeout))?;
152        exchange_probe_headers(&mut stream, &self.info.caller_id, &self.info.service)?;
153        Ok(())
154    }
155
156    pub fn probe(&self, timeout: std::time::Duration) -> Result<()> {
157        let probe_result = self.probe_inner(timeout);
158        if probe_result.is_err() {
159            self.uri_cache.clear();
160        }
161        probe_result
162    }
163
164    pub fn req(&self, args: &T::Request) -> Result<ServiceResult<T::Response>> {
165        Self::request_body(
166            args,
167            &self.uri_cache,
168            &self.info.caller_id,
169            &self.info.service,
170        )
171    }
172
173    pub fn req_async(&self, args: T::Request) -> ClientResponse<T::Response> {
174        let info = Arc::clone(&self.info);
175        let uri_cache = Arc::clone(&self.uri_cache);
176        ClientResponse {
177            handle: thread::spawn(move || {
178                Self::request_body(&args, &uri_cache, &info.caller_id, &info.service)
179            }),
180        }
181    }
182
183    fn request_body(
184        args: &T::Request,
185        uri_cache: &UriCache,
186        caller_id: &str,
187        service: &str,
188    ) -> Result<ServiceResult<T::Response>> {
189        let mut stream = connect_to_tcp_with_multiple_attempts(uri_cache, 15)
190            .chain_err(|| ErrorKind::ServiceConnectionFail(service.into()))?;
191
192        // Service request starts by exchanging connection headers
193        exchange_headers::<T, _>(&mut stream, caller_id, service)?;
194
195        let mut writer = io::Cursor::new(Vec::with_capacity(128));
196        // skip the first 4 bytes that will contain the message length
197        writer.set_position(4);
198
199        args.encode(&mut writer)?;
200
201        // write the message length to the start of the header
202        let message_length = (writer.position() - 4) as u32;
203        writer.set_position(0);
204        message_length.encode(&mut writer)?;
205
206        // Send request to service
207        stream.write_all(&writer.into_inner())?;
208
209        // Service responds with a boolean byte, signalling success
210        let success = read_verification_byte(&mut stream)
211            .chain_err(|| ErrorKind::ServiceResponseInterruption)?;
212        Ok(if success {
213            // Decode response as response type upon success
214
215            // TODO: validate response length
216            let _length = stream.read_u32::<LittleEndian>();
217
218            let data = RosMsg::decode(&mut stream)?;
219
220            let mut dump = vec![];
221            if let Err(err) = stream.read_to_end(&mut dump) {
222                error!("Failed to read from TCP stream: {:?}", err)
223            }
224
225            Ok(data)
226        } else {
227            // Decode response as string upon failure
228            let data = RosMsg::decode(&mut stream)?;
229
230            let mut dump = vec![];
231            if let Err(err) = stream.read_to_end(&mut dump) {
232                error!("Failed to read from TCP stream: {:?}", err)
233            }
234
235            Err(data)
236        })
237    }
238}
239
240#[inline]
241fn read_verification_byte<R: std::io::Read>(reader: &mut R) -> std::io::Result<bool> {
242    reader.read_u8().map(|v| v != 0)
243}
244
245fn write_request<T, U>(mut stream: &mut U, caller_id: &str, service: &str) -> Result<()>
246where
247    T: ServicePair,
248    U: std::io::Write,
249{
250    let mut fields = HashMap::<String, String>::new();
251    fields.insert(String::from("callerid"), String::from(caller_id));
252    fields.insert(String::from("service"), String::from(service));
253    fields.insert(String::from("md5sum"), T::md5sum());
254    fields.insert(String::from("type"), T::msg_type());
255    encode(&mut stream, &fields)?;
256    Ok(())
257}
258
259fn write_probe_request<U>(mut stream: &mut U, caller_id: &str, service: &str) -> Result<()>
260where
261    U: std::io::Write,
262{
263    let mut fields = HashMap::<String, String>::new();
264    fields.insert(String::from("probe"), String::from("1"));
265    fields.insert(String::from("callerid"), String::from(caller_id));
266    fields.insert(String::from("service"), String::from(service));
267    fields.insert(String::from("md5sum"), String::from("*"));
268    encode(&mut stream, &fields)?;
269    Ok(())
270}
271
272fn read_response<U>(mut stream: &mut U) -> Result<()>
273where
274    U: std::io::Read,
275{
276    let fields = decode(&mut stream)?;
277    if fields.get("callerid").is_none() {
278        bail!(ErrorKind::HeaderMissingField("callerid".into()));
279    }
280    Ok(())
281}
282
283fn exchange_headers<T, U>(stream: &mut U, caller_id: &str, service: &str) -> Result<()>
284where
285    T: ServicePair,
286    U: std::io::Write + std::io::Read,
287{
288    write_request::<T, U>(stream, caller_id, service)?;
289    read_response::<U>(stream)
290}
291
292fn exchange_probe_headers<U>(stream: &mut U, caller_id: &str, service: &str) -> Result<()>
293where
294    U: std::io::Write + std::io::Read,
295{
296    write_probe_request::<U>(stream, caller_id, service)?;
297    read_response::<U>(stream)
298}