rouille/websocket/low_level.rs
1// Copyright (c) 2016 The Rouille developers
2// Licensed under the Apache License, Version 2.0
3// <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>,
6// at your option. All files in the project carrying such
7// notice may not be copied, modified, or distributed except
8// according to those terms.
9
10//! Low-level parsing of websocket frames.
11//!
12//! Usage:
13//!
14//! - Create a `StateMachine` with `StateMachine::new`.
15//! - Whenever data is received on the socket, call `StateMachine::feed`.
16//! - The returned iterator produces zero, one or multiple `Element` objects containing what was
17//! received.
18//! - For `Element::Data`, the `Data` object is an iterator over the decoded bytes.
19//! - If `Element::Error` is produced, immediately end the connection.
20//!
21//! Glossary:
22//!
23//! - A websocket stream is made of multiple *messages*.
24//! - Each message is made of one or more *frames*. See https://tools.ietf.org/html/rfc6455#section-5.4.
25//! - Each frame can be received progressively, where each packet is an `Element` object (below).
26
27/// A websocket element decoded from the data given to `StateMachine::feed`.
28#[derive(Debug, PartialEq, Eq)]
29pub enum Element<'a> {
30 /// A new frame has started.
31 FrameStart {
32 /// If true, this is the last frame of the message.
33 fin: bool,
34 /// Length of the frame in bytes.
35 length: u64,
36 /// Opcode. See https://tools.ietf.org/html/rfc6455#section-5.2.
37 opcode: u8,
38 },
39
40 /// Data was received as part of the current frame.
41 Data {
42 /// The decoded data. An iterator that produces `u8`s.
43 data: Data<'a>,
44 /// If true, this is the last packet in the frame.
45 last_in_frame: bool,
46 },
47
48 /// An error in the stream. The connection must be dropped ASAP.
49 Error {
50 /// A description of the error. Can or cannot be be returned to the client.
51 desc: &'static str,
52 },
53}
54
55/// Decoded data. Implements `Iterator<Item = u8>`.
56#[derive(Debug, PartialEq, Eq)]
57pub struct Data<'a> {
58 // Source data. Undecoded.
59 data: &'a [u8],
60 // Copy of the mask of the current frame.
61 mask: u32,
62 // Same as `StateMachineInner::InData::offset`. Updated at each iteration.
63 offset: u8,
64}
65
66/// A websocket state machine. Contains partial data.
67pub struct StateMachine {
68 // Actual state.
69 inner: StateMachineInner,
70 // Contains the start of the header. Must be empty if `inner` is equal to `InData`.
71 buffer: Vec<u8>, // TODO: use SmallVec?
72}
73
74enum StateMachineInner {
75 // If `StateMachine::inner` is `InHeader`, then `buffer` contains the start of the header.
76 InHeader,
77 // If `StateMachine::inner` is `InData`, then `buffer` must be empty.
78 InData {
79 // Mask to decode the message.
80 mask: u32,
81 // Value between 0 and 3 that indicates the number of bytes between the start of the data
82 // and the next expected byte.
83 offset: u8,
84 // Number of bytes remaining in the frame.
85 remaining_len: u64,
86 },
87}
88
89impl StateMachine {
90 /// Initializes a new state machine for a new stream. Expects to see a new frame as the first
91 /// packet.
92 pub fn new() -> StateMachine {
93 StateMachine {
94 inner: StateMachineInner::InHeader,
95 buffer: Vec::with_capacity(14),
96 }
97 }
98
99 /// Feeds data to the state machine. Returns an iterator to the list of elements that were
100 /// received.
101 #[inline]
102 pub fn feed<'a>(&'a mut self, data: &'a [u8]) -> ElementsIter<'a> {
103 ElementsIter { state: self, data }
104 }
105}
106
107// Helpers for decoding masked big-endian byte sequences
108// These could probably be replaced with something more robust like `nom` if we want to
109// take the hit of adding another dependency.
110fn read_u16_be<'a, T: Iterator<Item = &'a u8>>(input: &mut T) -> u16 {
111 let buf: [u8; 2] = [*input.next().unwrap(), *input.next().unwrap()];
112 u16::from_be_bytes(buf)
113}
114
115fn read_u32_be<'a, T: Iterator<Item = &'a u8>>(input: &mut T) -> u32 {
116 let buf: [u8; 4] = [
117 *input.next().unwrap(),
118 *input.next().unwrap(),
119 *input.next().unwrap(),
120 *input.next().unwrap(),
121 ];
122 u32::from_be_bytes(buf)
123}
124
125fn read_u64_be<'a, T: Iterator<Item = &'a u8>>(input: &mut T) -> u64 {
126 let buf: [u8; 8] = [
127 *input.next().unwrap(),
128 *input.next().unwrap(),
129 *input.next().unwrap(),
130 *input.next().unwrap(),
131 *input.next().unwrap(),
132 *input.next().unwrap(),
133 *input.next().unwrap(),
134 *input.next().unwrap(),
135 ];
136 u64::from_be_bytes(buf)
137}
138
139/// Iterator to the list of elements that were received.
140pub struct ElementsIter<'a> {
141 state: &'a mut StateMachine,
142 data: &'a [u8],
143}
144
145impl<'a> Iterator for ElementsIter<'a> {
146 type Item = Element<'a>;
147
148 fn next(&mut self) -> Option<Element<'a>> {
149 if self.data.is_empty() {
150 return None;
151 }
152
153 match self.state.inner {
154 // First situation, we are in the header.
155 StateMachineInner::InHeader => {
156 // We need at least 6 bytes for a successful header. Otherwise we just return.
157 let total_buffered = self.state.buffer.len() + self.data.len();
158 if total_buffered < 6 {
159 self.state.buffer.extend_from_slice(self.data);
160 self.data = &[];
161 return None;
162 }
163
164 // Retrieve the first two bytes of the header.
165 let (first_byte, second_byte) = {
166 let mut mask_iter = self.state.buffer.iter().chain(self.data.iter());
167 let first_byte = *mask_iter.next().unwrap();
168 let second_byte = *mask_iter.next().unwrap();
169 (first_byte, second_byte)
170 };
171
172 // Reserved bits must be zero, otherwise error.
173 if (first_byte & 0x70) != 0 {
174 return Some(Element::Error {
175 desc: "Reserved bits must be zero",
176 });
177 }
178
179 // Client-to-server messages **must** be encoded.
180 if (second_byte & 0x80) == 0 {
181 return Some(Element::Error {
182 desc: "Client-to-server messages must be masked",
183 });
184 }
185
186 // Find the length of the frame and the mask.
187 let (length, mask) = match second_byte & 0x7f {
188 126 => {
189 if total_buffered < 8 {
190 self.state.buffer.extend_from_slice(self.data);
191 self.data = &[];
192 return None;
193 }
194
195 let mut mask_iter =
196 self.state.buffer.iter().chain(self.data.iter()).skip(2);
197
198 let length = read_u16_be(&mut mask_iter) as u64;
199 let mask = read_u32_be(&mut mask_iter);
200 (length, mask)
201 }
202 127 => {
203 if total_buffered < 14 {
204 self.state.buffer.extend_from_slice(self.data);
205 self.data = &[];
206 return None;
207 }
208
209 let mut mask_iter =
210 self.state.buffer.iter().chain(self.data.iter()).skip(2);
211
212 let length = {
213 let length = read_u64_be(&mut mask_iter);
214 // The most significant bit must be zero according to the specs.
215 if (length & 0x8000000000000000) != 0 {
216 return Some(Element::Error {
217 desc: "Most-significant bit of the length must be zero",
218 });
219 }
220 length
221 };
222
223 let mask = read_u32_be(&mut mask_iter);
224
225 (length, mask)
226 }
227 n => {
228 let mut mask_iter =
229 self.state.buffer.iter().chain(self.data.iter()).skip(2);
230
231 let mask = read_u32_be(&mut mask_iter);
232 (u64::from(n), mask)
233 }
234 };
235
236 // Builds a slice containing the start of the data.
237 let data_start = {
238 let data_start_off = match second_byte & 0x7f {
239 126 => 8,
240 127 => 14,
241 _ => 6,
242 };
243
244 assert!(self.state.buffer.len() < data_start_off);
245 &self.data[(data_start_off - self.state.buffer.len())..]
246 };
247
248 // Update ourselves for the next loop and return a FrameStart message.
249 self.data = data_start;
250 self.state.buffer.clear();
251 self.state.inner = StateMachineInner::InData {
252 mask,
253 remaining_len: length,
254 offset: 0,
255 };
256 Some(Element::FrameStart {
257 fin: (first_byte & 0x80) != 0,
258 length,
259 opcode: first_byte & 0xf,
260 })
261 }
262
263 // Second situation, we are in the message and we don't have enough data to finish the
264 // current frame.
265 StateMachineInner::InData {
266 mask,
267 ref mut remaining_len,
268 ref mut offset,
269 } if *remaining_len > self.data.len() as u64 => {
270 let data = Data {
271 data: self.data,
272 mask,
273 offset: *offset,
274 };
275
276 *offset += (self.data.len() % 4) as u8;
277 *offset %= 4;
278 *remaining_len -= self.data.len() as u64;
279
280 self.data = &[];
281
282 Some(Element::Data {
283 data,
284 last_in_frame: false,
285 })
286 }
287
288 // Third situation, we have enough data to finish the frame.
289 StateMachineInner::InData {
290 mask,
291 remaining_len,
292 offset,
293 } => {
294 debug_assert!(self.data.len() as u64 >= remaining_len);
295
296 let data = Data {
297 data: &self.data[0..remaining_len as usize],
298 mask,
299 offset,
300 };
301
302 self.data = &self.data[remaining_len as usize..];
303 self.state.inner = StateMachineInner::InHeader;
304 debug_assert!(self.state.buffer.is_empty());
305
306 Some(Element::Data {
307 data,
308 last_in_frame: true,
309 })
310 }
311 }
312 }
313}
314
315impl<'a> Iterator for Data<'a> {
316 type Item = u8;
317
318 #[inline]
319 fn next(&mut self) -> Option<u8> {
320 if self.data.is_empty() {
321 return None;
322 }
323
324 let byte = self.data[0];
325 let mask = ((self.mask >> ((3 - self.offset) * 8)) & 0xff) as u8;
326 let decoded = byte ^ mask;
327
328 self.data = &self.data[1..];
329 self.offset = (self.offset + 1) % 4;
330
331 Some(decoded)
332 }
333
334 #[inline]
335 fn size_hint(&self) -> (usize, Option<usize>) {
336 let l = self.data.len();
337 (l, Some(l))
338 }
339}
340
341impl<'a> ExactSizeIterator for Data<'a> {}
342
343#[cfg(test)]
344mod tests {
345 use super::Element;
346 use super::StateMachine;
347
348 #[test]
349 fn basic() {
350 let mut machine = StateMachine::new();
351
352 let data = &[
353 0x81, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58,
354 ];
355 let mut iter = machine.feed(data);
356
357 assert_eq!(
358 iter.next().unwrap(),
359 Element::FrameStart {
360 fin: true,
361 length: 5,
362 opcode: 1
363 }
364 );
365
366 match iter.next().unwrap() {
367 Element::Data {
368 data,
369 last_in_frame,
370 } => {
371 assert!(last_in_frame);
372 assert_eq!(data.collect::<Vec<_>>(), b"Hello");
373 }
374 _ => panic!(),
375 }
376 }
377}