chore: fmt

This commit is contained in:
cduvray 2023-01-08 13:45:21 +01:00
parent f77a7ce54f
commit b0667729a3
32 changed files with 3596 additions and 7 deletions

33
jwt-authorizer/Cargo.toml Normal file
View file

@ -0,0 +1,33 @@
[package]
name = "jwt-authorizer"
description = "jwt authorizer middleware for axum"
version = "0.2.0"
edition = "2021"
authors = ["cduvray <c_duvray@proton.me>"]
license = "MIT"
readme = "docs/README.md"
repository = "https://github.com/cduvray/jwt-authorizer"
keywords = ["jwt","axum","authorisation"]
[dependencies]
axum = { version = "0.6.1", features = ["headers"] }
futures-util = "0.3.25"
futures-core = "0.3.25"
headers = "0.3"
jsonwebtoken = "8.2.0"
http = "0.2.8"
# pin-project-lite = "0.2.9"
pin-project = "1.0.12"
reqwest = { version = "0.11.13", features = ["json"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
thiserror = "1.0.37"
tokio = { version = "1.0", features = ["full"] }
tower-http = { version = "0.3.4", features = ["trace", "auth"] }
tower-layer = "0.3"
tower-service = "0.3"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
[dev-dependencies]
wiremock = "0.5"

View file

View file

@ -0,0 +1,36 @@
# jwt-authorizer
JWT authoriser Layer for Axum.
Example:
```rust
use jwt_authorizer::{AuthError, JwtAuthorizer, JwtClaims};
use axum::{routing::get, Router};
use serde::Deserialize;
// Authorized entity, struct deserializable from JWT claims
#[derive(Debug, Deserialize, Clone)]
struct User {
sub: String,
}
// let's create an authorizer builder from a JWKS Endpoint
let jwt_auth: JwtAuthorizer<User> = JwtAuthorizer::new()
.from_jwks_url("http://localhost:3000/oidc/jwks");
// adding the authorization layer
let app = Router::new().route("/protected", get(protected))
.layer(jwt_auth.layer());
// proteced handler with user injection (mapping some jwt claims)
async fn protected(JwtClaims(user): JwtClaims<User>) -> Result<String, AuthError> {
// Send the protected data to the user
Ok(format!("Welcome: {}", user.sub))
}
# async {
axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
.serve(app.into_make_service()).await.expect("server failed");
# };
```

View file

@ -0,0 +1,170 @@
use std::{io::Read, time::Duration};
use jsonwebtoken::{decode, decode_header, jwk::JwkSet, DecodingKey, TokenData, Validation};
use serde::de::DeserializeOwned;
use crate::{
error::AuthError,
jwks::{key_store_manager::KeyStoreManager, KeySource},
};
pub trait ClaimsChecker<C> {
fn check(&self, claims: &C) -> bool;
}
#[derive(Clone)]
pub struct FnClaimsChecker<C>
where
C: Clone,
{
pub checker_fn: fn(&C) -> bool,
}
impl<C> ClaimsChecker<C> for FnClaimsChecker<C>
where
C: Clone,
{
fn check(&self, claims: &C) -> bool {
(self.checker_fn)(claims)
}
}
pub struct Authorizer<C>
where
C: Clone,
{
pub key_source: KeySource,
pub claims_checker: Option<FnClaimsChecker<C>>,
}
fn read_data(path: &str) -> Result<Vec<u8>, AuthError> {
let mut data = Vec::<u8>::new();
let mut f = std::fs::File::open(path)?;
f.read_to_end(&mut data)?;
Ok(data)
}
impl<C> Authorizer<C>
where
C: DeserializeOwned + Clone + Send + Sync,
{
pub fn from_jwks(jwks: &str, claims_checker: Option<FnClaimsChecker<C>>) -> Result<Authorizer<C>, AuthError> {
let set: JwkSet = serde_json::from_str(jwks)?;
let k = DecodingKey::from_jwk(&set.keys[0])?;
Ok(Authorizer {
key_source: KeySource::DecodingKeySource(k),
claims_checker,
})
}
pub fn from_rsa_file(path: &str) -> Result<Authorizer<C>, AuthError> {
Ok(Authorizer {
key_source: KeySource::DecodingKeySource(DecodingKey::from_rsa_pem(&read_data(path)?)?),
claims_checker: None,
})
}
pub fn from_ec_file(path: &str) -> Result<Authorizer<C>, AuthError> {
let k = DecodingKey::from_ec_der(&read_data(path)?);
Ok(Authorizer {
key_source: KeySource::DecodingKeySource(k),
claims_checker: None,
})
}
pub fn from_ed_file(path: &str) -> Result<Authorizer<C>, AuthError> {
let k = DecodingKey::from_ed_der(&read_data(path)?);
Ok(Authorizer {
key_source: KeySource::DecodingKeySource(k),
claims_checker: None,
})
}
pub fn from_secret(secret: &str) -> Result<Authorizer<C>, AuthError> {
let k = DecodingKey::from_secret(secret.as_bytes());
Ok(Authorizer {
key_source: KeySource::DecodingKeySource(k),
claims_checker: None,
})
}
pub fn from_jwks_url(url: &str, claims_checker: Option<FnClaimsChecker<C>>) -> Result<Authorizer<C>, AuthError> {
let key_store_manager = KeyStoreManager::with_refresh_interval(url, Duration::from_secs(60));
Ok(Authorizer {
key_source: KeySource::KeyStoreSource(key_store_manager),
claims_checker,
})
}
pub async fn check_auth(&self, token: &str) -> Result<TokenData<C>, AuthError> {
let header = decode_header(token)?;
let validation = Validation::new(header.alg);
let decoding_key = self.key_source.get_key(header).await?;
let token_data = decode::<C>(token, &decoding_key, &validation)?;
if let Some(ref checker) = self.claims_checker {
if !checker.check(&token_data.claims) {
return Err(AuthError::InvalidClaims());
}
}
Ok(token_data)
}
}
#[cfg(test)]
mod tests {
use jsonwebtoken::{Algorithm, Header};
use serde_json::Value;
use super::Authorizer;
#[tokio::test]
async fn from_secret() {
let h = Header::new(Algorithm::HS256);
let a = Authorizer::<Value>::from_secret("xxxxxx").unwrap();
let k = a.key_source.get_key(h);
assert!(k.await.is_ok());
}
#[tokio::test]
async fn from_jwks() {
let jwks = r#"
{"keys": [{
"kid": "1",
"kty": "RSA",
"alg": "RS256",
"use": "sig",
"n": "2pQeZdxa7q093K7bj5h6-leIpxfTnuAxzXdhjfGEJHxmt2ekHyCBWWWXCBiDn2RTcEBcy6gZqOW45Uy_tw-5e-Px1xFj1PykGEkRlOpYSAeWsNaAWvvpGB9m4zQ0PgZeMDDXE5IIBrY6YAzmGQxV-fcGGLhJnXl0-5_z7tKC7RvBoT3SGwlc_AmJqpFtTpEBn_fDnyqiZbpcjXYLExFpExm41xDitRKHWIwfc3dV8_vlNntlxCPGy_THkjdXJoHv2IJmlhvmr5_h03iGMLWDKSywxOol_4Wc1BT7Hb6byMxW40GKwSJJ4p7W8eI5mqggRHc8jlwSsTN9LZ2VOvO-XiVShZRVg7JeraGAfWwaIgIJ1D8C1h5Pi0iFpp2suxpHAXHfyLMJXuVotpXbDh4NDX-A4KRMgaxcfAcui_x6gybksq6gF90-9nfQfmVMVJctZ6M-FvRr-itd1Nef5WAtwUp1qyZygAXU3cH3rarscajmurOsP6dE1OHl3grY_eZhQxk33VBK9lavqNKPg6Q_PLiq1ojbYBj3bcYifJrsNeQwxldQP83aWt5rGtgZTehKVJwa40Uy_Grae1iRnsDtdSy5sTJIJ6EiShnWAdMoGejdiI8vpkjrdU8SWH8lv1KXI54DsbyAuke2cYz02zPWc6JEotQqI0HwhzU0KHyoY4s",
"e": "AQAB"
}]}
"#;
let a = Authorizer::<Value>::from_jwks(jwks, None).unwrap();
let k = a.key_source.get_key(Header::new(Algorithm::RS256));
assert!(k.await.is_ok());
}
#[tokio::test]
async fn from_file() {
let a = Authorizer::<Value>::from_rsa_file("../config/jwtRS256.key.pub").unwrap();
let k = a.key_source.get_key(Header::new(Algorithm::RS256));
assert!(k.await.is_ok());
let a = Authorizer::<Value>::from_ec_file("../config/ec256-public.pem").unwrap();
let k = a.key_source.get_key(Header::new(Algorithm::ES256));
assert!(k.await.is_ok());
let a = Authorizer::<Value>::from_ed_file("../config/ed25519-public.pem").unwrap();
let k = a.key_source.get_key(Header::new(Algorithm::EdDSA));
assert!(k.await.is_ok());
}
#[tokio::test]
async fn from_file_errors() {
let a = Authorizer::<Value>::from_rsa_file("./config/does-not-exist.pem");
println!("{:?}", a.as_ref().err());
assert!(a.is_err());
}
}

View file

@ -0,0 +1,60 @@
use axum::{
extract::rejection::TypedHeaderRejection,
http::StatusCode,
response::{IntoResponse, Response},
};
use jsonwebtoken::Algorithm;
use thiserror::Error;
use tracing::log::warn;
#[derive(Debug, Error)]
pub enum AuthError {
#[error(transparent)]
JwksSerialisationError(#[from] serde_json::Error),
#[error(transparent)]
JwksRefreshError(#[from] reqwest::Error),
#[error(transparent)]
KeyFileError(#[from] std::io::Error),
#[error("InvalidKey {0}")]
InvalidKey(String),
#[error("Invalid Kid {0}")]
InvalidKid(String),
#[error("Invalid Key Algorithm {0:?}")]
InvalidKeyAlg(Algorithm),
#[error(transparent)]
InvalidTokenHeader(#[from] TypedHeaderRejection),
#[error(transparent)]
InvalidToken(#[from] jsonwebtoken::errors::Error),
#[error("Invalid Claim")]
InvalidClaims(),
}
impl IntoResponse for AuthError {
fn into_response(self) -> Response {
warn!("AuthError: {}", &self);
let (status, error_message) = match self {
AuthError::JwksRefreshError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
AuthError::KeyFileError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
AuthError::InvalidKid(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg),
AuthError::InvalidTokenHeader(_) => (StatusCode::BAD_REQUEST, self.to_string()),
AuthError::InvalidToken(_) => (StatusCode::BAD_REQUEST, self.to_string()),
AuthError::InvalidKey(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg),
AuthError::JwksSerialisationError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
AuthError::InvalidKeyAlg(_) => (StatusCode::BAD_REQUEST, self.to_string()),
AuthError::InvalidClaims() => (StatusCode::FORBIDDEN, self.to_string()),
};
let body = axum::Json(serde_json::json!({
"error": error_message,
}));
(status, body).into_response()
}
}

View file

@ -0,0 +1,373 @@
use jsonwebtoken::{
jwk::{Jwk, JwkSet},
Algorithm, DecodingKey,
};
use std::{
sync::Arc,
time::{Duration, Instant},
};
use tokio::sync::Mutex;
use crate::error::AuthError;
#[derive(Clone)]
pub enum RefreshStrategy {
/// refresh periodicaly
Interval(Duration),
/// when kid not found in the store
KeyNotFound,
// other strategies? KeyNotFoundOrInterval(Duration), Once,
}
#[derive(Clone)]
pub struct KeyStoreManager {
key_url: String,
refresh: RefreshStrategy,
/// in case of fail loading (error or key not found), minimal interval
minimal_refresh_interval: Duration,
keystore: Arc<Mutex<KeyStore>>,
}
pub struct KeyStore {
/// key set
jwks: JwkSet,
/// time of the last successfully loaded jwkset
load_time: Option<Instant>,
/// time of the last failed load
fail_time: Option<Instant>,
}
impl KeyStoreManager {
pub(crate) fn new(url: &str, refresh: RefreshStrategy) -> KeyStoreManager {
KeyStoreManager {
key_url: url.to_owned(),
refresh,
minimal_refresh_interval: Duration::from_secs(5), // TODO: make configurable
keystore: Arc::new(Mutex::new(KeyStore {
jwks: JwkSet { keys: vec![] },
load_time: None,
fail_time: None,
})),
}
}
pub(crate) fn with_refresh(url: &str) -> KeyStoreManager {
KeyStoreManager::new(url, RefreshStrategy::KeyNotFound)
}
pub(crate) fn with_refresh_interval(url: &str, interval: Duration) -> KeyStoreManager {
KeyStoreManager::new(url, RefreshStrategy::Interval(interval))
}
pub(crate) async fn get_key(&self, header: &jsonwebtoken::Header) -> Result<jsonwebtoken::DecodingKey, AuthError> {
let kstore = self.keystore.clone();
let mut ks_gard = kstore.lock().await;
let key = match self.refresh {
RefreshStrategy::Interval(refresh_interval) => {
if ks_gard.should_refresh(refresh_interval) && ks_gard.can_refresh(self.minimal_refresh_interval) {
ks_gard.refresh(&self.key_url, &[]).await?;
}
if let Some(ref kid) = header.kid {
ks_gard
.find_kid(kid)
.ok_or_else(|| AuthError::InvalidKid(kid.to_owned()))?
} else {
ks_gard
.find_alg(&header.alg)
.ok_or(AuthError::InvalidKeyAlg(header.alg))?
}
}
RefreshStrategy::KeyNotFound => {
if let Some(ref kid) = header.kid {
let jwk_opt = ks_gard.find_kid(kid);
if let Some(jwk) = jwk_opt {
jwk
} else if ks_gard.can_refresh(self.minimal_refresh_interval) {
ks_gard.refresh(&self.key_url, &[("kid", kid)]).await?;
ks_gard
.find_kid(kid)
.ok_or_else(|| AuthError::InvalidKid(kid.to_owned()))?
} else {
return Err(AuthError::InvalidKid(kid.to_owned()));
}
} else {
let jwk_opt = ks_gard.find_alg(&header.alg);
// .ok_or(AuthError::InvalidKeyAlg(header.alg))?
if let Some(jwk) = jwk_opt {
jwk
} else if ks_gard.can_refresh(self.minimal_refresh_interval) {
ks_gard
.refresh(
&self.key_url,
&[(
"alg",
&serde_json::to_string(&header.alg)
.map_err(|_| AuthError::InvalidKeyAlg(header.alg))?,
)],
)
.await?;
ks_gard
.find_alg(&header.alg)
.ok_or_else(|| AuthError::InvalidKeyAlg(header.alg))?
} else {
return Err(AuthError::InvalidKeyAlg(header.alg));
}
}
}
};
DecodingKey::from_jwk(key).map_err(|err| AuthError::InvalidKey(err.to_string()))
}
}
impl KeyStore {
fn should_refresh(&self, refresh_interval: Duration) -> bool {
if let Some(t) = self.load_time {
t.elapsed() > refresh_interval
} else {
true
}
}
fn can_refresh(&self, minimal_refresh_interval: Duration) -> bool {
if let Some(ft) = self.fail_time {
if let Some(lt) = self.load_time {
ft.elapsed() > minimal_refresh_interval && lt.elapsed() > minimal_refresh_interval
} else {
ft.elapsed() > minimal_refresh_interval
}
} else if let Some(lt) = self.load_time {
lt.elapsed() > minimal_refresh_interval
} else {
true
}
}
async fn refresh(&mut self, key_url: &str, qparam: &[(&str, &str)]) -> Result<(), AuthError> {
reqwest::Client::new()
.get(key_url)
.query(qparam)
.send()
.await
.map_err(AuthError::JwksRefreshError)?
.json::<JwkSet>()
.await
.map(|jwks| {
self.load_time = Some(Instant::now());
self.jwks = jwks;
Ok(())
})
.map_err(|e| {
self.fail_time = Some(Instant::now());
AuthError::JwksRefreshError(e)
})?
}
/// Find the key in the set that matches the given key id, if any.
pub fn find_kid(&self, kid: &str) -> Option<&Jwk> {
self.jwks.find(kid)
}
/// Find the key in the set that matches the given key id, if any.
pub fn find_alg(&self, alg: &Algorithm) -> Option<&Jwk> {
self.jwks.keys.iter().find(|jwk| {
if let Some(ref a) = jwk.common.algorithm {
alg == a
} else {
false
}
})
}
/// Find first key.
pub fn find_first(&self) -> Option<&Jwk> {
self.jwks.keys.get(0)
}
}
#[cfg(test)]
mod tests {
use std::time::{Duration, Instant};
use jsonwebtoken::Algorithm;
use jsonwebtoken::{jwk::Jwk, Header};
use wiremock::{
matchers::{method, path},
Mock, MockServer, ResponseTemplate,
};
use crate::jwks::key_store_manager::{KeyStore, KeyStoreManager};
#[test]
fn keystore_should_refresh() {
let ks = KeyStore {
jwks: jsonwebtoken::jwk::JwkSet { keys: vec![] },
fail_time: None,
load_time: Some(Instant::now()),
};
assert!(!ks.should_refresh(Duration::from_secs(5)));
let ks = KeyStore {
jwks: jsonwebtoken::jwk::JwkSet { keys: vec![] },
fail_time: None,
load_time: Some(Instant::now() - Duration::from_secs(6)),
};
assert!(ks.should_refresh(Duration::from_secs(5)));
}
#[test]
fn keystore_can_refresh() {
let ks = KeyStore {
jwks: jsonwebtoken::jwk::JwkSet { keys: vec![] },
fail_time: Some(Instant::now() - Duration::from_secs(5)),
load_time: None,
};
assert!(ks.can_refresh(Duration::from_secs(4)));
assert!(!ks.can_refresh(Duration::from_secs(6)));
let ks = KeyStore {
jwks: jsonwebtoken::jwk::JwkSet { keys: vec![] },
fail_time: None,
load_time: Some(Instant::now() - Duration::from_secs(5)),
};
assert!(ks.can_refresh(Duration::from_secs(4)));
assert!(!ks.can_refresh(Duration::from_secs(6)));
}
#[test]
fn find_kid() {
let jwk0: Jwk = serde_json::from_str(r#"{"kid":"1","kty":"RSA","alg":"RS256","n":"xxxx","e":"AQAB"}"#).unwrap();
let jwk1: Jwk = serde_json::from_str(r#"{"kid":"2","kty":"RSA","alg":"RS256","n":"xxxx","e":"AQAB"}"#).unwrap();
let ks = KeyStore {
load_time: None,
fail_time: None,
jwks: jsonwebtoken::jwk::JwkSet { keys: vec![jwk0, jwk1] },
};
assert!(ks.find_kid("1").is_some());
assert!(ks.find_kid("2").is_some());
assert!(ks.find_kid("3").is_none());
}
#[test]
fn find_alg() {
let jwk0: Jwk = serde_json::from_str(r#"{"kty": "RSA", "alg": "RS256", "n": "xxx","e": "yyy"}"#).unwrap();
let ks = KeyStore {
load_time: None,
fail_time: None,
jwks: jsonwebtoken::jwk::JwkSet { keys: vec![jwk0] },
};
assert!(ks.find_alg(&Algorithm::RS256).is_some());
assert!(ks.find_alg(&Algorithm::EdDSA).is_none());
}
async fn mock_jwks_response_once(mock_server: &MockServer, jwk: &str) {
let jwk0: Jwk = serde_json::from_str(jwk).unwrap();
let jwks = jsonwebtoken::jwk::JwkSet { keys: vec![jwk0] };
Mock::given(method("GET"))
.and(path("/"))
.respond_with(ResponseTemplate::new(200).set_body_json(&jwks))
.expect(1)
.mount(&mock_server)
.await;
}
fn build_header(kid: &str, alg: Algorithm) -> Header {
let mut header = Header::new(alg);
header.kid = Some(kid.to_owned());
header
}
#[tokio::test]
async fn keystore_manager_find_key_with_refresh_interval() {
let mock_server = MockServer::start().await;
mock_jwks_response_once(
&mock_server,
r#"{
"kty": "OKP",
"use": "sig",
"crv": "Ed25519",
"x": "uWtSkE-I9aTMYTTvuTE1rtu0rNdxp3DU33cJ_ksL1Gk",
"kid": "key-ed",
"alg": "EdDSA"
}"#,
)
.await;
let ksm = KeyStoreManager::with_refresh_interval(&mock_server.uri(), Duration::from_secs(3000));
let r = ksm.get_key(&Header::new(Algorithm::EdDSA)).await;
assert!(r.is_ok());
mock_server.verify().await;
}
#[tokio::test]
async fn keystore_manager_find_key_with_refresh() {
let mock_server = MockServer::start().await;
mock_jwks_response_once(
&mock_server,
r#"{
"kty": "OKP",
"use": "sig",
"crv": "Ed25519",
"x": "uWtSkE-I9aTMYTTvuTE1rtu0rNdxp3DU33cJ_ksL1Gk",
"kid": "key-ed",
"alg": "EdDSA"
}"#,
)
.await;
let mut ksm = KeyStoreManager::with_refresh(&mock_server.uri());
// STEP 1: initial (lazy) reloading
let r = ksm.get_key(&build_header("key-ed", Algorithm::EdDSA)).await;
assert!(r.is_ok());
mock_server.verify().await;
// STEP2: new kid -> reloading ksm
mock_server.reset().await;
mock_jwks_response_once(
&mock_server,
r#"{
"kty": "OKP",
"use": "sig",
"crv": "Ed25519",
"x": "uWtSkE-I9aTMYTTvuTE1rtu0rNdxp3DU33cJ_ksL1Gk",
"kid": "key-ed02",
"alg": "EdDSA"
}"#,
)
.await;
let h = build_header("key-ed02", Algorithm::EdDSA);
assert!(ksm.get_key(&h).await.is_err());
ksm.minimal_refresh_interval = Duration::from_millis(100);
tokio::time::sleep(Duration::from_millis(101)).await;
assert!(ksm.get_key(&h).await.is_ok());
mock_server.verify().await;
// STEP3: new algorithm -> try to reload
mock_server.reset().await;
mock_jwks_response_once(
&mock_server,
r#"{
"kty": "EC",
"crv": "P-256",
"x": "w7JAoU_gJbZJvV-zCOvU9yFJq0FNC_edCMRM78P8eQQ",
"y": "wQg1EytcsEmGrM70Gb53oluoDbVhCZ3Uq3hHMslHVb4",
"kid": "ec01",
"alg": "ES256",
"use": "sig"
}"#,
)
.await;
let h = Header::new(Algorithm::ES256);
assert!(ksm.get_key(&h).await.is_err());
tokio::time::sleep(Duration::from_millis(101)).await;
assert!(ksm.get_key(&h).await.is_ok());
mock_server.verify().await;
}
}

View file

@ -0,0 +1,24 @@
use jsonwebtoken::{DecodingKey, Header};
use crate::error::AuthError;
use self::key_store_manager::KeyStoreManager;
pub mod key_store_manager;
#[derive(Clone)]
pub enum KeySource {
KeyStoreSource(KeyStoreManager),
DecodingKeySource(DecodingKey),
}
impl KeySource {
pub async fn get_key(&self, header: Header) -> Result<DecodingKey, AuthError> {
match self {
KeySource::KeyStoreSource(kstore) => kstore.get_key(&header).await,
KeySource::DecodingKeySource(key) => {
Ok(key.clone()) // TODO: clone -> &
}
}
}
}

286
jwt-authorizer/src/layer.rs Normal file
View file

@ -0,0 +1,286 @@
use axum::http::Request;
use axum::response::IntoResponse;
use axum::{body::Body, response::Response};
use futures_core::ready;
use futures_util::future::BoxFuture;
use headers::authorization::Bearer;
use headers::{Authorization, HeaderMapExt};
use http::StatusCode;
use pin_project::pin_project;
use serde::de::DeserializeOwned;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tower_layer::Layer;
use tower_service::Service;
use crate::authorizer::{Authorizer, FnClaimsChecker};
/// Authorizer Layer builder
///
/// - initialisation of the Authorizer from jwks, rsa, ed, ec or secret
/// - can define a checker (jwt claims check)
pub struct JwtAuthorizer<C>
where
C: Clone + DeserializeOwned,
{
url: Option<&'static str>,
claims_checker: Option<FnClaimsChecker<C>>,
}
/// layer builder
impl<C> JwtAuthorizer<C>
where
C: Clone + DeserializeOwned + Send + Sync,
{
pub fn new() -> Self {
JwtAuthorizer {
url: None,
claims_checker: None,
}
}
pub fn from_jwks_url(mut self, url: &'static str) -> JwtAuthorizer<C> {
self.url = Some(url);
self
}
pub fn from_rsa_pem(mut self, path: &'static str) -> JwtAuthorizer<C> {
// TODO
self
}
pub fn from_ec_der(mut self, path: &'static str) -> JwtAuthorizer<C> {
// TODO
self
}
pub fn from_ed_der(mut self, path: &'static str) -> JwtAuthorizer<C> {
// TODO
self
}
pub fn from_secret(mut self, path: &'static str) -> JwtAuthorizer<C> {
// TODO
self
}
/// layer that checks token validity and claim constraints (custom function)
pub fn with_check(mut self, checker_fn: fn(&C) -> bool) -> JwtAuthorizer<C> {
self.claims_checker = Some(FnClaimsChecker { checker_fn });
self
}
/// build axum layer
pub fn layer(&self) -> AsyncAuthorizationLayer<C> {
// TODO: replace unwrap
let auth = Arc::new(Authorizer::from_jwks_url(self.url.unwrap(), self.claims_checker.clone()).unwrap());
AsyncAuthorizationLayer::new(auth)
}
}
/// Trait for authorizing requests.
pub trait AsyncAuthorizer<B> {
type RequestBody;
type ResponseBody;
type Future: Future<Output = Result<Request<Self::RequestBody>, Response<Self::ResponseBody>>>;
/// Authorize the request.
///
/// If the future resolves to `Ok(request)` then the request is allowed through, otherwise not.
fn authorize(&self, request: Request<B>) -> Self::Future;
}
impl<B, S, C> AsyncAuthorizer<B> for AsyncAuthorizationService<S, C>
where
B: Send + Sync + 'static,
C: Clone + DeserializeOwned + Send + Sync + 'static,
{
type RequestBody = B;
type ResponseBody = Body;
type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
fn authorize(&self, mut request: Request<B>) -> Self::Future {
let authorizer = self.auth.clone();
let h = request.headers();
let bearer: Authorization<Bearer> = h.typed_get().unwrap();
Box::pin(async move {
if let Ok(token_data) = authorizer.check_auth(bearer.token()).await {
// Set `token_data` as a request extension so it can be accessed by other
// services down the stack.
request.extensions_mut().insert(token_data);
Ok(request)
} else {
let unauthorized_response = Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(Body::empty())
.unwrap();
Err(unauthorized_response)
}
})
}
}
// -------------- Layer -----------------
#[derive(Clone)]
pub struct AsyncAuthorizationLayer<C>
where
C: Clone + DeserializeOwned + Send,
{
auth: Arc<Authorizer<C>>,
}
impl<C> AsyncAuthorizationLayer<C>
where
C: Clone + DeserializeOwned + Send,
{
pub fn new(auth: Arc<Authorizer<C>>) -> AsyncAuthorizationLayer<C> {
Self { auth }
}
}
impl<S, C> Layer<S> for AsyncAuthorizationLayer<C>
where
C: Clone + DeserializeOwned + Send + Sync,
{
type Service = AsyncAuthorizationService<S, C>;
fn layer(&self, inner: S) -> Self::Service {
AsyncAuthorizationService::new(inner, self.auth.clone())
}
}
// ---------- AsyncAuthorizationService --------
#[derive(Clone)]
pub struct AsyncAuthorizationService<S, C>
where
C: Clone + DeserializeOwned + Send + Sync,
{
pub inner: S,
pub auth: Arc<Authorizer<C>>,
}
impl<S, C> AsyncAuthorizationService<S, C>
where
C: Clone + DeserializeOwned + Send + Sync,
{
pub fn get_ref(&self) -> &S {
&self.inner
}
/// Gets a mutable reference to the underlying service.
pub fn get_mut(&mut self) -> &mut S {
&mut self.inner
}
/// Consumes `self`, returning the underlying service.
pub fn into_inner(self) -> S {
self.inner
}
}
impl<S, C> AsyncAuthorizationService<S, C>
where
C: Clone + DeserializeOwned + Send + Sync,
{
/// Authorize requests using a custom scheme.
///
/// The `Authorization` header is required to have the value provided.
pub fn new(inner: S, auth: Arc<Authorizer<C>>) -> AsyncAuthorizationService<S, C> {
Self { inner, auth }
}
}
impl<ReqBody, S, C> Service<Request<ReqBody>> for AsyncAuthorizationService<S, C>
where
ReqBody: Send + Sync + 'static,
S: Service<Request<ReqBody>, Response = Response> + Clone,
C: Clone + DeserializeOwned + Send + Sync + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = ResponseFuture<S, ReqBody, C>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let inner = self.inner.clone();
let auth_fut = self.authorize(req);
ResponseFuture {
state: State::Authorize { auth_fut },
service: inner,
}
}
}
#[pin_project]
/// Response future for [`AsyncAuthorizationService`].
pub struct ResponseFuture<S, ReqBody, C>
where
S: Service<Request<ReqBody>, Response = Response>,
ReqBody: Send + Sync + 'static,
C: Clone + DeserializeOwned + Send + Sync + 'static,
{
#[pin]
state: State<<AsyncAuthorizationService<S, C> as AsyncAuthorizer<ReqBody>>::Future, S::Future>,
service: S,
}
#[pin_project(project = StateProj)]
enum State<A, SFut> {
Authorize {
#[pin]
auth_fut: A,
},
Authorized {
#[pin]
svc_fut: SFut,
},
}
impl<S, ReqBody, C> Future for ResponseFuture<S, ReqBody, C>
where
S: Service<Request<ReqBody>, Response = Response>,
ReqBody: Send + Sync + 'static,
C: Clone + DeserializeOwned + Send + Sync,
{
type Output = Result<S::Response, S::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
loop {
match this.state.as_mut().project() {
StateProj::Authorize { auth_fut } => {
let auth = ready!(auth_fut.poll(cx));
match auth {
Ok(req) => {
let svc_fut = this.service.call(req);
this.state.set(State::Authorized { svc_fut })
}
Err(res) => {
tracing::info!("err: {:?}", res);
let r = (StatusCode::FORBIDDEN, format!("Unauthorized : {:?}", res)).into_response();
// TODO: replace r by res (type problems: res should be already a 403 error response)
return Poll::Ready(Ok(r));
}
};
}
StateProj::Authorized { svc_fut } => {
return svc_fut.poll(cx);
}
}
}
}
}

34
jwt-authorizer/src/lib.rs Normal file
View file

@ -0,0 +1,34 @@
#![doc = include_str!("../docs/README.md")]
use axum::{async_trait, extract::FromRequestParts, http::request::Parts};
use jsonwebtoken::TokenData;
use serde::de::DeserializeOwned;
pub use self::error::AuthError;
pub use layer::JwtAuthorizer;
pub mod authorizer;
pub mod error;
pub mod jwks;
pub mod layer;
/// Claims serialized using T
#[derive(Debug, Clone, Copy, Default)]
pub struct JwtClaims<T>(pub T);
#[async_trait]
impl<T, S> FromRequestParts<S> for JwtClaims<T>
where
T: DeserializeOwned + Send + Sync + Clone + 'static,
S: Send + Sync,
{
type Rejection = error::AuthError;
async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
let claims = parts.extensions.get::<TokenData<T>>().unwrap(); // TODO: unwrap -> err
Ok(JwtClaims(claims.claims.clone())) // TODO: unwrap -> err
}
}
#[cfg(test)]
mod tests;

View file

@ -0,0 +1 @@
// TODO: tests