Make layer generic over Request body type

Nothing in the layer implementation actually depends on the Request's
body type. So generalise over the body type, allowing the service
implementation not longer be tied to axum specifically.
This commit is contained in:
Sjoerd Simons 2024-08-15 21:28:53 +02:00 committed by cduvray
parent d75fec0409
commit ac444f9286
2 changed files with 21 additions and 17 deletions

View file

@ -222,7 +222,7 @@ where
self self
} }
/// Build axum layer /// Build layer
#[deprecated(since = "0.10.0", note = "please use `IntoLayer::into_layer()` instead")] #[deprecated(since = "0.10.0", note = "please use `IntoLayer::into_layer()` instead")]
pub async fn layer(self) -> Result<AuthorizationLayer<C>, InitError> { pub async fn layer(self) -> Result<AuthorizationLayer<C>, InitError> {
let val = self.validation.unwrap_or_default(); let val = self.validation.unwrap_or_default();

View file

@ -1,6 +1,6 @@
use axum::extract::Request;
use futures_core::ready; use futures_core::ready;
use futures_util::future::{self, BoxFuture}; use futures_util::future::{self, BoxFuture};
use http::Request;
use jsonwebtoken::TokenData; use jsonwebtoken::TokenData;
use pin_project::pin_project; use pin_project::pin_project;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
@ -15,25 +15,26 @@ use crate::authorizer::Authorizer;
use crate::AuthError; use crate::AuthError;
/// Trait for authorizing requests. /// Trait for authorizing requests.
pub trait Authorize { pub trait Authorize<B> {
type Future: Future<Output = Result<Request, AuthError>>; type Future: Future<Output = Result<Request<B>, AuthError>>;
/// Authorize the request. /// Authorize the request.
/// ///
/// If the future resolves to `Ok(request)` then the request is allowed through, otherwise not. /// If the future resolves to `Ok(request)` then the request is allowed through, otherwise not.
fn authorize(&self, request: Request) -> Self::Future; fn authorize(&self, request: Request<B>) -> Self::Future;
} }
impl<S, C> Authorize for AuthorizationService<S, C> impl<S, B, C> Authorize<B> for AuthorizationService<S, C>
where where
B: Send + 'static,
C: Clone + DeserializeOwned + Send + Sync + 'static, C: Clone + DeserializeOwned + Send + Sync + 'static,
{ {
type Future = BoxFuture<'static, Result<Request, AuthError>>; type Future = BoxFuture<'static, Result<Request<B>, AuthError>>;
/// The authorizers are sequentially applied (check_auth) until one of them validates the token. /// The authorizers are sequentially applied (check_auth) until one of them validates the token.
/// If no authorizer validates the token the request is rejected. /// If no authorizer validates the token the request is rejected.
/// ///
fn authorize(&self, mut request: Request) -> Self::Future { fn authorize(&self, mut request: Request<B>) -> Self::Future {
let tkns_auths: Vec<(String, Arc<Authorizer<C>>)> = self let tkns_auths: Vec<(String, Arc<Authorizer<C>>)> = self
.auths .auths
.iter() .iter()
@ -154,21 +155,22 @@ where
} }
} }
impl<S, C> Service<Request> for AuthorizationService<S, C> impl<S, C, B> Service<Request<B>> for AuthorizationService<S, C>
where where
S: Service<Request> + Clone, B: Send + 'static,
S: Service<Request<B>> + Clone,
S::Response: From<AuthError>, S::Response: From<AuthError>,
C: Clone + DeserializeOwned + Send + Sync + 'static, C: Clone + DeserializeOwned + Send + Sync + 'static,
{ {
type Response = S::Response; type Response = S::Response;
type Error = S::Error; type Error = S::Error;
type Future = ResponseFuture<S, C>; type Future = ResponseFuture<S, C, B>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx) self.inner.poll_ready(cx)
} }
fn call(&mut self, req: Request) -> Self::Future { fn call(&mut self, req: Request<B>) -> Self::Future {
let inner = self.inner.clone(); let inner = self.inner.clone();
// take the service that was ready // take the service that was ready
let inner = std::mem::replace(&mut self.inner, inner); let inner = std::mem::replace(&mut self.inner, inner);
@ -184,13 +186,14 @@ where
#[pin_project] #[pin_project]
/// Response future for [`AuthorizationService`]. /// Response future for [`AuthorizationService`].
pub struct ResponseFuture<S, C> pub struct ResponseFuture<S, C, B>
where where
S: Service<Request>, B: Send + 'static,
S: Service<Request<B>>,
C: Clone + DeserializeOwned + Send + Sync + 'static, C: Clone + DeserializeOwned + Send + Sync + 'static,
{ {
#[pin] #[pin]
state: State<<AuthorizationService<S, C> as Authorize>::Future, S::Future>, state: State<<AuthorizationService<S, C> as Authorize<B>>::Future, S::Future>,
service: S, service: S,
} }
@ -206,9 +209,10 @@ enum State<A, SFut> {
}, },
} }
impl<S, C> Future for ResponseFuture<S, C> impl<S, C, B> Future for ResponseFuture<S, C, B>
where where
S: Service<Request>, B: Send,
S: Service<Request<B>>,
S::Response: From<AuthError>, S::Response: From<AuthError>,
C: Clone + DeserializeOwned + Send + Sync, C: Clone + DeserializeOwned + Send + Sync,
{ {