diff --git a/Cargo.lock b/Cargo.lock index c2a6133..046f79a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -280,6 +280,12 @@ dependencies = [ "crypto-common", ] +[[package]] +name = "either" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" + [[package]] name = "encoding_rs" version = "0.8.32" @@ -734,6 +740,15 @@ version = "2.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "12b6ee2129af8d4fb011108c73d99a1b83a85977f23b82460c0ae2e25bb4b57f" +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.6" @@ -794,6 +809,7 @@ dependencies = [ "jsonwebtoken", "lazy_static", "pin-project", + "prost", "reqwest", "serde", "serde_json", @@ -1122,6 +1138,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd" dependencies = [ "bytes", + "prost-derive", +] + +[[package]] +name = "prost-derive" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5d2d8d10f3c6ded6da8b05b5fb3b8a5082514344d56c9f871412d29b4e075b4" +dependencies = [ + "anyhow", + "itertools", + "proc-macro2", + "quote", + "syn 1.0.109", ] [[package]] diff --git a/jwt-authorizer/Cargo.toml b/jwt-authorizer/Cargo.toml index 031f210..52a5313 100644 --- a/jwt-authorizer/Cargo.toml +++ b/jwt-authorizer/Cargo.toml @@ -32,5 +32,10 @@ tonic = { version = "0.9.2", optional = true } [dev-dependencies] hyper = { version = "0.14", features = ["full"] } lazy_static = "1.4.0" +prost = "0.11.9" tower = { version = "0.4", features = ["util"] } wiremock = "0.5" + +[[test]] +name = "tonic" +required-features = [ "tonic" ] diff --git a/jwt-authorizer/tests/tonic.rs b/jwt-authorizer/tests/tonic.rs new file mode 100644 index 0000000..fc838fe --- /dev/null +++ b/jwt-authorizer/tests/tonic.rs @@ -0,0 +1,207 @@ +use std::{sync::Once, task::Poll}; + +use axum::body::HttpBody; +use futures_core::future::BoxFuture; +use http::header::AUTHORIZATION; +use jwt_authorizer::{layer::AsyncAuthorizationService, JwtAuthorizer}; +use serde::{Deserialize, Serialize}; +use tonic::{server::UnaryService, transport::NamedService, IntoRequest, Status}; +use tower::Service; + +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +use crate::common::{JWT_RSA1_OK, JWT_RSA2_OK}; + +mod common; + +/// Static variable to ensure that logging is only initialized once. +pub static INITIALIZED: Once = Once::new(); + +#[derive(Debug, Deserialize, Serialize, Clone)] +struct User { + sub: String, +} + +#[derive(prost::Message)] +struct HelloMessage { + #[prost(string, tag = "1")] + message: String, +} + +#[derive(Debug, Default, Clone)] +struct SayHelloMethod {} +impl UnaryService for SayHelloMethod { + type Response = HelloMessage; + type Future = BoxFuture<'static, Result, Status>>; + + fn call(&mut self, request: tonic::Request) -> Self::Future { + Box::pin(async move { + let hi = request.into_inner(); + let reply = HelloMessage { + message: format!("Hello, {}", hi.message), + }; + Ok(tonic::Response::new(reply)) + }) + } +} + +#[derive(Debug, Default, Clone)] +struct GreeterServer { + expected_sub: String, +} + +impl Service> for GreeterServer { + type Response = http::Response; + type Error = std::convert::Infallible; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: http::Request) -> Self::Future { + let token = req.extensions().get::>().unwrap(); + assert_eq!(token.claims.sub, self.expected_sub); + match req.uri().path() { + "/hello/SayHello" => Box::pin(async move { + let mut grpc = tonic::server::Grpc::new(tonic::codec::ProstCodec::default()); + Ok(grpc.unary(SayHelloMethod::default(), req).await) + }), + p => { + let p = p.to_string(); + Box::pin(async move { Ok(Status::unimplemented(p).to_http()) }) + } + } + } +} + +impl NamedService for GreeterServer { + const NAME: &'static str = "hello"; +} + +async fn app( + jwt_auth: JwtAuthorizer, + expected_sub: String, +) -> AsyncAuthorizationService { + let layer = jwt_auth.layer().await.unwrap(); + tonic::transport::Server::builder() + .layer(layer) + .add_service(GreeterServer { expected_sub }) + .into_service() +} + +fn init_test() { + INITIALIZED.call_once(|| { + tracing_subscriber::registry() + .with(tracing_subscriber::EnvFilter::new( + std::env::var("RUST_LOG").unwrap_or_else(|_| "info,jwt-authorizer=debug,tower_http=debug".into()), + )) + .with(tracing_subscriber::fmt::layer()) + .init(); + }); +} + +// The grpc client produces a http request with a tonic boxbody that the transport is meant to sent out, while the server side +// expects to receive a http request with a hyper body.. This simple wrapper converts from one to +// the other. +struct GrpcWrapper +where + S: Service> + Clone, +{ + inner: S, +} + +impl Service> for GrpcWrapper +where + S: Service> + Clone + Send + 'static, + S::Future: Send, +{ + type Response = S::Response; + type Error = S::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: http::Request) -> Self::Future { + let inner = self.inner.clone(); + // take the service that was ready + let mut inner = std::mem::replace(&mut self.inner, inner); + Box::pin(async move { + let (parts, mut body) = req.into_parts(); + let mut data = Vec::new(); + while let Some(d) = body.data().await { + let d = d.unwrap(); + data.extend_from_slice(&d) + } + inner + .call(http::Request::from_parts(parts, axum::body::Body::from(data))) + .await + }) + } +} + +async fn make_protected_request( + app: AsyncAuthorizationService, + bearer: Option<&str>, + message: &str, +) -> Result, Status> +where + S: Service< + http::Request, + Response = http::Response, + Error = tower::BoxError, + > + Send + + 'static, + S::Future: Send, +{ + let mut grpc = tonic::client::Grpc::new(GrpcWrapper { inner: app }); + + let mut request = HelloMessage { + message: message.to_string(), + } + .into_request(); + + if let Some(bearer) = bearer { + let headers = request.metadata_mut(); + headers.insert(AUTHORIZATION.as_str(), format!("Bearer {bearer}").parse().unwrap()); + } + + grpc.ready().await.unwrap(); + grpc.unary( + request, + http::uri::PathAndQuery::from_static("/hello/SayHello"), + tonic::codec::ProstCodec::default(), + ) + .await +} + +#[tokio::test] +async fn successfull_auth() { + init_test(); + let auth: JwtAuthorizer = JwtAuthorizer::from_rsa_pem("../config/rsa-public1.pem"); + let app = app(auth, "b@b.com".to_string()).await; + let r = make_protected_request(app.clone(), Some(JWT_RSA1_OK), "world").await.unwrap(); + assert_eq!(r.get_ref().message, "Hello, world"); +} + +#[tokio::test] +async fn wrong_token() { + init_test(); + let auth: JwtAuthorizer = JwtAuthorizer::from_rsa_pem("../config/rsa-public1.pem"); + let app = app(auth, "b@b.com".to_string()).await; + let status = make_protected_request(app.clone(), Some(JWT_RSA2_OK), "world") + .await + .unwrap_err(); + assert_eq!(status.code(), tonic::Code::Unauthenticated); +} + +#[tokio::test] +async fn no_token() { + init_test(); + let auth: JwtAuthorizer = JwtAuthorizer::from_rsa_pem("../config/rsa-public1.pem"); + let app = app(auth, "b@b.com".to_string()).await; + let status = make_protected_request(app.clone(), None, "world").await.unwrap_err(); + assert_eq!(status.code(), tonic::Code::Unauthenticated); +}