refactor: ClaimCheckerFn (simplification)

This commit is contained in:
cduvray 2023-09-19 08:04:22 +02:00 committed by cduvray
parent b42aab8d31
commit 6e19f31c77
2 changed files with 7 additions and 28 deletions

View file

@ -13,33 +13,14 @@ use crate::{
oidc, Refresh, RegisteredClaims, oidc, Refresh, RegisteredClaims,
}; };
pub trait ClaimsChecker<C> { pub type ClaimsCheckerFn<C> = Arc<Box<dyn Fn(&C) -> bool + Send + Sync>>;
fn check(&self, claims: &C) -> bool;
}
#[derive(Clone)]
pub struct FnClaimsChecker<C>
where
C: Clone + Send + Sync,
{
pub checker_fn: Arc<Box<dyn Fn(&C) -> bool + Send + Sync>>,
}
impl<C> ClaimsChecker<C> for FnClaimsChecker<C>
where
C: Clone + Send + Sync,
{
fn check(&self, claims: &C) -> bool {
(self.checker_fn)(claims)
}
}
pub struct Authorizer<C = RegisteredClaims> pub struct Authorizer<C = RegisteredClaims>
where where
C: Clone + Send, C: Clone + Send,
{ {
pub key_source: KeySource, pub key_source: KeySource,
pub claims_checker: Option<FnClaimsChecker<C>>, pub claims_checker: Option<ClaimsCheckerFn<C>>,
pub validation: crate::validation::Validation, pub validation: crate::validation::Validation,
pub jwt_source: JwtSource, pub jwt_source: JwtSource,
} }
@ -70,7 +51,7 @@ where
{ {
pub(crate) async fn build( pub(crate) async fn build(
key_source_type: KeySourceType, key_source_type: KeySourceType,
claims_checker: Option<FnClaimsChecker<C>>, claims_checker: Option<ClaimsCheckerFn<C>>,
refresh: Option<Refresh>, refresh: Option<Refresh>,
validation: crate::validation::Validation, validation: crate::validation::Validation,
jwt_source: JwtSource, jwt_source: JwtSource,
@ -212,7 +193,7 @@ where
let token_data = decode::<C>(token, &val_key.key, jwt_validation)?; let token_data = decode::<C>(token, &val_key.key, jwt_validation)?;
if let Some(ref checker) = self.claims_checker { if let Some(ref checker) = self.claims_checker {
if !checker.check(&token_data.claims) { if !checker(&token_data.claims) {
return Err(AuthError::InvalidClaims()); return Err(AuthError::InvalidClaims());
} }
} }

View file

@ -3,7 +3,7 @@ use std::sync::Arc;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use crate::{ use crate::{
authorizer::{FnClaimsChecker, KeySourceType}, authorizer::{ClaimsCheckerFn, KeySourceType},
error::InitError, error::InitError,
layer::{AuthorizationLayer, JwtSource}, layer::{AuthorizationLayer, JwtSource},
Authorizer, Refresh, RefreshStrategy, RegisteredClaims, Validation, Authorizer, Refresh, RefreshStrategy, RegisteredClaims, Validation,
@ -19,7 +19,7 @@ where
{ {
key_source_type: KeySourceType, key_source_type: KeySourceType,
refresh: Option<Refresh>, refresh: Option<Refresh>,
claims_checker: Option<FnClaimsChecker<C>>, claims_checker: Option<ClaimsCheckerFn<C>>,
validation: Option<Validation>, validation: Option<Validation>,
jwt_source: JwtSource, jwt_source: JwtSource,
} }
@ -158,9 +158,7 @@ where
where where
F: Fn(&C) -> bool + Send + Sync + 'static, F: Fn(&C) -> bool + Send + Sync + 'static,
{ {
self.claims_checker = Some(FnClaimsChecker { self.claims_checker = Some(Arc::new(Box::new(checker_fn)));
checker_fn: Arc::new(Box::new(checker_fn)),
});
self self
} }