refactor: demo server (clean, refactor, docs)

This commit is contained in:
cduvray 2023-01-30 23:44:25 +01:00
parent 43f2523ec6
commit 6ff5d88ae9
5 changed files with 56 additions and 170 deletions

View file

@ -1,8 +1,5 @@
### ### Public URL
POST http://localhost:3000/oidc/authorize GET http://localhost:3000/public
Content-Type: application/json
{"client_id":"foo","client_secret":"bar"}
### Protected RSA ### Protected RSA
GET http://localhost:3000/api/protected GET http://localhost:3000/api/protected
@ -19,7 +16,11 @@ GET http://localhost:3000/api/protected
Content-Type: application/json Content-Type: application/json
Authorization: Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJFZERTQSIsImtpZCI6ImtleS1lZCJ9.eyJzdWIiOiJiQGIuY29tIiwiZXhwIjoyMDAwMDAwMDAwfQ.XAx9msioheXEH1XUEIWMHGBg25JOpBHqcgL_ou_S3fwVht2TbKRiDZ4G6ZyEtn57hCbOy250zTD_g0EbaMGwAg Authorization: Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJFZERTQSIsImtpZCI6ImtleS1lZCJ9.eyJzdWIiOiJiQGIuY29tIiwiZXhwIjoyMDAwMDAwMDAwfQ.XAx9msioheXEH1XUEIWMHGBg25JOpBHqcgL_ou_S3fwVht2TbKRiDZ4G6ZyEtn57hCbOy250zTD_g0EbaMGwAg
### ### 401 (no token)
GET http://localhost:3000/api/protected
Content-Type: application/json
### 401 (invalid_token)
GET http://localhost:3000/api/protected GET http://localhost:3000/api/protected
Content-Type: application/json Content-Type: application/json
Authorization: Bearer blablabla.xxxx.xxxx Authorization: Bearer blablabla.xxxx.xxxx
@ -28,12 +29,10 @@ Authorization: Bearer blablabla.xxxx.xxxx
GET http://localhost:3001/.well-known/openid-configuration GET http://localhost:3001/.well-known/openid-configuration
Content-Type: application/json Content-Type: application/json
### jwks ### jwks
GET http://localhost:3001/jwks GET http://localhost:3001/jwks
Content-Type: application/json Content-Type: application/json
### Test tokens
###
GET http://localhost:3001/tokens GET http://localhost:3001/tokens
Content-Type: application/json Content-Type: application/json

View file

