refactor: move jwt_source to Authorizer

allows multiple sorces with multiple authorizers
This commit is contained in:
cduvray 2023-08-14 08:02:56 +02:00
parent 55c4f7cc16
commit 603c042ee3
2 changed files with 98 additions and 55 deletions

View file

@ -1,5 +1,7 @@
use std::{io::Read, sync::Arc};
use headers::{authorization::Bearer, Authorization, HeaderMapExt};
use http::HeaderMap;
use jsonwebtoken::{decode, decode_header, jwk::JwkSet, Algorithm, DecodingKey, TokenData};
use reqwest::Url;
use serde::de::DeserializeOwned;
@ -7,6 +9,7 @@ use serde::de::DeserializeOwned;
use crate::{
error::{AuthError, InitError},
jwks::{key_store_manager::KeyStoreManager, KeyData, KeySource},
layer::{self, JwtSource},
oidc, Refresh,
};
@ -38,6 +41,7 @@ where
pub key_source: KeySource,
pub claims_checker: Option<FnClaimsChecker<C>>,
pub validation: crate::validation::Validation,
pub jwt_source: JwtSource,
}
fn read_data(path: &str) -> Result<Vec<u8>, InitError> {
@ -69,6 +73,7 @@ where
claims_checker: Option<FnClaimsChecker<C>>,
refresh: Option<Refresh>,
validation: crate::validation::Validation,
jwt_source: JwtSource,
) -> Result<Authorizer<C>, InitError> {
Ok(match key_source_type {
KeySourceType::RSA(path) => {
@ -81,6 +86,7 @@ where
})),
claims_checker,
validation,
jwt_source,
}
}
KeySourceType::RSAString(text) => {
@ -93,6 +99,7 @@ where
})),
claims_checker,
validation,
jwt_source,
}
}
KeySourceType::EC(path) => {
@ -105,6 +112,7 @@ where
})),
claims_checker,
validation,
jwt_source,
}
}
KeySourceType::ECString(text) => {
@ -117,6 +125,7 @@ where
})),
claims_checker,
validation,
jwt_source,
}
}
KeySourceType::ED(path) => {
@ -129,6 +138,7 @@ where
})),
claims_checker,
validation,
jwt_source,
}
}
KeySourceType::EDString(text) => {
@ -141,6 +151,7 @@ where
})),
claims_checker,
validation,
jwt_source,
}
}
KeySourceType::Secret(secret) => {
@ -153,6 +164,7 @@ where
})),
claims_checker,
validation,
jwt_source,
}
}
KeySourceType::JwksString(jwks_str) => {
@ -164,6 +176,7 @@ where
key_source: KeySource::SingleKeySource(Arc::new(k)),
claims_checker,
validation,
jwt_source,
}
}
KeySourceType::Jwks(url) => {
@ -173,6 +186,7 @@ where
key_source: KeySource::KeyStoreSource(key_store_manager),
claims_checker,
validation,
jwt_source,
}
}
KeySourceType::Discovery(issuer_url) => {
@ -184,6 +198,7 @@ where
key_source: KeySource::KeyStoreSource(key_store_manager),
claims_checker,
validation,
jwt_source,
}
}
})
@ -204,6 +219,18 @@ where
Ok(token_data)
}
pub fn extract_token(&self, h: &HeaderMap) -> Option<String> {
match &self.jwt_source {
layer::JwtSource::AuthorizationHeader => {
let bearer_o: Option<Authorization<Bearer>> = h.typed_get();
bearer_o.map(|b| String::from(b.0.token()))
}
layer::JwtSource::Cookie(name) => h
.typed_get::<headers::Cookie>()
.and_then(|c| c.get(name.as_str()).map(String::from)),
}
}
}
#[cfg(test)]
@ -212,14 +239,20 @@ mod tests {
use jsonwebtoken::{Algorithm, Header};
use serde_json::Value;
use crate::validation::Validation;
use crate::{layer::JwtSource, validation::Validation};
use super::{Authorizer, KeySourceType};
#[tokio::test]
async fn build_from_secret() {
let h = Header::new(Algorithm::HS256);
let a = Authorizer::<Value>::build(KeySourceType::Secret("xxxxxx".to_owned()), None, None, Validation::new())
let a = Authorizer::<Value>::build(
KeySourceType::Secret("xxxxxx".to_owned()),
None,
None,
Validation::new(),
JwtSource::AuthorizationHeader,
)
.await
.unwrap();
let k = a.key_source.get_key(h);
@ -238,7 +271,13 @@ mod tests {
"e": "AQAB"
}]}
"#;
let a = Authorizer::<Value>::build(KeySourceType::JwksString(jwks.to_owned()), None, None, Validation::new())
let a = Authorizer::<Value>::build(
KeySourceType::JwksString(jwks.to_owned()),
None,
None,
Validation::new(),
JwtSource::AuthorizationHeader,
)
.await
.unwrap();
let k = a.key_source.get_key(Header::new(Algorithm::RS256));
@ -252,6 +291,7 @@ mod tests {
None,
None,
Validation::new(),
JwtSource::AuthorizationHeader,
)
.await
.unwrap();
@ -263,6 +303,7 @@ mod tests {
None,
None,
Validation::new(),
JwtSource::AuthorizationHeader,
)
.await
.unwrap();
@ -274,6 +315,7 @@ mod tests {
None,
None,
Validation::new(),
JwtSource::AuthorizationHeader,
)
.await
.unwrap();
@ -288,6 +330,7 @@ mod tests {
None,
None,
Validation::new(),
JwtSource::AuthorizationHeader,
)
.await
.unwrap();
@ -299,6 +342,7 @@ mod tests {
None,
None,
Validation::new(),
JwtSource::AuthorizationHeader,
)
.await
.unwrap();
@ -310,6 +354,7 @@ mod tests {
None,
None,
Validation::new(),
JwtSource::AuthorizationHeader,
)
.await
.unwrap();
@ -324,6 +369,7 @@ mod tests {
None,
None,
Validation::new(),
JwtSource::AuthorizationHeader,
)
.await;
println!("{:?}", a.as_ref().err());
@ -332,8 +378,14 @@ mod tests {
#[tokio::test]
async fn build_jwks_url_error() {
let a =
Authorizer::<Value>::build(KeySourceType::Jwks("://xxxx".to_owned()), None, None, Validation::default()).await;
let a = Authorizer::<Value>::build(
KeySourceType::Jwks("://xxxx".to_owned()),
None,
None,
Validation::default(),
JwtSource::AuthorizationHeader,
)
.await;
println!("{:?}", a.as_ref().err());
assert!(a.is_err());
}
@ -345,6 +397,7 @@ mod tests {
None,
None,
Validation::default(),
JwtSource::AuthorizationHeader,
)
.await;
println!("{:?}", a.as_ref().err());

View file

@ -1,10 +1,8 @@
use axum::async_trait;
use axum::http::Request;
use futures_core::ready;
use futures_util::future::BoxFuture;
use futures_util::future::{self, BoxFuture};
use futures_util::stream::{FuturesUnordered, StreamExt};
use headers::authorization::Bearer;
use headers::{Authorization, HeaderMapExt};
use jsonwebtoken::TokenData;
use pin_project::pin_project;
use serde::de::DeserializeOwned;
@ -20,7 +18,7 @@ use crate::claims::RegisteredClaims;
use crate::error::InitError;
use crate::jwks::key_store_manager::Refresh;
use crate::validation::Validation;
use crate::{layer, AuthError, RefreshStrategy};
use crate::{AuthError, RefreshStrategy};
/// Authorizer Layer builder
///
@ -189,8 +187,10 @@ where
#[deprecated(since = "0.10.0", note = "please use `to_layer()` instead")]
pub async fn layer(self) -> Result<AsyncAuthorizationLayer<C>, InitError> {
let val = self.validation.unwrap_or_default();
let auth = Arc::new(Authorizer::build(self.key_source_type, self.claims_checker, self.refresh, val).await?);
Ok(AsyncAuthorizationLayer::new(vec![auth], self.jwt_source))
let auth = Arc::new(
Authorizer::build(self.key_source_type, self.claims_checker, self.refresh, val, self.jwt_source).await?,
);
Ok(AsyncAuthorizationLayer::new(vec![auth]))
}
}
@ -201,8 +201,10 @@ where
{
async fn to_layer(self) -> Result<AsyncAuthorizationLayer<C>, InitError> {
let val = self.validation.unwrap_or_default();
let auth = Arc::new(Authorizer::build(self.key_source_type, self.claims_checker, self.refresh, val).await?);
Ok(AsyncAuthorizationLayer::new(vec![auth], self.jwt_source))
let auth = Arc::new(
Authorizer::build(self.key_source_type, self.claims_checker, self.refresh, val, self.jwt_source).await?,
);
Ok(AsyncAuthorizationLayer::new(vec![auth]))
}
}
@ -222,6 +224,7 @@ where
a.claims_checker,
a.refresh,
a.validation.unwrap_or_default(),
a.jwt_source,
)
})
.collect();
@ -237,8 +240,7 @@ where
// TODO: composite build error (containing all errors)
Err(e)
} else {
// TODO: jwt_source per Authorizer
Ok(AsyncAuthorizationLayer::new(auths, JwtSource::AuthorizationHeader))
Ok(AsyncAuthorizationLayer::new(auths))
}
}
}
@ -263,29 +265,26 @@ where
type Future = BoxFuture<'static, Result<Request<B>, AuthError>>;
fn authorize(&self, mut request: Request<B>) -> Self::Future {
// TODO: extract token per authorizer (jwt_source shloud be per authorizer)
let h = request.headers();
let token_o = match &self.jwt_source {
layer::JwtSource::AuthorizationHeader => {
let bearer_o: Option<Authorization<Bearer>> = h.typed_get();
bearer_o.map(|b| String::from(b.0.token()))
}
layer::JwtSource::Cookie(name) => h
.typed_get::<headers::Cookie>()
.and_then(|c| c.get(name.as_str()).map(String::from)),
};
let tkns_auths: Vec<(Option<String>, Arc<Authorizer<C>>)> = self
.auths
.iter()
.map(|a| (a.extract_token(request.headers()), a.clone()))
.collect();
let authorizers: Vec<Arc<Authorizer<C>>> = self.auths.iter().cloned().collect();
if !tkns_auths.iter().any(|(t, _)| t.is_some()) {
return Box::pin(future::ready(Err(AuthError::MissingToken())));
}
Box::pin(async move {
if let Some(token) = token_o {
let mut token_data: Result<TokenData<C>, AuthError> = Err(AuthError::NoAuthorizer());
for auth in authorizers {
for (tkn, auth) in tkns_auths {
if let Some(token) = tkn {
token_data = auth.check_auth(token.as_str()).await;
if token_data.is_ok() {
break;
}
}
}
match token_data {
Ok(tdata) => {
// Set `token_data` as a request extension so it can be accessed by other
@ -296,9 +295,6 @@ where
}
Err(err) => Err(err), // TODO: error containing all errors (not just the last one)
}
} else {
Err(AuthError::MissingToken())
}
})
}
}
@ -311,15 +307,14 @@ where
C: Clone + DeserializeOwned + Send,
{
auths: Vec<Arc<Authorizer<C>>>,
jwt_source: JwtSource,
}
impl<C> AsyncAuthorizationLayer<C>
where
C: Clone + DeserializeOwned + Send,
{
pub fn new(auths: Vec<Arc<Authorizer<C>>>, jwt_source: JwtSource) -> AsyncAuthorizationLayer<C> {
Self { auths, jwt_source }
pub fn new(auths: Vec<Arc<Authorizer<C>>>) -> AsyncAuthorizationLayer<C> {
Self { auths }
}
}
@ -330,7 +325,7 @@ where
type Service = AsyncAuthorizationService<S, C>;
fn layer(&self, inner: S) -> Self::Service {
AsyncAuthorizationService::new(inner, self.auths.clone(), self.jwt_source.clone())
AsyncAuthorizationService::new(inner, self.auths.clone())
}
}
@ -366,7 +361,6 @@ where
{
pub inner: S,
pub auths: Vec<Arc<Authorizer<C>>>,
pub jwt_source: JwtSource,
}
impl<S, C> AsyncAuthorizationService<S, C>
@ -395,12 +389,8 @@ where
/// Authorize requests using a custom scheme.
///
/// The `Authorization` header is required to have the value provided.
pub fn new(inner: S, auths: Vec<Arc<Authorizer<C>>>, jwt_source: JwtSource) -> AsyncAuthorizationService<S, C> {
Self {
inner,
auths,
jwt_source,
}
pub fn new(inner: S, auths: Vec<Arc<Authorizer<C>>>) -> AsyncAuthorizationService<S, C> {
Self { inner, auths }
}
}