rouille/lib.rs
1// Copyright (c) 2016 The Rouille developers
2// Licensed under the Apache License, Version 2.0
3// <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>,
6// at your option. All files in the project carrying such
7// notice may not be copied, modified, or distributed except
8// according to those terms.
9
10//! The rouille library is very easy to get started with.
11//!
12//! Listening to a port is done by calling the [`start_server`](fn.start_server.html) function:
13//!
14//! ```no_run
15//! use rouille::Request;
16//! use rouille::Response;
17//!
18//! rouille::start_server("0.0.0.0:80", move |request| {
19//! Response::text("hello world")
20//! });
21//! ```
22//!
23//! Whenever an HTTP request is received on the address passed as first parameter, the closure
24//! passed as second parameter is called. This closure must then return a
25//! [`Response`](struct.Response.html) that will be sent back to the client.
26//!
27//! See the documentation of [`start_server`](fn.start_server.html) for more details.
28//!
29//! # Analyzing the request
30//!
31//! The parameter that the closure receives is a [`Request`](struct.Request.html) object that
32//! represents the request made by the client.
33//!
34//! The `Request` object itself provides some getters, but most advanced functionalities are
35//! provided by other modules of this crate.
36//!
37//! - In order to dispatch between various code depending on the URL, you can use the
38//! [`router!`](macro.router.html) macro.
39//! - In order to analyze the body of the request, like handling JSON input, form input, etc. you
40//! can take a look at [the `input` module](input/index.html).
41//!
42//! # Returning a response
43//!
44//! Once you analyzed the request, it is time to return a response by returning a
45//! [`Response`](struct.Response.html) object.
46//!
47//! All the members of `Response` are public, so you can customize it as you want. There are also
48//! several constructors that you build a basic `Response` which can then modify.
49//!
50//! In order to serve static files, take a look at
51//! [the `match_assets` function](fn.match_assets.html).
52//!
53//! In order to apply content encodings (including compression such as gzip or deflate), see
54//! the [content_encoding module](content_encoding/index.html), and specifically the
55//! [content_encoding::apply](content_encoding/fn.apply.html) function.
56
57#![deny(unsafe_code)]
58
59extern crate base64;
60#[cfg(feature = "brotli")]
61extern crate brotli;
62extern crate chrono;
63#[cfg(feature = "gzip")]
64extern crate deflate;
65extern crate filetime;
66extern crate multipart;
67extern crate rand;
68extern crate serde;
69#[macro_use]
70extern crate serde_derive;
71pub extern crate percent_encoding;
72extern crate serde_json;
73extern crate sha1_smol;
74extern crate threadpool;
75extern crate time;
76extern crate tiny_http;
77pub extern crate url;
78
79// https://github.com/servo/rust-url/blob/e121d8d0aafd50247de5f5310a227ecb1efe6ffe/percent_encoding/lib.rs#L126
80pub const DEFAULT_ENCODE_SET: &percent_encoding::AsciiSet = &percent_encoding::CONTROLS
81 .add(b' ')
82 .add(b'"')
83 .add(b'#')
84 .add(b'<')
85 .add(b'>')
86 .add(b'`')
87 .add(b'?')
88 .add(b'{')
89 .add(b'}');
90
91pub use assets::extension_to_mime;
92pub use assets::match_assets;
93pub use log::{log, log_custom};
94pub use response::{Response, ResponseBody};
95pub use tiny_http::ReadWrite;
96
97use std::error::Error;
98use std::fmt;
99use std::io::Cursor;
100use std::io::Read;
101use std::io::Result as IoResult;
102use std::marker::PhantomData;
103use std::net::SocketAddr;
104use std::net::ToSocketAddrs;
105use std::panic;
106use std::panic::AssertUnwindSafe;
107use std::slice::Iter as SliceIter;
108use std::sync::atomic::{AtomicUsize, Ordering};
109use std::sync::mpsc;
110use std::sync::Arc;
111use std::sync::Mutex;
112use std::thread;
113use std::time::Duration;
114
115pub mod cgi;
116pub mod content_encoding;
117pub mod input;
118pub mod proxy;
119pub mod session;
120pub mod websocket;
121
122mod assets;
123mod find_route;
124mod log;
125mod response;
126mod router;
127#[doc(hidden)]
128pub mod try_or_400;
129
130/// This macro assumes that the current function returns a `Response` and takes a `Result`.
131/// If the expression you pass to the macro is an error, then a 404 response is returned.
132#[macro_export]
133macro_rules! try_or_404 {
134 ($result:expr) => {
135 match $result {
136 Ok(r) => r,
137 Err(_) => return $crate::Response::empty_404(),
138 }
139 };
140}
141
142/// This macro assumes that the current function returns a `Response`. If the condition you pass
143/// to the macro is false, then a 400 response is returned.
144///
145/// # Example
146///
147/// ```
148/// # #[macro_use] extern crate rouille;
149/// # fn main() {
150/// use rouille::Request;
151/// use rouille::Response;
152///
153/// fn handle_something(request: &Request) -> Response {
154/// let data = try_or_400!(post_input!(request, {
155/// field1: u32,
156/// field2: String,
157/// }));
158///
159/// assert_or_400!(data.field1 >= 2);
160/// Response::text("hello")
161/// }
162/// # }
163/// ```
164#[macro_export]
165macro_rules! assert_or_400 {
166 ($cond:expr) => {
167 if !$cond {
168 return $crate::Response::empty_400();
169 }
170 };
171}
172
173/// Starts a server and uses the given requests handler.
174///
175/// The request handler takes a `&Request` and must return a `Response` to send to the user.
176///
177/// > **Note**: `start_server` is meant to be an easy-to-use function. If you want more control,
178/// > see [the `Server` struct](struct.Server.html).
179///
180/// # Common mistakes
181///
182/// The handler must capture its environment by value and not by reference (`'static`). If you
183/// use closure, don't forget to put `move` in front of the closure.
184///
185/// The handler must also be thread-safe (`Send` and `Sync`).
186/// For example this handler isn't thread-safe:
187///
188/// ```should_fail
189/// let mut requests_counter = 0;
190///
191/// rouille::start_server("localhost:80", move |request| {
192/// requests_counter += 1;
193///
194/// // ... rest of the handler ...
195/// # panic!()
196/// })
197/// ```
198///
199/// Multiple requests can be processed simultaneously, therefore you can't mutably access
200/// variables from the outside.
201///
202/// Instead you must use a `Mutex`:
203///
204/// ```no_run
205/// use std::sync::Mutex;
206/// let requests_counter = Mutex::new(0);
207///
208/// rouille::start_server("localhost:80", move |request| {
209/// *requests_counter.lock().unwrap() += 1;
210///
211/// // rest of the handler
212/// # panic!()
213/// })
214/// ```
215///
216/// # Panic handling in the handler
217///
218/// If your request handler panics, a 500 error will automatically be sent to the client.
219///
220/// # Panic
221///
222/// This function will panic if the server starts to fail (for example if you use a port that is
223/// already occupied) or if the socket is force-closed by the operating system.
224///
225/// If you need to handle these situations, please see `Server`.
226pub fn start_server<A, F>(addr: A, handler: F) -> !
227where
228 A: ToSocketAddrs,
229 F: Send + Sync + 'static + Fn(&Request) -> Response,
230{
231 Server::new(addr, handler)
232 .expect("Failed to start server")
233 .run();
234 panic!("The server socket closed unexpectedly")
235}
236
237/// Identical to `start_server` but uses a `ThreadPool` of the given size.
238///
239/// When `pool_size` is `None`, the thread pool size will default to `8 * num-cpus`.
240/// `pool_size` must be greater than zero or this function will panic.
241pub fn start_server_with_pool<A, F>(addr: A, pool_size: Option<usize>, handler: F) -> !
242where
243 A: ToSocketAddrs,
244 F: Send + Sync + 'static + Fn(&Request) -> Response,
245{
246 let pool_size = pool_size.unwrap_or_else(|| {
247 8 * thread::available_parallelism()
248 .map(|n| n.get())
249 .unwrap_or(1)
250 });
251
252 Server::new(addr, handler)
253 .expect("Failed to start server")
254 .pool_size(pool_size)
255 .run();
256 panic!("The server socket closed unexpectedly")
257}
258
259struct AtomicCounter(Arc<AtomicUsize>);
260
261impl AtomicCounter {
262 fn new(count: &Arc<AtomicUsize>) -> Self {
263 count.fetch_add(1, Ordering::Relaxed);
264 AtomicCounter(Arc::clone(count))
265 }
266}
267
268impl Drop for AtomicCounter {
269 fn drop(&mut self) {
270 self.0.fetch_sub(1, Ordering::Release);
271 }
272}
273
274/// Executes a function in either a thread of a thread pool
275enum Executor {
276 Threaded { count: Arc<AtomicUsize> },
277 Pooled { pool: threadpool::ThreadPool },
278}
279impl Executor {
280 /// `size` must be greater than zero or the call to `ThreadPool::new` will panic.
281 fn with_size(size: usize) -> Self {
282 let pool = threadpool::ThreadPool::new(size);
283 Executor::Pooled { pool }
284 }
285
286 #[inline]
287 fn execute<F: FnOnce() + Send + 'static>(&self, f: F) {
288 match *self {
289 Executor::Threaded { ref count } => {
290 let counter = AtomicCounter::new(count);
291 thread::spawn(move || {
292 let _counter = counter;
293 f()
294 });
295 }
296 Executor::Pooled { ref pool } => {
297 pool.execute(f);
298 }
299 }
300 }
301
302 fn join(&self) {
303 match *self {
304 Executor::Threaded { ref count } => {
305 while count.load(Ordering::Acquire) > 0 {
306 thread::sleep(Duration::from_millis(100));
307 }
308 }
309 Executor::Pooled { ref pool } => {
310 pool.join();
311 }
312 }
313 }
314}
315
316impl Default for Executor {
317 fn default() -> Self {
318 Executor::Threaded {
319 count: Arc::new(AtomicUsize::new(0)),
320 }
321 }
322}
323
324/// A listening server.
325///
326/// This struct is the more manual server creation API of rouille and can be used as an alternative
327/// to the `start_server` function.
328///
329/// The `start_server` function is just a shortcut for `Server::new` followed with `run`. See the
330/// documentation of the `start_server` function for more details about the handler.
331///
332/// # Example
333///
334/// ```no_run
335/// use rouille::Server;
336/// use rouille::Response;
337///
338/// let server = Server::new("localhost:0", |request| {
339/// Response::text("hello world")
340/// }).unwrap();
341/// println!("Listening on {:?}", server.server_addr());
342/// server.run();
343/// ```
344pub struct Server<F> {
345 server: tiny_http::Server,
346 handler: Arc<AssertUnwindSafe<F>>,
347 executor: Executor,
348}
349
350impl<F> Server<F>
351where
352 F: Send + Sync + 'static + Fn(&Request) -> Response,
353{
354 /// Builds a new `Server` object.
355 ///
356 /// After this function returns, the HTTP server is listening.
357 ///
358 /// Returns an error if there was an error while creating the listening socket, for example if
359 /// the port is already in use.
360 pub fn new<A>(addr: A, handler: F) -> Result<Server<F>, Box<dyn Error + Send + Sync + 'static>>
361 where
362 A: ToSocketAddrs,
363 {
364 let server = tiny_http::Server::http(addr)?;
365 Ok(Server {
366 server,
367 executor: Executor::default(),
368 handler: Arc::new(AssertUnwindSafe(handler)), // TODO: using AssertUnwindSafe here is wrong, but unwind safety has some usability problems in Rust in general
369 })
370 }
371
372 /// Builds a new `Server` object with SSL support.
373 ///
374 /// After this function returns, the HTTPS server is listening.
375 ///
376 /// Returns an error if there was an error while creating the listening socket, for example if
377 /// the port is already in use.
378 #[cfg(any(feature = "ssl", feature = "rustls"))]
379 pub fn new_ssl<A>(
380 addr: A,
381 handler: F,
382 certificate: Vec<u8>,
383 private_key: Vec<u8>,
384 ) -> Result<Server<F>, Box<dyn Error + Send + Sync + 'static>>
385 where
386 A: ToSocketAddrs,
387 {
388 let ssl_config = tiny_http::SslConfig {
389 certificate,
390 private_key,
391 };
392 let server = tiny_http::Server::https(addr, ssl_config)?;
393 Ok(Server {
394 server,
395 executor: Executor::default(),
396 handler: Arc::new(AssertUnwindSafe(handler)), // TODO: using AssertUnwindSafe here is wrong, but unwind safety has some usability problems in Rust in general
397 })
398 }
399
400 /// Use a `ThreadPool` of the given size to process requests
401 ///
402 /// `pool_size` must be greater than zero or this function will panic.
403 pub fn pool_size(mut self, pool_size: usize) -> Self {
404 self.executor = Executor::with_size(pool_size);
405 self
406 }
407
408 /// Returns the address of the listening socket.
409 #[inline]
410 pub fn server_addr(&self) -> SocketAddr {
411 self.server
412 .server_addr()
413 .to_ip()
414 .expect("Unexpected Unix socket listener")
415 }
416
417 /// Runs the server forever, or until the listening socket is somehow force-closed by the
418 /// operating system.
419 #[inline]
420 pub fn run(self) {
421 for request in self.server.incoming_requests() {
422 self.process(request);
423 }
424 }
425
426 /// Processes all the client requests waiting to be processed, then returns.
427 ///
428 /// This function executes very quickly, as each client requests that needs to be processed
429 /// is processed in a separate thread.
430 #[inline]
431 pub fn poll(&self) {
432 while let Ok(Some(request)) = self.server.try_recv() {
433 self.process(request);
434 }
435 }
436
437 /// Creates a new thread for the server that can be gracefully stopped later.
438 ///
439 /// This function returns a tuple of a `JoinHandle` and a `Sender`.
440 /// You must call `JoinHandle::join()` otherwise the server will not run until completion.
441 /// The server can be stopped at will by sending it an empty `()` message from another thread.
442 /// There may be a maximum of a 1 second delay between sending the stop message and the server
443 /// stopping. This delay may be shortened in future.
444 ///
445 /// ```no_run
446 /// use std::thread;
447 /// use std::time::Duration;
448 /// use rouille::Server;
449 /// use rouille::Response;
450 ///
451 /// let server = Server::new("localhost:0", |request| {
452 /// Response::text("hello world")
453 /// }).unwrap();
454 /// println!("Listening on {:?}", server.server_addr());
455 /// let (handle, sender) = server.stoppable();
456 ///
457 /// // Stop the server in 3 seconds
458 /// thread::spawn(move || {
459 /// thread::sleep(Duration::from_secs(3));
460 /// sender.send(()).unwrap();
461 /// });
462 ///
463 /// // Block the main thread until the server is stopped
464 /// handle.join().unwrap();
465 /// ```
466 #[inline]
467 pub fn stoppable(self) -> (thread::JoinHandle<()>, mpsc::Sender<()>) {
468 let (tx, rx) = mpsc::channel();
469 let handle = thread::spawn(move || {
470 while rx.try_recv().is_err() {
471 // In order to reduce CPU load wait 1s for a recv before looping again
472 while let Ok(Some(request)) = self.server.recv_timeout(Duration::from_secs(1)) {
473 self.process(request);
474 }
475 }
476 });
477
478 (handle, tx)
479 }
480
481 /// Same as `poll()` but blocks for at most `duration` before returning.
482 ///
483 /// This function can be used implement a custom server loop in a more CPU-efficient manner
484 /// than calling `poll`.
485 ///
486 /// # Example
487 ///
488 /// ```no_run
489 /// use rouille::Server;
490 /// use rouille::Response;
491 ///
492 /// let server = Server::new("localhost:0", |request| {
493 /// Response::text("hello world")
494 /// }).unwrap();
495 /// println!("Listening on {:?}", server.server_addr());
496 ///
497 /// loop {
498 /// server.poll_timeout(std::time::Duration::from_millis(100));
499 /// }
500 /// ```
501 #[inline]
502 pub fn poll_timeout(&self, dur: std::time::Duration) {
503 while let Ok(Some(request)) = self.server.recv_timeout(dur) {
504 self.process(request);
505 }
506 }
507
508 /// Waits for all in-flight requests to be processed. This is useful for implementing a graceful
509 /// shutdown.
510 ///
511 /// Note: new connections may still be accepted while we wait, and this function does not guarantee
512 /// to wait for those new requests. To implement a graceful shutdown or a clean rolling-update,
513 /// the following approach should be used:
514 ///
515 /// 1) Stop routing requests to this server. For a rolling update, requests should be routed
516 /// to the new instance. This logic typically sits outside of your application.
517 ///
518 /// 2) Drain the queue of incoming connections by calling `poll_timeout` with a short timeout.
519 ///
520 /// 3) Wait for in-flight requests to be processed by using this method.
521 ///
522 /// # Example
523 /// ```no_run
524 /// # use std::time::Duration;
525 /// # use rouille::Server;
526 /// #
527 /// # let server = Server::new("", |_| unimplemented!()).unwrap();
528 /// # fn is_stopping() -> bool { unimplemented!() }
529 ///
530 /// // Accept connections until we receive a SIGTERM
531 /// while !is_stopping() {
532 /// server.poll_timeout(Duration::from_millis(100));
533 /// }
534 ///
535 /// // We received a SIGTERM, but there still may be some queued connections,
536 /// // so wait for them to be accepted.
537 /// println!("Shutting down gracefully...");
538 /// server.poll_timeout(Duration::from_millis(100));
539 ///
540 /// // We can expect there to be no more queued connections now, but slow requests
541 /// // may still be in-flight, so wait for them to finish.
542 /// server.join();
543 /// ```
544 pub fn join(&self) {
545 self.executor.join();
546 }
547
548 // Internal function, called when we got a request from tiny-http that needs to be processed.
549 fn process(&self, request: tiny_http::Request) {
550 // We spawn a thread so that requests are processed in parallel.
551 let handler = self.handler.clone();
552 self.executor.execute(|| {
553 // Small helper struct that makes it possible to put
554 // a `tiny_http::Request` inside a `Box<Read>`.
555 struct RequestRead(Arc<Mutex<Option<tiny_http::Request>>>);
556 impl Read for RequestRead {
557 #[inline]
558 fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
559 self.0
560 .lock()
561 .unwrap()
562 .as_mut()
563 .unwrap()
564 .as_reader()
565 .read(buf)
566 }
567 }
568
569 // Building the `Request` object.
570 let tiny_http_request;
571 let rouille_request = {
572 let url = request.url().to_owned();
573 let method = request.method().as_str().to_owned();
574 let headers = request
575 .headers()
576 .iter()
577 .map(|h| (h.field.to_string(), h.value.clone().into()))
578 .collect();
579 let remote_addr = request.remote_addr().copied();
580
581 tiny_http_request = Arc::new(Mutex::new(Some(request)));
582 let data = Arc::new(Mutex::new(Some(
583 Box::new(RequestRead(tiny_http_request.clone())) as Box<_>,
584 )));
585
586 Request {
587 url,
588 method,
589 headers,
590 https: false,
591 data,
592 remote_addr,
593 }
594 };
595
596 // Calling the handler ; this most likely takes a lot of time.
597 // If the handler panics, we build a dummy response.
598 let mut rouille_response = {
599 // We don't use the `rouille_request` anymore after the panic, so it's ok to assert
600 // it's unwind safe.
601 let rouille_request = AssertUnwindSafe(rouille_request);
602 let res = panic::catch_unwind(move || {
603 let rouille_request = rouille_request;
604 handler(&rouille_request)
605 });
606
607 match res {
608 Ok(r) => r,
609 Err(_) => Response::html(
610 "<h1>Internal Server Error</h1>\
611 <p>An internal error has occurred on the server.</p>",
612 )
613 .with_status_code(500),
614 }
615 };
616
617 // writing the response
618 let (res_data, res_len) = rouille_response.data.into_reader_and_size();
619 let mut response = tiny_http::Response::empty(rouille_response.status_code)
620 .with_data(res_data, res_len);
621
622 let mut upgrade_header = "".into();
623
624 for (key, value) in rouille_response.headers {
625 if key.eq_ignore_ascii_case("Content-Length") {
626 continue;
627 }
628
629 if key.eq_ignore_ascii_case("Upgrade") {
630 upgrade_header = value;
631 continue;
632 }
633
634 if let Ok(header) = tiny_http::Header::from_bytes(key.as_bytes(), value.as_bytes())
635 {
636 response.add_header(header);
637 } else {
638 // TODO: ?
639 }
640 }
641
642 if let Some(ref mut upgrade) = rouille_response.upgrade {
643 let trq = tiny_http_request.lock().unwrap().take().unwrap();
644 let socket = trq.upgrade(&upgrade_header, response);
645 upgrade.build(socket);
646 } else {
647 // We don't really care if we fail to send the response to the client, as there's
648 // nothing we can do anyway.
649 let _ = tiny_http_request
650 .lock()
651 .unwrap()
652 .take()
653 .unwrap()
654 .respond(response);
655 }
656 });
657 }
658}
659
660/// Trait for objects that can take ownership of a raw connection to the client data.
661///
662/// The purpose of this trait is to be used with the `Connection: Upgrade` header, hence its name.
663pub trait Upgrade {
664 /// Initializes the object with the given socket.
665 fn build(&mut self, socket: Box<dyn ReadWrite + Send>);
666}
667
668/// Represents a request that your handler must answer to.
669///
670/// This can be either a real request (received by the HTTP server) or a mock object created with
671/// one of the `fake_*` constructors.
672pub struct Request {
673 method: String,
674 url: String,
675 headers: Vec<(String, String)>,
676 https: bool,
677 data: Arc<Mutex<Option<Box<dyn Read + Send>>>>,
678 remote_addr: Option<SocketAddr>,
679}
680
681impl fmt::Debug for Request {
682 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
683 f.debug_struct("Request")
684 .field("method", &self.method)
685 .field("url", &self.url)
686 .field("headers", &self.headers)
687 .field("https", &self.https)
688 .field("remote_addr", &self.remote_addr)
689 .finish()
690 }
691}
692
693impl Request {
694 /// Builds a fake HTTP request to be used during tests.
695 ///
696 /// The remote address of the client will be `127.0.0.1:12345`. Use `fake_http_from` to
697 /// specify what the client's address should be.
698 pub fn fake_http<U, M>(
699 method: M,
700 url: U,
701 headers: Vec<(String, String)>,
702 data: Vec<u8>,
703 ) -> Request
704 where
705 U: Into<String>,
706 M: Into<String>,
707 {
708 let data = Arc::new(Mutex::new(Some(Box::new(Cursor::new(data)) as Box<_>)));
709 let remote_addr = Some("127.0.0.1:12345".parse().unwrap());
710
711 Request {
712 url: url.into(),
713 method: method.into(),
714 https: false,
715 data,
716 headers,
717 remote_addr,
718 }
719 }
720
721 /// Builds a fake HTTP request to be used during tests.
722 pub fn fake_http_from<U, M>(
723 from: SocketAddr,
724 method: M,
725 url: U,
726 headers: Vec<(String, String)>,
727 data: Vec<u8>,
728 ) -> Request
729 where
730 U: Into<String>,
731 M: Into<String>,
732 {
733 let data = Arc::new(Mutex::new(Some(Box::new(Cursor::new(data)) as Box<_>)));
734
735 Request {
736 url: url.into(),
737 method: method.into(),
738 https: false,
739 data,
740 headers,
741 remote_addr: Some(from),
742 }
743 }
744
745 /// Builds a fake HTTPS request to be used during tests.
746 ///
747 /// The remote address of the client will be `127.0.0.1:12345`. Use `fake_https_from` to
748 /// specify what the client's address should be.
749 pub fn fake_https<U, M>(
750 method: M,
751 url: U,
752 headers: Vec<(String, String)>,
753 data: Vec<u8>,
754 ) -> Request
755 where
756 U: Into<String>,
757 M: Into<String>,
758 {
759 let data = Arc::new(Mutex::new(Some(Box::new(Cursor::new(data)) as Box<_>)));
760 let remote_addr = Some("127.0.0.1:12345".parse().unwrap());
761
762 Request {
763 url: url.into(),
764 method: method.into(),
765 https: true,
766 data,
767 headers,
768 remote_addr,
769 }
770 }
771
772 /// Builds a fake HTTPS request to be used during tests.
773 pub fn fake_https_from<U, M>(
774 from: SocketAddr,
775 method: M,
776 url: U,
777 headers: Vec<(String, String)>,
778 data: Vec<u8>,
779 ) -> Request
780 where
781 U: Into<String>,
782 M: Into<String>,
783 {
784 let data = Arc::new(Mutex::new(Some(Box::new(Cursor::new(data)) as Box<_>)));
785
786 Request {
787 url: url.into(),
788 method: method.into(),
789 https: true,
790 data,
791 headers,
792 remote_addr: Some(from),
793 }
794 }
795
796 /// If the decoded URL of the request starts with `prefix`, builds a new `Request` that is
797 /// the same as the original but without that prefix.
798 ///
799 /// # Example
800 ///
801 /// ```
802 /// # use rouille::Request;
803 /// # use rouille::Response;
804 /// fn handle(request: &Request) -> Response {
805 /// if let Some(request) = request.remove_prefix("/static") {
806 /// return rouille::match_assets(&request, "/static");
807 /// }
808 ///
809 /// // ...
810 /// # panic!()
811 /// }
812 /// ```
813 pub fn remove_prefix(&self, prefix: &str) -> Option<Request> {
814 if !self.url().starts_with(prefix) {
815 return None;
816 }
817
818 // TODO: url-encoded characters in the prefix are not implemented
819 assert!(self.url.starts_with(prefix));
820 Some(Request {
821 method: self.method.clone(),
822 url: self.url[prefix.len()..].to_owned(),
823 headers: self.headers.clone(), // TODO: expensive
824 https: self.https,
825 data: self.data.clone(),
826 remote_addr: self.remote_addr,
827 })
828 }
829
830 /// Returns `true` if the request uses HTTPS, and `false` if it uses HTTP.
831 ///
832 /// # Example
833 ///
834 /// ```
835 /// use rouille::{Request, Response};
836 ///
837 /// fn handle(request: &Request) -> Response {
838 /// if !request.is_secure() {
839 /// return Response::redirect_303(format!("https://example.com"));
840 /// }
841 ///
842 /// // ...
843 /// # panic!()
844 /// }
845 /// ```
846 #[inline]
847 pub fn is_secure(&self) -> bool {
848 self.https
849 }
850
851 /// Returns the method of the request (`GET`, `POST`, etc.).
852 #[inline]
853 pub fn method(&self) -> &str {
854 &self.method
855 }
856
857 /// Returns the raw URL requested by the client. It is not decoded and thus can contain strings
858 /// such as `%20`, and the query parameters such as `?p=hello`.
859 ///
860 /// See also `url()`.
861 ///
862 /// # Example
863 ///
864 /// ```
865 /// use rouille::Request;
866 ///
867 /// let request = Request::fake_http("GET", "/hello%20world?foo=bar", vec![], vec![]);
868 /// assert_eq!(request.raw_url(), "/hello%20world?foo=bar");
869 /// ```
870 #[inline]
871 pub fn raw_url(&self) -> &str {
872 &self.url
873 }
874
875 /// Returns the raw query string requested by the client. In other words, everything after the
876 /// first `?` in the raw url.
877 ///
878 /// Returns the empty string if no query string.
879 #[inline]
880 pub fn raw_query_string(&self) -> &str {
881 if let Some(pos) = self.url.bytes().position(|c| c == b'?') {
882 self.url.split_at(pos + 1).1
883 } else {
884 ""
885 }
886 }
887
888 /// Returns the URL requested by the client.
889 ///
890 /// Contrary to `raw_url`, special characters have been decoded and the query string
891 /// (eg `?p=hello`) has been removed.
892 ///
893 /// If there is any non-unicode character in the URL, it will be replaced with `U+FFFD`.
894 ///
895 /// > **Note**: This function will decode the token `%2F` will be decoded as `/`. However the
896 /// > official specifications say that such a token must not count as a delimiter for URL paths.
897 /// > In other words, `/hello/world` is not the same as `/hello%2Fworld`.
898 ///
899 /// # Example
900 ///
901 /// ```
902 /// use rouille::Request;
903 ///
904 /// let request = Request::fake_http("GET", "/hello%20world?foo=bar", vec![], vec![]);
905 /// assert_eq!(request.url(), "/hello world");
906 /// ```
907 pub fn url(&self) -> String {
908 let url = self.url.as_bytes();
909 let url = if let Some(pos) = url.iter().position(|&c| c == b'?') {
910 &url[..pos]
911 } else {
912 url
913 };
914
915 percent_encoding::percent_decode(url)
916 .decode_utf8_lossy()
917 .into_owned()
918 }
919
920 /// Returns the value of a GET parameter or None if it doesn't exist.
921 pub fn get_param(&self, param_name: &str) -> Option<String> {
922 let name_pattern = &format!("{}=", param_name);
923 let param_pairs = self.raw_query_string().split('&');
924 param_pairs
925 .filter(|pair| pair.starts_with(name_pattern) || pair == ¶m_name)
926 .map(|pair| pair.split('=').nth(1).unwrap_or(""))
927 .next()
928 .map(|value| {
929 percent_encoding::percent_decode(value.replace('+', " ").as_bytes())
930 .decode_utf8_lossy()
931 .into_owned()
932 })
933 }
934
935 /// Returns the value of a header of the request.
936 ///
937 /// Returns `None` if no such header could be found.
938 #[inline]
939 pub fn header(&self, key: &str) -> Option<&str> {
940 self.headers
941 .iter()
942 .find(|&&(ref k, _)| k.eq_ignore_ascii_case(key))
943 .map(|&(_, ref v)| &v[..])
944 }
945
946 /// Returns a list of all the headers of the request.
947 #[inline]
948 pub fn headers(&self) -> HeadersIter {
949 HeadersIter {
950 iter: self.headers.iter(),
951 }
952 }
953
954 /// Returns the state of the `DNT` (Do Not Track) header.
955 ///
956 /// If the header is missing or is malformed, `None` is returned. If the header exists,
957 /// `Some(true)` is returned if `DNT` is `1` and `Some(false)` is returned if `DNT` is `0`.
958 ///
959 /// # Example
960 ///
961 /// ```
962 /// use rouille::{Request, Response};
963 ///
964 /// # fn track_user(request: &Request) {}
965 /// fn handle(request: &Request) -> Response {
966 /// if !request.do_not_track().unwrap_or(false) {
967 /// track_user(&request);
968 /// }
969 ///
970 /// // ...
971 /// # panic!()
972 /// }
973 /// ```
974 pub fn do_not_track(&self) -> Option<bool> {
975 match self.header("DNT") {
976 Some(h) if h == "1" => Some(true),
977 Some(h) if h == "0" => Some(false),
978 _ => None,
979 }
980 }
981
982 /// Returns the body of the request.
983 ///
984 /// The body can only be retrieved once. Returns `None` is the body has already been retrieved
985 /// before.
986 ///
987 /// # Example
988 ///
989 /// ```
990 /// use std::io::Read;
991 /// use rouille::{Request, Response, ResponseBody};
992 ///
993 /// fn echo(request: &Request) -> Response {
994 /// let mut data = request.data().expect("Oops, body already retrieved, problem \
995 /// in the server");
996 ///
997 /// let mut buf = Vec::new();
998 /// match data.read_to_end(&mut buf) {
999 /// Ok(_) => (),
1000 /// Err(_) => return Response::text("Failed to read body")
1001 /// };
1002 ///
1003 /// Response {
1004 /// data: ResponseBody::from_data(buf),
1005 /// .. Response::text("")
1006 /// }
1007 /// }
1008 /// ```
1009 pub fn data(&self) -> Option<RequestBody> {
1010 let reader = self.data.lock().unwrap().take();
1011 reader.map(|r| RequestBody {
1012 body: r,
1013 marker: PhantomData,
1014 })
1015 }
1016
1017 /// Returns the address of the client that made this request.
1018 ///
1019 /// # Example
1020 ///
1021 /// ```
1022 /// use rouille::{Request, Response};
1023 ///
1024 /// fn handle(request: &Request) -> Response {
1025 /// Response::text(format!("Your IP is: {:?}", request.remote_addr()))
1026 /// }
1027 /// ```
1028 #[inline]
1029 pub fn remote_addr(&self) -> &SocketAddr {
1030 self.remote_addr
1031 .as_ref()
1032 .expect("Unexpected Unix socket for request")
1033 }
1034}
1035
1036/// Iterator to the list of headers in a request.
1037#[derive(Debug, Clone)]
1038pub struct HeadersIter<'a> {
1039 iter: SliceIter<'a, (String, String)>,
1040}
1041
1042impl<'a> Iterator for HeadersIter<'a> {
1043 type Item = (&'a str, &'a str);
1044
1045 #[inline]
1046 fn next(&mut self) -> Option<Self::Item> {
1047 self.iter.next().map(|&(ref k, ref v)| (&k[..], &v[..]))
1048 }
1049
1050 #[inline]
1051 fn size_hint(&self) -> (usize, Option<usize>) {
1052 self.iter.size_hint()
1053 }
1054}
1055
1056impl<'a> ExactSizeIterator for HeadersIter<'a> {}
1057
1058/// Gives access to the body of a request.
1059///
1060/// In order to obtain this object, call `request.data()`.
1061pub struct RequestBody<'a> {
1062 body: Box<dyn Read + Send>,
1063 marker: PhantomData<&'a ()>,
1064}
1065
1066impl<'a> Read for RequestBody<'a> {
1067 #[inline]
1068 fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
1069 self.body.read(buf)
1070 }
1071}
1072
1073#[cfg(test)]
1074mod tests {
1075 use Request;
1076
1077 #[test]
1078 fn header() {
1079 let request = Request::fake_http(
1080 "GET",
1081 "/",
1082 vec![("Host".to_owned(), "localhost".to_owned())],
1083 vec![],
1084 );
1085 assert_eq!(request.header("Host"), Some("localhost"));
1086 assert_eq!(request.header("host"), Some("localhost"));
1087 }
1088
1089 #[test]
1090 fn get_param() {
1091 let request = Request::fake_http("GET", "/?p=hello", vec![], vec![]);
1092 assert_eq!(request.get_param("p"), Some("hello".to_owned()));
1093 }
1094
1095 #[test]
1096 fn get_param_multiple_param() {
1097 let request = Request::fake_http("GET", "/?foo=bar&message=hello", vec![], vec![]);
1098 assert_eq!(request.get_param("message"), Some("hello".to_owned()));
1099 }
1100 #[test]
1101 fn get_param_no_match() {
1102 let request = Request::fake_http("GET", "/?hello=world", vec![], vec![]);
1103 assert_eq!(request.get_param("foo"), None);
1104 }
1105
1106 #[test]
1107 fn get_param_partial_suffix_match() {
1108 let request = Request::fake_http("GET", "/?hello=world", vec![], vec![]);
1109 assert_eq!(request.get_param("lo"), None);
1110 }
1111
1112 #[test]
1113 fn get_param_partial_prefix_match() {
1114 let request = Request::fake_http("GET", "/?hello=world", vec![], vec![]);
1115 assert_eq!(request.get_param("he"), None);
1116 }
1117
1118 #[test]
1119 fn get_param_superstring_match() {
1120 let request = Request::fake_http("GET", "/?jan=01", vec![], vec![]);
1121 assert_eq!(request.get_param("january"), None);
1122 }
1123
1124 #[test]
1125 fn get_param_flag_with_equals() {
1126 let request = Request::fake_http("GET", "/?flag=", vec![], vec![]);
1127 assert_eq!(request.get_param("flag"), Some("".to_owned()));
1128 }
1129
1130 #[test]
1131 fn get_param_flag_without_equals() {
1132 let request = Request::fake_http("GET", "/?flag", vec![], vec![]);
1133 assert_eq!(request.get_param("flag"), Some("".to_owned()));
1134 }
1135
1136 #[test]
1137 fn get_param_flag_with_multiple_params() {
1138 let request = Request::fake_http("GET", "/?flag&foo=bar", vec![], vec![]);
1139 assert_eq!(request.get_param("flag"), Some("".to_owned()));
1140 }
1141
1142 #[test]
1143 fn body_twice() {
1144 let request = Request::fake_http("GET", "/", vec![], vec![62, 62, 62]);
1145 assert!(request.data().is_some());
1146 assert!(request.data().is_none());
1147 }
1148
1149 #[test]
1150 fn url_strips_get_query() {
1151 let request = Request::fake_http("GET", "/?p=hello", vec![], vec![]);
1152 assert_eq!(request.url(), "/");
1153 }
1154
1155 #[test]
1156 fn urlencode_query_string() {
1157 let request = Request::fake_http("GET", "/?p=hello%20world", vec![], vec![]);
1158 assert_eq!(request.get_param("p"), Some("hello world".to_owned()));
1159 }
1160
1161 #[test]
1162 fn plus_in_query_string() {
1163 let request = Request::fake_http("GET", "/?p=hello+world", vec![], vec![]);
1164 assert_eq!(request.get_param("p"), Some("hello world".to_owned()));
1165 }
1166
1167 #[test]
1168 fn encoded_plus_in_query_string() {
1169 let request = Request::fake_http("GET", "/?p=hello%2Bworld", vec![], vec![]);
1170 assert_eq!(request.get_param("p"), Some("hello+world".to_owned()));
1171 }
1172
1173 #[test]
1174 fn url_encode() {
1175 let request = Request::fake_http("GET", "/hello%20world", vec![], vec![]);
1176 assert_eq!(request.url(), "/hello world");
1177 }
1178
1179 #[test]
1180 fn plus_in_url() {
1181 let request = Request::fake_http("GET", "/hello+world", vec![], vec![]);
1182 assert_eq!(request.url(), "/hello+world");
1183 }
1184
1185 #[test]
1186 fn dnt() {
1187 let request =
1188 Request::fake_http("GET", "/", vec![("DNT".to_owned(), "1".to_owned())], vec![]);
1189 assert_eq!(request.do_not_track(), Some(true));
1190
1191 let request =
1192 Request::fake_http("GET", "/", vec![("DNT".to_owned(), "0".to_owned())], vec![]);
1193 assert_eq!(request.do_not_track(), Some(false));
1194
1195 let request = Request::fake_http("GET", "/", vec![], vec![]);
1196 assert_eq!(request.do_not_track(), None);
1197
1198 let request = Request::fake_http(
1199 "GET",
1200 "/",
1201 vec![("DNT".to_owned(), "malformed".to_owned())],
1202 vec![],
1203 );
1204 assert_eq!(request.do_not_track(), None);
1205 }
1206}