rouille/websocket/websocket.rs
1// Copyright (c) 2016 The Rouille developers
2// Licensed under the Apache License, Version 2.0
3// <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>,
6// at your option. All files in the project carrying such
7// notice may not be copied, modified, or distributed except
8// according to those terms.
9
10use std::io;
11use std::io::Write;
12use std::mem;
13use std::sync::mpsc::Sender;
14use ReadWrite;
15use Upgrade;
16
17use websocket::low_level;
18
19/// A successful websocket. An open channel of communication. Implements `Read` and `Write`.
20pub struct Websocket {
21 // The socket. `None` if closed.
22 socket: Option<Box<dyn ReadWrite + Send>>,
23 // The websocket state machine.
24 state_machine: low_level::StateMachine,
25 // True if the fragmented message currently being processed is binary. False if string. Pings
26 // are excluded.
27 current_message_binary: bool,
28 // Buffer for the fragmented message currently being processed. Pings are excluded.
29 current_message_payload: Vec<u8>,
30 // Opcode of the fragment currently being processed.
31 current_frame_opcode: u8,
32 // Fin flag of the fragment currently being processed.
33 current_frame_fin: bool,
34 // Data of the fragment currently being processed.
35 current_frame_payload: Vec<u8>,
36 // Queue of the messages that are going to be returned by `next()`.
37 messages_in_queue: Vec<Message>,
38}
39
40/// A message produced by a websocket connection.
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub enum Message {
43 /// Text data. If the client is in Javascript, this happens when the client called `send()`
44 /// with a string.
45 Text(String),
46
47 /// Binary data. If the client is in Javascript, this happens when the client called `send()`
48 /// with a blob or an arraybuffer.
49 Binary(Vec<u8>),
50}
51
52/// Error that can happen when sending a message to the client.
53#[derive(Debug)]
54pub enum SendError {
55 /// Failed to transfer the message on the socket.
56 IoError(io::Error),
57
58 /// The websocket connection is closed.
59 Closed,
60}
61
62impl From<io::Error> for SendError {
63 #[inline]
64 fn from(err: io::Error) -> SendError {
65 SendError::IoError(err)
66 }
67}
68
69impl Websocket {
70 /// Sends text data over the websocket.
71 ///
72 /// Returns an error if the message didn't send correctly or if the connection is closed.
73 ///
74 /// If the client is in javascript, the message will contain a string.
75 #[inline]
76 pub fn send_text(&mut self, data: &str) -> Result<(), SendError> {
77 let socket = match self.socket {
78 Some(ref mut s) => s,
79 None => return Err(SendError::Closed),
80 };
81
82 send(data.as_bytes(), Write::by_ref(socket), 0x1)?;
83 Ok(())
84 }
85
86 /// Sends binary data over the websocket.
87 ///
88 /// Returns an error if the message didn't send correctly or if the connection is closed.
89 ///
90 /// If the client is in javascript, the message will contain a blob or an arraybuffer.
91 #[inline]
92 pub fn send_binary(&mut self, data: &[u8]) -> Result<(), SendError> {
93 let socket = match self.socket {
94 Some(ref mut s) => s,
95 None => return Err(SendError::Closed),
96 };
97
98 send(data, Write::by_ref(socket), 0x2)?;
99 Ok(())
100 }
101
102 /// Returns `true` if the websocket has been closed by either the client (voluntarily or not)
103 /// or by the server (if the websocket protocol was violated).
104 #[inline]
105 pub fn is_closed(&self) -> bool {
106 self.socket.is_none()
107 }
108
109 // TODO: give access to close reason
110}
111
112impl Upgrade for Sender<Websocket> {
113 fn build(&mut self, socket: Box<dyn ReadWrite + Send>) {
114 let websocket = Websocket {
115 socket: Some(socket),
116 state_machine: low_level::StateMachine::new(),
117 current_message_binary: false,
118 current_message_payload: Vec::new(),
119 current_frame_opcode: 0,
120 current_frame_fin: false,
121 current_frame_payload: Vec::new(),
122 messages_in_queue: Vec::new(),
123 };
124
125 let _ = self.send(websocket);
126 }
127}
128
129impl Iterator for Websocket {
130 type Item = Message;
131
132 fn next(&mut self) -> Option<Message> {
133 loop {
134 // If the socket is `None`, the connection has been closed.
135 self.socket.as_ref()?;
136
137 // There may be some messages waiting to be processed.
138 if !self.messages_in_queue.is_empty() {
139 return Some(self.messages_in_queue.remove(0));
140 }
141
142 // Read `n` bytes in `buf`.
143 let mut buf = [0; 256];
144 let n = match self.socket.as_mut().unwrap().read(&mut buf) {
145 Ok(n) if n == 0 => {
146 // Read returning zero means EOF
147 self.socket = None;
148 return None;
149 }
150 Ok(n) => n,
151 Err(ref err) if err.kind() == io::ErrorKind::Interrupted => 0,
152 Err(_) => {
153 self.socket = None;
154 return None;
155 }
156 };
157
158 // Fill `messages_in_queue` by analyzing the packets.
159 for element in self.state_machine.feed(&buf[0..n]) {
160 match element {
161 low_level::Element::FrameStart { fin, opcode, .. } => {
162 debug_assert!(self.current_frame_payload.is_empty());
163 self.current_frame_fin = fin;
164 self.current_frame_opcode = opcode;
165 }
166
167 low_level::Element::Data {
168 data,
169 last_in_frame,
170 } => {
171 // Under normal circumstances we just handle data by pushing it to
172 // `current_frame_payload`.
173 self.current_frame_payload.extend(data);
174
175 // But if the frame is finished we additionally need to dispatch it.
176 if last_in_frame {
177 match self.current_frame_opcode {
178 // Frame is a continuation of the current message.
179 0x0 => {
180 self.current_message_payload
181 .append(&mut self.current_frame_payload);
182
183 // If the message is finished, dispatch it.
184 if self.current_frame_fin {
185 let binary = mem::take(&mut self.current_message_payload);
186
187 if self.current_message_binary {
188 self.messages_in_queue.push(Message::Binary(binary));
189 } else {
190 let string = match String::from_utf8(binary) {
191 Ok(s) => s,
192 Err(_) => {
193 // Closing connection because text wasn't UTF-8
194 let _ = send(
195 b"1007 Invalid UTF-8 encoding",
196 Write::by_ref(
197 self.socket.as_mut().unwrap(),
198 ),
199 0x8,
200 );
201 self.socket = None;
202 return None;
203 }
204 };
205
206 self.messages_in_queue.push(Message::Text(string));
207 }
208 }
209 }
210
211 // Frame is an individual text frame.
212 0x1 => {
213 // If we're in the middle of a message, this frame is invalid
214 // and we need to close.
215 if !self.current_message_payload.is_empty() {
216 let _ = send(
217 b"1002 Expected continuation frame",
218 Write::by_ref(self.socket.as_mut().unwrap()),
219 0x8,
220 );
221 self.socket = None;
222 return None;
223 }
224
225 if self.current_frame_fin {
226 // There's only one frame in this message.
227 let binary = mem::take(&mut self.current_frame_payload);
228 let string = match String::from_utf8(binary) {
229 Ok(s) => s,
230 Err(_err) => {
231 // Closing connection because text wasn't UTF-8
232 let _ = send(
233 b"1007 Invalid UTF-8 encoding",
234 Write::by_ref(self.socket.as_mut().unwrap()),
235 0x8,
236 );
237 self.socket = None;
238 return None;
239 }
240 };
241
242 self.messages_in_queue.push(Message::Text(string));
243 } else {
244 // Start of a fragmented message.
245 self.current_message_binary = false;
246 self.current_message_payload
247 .append(&mut self.current_frame_payload);
248 }
249 }
250
251 // Frame is an individual binary frame.
252 0x2 => {
253 // If we're in the middle of a message, this frame is invalid
254 // and we need to close.
255 if !self.current_message_payload.is_empty() {
256 let _ = send(
257 b"1002 Expected continuation frame",
258 Write::by_ref(self.socket.as_mut().unwrap()),
259 0x8,
260 );
261 self.socket = None;
262 return None;
263 }
264
265 if self.current_frame_fin {
266 let binary = mem::take(&mut self.current_frame_payload);
267 self.messages_in_queue.push(Message::Binary(binary));
268 } else {
269 // Start of a fragmented message.
270 self.current_message_binary = true;
271 self.current_message_payload
272 .append(&mut self.current_frame_payload);
273 }
274 }
275
276 // Close request.
277 0x8 => {
278 // We need to send a confirmation.
279 let _ = send(
280 &self.current_frame_payload,
281 Write::by_ref(self.socket.as_mut().unwrap()),
282 0x8,
283 );
284 // Since the packets are always received in order, and since
285 // the server is considered dead as soon as it sends the
286 // confirmation, we have no risk of losing packets.
287 self.socket = None;
288 return None;
289 }
290
291 // Ping.
292 0x9 => {
293 // Send the pong.
294 let _ = send(
295 &self.current_frame_payload,
296 Write::by_ref(self.socket.as_mut().unwrap()),
297 0xA,
298 );
299 }
300
301 // Pong. We ignore this as there's nothing to do.
302 0xA => {}
303
304 // Unknown opcode means error and close.
305 _ => {
306 let _ = send(
307 b"Unknown opcode",
308 Write::by_ref(self.socket.as_mut().unwrap()),
309 0x8,
310 );
311 self.socket = None;
312 return None;
313 }
314 }
315
316 self.current_frame_payload.clear();
317 }
318 }
319
320 low_level::Element::Error { desc } => {
321 // The low level layer signaled an error. Sending it to client and closing.
322 let _ = send(
323 desc.as_bytes(),
324 Write::by_ref(self.socket.as_mut().unwrap()),
325 0x8,
326 );
327 self.socket = None;
328 return None;
329 }
330 }
331 }
332 }
333 }
334}
335
336// Sends a message to a websocket.
337// TODO: message fragmentation?
338fn send<W: Write>(data: &[u8], mut dest: W, opcode: u8) -> io::Result<()> {
339 // Write the opcode
340 assert!(opcode <= 0xf);
341 let first_byte = 0x80 | opcode;
342 dest.write_all(&[first_byte])?;
343
344 // Write the length
345 if data.len() >= 65536 {
346 dest.write_all(&[127u8])?;
347 let len = data.len() as u64;
348 assert!(len < 0x8000_0000_0000_0000);
349 dest.write_all(&len.to_be_bytes())?;
350 } else if data.len() >= 126 {
351 dest.write_all(&[126u8])?;
352 let len = data.len() as u16;
353 dest.write_all(&len.to_be_bytes())?;
354 } else {
355 dest.write_all(&[data.len() as u8])?;
356 }
357
358 // Write the data
359 dest.write_all(data)?;
360 dest.flush()?;
361 Ok(())
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367
368 #[test]
369 fn test_ws_framing_short() {
370 let data: &[u8] = &[0xAB, 0xAB, 0xAB, 0xAB];
371 let mut buf = Vec::new();
372
373 send(data, &mut buf, 0x2).unwrap();
374
375 // Expected
376 // 0x82 (FIN = 1 | RSV1/2/3 = 0 | OPCODE = 2)
377 // 0x04 (len = 4 bytes)
378 // 0xABABABAB (payload = 4 bytes)
379 assert_eq!(&buf, &[0x82, 0x04, 0xAB, 0xAB, 0xAB, 0xAB]);
380 }
381
382 #[test]
383 fn test_ws_framing_medium() {
384 let data: [u8; 125] = [0xAB; 125];
385 let mut buf = Vec::new();
386
387 send(&data, &mut buf, 0x2).unwrap();
388
389 // Expected
390 // 0x82 (FIN = 1 | RSV1/2/3 = 0 | OPCODE = 2)
391 // 0x7D (len = 125 bytes)
392 // 0xABABABAB... (payload = 125 bytes)
393 assert_eq!(&buf[0..2], &[0x82, 0x7D]);
394 assert_eq!(&buf[2..], &data[..]);
395 }
396
397 #[test]
398 fn test_ws_framing_long() {
399 let data: [u8; 65534] = [0xAB; 65534];
400 let mut buf = Vec::new();
401
402 send(&data, &mut buf, 0x2).unwrap();
403
404 // Expected
405 // 0x82 (FIN = 1 | RSV1/2/3 = 0 | OPCODE = 2)
406 // 0x7E (len = 126 = extended 7+16)
407 // 0xFFFE (extended_len = 65534 - Network Byte Order)
408 // 0xABABABAB... (payload = 65534 bytes)
409 assert_eq!(&buf[0..4], &[0x82, 0x7E, 0xFF, 0xFE]);
410 assert_eq!(&buf[4..], &data[..]);
411 }
412
413 #[test]
414 fn test_ws_framing_very_long() {
415 let data: [u8; 0x100FF] = [0xAB; 0x100FF]; // 65791
416 let mut buf = Vec::new();
417
418 send(&data, &mut buf, 0x2).unwrap();
419
420 // Expected
421 // 0x82 (FIN = 1 | RSV1/2/3 = 0 | OPCODE = 2)
422 // 0x7F (len = 127 = extended 7+64)
423 // 0x00000000000100FF (extended_len = 65791 - Network Byte Order)
424 // 0xABABABAB... (payload = 65791 bytes)
425 assert_eq!(
426 &buf[0..10],
427 &[0x82, 0x7F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0xFF]
428 );
429 assert_eq!(&buf[10..], &data[..]);
430 }
431}