chunked_transfer/
decoder.rs

1use std::error::Error;
2use std::fmt;
3use std::io::Error as IoError;
4use std::io::ErrorKind;
5use std::io::Read;
6use std::io::Result as IoResult;
7
8/// Reads HTTP chunks and sends back real data.
9///
10/// # Example
11///
12/// ```
13/// use chunked_transfer::Decoder;
14/// use std::io::Read;
15///
16/// let encoded = b"3\r\nhel\r\nb\r\nlo world!!!\r\n0\r\n\r\n";
17/// let mut decoded = String::new();
18///
19/// let mut decoder = Decoder::new(encoded as &[u8]);
20/// decoder.read_to_string(&mut decoded);
21///
22/// assert_eq!(decoded, "hello world!!!");
23/// ```
24pub struct Decoder<R> {
25    // where the chunks come from
26    source: R,
27
28    // remaining size of the chunk being read
29    // none if we are not in a chunk
30    remaining_chunks_size: Option<usize>,
31}
32
33impl<R> Decoder<R>
34where
35    R: Read,
36{
37    pub fn new(source: R) -> Decoder<R> {
38        Decoder {
39            source,
40            remaining_chunks_size: None,
41        }
42    }
43
44    /// Returns the remaining bytes left in the chunk being read.
45    pub fn remaining_chunks_size(&self) -> Option<usize> {
46        self.remaining_chunks_size
47    }
48
49    /// Unwraps the Decoder into its inner `Read` source.
50    pub fn into_inner(self) -> R {
51        self.source
52    }
53
54    /// Gets a reference to the underlying value in this decoder.
55    pub fn get_ref(&self) -> &R {
56        &self.source
57    }
58
59    /// Gets a mutable reference to the underlying value in this decoder.
60    pub fn get_mut(&mut self) -> &mut R {
61        &mut self.source
62    }
63
64    fn read_chunk_size(&mut self) -> IoResult<usize> {
65        let mut chunk_size_bytes = Vec::new();
66        let mut has_ext = false;
67
68        loop {
69            let byte = match self.source.by_ref().bytes().next() {
70                Some(b) => b?,
71                None => return Err(IoError::new(ErrorKind::InvalidInput, DecoderError)),
72            };
73
74            if byte == b'\r' {
75                break;
76            }
77
78            if byte == b';' {
79                has_ext = true;
80                break;
81            }
82
83            chunk_size_bytes.push(byte);
84        }
85
86        // Ignore extensions for now
87        if has_ext {
88            loop {
89                let byte = match self.source.by_ref().bytes().next() {
90                    Some(b) => b?,
91                    None => return Err(IoError::new(ErrorKind::InvalidInput, DecoderError)),
92                };
93                if byte == b'\r' {
94                    break;
95                }
96            }
97        }
98
99        self.read_line_feed()?;
100
101        let chunk_size = String::from_utf8(chunk_size_bytes)
102            .ok()
103            .and_then(|c| usize::from_str_radix(c.trim(), 16).ok())
104            .ok_or_else(|| IoError::new(ErrorKind::InvalidInput, DecoderError))?;
105
106        Ok(chunk_size)
107    }
108
109    fn read_carriage_return(&mut self) -> IoResult<()> {
110        match self.source.by_ref().bytes().next() {
111            Some(Ok(b'\r')) => Ok(()),
112            _ => Err(IoError::new(ErrorKind::InvalidInput, DecoderError)),
113        }
114    }
115
116    fn read_line_feed(&mut self) -> IoResult<()> {
117        match self.source.by_ref().bytes().next() {
118            Some(Ok(b'\n')) => Ok(()),
119            _ => Err(IoError::new(ErrorKind::InvalidInput, DecoderError)),
120        }
121    }
122}
123
124impl<R> Read for Decoder<R>
125where
126    R: Read,
127{
128    fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
129        let remaining_chunks_size = match self.remaining_chunks_size {
130            Some(c) => c,
131            None => {
132                // first possibility: we are not in a chunk, so we'll attempt to determine
133                // the chunks size
134                let chunk_size = self.read_chunk_size()?;
135
136                // if the chunk size is 0, we are at EOF
137                if chunk_size == 0 {
138                    self.read_carriage_return()?;
139                    self.read_line_feed()?;
140                    return Ok(0);
141                }
142
143                chunk_size
144            }
145        };
146
147        // second possibility: we continue reading from a chunk
148        if buf.len() < remaining_chunks_size {
149            let read = self.source.read(buf)?;
150            self.remaining_chunks_size = Some(remaining_chunks_size - read);
151            return Ok(read);
152        }
153
154        // third possibility: the read request goes further than the current chunk
155        // we simply read until the end of the chunk and return
156        assert!(buf.len() >= remaining_chunks_size);
157
158        let buf = &mut buf[..remaining_chunks_size];
159        let read = self.source.read(buf)?;
160
161        self.remaining_chunks_size = if read == remaining_chunks_size {
162            self.read_carriage_return()?;
163            self.read_line_feed()?;
164            None
165        } else {
166            Some(remaining_chunks_size - read)
167        };
168
169        Ok(read)
170    }
171}
172
173#[derive(Debug, Copy, Clone)]
174struct DecoderError;
175
176impl fmt::Display for DecoderError {
177    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
178        write!(fmt, "Error while decoding chunks")
179    }
180}
181
182impl Error for DecoderError {
183    fn description(&self) -> &str {
184        "Error while decoding chunks"
185    }
186}
187
188#[cfg(test)]
189mod test {
190    use super::Decoder;
191    use std::io;
192    use std::io::Read;
193
194    /// This unit test is taken from from Hyper
195    /// https://github.com/hyperium/hyper
196    /// Copyright (c) 2014 Sean McArthur
197    #[test]
198    fn test_read_chunk_size() {
199        fn read(s: &str, expected: usize) {
200            let mut decoded = Decoder::new(s.as_bytes());
201            let actual = decoded.read_chunk_size().unwrap();
202            assert_eq!(expected, actual);
203        }
204
205        fn read_err(s: &str) {
206            let mut decoded = Decoder::new(s.as_bytes());
207            let err_kind = decoded.read_chunk_size().unwrap_err().kind();
208            assert_eq!(err_kind, io::ErrorKind::InvalidInput);
209        }
210
211        read("1\r\n", 1);
212        read("01\r\n", 1);
213        read("0\r\n", 0);
214        read("00\r\n", 0);
215        read("A\r\n", 10);
216        read("a\r\n", 10);
217        read("Ff\r\n", 255);
218        read("Ff   \r\n", 255);
219        // Missing LF or CRLF
220        read_err("F\rF");
221        read_err("F");
222        // Invalid hex digit
223        read_err("X\r\n");
224        read_err("1X\r\n");
225        read_err("-\r\n");
226        read_err("-1\r\n");
227        // Acceptable (if not fully valid) extensions do not influence the size
228        read("1;extension\r\n", 1);
229        read("a;ext name=value\r\n", 10);
230        read("1;extension;extension2\r\n", 1);
231        read("1;;;  ;\r\n", 1);
232        read("2; extension...\r\n", 2);
233        read("3   ; extension=123\r\n", 3);
234        read("3   ;\r\n", 3);
235        read("3   ;   \r\n", 3);
236        // Invalid extensions cause an error
237        read_err("1 invalid extension\r\n");
238        read_err("1 A\r\n");
239        read_err("1;no CRLF");
240    }
241
242    #[test]
243    fn test_valid_chunk_decode() {
244        let source = io::Cursor::new(
245            "3\r\nhel\r\nb\r\nlo world!!!\r\n0\r\n\r\n"
246                .to_string()
247                .into_bytes(),
248        );
249        let mut decoded = Decoder::new(source);
250
251        let mut string = String::new();
252        decoded.read_to_string(&mut string).unwrap();
253
254        assert_eq!(string, "hello world!!!");
255    }
256
257    #[test]
258    fn test_decode_zero_length() {
259        let mut decoder = Decoder::new(b"0\r\n\r\n" as &[u8]);
260
261        let mut decoded = String::new();
262        decoder.read_to_string(&mut decoded).unwrap();
263
264        assert_eq!(decoded, "");
265    }
266
267    #[test]
268    fn test_decode_invalid_chunk_length() {
269        let mut decoder = Decoder::new(b"m\r\n\r\n" as &[u8]);
270
271        let mut decoded = String::new();
272        assert!(decoder.read_to_string(&mut decoded).is_err());
273    }
274
275    #[test]
276    fn invalid_input1() {
277        let source = io::Cursor::new(
278            "2\r\nhel\r\nb\r\nlo world!!!\r\n0\r\n"
279                .to_string()
280                .into_bytes(),
281        );
282        let mut decoded = Decoder::new(source);
283
284        let mut string = String::new();
285        assert!(decoded.read_to_string(&mut string).is_err());
286    }
287
288    #[test]
289    fn invalid_input2() {
290        let source = io::Cursor::new(
291            "3\rhel\r\nb\r\nlo world!!!\r\n0\r\n"
292                .to_string()
293                .into_bytes(),
294        );
295        let mut decoded = Decoder::new(source);
296
297        let mut string = String::new();
298        assert!(decoded.read_to_string(&mut string).is_err());
299    }
300}