mirror of
https://github.com/TECHNOFAB11/jwt-authorizer.git
synced 2025-12-12 16:10:06 +01:00
chore: Merge 'sjoerdsimons/use-ready-inner-service'
- containing 2 PRs: #21, #19
This commit is contained in:
commit
93325dce96
9 changed files with 915 additions and 265 deletions
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "jwt-authorizer"
|
||||
description = "jwt authorizer middleware for axum"
|
||||
description = "jwt authorizer middleware for axum and tonic"
|
||||
version = "0.9.0"
|
||||
edition = "2021"
|
||||
authors = ["cduvray <c_duvray@proton.me>"]
|
||||
|
|
@ -27,11 +27,13 @@ tower-layer = "0.3"
|
|||
tower-service = "0.3"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
tonic = { version = "0.9.2", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
hyper = { version = "0.14", features = ["full"] }
|
||||
lazy_static = "1.4.0"
|
||||
tower = { version = "0.4", features = ["util"] }
|
||||
prost = "0.11.9"
|
||||
tower = { version = "0.4", features = ["util", "buffer"] }
|
||||
wiremock = "0.5"
|
||||
|
||||
[features]
|
||||
|
|
@ -44,3 +46,7 @@ rustls-tls = ["reqwest/rustls-tls"]
|
|||
rustls-tls-manual-roots = ["reqwest/rustls-tls-manual-roots"]
|
||||
rustls-tls-webpki-roots = ["reqwest/rustls-tls-webpki-roots"]
|
||||
rustls-tls-native-roots = ["reqwest/rustls-tls-native-roots"]
|
||||
|
||||
[[test]]
|
||||
name = "tonic"
|
||||
required-features = [ "tonic" ]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# jwt-authorizer
|
||||
|
||||
JWT authoriser Layer for Axum.
|
||||
JWT authoriser Layer for Axum and Tonic.
|
||||
|
||||
## Features
|
||||
|
||||
|
|
|
|||
|
|
@ -77,6 +77,53 @@ fn response_500() -> Response<BoxBody> {
|
|||
res
|
||||
}
|
||||
|
||||
#[cfg(feature = "tonic")]
|
||||
impl From<AuthError> for Response<tonic::body::BoxBody> {
|
||||
fn from(e: AuthError) -> Self {
|
||||
match e {
|
||||
AuthError::JwksRefreshError(err) => {
|
||||
tracing::error!("AuthErrors::JwksRefreshError: {}", err);
|
||||
tonic::Status::internal("")
|
||||
}
|
||||
AuthError::InvalidKey(err) => {
|
||||
tracing::error!("AuthErrors::InvalidKey: {}", err);
|
||||
tonic::Status::internal("")
|
||||
}
|
||||
AuthError::JwksSerialisationError(err) => {
|
||||
tracing::error!("AuthErrors::JwksSerialisationError: {}", err);
|
||||
tonic::Status::internal("")
|
||||
}
|
||||
AuthError::InvalidKeyAlg(err) => {
|
||||
debug!("AuthErrors::InvalidKeyAlg: {:?}", err);
|
||||
tonic::Status::unauthenticated("error=\"invalid_token\", error_description=\"invalid key algorithm\"")
|
||||
}
|
||||
AuthError::InvalidKid(err) => {
|
||||
debug!("AuthErrors::InvalidKid: {}", err);
|
||||
tonic::Status::unauthenticated("error=\"invalid_token\", error_description=\"invalid kid\"")
|
||||
}
|
||||
AuthError::InvalidToken(err) => {
|
||||
debug!("AuthErrors::InvalidToken: {}", err);
|
||||
tonic::Status::unauthenticated("error=\"invalid_token\"")
|
||||
}
|
||||
AuthError::MissingToken() => {
|
||||
debug!("AuthErrors::MissingToken");
|
||||
tonic::Status::unauthenticated("")
|
||||
}
|
||||
AuthError::InvalidClaims() => {
|
||||
debug!("AuthErrors::InvalidClaims");
|
||||
tonic::Status::unauthenticated("error=\"insufficient_scope\"")
|
||||
}
|
||||
}
|
||||
.to_http()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<AuthError> for Response {
|
||||
fn from(e: AuthError) -> Self {
|
||||
e.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
/// (https://datatracker.ietf.org/doc/html/rfc6750#section-3.1)
|
||||
impl IntoResponse for AuthError {
|
||||
fn into_response(self) -> Response {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,4 @@
|
|||
use axum::body::BoxBody;
|
||||
use axum::http::Request;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use futures_core::ready;
|
||||
use futures_util::future::BoxFuture;
|
||||
use headers::authorization::Bearer;
|
||||
|
|
@ -193,8 +191,7 @@ where
|
|||
/// Trait for authorizing requests.
|
||||
pub trait AsyncAuthorizer<B> {
|
||||
type RequestBody;
|
||||
type ResponseBody;
|
||||
type Future: Future<Output = Result<Request<Self::RequestBody>, Response<Self::ResponseBody>>>;
|
||||
type Future: Future<Output = Result<Request<Self::RequestBody>, AuthError>>;
|
||||
|
||||
/// Authorize the request.
|
||||
///
|
||||
|
|
@ -208,8 +205,7 @@ where
|
|||
C: Clone + DeserializeOwned + Send + Sync + 'static,
|
||||
{
|
||||
type RequestBody = B;
|
||||
type ResponseBody = BoxBody;
|
||||
type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
|
||||
type Future = BoxFuture<'static, Result<Request<B>, AuthError>>;
|
||||
|
||||
fn authorize(&self, mut request: Request<B>) -> Self::Future {
|
||||
let authorizer = self.auth.clone();
|
||||
|
|
@ -226,18 +222,15 @@ where
|
|||
};
|
||||
Box::pin(async move {
|
||||
if let Some(token) = token {
|
||||
match authorizer.check_auth(token.as_str()).await {
|
||||
Ok(token_data) => {
|
||||
// Set `token_data` as a request extension so it can be accessed by other
|
||||
// services down the stack.
|
||||
request.extensions_mut().insert(token_data);
|
||||
authorizer.check_auth(token.as_str()).await.map(|token_data| {
|
||||
// Set `token_data` as a request extension so it can be accessed by other
|
||||
// services down the stack.
|
||||
request.extensions_mut().insert(token_data);
|
||||
|
||||
Ok(request)
|
||||
}
|
||||
Err(err) => Err(err.into_response()),
|
||||
}
|
||||
request
|
||||
})
|
||||
} else {
|
||||
Err(AuthError::MissingToken().into_response())
|
||||
Err(AuthError::MissingToken())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -335,7 +328,8 @@ where
|
|||
impl<ReqBody, S, C> Service<Request<ReqBody>> for AsyncAuthorizationService<S, C>
|
||||
where
|
||||
ReqBody: Send + Sync + 'static,
|
||||
S: Service<Request<ReqBody>, Response = Response> + Clone,
|
||||
S: Service<Request<ReqBody>> + Clone,
|
||||
S::Response: From<AuthError>,
|
||||
C: Clone + DeserializeOwned + Send + Sync + 'static,
|
||||
{
|
||||
type Response = S::Response;
|
||||
|
|
@ -348,6 +342,9 @@ where
|
|||
|
||||
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
|
||||
let inner = self.inner.clone();
|
||||
// take the service that was ready
|
||||
let inner = std::mem::replace(&mut self.inner, inner);
|
||||
|
||||
let auth_fut = self.authorize(req);
|
||||
|
||||
ResponseFuture {
|
||||
|
|
@ -361,7 +358,7 @@ where
|
|||
/// Response future for [`AsyncAuthorizationService`].
|
||||
pub struct ResponseFuture<S, ReqBody, C>
|
||||
where
|
||||
S: Service<Request<ReqBody>, Response = Response>,
|
||||
S: Service<Request<ReqBody>>,
|
||||
ReqBody: Send + Sync + 'static,
|
||||
C: Clone + DeserializeOwned + Send + Sync + 'static,
|
||||
{
|
||||
|
|
@ -384,7 +381,8 @@ enum State<A, SFut> {
|
|||
|
||||
impl<S, ReqBody, C> Future for ResponseFuture<S, ReqBody, C>
|
||||
where
|
||||
S: Service<Request<ReqBody>, Response = Response>,
|
||||
S: Service<Request<ReqBody>>,
|
||||
S::Response: From<AuthError>,
|
||||
ReqBody: Send + Sync + 'static,
|
||||
C: Clone + DeserializeOwned + Send + Sync,
|
||||
{
|
||||
|
|
@ -404,7 +402,7 @@ where
|
|||
}
|
||||
Err(res) => {
|
||||
tracing::info!("err: {:?}", res);
|
||||
return Poll::Ready(Ok(res));
|
||||
return Poll::Ready(Ok(res.into()));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,17 +2,19 @@ mod common;
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::convert::Infallible;
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{Request, StatusCode},
|
||||
response::Response,
|
||||
routing::get,
|
||||
Router,
|
||||
BoxError, Router,
|
||||
};
|
||||
use http::{header, HeaderValue};
|
||||
use jwt_authorizer::{layer::JwtSource, validation::Validation, JwtAuthorizer, JwtClaims};
|
||||
use serde::Deserialize;
|
||||
use tower::ServiceExt;
|
||||
use tower::{util::MapErrLayer, ServiceExt};
|
||||
|
||||
use crate::common;
|
||||
|
||||
|
|
@ -24,8 +26,15 @@ mod tests {
|
|||
async fn app(jwt_auth: JwtAuthorizer<User>) -> Router {
|
||||
Router::new().route("/public", get(|| async { "hello" })).route(
|
||||
"/protected",
|
||||
get(|JwtClaims(user): JwtClaims<User>| async move { format!("hello: {}", user.sub) })
|
||||
.layer(jwt_auth.layer().await.unwrap()),
|
||||
get(|JwtClaims(user): JwtClaims<User>| async move { format!("hello: {}", user.sub) }).layer(
|
||||
tower_layer::Stack::new(
|
||||
tower_layer::Stack::new(
|
||||
tower::buffer::BufferLayer::new(1),
|
||||
MapErrLayer::new(|e: BoxError| -> Infallible { panic!("{}", e) }),
|
||||
),
|
||||
jwt_auth.layer().await.unwrap(),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
|||
208
jwt-authorizer/tests/tonic.rs
Normal file
208
jwt-authorizer/tests/tonic.rs
Normal file
|
|
@ -0,0 +1,208 @@
|
|||
use std::{sync::Once, task::Poll};
|
||||
|
||||
use axum::body::HttpBody;
|
||||
use futures_core::future::BoxFuture;
|
||||
use http::header::AUTHORIZATION;
|
||||
use jwt_authorizer::{layer::AsyncAuthorizationService, JwtAuthorizer};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tonic::{server::UnaryService, transport::NamedService, IntoRequest, Status};
|
||||
use tower::{buffer::Buffer, Service};
|
||||
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
use crate::common::{JWT_RSA1_OK, JWT_RSA2_OK};
|
||||
|
||||
mod common;
|
||||
|
||||
/// Static variable to ensure that logging is only initialized once.
|
||||
pub static INITIALIZED: Once = Once::new();
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
struct User {
|
||||
sub: String,
|
||||
}
|
||||
|
||||
#[derive(prost::Message)]
|
||||
struct HelloMessage {
|
||||
#[prost(string, tag = "1")]
|
||||
message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
struct SayHelloMethod {}
|
||||
impl UnaryService<HelloMessage> for SayHelloMethod {
|
||||
type Response = HelloMessage;
|
||||
type Future = BoxFuture<'static, Result<tonic::Response<Self::Response>, Status>>;
|
||||
|
||||
fn call(&mut self, request: tonic::Request<HelloMessage>) -> Self::Future {
|
||||
Box::pin(async move {
|
||||
let hi = request.into_inner();
|
||||
let reply = HelloMessage {
|
||||
message: format!("Hello, {}", hi.message),
|
||||
};
|
||||
Ok(tonic::Response::new(reply))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
struct GreeterServer {
|
||||
expected_sub: String,
|
||||
}
|
||||
|
||||
impl Service<http::Request<tonic::transport::Body>> for GreeterServer {
|
||||
type Response = http::Response<tonic::body::BoxBody>;
|
||||
type Error = std::convert::Infallible;
|
||||
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: http::Request<tonic::transport::Body>) -> Self::Future {
|
||||
let token = req.extensions().get::<jsonwebtoken::TokenData<User>>().unwrap();
|
||||
assert_eq!(token.claims.sub, self.expected_sub);
|
||||
match req.uri().path() {
|
||||
"/hello/SayHello" => Box::pin(async move {
|
||||
let mut grpc = tonic::server::Grpc::new(tonic::codec::ProstCodec::default());
|
||||
Ok(grpc.unary(SayHelloMethod::default(), req).await)
|
||||
}),
|
||||
p => {
|
||||
let p = p.to_string();
|
||||
Box::pin(async move { Ok(Status::unimplemented(p).to_http()) })
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl NamedService for GreeterServer {
|
||||
const NAME: &'static str = "hello";
|
||||
}
|
||||
|
||||
async fn app(
|
||||
jwt_auth: JwtAuthorizer<User>,
|
||||
expected_sub: String,
|
||||
) -> AsyncAuthorizationService<Buffer<tonic::transport::server::Routes, http::Request<tonic::transport::Body>>, User> {
|
||||
let layer = jwt_auth.layer().await.unwrap();
|
||||
tonic::transport::Server::builder()
|
||||
.layer(layer)
|
||||
.layer(tower::buffer::BufferLayer::new(1))
|
||||
.add_service(GreeterServer { expected_sub })
|
||||
.into_service()
|
||||
}
|
||||
|
||||
fn init_test() {
|
||||
INITIALIZED.call_once(|| {
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::EnvFilter::new(
|
||||
std::env::var("RUST_LOG").unwrap_or_else(|_| "info,jwt-authorizer=debug,tower_http=debug".into()),
|
||||
))
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
});
|
||||
}
|
||||
|
||||
// The grpc client produces a http request with a tonic boxbody that the transport is meant to sent out, while the server side
|
||||
// expects to receive a http request with a hyper body.. This simple wrapper converts from one to
|
||||
// the other.
|
||||
struct GrpcWrapper<S>
|
||||
where
|
||||
S: Service<http::Request<axum::body::Body>> + Clone,
|
||||
{
|
||||
inner: S,
|
||||
}
|
||||
|
||||
impl<S> Service<http::Request<tonic::body::BoxBody>> for GrpcWrapper<S>
|
||||
where
|
||||
S: Service<http::Request<axum::body::Body>> + Clone + Send + 'static,
|
||||
S::Future: Send,
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = S::Error;
|
||||
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn call(&mut self, req: http::Request<tonic::body::BoxBody>) -> Self::Future {
|
||||
let inner = self.inner.clone();
|
||||
// take the service that was ready
|
||||
let mut inner = std::mem::replace(&mut self.inner, inner);
|
||||
Box::pin(async move {
|
||||
let (parts, mut body) = req.into_parts();
|
||||
let mut data = Vec::new();
|
||||
while let Some(d) = body.data().await {
|
||||
let d = d.unwrap();
|
||||
data.extend_from_slice(&d)
|
||||
}
|
||||
inner
|
||||
.call(http::Request::from_parts(parts, axum::body::Body::from(data)))
|
||||
.await
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn make_protected_request<S: Clone>(
|
||||
app: AsyncAuthorizationService<S, User>,
|
||||
bearer: Option<&str>,
|
||||
message: &str,
|
||||
) -> Result<tonic::Response<HelloMessage>, Status>
|
||||
where
|
||||
S: Service<
|
||||
http::Request<tonic::transport::Body>,
|
||||
Response = http::Response<tonic::body::BoxBody>,
|
||||
Error = tower::BoxError,
|
||||
> + Send
|
||||
+ 'static,
|
||||
S::Future: Send,
|
||||
{
|
||||
let mut grpc = tonic::client::Grpc::new(GrpcWrapper { inner: app });
|
||||
|
||||
let mut request = HelloMessage {
|
||||
message: message.to_string(),
|
||||
}
|
||||
.into_request();
|
||||
|
||||
if let Some(bearer) = bearer {
|
||||
let headers = request.metadata_mut();
|
||||
headers.insert(AUTHORIZATION.as_str(), format!("Bearer {bearer}").parse().unwrap());
|
||||
}
|
||||
|
||||
grpc.ready().await.unwrap();
|
||||
grpc.unary(
|
||||
request,
|
||||
http::uri::PathAndQuery::from_static("/hello/SayHello"),
|
||||
tonic::codec::ProstCodec::default(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn successfull_auth() {
|
||||
init_test();
|
||||
let auth: JwtAuthorizer<User> = JwtAuthorizer::from_rsa_pem("../config/rsa-public1.pem");
|
||||
let app = app(auth, "b@b.com".to_string()).await;
|
||||
let r = make_protected_request(app.clone(), Some(JWT_RSA1_OK), "world").await.unwrap();
|
||||
assert_eq!(r.get_ref().message, "Hello, world");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wrong_token() {
|
||||
init_test();
|
||||
let auth: JwtAuthorizer<User> = JwtAuthorizer::from_rsa_pem("../config/rsa-public1.pem");
|
||||
let app = app(auth, "b@b.com".to_string()).await;
|
||||
let status = make_protected_request(app.clone(), Some(JWT_RSA2_OK), "world")
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(status.code(), tonic::Code::Unauthenticated);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn no_token() {
|
||||
init_test();
|
||||
let auth: JwtAuthorizer<User> = JwtAuthorizer::from_rsa_pem("../config/rsa-public1.pem");
|
||||
let app = app(auth, "b@b.com".to_string()).await;
|
||||
let status = make_protected_request(app.clone(), None, "world").await.unwrap_err();
|
||||
assert_eq!(status.code(), tonic::Code::Unauthenticated);
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue