diff --git a/jwt-authorizer/src/authorizer.rs b/jwt-authorizer/src/authorizer.rs index 66b5980..f5384a3 100644 --- a/jwt-authorizer/src/authorizer.rs +++ b/jwt-authorizer/src/authorizer.rs @@ -1,5 +1,7 @@ use std::{io::Read, sync::Arc}; +use headers::{authorization::Bearer, Authorization, HeaderMapExt}; +use http::HeaderMap; use jsonwebtoken::{decode, decode_header, jwk::JwkSet, Algorithm, DecodingKey, TokenData}; use reqwest::Url; use serde::de::DeserializeOwned; @@ -7,6 +9,7 @@ use serde::de::DeserializeOwned; use crate::{ error::{AuthError, InitError}, jwks::{key_store_manager::KeyStoreManager, KeyData, KeySource}, + layer::{self, JwtSource}, oidc, Refresh, }; @@ -38,6 +41,7 @@ where pub key_source: KeySource, pub claims_checker: Option>, pub validation: crate::validation::Validation, + pub jwt_source: JwtSource, } fn read_data(path: &str) -> Result, InitError> { @@ -69,6 +73,7 @@ where claims_checker: Option>, refresh: Option, validation: crate::validation::Validation, + jwt_source: JwtSource, ) -> Result, InitError> { Ok(match key_source_type { KeySourceType::RSA(path) => { @@ -81,6 +86,7 @@ where })), claims_checker, validation, + jwt_source, } } KeySourceType::RSAString(text) => { @@ -93,6 +99,7 @@ where })), claims_checker, validation, + jwt_source, } } KeySourceType::EC(path) => { @@ -105,6 +112,7 @@ where })), claims_checker, validation, + jwt_source, } } KeySourceType::ECString(text) => { @@ -117,6 +125,7 @@ where })), claims_checker, validation, + jwt_source, } } KeySourceType::ED(path) => { @@ -129,6 +138,7 @@ where })), claims_checker, validation, + jwt_source, } } KeySourceType::EDString(text) => { @@ -141,6 +151,7 @@ where })), claims_checker, validation, + jwt_source, } } KeySourceType::Secret(secret) => { @@ -153,6 +164,7 @@ where })), claims_checker, validation, + jwt_source, } } KeySourceType::JwksString(jwks_str) => { @@ -164,6 +176,7 @@ where key_source: KeySource::SingleKeySource(Arc::new(k)), claims_checker, validation, + jwt_source, } } KeySourceType::Jwks(url) => { @@ -173,6 +186,7 @@ where key_source: KeySource::KeyStoreSource(key_store_manager), claims_checker, validation, + jwt_source, } } KeySourceType::Discovery(issuer_url) => { @@ -184,6 +198,7 @@ where key_source: KeySource::KeyStoreSource(key_store_manager), claims_checker, validation, + jwt_source, } } }) @@ -204,6 +219,18 @@ where Ok(token_data) } + + pub fn extract_token(&self, h: &HeaderMap) -> Option { + match &self.jwt_source { + layer::JwtSource::AuthorizationHeader => { + let bearer_o: Option> = h.typed_get(); + bearer_o.map(|b| String::from(b.0.token())) + } + layer::JwtSource::Cookie(name) => h + .typed_get::() + .and_then(|c| c.get(name.as_str()).map(String::from)), + } + } } #[cfg(test)] @@ -212,16 +239,22 @@ mod tests { use jsonwebtoken::{Algorithm, Header}; use serde_json::Value; - use crate::validation::Validation; + use crate::{layer::JwtSource, validation::Validation}; use super::{Authorizer, KeySourceType}; #[tokio::test] async fn build_from_secret() { let h = Header::new(Algorithm::HS256); - let a = Authorizer::::build(KeySourceType::Secret("xxxxxx".to_owned()), None, None, Validation::new()) - .await - .unwrap(); + let a = Authorizer::::build( + KeySourceType::Secret("xxxxxx".to_owned()), + None, + None, + Validation::new(), + JwtSource::AuthorizationHeader, + ) + .await + .unwrap(); let k = a.key_source.get_key(h); assert!(k.await.is_ok()); } @@ -238,9 +271,15 @@ mod tests { "e": "AQAB" }]} "#; - let a = Authorizer::::build(KeySourceType::JwksString(jwks.to_owned()), None, None, Validation::new()) - .await - .unwrap(); + let a = Authorizer::::build( + KeySourceType::JwksString(jwks.to_owned()), + None, + None, + Validation::new(), + JwtSource::AuthorizationHeader, + ) + .await + .unwrap(); let k = a.key_source.get_key(Header::new(Algorithm::RS256)); assert!(k.await.is_ok()); } @@ -252,6 +291,7 @@ mod tests { None, None, Validation::new(), + JwtSource::AuthorizationHeader, ) .await .unwrap(); @@ -263,6 +303,7 @@ mod tests { None, None, Validation::new(), + JwtSource::AuthorizationHeader, ) .await .unwrap(); @@ -274,6 +315,7 @@ mod tests { None, None, Validation::new(), + JwtSource::AuthorizationHeader, ) .await .unwrap(); @@ -288,6 +330,7 @@ mod tests { None, None, Validation::new(), + JwtSource::AuthorizationHeader, ) .await .unwrap(); @@ -299,6 +342,7 @@ mod tests { None, None, Validation::new(), + JwtSource::AuthorizationHeader, ) .await .unwrap(); @@ -310,6 +354,7 @@ mod tests { None, None, Validation::new(), + JwtSource::AuthorizationHeader, ) .await .unwrap(); @@ -324,6 +369,7 @@ mod tests { None, None, Validation::new(), + JwtSource::AuthorizationHeader, ) .await; println!("{:?}", a.as_ref().err()); @@ -332,8 +378,14 @@ mod tests { #[tokio::test] async fn build_jwks_url_error() { - let a = - Authorizer::::build(KeySourceType::Jwks("://xxxx".to_owned()), None, None, Validation::default()).await; + let a = Authorizer::::build( + KeySourceType::Jwks("://xxxx".to_owned()), + None, + None, + Validation::default(), + JwtSource::AuthorizationHeader, + ) + .await; println!("{:?}", a.as_ref().err()); assert!(a.is_err()); } @@ -345,6 +397,7 @@ mod tests { None, None, Validation::default(), + JwtSource::AuthorizationHeader, ) .await; println!("{:?}", a.as_ref().err()); diff --git a/jwt-authorizer/src/layer.rs b/jwt-authorizer/src/layer.rs index 3a8e36b..e401f73 100644 --- a/jwt-authorizer/src/layer.rs +++ b/jwt-authorizer/src/layer.rs @@ -1,10 +1,8 @@ use axum::async_trait; use axum::http::Request; use futures_core::ready; -use futures_util::future::BoxFuture; +use futures_util::future::{self, BoxFuture}; use futures_util::stream::{FuturesUnordered, StreamExt}; -use headers::authorization::Bearer; -use headers::{Authorization, HeaderMapExt}; use jsonwebtoken::TokenData; use pin_project::pin_project; use serde::de::DeserializeOwned; @@ -20,7 +18,7 @@ use crate::claims::RegisteredClaims; use crate::error::InitError; use crate::jwks::key_store_manager::Refresh; use crate::validation::Validation; -use crate::{layer, AuthError, RefreshStrategy}; +use crate::{AuthError, RefreshStrategy}; /// Authorizer Layer builder /// @@ -189,8 +187,10 @@ where #[deprecated(since = "0.10.0", note = "please use `to_layer()` instead")] pub async fn layer(self) -> Result, InitError> { let val = self.validation.unwrap_or_default(); - let auth = Arc::new(Authorizer::build(self.key_source_type, self.claims_checker, self.refresh, val).await?); - Ok(AsyncAuthorizationLayer::new(vec![auth], self.jwt_source)) + let auth = Arc::new( + Authorizer::build(self.key_source_type, self.claims_checker, self.refresh, val, self.jwt_source).await?, + ); + Ok(AsyncAuthorizationLayer::new(vec![auth])) } } @@ -201,8 +201,10 @@ where { async fn to_layer(self) -> Result, InitError> { let val = self.validation.unwrap_or_default(); - let auth = Arc::new(Authorizer::build(self.key_source_type, self.claims_checker, self.refresh, val).await?); - Ok(AsyncAuthorizationLayer::new(vec![auth], self.jwt_source)) + let auth = Arc::new( + Authorizer::build(self.key_source_type, self.claims_checker, self.refresh, val, self.jwt_source).await?, + ); + Ok(AsyncAuthorizationLayer::new(vec![auth])) } } @@ -222,6 +224,7 @@ where a.claims_checker, a.refresh, a.validation.unwrap_or_default(), + a.jwt_source, ) }) .collect(); @@ -237,8 +240,7 @@ where // TODO: composite build error (containing all errors) Err(e) } else { - // TODO: jwt_source per Authorizer - Ok(AsyncAuthorizationLayer::new(auths, JwtSource::AuthorizationHeader)) + Ok(AsyncAuthorizationLayer::new(auths)) } } } @@ -263,41 +265,35 @@ where type Future = BoxFuture<'static, Result, AuthError>>; fn authorize(&self, mut request: Request) -> Self::Future { - // TODO: extract token per authorizer (jwt_source shloud be per authorizer) - let h = request.headers(); - let token_o = match &self.jwt_source { - layer::JwtSource::AuthorizationHeader => { - let bearer_o: Option> = h.typed_get(); - bearer_o.map(|b| String::from(b.0.token())) - } - layer::JwtSource::Cookie(name) => h - .typed_get::() - .and_then(|c| c.get(name.as_str()).map(String::from)), - }; + let tkns_auths: Vec<(Option, Arc>)> = self + .auths + .iter() + .map(|a| (a.extract_token(request.headers()), a.clone())) + .collect(); - let authorizers: Vec>> = self.auths.iter().cloned().collect(); + if !tkns_auths.iter().any(|(t, _)| t.is_some()) { + return Box::pin(future::ready(Err(AuthError::MissingToken()))); + } Box::pin(async move { - if let Some(token) = token_o { - let mut token_data: Result, AuthError> = Err(AuthError::NoAuthorizer()); - for auth in authorizers { + let mut token_data: Result, AuthError> = Err(AuthError::NoAuthorizer()); + for (tkn, auth) in tkns_auths { + if let Some(token) = tkn { token_data = auth.check_auth(token.as_str()).await; if token_data.is_ok() { break; } } - match token_data { - Ok(tdata) => { - // Set `token_data` as a request extension so it can be accessed by other - // services down the stack. - request.extensions_mut().insert(tdata); + } + match token_data { + Ok(tdata) => { + // Set `token_data` as a request extension so it can be accessed by other + // services down the stack. + request.extensions_mut().insert(tdata); - Ok(request) - } - Err(err) => Err(err), // TODO: error containing all errors (not just the last one) + Ok(request) } - } else { - Err(AuthError::MissingToken()) + Err(err) => Err(err), // TODO: error containing all errors (not just the last one) } }) } @@ -311,15 +307,14 @@ where C: Clone + DeserializeOwned + Send, { auths: Vec>>, - jwt_source: JwtSource, } impl AsyncAuthorizationLayer where C: Clone + DeserializeOwned + Send, { - pub fn new(auths: Vec>>, jwt_source: JwtSource) -> AsyncAuthorizationLayer { - Self { auths, jwt_source } + pub fn new(auths: Vec>>) -> AsyncAuthorizationLayer { + Self { auths } } } @@ -330,7 +325,7 @@ where type Service = AsyncAuthorizationService; fn layer(&self, inner: S) -> Self::Service { - AsyncAuthorizationService::new(inner, self.auths.clone(), self.jwt_source.clone()) + AsyncAuthorizationService::new(inner, self.auths.clone()) } } @@ -366,7 +361,6 @@ where { pub inner: S, pub auths: Vec>>, - pub jwt_source: JwtSource, } impl AsyncAuthorizationService @@ -395,12 +389,8 @@ where /// Authorize requests using a custom scheme. /// /// The `Authorization` header is required to have the value provided. - pub fn new(inner: S, auths: Vec>>, jwt_source: JwtSource) -> AsyncAuthorizationService { - Self { - inner, - auths, - jwt_source, - } + pub fn new(inner: S, auths: Vec>>) -> AsyncAuthorizationService { + Self { inner, auths } } }