xml_rpc/
server.rs

1use rouille;
2use serde::{Deserialize, Serialize};
3use std;
4use std::collections::HashMap;
5
6use super::error::{ErrorKind, Result};
7use super::xmlfmt::{error, from_params, into_params, parse, Call, Fault, Response, Value};
8
9type Handler = Box<dyn Fn(Vec<Value>) -> Response + Send + Sync>;
10type HandlerMap = HashMap<String, Handler>;
11
12pub fn on_decode_fail(err: &error::Error) -> Response {
13    Err(Fault::new(
14        400,
15        format!("Failed to decode request: {}", err),
16    ))
17}
18
19pub fn on_encode_fail(err: &error::Error) -> Response {
20    Err(Fault::new(
21        500,
22        format!("Failed to encode response: {}", err),
23    ))
24}
25
26fn on_missing_method(_: Vec<Value>) -> Response {
27    Err(Fault::new(404, "Requested method does not exist"))
28}
29
30pub struct Server {
31    handlers: HandlerMap,
32    on_missing_method: Handler,
33}
34
35impl Default for Server {
36    fn default() -> Self {
37        Server {
38            handlers: HashMap::new(),
39            on_missing_method: Box::new(on_missing_method),
40        }
41    }
42}
43
44impl Server {
45    pub fn new() -> Server {
46        Server::default()
47    }
48
49    pub fn register_value<K, T>(&mut self, name: K, handler: T)
50    where
51        K: Into<String>,
52        T: Fn(Vec<Value>) -> Response + Send + Sync + 'static,
53    {
54        self.handlers.insert(name.into(), Box::new(handler));
55    }
56
57    pub fn register<'a, K, Treq, Tres, Thandler, Tef, Tdf>(
58        &mut self,
59        name: K,
60        handler: Thandler,
61        encode_fail: Tef,
62        decode_fail: Tdf,
63    ) where
64        K: Into<String>,
65        Treq: Deserialize<'a>,
66        Tres: Serialize,
67        Thandler: Fn(Treq) -> std::result::Result<Tres, Fault> + Send + Sync + 'static,
68        Tef: Fn(&error::Error) -> Response + Send + Sync + 'static,
69        Tdf: Fn(&error::Error) -> Response + Send + Sync + 'static,
70    {
71        self.register_value(name, move |req| {
72            let params = match from_params(req) {
73                Ok(v) => v,
74                Err(err) => return decode_fail(&err),
75            };
76            let response = handler(params)?;
77            into_params(&response).or_else(|v| encode_fail(&v))
78        });
79    }
80
81    pub fn register_simple<'a, K, Treq, Tres, Thandler>(&mut self, name: K, handler: Thandler)
82    where
83        K: Into<String>,
84        Treq: Deserialize<'a>,
85        Tres: Serialize,
86        Thandler: Fn(Treq) -> std::result::Result<Tres, Fault> + Send + Sync + 'static,
87    {
88        self.register(name, handler, on_encode_fail, on_decode_fail);
89    }
90
91    pub fn set_on_missing<T>(&mut self, handler: T)
92    where
93        T: Fn(Vec<Value>) -> Response + Send + Sync + 'static,
94    {
95        self.on_missing_method = Box::new(handler);
96    }
97
98    pub fn bind(
99        self,
100        uri: &std::net::SocketAddr,
101    ) -> Result<BoundServer<impl Fn(&rouille::Request) -> rouille::Response + Send + Sync + 'static>>
102    {
103        rouille::Server::new(uri, move |req| self.handle_outer(req))
104            .map_err(|err| ErrorKind::BindFail(err.to_string()).into())
105            .map(BoundServer::new)
106    }
107
108    fn handle_outer(&self, request: &rouille::Request) -> rouille::Response {
109        use super::xmlfmt::value::ToXml;
110
111        let body = match request.data() {
112            Some(data) => data,
113            None => return rouille::Response::empty_400(),
114        };
115
116        // TODO: use the right error type
117        let call: Call = match parse::call(body) {
118            Ok(data) => data,
119            Err(_err) => return rouille::Response::empty_400(),
120        };
121        let res = self.handle(call);
122        let body = res.to_xml();
123        rouille::Response::from_data("text/xml", body)
124    }
125
126    fn handle(&self, req: Call) -> Response {
127        self.handlers
128            .get(&req.name)
129            .unwrap_or(&self.on_missing_method)(req.params)
130    }
131}
132
133pub struct BoundServer<F>
134where
135    F: Send + Sync + 'static + Fn(&rouille::Request) -> rouille::Response,
136{
137    server: rouille::Server<F>,
138    // server: hyper::Server<NewService, hyper::Body>,
139}
140
141impl<F> BoundServer<F>
142where
143    F: Send + Sync + 'static + Fn(&rouille::Request) -> rouille::Response,
144{
145    fn new(server: rouille::Server<F>) -> Self {
146        Self { server }
147    }
148
149    pub fn local_addr(&self) -> std::net::SocketAddr {
150        self.server.server_addr()
151    }
152
153    pub fn run(self) {
154        self.server.run()
155    }
156
157    pub fn poll(&self) {
158        self.server.poll()
159    }
160}