fix: claims_checker

This commit is contained in:
cduvray 2023-01-20 08:05:42 +01:00
parent 9f459fb362
commit 6535408979
3 changed files with 31 additions and 9 deletions

View file

@ -66,7 +66,7 @@ where
}) })
} }
pub fn from(key_source_type: &KeySourceType) -> Result<Authorizer<C>, InitError> { pub fn from(key_source_type: &KeySourceType, claims_checker: Option<FnClaimsChecker<C>>) -> Result<Authorizer<C>, InitError> {
let key = match key_source_type { let key = match key_source_type {
KeySourceType::RSA(path) => DecodingKey::from_rsa_pem(&read_data(path.as_str())?)?, KeySourceType::RSA(path) => DecodingKey::from_rsa_pem(&read_data(path.as_str())?)?,
KeySourceType::EC(path) => DecodingKey::from_ec_der(&read_data(path.as_str())?), KeySourceType::EC(path) => DecodingKey::from_ec_der(&read_data(path.as_str())?),
@ -77,7 +77,7 @@ where
Ok(Authorizer { Ok(Authorizer {
key_source: KeySource::DecodingKeySource(key), key_source: KeySource::DecodingKeySource(key),
claims_checker: None, claims_checker,
}) })
} }
@ -116,7 +116,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn from_secret() { async fn from_secret() {
let h = Header::new(Algorithm::HS256); let h = Header::new(Algorithm::HS256);
let a = Authorizer::<Value>::from(&KeySourceType::Secret("xxxxxx")).unwrap(); let a = Authorizer::<Value>::from(&KeySourceType::Secret("xxxxxx"), None).unwrap();
let k = a.key_source.get_key(h); let k = a.key_source.get_key(h);
assert!(k.await.is_ok()); assert!(k.await.is_ok());
} }
@ -140,22 +140,22 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn from_file() { async fn from_file() {
let a = Authorizer::<Value>::from(&KeySourceType::RSA("../config/jwtRS256.key.pub".to_owned())).unwrap(); let a = Authorizer::<Value>::from(&KeySourceType::RSA("../config/jwtRS256.key.pub".to_owned()), None).unwrap();
let k = a.key_source.get_key(Header::new(Algorithm::RS256)); let k = a.key_source.get_key(Header::new(Algorithm::RS256));
assert!(k.await.is_ok()); assert!(k.await.is_ok());
let a = Authorizer::<Value>::from(&KeySourceType::EC("../config/ec256-public.pem".to_owned())).unwrap(); let a = Authorizer::<Value>::from(&KeySourceType::EC("../config/ec256-public.pem".to_owned()), None).unwrap();
let k = a.key_source.get_key(Header::new(Algorithm::ES256)); let k = a.key_source.get_key(Header::new(Algorithm::ES256));
assert!(k.await.is_ok()); assert!(k.await.is_ok());
let a = Authorizer::<Value>::from(&KeySourceType::ED("../config/ed25519-public.pem".to_owned())).unwrap(); let a = Authorizer::<Value>::from(&KeySourceType::ED("../config/ed25519-public.pem".to_owned()), None).unwrap();
let k = a.key_source.get_key(Header::new(Algorithm::EdDSA)); let k = a.key_source.get_key(Header::new(Algorithm::EdDSA));
assert!(k.await.is_ok()); assert!(k.await.is_ok());
} }
#[tokio::test] #[tokio::test]
async fn from_file_errors() { async fn from_file_errors() {
let a = Authorizer::<Value>::from(&KeySourceType::RSA("./config/does-not-exist.pem".to_owned())); let a = Authorizer::<Value>::from(&KeySourceType::RSA("./config/does-not-exist.pem".to_owned()), None);
println!("{:?}", a.as_ref().err()); println!("{:?}", a.as_ref().err());
assert!(a.is_err()); assert!(a.is_err());
} }

View file

@ -89,7 +89,7 @@ where
let auth = if let Some(ref key_source_type) = self.key_source_type { let auth = if let Some(ref key_source_type) = self.key_source_type {
match key_source_type { match key_source_type {
KeySourceType::RSA(_) | KeySourceType::EC(_) | KeySourceType::ED(_) | KeySourceType::Secret(_) => { KeySourceType::RSA(_) | KeySourceType::EC(_) | KeySourceType::ED(_) | KeySourceType::Secret(_) => {
Arc::new(Authorizer::from(key_source_type)?) Arc::new(Authorizer::from(key_source_type, self.claims_checker.clone())?)
} }
KeySourceType::Jwks(url) => { KeySourceType::Jwks(url) => {
Arc::new(Authorizer::from_jwks_url(url.as_str(), self.claims_checker.clone())?) Arc::new(Authorizer::from_jwks_url(url.as_str(), self.claims_checker.clone())?)

View file

@ -6,7 +6,7 @@ mod tests {
http::{Request, StatusCode}, http::{Request, StatusCode},
routing::get, Router, response::Response, routing::get, Router, response::Response,
}; };
use http::header; use http::{header, HeaderValue};
use serde::Deserialize; use serde::Deserialize;
use tower::ServiceExt; use tower::ServiceExt;
@ -80,6 +80,28 @@ mod tests {
// TODO: check error code (https://datatracker.ietf.org/doc/html/rfc6750#section-3.1) // TODO: check error code (https://datatracker.ietf.org/doc/html/rfc6750#section-3.1)
} }
#[tokio::test]
async fn protected_with_claims_check() {
let rsp_ok = make_proteced_request(
JwtAuthorizer::new().from_rsa_pem("../config/jwtRS256.key.pub").with_check(|_|true),
JWT_RSA_OK
).await;
assert_eq!(rsp_ok.status(), StatusCode::OK);
let rsp_ko = make_proteced_request(
JwtAuthorizer::new().from_rsa_pem("../config/jwtRS256.key.pub").with_check(|_|false),
JWT_RSA_OK
).await;
assert_eq!(rsp_ko.status(), StatusCode::UNAUTHORIZED);
let h = rsp_ko.headers().get(http::header::WWW_AUTHENTICATE);
assert!(h.is_some(), "WWW-AUTHENTICATE header missing!");
assert_eq!(h.unwrap(), HeaderValue::from_static("Bearer error=\"insufficient_scope\""), "Bad WWW-AUTHENTICATE header!");
}
// Unreachable jwks endpoint, should build (endpoint can comme on line later ), // Unreachable jwks endpoint, should build (endpoint can comme on line later ),
// but should be 500 when checking. // but should be 500 when checking.
#[tokio::test] #[tokio::test]