refactor: better url error checking (jwks, oidc)

This commit is contained in:
cduvray 2023-02-05 09:40:53 +01:00
parent b189caaab8
commit f1b11ecf3b
4 changed files with 43 additions and 17 deletions

View file

@ -1,6 +1,7 @@
use std::io::Read; use std::io::Read;
use jsonwebtoken::{decode, decode_header, jwk::JwkSet, DecodingKey, TokenData, Validation}; use jsonwebtoken::{decode, decode_header, jwk::JwkSet, DecodingKey, TokenData, Validation};
use reqwest::Url;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use crate::{ use crate::{
@ -104,15 +105,18 @@ where
} }
} }
KeySourceType::Jwks(url) => { KeySourceType::Jwks(url) => {
let key_store_manager = KeyStoreManager::new(url, refresh.unwrap_or_default()); let jwks_url = Url::parse(url).map_err(|e| InitError::JwksUrlError(e.to_string()))?;
let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default());
Authorizer { Authorizer {
key_source: KeySource::KeyStoreSource(key_store_manager), key_source: KeySource::KeyStoreSource(key_store_manager),
claims_checker, claims_checker,
} }
} }
KeySourceType::Discovery(issuer_url) => { KeySourceType::Discovery(issuer_url) => {
let jwks_url = oidc::discover_jwks(issuer_url).await?; let jwks_url = Url::parse(&oidc::discover_jwks(issuer_url).await?)
let key_store_manager = KeyStoreManager::new(&jwks_url, refresh.unwrap_or_default()); .map_err(|e| InitError::JwksUrlError(e.to_string()))?;
let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default());
Authorizer { Authorizer {
key_source: KeySource::KeyStoreSource(key_store_manager), key_source: KeySource::KeyStoreSource(key_store_manager),
claims_checker, claims_checker,
@ -146,7 +150,7 @@ mod tests {
use super::{Authorizer, KeySourceType}; use super::{Authorizer, KeySourceType};
#[tokio::test] #[tokio::test]
async fn from_secret() { async fn build_from_secret() {
let h = Header::new(Algorithm::HS256); let h = Header::new(Algorithm::HS256);
let a = Authorizer::<Value>::build(&KeySourceType::Secret("xxxxxx"), None, None) let a = Authorizer::<Value>::build(&KeySourceType::Secret("xxxxxx"), None, None)
.await .await
@ -156,7 +160,7 @@ mod tests {
} }
#[tokio::test] #[tokio::test]
async fn from_jwks() { async fn build_from_jwks_string() {
let jwks = r#" let jwks = r#"
{"keys": [{ {"keys": [{
"kid": "1", "kid": "1",
@ -175,7 +179,7 @@ mod tests {
} }
#[tokio::test] #[tokio::test]
async fn from_file() { async fn build_from_file() {
let a = Authorizer::<Value>::build(&KeySourceType::RSA("../config/jwtRS256.key.pub".to_owned()), None, None) let a = Authorizer::<Value>::build(&KeySourceType::RSA("../config/jwtRS256.key.pub".to_owned()), None, None)
.await .await
.unwrap(); .unwrap();
@ -196,9 +200,23 @@ mod tests {
} }
#[tokio::test] #[tokio::test]
async fn from_file_errors() { async fn build_file_errors() {
let a = Authorizer::<Value>::build(&KeySourceType::RSA("./config/does-not-exist.pem".to_owned()), None, None).await; let a = Authorizer::<Value>::build(&KeySourceType::RSA("./config/does-not-exist.pem".to_owned()), None, None).await;
println!("{:?}", a.as_ref().err()); println!("{:?}", a.as_ref().err());
assert!(a.is_err()); assert!(a.is_err());
} }
#[tokio::test]
async fn build_jwks_url_error() {
let a = Authorizer::<Value>::build(&&KeySourceType::Jwks("://xxxx".to_owned()), None, None).await;
println!("{:?}", a.as_ref().err());
assert!(a.is_err());
}
#[tokio::test]
async fn build_discovery_url_error() {
let a = Authorizer::<Value>::build(&&KeySourceType::Discovery("://xxxx".to_owned()), None, None).await;
println!("{:?}", a.as_ref().err());
assert!(a.is_err());
}
} }

View file

@ -23,6 +23,9 @@ pub enum InitError {
#[error("Builder Error {0}")] #[error("Builder Error {0}")]
DiscoveryError(String), DiscoveryError(String),
#[error("Builder Error {0}")]
JwksUrlError(String),
#[error("Jwks Parsing Error {0}")] #[error("Jwks Parsing Error {0}")]
JwksParsingError(#[from] serde_json::Error), JwksParsingError(#[from] serde_json::Error),
} }

View file

@ -2,6 +2,7 @@ use jsonwebtoken::{
jwk::{Jwk, JwkSet}, jwk::{Jwk, JwkSet},
Algorithm, DecodingKey, Algorithm, DecodingKey,
}; };
use reqwest::Url;
use std::{ use std::{
sync::Arc, sync::Arc,
time::{Duration, Instant}, time::{Duration, Instant},
@ -47,7 +48,7 @@ impl Default for Refresh {
#[derive(Clone)] #[derive(Clone)]
pub struct KeyStoreManager { pub struct KeyStoreManager {
key_url: String, key_url: Url,
/// in case of fail loading (error or key not found), minimal interval /// in case of fail loading (error or key not found), minimal interval
refresh: Refresh, refresh: Refresh,
keystore: Arc<Mutex<KeyStore>>, keystore: Arc<Mutex<KeyStore>>,
@ -63,9 +64,9 @@ pub struct KeyStore {
} }
impl KeyStoreManager { impl KeyStoreManager {
pub(crate) fn new(url: &str, refresh: Refresh) -> KeyStoreManager { pub(crate) fn new(key_url: Url, refresh: Refresh) -> KeyStoreManager {
KeyStoreManager { KeyStoreManager {
key_url: url.to_owned(), key_url,
refresh, refresh,
keystore: Arc::new(Mutex::new(KeyStore { keystore: Arc::new(Mutex::new(KeyStore {
jwks: JwkSet { keys: vec![] }, jwks: JwkSet { keys: vec![] },
@ -164,9 +165,9 @@ impl KeyStore {
} }
} }
async fn refresh(&mut self, key_url: &str, qparam: &[(&str, &str)]) -> Result<(), AuthError> { async fn refresh(&mut self, key_url: &Url, qparam: &[(&str, &str)]) -> Result<(), AuthError> {
reqwest::Client::new() reqwest::Client::new()
.get(key_url) .get(key_url.as_ref())
.query(qparam) .query(qparam)
.send() .send()
.await .await
@ -212,6 +213,7 @@ mod tests {
use jsonwebtoken::Algorithm; use jsonwebtoken::Algorithm;
use jsonwebtoken::{jwk::Jwk, Header}; use jsonwebtoken::{jwk::Jwk, Header};
use reqwest::Url;
use wiremock::{ use wiremock::{
matchers::{method, path}, matchers::{method, path},
Mock, MockServer, ResponseTemplate, Mock, MockServer, ResponseTemplate,
@ -330,7 +332,7 @@ mod tests {
.await; .await;
let ksm = KeyStoreManager::new( let ksm = KeyStoreManager::new(
&mock_server.uri(), Url::parse(&mock_server.uri()).unwrap(),
Refresh { Refresh {
strategy: RefreshStrategy::Interval, strategy: RefreshStrategy::Interval,
refresh_interval: Duration::from_secs(3000), refresh_interval: Duration::from_secs(3000),
@ -359,7 +361,7 @@ mod tests {
.await; .await;
let mut ksm = KeyStoreManager::new( let mut ksm = KeyStoreManager::new(
&mock_server.uri(), Url::parse(&mock_server.uri()).unwrap(),
Refresh { Refresh {
strategy: RefreshStrategy::KeyNotFound, strategy: RefreshStrategy::KeyNotFound,
..Default::default() ..Default::default()
@ -435,7 +437,7 @@ mod tests {
.await; .await;
let ksm = KeyStoreManager::new( let ksm = KeyStoreManager::new(
&mock_server.uri(), Url::parse(&mock_server.uri()).unwrap(),
Refresh { Refresh {
strategy: RefreshStrategy::NoRefresh, strategy: RefreshStrategy::NoRefresh,
..Default::default() ..Default::default()

View file

@ -9,7 +9,10 @@ pub struct OidcDiscovery {
} }
pub async fn discover_jwks(issuer: &str) -> Result<String, InitError> { pub async fn discover_jwks(issuer: &str) -> Result<String, InitError> {
let discovery_url = format!("{issuer}/.well-known/openid-configuration"); let discovery_url = reqwest::Url::parse(issuer)
.map_err(|e| InitError::DiscoveryError(e.to_string()))?
.join("/.well-known/openid-configuration")
.map_err(|e| InitError::DiscoveryError(e.to_string()))?;
reqwest::Client::new() reqwest::Client::new()
.get(discovery_url) .get(discovery_url)
.send() .send()