rouille/websocket/mod.rs
1// Copyright (c) 2016 The Rouille developers
2// Licensed under the Apache License, Version 2.0
3// <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>,
6// at your option. All files in the project carrying such
7// notice may not be copied, modified, or distributed except
8// according to those terms.
9
10//! Support for websockets.
11//!
12//! Using websockets is done with the following steps:
13//!
14//! - The websocket client (usually the browser through some Javascript) must send a request to the
15//! server to initiate the process. Examples for how to do this in Javascript are out of scope
16//! of this documentation but should be easy to find on the web.
17//! - The server written with rouille must answer that request with the `start()` function defined
18//! in this module. This function returns an error if the request is not a websocket
19//! initialization request.
20//! - The `start()` function also returns a `Receiver<Websocket>` object. Once that `Receiver`
21//! contains a value, the connection has been initiated.
22//! - You can then use the `Websocket` object to communicate with the client through the `Read`
23//! and `Write` traits.
24//!
25//! # Subprotocols
26//!
27//! The websocket connection will produce either text or binary messages. But these messages do not
28//! have a meaning per se, and must also be interpreted in some way. The way messages are
29//! interpreted during a websocket connection is called a *subprotocol*.
30//!
31//! When you call `start()` you have to indicate which subprotocol the connection is going to use.
32//! This subprotocol must match one of the subprotocols that were passed by the client during its
33//! request, otherwise `start()` will return an error. It is also possible to pass `None`, in which
34//! case the subprotocol is unknown to both the client and the server.
35//!
36//! There are usually three ways to handle subprotocols on the server-side:
37//!
38//! - You don't really care about subprotocols because you use websockets for your own needs. You
39//! can just pass `None` to `start()`. The connection will thus never fail unless the client
40//! decides to.
41//! - Your route only handles one subprotocol. Just pass this subprotocol to `start()` and you will
42//! get an error (which you can handle for example with `try_or_400!`) if it's not supported by
43//! the client.
44//! - Your route supports multiple subprotocols. This is the most complex situation as you will
45//! have to enumerate the protocols with `requested_protocols()` and choose one.
46//!
47//! # Example
48//!
49//! ```
50//! # #[macro_use] extern crate rouille;
51//! use std::sync::Mutex;
52//! use std::sync::mpsc::Receiver;
53//!
54//! use rouille::Request;
55//! use rouille::Response;
56//! use rouille::websocket;
57//! # fn main() {}
58//!
59//! fn handle_request(request: &Request, websockets: &Mutex<Vec<Receiver<websocket::Websocket>>>)
60//! -> Response
61//! {
62//! let (response, websocket) = try_or_400!(websocket::start(request, Some("my-subprotocol")));
63//! websockets.lock().unwrap().push(websocket);
64//! response
65//! }
66//! ```
67
68pub use self::websocket::Message;
69pub use self::websocket::SendError;
70pub use self::websocket::Websocket;
71
72use base64;
73use sha1_smol::Sha1;
74use std::borrow::Cow;
75use std::error;
76use std::fmt;
77use std::sync::mpsc;
78use std::vec::IntoIter as VecIntoIter;
79
80use Request;
81use Response;
82
83mod low_level;
84#[allow(clippy::module_inception)]
85mod websocket;
86
87/// Error that can happen when attempting to start websocket.
88#[derive(Debug)]
89pub enum WebsocketError {
90 /// The request does not match a websocket request.
91 ///
92 /// The conditions are:
93 /// - The method must be `GET`.
94 /// - The HTTP version must be at least 1.1.
95 /// - The request must include `Host`.
96 /// - The `Connection` header must include `websocket`.
97 /// - The `Sec-WebSocket-Version` header must be `13`.
98 /// - Must have a `Sec-WebSocket-Key` header.
99 InvalidWebsocketRequest,
100
101 /// The subprotocol passed to the function was not requested by the client.
102 WrongSubprotocol,
103}
104
105impl error::Error for WebsocketError {}
106
107impl fmt::Display for WebsocketError {
108 #[inline]
109 fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
110 let description = match *self {
111 WebsocketError::InvalidWebsocketRequest => {
112 "the request does not match a websocket request"
113 }
114 WebsocketError::WrongSubprotocol => {
115 "the subprotocol passed to the function was not requested by the client"
116 }
117 };
118
119 write!(fmt, "{}", description)
120 }
121}
122
123/// Builds a `Response` that initiates the websocket protocol.
124pub fn start<S>(
125 request: &Request,
126 subprotocol: Option<S>,
127) -> Result<(Response, mpsc::Receiver<Websocket>), WebsocketError>
128where
129 S: Into<Cow<'static, str>>,
130{
131 let subprotocol = subprotocol.map(|s| s.into());
132
133 if request.method() != "GET" {
134 return Err(WebsocketError::InvalidWebsocketRequest);
135 }
136
137 // TODO:
138 /*if request.http_version() < &HTTPVersion(1, 1) {
139 return Err(WebsocketError::InvalidWebsocketRequest);
140 }*/
141
142 match request.header("Connection") {
143 Some(h) if h.to_ascii_lowercase().contains("upgrade") => (),
144 _ => return Err(WebsocketError::InvalidWebsocketRequest),
145 }
146
147 match request.header("Upgrade") {
148 Some(h) if h.to_ascii_lowercase().contains("websocket") => (),
149 _ => return Err(WebsocketError::InvalidWebsocketRequest),
150 }
151
152 // TODO: there are some version shenanigans to handle
153 // see https://tools.ietf.org/html/rfc6455#section-4.4
154 match request.header("Sec-WebSocket-Version") {
155 Some(h) if h == "13" => (),
156 _ => return Err(WebsocketError::InvalidWebsocketRequest),
157 }
158
159 if let Some(ref sp) = subprotocol {
160 if !requested_protocols(request).any(|p| &p == sp) {
161 return Err(WebsocketError::WrongSubprotocol);
162 }
163 }
164
165 let key = {
166 let in_key = match request.header("Sec-WebSocket-Key") {
167 Some(h) => h,
168 None => return Err(WebsocketError::InvalidWebsocketRequest),
169 };
170
171 convert_key(in_key)
172 };
173
174 let (tx, rx) = mpsc::channel();
175
176 let mut response = Response::text("");
177 response.status_code = 101;
178 response
179 .headers
180 .push(("Upgrade".into(), "websocket".into()));
181 if let Some(sp) = subprotocol {
182 response.headers.push(("Sec-Websocket-Protocol".into(), sp));
183 }
184 response
185 .headers
186 .push(("Sec-Websocket-Accept".into(), key.into()));
187 response.upgrade = Some(Box::new(tx) as Box<_>);
188 Ok((response, rx))
189}
190
191/// Returns a list of the websocket protocols requested by the client.
192///
193/// # Example
194///
195/// ```
196/// use rouille::websocket;
197///
198/// # let request: rouille::Request = return;
199/// for protocol in websocket::requested_protocols(&request) {
200/// // ...
201/// }
202/// ```
203// TODO: return references to the request
204pub fn requested_protocols(request: &Request) -> RequestedProtocolsIter {
205 match request.header("Sec-WebSocket-Protocol") {
206 None => RequestedProtocolsIter {
207 iter: Vec::new().into_iter(),
208 },
209 Some(h) => {
210 let iter = h
211 .split(',')
212 .map(|s| s.trim())
213 .filter(|s| !s.is_empty())
214 .map(|s| s.to_owned())
215 .collect::<Vec<_>>()
216 .into_iter();
217 RequestedProtocolsIter { iter }
218 }
219 }
220}
221
222/// Iterator to the list of protocols requested by the user.
223pub struct RequestedProtocolsIter {
224 iter: VecIntoIter<String>,
225}
226
227impl Iterator for RequestedProtocolsIter {
228 type Item = String;
229
230 #[inline]
231 fn next(&mut self) -> Option<String> {
232 self.iter.next()
233 }
234
235 #[inline]
236 fn size_hint(&self) -> (usize, Option<usize>) {
237 self.iter.size_hint()
238 }
239}
240
241impl ExactSizeIterator for RequestedProtocolsIter {}
242
243/// Turns a `Sec-WebSocket-Key` into a `Sec-WebSocket-Accept`.
244fn convert_key(input: &str) -> String {
245 let mut sha1 = Sha1::new();
246 sha1.update(input.as_bytes());
247 sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
248
249 base64::encode_config(&sha1.digest().bytes(), base64::STANDARD)
250}