use crate::{
    body::{boxed, BoxBody},
    request::SanitizeHeaders,
    Status,
};
use bytes::Bytes;
use pin_project::pin_project;
use std::{
    fmt,
    future::Future,
    pin::Pin,
    task::{Context, Poll},
};
use tower_layer::Layer;
use tower_service::Service;
pub trait Interceptor {
    fn call(&mut self, request: crate::Request<()>) -> Result<crate::Request<()>, Status>;
}
impl<F> Interceptor for F
where
    F: FnMut(crate::Request<()>) -> Result<crate::Request<()>, Status>,
{
    fn call(&mut self, request: crate::Request<()>) -> Result<crate::Request<()>, Status> {
        self(request)
    }
}
pub fn interceptor<F>(f: F) -> InterceptorLayer<F>
where
    F: Interceptor,
{
    InterceptorLayer { f }
}
#[derive(Debug, Clone, Copy)]
pub struct InterceptorLayer<F> {
    f: F,
}
impl<S, F> Layer<S> for InterceptorLayer<F>
where
    F: Interceptor + Clone,
{
    type Service = InterceptedService<S, F>;
    fn layer(&self, service: S) -> Self::Service {
        InterceptedService::new(service, self.f.clone())
    }
}
#[derive(Clone, Copy)]
pub struct InterceptedService<S, F> {
    inner: S,
    f: F,
}
impl<S, F> InterceptedService<S, F> {
    pub fn new(service: S, f: F) -> Self
    where
        F: Interceptor,
    {
        Self { inner: service, f }
    }
}
impl<S, F> fmt::Debug for InterceptedService<S, F>
where
    S: fmt::Debug,
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("InterceptedService")
            .field("inner", &self.inner)
            .field("f", &format_args!("{}", std::any::type_name::<F>()))
            .finish()
    }
}
impl<S, F, ReqBody, ResBody> Service<http::Request<ReqBody>> for InterceptedService<S, F>
where
    ResBody: Default + http_body::Body<Data = Bytes> + Send + 'static,
    F: Interceptor,
    S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>>,
    S::Error: Into<crate::Error>,
    ResBody: http_body::Body<Data = bytes::Bytes> + Send + 'static,
    ResBody::Error: Into<crate::Error>,
{
    type Response = http::Response<BoxBody>;
    type Error = S::Error;
    type Future = ResponseFuture<S::Future>;
    #[inline]
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }
    fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
        let uri = req.uri().clone();
        let method = req.method().clone();
        let version = req.version();
        let req = crate::Request::from_http(req);
        let (metadata, extensions, msg) = req.into_parts();
        match self
            .f
            .call(crate::Request::from_parts(metadata, extensions, ()))
        {
            Ok(req) => {
                let (metadata, extensions, _) = req.into_parts();
                let req = crate::Request::from_parts(metadata, extensions, msg);
                let req = req.into_http(uri, method, version, SanitizeHeaders::No);
                ResponseFuture::future(self.inner.call(req))
            }
            Err(status) => ResponseFuture::status(status),
        }
    }
}
impl<S, F> crate::server::NamedService for InterceptedService<S, F>
where
    S: crate::server::NamedService,
{
    const NAME: &'static str = S::NAME;
}
#[pin_project]
#[derive(Debug)]
pub struct ResponseFuture<F> {
    #[pin]
    kind: Kind<F>,
}
impl<F> ResponseFuture<F> {
    fn future(future: F) -> Self {
        Self {
            kind: Kind::Future(future),
        }
    }
    fn status(status: Status) -> Self {
        Self {
            kind: Kind::Status(Some(status)),
        }
    }
}
#[pin_project(project = KindProj)]
#[derive(Debug)]
enum Kind<F> {
    Future(#[pin] F),
    Status(Option<Status>),
}
impl<F, E, B> Future for ResponseFuture<F>
where
    F: Future<Output = Result<http::Response<B>, E>>,
    E: Into<crate::Error>,
    B: Default + http_body::Body<Data = Bytes> + Send + 'static,
    B::Error: Into<crate::Error>,
{
    type Output = Result<http::Response<BoxBody>, E>;
    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        match self.project().kind.project() {
            KindProj::Future(future) => future
                .poll(cx)
                .map(|result| result.map(|res| res.map(boxed))),
            KindProj::Status(status) => {
                let response = status
                    .take()
                    .unwrap()
                    .to_http()
                    .map(|_| B::default())
                    .map(boxed);
                Poll::Ready(Ok(response))
            }
        }
    }
}
#[cfg(test)]
mod tests {
    #[allow(unused_imports)]
    use super::*;
    use http::header::HeaderMap;
    use std::{
        pin::Pin,
        task::{Context, Poll},
    };
    use tower::ServiceExt;
    #[derive(Debug, Default)]
    struct TestBody;
    impl http_body::Body for TestBody {
        type Data = Bytes;
        type Error = Status;
        fn poll_data(
            self: Pin<&mut Self>,
            _cx: &mut Context<'_>,
        ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
            Poll::Ready(None)
        }
        fn poll_trailers(
            self: Pin<&mut Self>,
            _cx: &mut Context<'_>,
        ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
            Poll::Ready(Ok(None))
        }
    }
    #[tokio::test]
    async fn doesnt_remove_headers_from_requests() {
        let svc = tower::service_fn(|request: http::Request<TestBody>| async move {
            assert_eq!(
                request
                    .headers()
                    .get("user-agent")
                    .expect("missing in leaf service"),
                "test-tonic"
            );
            Ok::<_, Status>(http::Response::new(TestBody))
        });
        let svc = InterceptedService::new(svc, |request: crate::Request<()>| {
            assert_eq!(
                request
                    .metadata()
                    .get("user-agent")
                    .expect("missing in interceptor"),
                "test-tonic"
            );
            Ok(request)
        });
        let request = http::Request::builder()
            .header("user-agent", "test-tonic")
            .body(TestBody)
            .unwrap();
        svc.oneshot(request).await.unwrap();
    }
    #[tokio::test]
    async fn handles_intercepted_status_as_response() {
        let message = "Blocked by the interceptor";
        let expected = Status::permission_denied(message).to_http();
        let svc = tower::service_fn(|_: http::Request<TestBody>| async {
            Ok::<_, Status>(http::Response::new(TestBody))
        });
        let svc = InterceptedService::new(svc, |_: crate::Request<()>| {
            Err(Status::permission_denied(message))
        });
        let request = http::Request::builder().body(TestBody).unwrap();
        let response = svc.oneshot(request).await.unwrap();
        assert_eq!(expected.status(), response.status());
        assert_eq!(expected.version(), response.version());
        assert_eq!(expected.headers(), response.headers());
    }
    #[tokio::test]
    async fn doesnt_change_http_method() {
        let svc = tower::service_fn(|request: http::Request<hyper::Body>| async move {
            assert_eq!(request.method(), http::Method::OPTIONS);
            Ok::<_, hyper::Error>(hyper::Response::new(hyper::Body::empty()))
        });
        let svc = InterceptedService::new(svc, Ok);
        let request = http::Request::builder()
            .method(http::Method::OPTIONS)
            .body(hyper::Body::empty())
            .unwrap();
        svc.oneshot(request).await.unwrap();
    }
}