mirror of
https://github.com/TECHNOFAB11/jwt-authorizer.git
synced 2025-12-11 23:50:07 +01:00
Cloning the inner service to use in call can mean a not-ready clone gets used which violates the tower service preconditions. Replace the cloned service with the ready service to ensure the right copy gets used. See https://docs.rs/tower/0.4.13/tower/trait.Service.html#be-careful-when-cloning-inner-services for more details Signed-off-by: Sjoerd Simons <sjoerd@collabora.com>
415 lines
12 KiB
Rust
415 lines
12 KiB
Rust
use axum::http::Request;
|
|
use futures_core::ready;
|
|
use futures_util::future::BoxFuture;
|
|
use headers::authorization::Bearer;
|
|
use headers::{Authorization, HeaderMapExt};
|
|
use pin_project::pin_project;
|
|
use serde::de::DeserializeOwned;
|
|
use std::future::Future;
|
|
use std::pin::Pin;
|
|
use std::sync::Arc;
|
|
use std::task::{Context, Poll};
|
|
use tower_layer::Layer;
|
|
use tower_service::Service;
|
|
|
|
use crate::authorizer::{Authorizer, FnClaimsChecker, KeySourceType};
|
|
use crate::error::InitError;
|
|
use crate::jwks::key_store_manager::Refresh;
|
|
use crate::validation::Validation;
|
|
use crate::{layer, AuthError, RefreshStrategy};
|
|
|
|
/// Authorizer Layer builder
|
|
///
|
|
/// - initialisation of the Authorizer from jwks, rsa, ed, ec or secret
|
|
/// - can define a checker (jwt claims check)
|
|
pub struct JwtAuthorizer<C>
|
|
where
|
|
C: Clone + DeserializeOwned,
|
|
{
|
|
key_source_type: KeySourceType,
|
|
refresh: Option<Refresh>,
|
|
claims_checker: Option<FnClaimsChecker<C>>,
|
|
validation: Option<Validation>,
|
|
jwt_source: JwtSource,
|
|
}
|
|
|
|
/// authorization layer builder
|
|
impl<C> JwtAuthorizer<C>
|
|
where
|
|
C: Clone + DeserializeOwned + Send + Sync,
|
|
{
|
|
/// Builds Authorizer Layer from a OpenId Connect discover metadata
|
|
pub fn from_oidc(issuer: &str) -> JwtAuthorizer<C> {
|
|
JwtAuthorizer {
|
|
key_source_type: KeySourceType::Discovery(issuer.to_string()),
|
|
refresh: Default::default(),
|
|
claims_checker: None,
|
|
validation: None,
|
|
jwt_source: JwtSource::AuthorizationHeader,
|
|
}
|
|
}
|
|
|
|
/// Builds Authorizer Layer from a JWKS endpoint
|
|
pub fn from_jwks_url(url: &str) -> JwtAuthorizer<C> {
|
|
JwtAuthorizer {
|
|
key_source_type: KeySourceType::Jwks(url.to_owned()),
|
|
refresh: Default::default(),
|
|
claims_checker: None,
|
|
validation: None,
|
|
jwt_source: JwtSource::AuthorizationHeader,
|
|
}
|
|
}
|
|
|
|
/// Builds Authorizer Layer from a RSA PEM file
|
|
pub fn from_rsa_pem(path: &str) -> JwtAuthorizer<C> {
|
|
JwtAuthorizer {
|
|
key_source_type: KeySourceType::RSA(path.to_owned()),
|
|
refresh: Default::default(),
|
|
claims_checker: None,
|
|
validation: None,
|
|
jwt_source: JwtSource::AuthorizationHeader,
|
|
}
|
|
}
|
|
|
|
/// Builds Authorizer Layer from an RSA PEM raw text
|
|
pub fn from_rsa_pem_text(text: &str) -> JwtAuthorizer<C> {
|
|
JwtAuthorizer {
|
|
key_source_type: KeySourceType::RSAString(text.to_owned()),
|
|
refresh: Default::default(),
|
|
claims_checker: None,
|
|
validation: None,
|
|
jwt_source: JwtSource::AuthorizationHeader,
|
|
}
|
|
}
|
|
|
|
/// Builds Authorizer Layer from a EC PEM file
|
|
pub fn from_ec_pem(path: &str) -> JwtAuthorizer<C> {
|
|
JwtAuthorizer {
|
|
key_source_type: KeySourceType::EC(path.to_owned()),
|
|
refresh: Default::default(),
|
|
claims_checker: None,
|
|
validation: None,
|
|
jwt_source: JwtSource::AuthorizationHeader,
|
|
}
|
|
}
|
|
|
|
/// Builds Authorizer Layer from a EC PEM raw text
|
|
pub fn from_ec_pem_text(text: &str) -> JwtAuthorizer<C> {
|
|
JwtAuthorizer {
|
|
key_source_type: KeySourceType::ECString(text.to_owned()),
|
|
refresh: Default::default(),
|
|
claims_checker: None,
|
|
validation: None,
|
|
jwt_source: JwtSource::AuthorizationHeader,
|
|
}
|
|
}
|
|
|
|
/// Builds Authorizer Layer from a EC PEM file
|
|
pub fn from_ed_pem(path: &str) -> JwtAuthorizer<C> {
|
|
JwtAuthorizer {
|
|
key_source_type: KeySourceType::ED(path.to_owned()),
|
|
refresh: Default::default(),
|
|
claims_checker: None,
|
|
validation: None,
|
|
jwt_source: JwtSource::AuthorizationHeader,
|
|
}
|
|
}
|
|
|
|
/// Builds Authorizer Layer from a EC PEM raw text
|
|
pub fn from_ed_pem_text(text: &str) -> JwtAuthorizer<C> {
|
|
JwtAuthorizer {
|
|
key_source_type: KeySourceType::EDString(text.to_owned()),
|
|
refresh: Default::default(),
|
|
claims_checker: None,
|
|
validation: None,
|
|
jwt_source: JwtSource::AuthorizationHeader,
|
|
}
|
|
}
|
|
|
|
/// Builds Authorizer Layer from a secret phrase
|
|
pub fn from_secret(secret: &str) -> JwtAuthorizer<C> {
|
|
JwtAuthorizer {
|
|
key_source_type: KeySourceType::Secret(secret.to_owned()),
|
|
refresh: Default::default(),
|
|
claims_checker: None,
|
|
validation: None,
|
|
jwt_source: JwtSource::AuthorizationHeader,
|
|
}
|
|
}
|
|
|
|
/// Refreshes configuration for jwk store
|
|
pub fn refresh(mut self, refresh: Refresh) -> JwtAuthorizer<C> {
|
|
if self.refresh.is_some() {
|
|
tracing::warn!("More than one refresh configuration found!");
|
|
}
|
|
self.refresh = Some(refresh);
|
|
self
|
|
}
|
|
|
|
/// no refresh, jwks will be loaded juste once
|
|
pub fn no_refresh(mut self) -> JwtAuthorizer<C> {
|
|
if self.refresh.is_some() {
|
|
tracing::warn!("More than one refresh configuration found!");
|
|
}
|
|
self.refresh = Some(Refresh {
|
|
strategy: RefreshStrategy::NoRefresh,
|
|
..Default::default()
|
|
});
|
|
self
|
|
}
|
|
|
|
/// configures token content check (custom function), if false a 403 will be sent.
|
|
/// (AuthError::InvalidClaims())
|
|
pub fn check(mut self, checker_fn: fn(&C) -> bool) -> JwtAuthorizer<C> {
|
|
self.claims_checker = Some(FnClaimsChecker { checker_fn });
|
|
|
|
self
|
|
}
|
|
|
|
pub fn validation(mut self, validation: Validation) -> JwtAuthorizer<C> {
|
|
self.validation = Some(validation);
|
|
|
|
self
|
|
}
|
|
|
|
/// configures the source of the bearer token
|
|
///
|
|
/// (default: AuthorizationHeader)
|
|
pub fn jwt_source(mut self, src: JwtSource) -> JwtAuthorizer<C> {
|
|
self.jwt_source = src;
|
|
|
|
self
|
|
}
|
|
|
|
/// Build axum layer
|
|
pub async fn layer(self) -> Result<AsyncAuthorizationLayer<C>, InitError> {
|
|
let val = self.validation.unwrap_or_default();
|
|
let auth = Arc::new(Authorizer::build(&self.key_source_type, self.claims_checker, self.refresh, val).await?);
|
|
Ok(AsyncAuthorizationLayer::new(auth, self.jwt_source))
|
|
}
|
|
}
|
|
/// Trait for authorizing requests.
|
|
pub trait AsyncAuthorizer<B> {
|
|
type RequestBody;
|
|
type Future: Future<Output = Result<Request<Self::RequestBody>, AuthError>>;
|
|
|
|
/// Authorize the request.
|
|
///
|
|
/// If the future resolves to `Ok(request)` then the request is allowed through, otherwise not.
|
|
fn authorize(&self, request: Request<B>) -> Self::Future;
|
|
}
|
|
|
|
impl<B, S, C> AsyncAuthorizer<B> for AsyncAuthorizationService<S, C>
|
|
where
|
|
B: Send + Sync + 'static,
|
|
C: Clone + DeserializeOwned + Send + Sync + 'static,
|
|
{
|
|
type RequestBody = B;
|
|
type Future = BoxFuture<'static, Result<Request<B>, AuthError>>;
|
|
|
|
fn authorize(&self, mut request: Request<B>) -> Self::Future {
|
|
let authorizer = self.auth.clone();
|
|
let h = request.headers();
|
|
|
|
let token = match &self.jwt_source {
|
|
layer::JwtSource::AuthorizationHeader => {
|
|
let bearer_o: Option<Authorization<Bearer>> = h.typed_get();
|
|
bearer_o.map(|b| String::from(b.0.token()))
|
|
}
|
|
layer::JwtSource::Cookie(name) => h
|
|
.typed_get::<headers::Cookie>()
|
|
.and_then(|c| c.get(name.as_str()).map(String::from)),
|
|
};
|
|
Box::pin(async move {
|
|
if let Some(token) = token {
|
|
authorizer.check_auth(token.as_str()).await.map(|token_data| {
|
|
// Set `token_data` as a request extension so it can be accessed by other
|
|
// services down the stack.
|
|
request.extensions_mut().insert(token_data);
|
|
|
|
request
|
|
})
|
|
} else {
|
|
Err(AuthError::MissingToken())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// -------------- Layer -----------------
|
|
|
|
#[derive(Clone)]
|
|
pub struct AsyncAuthorizationLayer<C>
|
|
where
|
|
C: Clone + DeserializeOwned + Send,
|
|
{
|
|
auth: Arc<Authorizer<C>>,
|
|
jwt_source: JwtSource,
|
|
}
|
|
|
|
impl<C> AsyncAuthorizationLayer<C>
|
|
where
|
|
C: Clone + DeserializeOwned + Send,
|
|
{
|
|
pub fn new(auth: Arc<Authorizer<C>>, jwt_source: JwtSource) -> AsyncAuthorizationLayer<C> {
|
|
Self { auth, jwt_source }
|
|
}
|
|
}
|
|
|
|
impl<S, C> Layer<S> for AsyncAuthorizationLayer<C>
|
|
where
|
|
C: Clone + DeserializeOwned + Send + Sync,
|
|
{
|
|
type Service = AsyncAuthorizationService<S, C>;
|
|
|
|
fn layer(&self, inner: S) -> Self::Service {
|
|
AsyncAuthorizationService::new(inner, self.auth.clone(), self.jwt_source.clone())
|
|
}
|
|
}
|
|
|
|
// ---------- AsyncAuthorizationService --------
|
|
|
|
/// Source of the bearer token
|
|
#[derive(Clone)]
|
|
pub enum JwtSource {
|
|
/// Storing the bearer token in Authorization header
|
|
///
|
|
/// (default)
|
|
AuthorizationHeader,
|
|
/// Cookies
|
|
///
|
|
/// (be careful when using cookies, some precautions must be taken, cf. RFC6750)
|
|
Cookie(String),
|
|
// TODO: "Form-Encoded Content Parameter" may be added in the future (OAuth 2.1 / 5.2.1.2)
|
|
// FormParam,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct AsyncAuthorizationService<S, C>
|
|
where
|
|
C: Clone + DeserializeOwned + Send + Sync,
|
|
{
|
|
pub inner: S,
|
|
pub auth: Arc<Authorizer<C>>,
|
|
pub jwt_source: JwtSource,
|
|
}
|
|
|
|
impl<S, C> AsyncAuthorizationService<S, C>
|
|
where
|
|
C: Clone + DeserializeOwned + Send + Sync,
|
|
{
|
|
pub fn get_ref(&self) -> &S {
|
|
&self.inner
|
|
}
|
|
|
|
/// Gets a mutable reference to the underlying service.
|
|
pub fn get_mut(&mut self) -> &mut S {
|
|
&mut self.inner
|
|
}
|
|
|
|
/// Consumes `self`, returning the underlying service.
|
|
pub fn into_inner(self) -> S {
|
|
self.inner
|
|
}
|
|
}
|
|
|
|
impl<S, C> AsyncAuthorizationService<S, C>
|
|
where
|
|
C: Clone + DeserializeOwned + Send + Sync,
|
|
{
|
|
/// Authorize requests using a custom scheme.
|
|
///
|
|
/// The `Authorization` header is required to have the value provided.
|
|
pub fn new(inner: S, auth: Arc<Authorizer<C>>, jwt_source: JwtSource) -> AsyncAuthorizationService<S, C> {
|
|
Self { inner, auth, jwt_source }
|
|
}
|
|
}
|
|
|
|
impl<ReqBody, S, C> Service<Request<ReqBody>> for AsyncAuthorizationService<S, C>
|
|
where
|
|
ReqBody: Send + Sync + 'static,
|
|
S: Service<Request<ReqBody>> + Clone,
|
|
S::Response: From<AuthError>,
|
|
C: Clone + DeserializeOwned + Send + Sync + 'static,
|
|
{
|
|
type Response = S::Response;
|
|
type Error = S::Error;
|
|
type Future = ResponseFuture<S, ReqBody, C>;
|
|
|
|
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
|
self.inner.poll_ready(cx)
|
|
}
|
|
|
|
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
|
|
let inner = self.inner.clone();
|
|
// take the service that was ready
|
|
let inner = std::mem::replace(&mut self.inner, inner);
|
|
|
|
let auth_fut = self.authorize(req);
|
|
|
|
ResponseFuture {
|
|
state: State::Authorize { auth_fut },
|
|
service: inner,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[pin_project]
|
|
/// Response future for [`AsyncAuthorizationService`].
|
|
pub struct ResponseFuture<S, ReqBody, C>
|
|
where
|
|
S: Service<Request<ReqBody>>,
|
|
ReqBody: Send + Sync + 'static,
|
|
C: Clone + DeserializeOwned + Send + Sync + 'static,
|
|
{
|
|
#[pin]
|
|
state: State<<AsyncAuthorizationService<S, C> as AsyncAuthorizer<ReqBody>>::Future, S::Future>,
|
|
service: S,
|
|
}
|
|
|
|
#[pin_project(project = StateProj)]
|
|
enum State<A, SFut> {
|
|
Authorize {
|
|
#[pin]
|
|
auth_fut: A,
|
|
},
|
|
Authorized {
|
|
#[pin]
|
|
svc_fut: SFut,
|
|
},
|
|
}
|
|
|
|
impl<S, ReqBody, C> Future for ResponseFuture<S, ReqBody, C>
|
|
where
|
|
S: Service<Request<ReqBody>>,
|
|
S::Response: From<AuthError>,
|
|
ReqBody: Send + Sync + 'static,
|
|
C: Clone + DeserializeOwned + Send + Sync,
|
|
{
|
|
type Output = Result<S::Response, S::Error>;
|
|
|
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
let mut this = self.project();
|
|
|
|
loop {
|
|
match this.state.as_mut().project() {
|
|
StateProj::Authorize { auth_fut } => {
|
|
let auth = ready!(auth_fut.poll(cx));
|
|
match auth {
|
|
Ok(req) => {
|
|
let svc_fut = this.service.call(req);
|
|
this.state.set(State::Authorized { svc_fut })
|
|
}
|
|
Err(res) => {
|
|
tracing::info!("err: {:?}", res);
|
|
return Poll::Ready(Ok(res.into()));
|
|
}
|
|
};
|
|
}
|
|
StateProj::Authorized { svc_fut } => {
|
|
return svc_fut.poll(cx);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|