refactor: JwtAuthorizer::IntoLayer -> Authorizer::IntoLayer

- better error management (avoids composite errors when transforming multiple builder into layer)
This commit is contained in:
cduvray 2023-08-14 11:26:49 +02:00
parent 3d5367da88
commit e815d35a55
9 changed files with 140 additions and 117 deletions

View file

@ -10,7 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## 0.11 (2023-xx-xx) ## 0.11 (2023-xx-xx)
- support for multiple authorizers - support for multiple authorizers
- JwtAuthorizer.layer() deprecated in favor of JwtAuthorizer.into_layer() - JwtAuthorizer::layer() deprecated in favor of JwtAuthorizer::build() and IntoLayer::into_layer()
## 0.10.1 (2023-07-11) ## 0.10.1 (2023-07-11)

View file

@ -1,5 +1,7 @@
use axum::{routing::get, Router}; use axum::{routing::get, Router};
use jwt_authorizer::{error::InitError, AuthError, IntoLayer, JwtAuthorizer, JwtClaims, Refresh, RefreshStrategy}; use jwt_authorizer::{
error::InitError, AuthError, Authorizer, IntoLayer, JwtAuthorizer, JwtClaims, Refresh, RefreshStrategy,
};
use serde::Deserialize; use serde::Deserialize;
use std::net::SocketAddr; use std::net::SocketAddr;
use tower_http::trace::TraceLayer; use tower_http::trace::TraceLayer;
@ -37,19 +39,21 @@ async fn main() -> Result<(), InitError> {
// First let's create an authorizer builder from a Oidc Discovery // First let's create an authorizer builder from a Oidc Discovery
// User is a struct deserializable from JWT claims representing the authorized user // User is a struct deserializable from JWT claims representing the authorized user
// let jwt_auth: JwtAuthorizer<User> = JwtAuthorizer::from_oidc("https://accounts.google.com/") // let jwt_auth: JwtAuthorizer<User> = JwtAuthorizer::from_oidc("https://accounts.google.com/")
let jwt_auth: JwtAuthorizer<User> = JwtAuthorizer::from_oidc(issuer_uri) let auth: Authorizer<User> = JwtAuthorizer::from_oidc(issuer_uri)
// .no_refresh() // .no_refresh()
.refresh(Refresh { .refresh(Refresh {
strategy: RefreshStrategy::Interval, strategy: RefreshStrategy::Interval,
..Default::default() ..Default::default()
}) })
.check(claim_checker); .check(claim_checker)
.build()
.await?;
// actual router demo // actual router demo
let api = Router::new() let api = Router::new()
.route("/protected", get(protected)) .route("/protected", get(protected))
// adding the authorizer layer // adding the authorizer layer
.layer(jwt_auth.into_layer().await?); .layer(auth.into_layer());
let app = Router::new() let app = Router::new()
// public endpoint // public endpoint

View file

@ -21,7 +21,7 @@ JWT authoriser Layer for Axum and Tonic.
## Usage Example ## Usage Example
```rust ```rust
# use jwt_authorizer::{AuthError, IntoLayer, JwtAuthorizer, JwtClaims, RegisteredClaims}; # use jwt_authorizer::{AuthError, Authorizer, JwtAuthorizer, JwtClaims, RegisteredClaims, IntoLayer};
# use axum::{routing::get, Router}; # use axum::{routing::get, Router};
# use serde::Deserialize; # use serde::Deserialize;
@ -29,12 +29,12 @@ JWT authoriser Layer for Axum and Tonic.
// let's create an authorizer builder from a JWKS Endpoint // let's create an authorizer builder from a JWKS Endpoint
// (a serializable struct can be used to represent jwt claims, JwtAuthorizer<RegisteredClaims> is the default) // (a serializable struct can be used to represent jwt claims, JwtAuthorizer<RegisteredClaims> is the default)
let jwt_auth: JwtAuthorizer = let auth: Authorizer =
JwtAuthorizer::from_jwks_url("http://localhost:3000/oidc/jwks"); JwtAuthorizer::from_jwks_url("http://localhost:3000/oidc/jwks").build().await.unwrap();
// adding the authorization layer // adding the authorization layer
let app = Router::new().route("/protected", get(protected)) let app = Router::new().route("/protected", get(protected))
.layer(jwt_auth.into_layer().await.unwrap()); .layer(auth.into_layer());
// proteced handler with user injection (mapping some jwt claims) // proteced handler with user injection (mapping some jwt claims)
async fn protected(JwtClaims(user): JwtClaims<RegisteredClaims>) -> Result<String, AuthError> { async fn protected(JwtClaims(user): JwtClaims<RegisteredClaims>) -> Result<String, AuthError> {

View file

@ -9,8 +9,8 @@ use serde::de::DeserializeOwned;
use crate::{ use crate::{
error::{AuthError, InitError}, error::{AuthError, InitError},
jwks::{key_store_manager::KeyStoreManager, KeyData, KeySource}, jwks::{key_store_manager::KeyStoreManager, KeyData, KeySource},
layer::{self, JwtSource}, layer::{self, AsyncAuthorizationLayer, JwtSource},
oidc, Refresh, oidc, Refresh, RegisteredClaims,
}; };
pub trait ClaimsChecker<C> { pub trait ClaimsChecker<C> {
@ -34,7 +34,7 @@ where
} }
} }
pub struct Authorizer<C> pub struct Authorizer<C = RegisteredClaims>
where where
C: Clone, C: Clone,
{ {
@ -233,6 +233,40 @@ where
} }
} }
pub trait IntoLayer<C>
where
C: Clone + DeserializeOwned + Send,
{
fn into_layer(self) -> AsyncAuthorizationLayer<C>;
}
impl<C> IntoLayer<C> for Vec<Authorizer<C>>
where
C: Clone + DeserializeOwned + Send,
{
fn into_layer(self) -> AsyncAuthorizationLayer<C> {
AsyncAuthorizationLayer::new(self.into_iter().map(Arc::new).collect())
}
}
impl<C, const N: usize> IntoLayer<C> for [Authorizer<C>; N]
where
C: Clone + DeserializeOwned + Send,
{
fn into_layer(self) -> AsyncAuthorizationLayer<C> {
AsyncAuthorizationLayer::new(self.into_iter().map(Arc::new).collect())
}
}
impl<C> IntoLayer<C> for Authorizer<C>
where
C: Clone + DeserializeOwned + Send,
{
fn into_layer(self) -> AsyncAuthorizationLayer<C> {
AsyncAuthorizationLayer::new(vec![Arc::new(self)])
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {

View file

@ -1,8 +1,6 @@
use axum::async_trait;
use axum::http::Request; use axum::http::Request;
use futures_core::ready; use futures_core::ready;
use futures_util::future::{self, BoxFuture}; use futures_util::future::{self, BoxFuture};
use futures_util::stream::{FuturesUnordered, StreamExt};
use jsonwebtoken::TokenData; use jsonwebtoken::TokenData;
use pin_project::pin_project; use pin_project::pin_project;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
@ -184,7 +182,7 @@ where
} }
/// Build axum layer /// Build axum layer
#[deprecated(since = "0.10.0", note = "please use `to_layer()` instead")] #[deprecated(since = "0.10.0", note = "please use `IntoLayer::into_layer()` instead")]
pub async fn layer(self) -> Result<AsyncAuthorizationLayer<C>, InitError> { pub async fn layer(self) -> Result<AsyncAuthorizationLayer<C>, InitError> {
let val = self.validation.unwrap_or_default(); let val = self.validation.unwrap_or_default();
let auth = Arc::new( let auth = Arc::new(
@ -192,57 +190,11 @@ where
); );
Ok(AsyncAuthorizationLayer::new(vec![auth])) Ok(AsyncAuthorizationLayer::new(vec![auth]))
} }
}
#[async_trait] pub async fn build(self) -> Result<Authorizer<C>, InitError> {
impl<C> IntoLayer<C> for JwtAuthorizer<C>
where
C: Clone + DeserializeOwned + Send + Sync,
{
async fn into_layer(self) -> Result<AsyncAuthorizationLayer<C>, InitError> {
let val = self.validation.unwrap_or_default(); let val = self.validation.unwrap_or_default();
let auth = Arc::new(
Authorizer::build(self.key_source_type, self.claims_checker, self.refresh, val, self.jwt_source).await?,
);
Ok(AsyncAuthorizationLayer::new(vec![auth]))
}
}
#[async_trait] Authorizer::build(self.key_source_type, self.claims_checker, self.refresh, val, self.jwt_source).await
impl<C, T> IntoLayer<C> for T
where
T: IntoIterator<Item = JwtAuthorizer<C>> + Send + Sync,
C: Clone + DeserializeOwned + Send + Sync,
{
async fn into_layer(self) -> Result<AsyncAuthorizationLayer<C>, InitError> {
let mut errs = Vec::<InitError>::new();
let mut auths = Vec::<Arc<Authorizer<C>>>::new();
let mut auths_futs: FuturesUnordered<_> = self
.into_iter()
.map(|a| {
Authorizer::build(
a.key_source_type,
a.claims_checker,
a.refresh,
a.validation.unwrap_or_default(),
a.jwt_source,
)
})
.collect();
while let Some(a) = auths_futs.next().await {
match a {
Ok(res) => auths.push(Arc::new(res)),
Err(err) => errs.push(err),
}
}
if let Some(e) = errs.into_iter().next() {
// TODO: composite build error (containing all errors)
Err(e)
} else {
Ok(AsyncAuthorizationLayer::new(auths))
}
} }
} }
@ -330,14 +282,6 @@ where
} }
} }
#[async_trait]
pub trait IntoLayer<C>
where
C: Clone + DeserializeOwned + Send,
{
async fn into_layer(self) -> Result<AsyncAuthorizationLayer<C>, InitError>;
}
// ---------- AsyncAuthorizationService -------- // ---------- AsyncAuthorizationService --------
/// Source of the bearer token /// Source of the bearer token
@ -486,33 +430,40 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::{IntoLayer, JwtAuthorizer}; use crate::{authorizer::Authorizer, IntoLayer, JwtAuthorizer, RegisteredClaims};
use super::AsyncAuthorizationLayer;
#[tokio::test]
async fn auth_into_layer() {
let auth1: Authorizer = JwtAuthorizer::from_secret("aaa").build().await.unwrap();
let layer = auth1.into_layer();
assert_eq!(1, layer.auths.len());
}
#[tokio::test]
async fn auths_into_layer() {
let auth1 = JwtAuthorizer::from_secret("aaa").build().await.unwrap();
let auth2 = JwtAuthorizer::from_secret("bbb").build().await.unwrap();
let layer: AsyncAuthorizationLayer<RegisteredClaims> = [auth1, auth2].into_layer();
assert_eq!(2, layer.auths.len());
}
#[tokio::test]
async fn vec_auths_into_layer() {
let auth1 = JwtAuthorizer::from_secret("aaa").build().await.unwrap();
let auth2 = JwtAuthorizer::from_secret("bbb").build().await.unwrap();
let layer: AsyncAuthorizationLayer<RegisteredClaims> = vec![auth1, auth2].into_layer();
assert_eq!(2, layer.auths.len());
}
#[tokio::test] #[tokio::test]
async fn jwt_auth_to_layer() { async fn jwt_auth_to_layer() {
let auth1: JwtAuthorizer = JwtAuthorizer::from_secret("aaa"); let auth1: JwtAuthorizer = JwtAuthorizer::from_secret("aaa");
let layer = auth1.into_layer().await; #[allow(deprecated)]
let layer = auth1.layer().await;
assert!(layer.is_ok()); assert!(layer.is_ok());
} }
#[tokio::test]
async fn vec_to_layer() {
let auth1: JwtAuthorizer = JwtAuthorizer::from_secret("aaa");
let auth2: JwtAuthorizer = JwtAuthorizer::from_secret("bbb");
let av = vec![auth1, auth2];
let layer = av.into_layer().await;
assert!(layer.is_ok());
}
#[tokio::test]
async fn vec_to_layer_errors() {
let auth1: JwtAuthorizer = JwtAuthorizer::from_ec_pem("aaa");
let auth2: JwtAuthorizer = JwtAuthorizer::from_ed_pem("bbb");
let av = vec![auth1, auth2];
let layer = av.into_layer().await;
assert!(layer.is_err());
if let Err(err) = layer {
assert_eq!(err.to_string(), "No such file or directory (os error 2)");
}
}
} }

