tiny_http/util/
refined_tcp_stream.rs

1use std::io::Result as IoResult;
2use std::io::{Read, Write};
3use std::net::{Shutdown, SocketAddr};
4
5use crate::connection::Connection;
6#[cfg(any(feature = "ssl-openssl", feature = "ssl-rustls"))]
7use crate::ssl::SslStream;
8
9pub(crate) enum Stream {
10    Http(Connection),
11    #[cfg(any(feature = "ssl-openssl", feature = "ssl-rustls"))]
12    Https(SslStream),
13}
14
15impl Clone for Stream {
16    fn clone(&self) -> Self {
17        match self {
18            Stream::Http(tcp_stream) => Stream::Http(tcp_stream.try_clone().unwrap()),
19            #[cfg(any(feature = "ssl-openssl", feature = "ssl-rustls"))]
20            Stream::Https(ssl_stream) => Stream::Https(ssl_stream.clone()),
21        }
22    }
23}
24
25impl From<Connection> for Stream {
26    fn from(tcp_stream: Connection) -> Self {
27        Stream::Http(tcp_stream)
28    }
29}
30
31impl Stream {
32    fn secure(&self) -> bool {
33        match self {
34            Stream::Http(_) => false,
35            #[cfg(any(feature = "ssl-openssl", feature = "ssl-rustls"))]
36            Stream::Https(_) => true,
37        }
38    }
39
40    fn peer_addr(&mut self) -> IoResult<Option<SocketAddr>> {
41        match self {
42            Stream::Http(tcp_stream) => tcp_stream.peer_addr(),
43            #[cfg(any(feature = "ssl-openssl", feature = "ssl-rustls"))]
44            Stream::Https(ssl_stream) => ssl_stream.peer_addr(),
45        }
46    }
47
48    fn shutdown(&mut self, how: Shutdown) -> IoResult<()> {
49        match self {
50            Stream::Http(tcp_stream) => tcp_stream.shutdown(how),
51            #[cfg(any(feature = "ssl-openssl", feature = "ssl-rustls"))]
52            Stream::Https(ssl_stream) => ssl_stream.shutdown(how),
53        }
54    }
55}
56
57impl Read for Stream {
58    fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
59        match self {
60            Stream::Http(tcp_stream) => tcp_stream.read(buf),
61            #[cfg(any(feature = "ssl-openssl", feature = "ssl-rustls"))]
62            Stream::Https(ssl_stream) => ssl_stream.read(buf),
63        }
64    }
65}
66
67impl Write for Stream {
68    fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
69        match self {
70            Stream::Http(tcp_stream) => tcp_stream.write(buf),
71            #[cfg(any(feature = "ssl-openssl", feature = "ssl-rustls"))]
72            Stream::Https(ssl_stream) => ssl_stream.write(buf),
73        }
74    }
75
76    fn flush(&mut self) -> IoResult<()> {
77        match self {
78            Stream::Http(tcp_stream) => tcp_stream.flush(),
79            #[cfg(any(feature = "ssl-openssl", feature = "ssl-rustls"))]
80            Stream::Https(ssl_stream) => ssl_stream.flush(),
81        }
82    }
83}
84
85pub struct RefinedTcpStream {
86    stream: Stream,
87    close_read: bool,
88    close_write: bool,
89}
90
91impl RefinedTcpStream {
92    pub(crate) fn new<S>(stream: S) -> (RefinedTcpStream, RefinedTcpStream)
93    where
94        S: Into<Stream>,
95    {
96        let stream: Stream = stream.into();
97
98        let (read, write) = (stream.clone(), stream);
99
100        let read = RefinedTcpStream {
101            stream: read,
102            close_read: true,
103            close_write: false,
104        };
105
106        let write = RefinedTcpStream {
107            stream: write,
108            close_read: false,
109            close_write: true,
110        };
111
112        (read, write)
113    }
114
115    /// Returns true if this struct wraps around a secure connection.
116    #[inline]
117    pub(crate) fn secure(&self) -> bool {
118        self.stream.secure()
119    }
120
121    pub(crate) fn peer_addr(&mut self) -> IoResult<Option<SocketAddr>> {
122        self.stream.peer_addr()
123    }
124}
125
126impl Drop for RefinedTcpStream {
127    fn drop(&mut self) {
128        if self.close_read {
129            self.stream.shutdown(Shutdown::Read).ok();
130        }
131
132        if self.close_write {
133            self.stream.shutdown(Shutdown::Write).ok();
134        }
135    }
136}
137
138impl Read for RefinedTcpStream {
139    fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
140        self.stream.read(buf)
141    }
142}
143
144impl Write for RefinedTcpStream {
145    fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
146        self.stream.write(buf)
147    }
148
149    fn flush(&mut self) -> IoResult<()> {
150        self.stream.flush()
151    }
152}