Merge pull request #42 from NotNorom/main

feat: Add support for custom http client in jwks discovery.
This commit is contained in:
cduvray 2023-11-15 07:35:03 +01:00 committed by GitHub
commit 8d9734bcd5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 61 additions and 6 deletions

View file

@ -3,7 +3,7 @@ use std::{io::Read, sync::Arc};
use headers::{authorization::Bearer, Authorization, HeaderMapExt}; use headers::{authorization::Bearer, Authorization, HeaderMapExt};
use http::HeaderMap; use http::HeaderMap;
use jsonwebtoken::{decode, decode_header, jwk::JwkSet, Algorithm, DecodingKey, TokenData}; use jsonwebtoken::{decode, decode_header, jwk::JwkSet, Algorithm, DecodingKey, TokenData};
use reqwest::Url; use reqwest::{Client, Url};
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use crate::{ use crate::{
@ -56,6 +56,7 @@ where
refresh: Option<Refresh>, refresh: Option<Refresh>,
validation: crate::validation::Validation, validation: crate::validation::Validation,
jwt_source: JwtSource, jwt_source: JwtSource,
http_client: Option<Client>,
) -> Result<Authorizer<C>, InitError> { ) -> Result<Authorizer<C>, InitError> {
Ok(match key_source_type { Ok(match key_source_type {
KeySourceType::RSA(path) => { KeySourceType::RSA(path) => {
@ -195,7 +196,7 @@ where
} }
} }
KeySourceType::Discovery(issuer_url) => { KeySourceType::Discovery(issuer_url) => {
let jwks_url = Url::parse(&oidc::discover_jwks(issuer_url.as_str()).await?) let jwks_url = Url::parse(&oidc::discover_jwks(issuer_url.as_str(), http_client).await?)
.map_err(|e| InitError::JwksUrlError(e.to_string()))?; .map_err(|e| InitError::JwksUrlError(e.to_string()))?;
let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default()); let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default());
@ -318,6 +319,7 @@ mod tests {
None, None,
Validation::new(), Validation::new(),
JwtSource::AuthorizationHeader, JwtSource::AuthorizationHeader,
None,
) )
.await .await
.unwrap(); .unwrap();
@ -343,6 +345,7 @@ mod tests {
None, None,
Validation::new(), Validation::new(),
JwtSource::AuthorizationHeader, JwtSource::AuthorizationHeader,
None,
) )
.await .await
.unwrap(); .unwrap();
@ -358,6 +361,7 @@ mod tests {
None, None,
Validation::new(), Validation::new(),
JwtSource::AuthorizationHeader, JwtSource::AuthorizationHeader,
None,
) )
.await .await
.unwrap(); .unwrap();
@ -370,6 +374,7 @@ mod tests {
None, None,
Validation::new(), Validation::new(),
JwtSource::AuthorizationHeader, JwtSource::AuthorizationHeader,
None,
) )
.await .await
.unwrap(); .unwrap();
@ -382,6 +387,7 @@ mod tests {
None, None,
Validation::new(), Validation::new(),
JwtSource::AuthorizationHeader, JwtSource::AuthorizationHeader,
None,
) )
.await .await
.unwrap(); .unwrap();
@ -394,6 +400,7 @@ mod tests {
None, None,
Validation::new(), Validation::new(),
JwtSource::AuthorizationHeader, JwtSource::AuthorizationHeader,
None,
) )
.await .await
.unwrap(); .unwrap();
@ -419,6 +426,7 @@ mod tests {
None, None,
Validation::new(), Validation::new(),
JwtSource::AuthorizationHeader, JwtSource::AuthorizationHeader,
None,
) )
.await .await
.unwrap(); .unwrap();
@ -431,6 +439,7 @@ mod tests {
None, None,
Validation::new(), Validation::new(),
JwtSource::AuthorizationHeader, JwtSource::AuthorizationHeader,
None,
) )
.await .await
.unwrap(); .unwrap();
@ -443,6 +452,7 @@ mod tests {
None, None,
Validation::new(), Validation::new(),
JwtSource::AuthorizationHeader, JwtSource::AuthorizationHeader,
None,
) )
.await .await
.unwrap(); .unwrap();
@ -458,6 +468,7 @@ mod tests {
None, None,
Validation::new(), Validation::new(),
JwtSource::AuthorizationHeader, JwtSource::AuthorizationHeader,
None,
) )
.await; .await;
println!("{:?}", a.as_ref().err()); println!("{:?}", a.as_ref().err());
@ -472,6 +483,7 @@ mod tests {
None, None,
Validation::default(), Validation::default(),
JwtSource::AuthorizationHeader, JwtSource::AuthorizationHeader,
None,
) )
.await; .await;
println!("{:?}", a.as_ref().err()); println!("{:?}", a.as_ref().err());
@ -486,6 +498,7 @@ mod tests {
None, None,
Validation::default(), Validation::default(),
JwtSource::AuthorizationHeader, JwtSource::AuthorizationHeader,
None,
) )
.await; .await;
println!("{:?}", a.as_ref().err()); println!("{:?}", a.as_ref().err());

View file

