refactor: move jwt_source to Authorizer

allows multiple sorces with multiple authorizers
This commit is contained in:
cduvray 2023-08-14 08:02:56 +02:00
parent 55c4f7cc16
commit 603c042ee3
2 changed files with 98 additions and 55 deletions

View file

@ -1,5 +1,7 @@
use std::{io::Read, sync::Arc}; use std::{io::Read, sync::Arc};
use headers::{authorization::Bearer, Authorization, HeaderMapExt};
use http::HeaderMap;
use jsonwebtoken::{decode, decode_header, jwk::JwkSet, Algorithm, DecodingKey, TokenData}; use jsonwebtoken::{decode, decode_header, jwk::JwkSet, Algorithm, DecodingKey, TokenData};
use reqwest::Url; use reqwest::Url;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
@ -7,6 +9,7 @@ use serde::de::DeserializeOwned;
use crate::{ use crate::{
error::{AuthError, InitError}, error::{AuthError, InitError},
jwks::{key_store_manager::KeyStoreManager, KeyData, KeySource}, jwks::{key_store_manager::KeyStoreManager, KeyData, KeySource},
layer::{self, JwtSource},
oidc, Refresh, oidc, Refresh,
}; };
@ -38,6 +41,7 @@ where
pub key_source: KeySource, pub key_source: KeySource,
pub claims_checker: Option<FnClaimsChecker<C>>, pub claims_checker: Option<FnClaimsChecker<C>>,
pub validation: crate::validation::Validation, pub validation: crate::validation::Validation,
pub jwt_source: JwtSource,
} }
fn read_data(path: &str) -> Result<Vec<u8>, InitError> { fn read_data(path: &str) -> Result<Vec<u8>, InitError> {
@ -69,6 +73,7 @@ where
claims_checker: Option<FnClaimsChecker<C>>, claims_checker: Option<FnClaimsChecker<C>>,
refresh: Option<Refresh>, refresh: Option<Refresh>,
validation: crate::validation::Validation, validation: crate::validation::Validation,
jwt_source: JwtSource,
) -> Result<Authorizer<C>, InitError> { ) -> Result<Authorizer<C>, InitError> {
Ok(match key_source_type { Ok(match key_source_type {
KeySourceType::RSA(path) => { KeySourceType::RSA(path) => {
@ -81,6 +86,7 @@ where
})), })),
claims_checker, claims_checker,
validation, validation,
jwt_source,
} }
} }
KeySourceType::RSAString(text) => { KeySourceType::RSAString(text) => {
@ -93,6 +99,7 @@ where
})), })),
claims_checker, claims_checker,
validation, validation,
jwt_source,
} }
} }
KeySourceType::EC(path) => { KeySourceType::EC(path) => {
@ -105,6 +112,7 @@ where
})), })),
claims_checker, claims_checker,
validation, validation,
jwt_source,
} }
} }
KeySourceType::ECString(text) => { KeySourceType::ECString(text) => {
@ -117,6 +125,7 @@ where
})), })),
claims_checker, claims_checker,
validation, validation,
jwt_source,
} }
} }
KeySourceType::ED(path) => { KeySourceType::ED(path) => {
@ -129,6 +138,7 @@ where
})), })),
claims_checker, claims_checker,
validation, validation,
jwt_source,
} }
} }
KeySourceType::EDString(text) => { KeySourceType::EDString(text) => {
@ -141,6 +151,7 @@ where
})), })),
claims_checker, claims_checker,
validation, validation,
jwt_source,
} }
} }
KeySourceType::Secret(secret) => { KeySourceType::Secret(secret) => {
@ -153,6 +164,7 @@ where
})), })),
claims_checker, claims_checker,
validation, validation,
jwt_source,
} }
} }
KeySourceType::JwksString(jwks_str) => { KeySourceType::JwksString(jwks_str) => {
@ -164,6 +176,7 @@ where
key_source: KeySource::SingleKeySource(Arc::new(k)), key_source: KeySource::SingleKeySource(Arc::new(k)),
claims_checker, claims_checker,
validation, validation,
jwt_source,
} }
} }
KeySourceType::Jwks(url) => { KeySourceType::Jwks(url) => {
@ -173,6 +186,7 @@ where
key_source: KeySource::KeyStoreSource(key_store_manager), key_source: KeySource::KeyStoreSource(key_store_manager),
claims_checker, claims_checker,
validation, validation,
jwt_source,
} }
} }
KeySourceType::Discovery(issuer_url) => { KeySourceType::Discovery(issuer_url) => {
@ -184,6 +198,7 @@ where
key_source: KeySource::KeyStoreSource(key_store_manager), key_source: KeySource::KeyStoreSource(key_store_manager),
claims_checker, claims_checker,
validation, validation,
jwt_source,
} }
} }
}) })
@ -204,6 +219,18 @@ where
Ok(token_data) Ok(token_data)
} }
pub fn extract_token(&self, h: &HeaderMap) -> Option<String> {
match &self.jwt_source {
layer::JwtSource::AuthorizationHeader => {
let bearer_o: Option<Authorization<Bearer>> = h.typed_get();
bearer_o.map(|b| String::from(b.0.token()))
}
layer::JwtSource::Cookie(name) => h
.typed_get::<headers::Cookie>()
.and_then(|c| c.get(name.as_str()).map(String::from)),
}
}
} }
#[cfg(test)] #[cfg(test)]
@ -212,16 +239,22 @@ mod tests {
use jsonwebtoken::{Algorithm, Header}; use jsonwebtoken::{Algorithm, Header};
use serde_json::Value; use serde_json::Value;
use crate::validation::Validation; use crate::{layer::JwtSource, validation::Validation};
use super::{Authorizer, KeySourceType}; use super::{Authorizer, KeySourceType};
#[tokio::test] #[tokio::test]
async fn build_from_secret() { async fn build_from_secret() {
let h = Header::new(Algorithm::HS256); let h = Header::new(Algorithm::HS256);
let a = Authorizer::<Value>::build(KeySourceType::Secret("xxxxxx".to_owned()), None, None, Validation::new()) let a = Authorizer::<Value>::build(
.await KeySourceType::Secret("xxxxxx".to_owned()),
.unwrap(); None,
None,
Validation::new(),
JwtSource::AuthorizationHeader,
)
.await
.unwrap();
let k = a.key_source.get_key(h); let k = a.key_source.get_key(h);
assert!(k.await.is_ok()); assert!(k.await.is_ok());
} }
@ -238,9 +271,15 @@ mod tests {
"e": "AQAB" "e": "AQAB"
}]} }]}
"#; "#;
let a = Authorizer::<Value>::build(KeySourceType::JwksString(jwks.to_owned()), None, None, Validation::new()) let a = Authorizer::<Value>::build(
.await KeySourceType::JwksString(jwks.to_owned()),
.unwrap(); None,
None,
Validation::new(),
JwtSource::AuthorizationHeader,
)
.await
.unwrap();
let k = a.key_source.get_key(Header::new(Algorithm::RS256)); let k = a.key_source.get_key(Header::new(Algorithm::RS256));
assert!(k.await.is_ok()); assert!(k.await.is_ok());
} }
@ -252,6 +291,7 @@ mod tests {
None, None,
None, None,
Validation::new(), Validation::new(),
JwtSource::AuthorizationHeader,
) )
.await .await
.unwrap(); .unwrap();
@ -263,6 +303,7 @@ mod tests {
None, None,
None, None,
Validation::new(), Validation::new(),
JwtSource::AuthorizationHeader,
) )
.await .await
.unwrap(); .unwrap();
@ -274,6 +315,7 @@ mod tests {
None, None,
None, None,
Validation::new(), Validation::new(),
JwtSource::AuthorizationHeader,
) )
.await .await
.unwrap(); .unwrap();
@ -288,6 +330,7 @@ mod tests {
None, None,
None, None,
Validation::new(), Validation::new(),
JwtSource::AuthorizationHeader,
) )
.await .await
.unwrap(); .unwrap();
@ -299,6 +342,7 @@ mod tests {
None, None,
None, None,
Validation::new(), Validation::new(),
JwtSource::AuthorizationHeader,
) )
.await .await
.unwrap(); .unwrap();
@ -310,6 +354,7 @@ mod tests {
None, None,
None, None,
Validation::new(), Validation::new(),
JwtSource::AuthorizationHeader,
) )
.await .await
.unwrap(); .unwrap();
@ -324,6 +369,7 @@ mod tests {
None, None,
None, None,
Validation::new(), Validation::new(),
JwtSource::AuthorizationHeader,
) )
.await; .await;
println!("{:?}", a.as_ref().err()); println!("{:?}", a.as_ref().err());
@ -332,8 +378,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn build_jwks_url_error() { async fn build_jwks_url_error() {
let a = let a = Authorizer::<Value>::build(
Authorizer::<Value>::build(KeySourceType::Jwks("://xxxx".to_owned()), None, None, Validation::default()).await; KeySourceType::Jwks("://xxxx".to_owned()),
None,
None,
Validation::default(),
JwtSource::AuthorizationHeader,
)
.await;
println!("{:?}", a.as_ref().err()); println!("{:?}", a.as_ref().err());
assert!(a.is_err()); assert!(a.is_err());
} }
@ -345,6 +397,7 @@ mod tests {
None, None,
None, None,
Validation::default(), Validation::default(),
JwtSource::AuthorizationHeader,
) )
.await; .await;
println!("{:?}", a.as_ref().err()); println!("{:?}", a.as_ref().err());

