1use super::error::{ErrorKind, Result};
2use super::header;
3use super::util::tcpconnection;
4use super::{ServicePair, ServiceResult};
5use crate::rosmsg::{encode_str, RosMsg};
6use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
7use error_chain::bail;
8use log::error;
9use std::collections::HashMap;
10use std::io;
11use std::net::{TcpListener, TcpStream};
12use std::sync::{atomic, Arc};
13use std::thread;
14
15pub struct Service {
16 pub api: String,
17 pub msg_type: String,
18 pub service: String,
19 exists: Arc<atomic::AtomicBool>,
20}
21
22impl Drop for Service {
23 fn drop(&mut self) {
24 self.exists.store(false, atomic::Ordering::SeqCst);
25 }
26}
27
28impl Service {
29 pub fn new<T, F>(
30 hostname: &str,
31 bind_address: &str,
32 port: u16,
33 service: &str,
34 node_name: &str,
35 handler: F,
36 ) -> Result<Service>
37 where
38 T: ServicePair,
39 F: Fn(T::Request) -> ServiceResult<T::Response> + Send + Sync + 'static,
40 {
41 let listener = TcpListener::bind((bind_address, port))?;
42 let socket_address = listener.local_addr()?;
43 let api = format!("rosrpc://{}:{}", hostname, socket_address.port());
44
45 let service_exists = Arc::new(atomic::AtomicBool::new(true));
46
47 let iterate_handler = {
48 let service_exists = service_exists.clone();
49 let service = String::from(service);
50 let node_name = String::from(node_name);
51 let handler = Arc::new(handler);
52 move |stream: TcpStream| {
53 if !service_exists.load(atomic::Ordering::SeqCst) {
54 return tcpconnection::Feedback::StopAccepting;
55 }
56 consume_client::<T, _, _>(&service, &node_name, Arc::clone(&handler), stream);
57 tcpconnection::Feedback::AcceptNextStream
58 }
59 };
60
61 tcpconnection::iterate(listener, format!("service '{}'", service), iterate_handler);
62
63 Ok(Service {
64 api,
65 msg_type: T::msg_type(),
66 service: String::from(service),
67 exists: service_exists,
68 })
69 }
70}
71
72enum RequestType {
73 Probe,
74 Action,
75}
76
77fn consume_client<T, U, F>(service: &str, node_name: &str, handler: Arc<F>, mut stream: U)
78where
79 T: ServicePair,
80 U: std::io::Read + std::io::Write + Send + 'static,
81 F: Fn(T::Request) -> ServiceResult<T::Response> + Send + Sync + 'static,
82{
83 match exchange_headers::<T, _>(&mut stream, service, node_name) {
85 Err(err) => {
86 if !err.is_closed_connection() {
88 error!(
89 "Failed to exchange headers for service '{}': {}",
90 service, err
91 );
92 }
93 }
94 Ok(RequestType::Action) => spawn_request_handler::<T, U, F>(stream, Arc::clone(&handler)),
96 Ok(RequestType::Probe) => (),
97 }
98}
99
100fn exchange_headers<T, U>(stream: &mut U, service: &str, node_name: &str) -> Result<RequestType>
101where
102 T: ServicePair,
103 U: std::io::Write + std::io::Read,
104{
105 let req_type = read_request::<T, U>(stream, service)?;
106 write_response::<T, U>(stream, node_name)?;
107 Ok(req_type)
108}
109
110fn read_request<T: ServicePair, U: std::io::Read>(
111 stream: &mut U,
112 service: &str,
113) -> Result<RequestType> {
114 let fields = header::decode(stream)?;
115 header::match_field(&fields, "service", service)?;
116 if fields.get("callerid").is_none() {
117 bail!(ErrorKind::HeaderMissingField("callerid".into()));
118 }
119 if header::match_field(&fields, "probe", "1").is_ok() {
120 return Ok(RequestType::Probe);
121 }
122 header::match_field(&fields, "md5sum", &T::md5sum())?;
123 Ok(RequestType::Action)
124}
125
126fn write_response<T, U>(stream: &mut U, node_name: &str) -> Result<()>
127where
128 T: ServicePair,
129 U: std::io::Write,
130{
131 let mut fields = HashMap::<String, String>::new();
132 fields.insert(String::from("callerid"), String::from(node_name));
133 fields.insert(String::from("md5sum"), T::md5sum());
134 fields.insert(String::from("type"), T::msg_type());
135 header::encode(stream, &fields)?;
136 Ok(())
137}
138
139fn spawn_request_handler<T, U, F>(stream: U, handler: Arc<F>)
140where
141 T: ServicePair,
142 U: std::io::Read + std::io::Write + Send + 'static,
143 F: Fn(T::Request) -> ServiceResult<T::Response> + Send + Sync + 'static,
144{
145 thread::spawn(move || {
146 if let Err(err) = handle_request_loop::<T, U, F>(stream, &handler) {
147 if !err.is_closed_connection() {
148 let info = err
149 .iter()
150 .map(|v| format!("{}", v))
151 .collect::<Vec<_>>()
152 .join("\nCaused by:");
153 error!("{}", info);
154 }
155 }
156 });
157}
158
159fn handle_request_loop<T, U, F>(mut stream: U, handler: &F) -> Result<()>
160where
161 T: ServicePair,
162 U: std::io::Read + std::io::Write,
163 F: Fn(T::Request) -> ServiceResult<T::Response>,
164{
165 let _length = stream.read_u32::<LittleEndian>();
168 if let Ok(req) = RosMsg::decode(&mut stream) {
171 match handler(req) {
173 Ok(res) => {
174 stream.write_u8(1)?;
176 let mut writer = io::Cursor::new(Vec::with_capacity(128));
177 writer.set_position(4);
179
180 res.encode(&mut writer)?;
181
182 let message_length = (writer.position() - 4) as u32;
184 writer.set_position(0);
185 message_length.encode(&mut writer)?;
186
187 stream.write_all(&writer.into_inner())?;
188 }
189 Err(message) => {
190 stream.write_u8(0)?;
192 RosMsg::encode(&message, &mut stream)?;
193 }
194 };
195 }
196
197 stream.write_u8(0)?;
200 encode_str("Failed to parse passed arguments", &mut stream)?;
201 Ok(())
202}