View file

@ -6,9 +6,10 @@ use jsonwebtoken::TokenData;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
pub use self::error::AuthError; pub use self::error::AuthError;
pub use authorizer::{Authorizer, IntoLayer};
pub use claims::{NumericDate, OneOrArray, RegisteredClaims}; pub use claims::{NumericDate, OneOrArray, RegisteredClaims};
pub use jwks::key_store_manager::{Refresh, RefreshStrategy}; pub use jwks::key_store_manager::{Refresh, RefreshStrategy};
pub use layer::{IntoLayer, JwtAuthorizer}; pub use layer::JwtAuthorizer;
pub use validation::Validation; pub use validation::Validation;
pub mod authorizer; pub mod authorizer;

View file

@ -104,7 +104,7 @@ async fn app(jwt_auth: JwtAuthorizer<User>) -> Router {
let protected_route: Router = Router::new() let protected_route: Router = Router::new()
.route("/protected", get(protected_handler)) .route("/protected", get(protected_handler))
.route("/protected-with-user", get(protected_with_user)) .route("/protected-with-user", get(protected_with_user))
.layer(jwt_auth.into_layer().await.unwrap()); .layer(jwt_auth.build().await.unwrap().into_layer());
Router::new().merge(pub_route).merge(protected_route) Router::new().merge(pub_route).merge(protected_route)
} }

