mirror of
https://github.com/TECHNOFAB11/jwt-authorizer.git
synced 2025-12-15 01:13:52 +01:00
chore: fmt
This commit is contained in:
parent
f77a7ce54f
commit
b0667729a3
32 changed files with 3596 additions and 7 deletions
33
jwt-authorizer/Cargo.toml
Normal file
33
jwt-authorizer/Cargo.toml
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
[package]
|
||||
name = "jwt-authorizer"
|
||||
description = "jwt authorizer middleware for axum"
|
||||
version = "0.2.0"
|
||||
edition = "2021"
|
||||
authors = ["cduvray <c_duvray@proton.me>"]
|
||||
license = "MIT"
|
||||
readme = "docs/README.md"
|
||||
repository = "https://github.com/cduvray/jwt-authorizer"
|
||||
keywords = ["jwt","axum","authorisation"]
|
||||
|
||||
[dependencies]
|
||||
axum = { version = "0.6.1", features = ["headers"] }
|
||||
futures-util = "0.3.25"
|
||||
futures-core = "0.3.25"
|
||||
headers = "0.3"
|
||||
jsonwebtoken = "8.2.0"
|
||||
http = "0.2.8"
|
||||
# pin-project-lite = "0.2.9"
|
||||
pin-project = "1.0.12"
|
||||
reqwest = { version = "0.11.13", features = ["json"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
thiserror = "1.0.37"
|
||||
tokio = { version = "1.0", features = ["full"] }
|
||||
tower-http = { version = "0.3.4", features = ["trace", "auth"] }
|
||||
tower-layer = "0.3"
|
||||
tower-service = "0.3"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
||||
[dev-dependencies]
|
||||
wiremock = "0.5"
|
||||
0
jwt-authorizer/clippy.toml
Normal file
0
jwt-authorizer/clippy.toml
Normal file
36
jwt-authorizer/docs/README.md
Normal file
36
jwt-authorizer/docs/README.md
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
# jwt-authorizer
|
||||
|
||||
JWT authoriser Layer for Axum.
|
||||
|
||||
Example:
|
||||
|
||||
```rust
|
||||
use jwt_authorizer::{AuthError, JwtAuthorizer, JwtClaims};
|
||||
use axum::{routing::get, Router};
|
||||
use serde::Deserialize;
|
||||
|
||||
// Authorized entity, struct deserializable from JWT claims
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
struct User {
|
||||
sub: String,
|
||||
}
|
||||
|
||||
// let's create an authorizer builder from a JWKS Endpoint
|
||||
let jwt_auth: JwtAuthorizer<User> = JwtAuthorizer::new()
|
||||
.from_jwks_url("http://localhost:3000/oidc/jwks");
|
||||
|
||||
// adding the authorization layer
|
||||
let app = Router::new().route("/protected", get(protected))
|
||||
.layer(jwt_auth.layer());
|
||||
|
||||
// proteced handler with user injection (mapping some jwt claims)
|
||||
async fn protected(JwtClaims(user): JwtClaims<User>) -> Result<String, AuthError> {
|
||||
// Send the protected data to the user
|
||||
Ok(format!("Welcome: {}", user.sub))
|
||||
}
|
||||
|
||||
# async {
|
||||
axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
|
||||
.serve(app.into_make_service()).await.expect("server failed");
|
||||
# };
|
||||
```
|
||||
170
jwt-authorizer/src/authorizer.rs
Normal file
170
jwt-authorizer/src/authorizer.rs
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
use std::{io::Read, time::Duration};
|
||||
|
||||
use jsonwebtoken::{decode, decode_header, jwk::JwkSet, DecodingKey, TokenData, Validation};
|
||||
use serde::de::DeserializeOwned;
|
||||
|
||||
use crate::{
|
||||
error::AuthError,
|
||||
jwks::{key_store_manager::KeyStoreManager, KeySource},
|
||||
};
|
||||
|
||||
pub trait ClaimsChecker<C> {
|
||||
fn check(&self, claims: &C) -> bool;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FnClaimsChecker<C>
|
||||
where
|
||||
C: Clone,
|
||||
{
|
||||
pub checker_fn: fn(&C) -> bool,
|
||||
}
|
||||
|
||||
impl<C> ClaimsChecker<C> for FnClaimsChecker<C>
|
||||
where
|
||||
C: Clone,
|
||||
{
|
||||
fn check(&self, claims: &C) -> bool {
|
||||
(self.checker_fn)(claims)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Authorizer<C>
|
||||
where
|
||||
C: Clone,
|
||||
{
|
||||
pub key_source: KeySource,
|
||||
pub claims_checker: Option<FnClaimsChecker<C>>,
|
||||
}
|
||||
|
||||
fn read_data(path: &str) -> Result<Vec<u8>, AuthError> {
|
||||
let mut data = Vec::<u8>::new();
|
||||
let mut f = std::fs::File::open(path)?;
|
||||
f.read_to_end(&mut data)?;
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
impl<C> Authorizer<C>
|
||||
where
|
||||
C: DeserializeOwned + Clone + Send + Sync,
|
||||
{
|
||||
pub fn from_jwks(jwks: &str, claims_checker: Option<FnClaimsChecker<C>>) -> Result<Authorizer<C>, AuthError> {
|
||||
let set: JwkSet = serde_json::from_str(jwks)?;
|
||||
let k = DecodingKey::from_jwk(&set.keys[0])?;
|
||||
|
||||
Ok(Authorizer {
|
||||
key_source: KeySource::DecodingKeySource(k),
|
||||
claims_checker,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_rsa_file(path: &str) -> Result<Authorizer<C>, AuthError> {
|
||||
Ok(Authorizer {
|
||||
key_source: KeySource::DecodingKeySource(DecodingKey::from_rsa_pem(&read_data(path)?)?),
|
||||
claims_checker: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_ec_file(path: &str) -> Result<Authorizer<C>, AuthError> {
|
||||
let k = DecodingKey::from_ec_der(&read_data(path)?);
|
||||
Ok(Authorizer {
|
||||
key_source: KeySource::DecodingKeySource(k),
|
||||
claims_checker: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_ed_file(path: &str) -> Result<Authorizer<C>, AuthError> {
|
||||
let k = DecodingKey::from_ed_der(&read_data(path)?);
|
||||
Ok(Authorizer {
|
||||
key_source: KeySource::DecodingKeySource(k),
|
||||
claims_checker: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_secret(secret: &str) -> Result<Authorizer<C>, AuthError> {
|
||||
let k = DecodingKey::from_secret(secret.as_bytes());
|
||||
Ok(Authorizer {
|
||||
key_source: KeySource::DecodingKeySource(k),
|
||||
claims_checker: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_jwks_url(url: &str, claims_checker: Option<FnClaimsChecker<C>>) -> Result<Authorizer<C>, AuthError> {
|
||||
let key_store_manager = KeyStoreManager::with_refresh_interval(url, Duration::from_secs(60));
|
||||
Ok(Authorizer {
|
||||
key_source: KeySource::KeyStoreSource(key_store_manager),
|
||||
claims_checker,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn check_auth(&self, token: &str) -> Result<TokenData<C>, AuthError> {
|
||||
let header = decode_header(token)?;
|
||||
let validation = Validation::new(header.alg);
|
||||
let decoding_key = self.key_source.get_key(header).await?;
|
||||
let token_data = decode::<C>(token, &decoding_key, &validation)?;
|
||||
|
||||
if let Some(ref checker) = self.claims_checker {
|
||||
if !checker.check(&token_data.claims) {
|
||||
return Err(AuthError::InvalidClaims());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(token_data)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use jsonwebtoken::{Algorithm, Header};
|
||||
use serde_json::Value;
|
||||
|
||||
use super::Authorizer;
|
||||
|
||||
#[tokio::test]
|
||||
async fn from_secret() {
|
||||
let h = Header::new(Algorithm::HS256);
|
||||
let a = Authorizer::<Value>::from_secret("xxxxxx").unwrap();
|
||||
let k = a.key_source.get_key(h);
|
||||
assert!(k.await.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn from_jwks() {
|
||||
let jwks = r#"
|
||||
{"keys": [{
|
||||
"kid": "1",
|
||||
"kty": "RSA",
|
||||
"alg": "RS256",
|
||||
"use": "sig",
|
||||
"n": "2pQeZdxa7q093K7bj5h6-leIpxfTnuAxzXdhjfGEJHxmt2ekHyCBWWWXCBiDn2RTcEBcy6gZqOW45Uy_tw-5e-Px1xFj1PykGEkRlOpYSAeWsNaAWvvpGB9m4zQ0PgZeMDDXE5IIBrY6YAzmGQxV-fcGGLhJnXl0-5_z7tKC7RvBoT3SGwlc_AmJqpFtTpEBn_fDnyqiZbpcjXYLExFpExm41xDitRKHWIwfc3dV8_vlNntlxCPGy_THkjdXJoHv2IJmlhvmr5_h03iGMLWDKSywxOol_4Wc1BT7Hb6byMxW40GKwSJJ4p7W8eI5mqggRHc8jlwSsTN9LZ2VOvO-XiVShZRVg7JeraGAfWwaIgIJ1D8C1h5Pi0iFpp2suxpHAXHfyLMJXuVotpXbDh4NDX-A4KRMgaxcfAcui_x6gybksq6gF90-9nfQfmVMVJctZ6M-FvRr-itd1Nef5WAtwUp1qyZygAXU3cH3rarscajmurOsP6dE1OHl3grY_eZhQxk33VBK9lavqNKPg6Q_PLiq1ojbYBj3bcYifJrsNeQwxldQP83aWt5rGtgZTehKVJwa40Uy_Grae1iRnsDtdSy5sTJIJ6EiShnWAdMoGejdiI8vpkjrdU8SWH8lv1KXI54DsbyAuke2cYz02zPWc6JEotQqI0HwhzU0KHyoY4s",
|
||||
"e": "AQAB"
|
||||
}]}
|
||||
"#;
|
||||
let a = Authorizer::<Value>::from_jwks(jwks, None).unwrap();
|
||||
let k = a.key_source.get_key(Header::new(Algorithm::RS256));
|
||||
assert!(k.await.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn from_file() {
|
||||
let a = Authorizer::<Value>::from_rsa_file("../config/jwtRS256.key.pub").unwrap();
|
||||
let k = a.key_source.get_key(Header::new(Algorithm::RS256));
|
||||
assert!(k.await.is_ok());
|
||||
|
||||
let a = Authorizer::<Value>::from_ec_file("../config/ec256-public.pem").unwrap();
|
||||
let k = a.key_source.get_key(Header::new(Algorithm::ES256));
|
||||
assert!(k.await.is_ok());
|
||||
|
||||
let a = Authorizer::<Value>::from_ed_file("../config/ed25519-public.pem").unwrap();
|
||||
let k = a.key_source.get_key(Header::new(Algorithm::EdDSA));
|
||||
assert!(k.await.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn from_file_errors() {
|
||||
let a = Authorizer::<Value>::from_rsa_file("./config/does-not-exist.pem");
|
||||
println!("{:?}", a.as_ref().err());
|
||||
assert!(a.is_err());
|
||||
}
|
||||
}
|
||||
60
jwt-authorizer/src/error.rs
Normal file
60
jwt-authorizer/src/error.rs
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
use axum::{
|
||||
extract::rejection::TypedHeaderRejection,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use jsonwebtoken::Algorithm;
|
||||
use thiserror::Error;
|
||||
|
||||
use tracing::log::warn;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AuthError {
|
||||
#[error(transparent)]
|
||||
JwksSerialisationError(#[from] serde_json::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
JwksRefreshError(#[from] reqwest::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
KeyFileError(#[from] std::io::Error),
|
||||
|
||||
#[error("InvalidKey {0}")]
|
||||
InvalidKey(String),
|
||||
|
||||
#[error("Invalid Kid {0}")]
|
||||
InvalidKid(String),
|
||||
|
||||
#[error("Invalid Key Algorithm {0:?}")]
|
||||
InvalidKeyAlg(Algorithm),
|
||||
|
||||
#[error(transparent)]
|
||||
InvalidTokenHeader(#[from] TypedHeaderRejection),
|
||||
|
||||
#[error(transparent)]
|
||||
InvalidToken(#[from] jsonwebtoken::errors::Error),
|
||||
|
||||
#[error("Invalid Claim")]
|
||||
InvalidClaims(),
|
||||
}
|
||||
|
||||
impl IntoResponse for AuthError {
|
||||
fn into_response(self) -> Response {
|
||||
warn!("AuthError: {}", &self);
|
||||
let (status, error_message) = match self {
|
||||
AuthError::JwksRefreshError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
|
||||
AuthError::KeyFileError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
|
||||
AuthError::InvalidKid(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg),
|
||||
AuthError::InvalidTokenHeader(_) => (StatusCode::BAD_REQUEST, self.to_string()),
|
||||
AuthError::InvalidToken(_) => (StatusCode::BAD_REQUEST, self.to_string()),
|
||||
AuthError::InvalidKey(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg),
|
||||
AuthError::JwksSerialisationError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
|
||||
AuthError::InvalidKeyAlg(_) => (StatusCode::BAD_REQUEST, self.to_string()),
|
||||
AuthError::InvalidClaims() => (StatusCode::FORBIDDEN, self.to_string()),
|
||||
};
|
||||
let body = axum::Json(serde_json::json!({
|
||||
"error": error_message,
|
||||
}));
|
||||
(status, body).into_response()
|
||||
}
|
||||
}
|
||||
373
jwt-authorizer/src/jwks/key_store_manager.rs
Normal file
373
jwt-authorizer/src/jwks/key_store_manager.rs
Normal file
|
|
@ -0,0 +1,373 @@
|
|||
use jsonwebtoken::{
|
||||
jwk::{Jwk, JwkSet},
|
||||
Algorithm, DecodingKey,
|
||||
};
|
||||
use std::{
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::error::AuthError;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum RefreshStrategy {
|
||||
/// refresh periodicaly
|
||||
Interval(Duration),
|
||||
/// when kid not found in the store
|
||||
KeyNotFound,
|
||||
// other strategies? KeyNotFoundOrInterval(Duration), Once,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct KeyStoreManager {
|
||||
key_url: String,
|
||||
refresh: RefreshStrategy,
|
||||
/// in case of fail loading (error or key not found), minimal interval
|
||||
minimal_refresh_interval: Duration,
|
||||
keystore: Arc<Mutex<KeyStore>>,
|
||||
}
|
||||
|
||||
pub struct KeyStore {
|
||||
/// key set
|
||||
jwks: JwkSet,
|
||||
/// time of the last successfully loaded jwkset
|
||||
load_time: Option<Instant>,
|
||||
/// time of the last failed load
|
||||
fail_time: Option<Instant>,
|
||||
}
|
||||
|
||||
impl KeyStoreManager {
|
||||
pub(crate) fn new(url: &str, refresh: RefreshStrategy) -> KeyStoreManager {
|
||||
KeyStoreManager {
|
||||
key_url: url.to_owned(),
|
||||
refresh,
|
||||
minimal_refresh_interval: Duration::from_secs(5), // TODO: make configurable
|
||||
keystore: Arc::new(Mutex::new(KeyStore {
|
||||
jwks: JwkSet { keys: vec![] },
|
||||
load_time: None,
|
||||
fail_time: None,
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn with_refresh(url: &str) -> KeyStoreManager {
|
||||
KeyStoreManager::new(url, RefreshStrategy::KeyNotFound)
|
||||
}
|
||||
|
||||
pub(crate) fn with_refresh_interval(url: &str, interval: Duration) -> KeyStoreManager {
|
||||
KeyStoreManager::new(url, RefreshStrategy::Interval(interval))
|
||||
}
|
||||
|
||||
pub(crate) async fn get_key(&self, header: &jsonwebtoken::Header) -> Result<jsonwebtoken::DecodingKey, AuthError> {
|
||||
let kstore = self.keystore.clone();
|
||||
let mut ks_gard = kstore.lock().await;
|
||||
let key = match self.refresh {
|
||||
RefreshStrategy::Interval(refresh_interval) => {
|
||||
if ks_gard.should_refresh(refresh_interval) && ks_gard.can_refresh(self.minimal_refresh_interval) {
|
||||
ks_gard.refresh(&self.key_url, &[]).await?;
|
||||
}
|
||||
if let Some(ref kid) = header.kid {
|
||||
ks_gard
|
||||
.find_kid(kid)
|
||||
.ok_or_else(|| AuthError::InvalidKid(kid.to_owned()))?
|
||||
} else {
|
||||
ks_gard
|
||||
.find_alg(&header.alg)
|
||||
.ok_or(AuthError::InvalidKeyAlg(header.alg))?
|
||||
}
|
||||
}
|
||||
RefreshStrategy::KeyNotFound => {
|
||||
if let Some(ref kid) = header.kid {
|
||||
let jwk_opt = ks_gard.find_kid(kid);
|
||||
if let Some(jwk) = jwk_opt {
|
||||
jwk
|
||||
} else if ks_gard.can_refresh(self.minimal_refresh_interval) {
|
||||
ks_gard.refresh(&self.key_url, &[("kid", kid)]).await?;
|
||||
ks_gard
|
||||
.find_kid(kid)
|
||||
.ok_or_else(|| AuthError::InvalidKid(kid.to_owned()))?
|
||||
} else {
|
||||
return Err(AuthError::InvalidKid(kid.to_owned()));
|
||||
}
|
||||
} else {
|
||||
let jwk_opt = ks_gard.find_alg(&header.alg);
|
||||
// .ok_or(AuthError::InvalidKeyAlg(header.alg))?
|
||||
if let Some(jwk) = jwk_opt {
|
||||
jwk
|
||||
} else if ks_gard.can_refresh(self.minimal_refresh_interval) {
|
||||
ks_gard
|
||||
.refresh(
|
||||
&self.key_url,
|
||||
&[(
|
||||
"alg",
|
||||
&serde_json::to_string(&header.alg)
|
||||
.map_err(|_| AuthError::InvalidKeyAlg(header.alg))?,
|
||||
)],
|
||||
)
|
||||
.await?;
|
||||
ks_gard
|
||||
.find_alg(&header.alg)
|
||||
.ok_or_else(|| AuthError::InvalidKeyAlg(header.alg))?
|
||||
} else {
|
||||
return Err(AuthError::InvalidKeyAlg(header.alg));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
DecodingKey::from_jwk(key).map_err(|err| AuthError::InvalidKey(err.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
impl KeyStore {
|
||||
fn should_refresh(&self, refresh_interval: Duration) -> bool {
|
||||
if let Some(t) = self.load_time {
|
||||
t.elapsed() > refresh_interval
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
fn can_refresh(&self, minimal_refresh_interval: Duration) -> bool {
|
||||
if let Some(ft) = self.fail_time {
|
||||
if let Some(lt) = self.load_time {
|
||||
ft.elapsed() > minimal_refresh_interval && lt.elapsed() > minimal_refresh_interval
|
||||
} else {
|
||||
ft.elapsed() > minimal_refresh_interval
|
||||
}
|
||||
} else if let Some(lt) = self.load_time {
|
||||
lt.elapsed() > minimal_refresh_interval
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
async fn refresh(&mut self, key_url: &str, qparam: &[(&str, &str)]) -> Result<(), AuthError> {
|
||||
reqwest::Client::new()
|
||||
.get(key_url)
|
||||
.query(qparam)
|
||||
.send()
|
||||
.await
|
||||
.map_err(AuthError::JwksRefreshError)?
|
||||
.json::<JwkSet>()
|
||||
.await
|
||||
.map(|jwks| {
|
||||
self.load_time = Some(Instant::now());
|
||||
self.jwks = jwks;
|
||||
Ok(())
|
||||
})
|
||||
.map_err(|e| {
|
||||
self.fail_time = Some(Instant::now());
|
||||
AuthError::JwksRefreshError(e)
|
||||
})?
|
||||
}
|
||||
|
||||
/// Find the key in the set that matches the given key id, if any.
|
||||
pub fn find_kid(&self, kid: &str) -> Option<&Jwk> {
|
||||
self.jwks.find(kid)
|
||||
}
|
||||
|
||||
/// Find the key in the set that matches the given key id, if any.
|
||||
pub fn find_alg(&self, alg: &Algorithm) -> Option<&Jwk> {
|
||||
self.jwks.keys.iter().find(|jwk| {
|
||||
if let Some(ref a) = jwk.common.algorithm {
|
||||
alg == a
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Find first key.
|
||||
pub fn find_first(&self) -> Option<&Jwk> {
|
||||
self.jwks.keys.get(0)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use jsonwebtoken::Algorithm;
|
||||
use jsonwebtoken::{jwk::Jwk, Header};
|
||||
use wiremock::{
|
||||
matchers::{method, path},
|
||||
Mock, MockServer, ResponseTemplate,
|
||||
};
|
||||
|
||||
use crate::jwks::key_store_manager::{KeyStore, KeyStoreManager};
|
||||
|
||||
#[test]
|
||||
fn keystore_should_refresh() {
|
||||
let ks = KeyStore {
|
||||
jwks: jsonwebtoken::jwk::JwkSet { keys: vec![] },
|
||||
fail_time: None,
|
||||
load_time: Some(Instant::now()),
|
||||
};
|
||||
|
||||
assert!(!ks.should_refresh(Duration::from_secs(5)));
|
||||
|
||||
let ks = KeyStore {
|
||||
jwks: jsonwebtoken::jwk::JwkSet { keys: vec![] },
|
||||
fail_time: None,
|
||||
load_time: Some(Instant::now() - Duration::from_secs(6)),
|
||||
};
|
||||
|
||||
assert!(ks.should_refresh(Duration::from_secs(5)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keystore_can_refresh() {
|
||||
let ks = KeyStore {
|
||||
jwks: jsonwebtoken::jwk::JwkSet { keys: vec![] },
|
||||
fail_time: Some(Instant::now() - Duration::from_secs(5)),
|
||||
load_time: None,
|
||||
};
|
||||
assert!(ks.can_refresh(Duration::from_secs(4)));
|
||||
assert!(!ks.can_refresh(Duration::from_secs(6)));
|
||||
|
||||
let ks = KeyStore {
|
||||
jwks: jsonwebtoken::jwk::JwkSet { keys: vec![] },
|
||||
fail_time: None,
|
||||
load_time: Some(Instant::now() - Duration::from_secs(5)),
|
||||
};
|
||||
assert!(ks.can_refresh(Duration::from_secs(4)));
|
||||
assert!(!ks.can_refresh(Duration::from_secs(6)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_kid() {
|
||||
let jwk0: Jwk = serde_json::from_str(r#"{"kid":"1","kty":"RSA","alg":"RS256","n":"xxxx","e":"AQAB"}"#).unwrap();
|
||||
let jwk1: Jwk = serde_json::from_str(r#"{"kid":"2","kty":"RSA","alg":"RS256","n":"xxxx","e":"AQAB"}"#).unwrap();
|
||||
let ks = KeyStore {
|
||||
load_time: None,
|
||||
fail_time: None,
|
||||
jwks: jsonwebtoken::jwk::JwkSet { keys: vec![jwk0, jwk1] },
|
||||
};
|
||||
assert!(ks.find_kid("1").is_some());
|
||||
assert!(ks.find_kid("2").is_some());
|
||||
assert!(ks.find_kid("3").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_alg() {
|
||||
let jwk0: Jwk = serde_json::from_str(r#"{"kty": "RSA", "alg": "RS256", "n": "xxx","e": "yyy"}"#).unwrap();
|
||||
let ks = KeyStore {
|
||||
load_time: None,
|
||||
fail_time: None,
|
||||
jwks: jsonwebtoken::jwk::JwkSet { keys: vec![jwk0] },
|
||||
};
|
||||
assert!(ks.find_alg(&Algorithm::RS256).is_some());
|
||||
assert!(ks.find_alg(&Algorithm::EdDSA).is_none());
|
||||
}
|
||||
|
||||
async fn mock_jwks_response_once(mock_server: &MockServer, jwk: &str) {
|
||||
let jwk0: Jwk = serde_json::from_str(jwk).unwrap();
|
||||
let jwks = jsonwebtoken::jwk::JwkSet { keys: vec![jwk0] };
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_json(&jwks))
|
||||
.expect(1)
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
}
|
||||
|
||||
fn build_header(kid: &str, alg: Algorithm) -> Header {
|
||||
let mut header = Header::new(alg);
|
||||
header.kid = Some(kid.to_owned());
|
||||
header
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn keystore_manager_find_key_with_refresh_interval() {
|
||||
let mock_server = MockServer::start().await;
|
||||
mock_jwks_response_once(
|
||||
&mock_server,
|
||||
r#"{
|
||||
"kty": "OKP",
|
||||
"use": "sig",
|
||||
"crv": "Ed25519",
|
||||
"x": "uWtSkE-I9aTMYTTvuTE1rtu0rNdxp3DU33cJ_ksL1Gk",
|
||||
"kid": "key-ed",
|
||||
"alg": "EdDSA"
|
||||
}"#,
|
||||
)
|
||||
.await;
|
||||
|
||||
let ksm = KeyStoreManager::with_refresh_interval(&mock_server.uri(), Duration::from_secs(3000));
|
||||
let r = ksm.get_key(&Header::new(Algorithm::EdDSA)).await;
|
||||
assert!(r.is_ok());
|
||||
mock_server.verify().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn keystore_manager_find_key_with_refresh() {
|
||||
let mock_server = MockServer::start().await;
|
||||
mock_jwks_response_once(
|
||||
&mock_server,
|
||||
r#"{
|
||||
"kty": "OKP",
|
||||
"use": "sig",
|
||||
"crv": "Ed25519",
|
||||
"x": "uWtSkE-I9aTMYTTvuTE1rtu0rNdxp3DU33cJ_ksL1Gk",
|
||||
"kid": "key-ed",
|
||||
"alg": "EdDSA"
|
||||
}"#,
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut ksm = KeyStoreManager::with_refresh(&mock_server.uri());
|
||||
|
||||
// STEP 1: initial (lazy) reloading
|
||||
let r = ksm.get_key(&build_header("key-ed", Algorithm::EdDSA)).await;
|
||||
assert!(r.is_ok());
|
||||
mock_server.verify().await;
|
||||
|
||||
// STEP2: new kid -> reloading ksm
|
||||
mock_server.reset().await;
|
||||
mock_jwks_response_once(
|
||||
&mock_server,
|
||||
r#"{
|
||||
"kty": "OKP",
|
||||
"use": "sig",
|
||||
"crv": "Ed25519",
|
||||
"x": "uWtSkE-I9aTMYTTvuTE1rtu0rNdxp3DU33cJ_ksL1Gk",
|
||||
"kid": "key-ed02",
|
||||
"alg": "EdDSA"
|
||||
}"#,
|
||||
)
|
||||
.await;
|
||||
let h = build_header("key-ed02", Algorithm::EdDSA);
|
||||
assert!(ksm.get_key(&h).await.is_err());
|
||||
|
||||
ksm.minimal_refresh_interval = Duration::from_millis(100);
|
||||
tokio::time::sleep(Duration::from_millis(101)).await;
|
||||
assert!(ksm.get_key(&h).await.is_ok());
|
||||
|
||||
mock_server.verify().await;
|
||||
|
||||
// STEP3: new algorithm -> try to reload
|
||||
mock_server.reset().await;
|
||||
mock_jwks_response_once(
|
||||
&mock_server,
|
||||
r#"{
|
||||
"kty": "EC",
|
||||
"crv": "P-256",
|
||||
"x": "w7JAoU_gJbZJvV-zCOvU9yFJq0FNC_edCMRM78P8eQQ",
|
||||
"y": "wQg1EytcsEmGrM70Gb53oluoDbVhCZ3Uq3hHMslHVb4",
|
||||
"kid": "ec01",
|
||||
"alg": "ES256",
|
||||
"use": "sig"
|
||||
}"#,
|
||||
)
|
||||
.await;
|
||||
let h = Header::new(Algorithm::ES256);
|
||||
assert!(ksm.get_key(&h).await.is_err());
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(101)).await;
|
||||
assert!(ksm.get_key(&h).await.is_ok());
|
||||
|
||||
mock_server.verify().await;
|
||||
}
|
||||
}
|
||||
24
jwt-authorizer/src/jwks/mod.rs
Normal file
24
jwt-authorizer/src/jwks/mod.rs
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
use jsonwebtoken::{DecodingKey, Header};
|
||||
|
||||
use crate::error::AuthError;
|
||||
|
||||
use self::key_store_manager::KeyStoreManager;
|
||||
|
||||
pub mod key_store_manager;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum KeySource {
|
||||
KeyStoreSource(KeyStoreManager),
|
||||
DecodingKeySource(DecodingKey),
|
||||
}
|
||||
|
||||
impl KeySource {
|
||||
pub async fn get_key(&self, header: Header) -> Result<DecodingKey, AuthError> {
|
||||
match self {
|
||||
KeySource::KeyStoreSource(kstore) => kstore.get_key(&header).await,
|
||||
KeySource::DecodingKeySource(key) => {
|
||||
Ok(key.clone()) // TODO: clone -> &
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
286
jwt-authorizer/src/layer.rs
Normal file
286
jwt-authorizer/src/layer.rs
Normal file
|
|
@ -0,0 +1,286 @@
|
|||
use axum::http::Request;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::{body::Body, response::Response};
|
||||
use futures_core::ready;
|
||||
use futures_util::future::BoxFuture;
|
||||
use headers::authorization::Bearer;
|
||||
use headers::{Authorization, HeaderMapExt};
|
||||
use http::StatusCode;
|
||||
use pin_project::pin_project;
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use tower_layer::Layer;
|
||||
use tower_service::Service;
|
||||
|
||||
use crate::authorizer::{Authorizer, FnClaimsChecker};
|
||||
|
||||
/// Authorizer Layer builder
|
||||
///
|
||||
/// - initialisation of the Authorizer from jwks, rsa, ed, ec or secret
|
||||
/// - can define a checker (jwt claims check)
|
||||
pub struct JwtAuthorizer<C>
|
||||
where
|
||||
C: Clone + DeserializeOwned,
|
||||
{
|
||||
url: Option<&'static str>,
|
||||
claims_checker: Option<FnClaimsChecker<C>>,
|
||||
}
|
||||
|
||||
/// layer builder
|
||||
impl<C> JwtAuthorizer<C>
|
||||
where
|
||||
C: Clone + DeserializeOwned + Send + Sync,
|
||||
{
|
||||
pub fn new() -> Self {
|
||||
JwtAuthorizer {
|
||||
url: None,
|
||||
claims_checker: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_jwks_url(mut self, url: &'static str) -> JwtAuthorizer<C> {
|
||||
self.url = Some(url);
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn from_rsa_pem(mut self, path: &'static str) -> JwtAuthorizer<C> {
|
||||
// TODO
|
||||
self
|
||||
}
|
||||
|
||||
pub fn from_ec_der(mut self, path: &'static str) -> JwtAuthorizer<C> {
|
||||
// TODO
|
||||
self
|
||||
}
|
||||
|
||||
pub fn from_ed_der(mut self, path: &'static str) -> JwtAuthorizer<C> {
|
||||
// TODO
|
||||
self
|
||||
}
|
||||
|
||||
pub fn from_secret(mut self, path: &'static str) -> JwtAuthorizer<C> {
|
||||
// TODO
|
||||
self
|
||||
}
|
||||
|
||||
/// layer that checks token validity and claim constraints (custom function)
|
||||
pub fn with_check(mut self, checker_fn: fn(&C) -> bool) -> JwtAuthorizer<C> {
|
||||
self.claims_checker = Some(FnClaimsChecker { checker_fn });
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// build axum layer
|
||||
pub fn layer(&self) -> AsyncAuthorizationLayer<C> {
|
||||
// TODO: replace unwrap
|
||||
let auth = Arc::new(Authorizer::from_jwks_url(self.url.unwrap(), self.claims_checker.clone()).unwrap());
|
||||
|
||||
AsyncAuthorizationLayer::new(auth)
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for authorizing requests.
|
||||
pub trait AsyncAuthorizer<B> {
|
||||
type RequestBody;
|
||||
type ResponseBody;
|
||||
type Future: Future<Output = Result<Request<Self::RequestBody>, Response<Self::ResponseBody>>>;
|
||||
|
||||
/// Authorize the request.
|
||||
///
|
||||
/// If the future resolves to `Ok(request)` then the request is allowed through, otherwise not.
|
||||
fn authorize(&self, request: Request<B>) -> Self::Future;
|
||||
}
|
||||
|
||||
impl<B, S, C> AsyncAuthorizer<B> for AsyncAuthorizationService<S, C>
|
||||
where
|
||||
B: Send + Sync + 'static,
|
||||
C: Clone + DeserializeOwned + Send + Sync + 'static,
|
||||
{
|
||||
type RequestBody = B;
|
||||
type ResponseBody = Body;
|
||||
type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
|
||||
|
||||
fn authorize(&self, mut request: Request<B>) -> Self::Future {
|
||||
let authorizer = self.auth.clone();
|
||||
let h = request.headers();
|
||||
let bearer: Authorization<Bearer> = h.typed_get().unwrap();
|
||||
Box::pin(async move {
|
||||
if let Ok(token_data) = authorizer.check_auth(bearer.token()).await {
|
||||
// Set `token_data` as a request extension so it can be accessed by other
|
||||
// services down the stack.
|
||||
request.extensions_mut().insert(token_data);
|
||||
|
||||
Ok(request)
|
||||
} else {
|
||||
let unauthorized_response = Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
Err(unauthorized_response)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// -------------- Layer -----------------
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AsyncAuthorizationLayer<C>
|
||||
where
|
||||
C: Clone + DeserializeOwned + Send,
|
||||
{
|
||||
auth: Arc<Authorizer<C>>,
|
||||
}
|
||||
|
||||
impl<C> AsyncAuthorizationLayer<C>
|
||||
where
|
||||
C: Clone + DeserializeOwned + Send,
|
||||
{
|
||||
pub fn new(auth: Arc<Authorizer<C>>) -> AsyncAuthorizationLayer<C> {
|
||||
Self { auth }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, C> Layer<S> for AsyncAuthorizationLayer<C>
|
||||
where
|
||||
C: Clone + DeserializeOwned + Send + Sync,
|
||||
{
|
||||
type Service = AsyncAuthorizationService<S, C>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
AsyncAuthorizationService::new(inner, self.auth.clone())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- AsyncAuthorizationService --------
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AsyncAuthorizationService<S, C>
|
||||
where
|
||||
C: Clone + DeserializeOwned + Send + Sync,
|
||||
{
|
||||
pub inner: S,
|
||||
pub auth: Arc<Authorizer<C>>,
|
||||
}
|
||||
|
||||
impl<S, C> AsyncAuthorizationService<S, C>
|
||||
where
|
||||
C: Clone + DeserializeOwned + Send + Sync,
|
||||
{
|
||||
pub fn get_ref(&self) -> &S {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
/// Gets a mutable reference to the underlying service.
|
||||
pub fn get_mut(&mut self) -> &mut S {
|
||||
&mut self.inner
|
||||
}
|
||||
|
||||
/// Consumes `self`, returning the underlying service.
|
||||
pub fn into_inner(self) -> S {
|
||||
self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, C> AsyncAuthorizationService<S, C>
|
||||
where
|
||||
C: Clone + DeserializeOwned + Send + Sync,
|
||||
{
|
||||
/// Authorize requests using a custom scheme.
|
||||
///
|
||||
/// The `Authorization` header is required to have the value provided.
|
||||
pub fn new(inner: S, auth: Arc<Authorizer<C>>) -> AsyncAuthorizationService<S, C> {
|
||||
Self { inner, auth }
|
||||
}
|
||||
}
|
||||
|
||||
impl<ReqBody, S, C> Service<Request<ReqBody>> for AsyncAuthorizationService<S, C>
|
||||
where
|
||||
ReqBody: Send + Sync + 'static,
|
||||
S: Service<Request<ReqBody>, Response = Response> + Clone,
|
||||
C: Clone + DeserializeOwned + Send + Sync + 'static,
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = S::Error;
|
||||
type Future = ResponseFuture<S, ReqBody, C>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
|
||||
let inner = self.inner.clone();
|
||||
let auth_fut = self.authorize(req);
|
||||
|
||||
ResponseFuture {
|
||||
state: State::Authorize { auth_fut },
|
||||
service: inner,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
/// Response future for [`AsyncAuthorizationService`].
|
||||
pub struct ResponseFuture<S, ReqBody, C>
|
||||
where
|
||||
S: Service<Request<ReqBody>, Response = Response>,
|
||||
ReqBody: Send + Sync + 'static,
|
||||
C: Clone + DeserializeOwned + Send + Sync + 'static,
|
||||
{
|
||||
#[pin]
|
||||
state: State<<AsyncAuthorizationService<S, C> as AsyncAuthorizer<ReqBody>>::Future, S::Future>,
|
||||
service: S,
|
||||
}
|
||||
|
||||
#[pin_project(project = StateProj)]
|
||||
enum State<A, SFut> {
|
||||
Authorize {
|
||||
#[pin]
|
||||
auth_fut: A,
|
||||
},
|
||||
Authorized {
|
||||
#[pin]
|
||||
svc_fut: SFut,
|
||||
},
|
||||
}
|
||||
|
||||
impl<S, ReqBody, C> Future for ResponseFuture<S, ReqBody, C>
|
||||
where
|
||||
S: Service<Request<ReqBody>, Response = Response>,
|
||||
ReqBody: Send + Sync + 'static,
|
||||
C: Clone + DeserializeOwned + Send + Sync,
|
||||
{
|
||||
type Output = Result<S::Response, S::Error>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let mut this = self.project();
|
||||
|
||||
loop {
|
||||
match this.state.as_mut().project() {
|
||||
StateProj::Authorize { auth_fut } => {
|
||||
let auth = ready!(auth_fut.poll(cx));
|
||||
match auth {
|
||||
Ok(req) => {
|
||||
let svc_fut = this.service.call(req);
|
||||
this.state.set(State::Authorized { svc_fut })
|
||||
}
|
||||
Err(res) => {
|
||||
tracing::info!("err: {:?}", res);
|
||||
let r = (StatusCode::FORBIDDEN, format!("Unauthorized : {:?}", res)).into_response();
|
||||
// TODO: replace r by res (type problems: res should be already a 403 error response)
|
||||
return Poll::Ready(Ok(r));
|
||||
}
|
||||
};
|
||||
}
|
||||
StateProj::Authorized { svc_fut } => {
|
||||
return svc_fut.poll(cx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
34
jwt-authorizer/src/lib.rs
Normal file
34
jwt-authorizer/src/lib.rs
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
#![doc = include_str!("../docs/README.md")]
|
||||
|
||||
use axum::{async_trait, extract::FromRequestParts, http::request::Parts};
|
||||
use jsonwebtoken::TokenData;
|
||||
use serde::de::DeserializeOwned;
|
||||
|
||||
pub use self::error::AuthError;
|
||||
pub use layer::JwtAuthorizer;
|
||||
|
||||
pub mod authorizer;
|
||||
pub mod error;
|
||||
pub mod jwks;
|
||||
pub mod layer;
|
||||
|
||||
/// Claims serialized using T
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct JwtClaims<T>(pub T);
|
||||
|
||||
#[async_trait]
|
||||
impl<T, S> FromRequestParts<S> for JwtClaims<T>
|
||||
where
|
||||
T: DeserializeOwned + Send + Sync + Clone + 'static,
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = error::AuthError;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
|
||||
let claims = parts.extensions.get::<TokenData<T>>().unwrap(); // TODO: unwrap -> err
|
||||
Ok(JwtClaims(claims.claims.clone())) // TODO: unwrap -> err
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
1
jwt-authorizer/src/tests.rs
Normal file
1
jwt-authorizer/src/tests.rs
Normal file
|
|
@ -0,0 +1 @@
|
|||
// TODO: tests
|
||||
Loading…
Add table
Add a link
Reference in a new issue