diff --git a/demo-server/src/main.rs b/demo-server/src/main.rs index a1e7bfe..91f7879 100644 --- a/demo-server/src/main.rs +++ b/demo-server/src/main.rs @@ -1,5 +1,7 @@ use axum::{routing::get, Router}; -use jwt_authorizer::{error::InitError, AuthError, JwtAuthorizer, JwtClaims, Refresh, RefreshStrategy}; +use jwt_authorizer::{ + error::InitError, AuthError, JwtAuthorizer, JwtClaims, Refresh, RefreshStrategy, ToAuthorizationLayer, +}; use serde::Deserialize; use std::net::SocketAddr; use tower_http::trace::TraceLayer; @@ -49,7 +51,7 @@ async fn main() -> Result<(), InitError> { let api = Router::new() .route("/protected", get(protected)) // adding the authorizer layer - .layer(jwt_auth.layer().await?); + .layer(jwt_auth.to_layer().await?); let app = Router::new() // public endpoint diff --git a/jwt-authorizer/src/authorizer.rs b/jwt-authorizer/src/authorizer.rs index f49b6b1..66b5980 100644 --- a/jwt-authorizer/src/authorizer.rs +++ b/jwt-authorizer/src/authorizer.rs @@ -65,7 +65,7 @@ where C: DeserializeOwned + Clone + Send + Sync, { pub(crate) async fn build( - key_source_type: &KeySourceType, + key_source_type: KeySourceType, claims_checker: Option>, refresh: Option, validation: crate::validation::Validation, @@ -157,7 +157,7 @@ where } KeySourceType::JwksString(jwks_str) => { // TODO: expose it in JwtAuthorizer or remove - let set: JwkSet = serde_json::from_str(jwks_str)?; + let set: JwkSet = serde_json::from_str(jwks_str.as_str())?; // TODO: replace [0] by kid/alg search let k = KeyData::from_jwk(&set.keys[0]).map_err(InitError::KeyDecodingError)?; Authorizer { @@ -167,7 +167,7 @@ where } } KeySourceType::Jwks(url) => { - let jwks_url = Url::parse(url).map_err(|e| InitError::JwksUrlError(e.to_string()))?; + let jwks_url = Url::parse(url.as_str()).map_err(|e| InitError::JwksUrlError(e.to_string()))?; let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default()); Authorizer { key_source: KeySource::KeyStoreSource(key_store_manager), @@ -176,7 +176,7 @@ where } } KeySourceType::Discovery(issuer_url) => { - let jwks_url = Url::parse(&oidc::discover_jwks(issuer_url).await?) + let jwks_url = Url::parse(&oidc::discover_jwks(issuer_url.as_str()).await?) .map_err(|e| InitError::JwksUrlError(e.to_string()))?; let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default()); @@ -219,7 +219,7 @@ mod tests { #[tokio::test] async fn build_from_secret() { let h = Header::new(Algorithm::HS256); - let a = Authorizer::::build(&KeySourceType::Secret("xxxxxx".to_owned()), None, None, Validation::new()) + let a = Authorizer::::build(KeySourceType::Secret("xxxxxx".to_owned()), None, None, Validation::new()) .await .unwrap(); let k = a.key_source.get_key(h); @@ -238,7 +238,7 @@ mod tests { "e": "AQAB" }]} "#; - let a = Authorizer::::build(&KeySourceType::JwksString(jwks.to_owned()), None, None, Validation::new()) + let a = Authorizer::::build(KeySourceType::JwksString(jwks.to_owned()), None, None, Validation::new()) .await .unwrap(); let k = a.key_source.get_key(Header::new(Algorithm::RS256)); @@ -248,7 +248,7 @@ mod tests { #[tokio::test] async fn build_from_file() { let a = Authorizer::::build( - &KeySourceType::RSA("../config/rsa-public1.pem".to_owned()), + KeySourceType::RSA("../config/rsa-public1.pem".to_owned()), None, None, Validation::new(), @@ -259,7 +259,7 @@ mod tests { assert!(k.await.is_ok()); let a = Authorizer::::build( - &KeySourceType::EC("../config/ecdsa-public1.pem".to_owned()), + KeySourceType::EC("../config/ecdsa-public1.pem".to_owned()), None, None, Validation::new(), @@ -270,7 +270,7 @@ mod tests { assert!(k.await.is_ok()); let a = Authorizer::::build( - &KeySourceType::ED("../config/ed25519-public1.pem".to_owned()), + KeySourceType::ED("../config/ed25519-public1.pem".to_owned()), None, None, Validation::new(), @@ -284,7 +284,7 @@ mod tests { #[tokio::test] async fn build_from_text() { let a = Authorizer::::build( - &KeySourceType::RSAString(include_str!("../../config/rsa-public1.pem").to_owned()), + KeySourceType::RSAString(include_str!("../../config/rsa-public1.pem").to_owned()), None, None, Validation::new(), @@ -295,7 +295,7 @@ mod tests { assert!(k.await.is_ok()); let a = Authorizer::::build( - &KeySourceType::ECString(include_str!("../../config/ecdsa-public1.pem").to_owned()), + KeySourceType::ECString(include_str!("../../config/ecdsa-public1.pem").to_owned()), None, None, Validation::new(), @@ -306,7 +306,7 @@ mod tests { assert!(k.await.is_ok()); let a = Authorizer::::build( - &KeySourceType::EDString(include_str!("../../config/ed25519-public1.pem").to_owned()), + KeySourceType::EDString(include_str!("../../config/ed25519-public1.pem").to_owned()), None, None, Validation::new(), @@ -320,7 +320,7 @@ mod tests { #[tokio::test] async fn build_file_errors() { let a = Authorizer::::build( - &KeySourceType::RSA("./config/does-not-exist.pem".to_owned()), + KeySourceType::RSA("./config/does-not-exist.pem".to_owned()), None, None, Validation::new(), @@ -333,7 +333,7 @@ mod tests { #[tokio::test] async fn build_jwks_url_error() { let a = - Authorizer::::build(&KeySourceType::Jwks("://xxxx".to_owned()), None, None, Validation::default()).await; + Authorizer::::build(KeySourceType::Jwks("://xxxx".to_owned()), None, None, Validation::default()).await; println!("{:?}", a.as_ref().err()); assert!(a.is_err()); } @@ -341,7 +341,7 @@ mod tests { #[tokio::test] async fn build_discovery_url_error() { let a = Authorizer::::build( - &KeySourceType::Discovery("://xxxx".to_owned()), + KeySourceType::Discovery("://xxxx".to_owned()), None, None, Validation::default(), diff --git a/jwt-authorizer/src/layer.rs b/jwt-authorizer/src/layer.rs index 6583868..bad0b31 100644 --- a/jwt-authorizer/src/layer.rs +++ b/jwt-authorizer/src/layer.rs @@ -1,6 +1,8 @@ +use axum::async_trait; use axum::http::Request; use futures_core::ready; use futures_util::future::BoxFuture; +use futures_util::stream::{FuturesUnordered, StreamExt}; use headers::authorization::Bearer; use headers::{Authorization, HeaderMapExt}; use jsonwebtoken::TokenData; @@ -184,12 +186,63 @@ where } /// Build axum layer + #[deprecated(since = "0.10.0", note = "please use `to_layer()` instead")] pub async fn layer(self) -> Result, InitError> { let val = self.validation.unwrap_or_default(); - let auth = Arc::new(Authorizer::build(&self.key_source_type, self.claims_checker, self.refresh, val).await?); - Ok(AsyncAuthorizationLayer::new(auth, self.jwt_source)) + let auth = Arc::new(Authorizer::build(self.key_source_type, self.claims_checker, self.refresh, val).await?); + Ok(AsyncAuthorizationLayer::new(vec![auth], self.jwt_source)) } } + +#[async_trait] +impl ToAuthorizationLayer for JwtAuthorizer +where + C: Clone + DeserializeOwned + Send + Sync, +{ + async fn to_layer(self) -> Result, InitError> { + let val = self.validation.unwrap_or_default(); + let auth = Arc::new(Authorizer::build(self.key_source_type, self.claims_checker, self.refresh, val).await?); + Ok(AsyncAuthorizationLayer::new(vec![auth], self.jwt_source)) + } +} + +#[async_trait] +impl ToAuthorizationLayer for Vec> +where + C: Clone + DeserializeOwned + Send + Sync, +{ + async fn to_layer(self) -> Result, InitError> { + let mut errs = Vec::::new(); + let mut auths = Vec::>>::new(); + let mut auths_futs: FuturesUnordered<_> = self + .into_iter() + .map(|a| { + Authorizer::build( + a.key_source_type, + a.claims_checker, + a.refresh, + a.validation.unwrap_or_default(), + ) + }) + .collect(); + + while let Some(a) = auths_futs.next().await { + match a { + Ok(res) => auths.push(Arc::new(res)), + Err(err) => errs.push(err), + } + } + + if let Some(e) = errs.into_iter().next() { + // TODO: composite build error (containing all errors) + Err(e) + } else { + // TODO: jwt_source per Authorizer + Ok(AsyncAuthorizationLayer::new(auths, JwtSource::AuthorizationHeader)) + } + } +} + /// Trait for authorizing requests. pub trait AsyncAuthorizer { type RequestBody; @@ -257,7 +310,7 @@ pub struct AsyncAuthorizationLayer where C: Clone + DeserializeOwned + Send, { - auth: Arc>, + auths: Vec>>, jwt_source: JwtSource, } @@ -265,8 +318,8 @@ impl AsyncAuthorizationLayer where C: Clone + DeserializeOwned + Send, { - pub fn new(auth: Arc>, jwt_source: JwtSource) -> AsyncAuthorizationLayer { - Self { auth, jwt_source } + pub fn new(auths: Vec>>, jwt_source: JwtSource) -> AsyncAuthorizationLayer { + Self { auths, jwt_source } } } @@ -277,10 +330,18 @@ where type Service = AsyncAuthorizationService; fn layer(&self, inner: S) -> Self::Service { - AsyncAuthorizationService::new(inner, self.auth.clone(), self.jwt_source.clone()) + AsyncAuthorizationService::new(inner, self.auths.clone(), self.jwt_source.clone()) } } +#[async_trait] +pub trait ToAuthorizationLayer +where + C: Clone + DeserializeOwned + Send, +{ + async fn to_layer(self) -> Result, InitError>; +} + // ---------- AsyncAuthorizationService -------- /// Source of the bearer token @@ -334,10 +395,10 @@ where /// Authorize requests using a custom scheme. /// /// The `Authorization` header is required to have the value provided. - pub fn new(inner: S, auth: Arc>, jwt_source: JwtSource) -> AsyncAuthorizationService { + pub fn new(inner: S, auths: Vec>>, jwt_source: JwtSource) -> AsyncAuthorizationService { Self { inner, - auths: vec![auth], + auths: auths, jwt_source, } } @@ -431,3 +492,36 @@ where } } } + +#[cfg(test)] +mod tests { + use crate::{JwtAuthorizer, ToAuthorizationLayer}; + + #[tokio::test] + async fn jwt_auth_to_layer() { + let auth1: JwtAuthorizer = JwtAuthorizer::from_secret("aaa"); + let layer = auth1.to_layer().await; + assert!(layer.is_ok()); + } + + #[tokio::test] + async fn vec_to_layer() { + let auth1: JwtAuthorizer = JwtAuthorizer::from_secret("aaa"); + let auth2: JwtAuthorizer = JwtAuthorizer::from_secret("bbb"); + let av = vec![auth1, auth2]; + let layer = av.to_layer().await; + assert!(layer.is_ok()); + } + + #[tokio::test] + async fn vec_to_layer_errors() { + let auth1: JwtAuthorizer = JwtAuthorizer::from_ec_pem("aaa"); + let auth2: JwtAuthorizer = JwtAuthorizer::from_ed_pem("bbb"); + let av = vec![auth1, auth2]; + let layer = av.to_layer().await; + assert!(layer.is_err()); + if let Err(err) = layer { + assert_eq!(err.to_string(), "No such file or directory (os error 2)"); + } + } +} diff --git a/jwt-authorizer/src/lib.rs b/jwt-authorizer/src/lib.rs index abc1254..e35bbfb 100644 --- a/jwt-authorizer/src/lib.rs +++ b/jwt-authorizer/src/lib.rs @@ -8,7 +8,7 @@ use serde::de::DeserializeOwned; pub use self::error::AuthError; pub use claims::{NumericDate, OneOrArray, RegisteredClaims}; pub use jwks::key_store_manager::{Refresh, RefreshStrategy}; -pub use layer::JwtAuthorizer; +pub use layer::{JwtAuthorizer, ToAuthorizationLayer}; pub use validation::Validation; pub mod authorizer; diff --git a/jwt-authorizer/tests/tests.rs b/jwt-authorizer/tests/tests.rs index 37641c4..e2cf54a 100644 --- a/jwt-authorizer/tests/tests.rs +++ b/jwt-authorizer/tests/tests.rs @@ -12,7 +12,7 @@ mod tests { BoxError, Router, }; use http::{header, HeaderValue}; - use jwt_authorizer::{layer::JwtSource, validation::Validation, JwtAuthorizer, JwtClaims}; + use jwt_authorizer::{layer::JwtSource, validation::Validation, JwtAuthorizer, JwtClaims, ToAuthorizationLayer}; use serde::Deserialize; use tower::{util::MapErrLayer, ServiceExt}; @@ -32,7 +32,7 @@ mod tests { tower::buffer::BufferLayer::new(1), MapErrLayer::new(|e: BoxError| -> Infallible { panic!("{}", e) }), ), - jwt_auth.layer().await.unwrap(), + jwt_auth.to_layer().await.unwrap(), ), ), )