use axum::{ async_trait, extract::{FromRequestParts, TypedHeader}, headers::{authorization::Bearer, Authorization}, http::{request::Parts, StatusCode}, response::{IntoResponse, Response}, routing::{get, post}, Json, Router, }; use josekit::jwk::{ alg::{ec::EcCurve, ec::EcKeyPair, ed::EdKeyPair, rsa::RsaKeyPair}, Jwk, }; use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation}; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use std::{fmt::Display, net::SocketAddr, thread, time::Duration}; pub static KEYS: Lazy = Lazy::new(|| { //let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set"); // Keys::new("xxxxx".as_bytes()) Keys::load_rsa() }); const ISSUER_URI: &str = "http://localhost:3001"; pub struct Keys { pub alg: Algorithm, pub encoding: EncodingKey, pub decoding: DecodingKey, } impl Keys { fn load_rsa() -> Self { Self { alg: Algorithm::RS256, encoding: EncodingKey::from_rsa_pem(include_bytes!("../../../config/jwtRS256.key")).unwrap(), decoding: DecodingKey::from_rsa_pem(include_bytes!("../../../config/jwtRS256.key.pub")).unwrap(), } } } /// OpenId Connect discovery (simplified for test purposes) #[derive(Serialize, Clone)] pub struct OidcDiscovery { issuer: String, jwks_uri: String, authorization_endpoint: String, } pub async fn discovery() -> Json { let d = OidcDiscovery { issuer: ISSUER_URI.to_owned(), jwks_uri: format!("{}/jwks", ISSUER_URI), authorization_endpoint: format!("{}/authorize", ISSUER_URI), }; Json(json!(d)) } #[derive(Debug, Eq, PartialEq, Clone, Deserialize, Serialize)] struct JwkSet { keys: Vec, } pub async fn jwks() -> Json { // let mut ksmap = serde_json::Map::new(); let mut kset = JwkSet { keys: Vec::::new() }; let keypair = RsaKeyPair::from_pem(include_bytes!("../../../config/jwtRS256.key")).unwrap(); let mut pk = keypair.to_jwk_public_key(); pk.set_key_id("key-rsa"); pk.set_algorithm("RS256"); pk.set_key_use("sig"); kset.keys.push(pk); let keypair = RsaKeyPair::from_pem(include_bytes!("../../../config/private_rsa_key_pkcs8.pem")).unwrap(); let mut pk = keypair.to_jwk_public_key(); pk.set_key_id("rsa01"); pk.set_algorithm("RS256"); pk.set_key_use("sig"); kset.keys.push(pk); let keypair = EcKeyPair::from_pem(include_bytes!("../../../config/ec256-private.pem"), Some(EcCurve::P256)).unwrap(); let mut pk = keypair.to_jwk_public_key(); pk.set_key_id("key-ec"); pk.set_algorithm("ES256"); pk.set_key_use("sig"); kset.keys.push(pk); let keypair = EcKeyPair::from_pem(include_bytes!("../../../config/private_ecdsa_key.pem"), Some(EcCurve::P256)).unwrap(); let mut pk = keypair.to_jwk_public_key(); pk.set_key_id("ec01"); pk.set_algorithm("ES256"); pk.set_key_use("sig"); kset.keys.push(pk); let keypair = EdKeyPair::from_pem(include_bytes!("../../../config/ed25519-private.pem")).unwrap(); let mut pk = keypair.to_jwk_public_key(); pk.set_key_id("key-ed"); pk.set_algorithm("EdDSA"); pk.set_key_use("sig"); kset.keys.push(pk); let keypair = EdKeyPair::from_pem(include_bytes!("../../../config/private_ed25519_key.pem")).unwrap(); let mut pk = keypair.to_jwk_public_key(); pk.set_key_id("ed01"); pk.set_algorithm("EdDSA"); pk.set_key_use("sig"); kset.keys.push(pk); Json(json!(kset)) } fn build_header(alg: Algorithm, kid: &str) -> Header { Header { typ: Some("JWT".to_string()), alg, kid: Some(kid.to_owned()), cty: None, jku: None, jwk: None, x5u: None, x5c: None, x5t: None, x5t_s256: None, } } /// issues test tokens (this is not a standard endpoint) pub async fn tokens() -> Json { let claims = Claims { sub: "b@b.com".to_owned(), exp: 2000000000, // May 2033 }; let rsa_key = EncodingKey::from_rsa_pem(include_bytes!("../../../config/jwtRS256.key")).unwrap(); let ec_key = EncodingKey::from_ec_pem(include_bytes!("../../../config/ec256-private.pem")).unwrap(); let ed_key = EncodingKey::from_ed_pem(include_bytes!("../../../config/ed25519-private.pem")).unwrap(); let rsa_token = encode(&build_header(Algorithm::RS256, "key-rsa"), &claims, &rsa_key).unwrap(); let ec_token = encode(&build_header(Algorithm::ES256, "key-ec"), &claims, &ec_key).unwrap(); let ed_token = encode(&build_header(Algorithm::EdDSA, "key-ed"), &claims, &ed_key).unwrap(); Json(json!({ "rsa": rsa_token, "ec": ec_token, "ed": ed_token })) } pub async fn authorize(Json(payload): Json) -> Result, AuthError> { tracing::info!("authorizing ..."); if payload.client_id.is_empty() || payload.client_secret.is_empty() { return Err(AuthError::MissingCredentials); } // Here you can check the user credentials from a database if payload.client_id != "foo" || payload.client_secret != "bar" { return Err(AuthError::WrongCredentials); } let claims = Claims { sub: "b@b.com".to_owned(), // Mandatory expiry time as UTC timestamp exp: 2000000000, // May 2033 }; // Create the authorization token let token = encode(&Header::new(KEYS.alg), &claims, &KEYS.encoding).map_err(|_| AuthError::TokenCreation)?; // Send the authorized token Ok(Json(AuthBody::new(token))) } /// exposes oidc "like" endpoints (this is for test purposes) pub fn run_server() { // oidc "like" endpoints for test purposes let app = Router::new() .route("/.well-known/openid-configuration", get(discovery)) .route("/authorize", post(authorize)) .route("/jwks", get(jwks)) .route("/tokens", get(tokens)); tokio::spawn(async move { let addr = SocketAddr::from(([127, 0, 0, 1], 3001)); tracing::info!("oidc provider starting on: {}", addr); axum::Server::bind(&addr).serve(app.into_make_service()).await.unwrap(); }); thread::sleep(Duration::from_millis(200)); // waiting oidc to start } #[derive(Debug, Serialize, Deserialize)] struct Claims { sub: String, exp: usize, } #[derive(Debug, Serialize)] pub struct AuthBody { access_token: String, token_type: String, } #[derive(Debug, Deserialize)] pub struct AuthPayload { client_id: String, client_secret: String, } #[derive(Debug)] pub enum AuthError { WrongCredentials, MissingCredentials, TokenCreation, InvalidToken, } impl Display for Claims { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "sub: {}", self.sub) } } impl AuthBody { fn new(access_token: String) -> Self { Self { access_token, token_type: "Bearer".to_string(), } } } impl IntoResponse for AuthError { fn into_response(self) -> Response { let (status, error_message) = match self { AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"), AuthError::MissingCredentials => (StatusCode::BAD_REQUEST, "Missing credentials"), AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"), AuthError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"), }; let body = Json(json!({ "error": error_message, })); (status, body).into_response() } } #[async_trait] impl FromRequestParts for Claims where S: Send + Sync, { type Rejection = AuthError; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { // Extract the token from the authorization header let TypedHeader(Authorization(bearer)) = TypedHeader::>::from_request_parts(parts, state) .await .map_err(|_| AuthError::InvalidToken)?; let token_data = decode::(bearer.token(), &KEYS.decoding, &Validation::default()).map_err(|_| AuthError::InvalidToken)?; Ok(token_data.claims) } }