tiny_http/
client.rs

1use ascii::AsciiString;
2
3use std::io::Error as IoError;
4use std::io::Result as IoResult;
5use std::io::{BufReader, BufWriter, ErrorKind, Read};
6
7use std::net::SocketAddr;
8use std::str::FromStr;
9
10use crate::common::{HTTPVersion, Method};
11use crate::util::RefinedTcpStream;
12use crate::util::{SequentialReader, SequentialReaderBuilder, SequentialWriterBuilder};
13use crate::Request;
14
15/// A ClientConnection is an object that will store a socket to a client
16/// and return Request objects.
17pub struct ClientConnection {
18    // address of the client
19    remote_addr: IoResult<Option<SocketAddr>>,
20
21    // sequence of Readers to the stream, so that the data is not read in
22    //  the wrong order
23    source: SequentialReaderBuilder<BufReader<RefinedTcpStream>>,
24
25    // sequence of Writers to the stream, to avoid writing response #2 before
26    //  response #1
27    sink: SequentialWriterBuilder<BufWriter<RefinedTcpStream>>,
28
29    // Reader to read the next header from
30    next_header_source: SequentialReader<BufReader<RefinedTcpStream>>,
31
32    // set to true if we know that the previous request is the last one
33    no_more_requests: bool,
34
35    // true if the connection goes through SSL
36    secure: bool,
37}
38
39/// Error that can happen when reading a request.
40#[derive(Debug)]
41enum ReadError {
42    WrongRequestLine,
43    WrongHeader(HTTPVersion),
44    /// the client sent an unrecognized `Expect` header
45    ExpectationFailed(HTTPVersion),
46    ReadIoError(IoError),
47}
48
49impl ClientConnection {
50    /// Creates a new `ClientConnection` that takes ownership of the `TcpStream`.
51    pub fn new(
52        write_socket: RefinedTcpStream,
53        mut read_socket: RefinedTcpStream,
54    ) -> ClientConnection {
55        let remote_addr = read_socket.peer_addr();
56        let secure = read_socket.secure();
57
58        let mut source = SequentialReaderBuilder::new(BufReader::with_capacity(1024, read_socket));
59        let first_header = source.next().unwrap();
60
61        ClientConnection {
62            source,
63            sink: SequentialWriterBuilder::new(BufWriter::with_capacity(1024, write_socket)),
64            remote_addr,
65            next_header_source: first_header,
66            no_more_requests: false,
67            secure,
68        }
69    }
70
71    /// true if the connection is HTTPS
72    pub fn secure(&self) -> bool {
73        self.secure
74    }
75
76    /// Reads the next line from self.next_header_source.
77    ///
78    /// Reads until `CRLF` is reached. The next read will start
79    ///  at the first byte of the new line.
80    fn read_next_line(&mut self) -> IoResult<AsciiString> {
81        let mut buf = Vec::new();
82        let mut prev_byte_was_cr = false;
83
84        loop {
85            let byte = self.next_header_source.by_ref().bytes().next();
86
87            let byte = match byte {
88                Some(b) => b?,
89                None => return Err(IoError::new(ErrorKind::ConnectionAborted, "Unexpected EOF")),
90            };
91
92            if byte == b'\n' && prev_byte_was_cr {
93                buf.pop(); // removing the '\r'
94                return AsciiString::from_ascii(buf)
95                    .map_err(|_| IoError::new(ErrorKind::InvalidInput, "Header is not in ASCII"));
96            }
97
98            prev_byte_was_cr = byte == b'\r';
99
100            buf.push(byte);
101        }
102    }
103
104    /// Reads a request from the stream.
105    /// Blocks until the header has been read.
106    fn read(&mut self) -> Result<Request, ReadError> {
107        let (method, path, version, headers) = {
108            // reading the request line
109            let (method, path, version) = {
110                let line = self.read_next_line().map_err(ReadError::ReadIoError)?;
111
112                parse_request_line(
113                    line.as_str().trim(), // TODO: remove this conversion
114                )?
115            };
116
117            // getting all headers
118            let headers = {
119                let mut headers = Vec::new();
120                loop {
121                    let line = self.read_next_line().map_err(ReadError::ReadIoError)?;
122
123                    if line.is_empty() {
124                        break;
125                    };
126                    headers.push(match FromStr::from_str(line.as_str().trim()) {
127                        // TODO: remove this conversion
128                        Ok(h) => h,
129                        _ => return Err(ReadError::WrongHeader(version)),
130                    });
131                }
132
133                headers
134            };
135
136            (method, path, version, headers)
137        };
138
139        // building the writer for the request
140        let writer = self.sink.next().unwrap();
141
142        // follow-up for next potential request
143        let mut data_source = self.source.next().unwrap();
144        std::mem::swap(&mut self.next_header_source, &mut data_source);
145
146        // building the next reader
147        let request = crate::request::new_request(
148            self.secure,
149            method,
150            path,
151            version.clone(),
152            headers,
153            *self.remote_addr.as_ref().unwrap(),
154            data_source,
155            writer,
156        )
157        .map_err(|e| {
158            use crate::request;
159            match e {
160                request::RequestCreationError::CreationIoError(e) => ReadError::ReadIoError(e),
161                request::RequestCreationError::ExpectationFailed => {
162                    ReadError::ExpectationFailed(version)
163                }
164            }
165        })?;
166
167        // return the request
168        Ok(request)
169    }
170}
171
172impl Iterator for ClientConnection {
173    type Item = Request;
174
175    /// Blocks until the next Request is available.
176    /// Returns None when no new Requests will come from the client.
177    fn next(&mut self) -> Option<Request> {
178        use crate::{Response, StatusCode};
179
180        // the client sent a "connection: close" header in this previous request
181        //  or is using HTTP 1.0, meaning that no new request will come
182        if self.no_more_requests {
183            return None;
184        }
185
186        loop {
187            let rq = match self.read() {
188                Err(ReadError::WrongRequestLine) => {
189                    let writer = self.sink.next().unwrap();
190                    let response = Response::new_empty(StatusCode(400));
191                    response
192                        .raw_print(writer, HTTPVersion(1, 1), &[], false, None)
193                        .ok();
194                    return None; // we don't know where the next request would start,
195                                 // se we have to close
196                }
197
198                Err(ReadError::WrongHeader(ver)) => {
199                    let writer = self.sink.next().unwrap();
200                    let response = Response::new_empty(StatusCode(400));
201                    response.raw_print(writer, ver, &[], false, None).ok();
202                    return None; // we don't know where the next request would start,
203                                 // se we have to close
204                }
205
206                Err(ReadError::ReadIoError(ref err)) if err.kind() == ErrorKind::TimedOut => {
207                    // request timeout
208                    let writer = self.sink.next().unwrap();
209                    let response = Response::new_empty(StatusCode(408));
210                    response
211                        .raw_print(writer, HTTPVersion(1, 1), &[], false, None)
212                        .ok();
213                    return None; // closing the connection
214                }
215
216                Err(ReadError::ExpectationFailed(ver)) => {
217                    let writer = self.sink.next().unwrap();
218                    let response = Response::new_empty(StatusCode(417));
219                    response.raw_print(writer, ver, &[], true, None).ok();
220                    return None; // TODO: should be recoverable, but needs handling in case of body
221                }
222
223                Err(ReadError::ReadIoError(_)) => return None,
224
225                Ok(rq) => rq,
226            };
227
228            // checking HTTP version
229            if *rq.http_version() > (1, 1) {
230                let writer = self.sink.next().unwrap();
231                let response = Response::from_string(
232                    "This server only supports HTTP versions 1.0 and 1.1".to_owned(),
233                )
234                .with_status_code(StatusCode(505));
235                response
236                    .raw_print(writer, HTTPVersion(1, 1), &[], false, None)
237                    .ok();
238                continue;
239            }
240
241            // updating the status of the connection
242            let connection_header = rq
243                .headers()
244                .iter()
245                .find(|h| h.field.equiv("Connection"))
246                .map(|h| h.value.as_str());
247
248            let lowercase = connection_header.map(|h| h.to_ascii_lowercase());
249
250            match lowercase {
251                Some(ref val) if val.contains("close") => self.no_more_requests = true,
252                Some(ref val) if val.contains("upgrade") => self.no_more_requests = true,
253                Some(ref val)
254                    if !val.contains("keep-alive") && *rq.http_version() == HTTPVersion(1, 0) =>
255                {
256                    self.no_more_requests = true
257                }
258                None if *rq.http_version() == HTTPVersion(1, 0) => self.no_more_requests = true,
259                _ => (),
260            };
261
262            // returning the request
263            return Some(rq);
264        }
265    }
266}
267
268/// Parses a "HTTP/1.1" string.
269fn parse_http_version(version: &str) -> Result<HTTPVersion, ReadError> {
270    let (major, minor) = match version {
271        "HTTP/0.9" => (0, 9),
272        "HTTP/1.0" => (1, 0),
273        "HTTP/1.1" => (1, 1),
274        "HTTP/2.0" => (2, 0),
275        "HTTP/3.0" => (3, 0),
276        _ => return Err(ReadError::WrongRequestLine),
277    };
278
279    Ok(HTTPVersion(major, minor))
280}
281
282/// Parses the request line of the request.
283/// eg. GET / HTTP/1.1
284fn parse_request_line(line: &str) -> Result<(Method, String, HTTPVersion), ReadError> {
285    let mut parts = line.split(' ');
286
287    let method = parts.next().and_then(|w| w.parse().ok());
288    let path = parts.next().map(ToOwned::to_owned);
289    let version = parts.next().and_then(|w| parse_http_version(w).ok());
290
291    method
292        .and_then(|method| Some((method, path?, version?)))
293        .ok_or(ReadError::WrongRequestLine)
294}
295
296#[cfg(test)]
297mod test {
298    #[test]
299    fn test_parse_request_line() {
300        let (method, path, ver) = super::parse_request_line("GET /hello HTTP/1.1").unwrap();
301
302        assert!(method == crate::Method::Get);
303        assert!(path == "/hello");
304        assert!(ver == crate::common::HTTPVersion(1, 1));
305
306        assert!(super::parse_request_line("GET /hello").is_err());
307        assert!(super::parse_request_line("qsd qsd qsd").is_err());
308    }
309}