rouille/websocket/
mod.rs

1// Copyright (c) 2016 The Rouille developers
2// Licensed under the Apache License, Version 2.0
3// <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>,
6// at your option. All files in the project carrying such
7// notice may not be copied, modified, or distributed except
8// according to those terms.
9
10//! Support for websockets.
11//!
12//! Using websockets is done with the following steps:
13//!
14//! - The websocket client (usually the browser through some Javascript) must send a request to the
15//!   server to initiate the process. Examples for how to do this in Javascript are out of scope
16//!   of this documentation but should be easy to find on the web.
17//! - The server written with rouille must answer that request with the `start()` function defined
18//!   in this module. This function returns an error if the request is not a websocket
19//!   initialization request.
20//! - The `start()` function also returns a `Receiver<Websocket>` object. Once that `Receiver`
21//!   contains a value, the connection has been initiated.
22//! - You can then use the `Websocket` object to communicate with the client through the `Read`
23//!   and `Write` traits.
24//!
25//! # Subprotocols
26//!
27//! The websocket connection will produce either text or binary messages. But these messages do not
28//! have a meaning per se, and must also be interpreted in some way. The way messages are
29//! interpreted during a websocket connection is called a *subprotocol*.
30//!
31//! When you call `start()` you have to indicate which subprotocol the connection is going to use.
32//! This subprotocol must match one of the subprotocols that were passed by the client during its
33//! request, otherwise `start()` will return an error. It is also possible to pass `None`, in which
34//! case the subprotocol is unknown to both the client and the server.
35//!
36//! There are usually three ways to handle subprotocols on the server-side:
37//!
38//! - You don't really care about subprotocols because you use websockets for your own needs. You
39//!   can just pass `None` to `start()`. The connection will thus never fail unless the client
40//!   decides to.
41//! - Your route only handles one subprotocol. Just pass this subprotocol to `start()` and you will
42//!   get an error (which you can handle for example with `try_or_400!`) if it's not supported by
43//!   the client.
44//! - Your route supports multiple subprotocols. This is the most complex situation as you will
45//!   have to enumerate the protocols with `requested_protocols()` and choose one.
46//!
47//! # Example
48//!
49//! ```
50//! # #[macro_use] extern crate rouille;
51//! use std::sync::Mutex;
52//! use std::sync::mpsc::Receiver;
53//!
54//! use rouille::Request;
55//! use rouille::Response;
56//! use rouille::websocket;
57//! # fn main() {}
58//!
59//! fn handle_request(request: &Request, websockets: &Mutex<Vec<Receiver<websocket::Websocket>>>)
60//!                   -> Response
61//! {
62//!     let (response, websocket) = try_or_400!(websocket::start(request, Some("my-subprotocol")));
63//!     websockets.lock().unwrap().push(websocket);
64//!     response
65//! }
66//! ```
67
68pub use self::websocket::Message;
69pub use self::websocket::SendError;
70pub use self::websocket::Websocket;
71
72use base64;
73use sha1_smol::Sha1;
74use std::borrow::Cow;
75use std::error;
76use std::fmt;
77use std::sync::mpsc;
78use std::vec::IntoIter as VecIntoIter;
79
80use Request;
81use Response;
82
83mod low_level;
84#[allow(clippy::module_inception)]
85mod websocket;
86
87/// Error that can happen when attempting to start websocket.
88#[derive(Debug)]
89pub enum WebsocketError {
90    /// The request does not match a websocket request.
91    ///
92    /// The conditions are:
93    /// - The method must be `GET`.
94    /// - The HTTP version must be at least 1.1.
95    /// - The request must include `Host`.
96    /// - The `Connection` header must include `websocket`.
97    /// - The `Sec-WebSocket-Version` header must be `13`.
98    /// - Must have a `Sec-WebSocket-Key` header.
99    InvalidWebsocketRequest,
100
101    /// The subprotocol passed to the function was not requested by the client.
102    WrongSubprotocol,
103}
104
105impl error::Error for WebsocketError {}
106
107impl fmt::Display for WebsocketError {
108    #[inline]
109    fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
110        let description = match *self {
111            WebsocketError::InvalidWebsocketRequest => {
112                "the request does not match a websocket request"
113            }
114            WebsocketError::WrongSubprotocol => {
115                "the subprotocol passed to the function was not requested by the client"
116            }
117        };
118
119        write!(fmt, "{}", description)
120    }
121}
122
123/// Builds a `Response` that initiates the websocket protocol.
124pub fn start<S>(
125    request: &Request,
126    subprotocol: Option<S>,
127) -> Result<(Response, mpsc::Receiver<Websocket>), WebsocketError>
128where
129    S: Into<Cow<'static, str>>,
130{
131    let subprotocol = subprotocol.map(|s| s.into());
132
133    if request.method() != "GET" {
134        return Err(WebsocketError::InvalidWebsocketRequest);
135    }
136
137    // TODO:
138    /*if request.http_version() < &HTTPVersion(1, 1) {
139        return Err(WebsocketError::InvalidWebsocketRequest);
140    }*/
141
142    match request.header("Connection") {
143        Some(h) if h.to_ascii_lowercase().contains("upgrade") => (),
144        _ => return Err(WebsocketError::InvalidWebsocketRequest),
145    }
146
147    match request.header("Upgrade") {
148        Some(h) if h.to_ascii_lowercase().contains("websocket") => (),
149        _ => return Err(WebsocketError::InvalidWebsocketRequest),
150    }
151
152    // TODO: there are some version shenanigans to handle
153    // see https://tools.ietf.org/html/rfc6455#section-4.4
154    match request.header("Sec-WebSocket-Version") {
155        Some(h) if h == "13" => (),
156        _ => return Err(WebsocketError::InvalidWebsocketRequest),
157    }
158
159    if let Some(ref sp) = subprotocol {
160        if !requested_protocols(request).any(|p| &p == sp) {
161            return Err(WebsocketError::WrongSubprotocol);
162        }
163    }
164
165    let key = {
166        let in_key = match request.header("Sec-WebSocket-Key") {
167            Some(h) => h,
168            None => return Err(WebsocketError::InvalidWebsocketRequest),
169        };
170
171        convert_key(in_key)
172    };
173
174    let (tx, rx) = mpsc::channel();
175
176    let mut response = Response::text("");
177    response.status_code = 101;
178    response
179        .headers
180        .push(("Upgrade".into(), "websocket".into()));
181    if let Some(sp) = subprotocol {
182        response.headers.push(("Sec-Websocket-Protocol".into(), sp));
183    }
184    response
185        .headers
186        .push(("Sec-Websocket-Accept".into(), key.into()));
187    response.upgrade = Some(Box::new(tx) as Box<_>);
188    Ok((response, rx))
189}
190
191/// Returns a list of the websocket protocols requested by the client.
192///
193/// # Example
194///
195/// ```
196/// use rouille::websocket;
197///
198/// # let request: rouille::Request = return;
199/// for protocol in websocket::requested_protocols(&request) {
200///     // ...
201/// }
202/// ```
203// TODO: return references to the request
204pub fn requested_protocols(request: &Request) -> RequestedProtocolsIter {
205    match request.header("Sec-WebSocket-Protocol") {
206        None => RequestedProtocolsIter {
207            iter: Vec::new().into_iter(),
208        },
209        Some(h) => {
210            let iter = h
211                .split(',')
212                .map(|s| s.trim())
213                .filter(|s| !s.is_empty())
214                .map(|s| s.to_owned())
215                .collect::<Vec<_>>()
216                .into_iter();
217            RequestedProtocolsIter { iter }
218        }
219    }
220}
221
222/// Iterator to the list of protocols requested by the user.
223pub struct RequestedProtocolsIter {
224    iter: VecIntoIter<String>,
225}
226
227impl Iterator for RequestedProtocolsIter {
228    type Item = String;
229
230    #[inline]
231    fn next(&mut self) -> Option<String> {
232        self.iter.next()
233    }
234
235    #[inline]
236    fn size_hint(&self) -> (usize, Option<usize>) {
237        self.iter.size_hint()
238    }
239}
240
241impl ExactSizeIterator for RequestedProtocolsIter {}
242
243/// Turns a `Sec-WebSocket-Key` into a `Sec-WebSocket-Accept`.
244fn convert_key(input: &str) -> String {
245    let mut sha1 = Sha1::new();
246    sha1.update(input.as_bytes());
247    sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
248
249    base64::encode_config(&sha1.digest().bytes(), base64::STANDARD)
250}