rouille/websocket/
low_level.rs

1// Copyright (c) 2016 The Rouille developers
2// Licensed under the Apache License, Version 2.0
3// <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>,
6// at your option. All files in the project carrying such
7// notice may not be copied, modified, or distributed except
8// according to those terms.
9
10//! Low-level parsing of websocket frames.
11//!
12//! Usage:
13//!
14//! - Create a `StateMachine` with `StateMachine::new`.
15//! - Whenever data is received on the socket, call `StateMachine::feed`.
16//! - The returned iterator produces zero, one or multiple `Element` objects containing what was
17//!   received.
18//! - For `Element::Data`, the `Data` object is an iterator over the decoded bytes.
19//! - If `Element::Error` is produced, immediately end the connection.
20//!
21//! Glossary:
22//!
23//! - A websocket stream is made of multiple *messages*.
24//! - Each message is made of one or more *frames*. See https://tools.ietf.org/html/rfc6455#section-5.4.
25//! - Each frame can be received progressively, where each packet is an `Element` object (below).
26
27/// A websocket element decoded from the data given to `StateMachine::feed`.
28#[derive(Debug, PartialEq, Eq)]
29pub enum Element<'a> {
30    /// A new frame has started.
31    FrameStart {
32        /// If true, this is the last frame of the message.
33        fin: bool,
34        /// Length of the frame in bytes.
35        length: u64,
36        /// Opcode. See https://tools.ietf.org/html/rfc6455#section-5.2.
37        opcode: u8,
38    },
39
40    /// Data was received as part of the current frame.
41    Data {
42        /// The decoded data. An iterator that produces `u8`s.
43        data: Data<'a>,
44        /// If true, this is the last packet in the frame.
45        last_in_frame: bool,
46    },
47
48    /// An error in the stream. The connection must be dropped ASAP.
49    Error {
50        /// A description of the error. Can or cannot be be returned to the client.
51        desc: &'static str,
52    },
53}
54
55/// Decoded data. Implements `Iterator<Item = u8>`.
56#[derive(Debug, PartialEq, Eq)]
57pub struct Data<'a> {
58    // Source data. Undecoded.
59    data: &'a [u8],
60    // Copy of the mask of the current frame.
61    mask: u32,
62    // Same as `StateMachineInner::InData::offset`. Updated at each iteration.
63    offset: u8,
64}
65
66/// A websocket state machine. Contains partial data.
67pub struct StateMachine {
68    // Actual state.
69    inner: StateMachineInner,
70    // Contains the start of the header. Must be empty if `inner` is equal to `InData`.
71    buffer: Vec<u8>, // TODO: use SmallVec?
72}
73
74enum StateMachineInner {
75    // If `StateMachine::inner` is `InHeader`, then `buffer` contains the start of the header.
76    InHeader,
77    // If `StateMachine::inner` is `InData`, then `buffer` must be empty.
78    InData {
79        // Mask to decode the message.
80        mask: u32,
81        // Value between 0 and 3 that indicates the number of bytes between the start of the data
82        // and the next expected byte.
83        offset: u8,
84        // Number of bytes remaining in the frame.
85        remaining_len: u64,
86    },
87}
88
89impl StateMachine {
90    /// Initializes a new state machine for a new stream. Expects to see a new frame as the first
91    /// packet.
92    pub fn new() -> StateMachine {
93        StateMachine {
94            inner: StateMachineInner::InHeader,
95            buffer: Vec::with_capacity(14),
96        }
97    }
98
99    /// Feeds data to the state machine. Returns an iterator to the list of elements that were
100    /// received.
101    #[inline]
102    pub fn feed<'a>(&'a mut self, data: &'a [u8]) -> ElementsIter<'a> {
103        ElementsIter { state: self, data }
104    }
105}
106
107// Helpers for decoding masked big-endian byte sequences
108// These could probably be replaced with something more robust like `nom` if we want to
109// take the hit of adding another dependency.
110fn read_u16_be<'a, T: Iterator<Item = &'a u8>>(input: &mut T) -> u16 {
111    let buf: [u8; 2] = [*input.next().unwrap(), *input.next().unwrap()];
112    u16::from_be_bytes(buf)
113}
114
115fn read_u32_be<'a, T: Iterator<Item = &'a u8>>(input: &mut T) -> u32 {
116    let buf: [u8; 4] = [
117        *input.next().unwrap(),
118        *input.next().unwrap(),
119        *input.next().unwrap(),
120        *input.next().unwrap(),
121    ];
122    u32::from_be_bytes(buf)
123}
124
125fn read_u64_be<'a, T: Iterator<Item = &'a u8>>(input: &mut T) -> u64 {
126    let buf: [u8; 8] = [
127        *input.next().unwrap(),
128        *input.next().unwrap(),
129        *input.next().unwrap(),
130        *input.next().unwrap(),
131        *input.next().unwrap(),
132        *input.next().unwrap(),
133        *input.next().unwrap(),
134        *input.next().unwrap(),
135    ];
136    u64::from_be_bytes(buf)
137}
138
139/// Iterator to the list of elements that were received.
140pub struct ElementsIter<'a> {
141    state: &'a mut StateMachine,
142    data: &'a [u8],
143}
144
145impl<'a> Iterator for ElementsIter<'a> {
146    type Item = Element<'a>;
147
148    fn next(&mut self) -> Option<Element<'a>> {
149        if self.data.is_empty() {
150            return None;
151        }
152
153        match self.state.inner {
154            // First situation, we are in the header.
155            StateMachineInner::InHeader => {
156                // We need at least 6 bytes for a successful header. Otherwise we just return.
157                let total_buffered = self.state.buffer.len() + self.data.len();
158                if total_buffered < 6 {
159                    self.state.buffer.extend_from_slice(self.data);
160                    self.data = &[];
161                    return None;
162                }
163
164                // Retrieve the first two bytes of the header.
165                let (first_byte, second_byte) = {
166                    let mut mask_iter = self.state.buffer.iter().chain(self.data.iter());
167                    let first_byte = *mask_iter.next().unwrap();
168                    let second_byte = *mask_iter.next().unwrap();
169                    (first_byte, second_byte)
170                };
171
172                // Reserved bits must be zero, otherwise error.
173                if (first_byte & 0x70) != 0 {
174                    return Some(Element::Error {
175                        desc: "Reserved bits must be zero",
176                    });
177                }
178
179                // Client-to-server messages **must** be encoded.
180                if (second_byte & 0x80) == 0 {
181                    return Some(Element::Error {
182                        desc: "Client-to-server messages must be masked",
183                    });
184                }
185
186                // Find the length of the frame and the mask.
187                let (length, mask) = match second_byte & 0x7f {
188                    126 => {
189                        if total_buffered < 8 {
190                            self.state.buffer.extend_from_slice(self.data);
191                            self.data = &[];
192                            return None;
193                        }
194
195                        let mut mask_iter =
196                            self.state.buffer.iter().chain(self.data.iter()).skip(2);
197
198                        let length = read_u16_be(&mut mask_iter) as u64;
199                        let mask = read_u32_be(&mut mask_iter);
200                        (length, mask)
201                    }
202                    127 => {
203                        if total_buffered < 14 {
204                            self.state.buffer.extend_from_slice(self.data);
205                            self.data = &[];
206                            return None;
207                        }
208
209                        let mut mask_iter =
210                            self.state.buffer.iter().chain(self.data.iter()).skip(2);
211
212                        let length = {
213                            let length = read_u64_be(&mut mask_iter);
214                            // The most significant bit must be zero according to the specs.
215                            if (length & 0x8000000000000000) != 0 {
216                                return Some(Element::Error {
217                                    desc: "Most-significant bit of the length must be zero",
218                                });
219                            }
220                            length
221                        };
222
223                        let mask = read_u32_be(&mut mask_iter);
224
225                        (length, mask)
226                    }
227                    n => {
228                        let mut mask_iter =
229                            self.state.buffer.iter().chain(self.data.iter()).skip(2);
230
231                        let mask = read_u32_be(&mut mask_iter);
232                        (u64::from(n), mask)
233                    }
234                };
235
236                // Builds a slice containing the start of the data.
237                let data_start = {
238                    let data_start_off = match second_byte & 0x7f {
239                        126 => 8,
240                        127 => 14,
241                        _ => 6,
242                    };
243
244                    assert!(self.state.buffer.len() < data_start_off);
245                    &self.data[(data_start_off - self.state.buffer.len())..]
246                };
247
248                // Update ourselves for the next loop and return a FrameStart message.
249                self.data = data_start;
250                self.state.buffer.clear();
251                self.state.inner = StateMachineInner::InData {
252                    mask,
253                    remaining_len: length,
254                    offset: 0,
255                };
256                Some(Element::FrameStart {
257                    fin: (first_byte & 0x80) != 0,
258                    length,
259                    opcode: first_byte & 0xf,
260                })
261            }
262
263            // Second situation, we are in the message and we don't have enough data to finish the
264            // current frame.
265            StateMachineInner::InData {
266                mask,
267                ref mut remaining_len,
268                ref mut offset,
269            } if *remaining_len > self.data.len() as u64 => {
270                let data = Data {
271                    data: self.data,
272                    mask,
273                    offset: *offset,
274                };
275
276                *offset += (self.data.len() % 4) as u8;
277                *offset %= 4;
278                *remaining_len -= self.data.len() as u64;
279
280                self.data = &[];
281
282                Some(Element::Data {
283                    data,
284                    last_in_frame: false,
285                })
286            }
287
288            // Third situation, we have enough data to finish the frame.
289            StateMachineInner::InData {
290                mask,
291                remaining_len,
292                offset,
293            } => {
294                debug_assert!(self.data.len() as u64 >= remaining_len);
295
296                let data = Data {
297                    data: &self.data[0..remaining_len as usize],
298                    mask,
299                    offset,
300                };
301
302                self.data = &self.data[remaining_len as usize..];
303                self.state.inner = StateMachineInner::InHeader;
304                debug_assert!(self.state.buffer.is_empty());
305
306                Some(Element::Data {
307                    data,
308                    last_in_frame: true,
309                })
310            }
311        }
312    }
313}
314
315impl<'a> Iterator for Data<'a> {
316    type Item = u8;
317
318    #[inline]
319    fn next(&mut self) -> Option<u8> {
320        if self.data.is_empty() {
321            return None;
322        }
323
324        let byte = self.data[0];
325        let mask = ((self.mask >> ((3 - self.offset) * 8)) & 0xff) as u8;
326        let decoded = byte ^ mask;
327
328        self.data = &self.data[1..];
329        self.offset = (self.offset + 1) % 4;
330
331        Some(decoded)
332    }
333
334    #[inline]
335    fn size_hint(&self) -> (usize, Option<usize>) {
336        let l = self.data.len();
337        (l, Some(l))
338    }
339}
340
341impl<'a> ExactSizeIterator for Data<'a> {}
342
343#[cfg(test)]
344mod tests {
345    use super::Element;
346    use super::StateMachine;
347
348    #[test]
349    fn basic() {
350        let mut machine = StateMachine::new();
351
352        let data = &[
353            0x81, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58,
354        ];
355        let mut iter = machine.feed(data);
356
357        assert_eq!(
358            iter.next().unwrap(),
359            Element::FrameStart {
360                fin: true,
361                length: 5,
362                opcode: 1
363            }
364        );
365
366        match iter.next().unwrap() {
367            Element::Data {
368                data,
369                last_in_frame,
370            } => {
371                assert!(last_in_frame);
372                assert_eq!(data.collect::<Vec<_>>(), b"Hello");
373            }
374            _ => panic!(),
375        }
376    }
377}