ureq/chunked/
decoder.rs
1use std::error::Error;
7use std::fmt;
8use std::io::Error as IoError;
9use std::io::ErrorKind;
10use std::io::Read;
11use std::io::Result as IoResult;
12
13pub struct Decoder<R> {
30 source: R,
32
33 remaining_chunks_size: Option<usize>,
36}
37
38impl<R> Decoder<R>
39where
40 R: Read,
41{
42 pub fn new(source: R) -> Decoder<R> {
43 Decoder {
44 source,
45 remaining_chunks_size: None,
46 }
47 }
48
49 pub fn into_inner(self) -> R {
51 self.source
52 }
53
54 fn read_chunk_size(&mut self) -> IoResult<usize> {
55 let mut chunk_size_bytes = Vec::new();
56 let mut has_ext = false;
57
58 loop {
59 let byte = match self.source.by_ref().bytes().next() {
60 Some(b) => b?,
61 None => return Err(IoError::new(ErrorKind::InvalidInput, DecoderError)),
62 };
63
64 if byte == b'\r' {
65 break;
66 }
67
68 if byte == b';' {
69 has_ext = true;
70 break;
71 }
72
73 chunk_size_bytes.push(byte);
74 }
75
76 if has_ext {
78 loop {
79 let byte = match self.source.by_ref().bytes().next() {
80 Some(b) => b?,
81 None => return Err(IoError::new(ErrorKind::InvalidInput, DecoderError)),
82 };
83 if byte == b'\r' {
84 break;
85 }
86 }
87 }
88
89 self.read_line_feed()?;
90
91 let chunk_size = String::from_utf8(chunk_size_bytes)
92 .ok()
93 .and_then(|c| usize::from_str_radix(c.trim(), 16).ok())
94 .ok_or_else(|| IoError::new(ErrorKind::InvalidInput, DecoderError))?;
95
96 Ok(chunk_size)
97 }
98
99 fn read_carriage_return(&mut self) -> IoResult<()> {
100 match self.source.by_ref().bytes().next() {
101 Some(Ok(b'\r')) => Ok(()),
102 _ => Err(IoError::new(ErrorKind::InvalidInput, DecoderError)),
103 }
104 }
105
106 fn read_line_feed(&mut self) -> IoResult<()> {
107 match self.source.by_ref().bytes().next() {
108 Some(Ok(b'\n')) => Ok(()),
109 _ => Err(IoError::new(ErrorKind::InvalidInput, DecoderError)),
110 }
111 }
112
113 fn read_end(&mut self) -> IoResult<()> {
115 fn expect_or_end(
116 bytes: &mut impl Iterator<Item = IoResult<u8>>,
117 expected: u8,
118 ) -> IoResult<()> {
119 match bytes.next() {
120 Some(Ok(c)) => {
121 if c == expected {
122 Ok(())
123 } else {
124 Err(IoError::new(ErrorKind::InvalidInput, DecoderError))
125 }
126 }
127 Some(Err(e)) => {
128 match e.kind() {
129 ErrorKind::ConnectionReset | ErrorKind::ConnectionAborted => Ok(()),
131 _ => Err(IoError::new(ErrorKind::InvalidInput, DecoderError)),
132 }
133 }
134 None => Ok(()), }
136 }
137
138 let mut bytes = self.source.by_ref().bytes();
139
140 expect_or_end(&mut bytes, b'\r')?;
141 expect_or_end(&mut bytes, b'\n')?;
142
143 Ok(())
144 }
145}
146
147impl<R> Read for Decoder<R>
148where
149 R: Read,
150{
151 fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
152 let remaining_chunks_size = match self.remaining_chunks_size {
153 Some(c) => c,
154 None => {
155 let chunk_size = self.read_chunk_size()?;
158
159 if chunk_size == 0 {
161 self.read_end()?;
162 return Ok(0);
163 }
164
165 chunk_size
166 }
167 };
168
169 if buf.len() < remaining_chunks_size {
171 let read = self.source.read(buf)?;
172 self.remaining_chunks_size = Some(remaining_chunks_size - read);
173 return Ok(read);
174 }
175
176 assert!(buf.len() >= remaining_chunks_size);
179
180 let buf = &mut buf[..remaining_chunks_size];
181 let read = self.source.read(buf)?;
182
183 self.remaining_chunks_size = if read == remaining_chunks_size {
184 self.read_carriage_return()?;
185 self.read_line_feed()?;
186 None
187 } else {
188 Some(remaining_chunks_size - read)
189 };
190
191 Ok(read)
192 }
193}
194
195#[derive(Debug, Copy, Clone)]
196struct DecoderError;
197
198impl fmt::Display for DecoderError {
199 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
200 write!(fmt, "Error while decoding chunks")
201 }
202}
203
204impl Error for DecoderError {
205 fn description(&self) -> &str {
206 "Error while decoding chunks"
207 }
208}
209
210#[cfg(test)]
211mod test {
212 use super::Decoder;
213 use std::io;
214 use std::io::Read;
215
216 #[test]
220 fn test_read_chunk_size() {
221 fn read(s: &str, expected: usize) {
222 let mut decoded = Decoder::new(s.as_bytes());
223 let actual = decoded.read_chunk_size().unwrap();
224 assert_eq!(expected, actual);
225 }
226
227 fn read_err(s: &str) {
228 let mut decoded = Decoder::new(s.as_bytes());
229 let err_kind = decoded.read_chunk_size().unwrap_err().kind();
230 assert_eq!(err_kind, io::ErrorKind::InvalidInput);
231 }
232
233 read("1\r\n", 1);
234 read("01\r\n", 1);
235 read("0\r\n", 0);
236 read("00\r\n", 0);
237 read("A\r\n", 10);
238 read("a\r\n", 10);
239 read("Ff\r\n", 255);
240 read("Ff \r\n", 255);
241 read_err("F\rF");
243 read_err("F");
244 read_err("X\r\n");
246 read_err("1X\r\n");
247 read_err("-\r\n");
248 read_err("-1\r\n");
249 read("1;extension\r\n", 1);
251 read("a;ext name=value\r\n", 10);
252 read("1;extension;extension2\r\n", 1);
253 read("1;;; ;\r\n", 1);
254 read("2; extension...\r\n", 2);
255 read("3 ; extension=123\r\n", 3);
256 read("3 ;\r\n", 3);
257 read("3 ; \r\n", 3);
258 read_err("1 invalid extension\r\n");
260 read_err("1 A\r\n");
261 read_err("1;no CRLF");
262 }
263
264 #[test]
265 fn test_valid_chunk_decode() {
266 let source = io::Cursor::new(
267 "3\r\nhel\r\nb\r\nlo world!!!\r\n0\r\n\r\n"
268 .to_string()
269 .into_bytes(),
270 );
271 let mut decoded = Decoder::new(source);
272
273 let mut string = String::new();
274 decoded.read_to_string(&mut string).unwrap();
275
276 assert_eq!(string, "hello world!!!");
277 }
278
279 #[test]
280 fn test_decode_zero_length() {
281 let mut decoder = Decoder::new(b"0\r\n\r\n" as &[u8]);
282
283 let mut decoded = String::new();
284 decoder.read_to_string(&mut decoded).unwrap();
285
286 assert_eq!(decoded, "");
287 }
288
289 #[test]
290 fn test_decode_invalid_chunk_length() {
291 let mut decoder = Decoder::new(b"m\r\n\r\n" as &[u8]);
292
293 let mut decoded = String::new();
294 assert!(decoder.read_to_string(&mut decoded).is_err());
295 }
296
297 #[test]
298 fn invalid_input1() {
299 let source = io::Cursor::new(
300 "2\r\nhel\r\nb\r\nlo world!!!\r\n0\r\n"
301 .to_string()
302 .into_bytes(),
303 );
304 let mut decoded = Decoder::new(source);
305
306 let mut string = String::new();
307 assert!(decoded.read_to_string(&mut string).is_err());
308 }
309
310 #[test]
311 fn invalid_input2() {
312 let source = io::Cursor::new(
313 "3\rhel\r\nb\r\nlo world!!!\r\n0\r\n"
314 .to_string()
315 .into_bytes(),
316 );
317 let mut decoded = Decoder::new(source);
318
319 let mut string = String::new();
320 assert!(decoded.read_to_string(&mut string).is_err());
321 }
322
323 #[test]
324 fn test_decode_end_missing_last_crlf() {
325 let source = io::Cursor::new(
330 "3\r\nhel\r\nb\r\nlo world!!!\r\n0\r\n"
331 .to_string()
332 .into_bytes(),
333 );
334 let mut decoded = Decoder::new(source);
335
336 let mut string = String::new();
337 decoded.read_to_string(&mut string).unwrap();
338
339 assert_eq!(string, "hello world!!!");
340 }
341}