View file

@ -1,10 +1,8 @@
use axum::async_trait; use axum::async_trait;
use axum::http::Request; use axum::http::Request;
use futures_core::ready; use futures_core::ready;
use futures_util::future::BoxFuture; use futures_util::future::{self, BoxFuture};
use futures_util::stream::{FuturesUnordered, StreamExt}; use futures_util::stream::{FuturesUnordered, StreamExt};
use headers::authorization::Bearer;
use headers::{Authorization, HeaderMapExt};
use jsonwebtoken::TokenData; use jsonwebtoken::TokenData;
use pin_project::pin_project; use pin_project::pin_project;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
@ -20,7 +18,7 @@ use crate::claims::RegisteredClaims;
use crate::error::InitError; use crate::error::InitError;
use crate::jwks::key_store_manager::Refresh; use crate::jwks::key_store_manager::Refresh;
use crate::validation::Validation; use crate::validation::Validation;
use crate::{layer, AuthError, RefreshStrategy}; use crate::{AuthError, RefreshStrategy};
/// Authorizer Layer builder /// Authorizer Layer builder
/// ///
@ -189,8 +187,10 @@ where
#[deprecated(since = "0.10.0", note = "please use `to_layer()` instead")] #[deprecated(since = "0.10.0", note = "please use `to_layer()` instead")]
pub async fn layer(self) -> Result<AsyncAuthorizationLayer<C>, InitError> { pub async fn layer(self) -> Result<AsyncAuthorizationLayer<C>, InitError> {
let val = self.validation.unwrap_or_default(); let val = self.validation.unwrap_or_default();
let auth = Arc::new(Authorizer::build(self.key_source_type, self.claims_checker, self.refresh, val).await?); let auth = Arc::new(
Ok(AsyncAuthorizationLayer::new(vec![auth], self.jwt_source)) Authorizer::build(self.key_source_type, self.claims_checker, self.refresh, val, self.jwt_source).await?,
);
Ok(AsyncAuthorizationLayer::new(vec![auth]))
} }
} }
@ -201,8 +201,10 @@ where
{ {
async fn to_layer(self) -> Result<AsyncAuthorizationLayer<C>, InitError> { async fn to_layer(self) -> Result<AsyncAuthorizationLayer<C>, InitError> {
let val = self.validation.unwrap_or_default(); let val = self.validation.unwrap_or_default();
let auth = Arc::new(Authorizer::build(self.key_source_type, self.claims_checker, self.refresh, val).await?); let auth = Arc::new(
Ok(AsyncAuthorizationLayer::new(vec![auth], self.jwt_source)) Authorizer::build(self.key_source_type, self.claims_checker, self.refresh, val, self.jwt_source).await?,
);
Ok(AsyncAuthorizationLayer::new(vec![auth]))
} }
} }
@ -222,6 +224,7 @@ where
a.claims_checker, a.claims_checker,
a.refresh, a.refresh,
a.validation.unwrap_or_default(), a.validation.unwrap_or_default(),
a.jwt_source,
) )
}) })
.collect(); .collect();
@ -237,8 +240,7 @@ where
// TODO: composite build error (containing all errors) // TODO: composite build error (containing all errors)
Err(e) Err(e)
} else { } else {
// TODO: jwt_source per Authorizer Ok(AsyncAuthorizationLayer::new(auths))
Ok(AsyncAuthorizationLayer::new(auths, JwtSource::AuthorizationHeader))
} }
} }
} }
@ -263,41 +265,35 @@ where
type Future = BoxFuture<'static, Result<Request<B>, AuthError>>; type Future = BoxFuture<'static, Result<Request<B>, AuthError>>;
fn authorize(&self, mut request: Request<B>) -> Self::Future { fn authorize(&self, mut request: Request<B>) -> Self::Future {
// TODO: extract token per authorizer (jwt_source shloud be per authorizer) let tkns_auths: Vec<(Option<String>, Arc<Authorizer<C>>)> = self
let h = request.headers(); .auths
let token_o = match &self.jwt_source { .iter()
layer::JwtSource::AuthorizationHeader => { .map(|a| (a.extract_token(request.headers()), a.clone()))
let bearer_o: Option<Authorization<Bearer>> = h.typed_get(); .collect();
bearer_o.map(|b| String::from(b.0.token()))
}
layer::JwtSource::Cookie(name) => h
.typed_get::<headers::Cookie>()
.and_then(|c| c.get(name.as_str()).map(String::from)),
};
let authorizers: Vec<Arc<Authorizer<C>>> = self.auths.iter().cloned().collect(); if !tkns_auths.iter().any(|(t, _)| t.is_some()) {
return Box::pin(future::ready(Err(AuthError::MissingToken())));
}
Box::pin(async move { Box::pin(async move {
if let Some(token) = token_o { let mut token_data: Result<TokenData<C>, AuthError> = Err(AuthError::NoAuthorizer());
let mut token_data: Result<TokenData<C>, AuthError> = Err(AuthError::NoAuthorizer()); for (tkn, auth) in tkns_auths {
for auth in authorizers { if let Some(token) = tkn {
token_data = auth.check_auth(token.as_str()).await; token_data = auth.check_auth(token.as_str()).await;
if token_data.is_ok() { if token_data.is_ok() {
break; break;
} }
} }
match token_data { }
Ok(tdata) => { match token_data {
// Set `token_data` as a request extension so it can be accessed by other Ok(tdata) => {
// services down the stack. // Set `token_data` as a request extension so it can be accessed by other
request.extensions_mut().insert(tdata); // services down the stack.
request.extensions_mut().insert(tdata);
Ok(request) Ok(request)
}
Err(err) => Err(err), // TODO: error containing all errors (not just the last one)
} }
} else { Err(err) => Err(err), // TODO: error containing all errors (not just the last one)
Err(AuthError::MissingToken())
} }
}) })
} }
@ -311,15 +307,14 @@ where
C: Clone + DeserializeOwned + Send, C: Clone + DeserializeOwned + Send,
{ {
auths: Vec<Arc<Authorizer<C>>>, auths: Vec<Arc<Authorizer<C>>>,
jwt_source: JwtSource,
} }
impl<C> AsyncAuthorizationLayer<C> impl<C> AsyncAuthorizationLayer<C>
where where
C: Clone + DeserializeOwned + Send, C: Clone + DeserializeOwned + Send,
{ {
pub fn new(auths: Vec<Arc<Authorizer<C>>>, jwt_source: JwtSource) -> AsyncAuthorizationLayer<C> { pub fn new(auths: Vec<Arc<Authorizer<C>>>) -> AsyncAuthorizationLayer<C> {
Self { auths, jwt_source } Self { auths }
} }
} }
@ -330,7 +325,7 @@ where
type Service = AsyncAuthorizationService<S, C>; type Service = AsyncAuthorizationService<S, C>;
fn layer(&self, inner: S) -> Self::Service { fn layer(&self, inner: S) -> Self::Service {
AsyncAuthorizationService::new(inner, self.auths.clone(), self.jwt_source.clone()) AsyncAuthorizationService::new(inner, self.auths.clone())
} }
} }
@ -366,7 +361,6 @@ where
{ {
pub inner: S, pub inner: S,
pub auths: Vec<Arc<Authorizer<C>>>, pub auths: Vec<Arc<Authorizer<C>>>,
pub jwt_source: JwtSource,
} }
impl<S, C> AsyncAuthorizationService<S, C> impl<S, C> AsyncAuthorizationService<S, C>
@ -395,12 +389,8 @@ where
/// Authorize requests using a custom scheme. /// Authorize requests using a custom scheme.
/// ///
/// The `Authorization` header is required to have the value provided. /// The `Authorization` header is required to have the value provided.
pub fn new(inner: S, auths: Vec<Arc<Authorizer<C>>>, jwt_source: JwtSource) -> AsyncAuthorizationService<S, C> { pub fn new(inner: S, auths: Vec<Arc<Authorizer<C>>>) -> AsyncAuthorizationService<S, C> {
Self { Self { inner, auths }
inner,
auths,
jwt_source,
}
} }
} }