Add support for reading keys from a static JWKS

Allow creating authorizer from JWKS files similar to other static
certificates.

Signed-off-by: Sjoerd Simons <sjoerd@collabora.com>
This commit is contained in:
Sjoerd Simons 2023-09-26 21:33:06 +02:00 committed by cduvray
parent 6e19f31c77
commit ef8ac07271
7 changed files with 175 additions and 26 deletions

View file

@ -41,6 +41,7 @@ pub enum KeySourceType {
EDString(String),
Secret(String),
Jwks(String),
JwksPath(String),
JwksString(String), // TODO: expose JwksString in JwtAuthorizer or remove it
Discovery(String),
}
@ -148,13 +149,36 @@ where
jwt_source,
}
}
KeySourceType::JwksPath(path) => {
let set: JwkSet = serde_json::from_slice(&read_data(path.as_str())?)?;
let keys = set
.keys
.iter()
.map(|k| match KeyData::from_jwk(k) {
Ok(kdata) => Ok(Arc::new(kdata)),
Err(err) => Err(InitError::KeyDecodingError(err)),
})
.collect::<Result<Vec<_>, _>>()?;
Authorizer {
key_source: KeySource::MultiKeySource(keys.into()),
claims_checker,
validation,
jwt_source,
}
}
KeySourceType::JwksString(jwks_str) => {
// TODO: expose it in JwtAuthorizer or remove
let set: JwkSet = serde_json::from_str(jwks_str.as_str())?;
// TODO: replace [0] by kid/alg search
let k = KeyData::from_jwk(&set.keys[0]).map_err(InitError::KeyDecodingError)?;
let keys = set
.keys
.iter()
.map(|k| match KeyData::from_jwk(k) {
Ok(kdata) => Ok(Arc::new(kdata)),
Err(err) => Err(InitError::KeyDecodingError(err)),
})
.collect::<Result<Vec<_>, _>>()?;
Authorizer {
key_source: KeySource::SingleKeySource(Arc::new(k)),
key_source: KeySource::MultiKeySource(keys.into()),
claims_checker,
validation,
jwt_source,
@ -363,6 +387,28 @@ mod tests {
.unwrap();
let k = a.key_source.get_key(Header::new(Algorithm::EdDSA));
assert!(k.await.is_ok());
let a = Authorizer::<Value>::build(
KeySourceType::JwksPath("../config/public1.jwks".to_owned()),
None,
None,
Validation::new(),
JwtSource::AuthorizationHeader,
)
.await
.unwrap();
a.key_source
.get_key(Header::new(Algorithm::RS256))
.await
.expect("Couldn't get RS256 key from jwk");
a.key_source
.get_key(Header::new(Algorithm::ES256))
.await
.expect("Couldn't get ES256 key from jwk");
a.key_source
.get_key(Header::new(Algorithm::EdDSA))
.await
.expect("Couldn't get EdDSA key from jwk");
}
#[tokio::test]

View file

@ -54,6 +54,26 @@ where
}
}
pub fn from_jwks(path: &str) -> AuthorizerBuilder<C> {
AuthorizerBuilder {
key_source_type: KeySourceType::JwksPath(path.to_owned()),
refresh: Default::default(),
claims_checker: None,
validation: None,
jwt_source: JwtSource::AuthorizationHeader,
}
}
pub fn from_jwks_text(text: &str) -> AuthorizerBuilder<C> {
AuthorizerBuilder {
key_source_type: KeySourceType::JwksString(text.to_owned()),
refresh: Default::default(),
claims_checker: None,
validation: None,
jwt_source: JwtSource::AuthorizationHeader,
}
}
/// Builds Authorizer Layer from a RSA PEM file
pub fn from_rsa_pem(path: &str) -> AuthorizerBuilder<C> {
AuthorizerBuilder {

View file

@ -8,7 +8,7 @@ use tokio::sync::Mutex;
use crate::error::AuthError;
use super::KeyData;
use super::{KeyData, KeySet};
/// Defines the strategy for the JWKS refresh.
#[derive(Clone)]
@ -59,7 +59,7 @@ pub struct KeyStoreManager {
pub struct KeyStore {
/// key set
keys: Vec<Arc<KeyData>>,
keys: KeySet,
/// time of the last successfully loaded jwkset
load_time: Option<Instant>,
/// time of the last failed load
@ -72,7 +72,7 @@ impl KeyStoreManager {
key_url,
refresh,
keystore: Arc::new(Mutex::new(KeyStore {
keys: vec![],
keys: KeySet::default(),
load_time: None,
fail_time: None,
})),
@ -87,11 +87,7 @@ impl KeyStoreManager {
if ks_gard.can_refresh(self.refresh.refresh_interval, self.refresh.retry_interval) {
ks_gard.refresh(&self.key_url, &[]).await?;
}
if let Some(ref kid) = header.kid {
ks_gard.find_kid(kid).ok_or_else(|| AuthError::InvalidKid(kid.to_owned()))?
} else {
ks_gard.find_alg(&header.alg).ok_or(AuthError::InvalidKeyAlg(header.alg))?
}
ks_gard.get_key(header)?
}
RefreshStrategy::KeyNotFound => {
if let Some(ref kid) = header.kid {
@ -133,11 +129,7 @@ impl KeyStoreManager {
{
ks_gard.refresh(&self.key_url, &[]).await?;
}
if let Some(ref kid) = header.kid {
ks_gard.find_kid(kid).ok_or_else(|| AuthError::InvalidKid(kid.to_owned()))?
} else {
ks_gard.find_alg(&header.alg).ok_or(AuthError::InvalidKeyAlg(header.alg))?
}
ks_gard.get_key(header)?
}
};
Ok(key.clone())
@ -186,7 +178,7 @@ impl KeyStore {
if keys.is_empty() {
Err(AuthError::JwksRefreshError("No valid keys in the Jwk Set!".to_owned()))
} else {
self.keys = keys;
self.keys = keys.into();
self.fail_time = None;
Ok(())
}
@ -199,17 +191,21 @@ impl KeyStore {
/// Find the key in the set that matches the given key id, if any.
pub fn find_kid(&self, kid: &str) -> Option<&Arc<KeyData>> {
self.keys.iter().find(|k| k.kid.is_some() && k.kid.as_ref().unwrap() == kid)
self.keys.find_kid(kid)
}
/// Find the key in the set that matches the given key id, if any.
pub fn find_alg(&self, alg: &Algorithm) -> Option<&Arc<KeyData>> {
self.keys.iter().find(|k| k.alg.contains(alg))
self.keys.find_alg(alg)
}
fn get_key(&self, header: &jsonwebtoken::Header) -> Result<&Arc<KeyData>, AuthError> {
self.keys.get_key(header)
}
/// Find first key.
pub fn find_first(&self) -> Option<&Arc<KeyData>> {
self.keys.get(0)
self.keys.first()
}
}
@ -227,7 +223,7 @@ mod tests {
};
use crate::jwks::key_store_manager::{KeyStore, KeyStoreManager};
use crate::jwks::KeyData;
use crate::jwks::{KeyData, KeySet};
use crate::{Refresh, RefreshStrategy};
const JWK_RSA01: &str = r#"{
@ -281,7 +277,7 @@ mod tests {
fn keystore_can_refresh() {
// FAIL, NO LOAD
let ks = KeyStore {
keys: vec![],
keys: KeySet::default(),
fail_time: Instant::now().checked_sub(Duration::from_secs(5)),
load_time: None,
};
@ -290,7 +286,7 @@ mod tests {
// NO FAIL, LOAD
let ks = KeyStore {
keys: vec![],
keys: KeySet::default(),
fail_time: None,
load_time: Instant::now().checked_sub(Duration::from_secs(5)),
};
@ -299,7 +295,7 @@ mod tests {
// FAIL, LOAD
let ks = KeyStore {
keys: vec![],
keys: KeySet::default(),
fail_time: Instant::now().checked_sub(Duration::from_secs(5)),
load_time: Instant::now().checked_sub(Duration::from_secs(10)),
};
@ -318,7 +314,8 @@ mod tests {
keys: vec![
Arc::new(KeyData::from_jwk(&jwk0).unwrap()),
Arc::new(KeyData::from_jwk(&jwk1).unwrap()),
],
]
.into(),
};
assert!(ks.find_kid("rsa01").is_some());
assert!(ks.find_kid("ec01").is_some());
@ -331,7 +328,7 @@ mod tests {
let ks = KeyStore {
load_time: None,
fail_time: None,
keys: vec![Arc::new(KeyData::from_jwk(&jwk0).unwrap())],
keys: vec![Arc::new(KeyData::from_jwk(&jwk0).unwrap())].into(),
};
assert!(ks.find_alg(&Algorithm::RS256).is_some());
assert!(ks.find_alg(&Algorithm::EdDSA).is_none());

View file

@ -12,6 +12,8 @@ pub mod key_store_manager;
pub enum KeySource {
/// KeyDataSource managing a refreshable key sets
KeyStoreSource(KeyStoreManager),
/// Manages public key sets, initialized on startup
MultiKeySource(KeySet),
/// Manages one public key, initialized on startup
SingleKeySource(Arc<KeyData>),
}
@ -33,10 +35,49 @@ impl KeyData {
}
}
#[derive(Clone, Default)]
pub struct KeySet(Vec<Arc<KeyData>>);
impl From<Vec<Arc<KeyData>>> for KeySet {
fn from(value: Vec<Arc<KeyData>>) -> Self {
KeySet(value)
}
}
impl KeySet {
/// Find the key in the set that matches the given key id, if any.
pub fn find_kid(&self, kid: &str) -> Option<&Arc<KeyData>> {
self.0.iter().find(|k| match &k.kid {
Some(k) => k == kid,
None => false,
})
}
/// Find the key in the set that matches the given key id, if any.
pub fn find_alg(&self, alg: &Algorithm) -> Option<&Arc<KeyData>> {
self.0.iter().find(|k| k.alg.contains(alg))
}
/// Find first key.
pub fn first(&self) -> Option<&Arc<KeyData>> {
self.0.first()
}
pub(crate) fn get_key(&self, header: &Header) -> Result<&Arc<KeyData>, AuthError> {
let key = if let Some(ref kid) = header.kid {
self.find_kid(kid).ok_or_else(|| AuthError::InvalidKid(kid.to_owned()))?
} else {
self.find_alg(&header.alg).ok_or(AuthError::InvalidKeyAlg(header.alg))?
};
Ok(key)
}
}
impl KeySource {
pub async fn get_key(&self, header: Header) -> Result<Arc<KeyData>, AuthError> {
match self {
KeySource::KeyStoreSource(kstore) => kstore.get_key(&header).await,
KeySource::MultiKeySource(keys) => keys.get_key(&header).cloned(),
KeySource::SingleKeySource(key) => Ok(key.clone()),
}
}

View file

@ -114,6 +114,21 @@ mod tests {
assert_eq!(response.status(), StatusCode::OK);
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
assert_eq!(&body[..], b"hello: b@b.com");
let response = make_proteced_request(JwtAuthorizer::from_jwks("../config/public1.jwks"), common::JWT_RSA1_OK).await;
assert_eq!(response.status(), StatusCode::OK);
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
assert_eq!(&body[..], b"hello: b@b.com");
let response = make_proteced_request(JwtAuthorizer::from_jwks("../config/public1.jwks"), common::JWT_EC1_OK).await;
assert_eq!(response.status(), StatusCode::OK);
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
assert_eq!(&body[..], b"hello: b@b.com");
let response = make_proteced_request(JwtAuthorizer::from_jwks("../config/public1.jwks"), common::JWT_ED1_OK).await;
assert_eq!(response.status(), StatusCode::OK);
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
assert_eq!(&body[..], b"hello: b@b.com");
}
#[tokio::test]