rosrust/tcpros/
service.rs

1use super::error::{ErrorKind, Result};
2use super::header;
3use super::util::tcpconnection;
4use super::{ServicePair, ServiceResult};
5use crate::rosmsg::{encode_str, RosMsg};
6use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
7use error_chain::bail;
8use log::error;
9use std::collections::HashMap;
10use std::io;
11use std::net::{TcpListener, TcpStream};
12use std::sync::{atomic, Arc};
13use std::thread;
14
15pub struct Service {
16    pub api: String,
17    pub msg_type: String,
18    pub service: String,
19    exists: Arc<atomic::AtomicBool>,
20}
21
22impl Drop for Service {
23    fn drop(&mut self) {
24        self.exists.store(false, atomic::Ordering::SeqCst);
25    }
26}
27
28impl Service {
29    pub fn new<T, F>(
30        hostname: &str,
31        bind_address: &str,
32        port: u16,
33        service: &str,
34        node_name: &str,
35        handler: F,
36    ) -> Result<Service>
37    where
38        T: ServicePair,
39        F: Fn(T::Request) -> ServiceResult<T::Response> + Send + Sync + 'static,
40    {
41        let listener = TcpListener::bind((bind_address, port))?;
42        let socket_address = listener.local_addr()?;
43        let api = format!("rosrpc://{}:{}", hostname, socket_address.port());
44
45        let service_exists = Arc::new(atomic::AtomicBool::new(true));
46
47        let iterate_handler = {
48            let service_exists = service_exists.clone();
49            let service = String::from(service);
50            let node_name = String::from(node_name);
51            let handler = Arc::new(handler);
52            move |stream: TcpStream| {
53                if !service_exists.load(atomic::Ordering::SeqCst) {
54                    return tcpconnection::Feedback::StopAccepting;
55                }
56                consume_client::<T, _, _>(&service, &node_name, Arc::clone(&handler), stream);
57                tcpconnection::Feedback::AcceptNextStream
58            }
59        };
60
61        tcpconnection::iterate(listener, format!("service '{}'", service), iterate_handler);
62
63        Ok(Service {
64            api,
65            msg_type: T::msg_type(),
66            service: String::from(service),
67            exists: service_exists,
68        })
69    }
70}
71
72enum RequestType {
73    Probe,
74    Action,
75}
76
77fn consume_client<T, U, F>(service: &str, node_name: &str, handler: Arc<F>, mut stream: U)
78where
79    T: ServicePair,
80    U: std::io::Read + std::io::Write + Send + 'static,
81    F: Fn(T::Request) -> ServiceResult<T::Response> + Send + Sync + 'static,
82{
83    // Service request starts by exchanging connection headers
84    match exchange_headers::<T, _>(&mut stream, service, node_name) {
85        Err(err) => {
86            // Connection can be closed when a client checks for a service.
87            if !err.is_closed_connection() {
88                error!(
89                    "Failed to exchange headers for service '{}': {}",
90                    service, err
91                );
92            }
93        }
94        // Spawn a thread for handling requests
95        Ok(RequestType::Action) => spawn_request_handler::<T, U, F>(stream, Arc::clone(&handler)),
96        Ok(RequestType::Probe) => (),
97    }
98}
99
100fn exchange_headers<T, U>(stream: &mut U, service: &str, node_name: &str) -> Result<RequestType>
101where
102    T: ServicePair,
103    U: std::io::Write + std::io::Read,
104{
105    let req_type = read_request::<T, U>(stream, service)?;
106    write_response::<T, U>(stream, node_name)?;
107    Ok(req_type)
108}
109
110fn read_request<T: ServicePair, U: std::io::Read>(
111    stream: &mut U,
112    service: &str,
113) -> Result<RequestType> {
114    let fields = header::decode(stream)?;
115    header::match_field(&fields, "service", service)?;
116    if fields.get("callerid").is_none() {
117        bail!(ErrorKind::HeaderMissingField("callerid".into()));
118    }
119    if header::match_field(&fields, "probe", "1").is_ok() {
120        return Ok(RequestType::Probe);
121    }
122    header::match_field(&fields, "md5sum", &T::md5sum())?;
123    Ok(RequestType::Action)
124}
125
126fn write_response<T, U>(stream: &mut U, node_name: &str) -> Result<()>
127where
128    T: ServicePair,
129    U: std::io::Write,
130{
131    let mut fields = HashMap::<String, String>::new();
132    fields.insert(String::from("callerid"), String::from(node_name));
133    fields.insert(String::from("md5sum"), T::md5sum());
134    fields.insert(String::from("type"), T::msg_type());
135    header::encode(stream, &fields)?;
136    Ok(())
137}
138
139fn spawn_request_handler<T, U, F>(stream: U, handler: Arc<F>)
140where
141    T: ServicePair,
142    U: std::io::Read + std::io::Write + Send + 'static,
143    F: Fn(T::Request) -> ServiceResult<T::Response> + Send + Sync + 'static,
144{
145    thread::spawn(move || {
146        if let Err(err) = handle_request_loop::<T, U, F>(stream, &handler) {
147            if !err.is_closed_connection() {
148                let info = err
149                    .iter()
150                    .map(|v| format!("{}", v))
151                    .collect::<Vec<_>>()
152                    .join("\nCaused by:");
153                error!("{}", info);
154            }
155        }
156    });
157}
158
159fn handle_request_loop<T, U, F>(mut stream: U, handler: &F) -> Result<()>
160where
161    T: ServicePair,
162    U: std::io::Read + std::io::Write,
163    F: Fn(T::Request) -> ServiceResult<T::Response>,
164{
165    // Receive request from client
166    // TODO: validate message length
167    let _length = stream.read_u32::<LittleEndian>();
168    // Break out of loop in case of failure to read request
169    // TODO: handle retained connections
170    if let Ok(req) = RosMsg::decode(&mut stream) {
171        // Call function that handles request and returns response
172        match handler(req) {
173            Ok(res) => {
174                // Send True flag and response in case of success
175                stream.write_u8(1)?;
176                let mut writer = io::Cursor::new(Vec::with_capacity(128));
177                // skip the first 4 bytes that will contain the message length
178                writer.set_position(4);
179
180                res.encode(&mut writer)?;
181
182                // write the message length to the start of the header
183                let message_length = (writer.position() - 4) as u32;
184                writer.set_position(0);
185                message_length.encode(&mut writer)?;
186
187                stream.write_all(&writer.into_inner())?;
188            }
189            Err(message) => {
190                // Send False flag and error message string in case of failure
191                stream.write_u8(0)?;
192                RosMsg::encode(&message, &mut stream)?;
193            }
194        };
195    }
196
197    // Upon failure to read request, send client failure message
198    // This can be caused by actual issues or by the client stopping the connection
199    stream.write_u8(0)?;
200    encode_str("Failed to parse passed arguments", &mut stream)?;
201    Ok(())
202}