refacor: KeyStore, KeySource

- (performance) build DecodingKey once (per refresh)
- (security) store algorithm in KeyData
This commit is contained in:
cduvray 2023-03-15 08:21:04 +01:00
parent 8f55bf9d3e
commit ca14e15b67
6 changed files with 125 additions and 63 deletions

View file

@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased
## 0.8.1 (?-?-?)
No public API changes, no new features.
### Changed
- KeyStore, KeySource refactor for better performance and security
## 0.8.0 (2023-02-28)
### Added

View file

@ -1,12 +1,12 @@
use std::io::Read;
use std::{io::Read, sync::Arc};
use jsonwebtoken::{decode, decode_header, jwk::JwkSet, DecodingKey, TokenData};
use jsonwebtoken::{decode, decode_header, jwk::JwkSet, Algorithm, DecodingKey, TokenData};
use reqwest::Url;
use serde::de::DeserializeOwned;
use crate::{
error::{AuthError, InitError},
jwks::{key_store_manager::KeyStoreManager, KeySource},
jwks::{key_store_manager::KeyStoreManager, KeyData, KeySource},
oidc, Refresh,
};
@ -71,7 +71,11 @@ where
KeySourceType::RSA(path) => {
let key = DecodingKey::from_rsa_pem(&read_data(path.as_str())?)?;
Authorizer {
key_source: KeySource::DecodingKeySource(key),
key_source: KeySource::SingleKeySource(Arc::new(KeyData {
kid: None,
alg: vec![Algorithm::RS256, Algorithm::RS384, Algorithm::RS512],
key,
})),
claims_checker,
validation,
}
@ -79,7 +83,11 @@ where
KeySourceType::EC(path) => {
let key = DecodingKey::from_ec_pem(&read_data(path.as_str())?)?;
Authorizer {
key_source: KeySource::DecodingKeySource(key),
key_source: KeySource::SingleKeySource(Arc::new(KeyData {
kid: None,
alg: vec![Algorithm::ES256, Algorithm::ES384],
key,
})),
claims_checker,
validation,
}
@ -87,7 +95,11 @@ where
KeySourceType::ED(path) => {
let key = DecodingKey::from_ed_pem(&read_data(path.as_str())?)?;
Authorizer {
key_source: KeySource::DecodingKeySource(key),
key_source: KeySource::SingleKeySource(Arc::new(KeyData {
kid: None,
alg: vec![Algorithm::EdDSA],
key,
})),
claims_checker,
validation,
}
@ -95,7 +107,11 @@ where
KeySourceType::Secret(secret) => {
let key = DecodingKey::from_secret(secret.as_bytes());
Authorizer {
key_source: KeySource::DecodingKeySource(key),
key_source: KeySource::SingleKeySource(Arc::new(KeyData {
kid: None,
alg: vec![Algorithm::HS256, Algorithm::HS384, Algorithm::HS512],
key,
})),
claims_checker,
validation,
}
@ -104,9 +120,9 @@ where
// TODO: expose it in JwtAuthorizer or remove
let set: JwkSet = serde_json::from_str(jwks_str)?;
// TODO: replace [0] by kid/alg search
let k = DecodingKey::from_jwk(&set.keys[0])?;
let k = KeyData::from_jwk(&set.keys[0]).map_err(InitError::KeyDecodingError)?;
Authorizer {
key_source: KeySource::DecodingKeySource(k),
key_source: KeySource::SingleKeySource(Arc::new(k)),
claims_checker,
validation,
}
@ -136,11 +152,10 @@ where
pub async fn check_auth(&self, token: &str) -> Result<TokenData<C>, AuthError> {
let header = decode_header(token)?;
// TODO: build validation only once or cache it (store it in key_source?)
// (problem: alg family is checked in jsonwebtoken but may change with store refresh)
let jwt_validation = &self.validation.to_jwt_validation(header.alg);
let decoding_key = self.key_source.get_key(header).await?;
let token_data = decode::<C>(token, &decoding_key, jwt_validation)?;
// TODO: (optimisation) build & store jwt_validation in key data, to avoid rebuilding it for each check
let val_key = self.key_source.get_key(header).await?;
let jwt_validation = &self.validation.to_jwt_validation(val_key.alg.clone());
let token_data = decode::<C>(token, &val_key.key, jwt_validation)?;
if let Some(ref checker) = self.claims_checker {
if !checker.check(&token_data.claims) {

View file

@ -18,7 +18,7 @@ pub enum InitError {
KeyFileError(#[from] std::io::Error),
#[error(transparent)]
KeyFileDecodingError(#[from] jsonwebtoken::errors::Error),
KeyDecodingError(#[from] jsonwebtoken::errors::Error),
#[error("Builder Error {0}")]
DiscoveryError(String),
@ -35,8 +35,8 @@ pub enum AuthError {
#[error(transparent)]
JwksSerialisationError(#[from] serde_json::Error),
#[error(transparent)]
JwksRefreshError(#[from] reqwest::Error),
#[error("JwksRefreshError {0}")]
JwksRefreshError(String),
#[error("InvalidKey {0}")]
InvalidKey(String),

View file

@ -1,7 +1,4 @@
use jsonwebtoken::{
jwk::{Jwk, JwkSet},
Algorithm, DecodingKey,
};
use jsonwebtoken::{jwk::JwkSet, Algorithm};
use reqwest::Url;
use std::{
sync::Arc,
@ -11,8 +8,10 @@ use tokio::sync::Mutex;
use crate::error::AuthError;
use super::KeyData;
/// Defines the strategy for the JWKS refresh.
#[derive(Clone, Copy)]
#[derive(Clone)]
pub enum RefreshStrategy {
/// refresh periodicaly
Interval,
@ -25,7 +24,7 @@ pub enum RefreshStrategy {
}
/// JWKS Refresh configuration
#[derive(Clone, Copy)]
#[derive(Clone)]
pub struct Refresh {
pub strategy: RefreshStrategy,
/// After the refresh interval the store will/can be refreshed.
@ -60,7 +59,7 @@ pub struct KeyStoreManager {
pub struct KeyStore {
/// key set
jwks: JwkSet,
keys: Vec<Arc<KeyData>>,
/// time of the last successfully loaded jwkset
load_time: Option<Instant>,
/// time of the last failed load
@ -73,14 +72,14 @@ impl KeyStoreManager {
key_url,
refresh,
keystore: Arc::new(Mutex::new(KeyStore {
jwks: JwkSet { keys: vec![] },
keys: vec![],
load_time: None,
fail_time: None,
})),
}
}
pub(crate) async fn get_key(&self, header: &jsonwebtoken::Header) -> Result<jsonwebtoken::DecodingKey, AuthError> {
pub(crate) async fn get_key(&self, header: &jsonwebtoken::Header) -> Result<Arc<KeyData>, AuthError> {
let kstore = self.keystore.clone();
let mut ks_gard = kstore.lock().await;
let key = match self.refresh.strategy {
@ -141,8 +140,7 @@ impl KeyStoreManager {
}
}
};
DecodingKey::from_jwk(key).map_err(|err| AuthError::InvalidKey(err.to_string()))
Ok(key.clone())
}
}
@ -169,46 +167,55 @@ impl KeyStore {
.await
.map_err(|e| {
self.fail_time = Some(Instant::now());
AuthError::JwksRefreshError(e)
AuthError::JwksRefreshError(e.to_string())
})?
.json::<JwkSet>()
.await
.map(|jwks| {
self.load_time = Some(Instant::now());
self.jwks = jwks;
// self.jwks = jwks;
let mut keys: Vec<Arc<KeyData>> = Vec::with_capacity(jwks.keys.len());
for jwk in jwks.keys {
match KeyData::from_jwk(&jwk) {
Ok(kdata) => keys.push(Arc::new(kdata)),
Err(err) => {
tracing::warn!("Jwk decoding error, the key will be ignored! ({})", err);
}
};
}
if keys.is_empty() {
Err(AuthError::JwksRefreshError("No valid keys in the Jwk Set!".to_owned()))
} else {
self.keys = keys;
self.fail_time = None;
Ok(())
}
})
.map_err(|e| {
self.fail_time = Some(Instant::now());
AuthError::JwksRefreshError(e)
AuthError::JwksRefreshError(e.to_string())
})?
}
/// Find the key in the set that matches the given key id, if any.
pub fn find_kid(&self, kid: &str) -> Option<&Jwk> {
self.jwks.find(kid)
pub fn find_kid(&self, kid: &str) -> Option<&Arc<KeyData>> {
self.keys.iter().find(|k| k.kid.is_some() && k.kid.as_ref().unwrap() == kid)
}
/// Find the key in the set that matches the given key id, if any.
pub fn find_alg(&self, alg: &Algorithm) -> Option<&Jwk> {
self.jwks.keys.iter().find(|jwk| {
if let Some(ref a) = jwk.common.algorithm {
alg == a
} else {
false
}
})
pub fn find_alg(&self, alg: &Algorithm) -> Option<&Arc<KeyData>> {
self.keys.iter().find(|k| k.alg.contains(alg))
}
/// Find first key.
pub fn find_first(&self) -> Option<&Jwk> {
self.jwks.keys.get(0)
pub fn find_first(&self) -> Option<&Arc<KeyData>> {
self.keys.get(0)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::time::{Duration, Instant};
use jsonwebtoken::Algorithm;
@ -220,8 +227,18 @@ mod tests {
};
use crate::jwks::key_store_manager::{KeyStore, KeyStoreManager};
use crate::jwks::KeyData;
use crate::{Refresh, RefreshStrategy};
const JWK_RSA01: &str = r#"{
"kty": "RSA",
"n": "2pQeZdxa7q093K7bj5h6-leIpxfTnuAxzXdhjfGEJHxmt2ekHyCBWWWXCBiDn2RTcEBcy6gZqOW45Uy_tw-5e-Px1xFj1PykGEkRlOpYSAeWsNaAWvvpGB9m4zQ0PgZeMDDXE5IIBrY6YAzmGQxV-fcGGLhJnXl0-5_z7tKC7RvBoT3SGwlc_AmJqpFtTpEBn_fDnyqiZbpcjXYLExFpExm41xDitRKHWIwfc3dV8_vlNntlxCPGy_THkjdXJoHv2IJmlhvmr5_h03iGMLWDKSywxOol_4Wc1BT7Hb6byMxW40GKwSJJ4p7W8eI5mqggRHc8jlwSsTN9LZ2VOvO-XiVShZRVg7JeraGAfWwaIgIJ1D8C1h5Pi0iFpp2suxpHAXHfyLMJXuVotpXbDh4NDX-A4KRMgaxcfAcui_x6gybksq6gF90-9nfQfmVMVJctZ6M-FvRr-itd1Nef5WAtwUp1qyZygAXU3cH3rarscajmurOsP6dE1OHl3grY_eZhQxk33VBK9lavqNKPg6Q_PLiq1ojbYBj3bcYifJrsNeQwxldQP83aWt5rGtgZTehKVJwa40Uy_Grae1iRnsDtdSy5sTJIJ6EiShnWAdMoGejdiI8vpkjrdU8SWH8lv1KXI54DsbyAuke2cYz02zPWc6JEotQqI0HwhzU0KHyoY4s",
"e": "AQAB",
"kid": "rsa01",
"alg": "RS256",
"use": "sig"
}"#;
const JWK_ED01: &str = r#"{
"kty": "OKP",
"use": "sig",
@ -264,7 +281,7 @@ mod tests {
fn keystore_can_refresh() {
// FAIL, NO LOAD
let ks = KeyStore {
jwks: jsonwebtoken::jwk::JwkSet { keys: vec![] },
keys: vec![],
fail_time: Instant::now().checked_sub(Duration::from_secs(5)),
load_time: None,
};
@ -273,7 +290,7 @@ mod tests {
// NO FAIL, LOAD
let ks = KeyStore {
jwks: jsonwebtoken::jwk::JwkSet { keys: vec![] },
keys: vec![],
fail_time: None,
load_time: Instant::now().checked_sub(Duration::from_secs(5)),
};
@ -282,7 +299,7 @@ mod tests {
// FAIL, LOAD
let ks = KeyStore {
jwks: jsonwebtoken::jwk::JwkSet { keys: vec![] },
keys: vec![],
fail_time: Instant::now().checked_sub(Duration::from_secs(5)),
load_time: Instant::now().checked_sub(Duration::from_secs(10)),
};
@ -293,25 +310,28 @@ mod tests {
#[test]
fn find_kid() {
let jwk0: Jwk = serde_json::from_str(r#"{"kid":"1","kty":"RSA","alg":"RS256","n":"xxxx","e":"AQAB"}"#).unwrap();
let jwk1: Jwk = serde_json::from_str(r#"{"kid":"2","kty":"RSA","alg":"RS256","n":"xxxx","e":"AQAB"}"#).unwrap();
let jwk0: Jwk = serde_json::from_str(JWK_RSA01).unwrap();
let jwk1: Jwk = serde_json::from_str(JWK_EC01).unwrap();
let ks = KeyStore {
load_time: None,
fail_time: None,
jwks: jsonwebtoken::jwk::JwkSet { keys: vec![jwk0, jwk1] },
keys: vec![
Arc::new(KeyData::from_jwk(&jwk0).unwrap()),
Arc::new(KeyData::from_jwk(&jwk1).unwrap()),
],
};
assert!(ks.find_kid("1").is_some());
assert!(ks.find_kid("2").is_some());
assert!(ks.find_kid("rsa01").is_some());
assert!(ks.find_kid("ec01").is_some());
assert!(ks.find_kid("3").is_none());
}
#[test]
fn find_alg() {
let jwk0: Jwk = serde_json::from_str(r#"{"kty": "RSA", "alg": "RS256", "n": "xxx","e": "yyy"}"#).unwrap();
let jwk0: Jwk = serde_json::from_str(JWK_RSA01).unwrap();
let ks = KeyStore {
load_time: None,
fail_time: None,
jwks: jsonwebtoken::jwk::JwkSet { keys: vec![jwk0] },
keys: vec![Arc::new(KeyData::from_jwk(&jwk0).unwrap())],
};
assert!(ks.find_alg(&Algorithm::RS256).is_some());
assert!(ks.find_alg(&Algorithm::EdDSA).is_none());

View file

@ -1,4 +1,6 @@
use jsonwebtoken::{DecodingKey, Header};
use std::sync::Arc;
use jsonwebtoken::{jwk::Jwk, Algorithm, DecodingKey, Header};
use crate::error::AuthError;
@ -8,17 +10,34 @@ pub mod key_store_manager;
#[derive(Clone)]
pub enum KeySource {
/// KeyDataSource managing a refreshable key sets
KeyStoreSource(KeyStoreManager),
DecodingKeySource(DecodingKey),
/// Manages one public key, initialized on startup
SingleKeySource(Arc<KeyData>),
}
#[derive(Clone)]
pub struct KeyData {
pub kid: Option<String>,
pub alg: Vec<Algorithm>,
pub key: DecodingKey,
}
impl KeyData {
pub fn from_jwk(key: &Jwk) -> Result<KeyData, jsonwebtoken::errors::Error> {
Ok(KeyData {
kid: key.common.key_id.clone(),
alg: vec![key.common.algorithm.unwrap_or(Algorithm::RS256)], // TODO: is this good default?
key: DecodingKey::from_jwk(key)?,
})
}
}
impl KeySource {
pub async fn get_key(&self, header: Header) -> Result<DecodingKey, AuthError> {
pub async fn get_key(&self, header: Header) -> Result<Arc<KeyData>, AuthError> {
match self {
KeySource::KeyStoreSource(kstore) => kstore.get_key(&header).await,
KeySource::DecodingKeySource(key) => {
Ok(key.clone()) // TODO: clone -> &
}
KeySource::SingleKeySource(key) => Ok(key.clone()),
}
}
}

View file

@ -82,7 +82,7 @@ impl Validation {
self
}
pub(crate) fn to_jwt_validation(&self, alg: Algorithm) -> jsonwebtoken::Validation {
pub(crate) fn to_jwt_validation(&self, alg: Vec<Algorithm>) -> jsonwebtoken::Validation {
let required_claims = if self.validate_exp {
let mut claims = HashSet::with_capacity(1);
claims.insert("exp".to_owned());
@ -103,7 +103,7 @@ impl Validation {
jwt_validation.iss = iss;
jwt_validation.aud = aud;
jwt_validation.sub = None;
jwt_validation.algorithms = vec![alg];
jwt_validation.algorithms = alg;
if !self.validate_signature {
jwt_validation.insecure_disable_signature_validation();
}