diff --git a/jwt-authorizer/src/authorizer.rs b/jwt-authorizer/src/authorizer.rs index 1465590..d80963c 100644 --- a/jwt-authorizer/src/authorizer.rs +++ b/jwt-authorizer/src/authorizer.rs @@ -48,7 +48,7 @@ pub enum KeySourceType { impl Authorizer where - C: DeserializeOwned + Clone + Send + Sync, + C: DeserializeOwned + Clone + Send, { pub(crate) async fn build( key_source_type: KeySourceType, diff --git a/jwt-authorizer/src/layer.rs b/jwt-authorizer/src/layer.rs index 83297a3..6fdff6f 100644 --- a/jwt-authorizer/src/layer.rs +++ b/jwt-authorizer/src/layer.rs @@ -1,3 +1,4 @@ +use axum::body::Body; use axum::http::Request; use futures_core::ready; use futures_util::future::{self, BoxFuture}; @@ -8,6 +9,7 @@ use std::future::Future; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use tokio::sync::Mutex; use tower_layer::Layer; use tower_service::Service; @@ -27,8 +29,8 @@ pub trait Authorize { impl Authorize for AuthorizationService where - B: Send + Sync + 'static, - C: Clone + DeserializeOwned + Send + Sync + 'static, + B: Send + 'static, + C: Clone + DeserializeOwned + Send + 'static, { type RequestBody = B; type Future = BoxFuture<'static, Result, AuthError>>; @@ -59,7 +61,9 @@ where Ok(tdata) => { // Set `token_data` as a request extension so it can be accessed by other // services down the stack. - request.extensions_mut().insert(tdata); + + let something = Arc::new(Mutex::new(tdata)); + request.extensions_mut().insert(something); Ok(request) } @@ -119,7 +123,7 @@ pub enum JwtSource { #[derive(Clone)] pub struct AuthorizationService where - C: Clone + DeserializeOwned + Send + Sync, + C: Clone + DeserializeOwned + Send, { pub inner: S, pub auths: Vec>>, @@ -127,7 +131,7 @@ where impl AuthorizationService where - C: Clone + DeserializeOwned + Send + Sync, + C: Clone + DeserializeOwned + Send, { pub fn get_ref(&self) -> &S { &self.inner @@ -156,6 +160,34 @@ where } } +impl Service> for AuthorizationService +where + S: Service> + Clone, + S::Response: From, + C: Clone + DeserializeOwned + Send + Sync + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + let inner = self.inner.clone(); + // take the service that was ready + let inner = std::mem::replace(&mut self.inner, inner); + + let auth_fut = self.authorize(req); + + ResponseFuture { + state: State::Authorize { auth_fut }, + service: inner, + } + } +} +/* impl Service> for AuthorizationService where ReqBody: Send + Sync + 'static, @@ -184,17 +216,17 @@ where } } } +*/ #[pin_project] /// Response future for [`AuthorizationService`]. -pub struct ResponseFuture +pub struct ResponseFuture where - S: Service>, - ReqBody: Send + Sync + 'static, + S: Service>, C: Clone + DeserializeOwned + Send + Sync + 'static, { #[pin] - state: State< as Authorize>::Future, S::Future>, + state: State< as Authorize>::Future, S::Future>, service: S, } @@ -210,11 +242,10 @@ enum State { }, } -impl Future for ResponseFuture +impl Future for ResponseFuture where - S: Service>, + S: Service>, S::Response: From, - ReqBody: Send + Sync + 'static, C: Clone + DeserializeOwned + Send + Sync, { type Output = Result; diff --git a/jwt-authorizer/tests/integration_tests.rs b/jwt-authorizer/tests/integration_tests.rs index 294a5c3..f31a253 100644 --- a/jwt-authorizer/tests/integration_tests.rs +++ b/jwt-authorizer/tests/integration_tests.rs @@ -8,9 +8,9 @@ use std::{ time::Duration, }; +use axum::body::Body; use axum::{response::Response, routing::get, Json, Router}; use http::{header::AUTHORIZATION, Request, StatusCode}; -use hyper::Body; use jwt_authorizer::{IntoLayer, JwtAuthorizer, JwtClaims, Refresh, RefreshStrategy, Validation}; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; @@ -127,7 +127,7 @@ fn init_test() { } async fn make_proteced_request(app: &mut Router, bearer: &str) -> Response { - app.ready() + >>::ready(app) .await .unwrap() .call( @@ -142,7 +142,7 @@ async fn make_proteced_request(app: &mut Router, bearer: &str) -> Response { } async fn make_public_request(app: &mut Router) -> Response { - app.ready() + >>::ready(app) .await .unwrap() .call(Request::builder().uri("/public").body(Body::empty()).unwrap())