feat: ToAuthorizationLayer

This commit is contained in:
cduvray 2023-08-14 08:02:56 +02:00
parent 57fbc6e399
commit d7d945c075
5 changed files with 124 additions and 28 deletions

View file

@ -1,5 +1,7 @@
use axum::{routing::get, Router}; use axum::{routing::get, Router};
use jwt_authorizer::{error::InitError, AuthError, JwtAuthorizer, JwtClaims, Refresh, RefreshStrategy}; use jwt_authorizer::{
error::InitError, AuthError, JwtAuthorizer, JwtClaims, Refresh, RefreshStrategy, ToAuthorizationLayer,
};
use serde::Deserialize; use serde::Deserialize;
use std::net::SocketAddr; use std::net::SocketAddr;
use tower_http::trace::TraceLayer; use tower_http::trace::TraceLayer;
@ -49,7 +51,7 @@ async fn main() -> Result<(), InitError> {
let api = Router::new() let api = Router::new()
.route("/protected", get(protected)) .route("/protected", get(protected))
// adding the authorizer layer // adding the authorizer layer
.layer(jwt_auth.layer().await?); .layer(jwt_auth.to_layer().await?);
let app = Router::new() let app = Router::new()
// public endpoint // public endpoint

View file

@ -65,7 +65,7 @@ where
C: DeserializeOwned + Clone + Send + Sync, C: DeserializeOwned + Clone + Send + Sync,
{ {
pub(crate) async fn build( pub(crate) async fn build(
key_source_type: &KeySourceType, key_source_type: KeySourceType,
claims_checker: Option<FnClaimsChecker<C>>, claims_checker: Option<FnClaimsChecker<C>>,
refresh: Option<Refresh>, refresh: Option<Refresh>,
validation: crate::validation::Validation, validation: crate::validation::Validation,
@ -157,7 +157,7 @@ where
} }
KeySourceType::JwksString(jwks_str) => { KeySourceType::JwksString(jwks_str) => {
// TODO: expose it in JwtAuthorizer or remove // TODO: expose it in JwtAuthorizer or remove
let set: JwkSet = serde_json::from_str(jwks_str)?; let set: JwkSet = serde_json::from_str(jwks_str.as_str())?;
// TODO: replace [0] by kid/alg search // TODO: replace [0] by kid/alg search
let k = KeyData::from_jwk(&set.keys[0]).map_err(InitError::KeyDecodingError)?; let k = KeyData::from_jwk(&set.keys[0]).map_err(InitError::KeyDecodingError)?;
Authorizer { Authorizer {
@ -167,7 +167,7 @@ where
} }
} }
KeySourceType::Jwks(url) => { KeySourceType::Jwks(url) => {
let jwks_url = Url::parse(url).map_err(|e| InitError::JwksUrlError(e.to_string()))?; let jwks_url = Url::parse(url.as_str()).map_err(|e| InitError::JwksUrlError(e.to_string()))?;
let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default()); let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default());
Authorizer { Authorizer {
key_source: KeySource::KeyStoreSource(key_store_manager), key_source: KeySource::KeyStoreSource(key_store_manager),
@ -176,7 +176,7 @@ where
} }
} }
KeySourceType::Discovery(issuer_url) => { KeySourceType::Discovery(issuer_url) => {
let jwks_url = Url::parse(&oidc::discover_jwks(issuer_url).await?) let jwks_url = Url::parse(&oidc::discover_jwks(issuer_url.as_str()).await?)
.map_err(|e| InitError::JwksUrlError(e.to_string()))?; .map_err(|e| InitError::JwksUrlError(e.to_string()))?;
let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default()); let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default());
@ -219,7 +219,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn build_from_secret() { async fn build_from_secret() {
let h = Header::new(Algorithm::HS256); let h = Header::new(Algorithm::HS256);
let a = Authorizer::<Value>::build(&KeySourceType::Secret("xxxxxx".to_owned()), None, None, Validation::new()) let a = Authorizer::<Value>::build(KeySourceType::Secret("xxxxxx".to_owned()), None, None, Validation::new())
.await .await
.unwrap(); .unwrap();
let k = a.key_source.get_key(h); let k = a.key_source.get_key(h);
@ -238,7 +238,7 @@ mod tests {
"e": "AQAB" "e": "AQAB"
}]} }]}
"#; "#;
let a = Authorizer::<Value>::build(&KeySourceType::JwksString(jwks.to_owned()), None, None, Validation::new()) let a = Authorizer::<Value>::build(KeySourceType::JwksString(jwks.to_owned()), None, None, Validation::new())
.await .await
.unwrap(); .unwrap();
let k = a.key_source.get_key(Header::new(Algorithm::RS256)); let k = a.key_source.get_key(Header::new(Algorithm::RS256));
@ -248,7 +248,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn build_from_file() { async fn build_from_file() {
let a = Authorizer::<Value>::build( let a = Authorizer::<Value>::build(
&KeySourceType::RSA("../config/rsa-public1.pem".to_owned()), KeySourceType::RSA("../config/rsa-public1.pem".to_owned()),
None, None,
None, None,
Validation::new(), Validation::new(),
@ -259,7 +259,7 @@ mod tests {
assert!(k.await.is_ok()); assert!(k.await.is_ok());
let a = Authorizer::<Value>::build( let a = Authorizer::<Value>::build(
&KeySourceType::EC("../config/ecdsa-public1.pem".to_owned()), KeySourceType::EC("../config/ecdsa-public1.pem".to_owned()),
None, None,
None, None,
Validation::new(), Validation::new(),
@ -270,7 +270,7 @@ mod tests {
assert!(k.await.is_ok()); assert!(k.await.is_ok());
let a = Authorizer::<Value>::build( let a = Authorizer::<Value>::build(
&KeySourceType::ED("../config/ed25519-public1.pem".to_owned()), KeySourceType::ED("../config/ed25519-public1.pem".to_owned()),
None, None,
None, None,
Validation::new(), Validation::new(),
@ -284,7 +284,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn build_from_text() { async fn build_from_text() {
let a = Authorizer::<Value>::build( let a = Authorizer::<Value>::build(
&KeySourceType::RSAString(include_str!("../../config/rsa-public1.pem").to_owned()), KeySourceType::RSAString(include_str!("../../config/rsa-public1.pem").to_owned()),
None, None,
None, None,
Validation::new(), Validation::new(),
@ -295,7 +295,7 @@ mod tests {
assert!(k.await.is_ok()); assert!(k.await.is_ok());
let a = Authorizer::<Value>::build( let a = Authorizer::<Value>::build(
&KeySourceType::ECString(include_str!("../../config/ecdsa-public1.pem").to_owned()), KeySourceType::ECString(include_str!("../../config/ecdsa-public1.pem").to_owned()),
None, None,
None, None,
Validation::new(), Validation::new(),
@ -306,7 +306,7 @@ mod tests {
assert!(k.await.is_ok()); assert!(k.await.is_ok());
let a = Authorizer::<Value>::build( let a = Authorizer::<Value>::build(
&KeySourceType::EDString(include_str!("../../config/ed25519-public1.pem").to_owned()), KeySourceType::EDString(include_str!("../../config/ed25519-public1.pem").to_owned()),
None, None,
None, None,
Validation::new(), Validation::new(),
@ -320,7 +320,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn build_file_errors() { async fn build_file_errors() {
let a = Authorizer::<Value>::build( let a = Authorizer::<Value>::build(
&KeySourceType::RSA("./config/does-not-exist.pem".to_owned()), KeySourceType::RSA("./config/does-not-exist.pem".to_owned()),
None, None,
None, None,
Validation::new(), Validation::new(),
@ -333,7 +333,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn build_jwks_url_error() { async fn build_jwks_url_error() {
let a = let a =
Authorizer::<Value>::build(&KeySourceType::Jwks("://xxxx".to_owned()), None, None, Validation::default()).await; Authorizer::<Value>::build(KeySourceType::Jwks("://xxxx".to_owned()), None, None, Validation::default()).await;
println!("{:?}", a.as_ref().err()); println!("{:?}", a.as_ref().err());
assert!(a.is_err()); assert!(a.is_err());
} }
@ -341,7 +341,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn build_discovery_url_error() { async fn build_discovery_url_error() {
let a = Authorizer::<Value>::build( let a = Authorizer::<Value>::build(
&KeySourceType::Discovery("://xxxx".to_owned()), KeySourceType::Discovery("://xxxx".to_owned()),
None, None,
None, None,
Validation::default(), Validation::default(),

View file

@ -1,6 +1,8 @@
use axum::async_trait;
use axum::http::Request; use axum::http::Request;
use futures_core::ready; use futures_core::ready;
use futures_util::future::BoxFuture; use futures_util::future::BoxFuture;
use futures_util::stream::{FuturesUnordered, StreamExt};
use headers::authorization::Bearer; use headers::authorization::Bearer;
use headers::{Authorization, HeaderMapExt}; use headers::{Authorization, HeaderMapExt};
use jsonwebtoken::TokenData; use jsonwebtoken::TokenData;
@ -184,12 +186,63 @@ where
} }
/// Build axum layer /// Build axum layer
#[deprecated(since = "0.10.0", note = "please use `to_layer()` instead")]
pub async fn layer(self) -> Result<AsyncAuthorizationLayer<C>, InitError> { pub async fn layer(self) -> Result<AsyncAuthorizationLayer<C>, InitError> {
let val = self.validation.unwrap_or_default(); let val = self.validation.unwrap_or_default();
let auth = Arc::new(Authorizer::build(&self.key_source_type, self.claims_checker, self.refresh, val).await?); let auth = Arc::new(Authorizer::build(self.key_source_type, self.claims_checker, self.refresh, val).await?);
Ok(AsyncAuthorizationLayer::new(auth, self.jwt_source)) Ok(AsyncAuthorizationLayer::new(vec![auth], self.jwt_source))
} }
} }
#[async_trait]
impl<C> ToAuthorizationLayer<C> for JwtAuthorizer<C>
where
C: Clone + DeserializeOwned + Send + Sync,
{
async fn to_layer(self) -> Result<AsyncAuthorizationLayer<C>, InitError> {
let val = self.validation.unwrap_or_default();
let auth = Arc::new(Authorizer::build(self.key_source_type, self.claims_checker, self.refresh, val).await?);
Ok(AsyncAuthorizationLayer::new(vec![auth], self.jwt_source))
}
}
#[async_trait]
impl<C> ToAuthorizationLayer<C> for Vec<JwtAuthorizer<C>>
where
C: Clone + DeserializeOwned + Send + Sync,
{
async fn to_layer(self) -> Result<AsyncAuthorizationLayer<C>, InitError> {
let mut errs = Vec::<InitError>::new();
let mut auths = Vec::<Arc<Authorizer<C>>>::new();
let mut auths_futs: FuturesUnordered<_> = self
.into_iter()
.map(|a| {
Authorizer::build(
a.key_source_type,
a.claims_checker,
a.refresh,
a.validation.unwrap_or_default(),
)
})
.collect();
while let Some(a) = auths_futs.next().await {
match a {
Ok(res) => auths.push(Arc::new(res)),
Err(err) => errs.push(err),
}
}
if let Some(e) = errs.into_iter().next() {
// TODO: composite build error (containing all errors)
Err(e)
} else {
// TODO: jwt_source per Authorizer
Ok(AsyncAuthorizationLayer::new(auths, JwtSource::AuthorizationHeader))
}
}
}
/// Trait for authorizing requests. /// Trait for authorizing requests.
pub trait AsyncAuthorizer<B> { pub trait AsyncAuthorizer<B> {
type RequestBody; type RequestBody;
@ -257,7 +310,7 @@ pub struct AsyncAuthorizationLayer<C>
where where
C: Clone + DeserializeOwned + Send, C: Clone + DeserializeOwned + Send,
{ {
auth: Arc<Authorizer<C>>, auths: Vec<Arc<Authorizer<C>>>,
jwt_source: JwtSource, jwt_source: JwtSource,
} }
@ -265,8 +318,8 @@ impl<C> AsyncAuthorizationLayer<C>
where where
C: Clone + DeserializeOwned + Send, C: Clone + DeserializeOwned + Send,
{ {
pub fn new(auth: Arc<Authorizer<C>>, jwt_source: JwtSource) -> AsyncAuthorizationLayer<C> { pub fn new(auths: Vec<Arc<Authorizer<C>>>, jwt_source: JwtSource) -> AsyncAuthorizationLayer<C> {
Self { auth, jwt_source } Self { auths, jwt_source }
} }
} }
@ -277,10 +330,18 @@ where
type Service = AsyncAuthorizationService<S, C>; type Service = AsyncAuthorizationService<S, C>;
fn layer(&self, inner: S) -> Self::Service { fn layer(&self, inner: S) -> Self::Service {
AsyncAuthorizationService::new(inner, self.auth.clone(), self.jwt_source.clone()) AsyncAuthorizationService::new(inner, self.auths.clone(), self.jwt_source.clone())
} }
} }
#[async_trait]
pub trait ToAuthorizationLayer<C>
where
C: Clone + DeserializeOwned + Send,
{
async fn to_layer(self) -> Result<AsyncAuthorizationLayer<C>, InitError>;
}
// ---------- AsyncAuthorizationService -------- // ---------- AsyncAuthorizationService --------
/// Source of the bearer token /// Source of the bearer token
@ -334,10 +395,10 @@ where
/// Authorize requests using a custom scheme. /// Authorize requests using a custom scheme.
/// ///
/// The `Authorization` header is required to have the value provided. /// The `Authorization` header is required to have the value provided.
pub fn new(inner: S, auth: Arc<Authorizer<C>>, jwt_source: JwtSource) -> AsyncAuthorizationService<S, C> { pub fn new(inner: S, auths: Vec<Arc<Authorizer<C>>>, jwt_source: JwtSource) -> AsyncAuthorizationService<S, C> {
Self { Self {
inner, inner,
auths: vec![auth], auths: auths,
jwt_source, jwt_source,
} }
} }
@ -431,3 +492,36 @@ where
} }
} }
} }
#[cfg(test)]
mod tests {
use crate::{JwtAuthorizer, ToAuthorizationLayer};
#[tokio::test]
async fn jwt_auth_to_layer() {
let auth1: JwtAuthorizer = JwtAuthorizer::from_secret("aaa");
let layer = auth1.to_layer().await;
assert!(layer.is_ok());
}
#[tokio::test]
async fn vec_to_layer() {
let auth1: JwtAuthorizer = JwtAuthorizer::from_secret("aaa");
let auth2: JwtAuthorizer = JwtAuthorizer::from_secret("bbb");
let av = vec![auth1, auth2];
let layer = av.to_layer().await;
assert!(layer.is_ok());
}
#[tokio::test]
async fn vec_to_layer_errors() {
let auth1: JwtAuthorizer = JwtAuthorizer::from_ec_pem("aaa");
let auth2: JwtAuthorizer = JwtAuthorizer::from_ed_pem("bbb");
let av = vec![auth1, auth2];
let layer = av.to_layer().await;
assert!(layer.is_err());
if let Err(err) = layer {
assert_eq!(err.to_string(), "No such file or directory (os error 2)");
}
}
}

View file

@ -8,7 +8,7 @@ use serde::de::DeserializeOwned;
pub use self::error::AuthError; pub use self::error::AuthError;
pub use claims::{NumericDate, OneOrArray, RegisteredClaims}; pub use claims::{NumericDate, OneOrArray, RegisteredClaims};
pub use jwks::key_store_manager::{Refresh, RefreshStrategy}; pub use jwks::key_store_manager::{Refresh, RefreshStrategy};
pub use layer::JwtAuthorizer; pub use layer::{JwtAuthorizer, ToAuthorizationLayer};
pub use validation::Validation; pub use validation::Validation;
pub mod authorizer; pub mod authorizer;

View file

@ -12,7 +12,7 @@ mod tests {
BoxError, Router, BoxError, Router,
}; };
use http::{header, HeaderValue}; use http::{header, HeaderValue};
use jwt_authorizer::{layer::JwtSource, validation::Validation, JwtAuthorizer, JwtClaims}; use jwt_authorizer::{layer::JwtSource, validation::Validation, JwtAuthorizer, JwtClaims, ToAuthorizationLayer};
use serde::Deserialize; use serde::Deserialize;
use tower::{util::MapErrLayer, ServiceExt}; use tower::{util::MapErrLayer, ServiceExt};
@ -32,7 +32,7 @@ mod tests {
tower::buffer::BufferLayer::new(1), tower::buffer::BufferLayer::new(1),
MapErrLayer::new(|e: BoxError| -> Infallible { panic!("{}", e) }), MapErrLayer::new(|e: BoxError| -> Infallible { panic!("{}", e) }),
), ),
jwt_auth.layer().await.unwrap(), jwt_auth.to_layer().await.unwrap(),
), ),
), ),
) )