rouille/
lib.rs

1// Copyright (c) 2016 The Rouille developers
2// Licensed under the Apache License, Version 2.0
3// <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>,
6// at your option. All files in the project carrying such
7// notice may not be copied, modified, or distributed except
8// according to those terms.
9
10//! The rouille library is very easy to get started with.
11//!
12//! Listening to a port is done by calling the [`start_server`](fn.start_server.html) function:
13//!
14//! ```no_run
15//! use rouille::Request;
16//! use rouille::Response;
17//!
18//! rouille::start_server("0.0.0.0:80", move |request| {
19//!     Response::text("hello world")
20//! });
21//! ```
22//!
23//! Whenever an HTTP request is received on the address passed as first parameter, the closure
24//! passed as second parameter is called. This closure must then return a
25//! [`Response`](struct.Response.html) that will be sent back to the client.
26//!
27//! See the documentation of [`start_server`](fn.start_server.html) for more details.
28//!
29//! # Analyzing the request
30//!
31//! The parameter that the closure receives is a [`Request`](struct.Request.html) object that
32//! represents the request made by the client.
33//!
34//! The `Request` object itself provides some getters, but most advanced functionalities are
35//! provided by other modules of this crate.
36//!
37//! - In order to dispatch between various code depending on the URL, you can use the
38//!   [`router!`](macro.router.html) macro.
39//! - In order to analyze the body of the request, like handling JSON input, form input, etc. you
40//!   can take a look at [the `input` module](input/index.html).
41//!
42//! # Returning a response
43//!
44//! Once you analyzed the request, it is time to return a response by returning a
45//! [`Response`](struct.Response.html) object.
46//!
47//! All the members of `Response` are public, so you can customize it as you want. There are also
48//! several constructors that you build a basic `Response` which can then modify.
49//!
50//! In order to serve static files, take a look at
51//! [the `match_assets` function](fn.match_assets.html).
52//!
53//! In order to apply content encodings (including compression such as gzip or deflate), see
54//! the [content_encoding module](content_encoding/index.html), and specifically the
55//! [content_encoding::apply](content_encoding/fn.apply.html) function.
56
57#![deny(unsafe_code)]
58
59extern crate base64;
60#[cfg(feature = "brotli")]
61extern crate brotli;
62extern crate chrono;
63#[cfg(feature = "gzip")]
64extern crate deflate;
65extern crate filetime;
66extern crate multipart;
67extern crate rand;
68extern crate serde;
69#[macro_use]
70extern crate serde_derive;
71pub extern crate percent_encoding;
72extern crate serde_json;
73extern crate sha1_smol;
74extern crate threadpool;
75extern crate time;
76extern crate tiny_http;
77pub extern crate url;
78
79// https://github.com/servo/rust-url/blob/e121d8d0aafd50247de5f5310a227ecb1efe6ffe/percent_encoding/lib.rs#L126
80pub const DEFAULT_ENCODE_SET: &percent_encoding::AsciiSet = &percent_encoding::CONTROLS
81    .add(b' ')
82    .add(b'"')
83    .add(b'#')
84    .add(b'<')
85    .add(b'>')
86    .add(b'`')
87    .add(b'?')
88    .add(b'{')
89    .add(b'}');
90
91pub use assets::extension_to_mime;
92pub use assets::match_assets;
93pub use log::{log, log_custom};
94pub use response::{Response, ResponseBody};
95pub use tiny_http::ReadWrite;
96
97use std::error::Error;
98use std::fmt;
99use std::io::Cursor;
100use std::io::Read;
101use std::io::Result as IoResult;
102use std::marker::PhantomData;
103use std::net::SocketAddr;
104use std::net::ToSocketAddrs;
105use std::panic;
106use std::panic::AssertUnwindSafe;
107use std::slice::Iter as SliceIter;
108use std::sync::atomic::{AtomicUsize, Ordering};
109use std::sync::mpsc;
110use std::sync::Arc;
111use std::sync::Mutex;
112use std::thread;
113use std::time::Duration;
114
115pub mod cgi;
116pub mod content_encoding;
117pub mod input;
118pub mod proxy;
119pub mod session;
120pub mod websocket;
121
122mod assets;
123mod find_route;
124mod log;
125mod response;
126mod router;
127#[doc(hidden)]
128pub mod try_or_400;
129
130/// This macro assumes that the current function returns a `Response` and takes a `Result`.
131/// If the expression you pass to the macro is an error, then a 404 response is returned.
132#[macro_export]
133macro_rules! try_or_404 {
134    ($result:expr) => {
135        match $result {
136            Ok(r) => r,
137            Err(_) => return $crate::Response::empty_404(),
138        }
139    };
140}
141
142/// This macro assumes that the current function returns a `Response`. If the condition you pass
143/// to the macro is false, then a 400 response is returned.
144///
145/// # Example
146///
147/// ```
148/// # #[macro_use] extern crate rouille;
149/// # fn main() {
150/// use rouille::Request;
151/// use rouille::Response;
152///
153/// fn handle_something(request: &Request) -> Response {
154///     let data = try_or_400!(post_input!(request, {
155///         field1: u32,
156///         field2: String,
157///     }));
158///
159///     assert_or_400!(data.field1 >= 2);
160///     Response::text("hello")
161/// }
162/// # }
163/// ```
164#[macro_export]
165macro_rules! assert_or_400 {
166    ($cond:expr) => {
167        if !$cond {
168            return $crate::Response::empty_400();
169        }
170    };
171}
172
173/// Starts a server and uses the given requests handler.
174///
175/// The request handler takes a `&Request` and must return a `Response` to send to the user.
176///
177/// > **Note**: `start_server` is meant to be an easy-to-use function. If you want more control,
178/// > see [the `Server` struct](struct.Server.html).
179///
180/// # Common mistakes
181///
182/// The handler must capture its environment by value and not by reference (`'static`). If you
183/// use closure, don't forget to put `move` in front of the closure.
184///
185/// The handler must also be thread-safe (`Send` and `Sync`).
186/// For example this handler isn't thread-safe:
187///
188/// ```should_fail
189/// let mut requests_counter = 0;
190///
191/// rouille::start_server("localhost:80", move |request| {
192///     requests_counter += 1;
193///
194///     // ... rest of the handler ...
195/// # panic!()
196/// })
197/// ```
198///
199/// Multiple requests can be processed simultaneously, therefore you can't mutably access
200/// variables from the outside.
201///
202/// Instead you must use a `Mutex`:
203///
204/// ```no_run
205/// use std::sync::Mutex;
206/// let requests_counter = Mutex::new(0);
207///
208/// rouille::start_server("localhost:80", move |request| {
209///     *requests_counter.lock().unwrap() += 1;
210///
211///     // rest of the handler
212/// # panic!()
213/// })
214/// ```
215///
216/// # Panic handling in the handler
217///
218/// If your request handler panics, a 500 error will automatically be sent to the client.
219///
220/// # Panic
221///
222/// This function will panic if the server starts to fail (for example if you use a port that is
223/// already occupied) or if the socket is force-closed by the operating system.
224///
225/// If you need to handle these situations, please see `Server`.
226pub fn start_server<A, F>(addr: A, handler: F) -> !
227where
228    A: ToSocketAddrs,
229    F: Send + Sync + 'static + Fn(&Request) -> Response,
230{
231    Server::new(addr, handler)
232        .expect("Failed to start server")
233        .run();
234    panic!("The server socket closed unexpectedly")
235}
236
237/// Identical to `start_server` but uses a `ThreadPool` of the given size.
238///
239/// When `pool_size` is `None`, the thread pool size will default to `8 * num-cpus`.
240/// `pool_size` must be greater than zero or this function will panic.
241pub fn start_server_with_pool<A, F>(addr: A, pool_size: Option<usize>, handler: F) -> !
242where
243    A: ToSocketAddrs,
244    F: Send + Sync + 'static + Fn(&Request) -> Response,
245{
246    let pool_size = pool_size.unwrap_or_else(|| {
247        8 * thread::available_parallelism()
248            .map(|n| n.get())
249            .unwrap_or(1)
250    });
251
252    Server::new(addr, handler)
253        .expect("Failed to start server")
254        .pool_size(pool_size)
255        .run();
256    panic!("The server socket closed unexpectedly")
257}
258
259struct AtomicCounter(Arc<AtomicUsize>);
260
261impl AtomicCounter {
262    fn new(count: &Arc<AtomicUsize>) -> Self {
263        count.fetch_add(1, Ordering::Relaxed);
264        AtomicCounter(Arc::clone(count))
265    }
266}
267
268impl Drop for AtomicCounter {
269    fn drop(&mut self) {
270        self.0.fetch_sub(1, Ordering::Release);
271    }
272}
273
274/// Executes a function in either a thread of a thread pool
275enum Executor {
276    Threaded { count: Arc<AtomicUsize> },
277    Pooled { pool: threadpool::ThreadPool },
278}
279impl Executor {
280    /// `size` must be greater than zero or the call to `ThreadPool::new` will panic.
281    fn with_size(size: usize) -> Self {
282        let pool = threadpool::ThreadPool::new(size);
283        Executor::Pooled { pool }
284    }
285
286    #[inline]
287    fn execute<F: FnOnce() + Send + 'static>(&self, f: F) {
288        match *self {
289            Executor::Threaded { ref count } => {
290                let counter = AtomicCounter::new(count);
291                thread::spawn(move || {
292                    let _counter = counter;
293                    f()
294                });
295            }
296            Executor::Pooled { ref pool } => {
297                pool.execute(f);
298            }
299        }
300    }
301
302    fn join(&self) {
303        match *self {
304            Executor::Threaded { ref count } => {
305                while count.load(Ordering::Acquire) > 0 {
306                    thread::sleep(Duration::from_millis(100));
307                }
308            }
309            Executor::Pooled { ref pool } => {
310                pool.join();
311            }
312        }
313    }
314}
315
316impl Default for Executor {
317    fn default() -> Self {
318        Executor::Threaded {
319            count: Arc::new(AtomicUsize::new(0)),
320        }
321    }
322}
323
324/// A listening server.
325///
326/// This struct is the more manual server creation API of rouille and can be used as an alternative
327/// to the `start_server` function.
328///
329/// The `start_server` function is just a shortcut for `Server::new` followed with `run`. See the
330/// documentation of the `start_server` function for more details about the handler.
331///
332/// # Example
333///
334/// ```no_run
335/// use rouille::Server;
336/// use rouille::Response;
337///
338/// let server = Server::new("localhost:0", |request| {
339///     Response::text("hello world")
340/// }).unwrap();
341/// println!("Listening on {:?}", server.server_addr());
342/// server.run();
343/// ```
344pub struct Server<F> {
345    server: tiny_http::Server,
346    handler: Arc<AssertUnwindSafe<F>>,
347    executor: Executor,
348}
349
350impl<F> Server<F>
351where
352    F: Send + Sync + 'static + Fn(&Request) -> Response,
353{
354    /// Builds a new `Server` object.
355    ///
356    /// After this function returns, the HTTP server is listening.
357    ///
358    /// Returns an error if there was an error while creating the listening socket, for example if
359    /// the port is already in use.
360    pub fn new<A>(addr: A, handler: F) -> Result<Server<F>, Box<dyn Error + Send + Sync + 'static>>
361    where
362        A: ToSocketAddrs,
363    {
364        let server = tiny_http::Server::http(addr)?;
365        Ok(Server {
366            server,
367            executor: Executor::default(),
368            handler: Arc::new(AssertUnwindSafe(handler)), // TODO: using AssertUnwindSafe here is wrong, but unwind safety has some usability problems in Rust in general
369        })
370    }
371
372    /// Builds a new `Server` object with SSL support.
373    ///
374    /// After this function returns, the HTTPS server is listening.
375    ///
376    /// Returns an error if there was an error while creating the listening socket, for example if
377    /// the port is already in use.
378    #[cfg(any(feature = "ssl", feature = "rustls"))]
379    pub fn new_ssl<A>(
380        addr: A,
381        handler: F,
382        certificate: Vec<u8>,
383        private_key: Vec<u8>,
384    ) -> Result<Server<F>, Box<dyn Error + Send + Sync + 'static>>
385    where
386        A: ToSocketAddrs,
387    {
388        let ssl_config = tiny_http::SslConfig {
389            certificate,
390            private_key,
391        };
392        let server = tiny_http::Server::https(addr, ssl_config)?;
393        Ok(Server {
394            server,
395            executor: Executor::default(),
396            handler: Arc::new(AssertUnwindSafe(handler)), // TODO: using AssertUnwindSafe here is wrong, but unwind safety has some usability problems in Rust in general
397        })
398    }
399
400    /// Use a `ThreadPool` of the given size to process requests
401    ///
402    /// `pool_size` must be greater than zero or this function will panic.
403    pub fn pool_size(mut self, pool_size: usize) -> Self {
404        self.executor = Executor::with_size(pool_size);
405        self
406    }
407
408    /// Returns the address of the listening socket.
409    #[inline]
410    pub fn server_addr(&self) -> SocketAddr {
411        self.server
412            .server_addr()
413            .to_ip()
414            .expect("Unexpected Unix socket listener")
415    }
416
417    /// Runs the server forever, or until the listening socket is somehow force-closed by the
418    /// operating system.
419    #[inline]
420    pub fn run(self) {
421        for request in self.server.incoming_requests() {
422            self.process(request);
423        }
424    }
425
426    /// Processes all the client requests waiting to be processed, then returns.
427    ///
428    /// This function executes very quickly, as each client requests that needs to be processed
429    /// is processed in a separate thread.
430    #[inline]
431    pub fn poll(&self) {
432        while let Ok(Some(request)) = self.server.try_recv() {
433            self.process(request);
434        }
435    }
436
437    /// Creates a new thread for the server that can be gracefully stopped later.
438    ///
439    /// This function returns a tuple of a `JoinHandle` and a `Sender`.
440    /// You must call `JoinHandle::join()` otherwise the server will not run until completion.
441    /// The server can be stopped at will by sending it an empty `()` message from another thread.
442    /// There may be a maximum of a 1 second delay between sending the stop message and the server
443    /// stopping. This delay may be shortened in future.
444    ///
445    /// ```no_run
446    /// use std::thread;
447    /// use std::time::Duration;
448    /// use rouille::Server;
449    /// use rouille::Response;
450    ///
451    /// let server = Server::new("localhost:0", |request| {
452    ///     Response::text("hello world")
453    /// }).unwrap();
454    /// println!("Listening on {:?}", server.server_addr());
455    /// let (handle, sender) = server.stoppable();
456    ///
457    /// // Stop the server in 3 seconds
458    /// thread::spawn(move || {
459    ///     thread::sleep(Duration::from_secs(3));
460    ///     sender.send(()).unwrap();
461    /// });
462    ///
463    /// // Block the main thread until the server is stopped
464    /// handle.join().unwrap();
465    /// ```
466    #[inline]
467    pub fn stoppable(self) -> (thread::JoinHandle<()>, mpsc::Sender<()>) {
468        let (tx, rx) = mpsc::channel();
469        let handle = thread::spawn(move || {
470            while rx.try_recv().is_err() {
471                // In order to reduce CPU load wait 1s for a recv before looping again
472                while let Ok(Some(request)) = self.server.recv_timeout(Duration::from_secs(1)) {
473                    self.process(request);
474                }
475            }
476        });
477
478        (handle, tx)
479    }
480
481    /// Same as `poll()` but blocks for at most `duration` before returning.
482    ///
483    /// This function can be used implement a custom server loop in a more CPU-efficient manner
484    /// than calling `poll`.
485    ///
486    /// # Example
487    ///
488    /// ```no_run
489    /// use rouille::Server;
490    /// use rouille::Response;
491    ///
492    /// let server = Server::new("localhost:0", |request| {
493    ///     Response::text("hello world")
494    /// }).unwrap();
495    /// println!("Listening on {:?}", server.server_addr());
496    ///
497    /// loop {
498    ///     server.poll_timeout(std::time::Duration::from_millis(100));
499    /// }
500    /// ```
501    #[inline]
502    pub fn poll_timeout(&self, dur: std::time::Duration) {
503        while let Ok(Some(request)) = self.server.recv_timeout(dur) {
504            self.process(request);
505        }
506    }
507
508    /// Waits for all in-flight requests to be processed. This is useful for implementing a graceful
509    /// shutdown.
510    ///
511    /// Note: new connections may still be accepted while we wait, and this function does not guarantee
512    /// to wait for those new requests. To implement a graceful shutdown or a clean rolling-update,
513    /// the following approach should be used:
514    ///
515    /// 1) Stop routing requests to this server. For a rolling update, requests should be routed
516    ///    to the new instance. This logic typically sits outside of your application.
517    ///
518    /// 2) Drain the queue of incoming connections by calling `poll_timeout` with a short timeout.
519    ///
520    /// 3) Wait for in-flight requests to be processed by using this method.
521    ///
522    /// # Example
523    /// ```no_run
524    /// # use std::time::Duration;
525    /// # use rouille::Server;
526    /// #
527    /// # let server = Server::new("", |_| unimplemented!()).unwrap();
528    /// # fn is_stopping() -> bool { unimplemented!() }
529    ///
530    /// // Accept connections until we receive a SIGTERM
531    /// while !is_stopping() {
532    ///     server.poll_timeout(Duration::from_millis(100));
533    /// }
534    ///
535    /// // We received a SIGTERM, but there still may be some queued connections,
536    /// // so wait for them to be accepted.
537    /// println!("Shutting down gracefully...");
538    /// server.poll_timeout(Duration::from_millis(100));
539    ///
540    /// // We can expect there to be no more queued connections now, but slow requests
541    /// // may still be in-flight, so wait for them to finish.
542    /// server.join();
543    /// ```
544    pub fn join(&self) {
545        self.executor.join();
546    }
547
548    // Internal function, called when we got a request from tiny-http that needs to be processed.
549    fn process(&self, request: tiny_http::Request) {
550        // We spawn a thread so that requests are processed in parallel.
551        let handler = self.handler.clone();
552        self.executor.execute(|| {
553            // Small helper struct that makes it possible to put
554            // a `tiny_http::Request` inside a `Box<Read>`.
555            struct RequestRead(Arc<Mutex<Option<tiny_http::Request>>>);
556            impl Read for RequestRead {
557                #[inline]
558                fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
559                    self.0
560                        .lock()
561                        .unwrap()
562                        .as_mut()
563                        .unwrap()
564                        .as_reader()
565                        .read(buf)
566                }
567            }
568
569            // Building the `Request` object.
570            let tiny_http_request;
571            let rouille_request = {
572                let url = request.url().to_owned();
573                let method = request.method().as_str().to_owned();
574                let headers = request
575                    .headers()
576                    .iter()
577                    .map(|h| (h.field.to_string(), h.value.clone().into()))
578                    .collect();
579                let remote_addr = request.remote_addr().copied();
580
581                tiny_http_request = Arc::new(Mutex::new(Some(request)));
582                let data = Arc::new(Mutex::new(Some(
583                    Box::new(RequestRead(tiny_http_request.clone())) as Box<_>,
584                )));
585
586                Request {
587                    url,
588                    method,
589                    headers,
590                    https: false,
591                    data,
592                    remote_addr,
593                }
594            };
595
596            // Calling the handler ; this most likely takes a lot of time.
597            // If the handler panics, we build a dummy response.
598            let mut rouille_response = {
599                // We don't use the `rouille_request` anymore after the panic, so it's ok to assert
600                // it's unwind safe.
601                let rouille_request = AssertUnwindSafe(rouille_request);
602                let res = panic::catch_unwind(move || {
603                    let rouille_request = rouille_request;
604                    handler(&rouille_request)
605                });
606
607                match res {
608                    Ok(r) => r,
609                    Err(_) => Response::html(
610                        "<h1>Internal Server Error</h1>\
611                                        <p>An internal error has occurred on the server.</p>",
612                    )
613                    .with_status_code(500),
614                }
615            };
616
617            // writing the response
618            let (res_data, res_len) = rouille_response.data.into_reader_and_size();
619            let mut response = tiny_http::Response::empty(rouille_response.status_code)
620                .with_data(res_data, res_len);
621
622            let mut upgrade_header = "".into();
623
624            for (key, value) in rouille_response.headers {
625                if key.eq_ignore_ascii_case("Content-Length") {
626                    continue;
627                }
628
629                if key.eq_ignore_ascii_case("Upgrade") {
630                    upgrade_header = value;
631                    continue;
632                }
633
634                if let Ok(header) = tiny_http::Header::from_bytes(key.as_bytes(), value.as_bytes())
635                {
636                    response.add_header(header);
637                } else {
638                    // TODO: ?
639                }
640            }
641
642            if let Some(ref mut upgrade) = rouille_response.upgrade {
643                let trq = tiny_http_request.lock().unwrap().take().unwrap();
644                let socket = trq.upgrade(&upgrade_header, response);
645                upgrade.build(socket);
646            } else {
647                // We don't really care if we fail to send the response to the client, as there's
648                // nothing we can do anyway.
649                let _ = tiny_http_request
650                    .lock()
651                    .unwrap()
652                    .take()
653                    .unwrap()
654                    .respond(response);
655            }
656        });
657    }
658}
659
660/// Trait for objects that can take ownership of a raw connection to the client data.
661///
662/// The purpose of this trait is to be used with the `Connection: Upgrade` header, hence its name.
663pub trait Upgrade {
664    /// Initializes the object with the given socket.
665    fn build(&mut self, socket: Box<dyn ReadWrite + Send>);
666}
667
668/// Represents a request that your handler must answer to.
669///
670/// This can be either a real request (received by the HTTP server) or a mock object created with
671/// one of the `fake_*` constructors.
672pub struct Request {
673    method: String,
674    url: String,
675    headers: Vec<(String, String)>,
676    https: bool,
677    data: Arc<Mutex<Option<Box<dyn Read + Send>>>>,
678    remote_addr: Option<SocketAddr>,
679}
680
681impl fmt::Debug for Request {
682    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
683        f.debug_struct("Request")
684            .field("method", &self.method)
685            .field("url", &self.url)
686            .field("headers", &self.headers)
687            .field("https", &self.https)
688            .field("remote_addr", &self.remote_addr)
689            .finish()
690    }
691}
692
693impl Request {
694    /// Builds a fake HTTP request to be used during tests.
695    ///
696    /// The remote address of the client will be `127.0.0.1:12345`. Use `fake_http_from` to
697    /// specify what the client's address should be.
698    pub fn fake_http<U, M>(
699        method: M,
700        url: U,
701        headers: Vec<(String, String)>,
702        data: Vec<u8>,
703    ) -> Request
704    where
705        U: Into<String>,
706        M: Into<String>,
707    {
708        let data = Arc::new(Mutex::new(Some(Box::new(Cursor::new(data)) as Box<_>)));
709        let remote_addr = Some("127.0.0.1:12345".parse().unwrap());
710
711        Request {
712            url: url.into(),
713            method: method.into(),
714            https: false,
715            data,
716            headers,
717            remote_addr,
718        }
719    }
720
721    /// Builds a fake HTTP request to be used during tests.
722    pub fn fake_http_from<U, M>(
723        from: SocketAddr,
724        method: M,
725        url: U,
726        headers: Vec<(String, String)>,
727        data: Vec<u8>,
728    ) -> Request
729    where
730        U: Into<String>,
731        M: Into<String>,
732    {
733        let data = Arc::new(Mutex::new(Some(Box::new(Cursor::new(data)) as Box<_>)));
734
735        Request {
736            url: url.into(),
737            method: method.into(),
738            https: false,
739            data,
740            headers,
741            remote_addr: Some(from),
742        }
743    }
744
745    /// Builds a fake HTTPS request to be used during tests.
746    ///
747    /// The remote address of the client will be `127.0.0.1:12345`. Use `fake_https_from` to
748    /// specify what the client's address should be.
749    pub fn fake_https<U, M>(
750        method: M,
751        url: U,
752        headers: Vec<(String, String)>,
753        data: Vec<u8>,
754    ) -> Request
755    where
756        U: Into<String>,
757        M: Into<String>,
758    {
759        let data = Arc::new(Mutex::new(Some(Box::new(Cursor::new(data)) as Box<_>)));
760        let remote_addr = Some("127.0.0.1:12345".parse().unwrap());
761
762        Request {
763            url: url.into(),
764            method: method.into(),
765            https: true,
766            data,
767            headers,
768            remote_addr,
769        }
770    }
771
772    /// Builds a fake HTTPS request to be used during tests.
773    pub fn fake_https_from<U, M>(
774        from: SocketAddr,
775        method: M,
776        url: U,
777        headers: Vec<(String, String)>,
778        data: Vec<u8>,
779    ) -> Request
780    where
781        U: Into<String>,
782        M: Into<String>,
783    {
784        let data = Arc::new(Mutex::new(Some(Box::new(Cursor::new(data)) as Box<_>)));
785
786        Request {
787            url: url.into(),
788            method: method.into(),
789            https: true,
790            data,
791            headers,
792            remote_addr: Some(from),
793        }
794    }
795
796    /// If the decoded URL of the request starts with `prefix`, builds a new `Request` that is
797    /// the same as the original but without that prefix.
798    ///
799    /// # Example
800    ///
801    /// ```
802    /// # use rouille::Request;
803    /// # use rouille::Response;
804    /// fn handle(request: &Request) -> Response {
805    ///     if let Some(request) = request.remove_prefix("/static") {
806    ///         return rouille::match_assets(&request, "/static");
807    ///     }
808    ///
809    ///     // ...
810    ///     # panic!()
811    /// }
812    /// ```
813    pub fn remove_prefix(&self, prefix: &str) -> Option<Request> {
814        if !self.url().starts_with(prefix) {
815            return None;
816        }
817
818        // TODO: url-encoded characters in the prefix are not implemented
819        assert!(self.url.starts_with(prefix));
820        Some(Request {
821            method: self.method.clone(),
822            url: self.url[prefix.len()..].to_owned(),
823            headers: self.headers.clone(), // TODO: expensive
824            https: self.https,
825            data: self.data.clone(),
826            remote_addr: self.remote_addr,
827        })
828    }
829
830    /// Returns `true` if the request uses HTTPS, and `false` if it uses HTTP.
831    ///
832    /// # Example
833    ///
834    /// ```
835    /// use rouille::{Request, Response};
836    ///
837    /// fn handle(request: &Request) -> Response {
838    ///     if !request.is_secure() {
839    ///         return Response::redirect_303(format!("https://example.com"));
840    ///     }
841    ///
842    ///     // ...
843    /// # panic!()
844    /// }
845    /// ```
846    #[inline]
847    pub fn is_secure(&self) -> bool {
848        self.https
849    }
850
851    /// Returns the method of the request (`GET`, `POST`, etc.).
852    #[inline]
853    pub fn method(&self) -> &str {
854        &self.method
855    }
856
857    /// Returns the raw URL requested by the client. It is not decoded and thus can contain strings
858    /// such as `%20`, and the query parameters such as `?p=hello`.
859    ///
860    /// See also `url()`.
861    ///
862    /// # Example
863    ///
864    /// ```
865    /// use rouille::Request;
866    ///
867    /// let request = Request::fake_http("GET", "/hello%20world?foo=bar", vec![], vec![]);
868    /// assert_eq!(request.raw_url(), "/hello%20world?foo=bar");
869    /// ```
870    #[inline]
871    pub fn raw_url(&self) -> &str {
872        &self.url
873    }
874
875    /// Returns the raw query string requested by the client. In other words, everything after the
876    /// first `?` in the raw url.
877    ///
878    /// Returns the empty string if no query string.
879    #[inline]
880    pub fn raw_query_string(&self) -> &str {
881        if let Some(pos) = self.url.bytes().position(|c| c == b'?') {
882            self.url.split_at(pos + 1).1
883        } else {
884            ""
885        }
886    }
887
888    /// Returns the URL requested by the client.
889    ///
890    /// Contrary to `raw_url`, special characters have been decoded and the query string
891    /// (eg `?p=hello`) has been removed.
892    ///
893    /// If there is any non-unicode character in the URL, it will be replaced with `U+FFFD`.
894    ///
895    /// > **Note**: This function will decode the token `%2F` will be decoded as `/`. However the
896    /// > official specifications say that such a token must not count as a delimiter for URL paths.
897    /// > In other words, `/hello/world` is not the same as `/hello%2Fworld`.
898    ///
899    /// # Example
900    ///
901    /// ```
902    /// use rouille::Request;
903    ///
904    /// let request = Request::fake_http("GET", "/hello%20world?foo=bar", vec![], vec![]);
905    /// assert_eq!(request.url(), "/hello world");
906    /// ```
907    pub fn url(&self) -> String {
908        let url = self.url.as_bytes();
909        let url = if let Some(pos) = url.iter().position(|&c| c == b'?') {
910            &url[..pos]
911        } else {
912            url
913        };
914
915        percent_encoding::percent_decode(url)
916            .decode_utf8_lossy()
917            .into_owned()
918    }
919
920    /// Returns the value of a GET parameter or None if it doesn't exist.
921    pub fn get_param(&self, param_name: &str) -> Option<String> {
922        let name_pattern = &format!("{}=", param_name);
923        let param_pairs = self.raw_query_string().split('&');
924        param_pairs
925            .filter(|pair| pair.starts_with(name_pattern) || pair == &param_name)
926            .map(|pair| pair.split('=').nth(1).unwrap_or(""))
927            .next()
928            .map(|value| {
929                percent_encoding::percent_decode(value.replace('+', " ").as_bytes())
930                    .decode_utf8_lossy()
931                    .into_owned()
932            })
933    }
934
935    /// Returns the value of a header of the request.
936    ///
937    /// Returns `None` if no such header could be found.
938    #[inline]
939    pub fn header(&self, key: &str) -> Option<&str> {
940        self.headers
941            .iter()
942            .find(|&&(ref k, _)| k.eq_ignore_ascii_case(key))
943            .map(|&(_, ref v)| &v[..])
944    }
945
946    /// Returns a list of all the headers of the request.
947    #[inline]
948    pub fn headers(&self) -> HeadersIter {
949        HeadersIter {
950            iter: self.headers.iter(),
951        }
952    }
953
954    /// Returns the state of the `DNT` (Do Not Track) header.
955    ///
956    /// If the header is missing or is malformed, `None` is returned. If the header exists,
957    /// `Some(true)` is returned if `DNT` is `1` and `Some(false)` is returned if `DNT` is `0`.
958    ///
959    /// # Example
960    ///
961    /// ```
962    /// use rouille::{Request, Response};
963    ///
964    /// # fn track_user(request: &Request) {}
965    /// fn handle(request: &Request) -> Response {
966    ///     if !request.do_not_track().unwrap_or(false) {
967    ///         track_user(&request);
968    ///     }
969    ///
970    ///     // ...
971    /// # panic!()
972    /// }
973    /// ```
974    pub fn do_not_track(&self) -> Option<bool> {
975        match self.header("DNT") {
976            Some(h) if h == "1" => Some(true),
977            Some(h) if h == "0" => Some(false),
978            _ => None,
979        }
980    }
981
982    /// Returns the body of the request.
983    ///
984    /// The body can only be retrieved once. Returns `None` is the body has already been retrieved
985    /// before.
986    ///
987    /// # Example
988    ///
989    /// ```
990    /// use std::io::Read;
991    /// use rouille::{Request, Response, ResponseBody};
992    ///
993    /// fn echo(request: &Request) -> Response {
994    ///     let mut data = request.data().expect("Oops, body already retrieved, problem \
995    ///                                           in the server");
996    ///
997    ///     let mut buf = Vec::new();
998    ///     match data.read_to_end(&mut buf) {
999    ///         Ok(_) => (),
1000    ///         Err(_) => return Response::text("Failed to read body")
1001    ///     };
1002    ///
1003    ///     Response {
1004    ///         data: ResponseBody::from_data(buf),
1005    ///         .. Response::text("")
1006    ///     }
1007    /// }
1008    /// ```
1009    pub fn data(&self) -> Option<RequestBody> {
1010        let reader = self.data.lock().unwrap().take();
1011        reader.map(|r| RequestBody {
1012            body: r,
1013            marker: PhantomData,
1014        })
1015    }
1016
1017    /// Returns the address of the client that made this request.
1018    ///
1019    /// # Example
1020    ///
1021    /// ```
1022    /// use rouille::{Request, Response};
1023    ///
1024    /// fn handle(request: &Request) -> Response {
1025    ///     Response::text(format!("Your IP is: {:?}", request.remote_addr()))
1026    /// }
1027    /// ```
1028    #[inline]
1029    pub fn remote_addr(&self) -> &SocketAddr {
1030        self.remote_addr
1031            .as_ref()
1032            .expect("Unexpected Unix socket for request")
1033    }
1034}
1035
1036/// Iterator to the list of headers in a request.
1037#[derive(Debug, Clone)]
1038pub struct HeadersIter<'a> {
1039    iter: SliceIter<'a, (String, String)>,
1040}
1041
1042impl<'a> Iterator for HeadersIter<'a> {
1043    type Item = (&'a str, &'a str);
1044
1045    #[inline]
1046    fn next(&mut self) -> Option<Self::Item> {
1047        self.iter.next().map(|&(ref k, ref v)| (&k[..], &v[..]))
1048    }
1049
1050    #[inline]
1051    fn size_hint(&self) -> (usize, Option<usize>) {
1052        self.iter.size_hint()
1053    }
1054}
1055
1056impl<'a> ExactSizeIterator for HeadersIter<'a> {}
1057
1058/// Gives access to the body of a request.
1059///
1060/// In order to obtain this object, call `request.data()`.
1061pub struct RequestBody<'a> {
1062    body: Box<dyn Read + Send>,
1063    marker: PhantomData<&'a ()>,
1064}
1065
1066impl<'a> Read for RequestBody<'a> {
1067    #[inline]
1068    fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
1069        self.body.read(buf)
1070    }
1071}
1072
1073#[cfg(test)]
1074mod tests {
1075    use Request;
1076
1077    #[test]
1078    fn header() {
1079        let request = Request::fake_http(
1080            "GET",
1081            "/",
1082            vec![("Host".to_owned(), "localhost".to_owned())],
1083            vec![],
1084        );
1085        assert_eq!(request.header("Host"), Some("localhost"));
1086        assert_eq!(request.header("host"), Some("localhost"));
1087    }
1088
1089    #[test]
1090    fn get_param() {
1091        let request = Request::fake_http("GET", "/?p=hello", vec![], vec![]);
1092        assert_eq!(request.get_param("p"), Some("hello".to_owned()));
1093    }
1094
1095    #[test]
1096    fn get_param_multiple_param() {
1097        let request = Request::fake_http("GET", "/?foo=bar&message=hello", vec![], vec![]);
1098        assert_eq!(request.get_param("message"), Some("hello".to_owned()));
1099    }
1100    #[test]
1101    fn get_param_no_match() {
1102        let request = Request::fake_http("GET", "/?hello=world", vec![], vec![]);
1103        assert_eq!(request.get_param("foo"), None);
1104    }
1105
1106    #[test]
1107    fn get_param_partial_suffix_match() {
1108        let request = Request::fake_http("GET", "/?hello=world", vec![], vec![]);
1109        assert_eq!(request.get_param("lo"), None);
1110    }
1111
1112    #[test]
1113    fn get_param_partial_prefix_match() {
1114        let request = Request::fake_http("GET", "/?hello=world", vec![], vec![]);
1115        assert_eq!(request.get_param("he"), None);
1116    }
1117
1118    #[test]
1119    fn get_param_superstring_match() {
1120        let request = Request::fake_http("GET", "/?jan=01", vec![], vec![]);
1121        assert_eq!(request.get_param("january"), None);
1122    }
1123
1124    #[test]
1125    fn get_param_flag_with_equals() {
1126        let request = Request::fake_http("GET", "/?flag=", vec![], vec![]);
1127        assert_eq!(request.get_param("flag"), Some("".to_owned()));
1128    }
1129
1130    #[test]
1131    fn get_param_flag_without_equals() {
1132        let request = Request::fake_http("GET", "/?flag", vec![], vec![]);
1133        assert_eq!(request.get_param("flag"), Some("".to_owned()));
1134    }
1135
1136    #[test]
1137    fn get_param_flag_with_multiple_params() {
1138        let request = Request::fake_http("GET", "/?flag&foo=bar", vec![], vec![]);
1139        assert_eq!(request.get_param("flag"), Some("".to_owned()));
1140    }
1141
1142    #[test]
1143    fn body_twice() {
1144        let request = Request::fake_http("GET", "/", vec![], vec![62, 62, 62]);
1145        assert!(request.data().is_some());
1146        assert!(request.data().is_none());
1147    }
1148
1149    #[test]
1150    fn url_strips_get_query() {
1151        let request = Request::fake_http("GET", "/?p=hello", vec![], vec![]);
1152        assert_eq!(request.url(), "/");
1153    }
1154
1155    #[test]
1156    fn urlencode_query_string() {
1157        let request = Request::fake_http("GET", "/?p=hello%20world", vec![], vec![]);
1158        assert_eq!(request.get_param("p"), Some("hello world".to_owned()));
1159    }
1160
1161    #[test]
1162    fn plus_in_query_string() {
1163        let request = Request::fake_http("GET", "/?p=hello+world", vec![], vec![]);
1164        assert_eq!(request.get_param("p"), Some("hello world".to_owned()));
1165    }
1166
1167    #[test]
1168    fn encoded_plus_in_query_string() {
1169        let request = Request::fake_http("GET", "/?p=hello%2Bworld", vec![], vec![]);
1170        assert_eq!(request.get_param("p"), Some("hello+world".to_owned()));
1171    }
1172
1173    #[test]
1174    fn url_encode() {
1175        let request = Request::fake_http("GET", "/hello%20world", vec![], vec![]);
1176        assert_eq!(request.url(), "/hello world");
1177    }
1178
1179    #[test]
1180    fn plus_in_url() {
1181        let request = Request::fake_http("GET", "/hello+world", vec![], vec![]);
1182        assert_eq!(request.url(), "/hello+world");
1183    }
1184
1185    #[test]
1186    fn dnt() {
1187        let request =
1188            Request::fake_http("GET", "/", vec![("DNT".to_owned(), "1".to_owned())], vec![]);
1189        assert_eq!(request.do_not_track(), Some(true));
1190
1191        let request =
1192            Request::fake_http("GET", "/", vec![("DNT".to_owned(), "0".to_owned())], vec![]);
1193        assert_eq!(request.do_not_track(), Some(false));
1194
1195        let request = Request::fake_http("GET", "/", vec![], vec![]);
1196        assert_eq!(request.do_not_track(), None);
1197
1198        let request = Request::fake_http(
1199            "GET",
1200            "/",
1201            vec![("DNT".to_owned(), "malformed".to_owned())],
1202            vec![],
1203        );
1204        assert_eq!(request.do_not_track(), None);
1205    }
1206}