1use ascii::AsciiString;
2
3use std::io::Error as IoError;
4use std::io::Result as IoResult;
5use std::io::{BufReader, BufWriter, ErrorKind, Read};
6
7use std::net::SocketAddr;
8use std::str::FromStr;
9
10use crate::common::{HTTPVersion, Method};
11use crate::util::RefinedTcpStream;
12use crate::util::{SequentialReader, SequentialReaderBuilder, SequentialWriterBuilder};
13use crate::Request;
14
15pub struct ClientConnection {
18 remote_addr: IoResult<Option<SocketAddr>>,
20
21 source: SequentialReaderBuilder<BufReader<RefinedTcpStream>>,
24
25 sink: SequentialWriterBuilder<BufWriter<RefinedTcpStream>>,
28
29 next_header_source: SequentialReader<BufReader<RefinedTcpStream>>,
31
32 no_more_requests: bool,
34
35 secure: bool,
37}
38
39#[derive(Debug)]
41enum ReadError {
42 WrongRequestLine,
43 WrongHeader(HTTPVersion),
44 ExpectationFailed(HTTPVersion),
46 ReadIoError(IoError),
47}
48
49impl ClientConnection {
50 pub fn new(
52 write_socket: RefinedTcpStream,
53 mut read_socket: RefinedTcpStream,
54 ) -> ClientConnection {
55 let remote_addr = read_socket.peer_addr();
56 let secure = read_socket.secure();
57
58 let mut source = SequentialReaderBuilder::new(BufReader::with_capacity(1024, read_socket));
59 let first_header = source.next().unwrap();
60
61 ClientConnection {
62 source,
63 sink: SequentialWriterBuilder::new(BufWriter::with_capacity(1024, write_socket)),
64 remote_addr,
65 next_header_source: first_header,
66 no_more_requests: false,
67 secure,
68 }
69 }
70
71 pub fn secure(&self) -> bool {
73 self.secure
74 }
75
76 fn read_next_line(&mut self) -> IoResult<AsciiString> {
81 let mut buf = Vec::new();
82 let mut prev_byte_was_cr = false;
83
84 loop {
85 let byte = self.next_header_source.by_ref().bytes().next();
86
87 let byte = match byte {
88 Some(b) => b?,
89 None => return Err(IoError::new(ErrorKind::ConnectionAborted, "Unexpected EOF")),
90 };
91
92 if byte == b'\n' && prev_byte_was_cr {
93 buf.pop(); return AsciiString::from_ascii(buf)
95 .map_err(|_| IoError::new(ErrorKind::InvalidInput, "Header is not in ASCII"));
96 }
97
98 prev_byte_was_cr = byte == b'\r';
99
100 buf.push(byte);
101 }
102 }
103
104 fn read(&mut self) -> Result<Request, ReadError> {
107 let (method, path, version, headers) = {
108 let (method, path, version) = {
110 let line = self.read_next_line().map_err(ReadError::ReadIoError)?;
111
112 parse_request_line(
113 line.as_str().trim(), )?
115 };
116
117 let headers = {
119 let mut headers = Vec::new();
120 loop {
121 let line = self.read_next_line().map_err(ReadError::ReadIoError)?;
122
123 if line.is_empty() {
124 break;
125 };
126 headers.push(match FromStr::from_str(line.as_str().trim()) {
127 Ok(h) => h,
129 _ => return Err(ReadError::WrongHeader(version)),
130 });
131 }
132
133 headers
134 };
135
136 (method, path, version, headers)
137 };
138
139 let writer = self.sink.next().unwrap();
141
142 let mut data_source = self.source.next().unwrap();
144 std::mem::swap(&mut self.next_header_source, &mut data_source);
145
146 let request = crate::request::new_request(
148 self.secure,
149 method,
150 path,
151 version.clone(),
152 headers,
153 *self.remote_addr.as_ref().unwrap(),
154 data_source,
155 writer,
156 )
157 .map_err(|e| {
158 use crate::request;
159 match e {
160 request::RequestCreationError::CreationIoError(e) => ReadError::ReadIoError(e),
161 request::RequestCreationError::ExpectationFailed => {
162 ReadError::ExpectationFailed(version)
163 }
164 }
165 })?;
166
167 Ok(request)
169 }
170}
171
172impl Iterator for ClientConnection {
173 type Item = Request;
174
175 fn next(&mut self) -> Option<Request> {
178 use crate::{Response, StatusCode};
179
180 if self.no_more_requests {
183 return None;
184 }
185
186 loop {
187 let rq = match self.read() {
188 Err(ReadError::WrongRequestLine) => {
189 let writer = self.sink.next().unwrap();
190 let response = Response::new_empty(StatusCode(400));
191 response
192 .raw_print(writer, HTTPVersion(1, 1), &[], false, None)
193 .ok();
194 return None; }
197
198 Err(ReadError::WrongHeader(ver)) => {
199 let writer = self.sink.next().unwrap();
200 let response = Response::new_empty(StatusCode(400));
201 response.raw_print(writer, ver, &[], false, None).ok();
202 return None; }
205
206 Err(ReadError::ReadIoError(ref err)) if err.kind() == ErrorKind::TimedOut => {
207 let writer = self.sink.next().unwrap();
209 let response = Response::new_empty(StatusCode(408));
210 response
211 .raw_print(writer, HTTPVersion(1, 1), &[], false, None)
212 .ok();
213 return None; }
215
216 Err(ReadError::ExpectationFailed(ver)) => {
217 let writer = self.sink.next().unwrap();
218 let response = Response::new_empty(StatusCode(417));
219 response.raw_print(writer, ver, &[], true, None).ok();
220 return None; }
222
223 Err(ReadError::ReadIoError(_)) => return None,
224
225 Ok(rq) => rq,
226 };
227
228 if *rq.http_version() > (1, 1) {
230 let writer = self.sink.next().unwrap();
231 let response = Response::from_string(
232 "This server only supports HTTP versions 1.0 and 1.1".to_owned(),
233 )
234 .with_status_code(StatusCode(505));
235 response
236 .raw_print(writer, HTTPVersion(1, 1), &[], false, None)
237 .ok();
238 continue;
239 }
240
241 let connection_header = rq
243 .headers()
244 .iter()
245 .find(|h| h.field.equiv("Connection"))
246 .map(|h| h.value.as_str());
247
248 let lowercase = connection_header.map(|h| h.to_ascii_lowercase());
249
250 match lowercase {
251 Some(ref val) if val.contains("close") => self.no_more_requests = true,
252 Some(ref val) if val.contains("upgrade") => self.no_more_requests = true,
253 Some(ref val)
254 if !val.contains("keep-alive") && *rq.http_version() == HTTPVersion(1, 0) =>
255 {
256 self.no_more_requests = true
257 }
258 None if *rq.http_version() == HTTPVersion(1, 0) => self.no_more_requests = true,
259 _ => (),
260 };
261
262 return Some(rq);
264 }
265 }
266}
267
268fn parse_http_version(version: &str) -> Result<HTTPVersion, ReadError> {
270 let (major, minor) = match version {
271 "HTTP/0.9" => (0, 9),
272 "HTTP/1.0" => (1, 0),
273 "HTTP/1.1" => (1, 1),
274 "HTTP/2.0" => (2, 0),
275 "HTTP/3.0" => (3, 0),
276 _ => return Err(ReadError::WrongRequestLine),
277 };
278
279 Ok(HTTPVersion(major, minor))
280}
281
282fn parse_request_line(line: &str) -> Result<(Method, String, HTTPVersion), ReadError> {
285 let mut parts = line.split(' ');
286
287 let method = parts.next().and_then(|w| w.parse().ok());
288 let path = parts.next().map(ToOwned::to_owned);
289 let version = parts.next().and_then(|w| parse_http_version(w).ok());
290
291 method
292 .and_then(|method| Some((method, path?, version?)))
293 .ok_or(ReadError::WrongRequestLine)
294}
295
296#[cfg(test)]
297mod test {
298 #[test]
299 fn test_parse_request_line() {
300 let (method, path, ver) = super::parse_request_line("GET /hello HTTP/1.1").unwrap();
301
302 assert!(method == crate::Method::Get);
303 assert!(path == "/hello");
304 assert!(ver == crate::common::HTTPVersion(1, 1));
305
306 assert!(super::parse_request_line("GET /hello").is_err());
307 assert!(super::parse_request_line("qsd qsd qsd").is_err());
308 }
309}