rouille/websocket/
websocket.rs

1// Copyright (c) 2016 The Rouille developers
2// Licensed under the Apache License, Version 2.0
3// <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>,
6// at your option. All files in the project carrying such
7// notice may not be copied, modified, or distributed except
8// according to those terms.
9
10use std::io;
11use std::io::Write;
12use std::mem;
13use std::sync::mpsc::Sender;
14use ReadWrite;
15use Upgrade;
16
17use websocket::low_level;
18
19/// A successful websocket. An open channel of communication. Implements `Read` and `Write`.
20pub struct Websocket {
21    // The socket. `None` if closed.
22    socket: Option<Box<dyn ReadWrite + Send>>,
23    // The websocket state machine.
24    state_machine: low_level::StateMachine,
25    // True if the fragmented message currently being processed is binary. False if string. Pings
26    // are excluded.
27    current_message_binary: bool,
28    // Buffer for the fragmented message currently being processed. Pings are excluded.
29    current_message_payload: Vec<u8>,
30    // Opcode of the fragment currently being processed.
31    current_frame_opcode: u8,
32    // Fin flag of the fragment currently being processed.
33    current_frame_fin: bool,
34    // Data of the fragment currently being processed.
35    current_frame_payload: Vec<u8>,
36    // Queue of the messages that are going to be returned by `next()`.
37    messages_in_queue: Vec<Message>,
38}
39
40/// A message produced by a websocket connection.
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub enum Message {
43    /// Text data. If the client is in Javascript, this happens when the client called `send()`
44    /// with a string.
45    Text(String),
46
47    /// Binary data. If the client is in Javascript, this happens when the client called `send()`
48    /// with a blob or an arraybuffer.
49    Binary(Vec<u8>),
50}
51
52/// Error that can happen when sending a message to the client.
53#[derive(Debug)]
54pub enum SendError {
55    /// Failed to transfer the message on the socket.
56    IoError(io::Error),
57
58    /// The websocket connection is closed.
59    Closed,
60}
61
62impl From<io::Error> for SendError {
63    #[inline]
64    fn from(err: io::Error) -> SendError {
65        SendError::IoError(err)
66    }
67}
68
69impl Websocket {
70    /// Sends text data over the websocket.
71    ///
72    /// Returns an error if the message didn't send correctly or if the connection is closed.
73    ///
74    /// If the client is in javascript, the message will contain a string.
75    #[inline]
76    pub fn send_text(&mut self, data: &str) -> Result<(), SendError> {
77        let socket = match self.socket {
78            Some(ref mut s) => s,
79            None => return Err(SendError::Closed),
80        };
81
82        send(data.as_bytes(), Write::by_ref(socket), 0x1)?;
83        Ok(())
84    }
85
86    /// Sends binary data over the websocket.
87    ///
88    /// Returns an error if the message didn't send correctly or if the connection is closed.
89    ///
90    /// If the client is in javascript, the message will contain a blob or an arraybuffer.
91    #[inline]
92    pub fn send_binary(&mut self, data: &[u8]) -> Result<(), SendError> {
93        let socket = match self.socket {
94            Some(ref mut s) => s,
95            None => return Err(SendError::Closed),
96        };
97
98        send(data, Write::by_ref(socket), 0x2)?;
99        Ok(())
100    }
101
102    /// Returns `true` if the websocket has been closed by either the client (voluntarily or not)
103    /// or by the server (if the websocket protocol was violated).
104    #[inline]
105    pub fn is_closed(&self) -> bool {
106        self.socket.is_none()
107    }
108
109    // TODO: give access to close reason
110}
111
112impl Upgrade for Sender<Websocket> {
113    fn build(&mut self, socket: Box<dyn ReadWrite + Send>) {
114        let websocket = Websocket {
115            socket: Some(socket),
116            state_machine: low_level::StateMachine::new(),
117            current_message_binary: false,
118            current_message_payload: Vec::new(),
119            current_frame_opcode: 0,
120            current_frame_fin: false,
121            current_frame_payload: Vec::new(),
122            messages_in_queue: Vec::new(),
123        };
124
125        let _ = self.send(websocket);
126    }
127}
128
129impl Iterator for Websocket {
130    type Item = Message;
131
132    fn next(&mut self) -> Option<Message> {
133        loop {
134            // If the socket is `None`, the connection has been closed.
135            self.socket.as_ref()?;
136
137            // There may be some messages waiting to be processed.
138            if !self.messages_in_queue.is_empty() {
139                return Some(self.messages_in_queue.remove(0));
140            }
141
142            // Read `n` bytes in `buf`.
143            let mut buf = [0; 256];
144            let n = match self.socket.as_mut().unwrap().read(&mut buf) {
145                Ok(n) if n == 0 => {
146                    // Read returning zero means EOF
147                    self.socket = None;
148                    return None;
149                }
150                Ok(n) => n,
151                Err(ref err) if err.kind() == io::ErrorKind::Interrupted => 0,
152                Err(_) => {
153                    self.socket = None;
154                    return None;
155                }
156            };
157
158            // Fill `messages_in_queue` by analyzing the packets.
159            for element in self.state_machine.feed(&buf[0..n]) {
160                match element {
161                    low_level::Element::FrameStart { fin, opcode, .. } => {
162                        debug_assert!(self.current_frame_payload.is_empty());
163                        self.current_frame_fin = fin;
164                        self.current_frame_opcode = opcode;
165                    }
166
167                    low_level::Element::Data {
168                        data,
169                        last_in_frame,
170                    } => {
171                        // Under normal circumstances we just handle data by pushing it to
172                        // `current_frame_payload`.
173                        self.current_frame_payload.extend(data);
174
175                        // But if the frame is finished we additionally need to dispatch it.
176                        if last_in_frame {
177                            match self.current_frame_opcode {
178                                // Frame is a continuation of the current message.
179                                0x0 => {
180                                    self.current_message_payload
181                                        .append(&mut self.current_frame_payload);
182
183                                    // If the message is finished, dispatch it.
184                                    if self.current_frame_fin {
185                                        let binary = mem::take(&mut self.current_message_payload);
186
187                                        if self.current_message_binary {
188                                            self.messages_in_queue.push(Message::Binary(binary));
189                                        } else {
190                                            let string = match String::from_utf8(binary) {
191                                                Ok(s) => s,
192                                                Err(_) => {
193                                                    // Closing connection because text wasn't UTF-8
194                                                    let _ = send(
195                                                        b"1007 Invalid UTF-8 encoding",
196                                                        Write::by_ref(
197                                                            self.socket.as_mut().unwrap(),
198                                                        ),
199                                                        0x8,
200                                                    );
201                                                    self.socket = None;
202                                                    return None;
203                                                }
204                                            };
205
206                                            self.messages_in_queue.push(Message::Text(string));
207                                        }
208                                    }
209                                }
210
211                                // Frame is an individual text frame.
212                                0x1 => {
213                                    // If we're in the middle of a message, this frame is invalid
214                                    // and we need to close.
215                                    if !self.current_message_payload.is_empty() {
216                                        let _ = send(
217                                            b"1002 Expected continuation frame",
218                                            Write::by_ref(self.socket.as_mut().unwrap()),
219                                            0x8,
220                                        );
221                                        self.socket = None;
222                                        return None;
223                                    }
224
225                                    if self.current_frame_fin {
226                                        // There's only one frame in this message.
227                                        let binary = mem::take(&mut self.current_frame_payload);
228                                        let string = match String::from_utf8(binary) {
229                                            Ok(s) => s,
230                                            Err(_err) => {
231                                                // Closing connection because text wasn't UTF-8
232                                                let _ = send(
233                                                    b"1007 Invalid UTF-8 encoding",
234                                                    Write::by_ref(self.socket.as_mut().unwrap()),
235                                                    0x8,
236                                                );
237                                                self.socket = None;
238                                                return None;
239                                            }
240                                        };
241
242                                        self.messages_in_queue.push(Message::Text(string));
243                                    } else {
244                                        // Start of a fragmented message.
245                                        self.current_message_binary = false;
246                                        self.current_message_payload
247                                            .append(&mut self.current_frame_payload);
248                                    }
249                                }
250
251                                // Frame is an individual binary frame.
252                                0x2 => {
253                                    // If we're in the middle of a message, this frame is invalid
254                                    // and we need to close.
255                                    if !self.current_message_payload.is_empty() {
256                                        let _ = send(
257                                            b"1002 Expected continuation frame",
258                                            Write::by_ref(self.socket.as_mut().unwrap()),
259                                            0x8,
260                                        );
261                                        self.socket = None;
262                                        return None;
263                                    }
264
265                                    if self.current_frame_fin {
266                                        let binary = mem::take(&mut self.current_frame_payload);
267                                        self.messages_in_queue.push(Message::Binary(binary));
268                                    } else {
269                                        // Start of a fragmented message.
270                                        self.current_message_binary = true;
271                                        self.current_message_payload
272                                            .append(&mut self.current_frame_payload);
273                                    }
274                                }
275
276                                // Close request.
277                                0x8 => {
278                                    // We need to send a confirmation.
279                                    let _ = send(
280                                        &self.current_frame_payload,
281                                        Write::by_ref(self.socket.as_mut().unwrap()),
282                                        0x8,
283                                    );
284                                    // Since the packets are always received in order, and since
285                                    // the server is considered dead as soon as it sends the
286                                    // confirmation, we have no risk of losing packets.
287                                    self.socket = None;
288                                    return None;
289                                }
290
291                                // Ping.
292                                0x9 => {
293                                    // Send the pong.
294                                    let _ = send(
295                                        &self.current_frame_payload,
296                                        Write::by_ref(self.socket.as_mut().unwrap()),
297                                        0xA,
298                                    );
299                                }
300
301                                // Pong. We ignore this as there's nothing to do.
302                                0xA => {}
303
304                                // Unknown opcode means error and close.
305                                _ => {
306                                    let _ = send(
307                                        b"Unknown opcode",
308                                        Write::by_ref(self.socket.as_mut().unwrap()),
309                                        0x8,
310                                    );
311                                    self.socket = None;
312                                    return None;
313                                }
314                            }
315
316                            self.current_frame_payload.clear();
317                        }
318                    }
319
320                    low_level::Element::Error { desc } => {
321                        // The low level layer signaled an error. Sending it to client and closing.
322                        let _ = send(
323                            desc.as_bytes(),
324                            Write::by_ref(self.socket.as_mut().unwrap()),
325                            0x8,
326                        );
327                        self.socket = None;
328                        return None;
329                    }
330                }
331            }
332        }
333    }
334}
335
336// Sends a message to a websocket.
337// TODO: message fragmentation?
338fn send<W: Write>(data: &[u8], mut dest: W, opcode: u8) -> io::Result<()> {
339    // Write the opcode
340    assert!(opcode <= 0xf);
341    let first_byte = 0x80 | opcode;
342    dest.write_all(&[first_byte])?;
343
344    // Write the length
345    if data.len() >= 65536 {
346        dest.write_all(&[127u8])?;
347        let len = data.len() as u64;
348        assert!(len < 0x8000_0000_0000_0000);
349        dest.write_all(&len.to_be_bytes())?;
350    } else if data.len() >= 126 {
351        dest.write_all(&[126u8])?;
352        let len = data.len() as u16;
353        dest.write_all(&len.to_be_bytes())?;
354    } else {
355        dest.write_all(&[data.len() as u8])?;
356    }
357
358    // Write the data
359    dest.write_all(data)?;
360    dest.flush()?;
361    Ok(())
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn test_ws_framing_short() {
370        let data: &[u8] = &[0xAB, 0xAB, 0xAB, 0xAB];
371        let mut buf = Vec::new();
372
373        send(data, &mut buf, 0x2).unwrap();
374
375        // Expected
376        // 0x82 (FIN = 1 | RSV1/2/3 = 0 | OPCODE = 2)
377        // 0x04 (len = 4 bytes)
378        // 0xABABABAB (payload = 4 bytes)
379        assert_eq!(&buf, &[0x82, 0x04, 0xAB, 0xAB, 0xAB, 0xAB]);
380    }
381
382    #[test]
383    fn test_ws_framing_medium() {
384        let data: [u8; 125] = [0xAB; 125];
385        let mut buf = Vec::new();
386
387        send(&data, &mut buf, 0x2).unwrap();
388
389        // Expected
390        // 0x82 (FIN = 1 | RSV1/2/3 = 0 | OPCODE = 2)
391        // 0x7D (len = 125 bytes)
392        // 0xABABABAB... (payload = 125 bytes)
393        assert_eq!(&buf[0..2], &[0x82, 0x7D]);
394        assert_eq!(&buf[2..], &data[..]);
395    }
396
397    #[test]
398    fn test_ws_framing_long() {
399        let data: [u8; 65534] = [0xAB; 65534];
400        let mut buf = Vec::new();
401
402        send(&data, &mut buf, 0x2).unwrap();
403
404        // Expected
405        // 0x82 (FIN = 1 | RSV1/2/3 = 0 | OPCODE = 2)
406        // 0x7E (len = 126 = extended 7+16)
407        // 0xFFFE (extended_len = 65534 - Network Byte Order)
408        // 0xABABABAB... (payload = 65534 bytes)
409        assert_eq!(&buf[0..4], &[0x82, 0x7E, 0xFF, 0xFE]);
410        assert_eq!(&buf[4..], &data[..]);
411    }
412
413    #[test]
414    fn test_ws_framing_very_long() {
415        let data: [u8; 0x100FF] = [0xAB; 0x100FF]; // 65791
416        let mut buf = Vec::new();
417
418        send(&data, &mut buf, 0x2).unwrap();
419
420        // Expected
421        // 0x82 (FIN = 1 | RSV1/2/3 = 0 | OPCODE = 2)
422        // 0x7F (len = 127 = extended 7+64)
423        // 0x00000000000100FF (extended_len = 65791 - Network Byte Order)
424        // 0xABABABAB... (payload = 65791 bytes)
425        assert_eq!(
426            &buf[0..10],
427            &[0x82, 0x7F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0xFF]
428        );
429        assert_eq!(&buf[10..], &data[..]);
430    }
431}