refactor: ClaimCheckerFn (simplification)

This commit is contained in:
cduvray 2023-09-19 08:04:22 +02:00 committed by cduvray
parent b42aab8d31
commit 6e19f31c77
2 changed files with 7 additions and 28 deletions

View file

@ -13,33 +13,14 @@ use crate::{
oidc, Refresh, RegisteredClaims,
};
pub trait ClaimsChecker<C> {
fn check(&self, claims: &C) -> bool;
}
#[derive(Clone)]
pub struct FnClaimsChecker<C>
where
C: Clone + Send + Sync,
{
pub checker_fn: Arc<Box<dyn Fn(&C) -> bool + Send + Sync>>,
}
impl<C> ClaimsChecker<C> for FnClaimsChecker<C>
where
C: Clone + Send + Sync,
{
fn check(&self, claims: &C) -> bool {
(self.checker_fn)(claims)
}
}
pub type ClaimsCheckerFn<C> = Arc<Box<dyn Fn(&C) -> bool + Send + Sync>>;
pub struct Authorizer<C = RegisteredClaims>
where
C: Clone + Send,
{
pub key_source: KeySource,
pub claims_checker: Option<FnClaimsChecker<C>>,
pub claims_checker: Option<ClaimsCheckerFn<C>>,
pub validation: crate::validation::Validation,
pub jwt_source: JwtSource,
}
@ -70,7 +51,7 @@ where
{
pub(crate) async fn build(
key_source_type: KeySourceType,
claims_checker: Option<FnClaimsChecker<C>>,
claims_checker: Option<ClaimsCheckerFn<C>>,
refresh: Option<Refresh>,
validation: crate::validation::Validation,
jwt_source: JwtSource,
@ -212,7 +193,7 @@ where
let token_data = decode::<C>(token, &val_key.key, jwt_validation)?;
if let Some(ref checker) = self.claims_checker {
if !checker.check(&token_data.claims) {
if !checker(&token_data.claims) {
return Err(AuthError::InvalidClaims());
}
}

View file

@ -3,7 +3,7 @@ use std::sync::Arc;
use serde::de::DeserializeOwned;
use crate::{
authorizer::{FnClaimsChecker, KeySourceType},
authorizer::{ClaimsCheckerFn, KeySourceType},
error::InitError,
layer::{AuthorizationLayer, JwtSource},
Authorizer, Refresh, RefreshStrategy, RegisteredClaims, Validation,
@ -19,7 +19,7 @@ where
{
key_source_type: KeySourceType,
refresh: Option<Refresh>,
claims_checker: Option<FnClaimsChecker<C>>,
claims_checker: Option<ClaimsCheckerFn<C>>,
validation: Option<Validation>,
jwt_source: JwtSource,
}
@ -158,9 +158,7 @@ where
where
F: Fn(&C) -> bool + Send + Sync + 'static,
{
self.claims_checker = Some(FnClaimsChecker {
checker_fn: Arc::new(Box::new(checker_fn)),
});
self.claims_checker = Some(Arc::new(Box::new(checker_fn)));
self
}