ureq/
middleware.rs

1use crate::{Error, Request, Response};
2
3/// Chained processing of request (and response).
4///
5/// # Middleware as `fn`
6///
7/// The middleware trait is implemented for all functions that have the signature
8///
9/// `Fn(Request, MiddlewareNext) -> Result<Response, Error>`
10///
11/// That means the easiest way to implement middleware is by providing a `fn`, like so
12///
13/// ```no_run
14/// # use ureq::{Request, Response, MiddlewareNext, Error};
15/// fn my_middleware(req: Request, next: MiddlewareNext) -> Result<Response, Error> {
16///     // do middleware things
17///
18///     // continue the middleware chain
19///     next.handle(req)
20/// }
21/// ```
22///
23/// # Adding headers
24///
25/// A common use case is to add headers to the outgoing request. Here an example of how.
26///
27/// ```no_run
28/// # #[cfg(feature = "json")]
29/// # fn main() -> Result<(), ureq::Error> {
30/// # use ureq::{Request, Response, MiddlewareNext, Error};
31/// # ureq::is_test(true);
32/// fn my_middleware(req: Request, next: MiddlewareNext) -> Result<Response, Error> {
33///     // set my bespoke header and continue the chain
34///     next.handle(req.set("X-My-Header", "value_42"))
35/// }
36///
37/// let agent = ureq::builder()
38///     .middleware(my_middleware)
39///     .build();
40///
41/// let result: serde_json::Value =
42///     agent.get("http://httpbin.org/headers").call()?.into_json()?;
43///
44/// assert_eq!(&result["headers"]["X-My-Header"], "value_42");
45///
46/// # Ok(()) }
47/// # #[cfg(not(feature = "json"))]
48/// # fn main() {}
49/// ```
50///
51/// # State
52///
53/// To maintain state between middleware invocations, we need to do something more elaborate than
54/// the simple `fn` and implement the `Middleware` trait directly.
55///
56/// ## Example with mutex lock
57///
58/// In the `examples` directory there is an additional example `count-bytes.rs` which uses
59/// a mutex lock like shown below.
60///
61/// ```no_run
62/// # use ureq::{Request, Response, Middleware, MiddlewareNext, Error};
63/// # use std::sync::{Arc, Mutex};
64/// struct MyState {
65///     // whatever is needed
66/// }
67///
68/// struct MyMiddleware(Arc<Mutex<MyState>>);
69///
70/// impl Middleware for MyMiddleware {
71///     fn handle(&self, request: Request, next: MiddlewareNext) -> Result<Response, Error> {
72///         // These extra brackets ensures we release the Mutex lock before continuing the
73///         // chain. There could also be scenarios where we want to maintain the lock through
74///         // the invocation, which would block other requests from proceeding concurrently
75///         // through the middleware.
76///         {
77///             let mut state = self.0.lock().unwrap();
78///             // do stuff with state
79///         }
80///
81///         // continue middleware chain
82///         next.handle(request)
83///     }
84/// }
85/// ```
86///
87/// ## Example with atomic
88///
89/// This example shows how we can increase a counter for each request going
90/// through the agent.
91///
92/// ```no_run
93/// # fn main() -> Result<(), ureq::Error> {
94/// # ureq::is_test(true);
95/// use ureq::{Request, Response, Middleware, MiddlewareNext, Error};
96/// use std::sync::atomic::{AtomicU64, Ordering};
97/// use std::sync::Arc;
98///
99/// // Middleware that stores a counter state. This example uses an AtomicU64
100/// // since the middleware is potentially shared by multiple threads running
101/// // requests at the same time.
102/// struct MyCounter(Arc<AtomicU64>);
103///
104/// impl Middleware for MyCounter {
105///     fn handle(&self, req: Request, next: MiddlewareNext) -> Result<Response, Error> {
106///         // increase the counter for each invocation
107///         self.0.fetch_add(1, Ordering::SeqCst);
108///
109///         // continue the middleware chain
110///         next.handle(req)
111///     }
112/// }
113///
114/// let shared_counter = Arc::new(AtomicU64::new(0));
115///
116/// let agent = ureq::builder()
117///     // Add our middleware
118///     .middleware(MyCounter(shared_counter.clone()))
119///     .build();
120///
121/// agent.get("http://httpbin.org/get").call()?;
122/// agent.get("http://httpbin.org/get").call()?;
123///
124/// // Check we did indeed increase the counter twice.
125/// assert_eq!(shared_counter.load(Ordering::SeqCst), 2);
126///
127/// # Ok(()) }
128/// ```
129pub trait Middleware: Send + Sync + 'static {
130    /// Handle of the middleware logic.
131    fn handle(&self, request: Request, next: MiddlewareNext) -> Result<Response, Error>;
132}
133
134/// Continuation of a [`Middleware`] chain.
135pub struct MiddlewareNext<'a> {
136    pub(crate) chain: &'a mut (dyn Iterator<Item = &'a dyn Middleware>),
137    // Since request_fn consumes the Payload<'a>, we must have an FnOnce.
138    //
139    // It's possible to get rid of this Box if we make MiddlewareNext generic
140    // over some type variable, i.e. MiddlewareNext<'a, R> where R: FnOnce...
141    // however that would "leak" to Middleware::handle introducing a complicated
142    // type signature that is totally irrelevant for someone implementing a middleware.
143    //
144    // So in the name of having a sane external API, we accept this Box.
145    pub(crate) request_fn: Box<dyn FnOnce(Request) -> Result<Response, Error> + 'a>,
146}
147
148impl<'a> MiddlewareNext<'a> {
149    /// Continue the middleware chain by providing (a possibly amended) [`Request`].
150    pub fn handle(self, request: Request) -> Result<Response, Error> {
151        if let Some(step) = self.chain.next() {
152            step.handle(request, self)
153        } else {
154            (self.request_fn)(request)
155        }
156    }
157}
158
159impl<F> Middleware for F
160where
161    F: Fn(Request, MiddlewareNext) -> Result<Response, Error> + Send + Sync + 'static,
162{
163    fn handle(&self, request: Request, next: MiddlewareNext) -> Result<Response, Error> {
164        (self)(request, next)
165    }
166}