tiny_http/util/
equal_reader.rs
1use std::io::Read;
2use std::io::Result as IoResult;
3use std::sync::mpsc::channel;
4use std::sync::mpsc::{Receiver, Sender};
5
6pub struct EqualReader<R>
12where
13 R: Read,
14{
15 reader: R,
16 size: usize,
17 last_read_signal: Sender<IoResult<()>>,
18}
19
20impl<R> EqualReader<R>
21where
22 R: Read,
23{
24 pub fn new(reader: R, size: usize) -> (EqualReader<R>, Receiver<IoResult<()>>) {
25 let (tx, rx) = channel();
26
27 let r = EqualReader {
28 reader,
29 size,
30 last_read_signal: tx,
31 };
32
33 (r, rx)
34 }
35}
36
37impl<R> Read for EqualReader<R>
38where
39 R: Read,
40{
41 fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
42 if self.size == 0 {
43 return Ok(0);
44 }
45
46 let buf = if buf.len() < self.size {
47 buf
48 } else {
49 &mut buf[..self.size]
50 };
51
52 match self.reader.read(buf) {
53 Ok(len) => {
54 self.size -= len;
55 Ok(len)
56 }
57 err @ Err(_) => err,
58 }
59 }
60}
61
62impl<R> Drop for EqualReader<R>
63where
64 R: Read,
65{
66 fn drop(&mut self) {
67 let mut remaining_to_read = self.size;
68
69 while remaining_to_read > 0 {
70 let mut buf = vec![0; remaining_to_read];
71
72 match self.reader.read(&mut buf) {
73 Err(e) => {
74 self.last_read_signal.send(Err(e)).ok();
75 break;
76 }
77 Ok(0) => {
78 self.last_read_signal.send(Ok(())).ok();
79 break;
80 }
81 Ok(other) => {
82 remaining_to_read -= other;
83 }
84 }
85 }
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::EqualReader;
92 use std::io::Read;
93
94 #[test]
95 fn test_limit() {
96 use std::io::Cursor;
97
98 let mut org_reader = Cursor::new("hello world".to_string().into_bytes());
99
100 {
101 let (mut equal_reader, _) = EqualReader::new(org_reader.by_ref(), 5);
102
103 let mut string = String::new();
104 equal_reader.read_to_string(&mut string).unwrap();
105 assert_eq!(string, "hello");
106 }
107
108 let mut string = String::new();
109 org_reader.read_to_string(&mut string).unwrap();
110 assert_eq!(string, " world");
111 }
112
113 #[test]
114 fn test_not_enough() {
115 use std::io::Cursor;
116
117 let mut org_reader = Cursor::new("hello world".to_string().into_bytes());
118
119 {
120 let (mut equal_reader, _) = EqualReader::new(org_reader.by_ref(), 5);
121
122 let mut vec = [0];
123 equal_reader.read_exact(&mut vec).unwrap();
124 assert_eq!(vec[0], b'h');
125 }
126
127 let mut string = String::new();
128 org_reader.read_to_string(&mut string).unwrap();
129 assert_eq!(string, " world");
130 }
131}