ureq/chunked/
decoder.rs

1// Copyright 2015 The tiny-http Contributors
2// Copyright 2015 The rust-chunked-transfer Contributors
3// Forked into ureq, 2024, from https://github.com/frewsxcv/rust-chunked-transfer
4// Forked under dual MIT and Apache 2.0 license (see adjacent LICENSE-MIT and LICENSE-APACHE file)
5
6use std::error::Error;
7use std::fmt;
8use std::io::Error as IoError;
9use std::io::ErrorKind;
10use std::io::Read;
11use std::io::Result as IoResult;
12
13/// Reads HTTP chunks and sends back real data.
14///
15/// # Example
16///
17/// ```no_compile
18/// use chunked_transfer::Decoder;
19/// use std::io::Read;
20///
21/// let encoded = b"3\r\nhel\r\nb\r\nlo world!!!\r\n0\r\n\r\n";
22/// let mut decoded = String::new();
23///
24/// let mut decoder = Decoder::new(encoded as &[u8]);
25/// decoder.read_to_string(&mut decoded);
26///
27/// assert_eq!(decoded, "hello world!!!");
28/// ```
29pub struct Decoder<R> {
30    // where the chunks come from
31    source: R,
32
33    // remaining size of the chunk being read
34    // none if we are not in a chunk
35    remaining_chunks_size: Option<usize>,
36}
37
38impl<R> Decoder<R>
39where
40    R: Read,
41{
42    pub fn new(source: R) -> Decoder<R> {
43        Decoder {
44            source,
45            remaining_chunks_size: None,
46        }
47    }
48
49    /// Unwraps the Decoder into its inner `Read` source.
50    pub fn into_inner(self) -> R {
51        self.source
52    }
53
54    fn read_chunk_size(&mut self) -> IoResult<usize> {
55        let mut chunk_size_bytes = Vec::new();
56        let mut has_ext = false;
57
58        loop {
59            let byte = match self.source.by_ref().bytes().next() {
60                Some(b) => b?,
61                None => return Err(IoError::new(ErrorKind::InvalidInput, DecoderError)),
62            };
63
64            if byte == b'\r' {
65                break;
66            }
67
68            if byte == b';' {
69                has_ext = true;
70                break;
71            }
72
73            chunk_size_bytes.push(byte);
74        }
75
76        // Ignore extensions for now
77        if has_ext {
78            loop {
79                let byte = match self.source.by_ref().bytes().next() {
80                    Some(b) => b?,
81                    None => return Err(IoError::new(ErrorKind::InvalidInput, DecoderError)),
82                };
83                if byte == b'\r' {
84                    break;
85                }
86            }
87        }
88
89        self.read_line_feed()?;
90
91        let chunk_size = String::from_utf8(chunk_size_bytes)
92            .ok()
93            .and_then(|c| usize::from_str_radix(c.trim(), 16).ok())
94            .ok_or_else(|| IoError::new(ErrorKind::InvalidInput, DecoderError))?;
95
96        Ok(chunk_size)
97    }
98
99    fn read_carriage_return(&mut self) -> IoResult<()> {
100        match self.source.by_ref().bytes().next() {
101            Some(Ok(b'\r')) => Ok(()),
102            _ => Err(IoError::new(ErrorKind::InvalidInput, DecoderError)),
103        }
104    }
105
106    fn read_line_feed(&mut self) -> IoResult<()> {
107        match self.source.by_ref().bytes().next() {
108            Some(Ok(b'\n')) => Ok(()),
109            _ => Err(IoError::new(ErrorKind::InvalidInput, DecoderError)),
110        }
111    }
112
113    // Sometimes the last \r\n is missing.
114    fn read_end(&mut self) -> IoResult<()> {
115        fn expect_or_end(
116            bytes: &mut impl Iterator<Item = IoResult<u8>>,
117            expected: u8,
118        ) -> IoResult<()> {
119            match bytes.next() {
120                Some(Ok(c)) => {
121                    if c == expected {
122                        Ok(())
123                    } else {
124                        Err(IoError::new(ErrorKind::InvalidInput, DecoderError))
125                    }
126                }
127                Some(Err(e)) => {
128                    match e.kind() {
129                        // Closed connections are ok.
130                        ErrorKind::ConnectionReset | ErrorKind::ConnectionAborted => Ok(()),
131                        _ => Err(IoError::new(ErrorKind::InvalidInput, DecoderError)),
132                    }
133                }
134                None => Ok(()), // End of iterator is ok
135            }
136        }
137
138        let mut bytes = self.source.by_ref().bytes();
139
140        expect_or_end(&mut bytes, b'\r')?;
141        expect_or_end(&mut bytes, b'\n')?;
142
143        Ok(())
144    }
145}
146
147impl<R> Read for Decoder<R>
148where
149    R: Read,
150{
151    fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
152        let remaining_chunks_size = match self.remaining_chunks_size {
153            Some(c) => c,
154            None => {
155                // first possibility: we are not in a chunk, so we'll attempt to determine
156                // the chunks size
157                let chunk_size = self.read_chunk_size()?;
158
159                // if the chunk size is 0, we are at EOF
160                if chunk_size == 0 {
161                    self.read_end()?;
162                    return Ok(0);
163                }
164
165                chunk_size
166            }
167        };
168
169        // second possibility: we continue reading from a chunk
170        if buf.len() < remaining_chunks_size {
171            let read = self.source.read(buf)?;
172            self.remaining_chunks_size = Some(remaining_chunks_size - read);
173            return Ok(read);
174        }
175
176        // third possibility: the read request goes further than the current chunk
177        // we simply read until the end of the chunk and return
178        assert!(buf.len() >= remaining_chunks_size);
179
180        let buf = &mut buf[..remaining_chunks_size];
181        let read = self.source.read(buf)?;
182
183        self.remaining_chunks_size = if read == remaining_chunks_size {
184            self.read_carriage_return()?;
185            self.read_line_feed()?;
186            None
187        } else {
188            Some(remaining_chunks_size - read)
189        };
190
191        Ok(read)
192    }
193}
194
195#[derive(Debug, Copy, Clone)]
196struct DecoderError;
197
198impl fmt::Display for DecoderError {
199    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
200        write!(fmt, "Error while decoding chunks")
201    }
202}
203
204impl Error for DecoderError {
205    fn description(&self) -> &str {
206        "Error while decoding chunks"
207    }
208}
209
210#[cfg(test)]
211mod test {
212    use super::Decoder;
213    use std::io;
214    use std::io::Read;
215
216    /// This unit test is taken from from Hyper
217    /// https://github.com/hyperium/hyper
218    /// Copyright (c) 2014 Sean McArthur
219    #[test]
220    fn test_read_chunk_size() {
221        fn read(s: &str, expected: usize) {
222            let mut decoded = Decoder::new(s.as_bytes());
223            let actual = decoded.read_chunk_size().unwrap();
224            assert_eq!(expected, actual);
225        }
226
227        fn read_err(s: &str) {
228            let mut decoded = Decoder::new(s.as_bytes());
229            let err_kind = decoded.read_chunk_size().unwrap_err().kind();
230            assert_eq!(err_kind, io::ErrorKind::InvalidInput);
231        }
232
233        read("1\r\n", 1);
234        read("01\r\n", 1);
235        read("0\r\n", 0);
236        read("00\r\n", 0);
237        read("A\r\n", 10);
238        read("a\r\n", 10);
239        read("Ff\r\n", 255);
240        read("Ff   \r\n", 255);
241        // Missing LF or CRLF
242        read_err("F\rF");
243        read_err("F");
244        // Invalid hex digit
245        read_err("X\r\n");
246        read_err("1X\r\n");
247        read_err("-\r\n");
248        read_err("-1\r\n");
249        // Acceptable (if not fully valid) extensions do not influence the size
250        read("1;extension\r\n", 1);
251        read("a;ext name=value\r\n", 10);
252        read("1;extension;extension2\r\n", 1);
253        read("1;;;  ;\r\n", 1);
254        read("2; extension...\r\n", 2);
255        read("3   ; extension=123\r\n", 3);
256        read("3   ;\r\n", 3);
257        read("3   ;   \r\n", 3);
258        // Invalid extensions cause an error
259        read_err("1 invalid extension\r\n");
260        read_err("1 A\r\n");
261        read_err("1;no CRLF");
262    }
263
264    #[test]
265    fn test_valid_chunk_decode() {
266        let source = io::Cursor::new(
267            "3\r\nhel\r\nb\r\nlo world!!!\r\n0\r\n\r\n"
268                .to_string()
269                .into_bytes(),
270        );
271        let mut decoded = Decoder::new(source);
272
273        let mut string = String::new();
274        decoded.read_to_string(&mut string).unwrap();
275
276        assert_eq!(string, "hello world!!!");
277    }
278
279    #[test]
280    fn test_decode_zero_length() {
281        let mut decoder = Decoder::new(b"0\r\n\r\n" as &[u8]);
282
283        let mut decoded = String::new();
284        decoder.read_to_string(&mut decoded).unwrap();
285
286        assert_eq!(decoded, "");
287    }
288
289    #[test]
290    fn test_decode_invalid_chunk_length() {
291        let mut decoder = Decoder::new(b"m\r\n\r\n" as &[u8]);
292
293        let mut decoded = String::new();
294        assert!(decoder.read_to_string(&mut decoded).is_err());
295    }
296
297    #[test]
298    fn invalid_input1() {
299        let source = io::Cursor::new(
300            "2\r\nhel\r\nb\r\nlo world!!!\r\n0\r\n"
301                .to_string()
302                .into_bytes(),
303        );
304        let mut decoded = Decoder::new(source);
305
306        let mut string = String::new();
307        assert!(decoded.read_to_string(&mut string).is_err());
308    }
309
310    #[test]
311    fn invalid_input2() {
312        let source = io::Cursor::new(
313            "3\rhel\r\nb\r\nlo world!!!\r\n0\r\n"
314                .to_string()
315                .into_bytes(),
316        );
317        let mut decoded = Decoder::new(source);
318
319        let mut string = String::new();
320        assert!(decoded.read_to_string(&mut string).is_err());
321    }
322
323    #[test]
324    fn test_decode_end_missing_last_crlf() {
325        // This has been observed in the wild.
326        // See https://github.com/algesten/ureq/issues/325
327
328        // Missing last \r\n
329        let source = io::Cursor::new(
330            "3\r\nhel\r\nb\r\nlo world!!!\r\n0\r\n"
331                .to_string()
332                .into_bytes(),
333        );
334        let mut decoded = Decoder::new(source);
335
336        let mut string = String::new();
337        decoded.read_to_string(&mut string).unwrap();
338
339        assert_eq!(string, "hello world!!!");
340    }
341}