1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
use crate::{Error, Request, Response};
/// Chained processing of request (and response).
///
/// # Middleware as `fn`
///
/// The middleware trait is implemented for all functions that have the signature
///
/// `Fn(Request, MiddlewareNext) -> Result<Response, Error>`
///
/// That means the easiest way to implement middleware is by providing a `fn`, like so
///
/// ```no_run
/// # use ureq::{Request, Response, MiddlewareNext, Error};
/// fn my_middleware(req: Request, next: MiddlewareNext) -> Result<Response, Error> {
/// // do middleware things
///
/// // continue the middleware chain
/// next.handle(req)
/// }
/// ```
///
/// # Adding headers
///
/// A common use case is to add headers to the outgoing request. Here an example of how.
///
/// ```no_run
/// # #[cfg(feature = "json")]
/// # fn main() -> Result<(), ureq::Error> {
/// # use ureq::{Request, Response, MiddlewareNext, Error};
/// # ureq::is_test(true);
/// fn my_middleware(req: Request, next: MiddlewareNext) -> Result<Response, Error> {
/// // set my bespoke header and continue the chain
/// next.handle(req.set("X-My-Header", "value_42"))
/// }
///
/// let agent = ureq::builder()
/// .middleware(my_middleware)
/// .build();
///
/// let result: serde_json::Value =
/// agent.get("http://httpbin.org/headers").call()?.into_json()?;
///
/// assert_eq!(&result["headers"]["X-My-Header"], "value_42");
///
/// # Ok(()) }
/// # #[cfg(not(feature = "json"))]
/// # fn main() {}
/// ```
///
/// # State
///
/// To maintain state between middleware invocations, we need to do something more elaborate than
/// the simple `fn` and implement the `Middleware` trait directly.
///
/// ## Example with mutex lock
///
/// In the `examples` directory there is an additional example `count-bytes.rs` which uses
/// a mutex lock like shown below.
///
/// ```no_run
/// # use ureq::{Request, Response, Middleware, MiddlewareNext, Error};
/// # use std::sync::{Arc, Mutex};
/// struct MyState {
/// // whatever is needed
/// }
///
/// struct MyMiddleware(Arc<Mutex<MyState>>);
///
/// impl Middleware for MyMiddleware {
/// fn handle(&self, request: Request, next: MiddlewareNext) -> Result<Response, Error> {
/// // These extra brackets ensures we release the Mutex lock before continuing the
/// // chain. There could also be scenarios where we want to maintain the lock through
/// // the invocation, which would block other requests from proceeding concurrently
/// // through the middleware.
/// {
/// let mut state = self.0.lock().unwrap();
/// // do stuff with state
/// }
///
/// // continue middleware chain
/// next.handle(request)
/// }
/// }
/// ```
///
/// ## Example with atomic
///
/// This example shows how we can increase a counter for each request going
/// through the agent.
///
/// ```no_run
/// # fn main() -> Result<(), ureq::Error> {
/// # ureq::is_test(true);
/// use ureq::{Request, Response, Middleware, MiddlewareNext, Error};
/// use std::sync::atomic::{AtomicU64, Ordering};
/// use std::sync::Arc;
///
/// // Middleware that stores a counter state. This example uses an AtomicU64
/// // since the middleware is potentially shared by multiple threads running
/// // requests at the same time.
/// struct MyCounter(Arc<AtomicU64>);
///
/// impl Middleware for MyCounter {
/// fn handle(&self, req: Request, next: MiddlewareNext) -> Result<Response, Error> {
/// // increase the counter for each invocation
/// self.0.fetch_add(1, Ordering::SeqCst);
///
/// // continue the middleware chain
/// next.handle(req)
/// }
/// }
///
/// let shared_counter = Arc::new(AtomicU64::new(0));
///
/// let agent = ureq::builder()
/// // Add our middleware
/// .middleware(MyCounter(shared_counter.clone()))
/// .build();
///
/// agent.get("http://httpbin.org/get").call()?;
/// agent.get("http://httpbin.org/get").call()?;
///
/// // Check we did indeed increase the counter twice.
/// assert_eq!(shared_counter.load(Ordering::SeqCst), 2);
///
/// # Ok(()) }
/// ```
pub trait Middleware: Send + Sync + 'static {
/// Handle of the middleware logic.
fn handle(&self, request: Request, next: MiddlewareNext) -> Result<Response, Error>;
}
/// Continuation of a [`Middleware`] chain.
pub struct MiddlewareNext<'a> {
pub(crate) chain: &'a mut (dyn Iterator<Item = &'a dyn Middleware>),
// Since request_fn consumes the Payload<'a>, we must have an FnOnce.
//
// It's possible to get rid of this Box if we make MiddlewareNext generic
// over some type variable, i.e. MiddlewareNext<'a, R> where R: FnOnce...
// however that would "leak" to Middleware::handle introducing a complicated
// type signature that is totally irrelevant for someone implementing a middleware.
//
// So in the name of having a sane external API, we accept this Box.
pub(crate) request_fn: Box<dyn FnOnce(Request) -> Result<Response, Error> + 'a>,
}
impl<'a> MiddlewareNext<'a> {
/// Continue the middleware chain by providing (a possibly amended) [`Request`].
pub fn handle(self, request: Request) -> Result<Response, Error> {
if let Some(step) = self.chain.next() {
step.handle(request, self)
} else {
(self.request_fn)(request)
}
}
}
impl<F> Middleware for F
where
F: Fn(Request, MiddlewareNext) -> Result<Response, Error> + Send + Sync + 'static,
{
fn handle(&self, request: Request, next: MiddlewareNext) -> Result<Response, Error> {
(self)(request, next)
}
}