ureq/
testserver.rs

1use std::io;
2use std::net::ToSocketAddrs;
3use std::net::{SocketAddr, TcpListener, TcpStream};
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::Arc;
6use std::thread;
7use std::time::Duration;
8
9use crate::{Agent, AgentBuilder};
10
11#[cfg(not(feature = "testdeps"))]
12fn test_server_handler(_stream: TcpStream) -> io::Result<()> {
13    Ok(())
14}
15
16#[cfg(feature = "testdeps")]
17fn test_server_handler(stream: TcpStream) -> io::Result<()> {
18    use hootbin::serve_single;
19    let o = stream.try_clone().expect("TcpStream to be clonable");
20    let i = stream;
21    match serve_single(i, o, "https://hootbin.test/") {
22        Ok(()) => {}
23        Err(e) => {
24            if let hootbin::Error::Io(ioe) = &e {
25                if ioe.kind() == io::ErrorKind::UnexpectedEof {
26                    // accept this. the pre-connect below is always erroring.
27                    return Ok(());
28                }
29            }
30
31            println!("TestServer error: {:?}", e);
32        }
33    };
34    Ok(())
35}
36
37// An agent to be installed by default for tests and doctests, such
38// that all hostnames resolve to a TestServer on localhost.
39pub(crate) fn test_agent() -> Agent {
40    #[cfg(test)]
41    let _ = env_logger::try_init();
42
43    let testserver = TestServer::new(test_server_handler);
44    // Slightly tricky thing here: we want to make sure the TestServer lives
45    // as long as the agent. This is accomplished by `move`ing it into the
46    // closure, which becomes owned by the agent.
47    AgentBuilder::new()
48        .resolver(move |h: &str| -> io::Result<Vec<SocketAddr>> {
49            // Don't override resolution for HTTPS requests yet, since we
50            // don't have a setup for an HTTPS testserver. Also, skip localhost
51            // resolutions since those may come from a unittest that set up
52            // its own, specific testserver.
53            if h.ends_with(":443") || h.starts_with("localhost:") {
54                return Ok(h.to_socket_addrs()?.collect::<Vec<_>>());
55            }
56            let addr: SocketAddr = format!("127.0.0.1:{}", testserver.port).parse().unwrap();
57            Ok(vec![addr])
58        })
59        .build()
60}
61
62pub struct TestServer {
63    pub port: u16,
64    pub done: Arc<AtomicBool>,
65}
66
67pub struct TestHeaders(Vec<String>);
68
69#[allow(dead_code)]
70impl TestHeaders {
71    // Return the path for a request, e.g. /foo from "GET /foo HTTP/1.1"
72    pub fn path(&self) -> &str {
73        if self.0.is_empty() {
74            ""
75        } else {
76            self.0[0].split(' ').nth(1).unwrap()
77        }
78    }
79
80    #[cfg(feature = "cookies")]
81    pub fn headers(&self) -> &[String] {
82        &self.0[1..]
83    }
84}
85
86// Read a stream until reaching a blank line, in order to consume
87// request headers.
88#[cfg(test)]
89pub fn read_request(stream: &TcpStream) -> TestHeaders {
90    use std::io::{BufRead, BufReader};
91
92    let mut results = vec![];
93    for line in BufReader::new(stream).lines() {
94        match line {
95            Err(e) => {
96                eprintln!("testserver: in read_request: {}", e);
97                break;
98            }
99            Ok(line) if line.is_empty() => break,
100            Ok(line) => results.push(line),
101        };
102    }
103    // Consume rest of body. TODO maybe capture the body for inspection in the test?
104    // There's a risk stream is ended here, and fill_buf() would block.
105    stream.set_nonblocking(true).ok();
106    let mut reader = BufReader::new(stream);
107    while let Ok(buf) = reader.fill_buf() {
108        let amount = buf.len();
109        if amount == 0 {
110            break;
111        }
112        reader.consume(amount);
113    }
114    TestHeaders(results)
115}
116
117impl TestServer {
118    pub fn new(handler: fn(TcpStream) -> io::Result<()>) -> Self {
119        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
120        let port = listener.local_addr().unwrap().port();
121        let done = Arc::new(AtomicBool::new(false));
122        let done_clone = done.clone();
123        thread::spawn(move || {
124            for stream in listener.incoming() {
125                if let Err(e) = stream {
126                    eprintln!("testserver: handling just-accepted stream: {}", e);
127                    break;
128                }
129                if done.load(Ordering::SeqCst) {
130                    break;
131                } else {
132                    thread::spawn(move || handler(stream.unwrap()));
133                }
134            }
135        });
136        // before returning from new(), ensure the server is ready to accept connections
137        while let Err(e) = TcpStream::connect(format!("127.0.0.1:{}", port)) {
138            match e.kind() {
139                io::ErrorKind::ConnectionRefused => {
140                    std::thread::sleep(Duration::from_millis(100));
141                    continue;
142                }
143                _ => eprintln!("testserver: pre-connect with error {}", e),
144            }
145        }
146        TestServer {
147            port,
148            done: done_clone,
149        }
150    }
151}
152
153impl Drop for TestServer {
154    fn drop(&mut self) {
155        self.done.store(true, Ordering::SeqCst);
156        // Connect once to unblock the listen loop.
157        if let Err(e) = TcpStream::connect(format!("localhost:{}", self.port)) {
158            eprintln!("error dropping testserver: {}", e);
159        }
160    }
161}