feat: Add support for tonic

Tonic and Axum are quite closely related; From a tower perspective the
main difference is in the Error type in the body for their Response.

This refactor the code a little bit and add conversions from AuthError
to a tonic's Response such that the exact same code can be used by both
Axum and tonic services

Signed-off-by: Sjoerd Simons <sjoerd@collabora.com>
This commit is contained in:
Sjoerd Simons 2023-04-17 21:23:39 +02:00
parent f45568a044
commit 5f3a08c4c7
6 changed files with 547 additions and 257 deletions

715
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
# jwt-authorizer # jwt-authorizer
JWT authorizer Layer for Axum. JWT authorizer Layer for Axum and Tonic.
[![Build status](https://github.com/cduvray/jwt-authorizer/actions/workflows/ci.yml/badge.svg?branch=main)](https://github.com/tokio-rs/cduvray/jwt-authorizer/workflows/ci.yml) [![Build status](https://github.com/cduvray/jwt-authorizer/actions/workflows/ci.yml/badge.svg?branch=main)](https://github.com/tokio-rs/cduvray/jwt-authorizer/workflows/ci.yml)
[![Crates.io](https://img.shields.io/crates/v/jwt-authorizer)](https://crates.io/crates/jwt-authorizer) [![Crates.io](https://img.shields.io/crates/v/jwt-authorizer)](https://crates.io/crates/jwt-authorizer)

View file

@ -1,6 +1,6 @@
[package] [package]
name = "jwt-authorizer" name = "jwt-authorizer"
description = "jwt authorizer middleware for axum" description = "jwt authorizer middleware for axum and tonic"
version = "0.9.0" version = "0.9.0"
edition = "2021" edition = "2021"
authors = ["cduvray <c_duvray@proton.me>"] authors = ["cduvray <c_duvray@proton.me>"]
@ -27,6 +27,7 @@ tower-layer = "0.3"
tower-service = "0.3" tower-service = "0.3"
tracing = "0.1" tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tonic = { version = "0.9.2", optional = true }
[dev-dependencies] [dev-dependencies]
hyper = { version = "0.14", features = ["full"] } hyper = { version = "0.14", features = ["full"] }

View file

@ -1,6 +1,6 @@
# jwt-authorizer # jwt-authorizer
JWT authoriser Layer for Axum. JWT authoriser Layer for Axum and Tonic.
## Features ## Features

View file

@ -77,6 +77,53 @@ fn response_500() -> Response<BoxBody> {
res res
} }
#[cfg(feature = "tonic")]
impl From<AuthError> for Response<tonic::body::BoxBody> {
fn from(e: AuthError) -> Self {
match e {
AuthError::JwksRefreshError(err) => {
tracing::error!("AuthErrors::JwksRefreshError: {}", err);
tonic::Status::internal("")
}
AuthError::InvalidKey(err) => {
tracing::error!("AuthErrors::InvalidKey: {}", err);
tonic::Status::internal("")
}
AuthError::JwksSerialisationError(err) => {
tracing::error!("AuthErrors::JwksSerialisationError: {}", err);
tonic::Status::internal("")
}
AuthError::InvalidKeyAlg(err) => {
debug!("AuthErrors::InvalidKeyAlg: {:?}", err);
tonic::Status::unauthenticated("error=\"invalid_token\", error_description=\"invalid key algorithm\"")
}
AuthError::InvalidKid(err) => {
debug!("AuthErrors::InvalidKid: {}", err);
tonic::Status::unauthenticated("error=\"invalid_token\", error_description=\"invalid kid\"")
}
AuthError::InvalidToken(err) => {
debug!("AuthErrors::InvalidToken: {}", err);
tonic::Status::unauthenticated("error=\"invalid_token\"")
}
AuthError::MissingToken() => {
debug!("AuthErrors::MissingToken");
tonic::Status::unauthenticated("")
}
AuthError::InvalidClaims() => {
debug!("AuthErrors::InvalidClaims");
tonic::Status::unauthenticated("error=\"insufficient_scope\"")
}
}
.to_http()
}
}
impl From<AuthError> for Response {
fn from(e: AuthError) -> Self {
e.into_response()
}
}
/// (https://datatracker.ietf.org/doc/html/rfc6750#section-3.1) /// (https://datatracker.ietf.org/doc/html/rfc6750#section-3.1)
impl IntoResponse for AuthError { impl IntoResponse for AuthError {
fn into_response(self) -> Response { fn into_response(self) -> Response {

View file

@ -1,6 +1,4 @@
use axum::body::BoxBody;
use axum::http::Request; use axum::http::Request;
use axum::response::{IntoResponse, Response};
use futures_core::ready; use futures_core::ready;
use futures_util::future::BoxFuture; use futures_util::future::BoxFuture;
use headers::authorization::Bearer; use headers::authorization::Bearer;
@ -193,8 +191,7 @@ where
/// Trait for authorizing requests. /// Trait for authorizing requests.
pub trait AsyncAuthorizer<B> { pub trait AsyncAuthorizer<B> {
type RequestBody; type RequestBody;
type ResponseBody; type Future: Future<Output = Result<Request<Self::RequestBody>, AuthError>>;
type Future: Future<Output = Result<Request<Self::RequestBody>, Response<Self::ResponseBody>>>;
/// Authorize the request. /// Authorize the request.
/// ///
@ -208,8 +205,7 @@ where
C: Clone + DeserializeOwned + Send + Sync + 'static, C: Clone + DeserializeOwned + Send + Sync + 'static,
{ {
type RequestBody = B; type RequestBody = B;
type ResponseBody = BoxBody; type Future = BoxFuture<'static, Result<Request<B>, AuthError>>;
type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
fn authorize(&self, mut request: Request<B>) -> Self::Future { fn authorize(&self, mut request: Request<B>) -> Self::Future {
let authorizer = self.auth.clone(); let authorizer = self.auth.clone();
@ -226,18 +222,15 @@ where
}; };
Box::pin(async move { Box::pin(async move {
if let Some(token) = token { if let Some(token) = token {
match authorizer.check_auth(token.as_str()).await { authorizer.check_auth(token.as_str()).await.map(|token_data| {
Ok(token_data) => { // Set `token_data` as a request extension so it can be accessed by other
// Set `token_data` as a request extension so it can be accessed by other // services down the stack.
// services down the stack. request.extensions_mut().insert(token_data);
request.extensions_mut().insert(token_data);
Ok(request) request
} })
Err(err) => Err(err.into_response()),
}
} else { } else {
Err(AuthError::MissingToken().into_response()) Err(AuthError::MissingToken())
} }
}) })
} }
@ -335,7 +328,8 @@ where
impl<ReqBody, S, C> Service<Request<ReqBody>> for AsyncAuthorizationService<S, C> impl<ReqBody, S, C> Service<Request<ReqBody>> for AsyncAuthorizationService<S, C>
where where
ReqBody: Send + Sync + 'static, ReqBody: Send + Sync + 'static,
S: Service<Request<ReqBody>, Response = Response> + Clone, S: Service<Request<ReqBody>> + Clone,
S::Response: From<AuthError>,
C: Clone + DeserializeOwned + Send + Sync + 'static, C: Clone + DeserializeOwned + Send + Sync + 'static,
{ {
type Response = S::Response; type Response = S::Response;
@ -361,7 +355,7 @@ where
/// Response future for [`AsyncAuthorizationService`]. /// Response future for [`AsyncAuthorizationService`].
pub struct ResponseFuture<S, ReqBody, C> pub struct ResponseFuture<S, ReqBody, C>
where where
S: Service<Request<ReqBody>, Response = Response>, S: Service<Request<ReqBody>>,
ReqBody: Send + Sync + 'static, ReqBody: Send + Sync + 'static,
C: Clone + DeserializeOwned + Send + Sync + 'static, C: Clone + DeserializeOwned + Send + Sync + 'static,
{ {
@ -384,7 +378,8 @@ enum State<A, SFut> {
impl<S, ReqBody, C> Future for ResponseFuture<S, ReqBody, C> impl<S, ReqBody, C> Future for ResponseFuture<S, ReqBody, C>
where where
S: Service<Request<ReqBody>, Response = Response>, S: Service<Request<ReqBody>>,
S::Response: From<AuthError>,
ReqBody: Send + Sync + 'static, ReqBody: Send + Sync + 'static,
C: Clone + DeserializeOwned + Send + Sync, C: Clone + DeserializeOwned + Send + Sync,
{ {
@ -404,7 +399,7 @@ where
} }
Err(res) => { Err(res) => {
tracing::info!("err: {:?}", res); tracing::info!("err: {:?}", res);
return Poll::Ready(Ok(res)); return Poll::Ready(Ok(res.into()));
} }
}; };
} }