tiny_http/util/
equal_reader.rs

1use std::io::Read;
2use std::io::Result as IoResult;
3use std::sync::mpsc::channel;
4use std::sync::mpsc::{Receiver, Sender};
5
6/// A `Reader` that reads exactly the number of bytes from a sub-reader.
7///
8/// If the limit is reached, it returns EOF. If the limit is not reached
9/// when the destructor is called, the remaining bytes will be read and
10/// thrown away.
11pub struct EqualReader<R>
12where
13    R: Read,
14{
15    reader: R,
16    size: usize,
17    last_read_signal: Sender<IoResult<()>>,
18}
19
20impl<R> EqualReader<R>
21where
22    R: Read,
23{
24    pub fn new(reader: R, size: usize) -> (EqualReader<R>, Receiver<IoResult<()>>) {
25        let (tx, rx) = channel();
26
27        let r = EqualReader {
28            reader,
29            size,
30            last_read_signal: tx,
31        };
32
33        (r, rx)
34    }
35}
36
37impl<R> Read for EqualReader<R>
38where
39    R: Read,
40{
41    fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
42        if self.size == 0 {
43            return Ok(0);
44        }
45
46        let buf = if buf.len() < self.size {
47            buf
48        } else {
49            &mut buf[..self.size]
50        };
51
52        match self.reader.read(buf) {
53            Ok(len) => {
54                self.size -= len;
55                Ok(len)
56            }
57            err @ Err(_) => err,
58        }
59    }
60}
61
62impl<R> Drop for EqualReader<R>
63where
64    R: Read,
65{
66    fn drop(&mut self) {
67        let mut remaining_to_read = self.size;
68
69        while remaining_to_read > 0 {
70            let mut buf = vec![0; remaining_to_read];
71
72            match self.reader.read(&mut buf) {
73                Err(e) => {
74                    self.last_read_signal.send(Err(e)).ok();
75                    break;
76                }
77                Ok(0) => {
78                    self.last_read_signal.send(Ok(())).ok();
79                    break;
80                }
81                Ok(other) => {
82                    remaining_to_read -= other;
83                }
84            }
85        }
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::EqualReader;
92    use std::io::Read;
93
94    #[test]
95    fn test_limit() {
96        use std::io::Cursor;
97
98        let mut org_reader = Cursor::new("hello world".to_string().into_bytes());
99
100        {
101            let (mut equal_reader, _) = EqualReader::new(org_reader.by_ref(), 5);
102
103            let mut string = String::new();
104            equal_reader.read_to_string(&mut string).unwrap();
105            assert_eq!(string, "hello");
106        }
107
108        let mut string = String::new();
109        org_reader.read_to_string(&mut string).unwrap();
110        assert_eq!(string, " world");
111    }
112
113    #[test]
114    fn test_not_enough() {
115        use std::io::Cursor;
116
117        let mut org_reader = Cursor::new("hello world".to_string().into_bytes());
118
119        {
120            let (mut equal_reader, _) = EqualReader::new(org_reader.by_ref(), 5);
121
122            let mut vec = [0];
123            equal_reader.read_exact(&mut vec).unwrap();
124            assert_eq!(vec[0], b'h');
125        }
126
127        let mut string = String::new();
128        org_reader.read_to_string(&mut string).unwrap();
129        assert_eq!(string, " world");
130    }
131}