fix: use Request, ignore props of ReqBody

This commit is contained in:
Daniel Gallups 2023-12-01 18:47:08 -05:00
parent dd2a48b00c
commit 526fc77dae
2 changed files with 19 additions and 22 deletions

View file

@ -1,5 +1,4 @@
use axum::body::Body; use axum::extract::Request;
use axum::http::Request;
use futures_core::ready; use futures_core::ready;
use futures_util::future::{self, BoxFuture}; use futures_util::future::{self, BoxFuture};
use jsonwebtoken::TokenData; use jsonwebtoken::TokenData;
@ -17,28 +16,25 @@ use crate::authorizer::Authorizer;
use crate::AuthError; use crate::AuthError;
/// Trait for authorizing requests. /// Trait for authorizing requests.
pub trait Authorize<B> { pub trait Authorize {
type RequestBody; type Future: Future<Output = Result<Request, AuthError>>;
type Future: Future<Output = Result<Request<Self::RequestBody>, AuthError>>;
/// Authorize the request. /// Authorize the request.
/// ///
/// If the future resolves to `Ok(request)` then the request is allowed through, otherwise not. /// If the future resolves to `Ok(request)` then the request is allowed through, otherwise not.
fn authorize(&self, request: Request<B>) -> Self::Future; fn authorize(&self, request: Request) -> Self::Future;
} }
impl<B, S, C> Authorize<B> for AuthorizationService<S, C> impl<S, C> Authorize for AuthorizationService<S, C>
where where
B: Send + 'static, C: Clone + DeserializeOwned + Send + Sync + 'static,
C: Clone + DeserializeOwned + Send + 'static,
{ {
type RequestBody = B; type Future = BoxFuture<'static, Result<Request, AuthError>>;
type Future = BoxFuture<'static, Result<Request<B>, AuthError>>;
/// The authorizers are sequentially applied (check_auth) until one of them validates the token. /// The authorizers are sequentially applied (check_auth) until one of them validates the token.
/// If no authorizer validates the token the request is rejected. /// If no authorizer validates the token the request is rejected.
/// ///
fn authorize(&self, mut request: Request<B>) -> Self::Future { fn authorize(&self, mut request: Request) -> Self::Future {
let tkns_auths: Vec<(String, Arc<Authorizer<C>>)> = self let tkns_auths: Vec<(String, Arc<Authorizer<C>>)> = self
.auths .auths
.iter() .iter()
@ -160,9 +156,9 @@ where
} }
} }
impl<S, C> Service<Request<Body>> for AuthorizationService<S, C> impl<S, C> Service<Request> for AuthorizationService<S, C>
where where
S: Service<Request<Body>> + Clone, S: Service<Request> + Clone,
S::Response: From<AuthError>, S::Response: From<AuthError>,
C: Clone + DeserializeOwned + Send + Sync + 'static, C: Clone + DeserializeOwned + Send + Sync + 'static,
{ {
@ -174,7 +170,7 @@ where
self.inner.poll_ready(cx) self.inner.poll_ready(cx)
} }
fn call(&mut self, req: Request<Body>) -> Self::Future { fn call(&mut self, req: Request) -> Self::Future {
let inner = self.inner.clone(); let inner = self.inner.clone();
// take the service that was ready // take the service that was ready
let inner = std::mem::replace(&mut self.inner, inner); let inner = std::mem::replace(&mut self.inner, inner);
@ -192,11 +188,11 @@ where
/// Response future for [`AuthorizationService`]. /// Response future for [`AuthorizationService`].
pub struct ResponseFuture<S, C> pub struct ResponseFuture<S, C>
where where
S: Service<Request<Body>>, S: Service<Request>,
C: Clone + DeserializeOwned + Send + Sync + 'static, C: Clone + DeserializeOwned + Send + Sync + 'static,
{ {
#[pin] #[pin]
state: State<<AuthorizationService<S, C> as Authorize<Body>>::Future, S::Future>, state: State<<AuthorizationService<S, C> as Authorize>::Future, S::Future>,
service: S, service: S,
} }
@ -214,7 +210,7 @@ enum State<A, SFut> {
impl<S, C> Future for ResponseFuture<S, C> impl<S, C> Future for ResponseFuture<S, C>
where where
S: Service<Request<Body>>, S: Service<Request>,
S::Response: From<AuthError>, S::Response: From<AuthError>,
C: Clone + DeserializeOwned + Send + Sync, C: Clone + DeserializeOwned + Send + Sync,
{ {

View file

@ -8,8 +8,7 @@ use std::{
time::Duration, time::Duration,
}; };
use axum::body::Body; use axum::{body::Body, response::Response, routing::get, Json, Router};
use axum::{response::Response, routing::get, Json, Router};
use http::{header::AUTHORIZATION, Request, StatusCode}; use http::{header::AUTHORIZATION, Request, StatusCode};
use jwt_authorizer::{IntoLayer, JwtAuthorizer, JwtClaims, Refresh, RefreshStrategy, Validation}; use jwt_authorizer::{IntoLayer, JwtAuthorizer, JwtClaims, Refresh, RefreshStrategy, Validation};
use lazy_static::lazy_static; use lazy_static::lazy_static;
@ -127,7 +126,8 @@ fn init_test() {
} }
async fn make_proteced_request(app: &mut Router, bearer: &str) -> Response { async fn make_proteced_request(app: &mut Router, bearer: &str) -> Response {
<Router as tower::ServiceExt<Request<Body>>>::ready(app) app.as_service()
.ready()
.await .await
.unwrap() .unwrap()
.call( .call(
@ -142,7 +142,8 @@ async fn make_proteced_request(app: &mut Router, bearer: &str) -> Response {
} }
async fn make_public_request(app: &mut Router) -> Response { async fn make_public_request(app: &mut Router) -> Response {
<Router as tower::ServiceExt<Request<Body>>>::ready(app) app.as_service()
.ready()
.await .await
.unwrap() .unwrap()
.call(Request::builder().uri("/public").body(Body::empty()).unwrap()) .call(Request::builder().uri("/public").body(Body::empty()).unwrap())