refactor: refresh simplification

- difference between refresh_interval and minimal_refresh_interval was no clear,
- therfore they were merged
This commit is contained in:
cduvray 2023-02-12 09:06:18 +01:00
parent 9c45a43584
commit a8b510a03e
2 changed files with 142 additions and 126 deletions

View file

@ -11,15 +11,16 @@ use tokio::sync::Mutex;
use crate::error::AuthError; use crate::error::AuthError;
/// Defines the strategy for the JWKS refresh.
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
pub enum RefreshStrategy { pub enum RefreshStrategy {
/// refresh periodicaly /// refresh periodicaly
Interval, Interval,
/// when kid not found in the store /// refresh when kid not found in the store
KeyNotFound, KeyNotFound,
/// load once triggered by the first use /// loading is triggered only once by the first use
NoRefresh, NoRefresh,
} }
@ -27,11 +28,15 @@ pub enum RefreshStrategy {
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
pub struct Refresh { pub struct Refresh {
pub strategy: RefreshStrategy, pub strategy: RefreshStrategy,
// after the interval the store will be refreshed (before getting a new key - lazy behaviour) /// After the refresh interval the store will/can be refreshed.
///
/// - RefreshStrategy::KeyNotFound - refresh will be performed only if a kid is not found in the store
/// (if no kid is in the token header the alg is looked up)
/// - RefreshStrategy::Interval - refresh will be performed each time the refresh interval has elapsed
/// (before checking a new token -> lazy behaviour)
pub refresh_interval: Duration, pub refresh_interval: Duration,
// don't refresh before (counting from the last refresh, when the kid not found) /// don't refresh before (after an error or jwks is unawailable)
pub minimal_refresh_interval: Duration, /// (we let a little bit of time to the jwks endpoint to recover)
// don't refresh before (after an error or jwks unawailable)
pub retry_interval: Duration, pub retry_interval: Duration,
} }
@ -40,7 +45,6 @@ impl Default for Refresh {
Self { Self {
strategy: RefreshStrategy::KeyNotFound, strategy: RefreshStrategy::KeyNotFound,
refresh_interval: Duration::from_secs(600), refresh_interval: Duration::from_secs(600),
minimal_refresh_interval: Duration::from_secs(30),
retry_interval: Duration::from_secs(10), retry_interval: Duration::from_secs(10),
} }
} }
@ -81,9 +85,7 @@ impl KeyStoreManager {
let mut ks_gard = kstore.lock().await; let mut ks_gard = kstore.lock().await;
let key = match self.refresh.strategy { let key = match self.refresh.strategy {
RefreshStrategy::Interval => { RefreshStrategy::Interval => {
if ks_gard.should_refresh(self.refresh.refresh_interval) if ks_gard.can_refresh(self.refresh.refresh_interval, self.refresh.retry_interval) {
&& ks_gard.can_refresh(self.refresh.minimal_refresh_interval, self.refresh.retry_interval)
{
ks_gard.refresh(&self.key_url, &[]).await?; ks_gard.refresh(&self.key_url, &[]).await?;
} }
if let Some(ref kid) = header.kid { if let Some(ref kid) = header.kid {
@ -97,7 +99,7 @@ impl KeyStoreManager {
let jwk_opt = ks_gard.find_kid(kid); let jwk_opt = ks_gard.find_kid(kid);
if let Some(jwk) = jwk_opt { if let Some(jwk) = jwk_opt {
jwk jwk
} else if ks_gard.can_refresh(self.refresh.minimal_refresh_interval, self.refresh.retry_interval) { } else if ks_gard.can_refresh(self.refresh.refresh_interval, self.refresh.retry_interval) {
ks_gard.refresh(&self.key_url, &[("kid", kid)]).await?; ks_gard.refresh(&self.key_url, &[("kid", kid)]).await?;
ks_gard.find_kid(kid).ok_or_else(|| AuthError::InvalidKid(kid.to_owned()))? ks_gard.find_kid(kid).ok_or_else(|| AuthError::InvalidKid(kid.to_owned()))?
} else { } else {
@ -105,10 +107,9 @@ impl KeyStoreManager {
} }
} else { } else {
let jwk_opt = ks_gard.find_alg(&header.alg); let jwk_opt = ks_gard.find_alg(&header.alg);
// .ok_or(AuthError::InvalidKeyAlg(header.alg))?
if let Some(jwk) = jwk_opt { if let Some(jwk) = jwk_opt {
jwk jwk
} else if ks_gard.can_refresh(self.refresh.minimal_refresh_interval, self.refresh.retry_interval) { } else if ks_gard.can_refresh(self.refresh.refresh_interval, self.refresh.retry_interval) {
ks_gard ks_gard
.refresh( .refresh(
&self.key_url, &self.key_url,
@ -127,7 +128,10 @@ impl KeyStoreManager {
} }
} }
RefreshStrategy::NoRefresh => { RefreshStrategy::NoRefresh => {
if ks_gard.load_time.is_none() { if ks_gard.load_time.is_none()
// if jwks endpoint is down for the loading, respect retry_interval
&& ks_gard.can_refresh(self.refresh.refresh_interval, self.refresh.retry_interval)
{
ks_gard.refresh(&self.key_url, &[]).await?; ks_gard.refresh(&self.key_url, &[]).await?;
} }
if let Some(ref kid) = header.kid { if let Some(ref kid) = header.kid {
@ -143,23 +147,15 @@ impl KeyStoreManager {
} }
impl KeyStore { impl KeyStore {
fn should_refresh(&self, refresh_interval: Duration) -> bool { fn can_refresh(&self, refresh_interval: Duration, minimal_retry: Duration) -> bool {
if let Some(t) = self.load_time {
t.elapsed() > refresh_interval
} else {
true
}
}
fn can_refresh(&self, minimal_refresh_interval: Duration, minimal_retry: Duration) -> bool {
if let Some(fail_tm) = self.fail_time { if let Some(fail_tm) = self.fail_time {
if let Some(load_tm) = self.load_time { if let Some(load_tm) = self.load_time {
fail_tm.elapsed() > minimal_retry && load_tm.elapsed() > minimal_refresh_interval fail_tm.elapsed() > minimal_retry && load_tm.elapsed() > refresh_interval
} else { } else {
fail_tm.elapsed() > minimal_retry fail_tm.elapsed() > minimal_retry
} }
} else if let Some(load_tm) = self.load_time { } else if let Some(load_tm) = self.load_time {
load_tm.elapsed() > minimal_refresh_interval load_tm.elapsed() > refresh_interval
} else { } else {
true true
} }
@ -171,12 +167,16 @@ impl KeyStore {
.query(qparam) .query(qparam)
.send() .send()
.await .await
.map_err(AuthError::JwksRefreshError)? .map_err(|e| {
self.fail_time = Some(Instant::now());
AuthError::JwksRefreshError(e)
})?
.json::<JwkSet>() .json::<JwkSet>()
.await .await
.map(|jwks| { .map(|jwks| {
self.load_time = Some(Instant::now()); self.load_time = Some(Instant::now());
self.jwks = jwks; self.jwks = jwks;
self.fail_time = None;
Ok(()) Ok(())
}) })
.map_err(|e| { .map_err(|e| {
@ -222,24 +222,43 @@ mod tests {
use crate::jwks::key_store_manager::{KeyStore, KeyStoreManager}; use crate::jwks::key_store_manager::{KeyStore, KeyStoreManager};
use crate::{Refresh, RefreshStrategy}; use crate::{Refresh, RefreshStrategy};
#[test] const JWK_ED01: &str = r#"{
fn keystore_should_refresh() { "kty": "OKP",
let ks = KeyStore { "use": "sig",
jwks: jsonwebtoken::jwk::JwkSet { keys: vec![] }, "crv": "Ed25519",
fail_time: None, "x": "uWtSkE-I9aTMYTTvuTE1rtu0rNdxp3DU33cJ_ksL1Gk",
load_time: Some(Instant::now()), "kid": "ed01",
}; "alg": "EdDSA"
}"#;
assert!(!ks.should_refresh(Duration::from_secs(5))); const JWK_ED02: &str = r#"{
"kty": "OKP",
"use": "sig",
"crv": "Ed25519",
"x": "uWtSkE-I9aTMYTTvuTE1rtu0rNdxp3DU33cJ_ksL1Gk",
"kid": "ed02",
"alg": "EdDSA"
}"#;
let ks = KeyStore { const JWK_EC01: &str = r#"{
jwks: jsonwebtoken::jwk::JwkSet { keys: vec![] }, "kty": "EC",
fail_time: None, "crv": "P-256",
load_time: Some(Instant::now() - Duration::from_secs(6)), "x": "w7JAoU_gJbZJvV-zCOvU9yFJq0FNC_edCMRM78P8eQQ",
}; "y": "wQg1EytcsEmGrM70Gb53oluoDbVhCZ3Uq3hHMslHVb4",
"kid": "ec01",
"alg": "ES256",
"use": "sig"
}"#;
assert!(ks.should_refresh(Duration::from_secs(5))); const JWK_EC02: &str = r#"{
} "kty": "EC",
"crv": "P-256",
"x": "w7JAoU_gJbZJvV-zCOvU9yFJq0FNC_edCMRM78P8eQQ",
"y": "wQg1EytcsEmGrM70Gb53oluoDbVhCZ3Uq3hHMslHVb4",
"kid": "ec02",
"alg": "ES256",
"use": "sig"
}"#;
#[test] #[test]
fn keystore_can_refresh() { fn keystore_can_refresh() {
@ -249,9 +268,8 @@ mod tests {
fail_time: Some(Instant::now() - Duration::from_secs(5)), fail_time: Some(Instant::now() - Duration::from_secs(5)),
load_time: None, load_time: None,
}; };
assert!(ks.can_refresh(Duration::from_secs(4), Duration::from_secs(4))); assert!(ks.can_refresh(Duration::from_secs(0), Duration::from_secs(4)));
assert!(ks.can_refresh(Duration::from_secs(6), Duration::from_secs(4))); assert!(!ks.can_refresh(Duration::from_secs(0), Duration::from_secs(6)));
assert!(!ks.can_refresh(Duration::from_secs(6), Duration::from_secs(6)));
// NO FAIL, LOAD // NO FAIL, LOAD
let ks = KeyStore { let ks = KeyStore {
@ -259,8 +277,8 @@ mod tests {
fail_time: None, fail_time: None,
load_time: Some(Instant::now() - Duration::from_secs(5)), load_time: Some(Instant::now() - Duration::from_secs(5)),
}; };
assert!(ks.can_refresh(Duration::from_secs(4), Duration::from_secs(4))); assert!(ks.can_refresh(Duration::from_secs(4), Duration::from_secs(0)));
assert!(!ks.can_refresh(Duration::from_secs(6), Duration::from_secs(6))); assert!(!ks.can_refresh(Duration::from_secs(6), Duration::from_secs(0)));
// FAIL, LOAD // FAIL, LOAD
let ks = KeyStore { let ks = KeyStore {
@ -269,6 +287,7 @@ mod tests {
load_time: Some(Instant::now() - Duration::from_secs(10)), load_time: Some(Instant::now() - Duration::from_secs(10)),
}; };
assert!(ks.can_refresh(Duration::from_secs(6), Duration::from_secs(4))); assert!(ks.can_refresh(Duration::from_secs(6), Duration::from_secs(4)));
assert!(!ks.can_refresh(Duration::from_secs(11), Duration::from_secs(4)));
assert!(!ks.can_refresh(Duration::from_secs(6), Duration::from_secs(6))); assert!(!ks.can_refresh(Duration::from_secs(6), Duration::from_secs(6)));
} }
@ -309,6 +328,15 @@ mod tests {
.await; .await;
} }
async fn mock_jwks_response_fail_once(mock_server: &MockServer) {
Mock::given(method("GET"))
.and(path("/"))
.respond_with(ResponseTemplate::new(500))
.expect(1)
.mount(&mock_server)
.await;
}
fn build_header(kid: &str, alg: Algorithm) -> Header { fn build_header(kid: &str, alg: Algorithm) -> Header {
let mut header = Header::new(alg); let mut header = Header::new(alg);
header.kid = Some(kid.to_owned()); header.kid = Some(kid.to_owned());
@ -316,125 +344,115 @@ mod tests {
} }
#[tokio::test] #[tokio::test]
async fn keystore_manager_find_key_with_refresh_interval() { async fn strategy_interval() {
let mock_server = MockServer::start().await; let mock_server = MockServer::start().await;
mock_jwks_response_once( mock_jwks_response_once(&mock_server, JWK_ED01).await;
&mock_server,
r#"{
"kty": "OKP",
"use": "sig",
"crv": "Ed25519",
"x": "uWtSkE-I9aTMYTTvuTE1rtu0rNdxp3DU33cJ_ksL1Gk",
"kid": "key-ed",
"alg": "EdDSA"
}"#,
)
.await;
let ksm = KeyStoreManager::new( let ksm = KeyStoreManager::new(
Url::parse(&mock_server.uri()).unwrap(), Url::parse(&mock_server.uri()).unwrap(),
Refresh { Refresh {
strategy: RefreshStrategy::Interval, strategy: RefreshStrategy::Interval,
refresh_interval: Duration::from_secs(3000), refresh_interval: Duration::from_millis(10),
..Default::default() retry_interval: Duration::from_millis(5),
}, },
); );
// 1st RELOAD
let r = ksm.get_key(&Header::new(Algorithm::EdDSA)).await; let r = ksm.get_key(&Header::new(Algorithm::EdDSA)).await;
assert!(r.is_ok()); assert!(r.is_ok());
mock_server.verify().await; mock_server.verify().await;
// NO RELOAD - inteval not elapsed
assert!(ksm.get_key(&Header::new(Algorithm::EdDSA)).await.is_ok());
// RELOAD - interval elapsed
mock_server.reset().await;
tokio::time::sleep(Duration::from_millis(11)).await;
mock_jwks_response_once(&mock_server, JWK_ED01).await;
assert!(ksm.get_key(&Header::new(Algorithm::EdDSA)).await.is_ok());
mock_server.verify().await;
// RELOAD - with fail
mock_server.reset().await;
tokio::time::sleep(Duration::from_millis(11)).await;
mock_jwks_response_fail_once(&mock_server).await;
assert!(ksm.get_key(&Header::new(Algorithm::EdDSA)).await.is_err());
mock_server.verify().await;
// NO RELOAD - retry not ellapsed
assert!(ksm.get_key(&Header::new(Algorithm::EdDSA)).await.is_ok());
// RELOAD - retry elapsed
mock_server.reset().await;
tokio::time::sleep(Duration::from_millis(6)).await;
mock_jwks_response_once(&mock_server, JWK_ED01).await;
assert!(ksm.get_key(&Header::new(Algorithm::EdDSA)).await.is_ok());
mock_server.verify().await;
} }
#[tokio::test] #[tokio::test]
async fn keystore_manager_find_key_with_refresh() { async fn strategy_key_not_found_with_refresh() {
let mock_server = MockServer::start().await; let mock_server = MockServer::start().await;
mock_jwks_response_once( mock_jwks_response_once(&mock_server, JWK_ED01).await;
&mock_server,
r#"{
"kty": "OKP",
"use": "sig",
"crv": "Ed25519",
"x": "uWtSkE-I9aTMYTTvuTE1rtu0rNdxp3DU33cJ_ksL1Gk",
"kid": "key-ed",
"alg": "EdDSA"
}"#,
)
.await;
let mut ksm = KeyStoreManager::new( let ksm = KeyStoreManager::new(
Url::parse(&mock_server.uri()).unwrap(), Url::parse(&mock_server.uri()).unwrap(),
Refresh { Refresh {
strategy: RefreshStrategy::KeyNotFound, strategy: RefreshStrategy::KeyNotFound,
..Default::default() refresh_interval: Duration::from_millis(10),
retry_interval: Duration::from_millis(5),
}, },
); );
// STEP 1: initial (lazy) reloading // STEP 1: initial (lazy) reloading
let r = ksm.get_key(&build_header("key-ed", Algorithm::EdDSA)).await; let r = ksm.get_key(&build_header("ed01", Algorithm::EdDSA)).await;
assert!(r.is_ok()); assert!(r.is_ok());
mock_server.verify().await; mock_server.verify().await;
// STEP2: new kid -> reloading ksm // STEP2: new kid, < refresh_interval -> reloading ksm
mock_server.reset().await; mock_server.reset().await;
mock_jwks_response_once( mock_jwks_response_once(&mock_server, JWK_ED02).await;
&mock_server, let h = build_header("ed02", Algorithm::EdDSA);
r#"{
"kty": "OKP",
"use": "sig",
"crv": "Ed25519",
"x": "uWtSkE-I9aTMYTTvuTE1rtu0rNdxp3DU33cJ_ksL1Gk",
"kid": "key-ed02",
"alg": "EdDSA"
}"#,
)
.await;
let h = build_header("key-ed02", Algorithm::EdDSA);
assert!(ksm.get_key(&h).await.is_err()); assert!(ksm.get_key(&h).await.is_err());
ksm.refresh.minimal_refresh_interval = Duration::from_millis(100); // ksm.refresh.refresh_interval = Duration::from_millis(10);
tokio::time::sleep(Duration::from_millis(101)).await; tokio::time::sleep(Duration::from_millis(11)).await;
assert!(ksm.get_key(&h).await.is_ok()); assert!(ksm.get_key(&h).await.is_ok());
mock_server.verify().await; mock_server.verify().await;
// STEP3: new algorithm -> try to reload // STEP3: new algorithm -> try to reload
mock_server.reset().await; mock_server.reset().await;
mock_jwks_response_once( mock_jwks_response_once(&mock_server, JWK_EC01).await;
&mock_server,
r#"{
"kty": "EC",
"crv": "P-256",
"x": "w7JAoU_gJbZJvV-zCOvU9yFJq0FNC_edCMRM78P8eQQ",
"y": "wQg1EytcsEmGrM70Gb53oluoDbVhCZ3Uq3hHMslHVb4",
"kid": "ec01",
"alg": "ES256",
"use": "sig"
}"#,
)
.await;
let h = Header::new(Algorithm::ES256); let h = Header::new(Algorithm::ES256);
assert!(ksm.get_key(&h).await.is_err()); assert!(ksm.get_key(&h).await.is_err());
tokio::time::sleep(Duration::from_millis(101)).await; tokio::time::sleep(Duration::from_millis(11)).await;
assert!(ksm.get_key(&h).await.is_ok()); assert!(ksm.get_key(&h).await.is_ok());
mock_server.verify().await; mock_server.verify().await;
// STEP4: new key, refresh elapsed, FAIL
mock_server.reset().await;
tokio::time::sleep(Duration::from_millis(11)).await;
mock_jwks_response_fail_once(&mock_server).await;
let h = build_header("ec02", Algorithm::EdDSA);
assert!(ksm.get_key(&h).await.is_err());
mock_server.verify().await;
// STEP5: retry elapsed -> reload
mock_server.reset().await;
tokio::time::sleep(Duration::from_millis(6)).await;
mock_jwks_response_once(&mock_server, JWK_EC02).await;
let h = build_header("ec02", Algorithm::EdDSA);
assert!(ksm.get_key(&h).await.is_ok());
mock_server.verify().await;
} }
#[tokio::test] #[tokio::test]
async fn keystore_manager_find_key_with_no_refresh() { async fn strategy_no_refresh() {
let mock_server = MockServer::start().await; let mock_server = MockServer::start().await;
mock_jwks_response_once( mock_jwks_response_once(&mock_server, JWK_ED01).await;
&mock_server,
r#"{
"kty": "OKP",
"use": "sig",
"crv": "Ed25519",
"x": "uWtSkE-I9aTMYTTvuTE1rtu0rNdxp3DU33cJ_ksL1Gk",
"kid": "key-ed",
"alg": "EdDSA"
}"#,
)
.await;
let ksm = KeyStoreManager::new( let ksm = KeyStoreManager::new(
Url::parse(&mock_server.uri()).unwrap(), Url::parse(&mock_server.uri()).unwrap(),
@ -445,12 +463,12 @@ mod tests {
); );
// STEP 1: initial (lazy) reloading // STEP 1: initial (lazy) reloading
let r = ksm.get_key(&build_header("key-ed", Algorithm::EdDSA)).await; let r = ksm.get_key(&build_header("ed01", Algorithm::EdDSA)).await;
assert!(r.is_ok()); assert!(r.is_ok());
mock_server.verify().await; mock_server.verify().await;
// STEP2: new kid -> reloading ksm // STEP2: new kid -> reloading ksm
let h = build_header("key-ed02", Algorithm::EdDSA); let h = build_header("ed02", Algorithm::EdDSA);
assert!(ksm.get_key(&h).await.is_err()); assert!(ksm.get_key(&h).await.is_err());
mock_server.verify().await; mock_server.verify().await;

View file

@ -224,7 +224,6 @@ async fn scenario2() {
init_test(); init_test();
let url = run_jwks_server(); let url = run_jwks_server();
let refresh = Refresh { let refresh = Refresh {
minimal_refresh_interval: Duration::from_millis(20),
refresh_interval: Duration::from_millis(40), refresh_interval: Duration::from_millis(40),
retry_interval: Duration::from_millis(0), retry_interval: Duration::from_millis(0),
strategy: RefreshStrategy::Interval, strategy: RefreshStrategy::Interval,
@ -255,9 +254,8 @@ async fn scenario3() {
let url = run_jwks_server(); let url = run_jwks_server();
let refresh = Refresh { let refresh = Refresh {
strategy: RefreshStrategy::KeyNotFound, strategy: RefreshStrategy::KeyNotFound,
minimal_refresh_interval: Duration::from_millis(20), refresh_interval: Duration::from_millis(40),
retry_interval: Duration::from_millis(0), retry_interval: Duration::from_millis(0),
..Default::default()
}; };
let auth: JwtAuthorizer<User> = JwtAuthorizer::from_oidc(&url).refresh(refresh); let auth: JwtAuthorizer<User> = JwtAuthorizer::from_oidc(&url).refresh(refresh);
let mut app = app(auth).await; let mut app = app(auth).await;
@ -287,7 +285,7 @@ async fn scenario4() {
let url = run_jwks_server(); let url = run_jwks_server();
let refresh = Refresh { let refresh = Refresh {
strategy: RefreshStrategy::NoRefresh, strategy: RefreshStrategy::NoRefresh,
minimal_refresh_interval: Duration::from_millis(0), refresh_interval: Duration::from_millis(0),
retry_interval: Duration::from_millis(0), retry_interval: Duration::from_millis(0),
..Default::default() ..Default::default()
}; };