diff --git a/CHANGELOG.md b/CHANGELOG.md index cd86782..b052582 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +## 0.11 (2023-xx-xx) + +- support for multiple authorizers + - JwtAuthorizer::layer() deprecated in favor of JwtAuthorizer::build() and IntoLayer::into_layer() + ## 0.10.1 (2023-07-11) ### Fixed diff --git a/demo-server/src/main.rs b/demo-server/src/main.rs index a1e7bfe..df42105 100644 --- a/demo-server/src/main.rs +++ b/demo-server/src/main.rs @@ -1,5 +1,7 @@ use axum::{routing::get, Router}; -use jwt_authorizer::{error::InitError, AuthError, JwtAuthorizer, JwtClaims, Refresh, RefreshStrategy}; +use jwt_authorizer::{ + error::InitError, AuthError, Authorizer, IntoLayer, JwtAuthorizer, JwtClaims, Refresh, RefreshStrategy, +}; use serde::Deserialize; use std::net::SocketAddr; use tower_http::trace::TraceLayer; @@ -37,19 +39,21 @@ async fn main() -> Result<(), InitError> { // First let's create an authorizer builder from a Oidc Discovery // User is a struct deserializable from JWT claims representing the authorized user // let jwt_auth: JwtAuthorizer = JwtAuthorizer::from_oidc("https://accounts.google.com/") - let jwt_auth: JwtAuthorizer = JwtAuthorizer::from_oidc(issuer_uri) + let auth: Authorizer = JwtAuthorizer::from_oidc(issuer_uri) // .no_refresh() .refresh(Refresh { strategy: RefreshStrategy::Interval, ..Default::default() }) - .check(claim_checker); + .check(claim_checker) + .build() + .await?; // actual router demo let api = Router::new() .route("/protected", get(protected)) // adding the authorizer layer - .layer(jwt_auth.layer().await?); + .layer(auth.into_layer()); let app = Router::new() // public endpoint diff --git a/jwt-authorizer/docs/README.md b/jwt-authorizer/docs/README.md index bed9828..6e21748 100644 --- a/jwt-authorizer/docs/README.md +++ b/jwt-authorizer/docs/README.md @@ -14,12 +14,14 @@ JWT authoriser Layer for Axum and Tonic. - Claims extraction - Claims checker - Tracing support (error logging) +- *tonic* support +- multiple authorizers ## Usage Example ```rust -# use jwt_authorizer::{AuthError, JwtAuthorizer, JwtClaims, RegisteredClaims}; +# use jwt_authorizer::{AuthError, Authorizer, JwtAuthorizer, JwtClaims, RegisteredClaims, IntoLayer}; # use axum::{routing::get, Router}; # use serde::Deserialize; @@ -27,12 +29,12 @@ JWT authoriser Layer for Axum and Tonic. // let's create an authorizer builder from a JWKS Endpoint // (a serializable struct can be used to represent jwt claims, JwtAuthorizer is the default) - let jwt_auth: JwtAuthorizer = - JwtAuthorizer::from_jwks_url("http://localhost:3000/oidc/jwks"); + let auth: Authorizer = + JwtAuthorizer::from_jwks_url("http://localhost:3000/oidc/jwks").build().await.unwrap(); // adding the authorization layer let app = Router::new().route("/protected", get(protected)) - .layer(jwt_auth.layer().await.unwrap()); + .layer(auth.into_layer()); // proteced handler with user injection (mapping some jwt claims) async fn protected(JwtClaims(user): JwtClaims) -> Result { @@ -45,6 +47,11 @@ JWT authoriser Layer for Axum and Tonic. # }; ``` +## Multiple Authorizers + +A layer can be built using multiple authorizers (`IntoLayer` is implemented for `[Authorizer; N]` and for `Vec>`). +The authorizers are sequentially applied until one of them validates the token. If no authorizer validates it the request is rejected. + ## Validation Validation configuration object. diff --git a/jwt-authorizer/src/authorizer.rs b/jwt-authorizer/src/authorizer.rs index f49b6b1..5f4ac70 100644 --- a/jwt-authorizer/src/authorizer.rs +++ b/jwt-authorizer/src/authorizer.rs @@ -1,5 +1,7 @@ use std::{io::Read, sync::Arc}; +use headers::{authorization::Bearer, Authorization, HeaderMapExt}; +use http::HeaderMap; use jsonwebtoken::{decode, decode_header, jwk::JwkSet, Algorithm, DecodingKey, TokenData}; use reqwest::Url; use serde::de::DeserializeOwned; @@ -7,7 +9,8 @@ use serde::de::DeserializeOwned; use crate::{ error::{AuthError, InitError}, jwks::{key_store_manager::KeyStoreManager, KeyData, KeySource}, - oidc, Refresh, + layer::{self, AsyncAuthorizationLayer, JwtSource}, + oidc, Refresh, RegisteredClaims, }; pub trait ClaimsChecker { @@ -31,13 +34,14 @@ where } } -pub struct Authorizer +pub struct Authorizer where C: Clone, { pub key_source: KeySource, pub claims_checker: Option>, pub validation: crate::validation::Validation, + pub jwt_source: JwtSource, } fn read_data(path: &str) -> Result, InitError> { @@ -65,10 +69,11 @@ where C: DeserializeOwned + Clone + Send + Sync, { pub(crate) async fn build( - key_source_type: &KeySourceType, + key_source_type: KeySourceType, claims_checker: Option>, refresh: Option, validation: crate::validation::Validation, + jwt_source: JwtSource, ) -> Result, InitError> { Ok(match key_source_type { KeySourceType::RSA(path) => { @@ -81,6 +86,7 @@ where })), claims_checker, validation, + jwt_source, } } KeySourceType::RSAString(text) => { @@ -93,6 +99,7 @@ where })), claims_checker, validation, + jwt_source, } } KeySourceType::EC(path) => { @@ -105,6 +112,7 @@ where })), claims_checker, validation, + jwt_source, } } KeySourceType::ECString(text) => { @@ -117,6 +125,7 @@ where })), claims_checker, validation, + jwt_source, } } KeySourceType::ED(path) => { @@ -129,6 +138,7 @@ where })), claims_checker, validation, + jwt_source, } } KeySourceType::EDString(text) => { @@ -141,6 +151,7 @@ where })), claims_checker, validation, + jwt_source, } } KeySourceType::Secret(secret) => { @@ -153,30 +164,33 @@ where })), claims_checker, validation, + jwt_source, } } KeySourceType::JwksString(jwks_str) => { // TODO: expose it in JwtAuthorizer or remove - let set: JwkSet = serde_json::from_str(jwks_str)?; + let set: JwkSet = serde_json::from_str(jwks_str.as_str())?; // TODO: replace [0] by kid/alg search let k = KeyData::from_jwk(&set.keys[0]).map_err(InitError::KeyDecodingError)?; Authorizer { key_source: KeySource::SingleKeySource(Arc::new(k)), claims_checker, validation, + jwt_source, } } KeySourceType::Jwks(url) => { - let jwks_url = Url::parse(url).map_err(|e| InitError::JwksUrlError(e.to_string()))?; + let jwks_url = Url::parse(url.as_str()).map_err(|e| InitError::JwksUrlError(e.to_string()))?; let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default()); Authorizer { key_source: KeySource::KeyStoreSource(key_store_manager), claims_checker, validation, + jwt_source, } } KeySourceType::Discovery(issuer_url) => { - let jwks_url = Url::parse(&oidc::discover_jwks(issuer_url).await?) + let jwks_url = Url::parse(&oidc::discover_jwks(issuer_url.as_str()).await?) .map_err(|e| InitError::JwksUrlError(e.to_string()))?; let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default()); @@ -184,6 +198,7 @@ where key_source: KeySource::KeyStoreSource(key_store_manager), claims_checker, validation, + jwt_source, } } }) @@ -204,6 +219,52 @@ where Ok(token_data) } + + pub fn extract_token(&self, h: &HeaderMap) -> Option { + match &self.jwt_source { + layer::JwtSource::AuthorizationHeader => { + let bearer_o: Option> = h.typed_get(); + bearer_o.map(|b| String::from(b.0.token())) + } + layer::JwtSource::Cookie(name) => h + .typed_get::() + .and_then(|c| c.get(name.as_str()).map(String::from)), + } + } +} + +pub trait IntoLayer +where + C: Clone + DeserializeOwned + Send, +{ + fn into_layer(self) -> AsyncAuthorizationLayer; +} + +impl IntoLayer for Vec> +where + C: Clone + DeserializeOwned + Send, +{ + fn into_layer(self) -> AsyncAuthorizationLayer { + AsyncAuthorizationLayer::new(self.into_iter().map(Arc::new).collect()) + } +} + +impl IntoLayer for [Authorizer; N] +where + C: Clone + DeserializeOwned + Send, +{ + fn into_layer(self) -> AsyncAuthorizationLayer { + AsyncAuthorizationLayer::new(self.into_iter().map(Arc::new).collect()) + } +} + +impl IntoLayer for Authorizer +where + C: Clone + DeserializeOwned + Send, +{ + fn into_layer(self) -> AsyncAuthorizationLayer { + AsyncAuthorizationLayer::new(vec![Arc::new(self)]) + } } #[cfg(test)] @@ -212,16 +273,22 @@ mod tests { use jsonwebtoken::{Algorithm, Header}; use serde_json::Value; - use crate::validation::Validation; + use crate::{layer::JwtSource, validation::Validation}; use super::{Authorizer, KeySourceType}; #[tokio::test] async fn build_from_secret() { let h = Header::new(Algorithm::HS256); - let a = Authorizer::::build(&KeySourceType::Secret("xxxxxx".to_owned()), None, None, Validation::new()) - .await - .unwrap(); + let a = Authorizer::::build( + KeySourceType::Secret("xxxxxx".to_owned()), + None, + None, + Validation::new(), + JwtSource::AuthorizationHeader, + ) + .await + .unwrap(); let k = a.key_source.get_key(h); assert!(k.await.is_ok()); } @@ -238,9 +305,15 @@ mod tests { "e": "AQAB" }]} "#; - let a = Authorizer::::build(&KeySourceType::JwksString(jwks.to_owned()), None, None, Validation::new()) - .await - .unwrap(); + let a = Authorizer::::build( + KeySourceType::JwksString(jwks.to_owned()), + None, + None, + Validation::new(), + JwtSource::AuthorizationHeader, + ) + .await + .unwrap(); let k = a.key_source.get_key(Header::new(Algorithm::RS256)); assert!(k.await.is_ok()); } @@ -248,10 +321,11 @@ mod tests { #[tokio::test] async fn build_from_file() { let a = Authorizer::::build( - &KeySourceType::RSA("../config/rsa-public1.pem".to_owned()), + KeySourceType::RSA("../config/rsa-public1.pem".to_owned()), None, None, Validation::new(), + JwtSource::AuthorizationHeader, ) .await .unwrap(); @@ -259,10 +333,11 @@ mod tests { assert!(k.await.is_ok()); let a = Authorizer::::build( - &KeySourceType::EC("../config/ecdsa-public1.pem".to_owned()), + KeySourceType::EC("../config/ecdsa-public1.pem".to_owned()), None, None, Validation::new(), + JwtSource::AuthorizationHeader, ) .await .unwrap(); @@ -270,10 +345,11 @@ mod tests { assert!(k.await.is_ok()); let a = Authorizer::::build( - &KeySourceType::ED("../config/ed25519-public1.pem".to_owned()), + KeySourceType::ED("../config/ed25519-public1.pem".to_owned()), None, None, Validation::new(), + JwtSource::AuthorizationHeader, ) .await .unwrap(); @@ -284,10 +360,11 @@ mod tests { #[tokio::test] async fn build_from_text() { let a = Authorizer::::build( - &KeySourceType::RSAString(include_str!("../../config/rsa-public1.pem").to_owned()), + KeySourceType::RSAString(include_str!("../../config/rsa-public1.pem").to_owned()), None, None, Validation::new(), + JwtSource::AuthorizationHeader, ) .await .unwrap(); @@ -295,10 +372,11 @@ mod tests { assert!(k.await.is_ok()); let a = Authorizer::::build( - &KeySourceType::ECString(include_str!("../../config/ecdsa-public1.pem").to_owned()), + KeySourceType::ECString(include_str!("../../config/ecdsa-public1.pem").to_owned()), None, None, Validation::new(), + JwtSource::AuthorizationHeader, ) .await .unwrap(); @@ -306,10 +384,11 @@ mod tests { assert!(k.await.is_ok()); let a = Authorizer::::build( - &KeySourceType::EDString(include_str!("../../config/ed25519-public1.pem").to_owned()), + KeySourceType::EDString(include_str!("../../config/ed25519-public1.pem").to_owned()), None, None, Validation::new(), + JwtSource::AuthorizationHeader, ) .await .unwrap(); @@ -320,10 +399,11 @@ mod tests { #[tokio::test] async fn build_file_errors() { let a = Authorizer::::build( - &KeySourceType::RSA("./config/does-not-exist.pem".to_owned()), + KeySourceType::RSA("./config/does-not-exist.pem".to_owned()), None, None, Validation::new(), + JwtSource::AuthorizationHeader, ) .await; println!("{:?}", a.as_ref().err()); @@ -332,8 +412,14 @@ mod tests { #[tokio::test] async fn build_jwks_url_error() { - let a = - Authorizer::::build(&KeySourceType::Jwks("://xxxx".to_owned()), None, None, Validation::default()).await; + let a = Authorizer::::build( + KeySourceType::Jwks("://xxxx".to_owned()), + None, + None, + Validation::default(), + JwtSource::AuthorizationHeader, + ) + .await; println!("{:?}", a.as_ref().err()); assert!(a.is_err()); } @@ -341,10 +427,11 @@ mod tests { #[tokio::test] async fn build_discovery_url_error() { let a = Authorizer::::build( - &KeySourceType::Discovery("://xxxx".to_owned()), + KeySourceType::Discovery("://xxxx".to_owned()), None, None, Validation::default(), + JwtSource::AuthorizationHeader, ) .await; println!("{:?}", a.as_ref().err()); diff --git a/jwt-authorizer/src/error.rs b/jwt-authorizer/src/error.rs index 2575cca..35b829e 100644 --- a/jwt-authorizer/src/error.rs +++ b/jwt-authorizer/src/error.rs @@ -56,6 +56,9 @@ pub enum AuthError { #[error("Invalid Claim")] InvalidClaims(), + #[error("No Authorizer")] + NoAuthorizer(), + /// Used when a claim extractor is used and no authorization layer is in front the handler #[error("No Authorizer Layer")] NoAuthorizerLayer(), @@ -117,6 +120,10 @@ impl From for Response { debug!("AuthErrors::InvalidClaims"); tonic::Status::unauthenticated("error=\"insufficient_scope\"") } + AuthError::NoAuthorizer() => { + debug!("AuthErrors::NoAuthorizer"); + tonic::Status::unauthenticated("error=\"invalid_token\"") + } AuthError::NoAuthorizerLayer() => { debug!("AuthErrors::NoAuthorizerLayer"); tonic::Status::unauthenticated("error=\"no_authorizer_layer\"") @@ -174,6 +181,10 @@ impl IntoResponse for AuthError { debug!("AuthErrors::InvalidClaims"); response_wwwauth(StatusCode::FORBIDDEN, "error=\"insufficient_scope\"") } + AuthError::NoAuthorizer() => { + debug!("AuthErrors::NoAuthorizer"); + response_wwwauth(StatusCode::FORBIDDEN, "error=\"invalid_token\"") + } AuthError::NoAuthorizerLayer() => { debug!("AuthErrors::NoAuthorizerLayer"); // TODO: should it be a standard error? diff --git a/jwt-authorizer/src/layer.rs b/jwt-authorizer/src/layer.rs index e37f82c..a8b8fb5 100644 --- a/jwt-authorizer/src/layer.rs +++ b/jwt-authorizer/src/layer.rs @@ -1,8 +1,7 @@ use axum::http::Request; use futures_core::ready; -use futures_util::future::BoxFuture; -use headers::authorization::Bearer; -use headers::{Authorization, HeaderMapExt}; +use futures_util::future::{self, BoxFuture}; +use jsonwebtoken::TokenData; use pin_project::pin_project; use serde::de::DeserializeOwned; use std::future::Future; @@ -17,7 +16,7 @@ use crate::claims::RegisteredClaims; use crate::error::InitError; use crate::jwks::key_store_manager::Refresh; use crate::validation::Validation; -use crate::{layer, AuthError, RefreshStrategy}; +use crate::{AuthError, RefreshStrategy}; /// Authorizer Layer builder /// @@ -183,12 +182,22 @@ where } /// Build axum layer + #[deprecated(since = "0.10.0", note = "please use `IntoLayer::into_layer()` instead")] pub async fn layer(self) -> Result, InitError> { let val = self.validation.unwrap_or_default(); - let auth = Arc::new(Authorizer::build(&self.key_source_type, self.claims_checker, self.refresh, val).await?); - Ok(AsyncAuthorizationLayer::new(auth, self.jwt_source)) + let auth = Arc::new( + Authorizer::build(self.key_source_type, self.claims_checker, self.refresh, val, self.jwt_source).await?, + ); + Ok(AsyncAuthorizationLayer::new(vec![auth])) + } + + pub async fn build(self) -> Result, InitError> { + let val = self.validation.unwrap_or_default(); + + Authorizer::build(self.key_source_type, self.claims_checker, self.refresh, val, self.jwt_source).await } } + /// Trait for authorizing requests. pub trait AsyncAuthorizer { type RequestBody; @@ -208,30 +217,37 @@ where type RequestBody = B; type Future = BoxFuture<'static, Result, AuthError>>; + /// The authorizers are sequentially applied (check_auth) until one of them validates the token. + /// If no authorizer validates the token the request is rejected. + /// fn authorize(&self, mut request: Request) -> Self::Future { - let authorizer = self.auth.clone(); - let h = request.headers(); + let tkns_auths: Vec<(String, Arc>)> = self + .auths + .iter() + .filter_map(|a| a.extract_token(request.headers()).map(|t| (t, a.clone()))) + .collect(); + + if tkns_auths.is_empty() { + return Box::pin(future::ready(Err(AuthError::MissingToken()))); + } - let token = match &self.jwt_source { - layer::JwtSource::AuthorizationHeader => { - let bearer_o: Option> = h.typed_get(); - bearer_o.map(|b| String::from(b.0.token())) - } - layer::JwtSource::Cookie(name) => h - .typed_get::() - .and_then(|c| c.get(name.as_str()).map(String::from)), - }; Box::pin(async move { - if let Some(token) = token { - authorizer.check_auth(token.as_str()).await.map(|token_data| { + let mut token_data: Result, AuthError> = Err(AuthError::NoAuthorizer()); + for (token, auth) in tkns_auths { + token_data = auth.check_auth(token.as_str()).await; + if token_data.is_ok() { + break; + } + } + match token_data { + Ok(tdata) => { // Set `token_data` as a request extension so it can be accessed by other // services down the stack. - request.extensions_mut().insert(token_data); + request.extensions_mut().insert(tdata); - request - }) - } else { - Err(AuthError::MissingToken()) + Ok(request) + } + Err(err) => Err(err), // TODO: error containing all errors (not just the last one) or to choose one? } }) } @@ -244,16 +260,15 @@ pub struct AsyncAuthorizationLayer where C: Clone + DeserializeOwned + Send, { - auth: Arc>, - jwt_source: JwtSource, + auths: Vec>>, } impl AsyncAuthorizationLayer where C: Clone + DeserializeOwned + Send, { - pub fn new(auth: Arc>, jwt_source: JwtSource) -> AsyncAuthorizationLayer { - Self { auth, jwt_source } + pub fn new(auths: Vec>>) -> AsyncAuthorizationLayer { + Self { auths } } } @@ -264,7 +279,7 @@ where type Service = AsyncAuthorizationService; fn layer(&self, inner: S) -> Self::Service { - AsyncAuthorizationService::new(inner, self.auth.clone(), self.jwt_source.clone()) + AsyncAuthorizationService::new(inner, self.auths.clone()) } } @@ -291,8 +306,7 @@ where C: Clone + DeserializeOwned + Send + Sync, { pub inner: S, - pub auth: Arc>, - pub jwt_source: JwtSource, + pub auths: Vec>>, } impl AsyncAuthorizationService @@ -321,8 +335,8 @@ where /// Authorize requests using a custom scheme. /// /// The `Authorization` header is required to have the value provided. - pub fn new(inner: S, auth: Arc>, jwt_source: JwtSource) -> AsyncAuthorizationService { - Self { inner, auth, jwt_source } + pub fn new(inner: S, auths: Vec>>) -> AsyncAuthorizationService { + Self { inner, auths } } } @@ -414,3 +428,43 @@ where } } } + +#[cfg(test)] +mod tests { + use crate::{authorizer::Authorizer, IntoLayer, JwtAuthorizer, RegisteredClaims}; + + use super::AsyncAuthorizationLayer; + + #[tokio::test] + async fn auth_into_layer() { + let auth1: Authorizer = JwtAuthorizer::from_secret("aaa").build().await.unwrap(); + let layer = auth1.into_layer(); + assert_eq!(1, layer.auths.len()); + } + + #[tokio::test] + async fn auths_into_layer() { + let auth1 = JwtAuthorizer::from_secret("aaa").build().await.unwrap(); + let auth2 = JwtAuthorizer::from_secret("bbb").build().await.unwrap(); + + let layer: AsyncAuthorizationLayer = [auth1, auth2].into_layer(); + assert_eq!(2, layer.auths.len()); + } + + #[tokio::test] + async fn vec_auths_into_layer() { + let auth1 = JwtAuthorizer::from_secret("aaa").build().await.unwrap(); + let auth2 = JwtAuthorizer::from_secret("bbb").build().await.unwrap(); + + let layer: AsyncAuthorizationLayer = vec![auth1, auth2].into_layer(); + assert_eq!(2, layer.auths.len()); + } + + #[tokio::test] + async fn jwt_auth_to_layer() { + let auth1: JwtAuthorizer = JwtAuthorizer::from_secret("aaa"); + #[allow(deprecated)] + let layer = auth1.layer().await.unwrap(); + assert_eq!(1, layer.auths.len()); + } +} diff --git a/jwt-authorizer/src/lib.rs b/jwt-authorizer/src/lib.rs index 1c60ec9..1940777 100644 --- a/jwt-authorizer/src/lib.rs +++ b/jwt-authorizer/src/lib.rs @@ -5,6 +5,7 @@ use jsonwebtoken::TokenData; use serde::de::DeserializeOwned; pub use self::error::AuthError; +pub use authorizer::{Authorizer, IntoLayer}; pub use claims::{NumericDate, OneOrArray, RegisteredClaims}; pub use jwks::key_store_manager::{Refresh, RefreshStrategy}; pub use layer::JwtAuthorizer; diff --git a/jwt-authorizer/tests/integration_tests.rs b/jwt-authorizer/tests/integration_tests.rs index a9ac2ef..067b8cd 100644 --- a/jwt-authorizer/tests/integration_tests.rs +++ b/jwt-authorizer/tests/integration_tests.rs @@ -11,7 +11,7 @@ use std::{ use axum::{response::Response, routing::get, Json, Router}; use http::{header::AUTHORIZATION, Request, StatusCode}; use hyper::Body; -use jwt_authorizer::{JwtAuthorizer, JwtClaims, Refresh, RefreshStrategy}; +use jwt_authorizer::{IntoLayer, JwtAuthorizer, JwtClaims, Refresh, RefreshStrategy}; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -104,7 +104,7 @@ async fn app(jwt_auth: JwtAuthorizer) -> Router { let protected_route: Router = Router::new() .route("/protected", get(protected_handler)) .route("/protected-with-user", get(protected_with_user)) - .layer(jwt_auth.layer().await.unwrap()); + .layer(jwt_auth.build().await.unwrap().into_layer()); Router::new().merge(pub_route).merge(protected_route) } diff --git a/jwt-authorizer/tests/tests.rs b/jwt-authorizer/tests/tests.rs index 1e45fde..3f99917 100644 --- a/jwt-authorizer/tests/tests.rs +++ b/jwt-authorizer/tests/tests.rs @@ -12,7 +12,12 @@ mod tests { BoxError, Router, }; use http::{header, HeaderValue}; - use jwt_authorizer::{layer::JwtSource, validation::Validation, JwtAuthorizer, JwtClaims}; + use jwt_authorizer::{ + authorizer::Authorizer, + layer::{AsyncAuthorizationLayer, JwtSource}, + validation::Validation, + IntoLayer, JwtAuthorizer, JwtClaims, + }; use serde::Deserialize; use tower::{util::MapErrLayer, ServiceExt}; @@ -23,7 +28,7 @@ mod tests { sub: String, } - async fn app(jwt_auth: JwtAuthorizer) -> Router { + async fn app(layer: AsyncAuthorizationLayer) -> Router { Router::new().route("/public", get(|| async { "hello" })).route( "/protected", get(|JwtClaims(user): JwtClaims| async move { format!("hello: {}", user.sub) }).layer( @@ -32,14 +37,22 @@ mod tests { tower::buffer::BufferLayer::new(1), MapErrLayer::new(|e: BoxError| -> Infallible { panic!("{}", e) }), ), - jwt_auth.layer().await.unwrap(), + layer, ), ), ) } async fn proteced_request_with_header(jwt_auth: JwtAuthorizer, header_name: &str, header_value: &str) -> Response { - app(jwt_auth) + proteced_request_with_header_and_layer(jwt_auth.build().await.unwrap().into_layer(), header_name, header_value).await + } + + async fn proteced_request_with_header_and_layer( + layer: AsyncAuthorizationLayer, + header_name: &str, + header_value: &str, + ) -> Response { + app(layer) .await .oneshot( Request::builder() @@ -58,9 +71,12 @@ mod tests { #[tokio::test] async fn protected_without_jwt() { - let jwt_auth: JwtAuthorizer = JwtAuthorizer::from_rsa_pem("../config/rsa-public1.pem"); + let auth: Authorizer = JwtAuthorizer::from_rsa_pem("../config/rsa-public1.pem") + .build() + .await + .unwrap(); - let response = app(jwt_auth) + let response = app(auth.into_layer()) .await .oneshot(Request::builder().uri("/protected").body(Body::empty()).unwrap()) .await @@ -348,4 +364,53 @@ mod tests { &"Bearer error=\"invalid_token\"" ); } + + // -------------------------- + // Multiple Authorizers + // -------------------------- + #[tokio::test] + async fn multiple_authorizers() { + let auths: Vec> = vec![ + JwtAuthorizer::from_ec_pem("../config/ecdsa-public1.pem") + .build() + .await + .unwrap(), + JwtAuthorizer::from_rsa_pem("../config/rsa-public1.pem") + .jwt_source(JwtSource::Cookie("ccc".to_owned())) + .build() + .await + .unwrap(), + ]; + + // OK + let response = proteced_request_with_header_and_layer( + auths.into_layer(), + header::COOKIE.as_str(), + &format!("ccc={}", common::JWT_RSA1_OK), + ) + .await; + assert_eq!(response.status(), StatusCode::OK); + + let auths: [Authorizer; 2] = [ + JwtAuthorizer::from_ec_pem("../config/ecdsa-public1.pem") + .build() + .await + .unwrap(), + JwtAuthorizer::from_rsa_pem("../config/rsa-public1.pem") + .jwt_source(JwtSource::Cookie("ccc".to_owned())) + .build() + .await + .unwrap(), + ]; + + // Cookie missing + let response = proteced_request_with_header_and_layer( + auths.into_layer(), + header::COOKIE.as_str(), + &format!("bad_cookie={}", common::JWT_EC2_OK), + ) + .await; + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + assert_eq!(response.headers().get(header::WWW_AUTHENTICATE).unwrap(), &"Bearer"); + } } diff --git a/jwt-authorizer/tests/tonic.rs b/jwt-authorizer/tests/tonic.rs index ac8874b..da499a8 100644 --- a/jwt-authorizer/tests/tonic.rs +++ b/jwt-authorizer/tests/tonic.rs @@ -3,7 +3,7 @@ use std::{sync::Once, task::Poll}; use axum::body::HttpBody; use futures_core::future::BoxFuture; use http::header::AUTHORIZATION; -use jwt_authorizer::{layer::AsyncAuthorizationService, JwtAuthorizer}; +use jwt_authorizer::{layer::AsyncAuthorizationService, IntoLayer, JwtAuthorizer}; use serde::{Deserialize, Serialize}; use tonic::{server::UnaryService, transport::NamedService, IntoRequest, Status}; use tower::{buffer::Buffer, Service}; @@ -83,7 +83,7 @@ async fn app( jwt_auth: JwtAuthorizer, expected_sub: String, ) -> AsyncAuthorizationService>, User> { - let layer = jwt_auth.layer().await.unwrap(); + let layer = jwt_auth.build().await.unwrap().into_layer(); tonic::transport::Server::builder() .layer(layer) .layer(tower::buffer::BufferLayer::new(1))