From ca14e15b679bf6f971fea7585bb936ebc62264bb Mon Sep 17 00:00:00 2001 From: cduvray Date: Wed, 15 Mar 2023 08:21:04 +0100 Subject: [PATCH] refacor: KeyStore, KeySource - (performance) build DecodingKey once (per refresh) - (security) store algorithm in KeyData --- CHANGELOG.md | 8 ++ jwt-authorizer/src/authorizer.rs | 43 ++++++--- jwt-authorizer/src/error.rs | 6 +- jwt-authorizer/src/jwks/key_store_manager.rs | 96 ++++++++++++-------- jwt-authorizer/src/jwks/mod.rs | 31 +++++-- jwt-authorizer/src/validation.rs | 4 +- 6 files changed, 125 insertions(+), 63 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 77ac772..96c0709 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +## 0.8.1 (?-?-?) + +No public API changes, no new features. + +### Changed + +- KeyStore, KeySource refactor for better performance and security + ## 0.8.0 (2023-02-28) ### Added diff --git a/jwt-authorizer/src/authorizer.rs b/jwt-authorizer/src/authorizer.rs index 0af9907..78e8ef8 100644 --- a/jwt-authorizer/src/authorizer.rs +++ b/jwt-authorizer/src/authorizer.rs @@ -1,12 +1,12 @@ -use std::io::Read; +use std::{io::Read, sync::Arc}; -use jsonwebtoken::{decode, decode_header, jwk::JwkSet, DecodingKey, TokenData}; +use jsonwebtoken::{decode, decode_header, jwk::JwkSet, Algorithm, DecodingKey, TokenData}; use reqwest::Url; use serde::de::DeserializeOwned; use crate::{ error::{AuthError, InitError}, - jwks::{key_store_manager::KeyStoreManager, KeySource}, + jwks::{key_store_manager::KeyStoreManager, KeyData, KeySource}, oidc, Refresh, }; @@ -71,7 +71,11 @@ where KeySourceType::RSA(path) => { let key = DecodingKey::from_rsa_pem(&read_data(path.as_str())?)?; Authorizer { - key_source: KeySource::DecodingKeySource(key), + key_source: KeySource::SingleKeySource(Arc::new(KeyData { + kid: None, + alg: vec![Algorithm::RS256, Algorithm::RS384, Algorithm::RS512], + key, + })), claims_checker, validation, } @@ -79,7 +83,11 @@ where KeySourceType::EC(path) => { let key = DecodingKey::from_ec_pem(&read_data(path.as_str())?)?; Authorizer { - key_source: KeySource::DecodingKeySource(key), + key_source: KeySource::SingleKeySource(Arc::new(KeyData { + kid: None, + alg: vec![Algorithm::ES256, Algorithm::ES384], + key, + })), claims_checker, validation, } @@ -87,7 +95,11 @@ where KeySourceType::ED(path) => { let key = DecodingKey::from_ed_pem(&read_data(path.as_str())?)?; Authorizer { - key_source: KeySource::DecodingKeySource(key), + key_source: KeySource::SingleKeySource(Arc::new(KeyData { + kid: None, + alg: vec![Algorithm::EdDSA], + key, + })), claims_checker, validation, } @@ -95,7 +107,11 @@ where KeySourceType::Secret(secret) => { let key = DecodingKey::from_secret(secret.as_bytes()); Authorizer { - key_source: KeySource::DecodingKeySource(key), + key_source: KeySource::SingleKeySource(Arc::new(KeyData { + kid: None, + alg: vec![Algorithm::HS256, Algorithm::HS384, Algorithm::HS512], + key, + })), claims_checker, validation, } @@ -104,9 +120,9 @@ where // TODO: expose it in JwtAuthorizer or remove let set: JwkSet = serde_json::from_str(jwks_str)?; // TODO: replace [0] by kid/alg search - let k = DecodingKey::from_jwk(&set.keys[0])?; + let k = KeyData::from_jwk(&set.keys[0]).map_err(InitError::KeyDecodingError)?; Authorizer { - key_source: KeySource::DecodingKeySource(k), + key_source: KeySource::SingleKeySource(Arc::new(k)), claims_checker, validation, } @@ -136,11 +152,10 @@ where pub async fn check_auth(&self, token: &str) -> Result, AuthError> { let header = decode_header(token)?; - // TODO: build validation only once or cache it (store it in key_source?) - // (problem: alg family is checked in jsonwebtoken but may change with store refresh) - let jwt_validation = &self.validation.to_jwt_validation(header.alg); - let decoding_key = self.key_source.get_key(header).await?; - let token_data = decode::(token, &decoding_key, jwt_validation)?; + // TODO: (optimisation) build & store jwt_validation in key data, to avoid rebuilding it for each check + let val_key = self.key_source.get_key(header).await?; + let jwt_validation = &self.validation.to_jwt_validation(val_key.alg.clone()); + let token_data = decode::(token, &val_key.key, jwt_validation)?; if let Some(ref checker) = self.claims_checker { if !checker.check(&token_data.claims) { diff --git a/jwt-authorizer/src/error.rs b/jwt-authorizer/src/error.rs index ec14cd0..a688e56 100644 --- a/jwt-authorizer/src/error.rs +++ b/jwt-authorizer/src/error.rs @@ -18,7 +18,7 @@ pub enum InitError { KeyFileError(#[from] std::io::Error), #[error(transparent)] - KeyFileDecodingError(#[from] jsonwebtoken::errors::Error), + KeyDecodingError(#[from] jsonwebtoken::errors::Error), #[error("Builder Error {0}")] DiscoveryError(String), @@ -35,8 +35,8 @@ pub enum AuthError { #[error(transparent)] JwksSerialisationError(#[from] serde_json::Error), - #[error(transparent)] - JwksRefreshError(#[from] reqwest::Error), + #[error("JwksRefreshError {0}")] + JwksRefreshError(String), #[error("InvalidKey {0}")] InvalidKey(String), diff --git a/jwt-authorizer/src/jwks/key_store_manager.rs b/jwt-authorizer/src/jwks/key_store_manager.rs index 8d31ce2..d44f9e7 100644 --- a/jwt-authorizer/src/jwks/key_store_manager.rs +++ b/jwt-authorizer/src/jwks/key_store_manager.rs @@ -1,7 +1,4 @@ -use jsonwebtoken::{ - jwk::{Jwk, JwkSet}, - Algorithm, DecodingKey, -}; +use jsonwebtoken::{jwk::JwkSet, Algorithm}; use reqwest::Url; use std::{ sync::Arc, @@ -11,8 +8,10 @@ use tokio::sync::Mutex; use crate::error::AuthError; +use super::KeyData; + /// Defines the strategy for the JWKS refresh. -#[derive(Clone, Copy)] +#[derive(Clone)] pub enum RefreshStrategy { /// refresh periodicaly Interval, @@ -25,7 +24,7 @@ pub enum RefreshStrategy { } /// JWKS Refresh configuration -#[derive(Clone, Copy)] +#[derive(Clone)] pub struct Refresh { pub strategy: RefreshStrategy, /// After the refresh interval the store will/can be refreshed. @@ -60,7 +59,7 @@ pub struct KeyStoreManager { pub struct KeyStore { /// key set - jwks: JwkSet, + keys: Vec>, /// time of the last successfully loaded jwkset load_time: Option, /// time of the last failed load @@ -73,14 +72,14 @@ impl KeyStoreManager { key_url, refresh, keystore: Arc::new(Mutex::new(KeyStore { - jwks: JwkSet { keys: vec![] }, + keys: vec![], load_time: None, fail_time: None, })), } } - pub(crate) async fn get_key(&self, header: &jsonwebtoken::Header) -> Result { + pub(crate) async fn get_key(&self, header: &jsonwebtoken::Header) -> Result, AuthError> { let kstore = self.keystore.clone(); let mut ks_gard = kstore.lock().await; let key = match self.refresh.strategy { @@ -141,8 +140,7 @@ impl KeyStoreManager { } } }; - - DecodingKey::from_jwk(key).map_err(|err| AuthError::InvalidKey(err.to_string())) + Ok(key.clone()) } } @@ -169,46 +167,55 @@ impl KeyStore { .await .map_err(|e| { self.fail_time = Some(Instant::now()); - AuthError::JwksRefreshError(e) + AuthError::JwksRefreshError(e.to_string()) })? .json::() .await .map(|jwks| { self.load_time = Some(Instant::now()); - self.jwks = jwks; - self.fail_time = None; - Ok(()) + // self.jwks = jwks; + let mut keys: Vec> = Vec::with_capacity(jwks.keys.len()); + for jwk in jwks.keys { + match KeyData::from_jwk(&jwk) { + Ok(kdata) => keys.push(Arc::new(kdata)), + Err(err) => { + tracing::warn!("Jwk decoding error, the key will be ignored! ({})", err); + } + }; + } + if keys.is_empty() { + Err(AuthError::JwksRefreshError("No valid keys in the Jwk Set!".to_owned())) + } else { + self.keys = keys; + self.fail_time = None; + Ok(()) + } }) .map_err(|e| { self.fail_time = Some(Instant::now()); - AuthError::JwksRefreshError(e) + AuthError::JwksRefreshError(e.to_string()) })? } /// Find the key in the set that matches the given key id, if any. - pub fn find_kid(&self, kid: &str) -> Option<&Jwk> { - self.jwks.find(kid) + pub fn find_kid(&self, kid: &str) -> Option<&Arc> { + self.keys.iter().find(|k| k.kid.is_some() && k.kid.as_ref().unwrap() == kid) } /// Find the key in the set that matches the given key id, if any. - pub fn find_alg(&self, alg: &Algorithm) -> Option<&Jwk> { - self.jwks.keys.iter().find(|jwk| { - if let Some(ref a) = jwk.common.algorithm { - alg == a - } else { - false - } - }) + pub fn find_alg(&self, alg: &Algorithm) -> Option<&Arc> { + self.keys.iter().find(|k| k.alg.contains(alg)) } /// Find first key. - pub fn find_first(&self) -> Option<&Jwk> { - self.jwks.keys.get(0) + pub fn find_first(&self) -> Option<&Arc> { + self.keys.get(0) } } #[cfg(test)] mod tests { + use std::sync::Arc; use std::time::{Duration, Instant}; use jsonwebtoken::Algorithm; @@ -220,8 +227,18 @@ mod tests { }; use crate::jwks::key_store_manager::{KeyStore, KeyStoreManager}; + use crate::jwks::KeyData; use crate::{Refresh, RefreshStrategy}; + const JWK_RSA01: &str = r#"{ + "kty": "RSA", + "n": "2pQeZdxa7q093K7bj5h6-leIpxfTnuAxzXdhjfGEJHxmt2ekHyCBWWWXCBiDn2RTcEBcy6gZqOW45Uy_tw-5e-Px1xFj1PykGEkRlOpYSAeWsNaAWvvpGB9m4zQ0PgZeMDDXE5IIBrY6YAzmGQxV-fcGGLhJnXl0-5_z7tKC7RvBoT3SGwlc_AmJqpFtTpEBn_fDnyqiZbpcjXYLExFpExm41xDitRKHWIwfc3dV8_vlNntlxCPGy_THkjdXJoHv2IJmlhvmr5_h03iGMLWDKSywxOol_4Wc1BT7Hb6byMxW40GKwSJJ4p7W8eI5mqggRHc8jlwSsTN9LZ2VOvO-XiVShZRVg7JeraGAfWwaIgIJ1D8C1h5Pi0iFpp2suxpHAXHfyLMJXuVotpXbDh4NDX-A4KRMgaxcfAcui_x6gybksq6gF90-9nfQfmVMVJctZ6M-FvRr-itd1Nef5WAtwUp1qyZygAXU3cH3rarscajmurOsP6dE1OHl3grY_eZhQxk33VBK9lavqNKPg6Q_PLiq1ojbYBj3bcYifJrsNeQwxldQP83aWt5rGtgZTehKVJwa40Uy_Grae1iRnsDtdSy5sTJIJ6EiShnWAdMoGejdiI8vpkjrdU8SWH8lv1KXI54DsbyAuke2cYz02zPWc6JEotQqI0HwhzU0KHyoY4s", + "e": "AQAB", + "kid": "rsa01", + "alg": "RS256", + "use": "sig" + }"#; + const JWK_ED01: &str = r#"{ "kty": "OKP", "use": "sig", @@ -264,7 +281,7 @@ mod tests { fn keystore_can_refresh() { // FAIL, NO LOAD let ks = KeyStore { - jwks: jsonwebtoken::jwk::JwkSet { keys: vec![] }, + keys: vec![], fail_time: Instant::now().checked_sub(Duration::from_secs(5)), load_time: None, }; @@ -273,7 +290,7 @@ mod tests { // NO FAIL, LOAD let ks = KeyStore { - jwks: jsonwebtoken::jwk::JwkSet { keys: vec![] }, + keys: vec![], fail_time: None, load_time: Instant::now().checked_sub(Duration::from_secs(5)), }; @@ -282,7 +299,7 @@ mod tests { // FAIL, LOAD let ks = KeyStore { - jwks: jsonwebtoken::jwk::JwkSet { keys: vec![] }, + keys: vec![], fail_time: Instant::now().checked_sub(Duration::from_secs(5)), load_time: Instant::now().checked_sub(Duration::from_secs(10)), }; @@ -293,25 +310,28 @@ mod tests { #[test] fn find_kid() { - let jwk0: Jwk = serde_json::from_str(r#"{"kid":"1","kty":"RSA","alg":"RS256","n":"xxxx","e":"AQAB"}"#).unwrap(); - let jwk1: Jwk = serde_json::from_str(r#"{"kid":"2","kty":"RSA","alg":"RS256","n":"xxxx","e":"AQAB"}"#).unwrap(); + let jwk0: Jwk = serde_json::from_str(JWK_RSA01).unwrap(); + let jwk1: Jwk = serde_json::from_str(JWK_EC01).unwrap(); let ks = KeyStore { load_time: None, fail_time: None, - jwks: jsonwebtoken::jwk::JwkSet { keys: vec![jwk0, jwk1] }, + keys: vec![ + Arc::new(KeyData::from_jwk(&jwk0).unwrap()), + Arc::new(KeyData::from_jwk(&jwk1).unwrap()), + ], }; - assert!(ks.find_kid("1").is_some()); - assert!(ks.find_kid("2").is_some()); + assert!(ks.find_kid("rsa01").is_some()); + assert!(ks.find_kid("ec01").is_some()); assert!(ks.find_kid("3").is_none()); } #[test] fn find_alg() { - let jwk0: Jwk = serde_json::from_str(r#"{"kty": "RSA", "alg": "RS256", "n": "xxx","e": "yyy"}"#).unwrap(); + let jwk0: Jwk = serde_json::from_str(JWK_RSA01).unwrap(); let ks = KeyStore { load_time: None, fail_time: None, - jwks: jsonwebtoken::jwk::JwkSet { keys: vec![jwk0] }, + keys: vec![Arc::new(KeyData::from_jwk(&jwk0).unwrap())], }; assert!(ks.find_alg(&Algorithm::RS256).is_some()); assert!(ks.find_alg(&Algorithm::EdDSA).is_none()); diff --git a/jwt-authorizer/src/jwks/mod.rs b/jwt-authorizer/src/jwks/mod.rs index 82d124c..28d86e3 100644 --- a/jwt-authorizer/src/jwks/mod.rs +++ b/jwt-authorizer/src/jwks/mod.rs @@ -1,4 +1,6 @@ -use jsonwebtoken::{DecodingKey, Header}; +use std::sync::Arc; + +use jsonwebtoken::{jwk::Jwk, Algorithm, DecodingKey, Header}; use crate::error::AuthError; @@ -8,17 +10,34 @@ pub mod key_store_manager; #[derive(Clone)] pub enum KeySource { + /// KeyDataSource managing a refreshable key sets KeyStoreSource(KeyStoreManager), - DecodingKeySource(DecodingKey), + /// Manages one public key, initialized on startup + SingleKeySource(Arc), +} + +#[derive(Clone)] +pub struct KeyData { + pub kid: Option, + pub alg: Vec, + pub key: DecodingKey, +} + +impl KeyData { + pub fn from_jwk(key: &Jwk) -> Result { + Ok(KeyData { + kid: key.common.key_id.clone(), + alg: vec![key.common.algorithm.unwrap_or(Algorithm::RS256)], // TODO: is this good default? + key: DecodingKey::from_jwk(key)?, + }) + } } impl KeySource { - pub async fn get_key(&self, header: Header) -> Result { + pub async fn get_key(&self, header: Header) -> Result, AuthError> { match self { KeySource::KeyStoreSource(kstore) => kstore.get_key(&header).await, - KeySource::DecodingKeySource(key) => { - Ok(key.clone()) // TODO: clone -> & - } + KeySource::SingleKeySource(key) => Ok(key.clone()), } } } diff --git a/jwt-authorizer/src/validation.rs b/jwt-authorizer/src/validation.rs index d354708..f609efb 100644 --- a/jwt-authorizer/src/validation.rs +++ b/jwt-authorizer/src/validation.rs @@ -82,7 +82,7 @@ impl Validation { self } - pub(crate) fn to_jwt_validation(&self, alg: Algorithm) -> jsonwebtoken::Validation { + pub(crate) fn to_jwt_validation(&self, alg: Vec) -> jsonwebtoken::Validation { let required_claims = if self.validate_exp { let mut claims = HashSet::with_capacity(1); claims.insert("exp".to_owned()); @@ -103,7 +103,7 @@ impl Validation { jwt_validation.iss = iss; jwt_validation.aud = aud; jwt_validation.sub = None; - jwt_validation.algorithms = vec![alg]; + jwt_validation.algorithms = alg; if !self.validate_signature { jwt_validation.insecure_disable_signature_validation(); }