@ -1,13 +1,20 @@
use axum::{routing::get, Router}; use axum::{routing::get, Router};
use jwt_authorizer::{error::InitError, AuthError, JwtAuthorizer, JwtClaims, Refresh, RefreshStrategy}; use jwt_authorizer::{error::InitError, AuthError, JwtAuthorizer, JwtClaims, Refresh, RefreshStrategy};
use serde::Deserialize; use serde::Deserialize;
use std::{fmt::Display, net::SocketAddr}; use std::net::SocketAddr;
use tower_http::trace::TraceLayer; use tower_http::trace::TraceLayer;
use tracing::info; use tracing::info;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
mod oidc_provider; mod oidc_provider;
/// Object representing claims
/// (a subset of deserialized claims)
#[derive(Debug, Deserialize, Clone)]
struct User {
sub: String,
}
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), InitError> { async fn main() -> Result<(), InitError> {
tracing_subscriber::registry() tracing_subscriber::registry()
@ -17,6 +24,7 @@ async fn main() -> Result<(), InitError> {
.with(tracing_subscriber::fmt::layer()) .with(tracing_subscriber::fmt::layer())
.init(); .init();
// claims checker function
fn claim_checker(u: &User) -> bool { fn claim_checker(u: &User) -> bool {
info!("checking claims: {} -> {}", u.sub, u.sub.contains('@')); info!("checking claims: {} -> {}", u.sub, u.sub.contains('@'));
@ -24,12 +32,12 @@ async fn main() -> Result<(), InitError> {
} }
// starting oidc provider (discovery is needed by from_oidc()) // starting oidc provider (discovery is needed by from_oidc())
oidc_provider::run_server(); let issuer_uri = oidc_provider::run_server();
// First let's create an authorizer builder from a JWKS Endpoint // First let's create an authorizer builder from a Oidc Discovery
// User is a struct deserializable from JWT claims representing the authorized user // User is a struct deserializable from JWT claims representing the authorized user
// let jwt_auth: JwtAuthorizer<User> = JwtAuthorizer::from_oidc("https://accounts.google.com/") // let jwt_auth: JwtAuthorizer<User> = JwtAuthorizer::from_oidc("https://accounts.google.com/")
let jwt_auth: JwtAuthorizer<User> = JwtAuthorizer::from_oidc("http://localhost:3001") let jwt_auth: JwtAuthorizer<User> = JwtAuthorizer::from_oidc(issuer_uri)
// .no_refresh() // .no_refresh()
.refresh(Refresh { .refresh(Refresh {
strategy: RefreshStrategy::Interval, strategy: RefreshStrategy::Interval,
@ -42,10 +50,11 @@ async fn main() -> Result<(), InitError> {
.route("/protected", get(protected)) .route("/protected", get(protected))
// adding the authorizer layer // adding the authorizer layer
.layer(jwt_auth.layer().await?); .layer(jwt_auth.layer().await?);
// .layer(jwt_auth.check_claims(|_: User| true));
let app = Router::new() let app = Router::new()
// actual protected apis // public endpoint
.route("/public", get(public_handler))
// protected APIs
.nest("/api", api) .nest("/api", api)
.layer(TraceLayer::new_for_http()); .layer(TraceLayer::new_for_http());
@ -57,18 +66,13 @@ async fn main() -> Result<(), InitError> {
Ok(()) Ok(())
} }
/// handler with injected claims object
async fn protected(JwtClaims(user): JwtClaims<User>) -> Result<String, AuthError> { async fn protected(JwtClaims(user): JwtClaims<User>) -> Result<String, AuthError> {
// Send the protected data to the user // Send the protected data to the user
Ok(format!("Welcome: {}", user.sub)) Ok(format!("Welcome: {}", user.sub))
} }
#[derive(Debug, Deserialize, Clone)] // public url handler
struct User { async fn public_handler() -> &'static str {
sub: String, "Public URL!"
}
impl Display for User {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "User: {:?}", self.sub)
}
} }

View file

@ -1,59 +1,27 @@
use axum::{ use axum::{routing::get, Json, Router};
async_trait,
extract::{FromRequestParts, TypedHeader},
headers::{authorization::Bearer, Authorization},
http::{request::Parts, StatusCode},
response::{IntoResponse, Response},
routing::{get, post},
Json, Router,
};
use josekit::jwk::{ use josekit::jwk::{
alg::{ec::EcCurve, ec::EcKeyPair, ed::EdKeyPair, rsa::RsaKeyPair}, alg::{ec::EcCurve, ec::EcKeyPair, ed::EdKeyPair, rsa::RsaKeyPair},
Jwk, Jwk,
}; };
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation}; use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{json, Value}; use serde_json::{json, Value};
use std::{fmt::Display, net::SocketAddr, thread, time::Duration}; use std::{net::SocketAddr, thread, time::Duration};
pub static KEYS: Lazy<Keys> = Lazy::new(|| {
//let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set");
// Keys::new("xxxxx".as_bytes())
Keys::load_rsa()
});
const ISSUER_URI: &str = "http://localhost:3001"; const ISSUER_URI: &str = "http://localhost:3001";
pub struct Keys {
pub alg: Algorithm,
pub encoding: EncodingKey,
pub decoding: DecodingKey,
}
impl Keys {
fn load_rsa() -> Self {
Self {
alg: Algorithm::RS256,
encoding: EncodingKey::from_rsa_pem(include_bytes!("../../../config/jwtRS256.key")).unwrap(),
decoding: DecodingKey::from_rsa_pem(include_bytes!("../../../config/jwtRS256.key.pub")).unwrap(),
}
}
}
/// OpenId Connect discovery (simplified for test purposes) /// OpenId Connect discovery (simplified for test purposes)
#[derive(Serialize, Clone)] #[derive(Serialize, Clone)]
pub struct OidcDiscovery { struct OidcDiscovery {
issuer: String, issuer: String,
jwks_uri: String, jwks_uri: String,
authorization_endpoint: String,
} }
pub async fn discovery() -> Json<Value> { /// discovery url handler
async fn discovery() -> Json<Value> {
let d = OidcDiscovery { let d = OidcDiscovery {
issuer: ISSUER_URI.to_owned(), issuer: ISSUER_URI.to_owned(),
jwks_uri: format!("{}/jwks", ISSUER_URI), jwks_uri: format!("{ISSUER_URI}/jwks"),
authorization_endpoint: format!("{}/authorize", ISSUER_URI),
}; };
Json(json!(d)) Json(json!(d))
} }
@ -63,9 +31,8 @@ struct JwkSet {
keys: Vec<Jwk>, keys: Vec<Jwk>,
} }
pub async fn jwks() -> Json<Value> { /// jwk set endpoint handler
// let mut ksmap = serde_json::Map::new(); async fn jwks() -> Json<Value> {
let mut kset = JwkSet { keys: Vec::<Jwk>::new() }; let mut kset = JwkSet { keys: Vec::<Jwk>::new() };
let keypair = RsaKeyPair::from_pem(include_bytes!("../../../config/jwtRS256.key")).unwrap(); let keypair = RsaKeyPair::from_pem(include_bytes!("../../../config/jwtRS256.key")).unwrap();
@ -113,6 +80,7 @@ pub async fn jwks() -> Json<Value> {
Json(json!(kset)) Json(json!(kset))
} }
/// build a minimal JWT header
fn build_header(alg: Algorithm, kid: &str) -> Header { fn build_header(alg: Algorithm, kid: &str) -> Header {
Header { Header {
typ: Some("JWT".to_string()), typ: Some("JWT".to_string()),
@ -128,11 +96,22 @@ fn build_header(alg: Algorithm, kid: &str) -> Header {
} }
} }
/// issues test tokens (this is not a standard endpoint) /// token claims
#[derive(Debug, Serialize, Deserialize)]
struct Claims {
iss: &'static str,
sub: &'static str,
exp: usize,
nbf: usize,
}
/// handler issuing test tokens (this is not a standard endpoint)
pub async fn tokens() -> Json<Value> { pub async fn tokens() -> Json<Value> {
let claims = Claims { let claims = Claims {
sub: "b@b.com".to_owned(), iss: ISSUER_URI,
sub: "b@b.com",
exp: 2000000000, // May 2033 exp: 2000000000, // May 2033
nbf: 1516239022, // Jan 2018
}; };
let rsa_key = EncodingKey::from_rsa_pem(include_bytes!("../../../config/jwtRS256.key")).unwrap(); let rsa_key = EncodingKey::from_rsa_pem(include_bytes!("../../../config/jwtRS256.key")).unwrap();
@ -150,33 +129,10 @@ pub async fn tokens() -> Json<Value> {
})) }))
} }
pub async fn authorize(Json(payload): Json<AuthPayload>) -> Result<Json<AuthBody>, AuthError> { /// exposes some oidc "like" endpoints for test purposes
tracing::info!("authorizing ..."); pub fn run_server() -> &'static str {
if payload.client_id.is_empty() || payload.client_secret.is_empty() {
return Err(AuthError::MissingCredentials);
}
// Here you can check the user credentials from a database
if payload.client_id != "foo" || payload.client_secret != "bar" {
return Err(AuthError::WrongCredentials);
}
let claims = Claims {
sub: "b@b.com".to_owned(),
// Mandatory expiry time as UTC timestamp
exp: 2000000000, // May 2033
};
// Create the authorization token
let token = encode(&Header::new(KEYS.alg), &claims, &KEYS.encoding).map_err(|_| AuthError::TokenCreation)?;
// Send the authorized token
Ok(Json(AuthBody::new(token)))
}
/// exposes oidc "like" endpoints (this is for test purposes)
pub fn run_server() {
// oidc "like" endpoints for test purposes
let app = Router::new() let app = Router::new()
.route("/.well-known/openid-configuration", get(discovery)) .route("/.well-known/openid-configuration", get(discovery))
.route("/authorize", post(authorize))
.route("/jwks", get(jwks)) .route("/jwks", get(jwks))
.route("/tokens", get(tokens)); .route("/tokens", get(tokens));
@ -187,79 +143,6 @@ pub fn run_server() {
}); });
thread::sleep(Duration::from_millis(200)); // waiting oidc to start thread::sleep(Duration::from_millis(200)); // waiting oidc to start
}
ISSUER_URI
#[derive(Debug, Serialize, Deserialize)]
struct Claims {
sub: String,
exp: usize,
}
#[derive(Debug, Serialize)]
pub struct AuthBody {
access_token: String,
token_type: String,
}
#[derive(Debug, Deserialize)]
pub struct AuthPayload {
client_id: String,
client_secret: String,
}
#[derive(Debug)]
pub enum AuthError {
WrongCredentials,
MissingCredentials,
TokenCreation,
InvalidToken,
}
impl Display for Claims {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "sub: {}", self.sub)
}
}
impl AuthBody {
fn new(access_token: String) -> Self {
Self {
access_token,
token_type: "Bearer".to_string(),
}
}
}
impl IntoResponse for AuthError {
fn into_response(self) -> Response {
let (status, error_message) = match self {
AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
AuthError::MissingCredentials => (StatusCode::BAD_REQUEST, "Missing credentials"),
AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"),
AuthError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"),
};
let body = Json(json!({
"error": error_message,
}));
(status, body).into_response()
}
}
#[async_trait]
impl<S> FromRequestParts<S> for Claims
where
S: Send + Sync,
{
type Rejection = AuthError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
// Extract the token from the authorization header
let TypedHeader(Authorization(bearer)) = TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
.await
.map_err(|_| AuthError::InvalidToken)?;
let token_data =
decode::<Claims>(bearer.token(), &KEYS.decoding, &Validation::default()).map_err(|_| AuthError::InvalidToken)?;
Ok(token_data.claims)
}
} }

View file

@ -57,7 +57,7 @@ fn response_wwwauth(status: StatusCode, bearer: &str) -> Response<BoxBody> {
let h = if bearer.is_empty() { let h = if bearer.is_empty() {
"Bearer".to_owned() "Bearer".to_owned()
} else { } else {
format!("Bearer {}", bearer) format!("Bearer {bearer}")
}; };
res.headers_mut().insert(header::WWW_AUTHENTICATE, h.parse().unwrap()); res.headers_mut().insert(header::WWW_AUTHENTICATE, h.parse().unwrap());

View file

@ -9,7 +9,7 @@ pub struct OidcDiscovery {
} }
pub async fn discover_jwks(issuer: &str) -> Result<String, InitError> { pub async fn discover_jwks(issuer: &str) -> Result<String, InitError> {
let discovery_url = format!("{}/.well-known/openid-configuration", issuer); let discovery_url = format!("{issuer}/.well-known/openid-configuration");
reqwest::Client::new() reqwest::Client::new()
.get(discovery_url) .get(discovery_url)
.send() .send()