@ -9,6 +9,8 @@ use crate::{
Authorizer, Refresh, RefreshStrategy, RegisteredClaims, Validation, Authorizer, Refresh, RefreshStrategy, RegisteredClaims, Validation,
}; };
use reqwest::Client;
/// Authorizer Layer builder /// Authorizer Layer builder
/// ///
/// - initialisation of the Authorizer from jwks, rsa, ed, ec or secret /// - initialisation of the Authorizer from jwks, rsa, ed, ec or secret
@ -22,6 +24,7 @@ where
claims_checker: Option<ClaimsCheckerFn<C>>, claims_checker: Option<ClaimsCheckerFn<C>>,
validation: Option<Validation>, validation: Option<Validation>,
jwt_source: JwtSource, jwt_source: JwtSource,
http_client: Option<Client>,
} }
/// alias for AuthorizerBuidler (backwards compatibility) /// alias for AuthorizerBuidler (backwards compatibility)
@ -40,6 +43,7 @@ where
claims_checker: None, claims_checker: None,
validation: None, validation: None,
jwt_source: JwtSource::AuthorizationHeader, jwt_source: JwtSource::AuthorizationHeader,
http_client: None,
} }
} }
@ -51,6 +55,7 @@ where
claims_checker: None, claims_checker: None,
validation: None, validation: None,
jwt_source: JwtSource::AuthorizationHeader, jwt_source: JwtSource::AuthorizationHeader,
http_client: None,
} }
} }
@ -61,6 +66,7 @@ where
claims_checker: None, claims_checker: None,
validation: None, validation: None,
jwt_source: JwtSource::AuthorizationHeader, jwt_source: JwtSource::AuthorizationHeader,
http_client: None,
} }
} }
@ -71,6 +77,7 @@ where
claims_checker: None, claims_checker: None,
validation: None, validation: None,
jwt_source: JwtSource::AuthorizationHeader, jwt_source: JwtSource::AuthorizationHeader,
http_client: None,
} }
} }
@ -82,6 +89,7 @@ where
claims_checker: None, claims_checker: None,
validation: None, validation: None,
jwt_source: JwtSource::AuthorizationHeader, jwt_source: JwtSource::AuthorizationHeader,
http_client: None,
} }
} }
@ -93,6 +101,7 @@ where
claims_checker: None, claims_checker: None,
validation: None, validation: None,
jwt_source: JwtSource::AuthorizationHeader, jwt_source: JwtSource::AuthorizationHeader,
http_client: None,
} }
} }
@ -104,6 +113,7 @@ where
claims_checker: None, claims_checker: None,
validation: None, validation: None,
jwt_source: JwtSource::AuthorizationHeader, jwt_source: JwtSource::AuthorizationHeader,
http_client: None,
} }
} }
@ -115,6 +125,7 @@ where
claims_checker: None, claims_checker: None,
validation: None, validation: None,
jwt_source: JwtSource::AuthorizationHeader, jwt_source: JwtSource::AuthorizationHeader,
http_client: None,
} }
} }
@ -126,6 +137,7 @@ where
claims_checker: None, claims_checker: None,
validation: None, validation: None,
jwt_source: JwtSource::AuthorizationHeader, jwt_source: JwtSource::AuthorizationHeader,
http_client: None,
} }
} }
@ -137,6 +149,7 @@ where
claims_checker: None, claims_checker: None,
validation: None, validation: None,
jwt_source: JwtSource::AuthorizationHeader, jwt_source: JwtSource::AuthorizationHeader,
http_client: None,
} }
} }
@ -148,6 +161,7 @@ where
claims_checker: None, claims_checker: None,
validation: None, validation: None,
jwt_source: JwtSource::AuthorizationHeader, jwt_source: JwtSource::AuthorizationHeader,
http_client: None,
} }
} }
@ -198,12 +212,30 @@ where
self self
} }
/// provide a custom http client for oicd requests
/// if not called, uses a default configured client
///
/// (default: None)
pub fn http_client(mut self, http_client: Client) -> AuthorizerBuilder<C> {
self.http_client = Some(http_client);
self
}
/// Build axum layer /// Build axum layer
#[deprecated(since = "0.10.0", note = "please use `IntoLayer::into_layer()` instead")] #[deprecated(since = "0.10.0", note = "please use `IntoLayer::into_layer()` instead")]
pub async fn layer(self) -> Result<AuthorizationLayer<C>, InitError> { pub async fn layer(self) -> Result<AuthorizationLayer<C>, InitError> {
let val = self.validation.unwrap_or_default(); let val = self.validation.unwrap_or_default();
let auth = Arc::new( let auth = Arc::new(
Authorizer::build(self.key_source_type, self.claims_checker, self.refresh, val, self.jwt_source).await?, Authorizer::build(
self.key_source_type,
self.claims_checker,
self.refresh,
val,
self.jwt_source,
None,
)
.await?,
); );
Ok(AuthorizationLayer::new(vec![auth])) Ok(AuthorizationLayer::new(vec![auth]))
} }
@ -211,6 +243,14 @@ where
pub async fn build(self) -> Result<Authorizer<C>, InitError> { pub async fn build(self) -> Result<Authorizer<C>, InitError> {
let val = self.validation.unwrap_or_default(); let val = self.validation.unwrap_or_default();
Authorizer::build(self.key_source_type, self.claims_checker, self.refresh, val, self.jwt_source).await Authorizer::build(
self.key_source_type,
self.claims_checker,
self.refresh,
val,
self.jwt_source,
self.http_client,
)
.await
} }
} }

View file

@ -20,8 +20,10 @@ fn discovery_url(issuer: &str) -> Result<Url, InitError> {
Ok(url) Ok(url)
} }
pub async fn discover_jwks(issuer: &str) -> Result<String, InitError> { pub async fn discover_jwks(issuer: &str, client: Option<Client>) -> Result<String, InitError> {
Client::new() let client = client.unwrap_or_default();
client
.get(discovery_url(issuer)?) .get(discovery_url(issuer)?)
.send() .send()
.await .await