View file

@ -12,7 +12,12 @@ mod tests {
BoxError, Router, BoxError, Router,
}; };
use http::{header, HeaderValue}; use http::{header, HeaderValue};
use jwt_authorizer::{layer::JwtSource, validation::Validation, IntoLayer, JwtAuthorizer, JwtClaims}; use jwt_authorizer::{
authorizer::Authorizer,
layer::{AsyncAuthorizationLayer, JwtSource},
validation::Validation,
IntoLayer, JwtAuthorizer, JwtClaims,
};
use serde::Deserialize; use serde::Deserialize;
use tower::{util::MapErrLayer, ServiceExt}; use tower::{util::MapErrLayer, ServiceExt};
@ -23,7 +28,7 @@ mod tests {
sub: String, sub: String,
} }
async fn app(jwt_auth: impl IntoLayer<User>) -> Router { async fn app(layer: AsyncAuthorizationLayer<User>) -> Router {
Router::new().route("/public", get(|| async { "hello" })).route( Router::new().route("/public", get(|| async { "hello" })).route(
"/protected", "/protected",
get(|JwtClaims(user): JwtClaims<User>| async move { format!("hello: {}", user.sub) }).layer( get(|JwtClaims(user): JwtClaims<User>| async move { format!("hello: {}", user.sub) }).layer(
@ -32,18 +37,22 @@ mod tests {
tower::buffer::BufferLayer::new(1), tower::buffer::BufferLayer::new(1),
MapErrLayer::new(|e: BoxError| -> Infallible { panic!("{}", e) }), MapErrLayer::new(|e: BoxError| -> Infallible { panic!("{}", e) }),
), ),
jwt_auth.into_layer().await.unwrap(), layer,
), ),
), ),
) )
} }
async fn proteced_request_with_header( async fn proteced_request_with_header(jwt_auth: JwtAuthorizer<User>, header_name: &str, header_value: &str) -> Response {
jwt_auth: impl IntoLayer<User>, proteced_request_with_header_and_layer(jwt_auth.build().await.unwrap().into_layer(), header_name, header_value).await
}
async fn proteced_request_with_header_and_layer(
layer: AsyncAuthorizationLayer<User>,
header_name: &str, header_name: &str,
header_value: &str, header_value: &str,
) -> Response { ) -> Response {
app(jwt_auth) app(layer)
.await .await
.oneshot( .oneshot(
Request::builder() Request::builder()
@ -56,15 +65,18 @@ mod tests {
.unwrap() .unwrap()
} }
async fn make_proteced_request(jwt_auth: impl IntoLayer<User>, bearer: &str) -> Response { async fn make_proteced_request(jwt_auth: JwtAuthorizer<User>, bearer: &str) -> Response {
proteced_request_with_header(jwt_auth, "Authorization", &format!("Bearer {bearer}")).await proteced_request_with_header(jwt_auth, "Authorization", &format!("Bearer {bearer}")).await
} }
#[tokio::test] #[tokio::test]
async fn protected_without_jwt() { async fn protected_without_jwt() {
let jwt_auth: JwtAuthorizer<User> = JwtAuthorizer::from_rsa_pem("../config/rsa-public1.pem"); let auth: Authorizer<User> = JwtAuthorizer::from_rsa_pem("../config/rsa-public1.pem")
.build()
.await
.unwrap();
let response = app(jwt_auth) let response = app(auth.into_layer())
.await .await
.oneshot(Request::builder().uri("/protected").body(Body::empty()).unwrap()) .oneshot(Request::builder().uri("/protected").body(Body::empty()).unwrap())
.await .await
@ -342,24 +354,45 @@ mod tests {
// -------------------------- // --------------------------
#[tokio::test] #[tokio::test]
async fn multiple_authorizers() { async fn multiple_authorizers() {
let auths: Vec<JwtAuthorizer<User>> = vec![ let auths: Vec<Authorizer<User>> = vec![
JwtAuthorizer::from_ec_pem("../config/ecdsa-public1.pem"), JwtAuthorizer::from_ec_pem("../config/ecdsa-public1.pem")
JwtAuthorizer::from_rsa_pem("../config/rsa-public1.pem").jwt_source(JwtSource::Cookie("ccc".to_owned())), .build()
.await
.unwrap(),
JwtAuthorizer::from_rsa_pem("../config/rsa-public1.pem")
.jwt_source(JwtSource::Cookie("ccc".to_owned()))
.build()
.await
.unwrap(),
]; ];
// OK // OK
let response = let response = proteced_request_with_header_and_layer(
proteced_request_with_header(auths, header::COOKIE.as_str(), &format!("ccc={}", common::JWT_RSA1_OK)).await; auths.into_layer(),
header::COOKIE.as_str(),
&format!("ccc={}", common::JWT_RSA1_OK),
)
.await;
assert_eq!(response.status(), StatusCode::OK); assert_eq!(response.status(), StatusCode::OK);
let auths: [JwtAuthorizer<User>; 2] = [ let auths: [Authorizer<User>; 2] = [
JwtAuthorizer::from_ec_pem("../config/ecdsa-public1.pem"), JwtAuthorizer::from_ec_pem("../config/ecdsa-public1.pem")
JwtAuthorizer::from_rsa_pem("../config/rsa-public1.pem").jwt_source(JwtSource::Cookie("ccc".to_owned())), .build()
.await
.unwrap(),
JwtAuthorizer::from_rsa_pem("../config/rsa-public1.pem")
.jwt_source(JwtSource::Cookie("ccc".to_owned()))
.build()
.await
.unwrap(),
]; ];
// Cookie missing // Cookie missing
let response = let response = proteced_request_with_header_and_layer(
proteced_request_with_header(auths, header::COOKIE.as_str(), &format!("bad_cookie={}", common::JWT_EC2_OK)) auths.into_layer(),
header::COOKIE.as_str(),
&format!("bad_cookie={}", common::JWT_EC2_OK),
)
.await; .await;
assert_eq!(response.status(), StatusCode::UNAUTHORIZED); assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
assert_eq!(response.headers().get(header::WWW_AUTHENTICATE).unwrap(), &"Bearer"); assert_eq!(response.headers().get(header::WWW_AUTHENTICATE).unwrap(), &"Bearer");

View file

@ -83,7 +83,7 @@ async fn app(
jwt_auth: JwtAuthorizer<User>, jwt_auth: JwtAuthorizer<User>,
expected_sub: String, expected_sub: String,
) -> AsyncAuthorizationService<Buffer<tonic::transport::server::Routes, http::Request<tonic::transport::Body>>, User> { ) -> AsyncAuthorizationService<Buffer<tonic::transport::server::Routes, http::Request<tonic::transport::Body>>, User> {
let layer = jwt_auth.into_layer().await.unwrap(); let layer = jwt_auth.build().await.unwrap().into_layer();
tonic::transport::Server::builder() tonic::transport::Server::builder()
.layer(layer) .layer(layer)
.layer(tower::buffer::BufferLayer::new(1)) .layer(tower::buffer::BufferLayer::new(1))