mirror of
https://github.com/TECHNOFAB11/jwt-authorizer.git
synced 2025-12-11 23:50:07 +01:00
refactor: demo server (clean, refactor, docs)
This commit is contained in:
parent
43f2523ec6
commit
6ff5d88ae9
5 changed files with 56 additions and 170 deletions
|
|
@ -1,13 +1,20 @@
|
|||
use axum::{routing::get, Router};
|
||||
use jwt_authorizer::{error::InitError, AuthError, JwtAuthorizer, JwtClaims, Refresh, RefreshStrategy};
|
||||
use serde::Deserialize;
|
||||
use std::{fmt::Display, net::SocketAddr};
|
||||
use std::net::SocketAddr;
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tracing::info;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
mod oidc_provider;
|
||||
|
||||
/// Object representing claims
|
||||
/// (a subset of deserialized claims)
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
struct User {
|
||||
sub: String,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), InitError> {
|
||||
tracing_subscriber::registry()
|
||||
|
|
@ -17,6 +24,7 @@ async fn main() -> Result<(), InitError> {
|
|||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
|
||||
// claims checker function
|
||||
fn claim_checker(u: &User) -> bool {
|
||||
info!("checking claims: {} -> {}", u.sub, u.sub.contains('@'));
|
||||
|
||||
|
|
@ -24,12 +32,12 @@ async fn main() -> Result<(), InitError> {
|
|||
}
|
||||
|
||||
// starting oidc provider (discovery is needed by from_oidc())
|
||||
oidc_provider::run_server();
|
||||
let issuer_uri = oidc_provider::run_server();
|
||||
|
||||
// First let's create an authorizer builder from a JWKS Endpoint
|
||||
// First let's create an authorizer builder from a Oidc Discovery
|
||||
// User is a struct deserializable from JWT claims representing the authorized user
|
||||
// let jwt_auth: JwtAuthorizer<User> = JwtAuthorizer::from_oidc("https://accounts.google.com/")
|
||||
let jwt_auth: JwtAuthorizer<User> = JwtAuthorizer::from_oidc("http://localhost:3001")
|
||||
let jwt_auth: JwtAuthorizer<User> = JwtAuthorizer::from_oidc(issuer_uri)
|
||||
// .no_refresh()
|
||||
.refresh(Refresh {
|
||||
strategy: RefreshStrategy::Interval,
|
||||
|
|
@ -42,10 +50,11 @@ async fn main() -> Result<(), InitError> {
|
|||
.route("/protected", get(protected))
|
||||
// adding the authorizer layer
|
||||
.layer(jwt_auth.layer().await?);
|
||||
// .layer(jwt_auth.check_claims(|_: User| true));
|
||||
|
||||
let app = Router::new()
|
||||
// actual protected apis
|
||||
// public endpoint
|
||||
.route("/public", get(public_handler))
|
||||
// protected APIs
|
||||
.nest("/api", api)
|
||||
.layer(TraceLayer::new_for_http());
|
||||
|
||||
|
|
@ -57,18 +66,13 @@ async fn main() -> Result<(), InitError> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// handler with injected claims object
|
||||
async fn protected(JwtClaims(user): JwtClaims<User>) -> Result<String, AuthError> {
|
||||
// Send the protected data to the user
|
||||
Ok(format!("Welcome: {}", user.sub))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
struct User {
|
||||
sub: String,
|
||||
}
|
||||
|
||||
impl Display for User {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "User: {:?}", self.sub)
|
||||
}
|
||||
// public url handler
|
||||
async fn public_handler() -> &'static str {
|
||||
"Public URL!"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,59 +1,27 @@
|
|||
use axum::{
|
||||
async_trait,
|
||||
extract::{FromRequestParts, TypedHeader},
|
||||
headers::{authorization::Bearer, Authorization},
|
||||
http::{request::Parts, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use axum::{routing::get, Json, Router};
|
||||
use josekit::jwk::{
|
||||
alg::{ec::EcCurve, ec::EcKeyPair, ed::EdKeyPair, rsa::RsaKeyPair},
|
||||
Jwk,
|
||||
};
|
||||
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
|
||||
use once_cell::sync::Lazy;
|
||||
use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::{fmt::Display, net::SocketAddr, thread, time::Duration};
|
||||
|
||||
pub static KEYS: Lazy<Keys> = Lazy::new(|| {
|
||||
//let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set");
|
||||
// Keys::new("xxxxx".as_bytes())
|
||||
Keys::load_rsa()
|
||||
});
|
||||
use std::{net::SocketAddr, thread, time::Duration};
|
||||
|
||||
const ISSUER_URI: &str = "http://localhost:3001";
|
||||
|
||||
pub struct Keys {
|
||||
pub alg: Algorithm,
|
||||
pub encoding: EncodingKey,
|
||||
pub decoding: DecodingKey,
|
||||
}
|
||||
|
||||
impl Keys {
|
||||
fn load_rsa() -> Self {
|
||||
Self {
|
||||
alg: Algorithm::RS256,
|
||||
encoding: EncodingKey::from_rsa_pem(include_bytes!("../../../config/jwtRS256.key")).unwrap(),
|
||||
decoding: DecodingKey::from_rsa_pem(include_bytes!("../../../config/jwtRS256.key.pub")).unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// OpenId Connect discovery (simplified for test purposes)
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct OidcDiscovery {
|
||||
struct OidcDiscovery {
|
||||
issuer: String,
|
||||
jwks_uri: String,
|
||||
authorization_endpoint: String,
|
||||
}
|
||||
|
||||
pub async fn discovery() -> Json<Value> {
|
||||
/// discovery url handler
|
||||
async fn discovery() -> Json<Value> {
|
||||
let d = OidcDiscovery {
|
||||
issuer: ISSUER_URI.to_owned(),
|
||||
jwks_uri: format!("{}/jwks", ISSUER_URI),
|
||||
authorization_endpoint: format!("{}/authorize", ISSUER_URI),
|
||||
jwks_uri: format!("{ISSUER_URI}/jwks"),
|
||||
};
|
||||
Json(json!(d))
|
||||
}
|
||||
|
|
@ -63,9 +31,8 @@ struct JwkSet {
|
|||
keys: Vec<Jwk>,
|
||||
}
|
||||
|
||||
pub async fn jwks() -> Json<Value> {
|
||||
// let mut ksmap = serde_json::Map::new();
|
||||
|
||||
/// jwk set endpoint handler
|
||||
async fn jwks() -> Json<Value> {
|
||||
let mut kset = JwkSet { keys: Vec::<Jwk>::new() };
|
||||
|
||||
let keypair = RsaKeyPair::from_pem(include_bytes!("../../../config/jwtRS256.key")).unwrap();
|
||||
|
|
@ -113,6 +80,7 @@ pub async fn jwks() -> Json<Value> {
|
|||
Json(json!(kset))
|
||||
}
|
||||
|
||||
/// build a minimal JWT header
|
||||
fn build_header(alg: Algorithm, kid: &str) -> Header {
|
||||
Header {
|
||||
typ: Some("JWT".to_string()),
|
||||
|
|
@ -128,11 +96,22 @@ fn build_header(alg: Algorithm, kid: &str) -> Header {
|
|||
}
|
||||
}
|
||||
|
||||
/// issues test tokens (this is not a standard endpoint)
|
||||
/// token claims
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Claims {
|
||||
iss: &'static str,
|
||||
sub: &'static str,
|
||||
exp: usize,
|
||||
nbf: usize,
|
||||
}
|
||||
|
||||
/// handler issuing test tokens (this is not a standard endpoint)
|
||||
pub async fn tokens() -> Json<Value> {
|
||||
let claims = Claims {
|
||||
sub: "b@b.com".to_owned(),
|
||||
iss: ISSUER_URI,
|
||||
sub: "b@b.com",
|
||||
exp: 2000000000, // May 2033
|
||||
nbf: 1516239022, // Jan 2018
|
||||
};
|
||||
|
||||
let rsa_key = EncodingKey::from_rsa_pem(include_bytes!("../../../config/jwtRS256.key")).unwrap();
|
||||
|
|
@ -150,33 +129,10 @@ pub async fn tokens() -> Json<Value> {
|
|||
}))
|
||||
}
|
||||
|
||||
pub async fn authorize(Json(payload): Json<AuthPayload>) -> Result<Json<AuthBody>, AuthError> {
|
||||
tracing::info!("authorizing ...");
|
||||
if payload.client_id.is_empty() || payload.client_secret.is_empty() {
|
||||
return Err(AuthError::MissingCredentials);
|
||||
}
|
||||
// Here you can check the user credentials from a database
|
||||
if payload.client_id != "foo" || payload.client_secret != "bar" {
|
||||
return Err(AuthError::WrongCredentials);
|
||||
}
|
||||
let claims = Claims {
|
||||
sub: "b@b.com".to_owned(),
|
||||
// Mandatory expiry time as UTC timestamp
|
||||
exp: 2000000000, // May 2033
|
||||
};
|
||||
// Create the authorization token
|
||||
let token = encode(&Header::new(KEYS.alg), &claims, &KEYS.encoding).map_err(|_| AuthError::TokenCreation)?;
|
||||
|
||||
// Send the authorized token
|
||||
Ok(Json(AuthBody::new(token)))
|
||||
}
|
||||
|
||||
/// exposes oidc "like" endpoints (this is for test purposes)
|
||||
pub fn run_server() {
|
||||
// oidc "like" endpoints for test purposes
|
||||
/// exposes some oidc "like" endpoints for test purposes
|
||||
pub fn run_server() -> &'static str {
|
||||
let app = Router::new()
|
||||
.route("/.well-known/openid-configuration", get(discovery))
|
||||
.route("/authorize", post(authorize))
|
||||
.route("/jwks", get(jwks))
|
||||
.route("/tokens", get(tokens));
|
||||
|
||||
|
|
@ -187,79 +143,6 @@ pub fn run_server() {
|
|||
});
|
||||
|
||||
thread::sleep(Duration::from_millis(200)); // waiting oidc to start
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Claims {
|
||||
sub: String,
|
||||
exp: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct AuthBody {
|
||||
access_token: String,
|
||||
token_type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AuthPayload {
|
||||
client_id: String,
|
||||
client_secret: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum AuthError {
|
||||
WrongCredentials,
|
||||
MissingCredentials,
|
||||
TokenCreation,
|
||||
InvalidToken,
|
||||
}
|
||||
|
||||
impl Display for Claims {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "sub: {}", self.sub)
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthBody {
|
||||
fn new(access_token: String) -> Self {
|
||||
Self {
|
||||
access_token,
|
||||
token_type: "Bearer".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for AuthError {
|
||||
fn into_response(self) -> Response {
|
||||
let (status, error_message) = match self {
|
||||
AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
|
||||
AuthError::MissingCredentials => (StatusCode::BAD_REQUEST, "Missing credentials"),
|
||||
AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"),
|
||||
AuthError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"),
|
||||
};
|
||||
let body = Json(json!({
|
||||
"error": error_message,
|
||||
}));
|
||||
(status, body).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S> FromRequestParts<S> for Claims
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = AuthError;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||
// Extract the token from the authorization header
|
||||
let TypedHeader(Authorization(bearer)) = TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
|
||||
.await
|
||||
.map_err(|_| AuthError::InvalidToken)?;
|
||||
let token_data =
|
||||
decode::<Claims>(bearer.token(), &KEYS.decoding, &Validation::default()).map_err(|_| AuthError::InvalidToken)?;
|
||||
|
||||
Ok(token_data.claims)
|
||||
}
|
||||
|
||||
ISSUER_URI
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue