refacor: KeyStore, KeySource

- (performance) build DecodingKey once (per refresh)
- (security) store algorithm in KeyData
This commit is contained in:
cduvray 2023-03-15 08:21:04 +01:00
parent 8f55bf9d3e
commit ca14e15b67
6 changed files with 125 additions and 63 deletions

View file

@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased ## Unreleased
## 0.8.1 (?-?-?)
No public API changes, no new features.
### Changed
- KeyStore, KeySource refactor for better performance and security
## 0.8.0 (2023-02-28) ## 0.8.0 (2023-02-28)
### Added ### Added

View file

@ -1,12 +1,12 @@
use std::io::Read; use std::{io::Read, sync::Arc};
use jsonwebtoken::{decode, decode_header, jwk::JwkSet, DecodingKey, TokenData}; use jsonwebtoken::{decode, decode_header, jwk::JwkSet, Algorithm, DecodingKey, TokenData};
use reqwest::Url; use reqwest::Url;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use crate::{ use crate::{
error::{AuthError, InitError}, error::{AuthError, InitError},
jwks::{key_store_manager::KeyStoreManager, KeySource}, jwks::{key_store_manager::KeyStoreManager, KeyData, KeySource},
oidc, Refresh, oidc, Refresh,
}; };
@ -71,7 +71,11 @@ where
KeySourceType::RSA(path) => { KeySourceType::RSA(path) => {
let key = DecodingKey::from_rsa_pem(&read_data(path.as_str())?)?; let key = DecodingKey::from_rsa_pem(&read_data(path.as_str())?)?;
Authorizer { Authorizer {
key_source: KeySource::DecodingKeySource(key), key_source: KeySource::SingleKeySource(Arc::new(KeyData {
kid: None,
alg: vec![Algorithm::RS256, Algorithm::RS384, Algorithm::RS512],
key,
})),
claims_checker, claims_checker,
validation, validation,
} }
@ -79,7 +83,11 @@ where
KeySourceType::EC(path) => { KeySourceType::EC(path) => {
let key = DecodingKey::from_ec_pem(&read_data(path.as_str())?)?; let key = DecodingKey::from_ec_pem(&read_data(path.as_str())?)?;
Authorizer { Authorizer {
key_source: KeySource::DecodingKeySource(key), key_source: KeySource::SingleKeySource(Arc::new(KeyData {
kid: None,
alg: vec![Algorithm::ES256, Algorithm::ES384],
key,
})),
claims_checker, claims_checker,
validation, validation,
} }
@ -87,7 +95,11 @@ where
KeySourceType::ED(path) => { KeySourceType::ED(path) => {
let key = DecodingKey::from_ed_pem(&read_data(path.as_str())?)?; let key = DecodingKey::from_ed_pem(&read_data(path.as_str())?)?;
Authorizer { Authorizer {
key_source: KeySource::DecodingKeySource(key), key_source: KeySource::SingleKeySource(Arc::new(KeyData {
kid: None,
alg: vec![Algorithm::EdDSA],
key,
})),
claims_checker, claims_checker,
validation, validation,
} }
@ -95,7 +107,11 @@ where
KeySourceType::Secret(secret) => { KeySourceType::Secret(secret) => {
let key = DecodingKey::from_secret(secret.as_bytes()); let key = DecodingKey::from_secret(secret.as_bytes());
Authorizer { Authorizer {
key_source: KeySource::DecodingKeySource(key), key_source: KeySource::SingleKeySource(Arc::new(KeyData {
kid: None,
alg: vec![Algorithm::HS256, Algorithm::HS384, Algorithm::HS512],
key,
})),
claims_checker, claims_checker,
validation, validation,
} }
@ -104,9 +120,9 @@ where
// TODO: expose it in JwtAuthorizer or remove // TODO: expose it in JwtAuthorizer or remove
let set: JwkSet = serde_json::from_str(jwks_str)?; let set: JwkSet = serde_json::from_str(jwks_str)?;
// TODO: replace [0] by kid/alg search // TODO: replace [0] by kid/alg search
let k = DecodingKey::from_jwk(&set.keys[0])?; let k = KeyData::from_jwk(&set.keys[0]).map_err(InitError::KeyDecodingError)?;
Authorizer { Authorizer {
key_source: KeySource::DecodingKeySource(k), key_source: KeySource::SingleKeySource(Arc::new(k)),
claims_checker, claims_checker,
validation, validation,
} }
@ -136,11 +152,10 @@ where
pub async fn check_auth(&self, token: &str) -> Result<TokenData<C>, AuthError> { pub async fn check_auth(&self, token: &str) -> Result<TokenData<C>, AuthError> {
let header = decode_header(token)?; let header = decode_header(token)?;
// TODO: build validation only once or cache it (store it in key_source?) // TODO: (optimisation) build & store jwt_validation in key data, to avoid rebuilding it for each check
// (problem: alg family is checked in jsonwebtoken but may change with store refresh) let val_key = self.key_source.get_key(header).await?;
let jwt_validation = &self.validation.to_jwt_validation(header.alg); let jwt_validation = &self.validation.to_jwt_validation(val_key.alg.clone());
let decoding_key = self.key_source.get_key(header).await?; let token_data = decode::<C>(token, &val_key.key, jwt_validation)?;
let token_data = decode::<C>(token, &decoding_key, jwt_validation)?;
if let Some(ref checker) = self.claims_checker { if let Some(ref checker) = self.claims_checker {
if !checker.check(&token_data.claims) { if !checker.check(&token_data.claims) {

View file

@ -18,7 +18,7 @@ pub enum InitError {
KeyFileError(#[from] std::io::Error), KeyFileError(#[from] std::io::Error),
#[error(transparent)] #[error(transparent)]
KeyFileDecodingError(#[from] jsonwebtoken::errors::Error), KeyDecodingError(#[from] jsonwebtoken::errors::Error),
#[error("Builder Error {0}")] #[error("Builder Error {0}")]
DiscoveryError(String), DiscoveryError(String),
@ -35,8 +35,8 @@ pub enum AuthError {
#[error(transparent)] #[error(transparent)]
JwksSerialisationError(#[from] serde_json::Error), JwksSerialisationError(#[from] serde_json::Error),
#[error(transparent)] #[error("JwksRefreshError {0}")]
JwksRefreshError(#[from] reqwest::Error), JwksRefreshError(String),
#[error("InvalidKey {0}")] #[error("InvalidKey {0}")]
InvalidKey(String), InvalidKey(String),

View file

@ -1,7 +1,4 @@
use jsonwebtoken::{ use jsonwebtoken::{jwk::JwkSet, Algorithm};
jwk::{Jwk, JwkSet},
Algorithm, DecodingKey,
};
use reqwest::Url; use reqwest::Url;
use std::{ use std::{
sync::Arc, sync::Arc,
@ -11,8 +8,10 @@ use tokio::sync::Mutex;
use crate::error::AuthError; use crate::error::AuthError;
use super::KeyData;
/// Defines the strategy for the JWKS refresh. /// Defines the strategy for the JWKS refresh.
#[derive(Clone, Copy)] #[derive(Clone)]
pub enum RefreshStrategy { pub enum RefreshStrategy {
/// refresh periodicaly /// refresh periodicaly
Interval, Interval,
@ -25,7 +24,7 @@ pub enum RefreshStrategy {
} }
/// JWKS Refresh configuration /// JWKS Refresh configuration
#[derive(Clone, Copy)] #[derive(Clone)]
pub struct Refresh { pub struct Refresh {
pub strategy: RefreshStrategy, pub strategy: RefreshStrategy,
/// After the refresh interval the store will/can be refreshed. /// After the refresh interval the store will/can be refreshed.
@ -60,7 +59,7 @@ pub struct KeyStoreManager {
pub struct KeyStore { pub struct KeyStore {
/// key set /// key set
jwks: JwkSet, keys: Vec<Arc<KeyData>>,
/// time of the last successfully loaded jwkset /// time of the last successfully loaded jwkset
load_time: Option<Instant>, load_time: Option<Instant>,
/// time of the last failed load /// time of the last failed load
@ -73,14 +72,14 @@ impl KeyStoreManager {
key_url, key_url,
refresh, refresh,
keystore: Arc::new(Mutex::new(KeyStore { keystore: Arc::new(Mutex::new(KeyStore {
jwks: JwkSet { keys: vec![] }, keys: vec![],
load_time: None, load_time: None,
fail_time: None, fail_time: None,
})), })),
} }
} }
pub(crate) async fn get_key(&self, header: &jsonwebtoken::Header) -> Result<jsonwebtoken::DecodingKey, AuthError> { pub(crate) async fn get_key(&self, header: &jsonwebtoken::Header) -> Result<Arc<KeyData>, AuthError> {
let kstore = self.keystore.clone(); let kstore = self.keystore.clone();
let mut ks_gard = kstore.lock().await; let mut ks_gard = kstore.lock().await;
let key = match self.refresh.strategy { let key = match self.refresh.strategy {
@ -141,8 +140,7 @@ impl KeyStoreManager {
} }
} }
}; };
Ok(key.clone())
DecodingKey::from_jwk(key).map_err(|err| AuthError::InvalidKey(err.to_string()))
} }
} }
@ -169,46 +167,55 @@ impl KeyStore {
.await .await
.map_err(|e| { .map_err(|e| {
self.fail_time = Some(Instant::now()); self.fail_time = Some(Instant::now());
AuthError::JwksRefreshError(e) AuthError::JwksRefreshError(e.to_string())
})? })?
.json::<JwkSet>() .json::<JwkSet>()
.await .await
.map(|jwks| { .map(|jwks| {
self.load_time = Some(Instant::now()); self.load_time = Some(Instant::now());
self.jwks = jwks; // self.jwks = jwks;
let mut keys: Vec<Arc<KeyData>> = Vec::with_capacity(jwks.keys.len());
for jwk in jwks.keys {
match KeyData::from_jwk(&jwk) {
Ok(kdata) => keys.push(Arc::new(kdata)),
Err(err) => {
tracing::warn!("Jwk decoding error, the key will be ignored! ({})", err);
}
};
}
if keys.is_empty() {
Err(AuthError::JwksRefreshError("No valid keys in the Jwk Set!".to_owned()))
} else {
self.keys = keys;
self.fail_time = None; self.fail_time = None;
Ok(()) Ok(())
}
}) })
.map_err(|e| { .map_err(|e| {
self.fail_time = Some(Instant::now()); self.fail_time = Some(Instant::now());
AuthError::JwksRefreshError(e) AuthError::JwksRefreshError(e.to_string())
})? })?
} }
/// Find the key in the set that matches the given key id, if any. /// Find the key in the set that matches the given key id, if any.
pub fn find_kid(&self, kid: &str) -> Option<&Jwk> { pub fn find_kid(&self, kid: &str) -> Option<&Arc<KeyData>> {
self.jwks.find(kid) self.keys.iter().find(|k| k.kid.is_some() && k.kid.as_ref().unwrap() == kid)
} }
/// Find the key in the set that matches the given key id, if any. /// Find the key in the set that matches the given key id, if any.
pub fn find_alg(&self, alg: &Algorithm) -> Option<&Jwk> { pub fn find_alg(&self, alg: &Algorithm) -> Option<&Arc<KeyData>> {
self.jwks.keys.iter().find(|jwk| { self.keys.iter().find(|k| k.alg.contains(alg))
if let Some(ref a) = jwk.common.algorithm {
alg == a
} else {
false
}
})
} }
/// Find first key. /// Find first key.
pub fn find_first(&self) -> Option<&Jwk> { pub fn find_first(&self) -> Option<&Arc<KeyData>> {
self.jwks.keys.get(0) self.keys.get(0)
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use jsonwebtoken::Algorithm; use jsonwebtoken::Algorithm;
@ -220,8 +227,18 @@ mod tests {
}; };
use crate::jwks::key_store_manager::{KeyStore, KeyStoreManager}; use crate::jwks::key_store_manager::{KeyStore, KeyStoreManager};
use crate::jwks::KeyData;
use crate::{Refresh, RefreshStrategy}; use crate::{Refresh, RefreshStrategy};
const JWK_RSA01: &str = r#"{
"kty": "RSA",
"n": "2pQeZdxa7q093K7bj5h6-leIpxfTnuAxzXdhjfGEJHxmt2ekHyCBWWWXCBiDn2RTcEBcy6gZqOW45Uy_tw-5e-Px1xFj1PykGEkRlOpYSAeWsNaAWvvpGB9m4zQ0PgZeMDDXE5IIBrY6YAzmGQxV-fcGGLhJnXl0-5_z7tKC7RvBoT3SGwlc_AmJqpFtTpEBn_fDnyqiZbpcjXYLExFpExm41xDitRKHWIwfc3dV8_vlNntlxCPGy_THkjdXJoHv2IJmlhvmr5_h03iGMLWDKSywxOol_4Wc1BT7Hb6byMxW40GKwSJJ4p7W8eI5mqggRHc8jlwSsTN9LZ2VOvO-XiVShZRVg7JeraGAfWwaIgIJ1D8C1h5Pi0iFpp2suxpHAXHfyLMJXuVotpXbDh4NDX-A4KRMgaxcfAcui_x6gybksq6gF90-9nfQfmVMVJctZ6M-FvRr-itd1Nef5WAtwUp1qyZygAXU3cH3rarscajmurOsP6dE1OHl3grY_eZhQxk33VBK9lavqNKPg6Q_PLiq1ojbYBj3bcYifJrsNeQwxldQP83aWt5rGtgZTehKVJwa40Uy_Grae1iRnsDtdSy5sTJIJ6EiShnWAdMoGejdiI8vpkjrdU8SWH8lv1KXI54DsbyAuke2cYz02zPWc6JEotQqI0HwhzU0KHyoY4s",
"e": "AQAB",
"kid": "rsa01",
"alg": "RS256",
"use": "sig"
}"#;
const JWK_ED01: &str = r#"{ const JWK_ED01: &str = r#"{
"kty": "OKP", "kty": "OKP",
"use": "sig", "use": "sig",
@ -264,7 +281,7 @@ mod tests {
fn keystore_can_refresh() { fn keystore_can_refresh() {
// FAIL, NO LOAD // FAIL, NO LOAD
let ks = KeyStore { let ks = KeyStore {
jwks: jsonwebtoken::jwk::JwkSet { keys: vec![] }, keys: vec![],
fail_time: Instant::now().checked_sub(Duration::from_secs(5)), fail_time: Instant::now().checked_sub(Duration::from_secs(5)),
load_time: None, load_time: None,
}; };
@ -273,7 +290,7 @@ mod tests {
// NO FAIL, LOAD // NO FAIL, LOAD
let ks = KeyStore { let ks = KeyStore {
jwks: jsonwebtoken::jwk::JwkSet { keys: vec![] }, keys: vec![],
fail_time: None, fail_time: None,
load_time: Instant::now().checked_sub(Duration::from_secs(5)), load_time: Instant::now().checked_sub(Duration::from_secs(5)),
}; };
@ -282,7 +299,7 @@ mod tests {
// FAIL, LOAD // FAIL, LOAD
let ks = KeyStore { let ks = KeyStore {
jwks: jsonwebtoken::jwk::JwkSet { keys: vec![] }, keys: vec![],
fail_time: Instant::now().checked_sub(Duration::from_secs(5)), fail_time: Instant::now().checked_sub(Duration::from_secs(5)),
load_time: Instant::now().checked_sub(Duration::from_secs(10)), load_time: Instant::now().checked_sub(Duration::from_secs(10)),
}; };
@ -293,25 +310,28 @@ mod tests {
#[test] #[test]
fn find_kid() { fn find_kid() {
let jwk0: Jwk = serde_json::from_str(r#"{"kid":"1","kty":"RSA","alg":"RS256","n":"xxxx","e":"AQAB"}"#).unwrap(); let jwk0: Jwk = serde_json::from_str(JWK_RSA01).unwrap();
let jwk1: Jwk = serde_json::from_str(r#"{"kid":"2","kty":"RSA","alg":"RS256","n":"xxxx","e":"AQAB"}"#).unwrap(); let jwk1: Jwk = serde_json::from_str(JWK_EC01).unwrap();
let ks = KeyStore { let ks = KeyStore {
load_time: None, load_time: None,
fail_time: None, fail_time: None,
jwks: jsonwebtoken::jwk::JwkSet { keys: vec![jwk0, jwk1] }, keys: vec![
Arc::new(KeyData::from_jwk(&jwk0).unwrap()),
Arc::new(KeyData::from_jwk(&jwk1).unwrap()),
],
}; };
assert!(ks.find_kid("1").is_some()); assert!(ks.find_kid("rsa01").is_some());
assert!(ks.find_kid("2").is_some()); assert!(ks.find_kid("ec01").is_some());
assert!(ks.find_kid("3").is_none()); assert!(ks.find_kid("3").is_none());
} }
#[test] #[test]
fn find_alg() { fn find_alg() {
let jwk0: Jwk = serde_json::from_str(r#"{"kty": "RSA", "alg": "RS256", "n": "xxx","e": "yyy"}"#).unwrap(); let jwk0: Jwk = serde_json::from_str(JWK_RSA01).unwrap();
let ks = KeyStore { let ks = KeyStore {
load_time: None, load_time: None,
fail_time: None, fail_time: None,
jwks: jsonwebtoken::jwk::JwkSet { keys: vec![jwk0] }, keys: vec![Arc::new(KeyData::from_jwk(&jwk0).unwrap())],
}; };
assert!(ks.find_alg(&Algorithm::RS256).is_some()); assert!(ks.find_alg(&Algorithm::RS256).is_some());
assert!(ks.find_alg(&Algorithm::EdDSA).is_none()); assert!(ks.find_alg(&Algorithm::EdDSA).is_none());

View file

@ -1,4 +1,6 @@
use jsonwebtoken::{DecodingKey, Header}; use std::sync::Arc;
use jsonwebtoken::{jwk::Jwk, Algorithm, DecodingKey, Header};
use crate::error::AuthError; use crate::error::AuthError;
@ -8,17 +10,34 @@ pub mod key_store_manager;
#[derive(Clone)] #[derive(Clone)]
pub enum KeySource { pub enum KeySource {
/// KeyDataSource managing a refreshable key sets
KeyStoreSource(KeyStoreManager), KeyStoreSource(KeyStoreManager),
DecodingKeySource(DecodingKey), /// Manages one public key, initialized on startup
SingleKeySource(Arc<KeyData>),
}
#[derive(Clone)]
pub struct KeyData {
pub kid: Option<String>,
pub alg: Vec<Algorithm>,
pub key: DecodingKey,
}
impl KeyData {
pub fn from_jwk(key: &Jwk) -> Result<KeyData, jsonwebtoken::errors::Error> {
Ok(KeyData {
kid: key.common.key_id.clone(),
alg: vec![key.common.algorithm.unwrap_or(Algorithm::RS256)], // TODO: is this good default?
key: DecodingKey::from_jwk(key)?,
})
}
} }
impl KeySource { impl KeySource {
pub async fn get_key(&self, header: Header) -> Result<DecodingKey, AuthError> { pub async fn get_key(&self, header: Header) -> Result<Arc<KeyData>, AuthError> {
match self { match self {
KeySource::KeyStoreSource(kstore) => kstore.get_key(&header).await, KeySource::KeyStoreSource(kstore) => kstore.get_key(&header).await,
KeySource::DecodingKeySource(key) => { KeySource::SingleKeySource(key) => Ok(key.clone()),
Ok(key.clone()) // TODO: clone -> &
}
} }
} }
} }

View file

@ -82,7 +82,7 @@ impl Validation {
self self
} }
pub(crate) fn to_jwt_validation(&self, alg: Algorithm) -> jsonwebtoken::Validation { pub(crate) fn to_jwt_validation(&self, alg: Vec<Algorithm>) -> jsonwebtoken::Validation {
let required_claims = if self.validate_exp { let required_claims = if self.validate_exp {
let mut claims = HashSet::with_capacity(1); let mut claims = HashSet::with_capacity(1);
claims.insert("exp".to_owned()); claims.insert("exp".to_owned());
@ -103,7 +103,7 @@ impl Validation {
jwt_validation.iss = iss; jwt_validation.iss = iss;
jwt_validation.aud = aud; jwt_validation.aud = aud;
jwt_validation.sub = None; jwt_validation.sub = None;
jwt_validation.algorithms = vec![alg]; jwt_validation.algorithms = alg;
if !self.validate_signature { if !self.validate_signature {
jwt_validation.insecure_disable_signature_validation(); jwt_validation.insecure_disable_signature_validation();
} }