diff --git a/jwt-authorizer/src/authorizer.rs b/jwt-authorizer/src/authorizer.rs index c4c8a3f..00d0840 100644 --- a/jwt-authorizer/src/authorizer.rs +++ b/jwt-authorizer/src/authorizer.rs @@ -64,7 +64,14 @@ where Authorizer { key_source: KeySource::SingleKeySource(Arc::new(KeyData { kid: None, - alg: vec![Algorithm::RS256, Algorithm::RS384, Algorithm::RS512], + algs: vec![ + Algorithm::RS256, + Algorithm::RS384, + Algorithm::RS512, + Algorithm::PS256, + Algorithm::PS384, + Algorithm::PS512, + ], key, })), claims_checker, @@ -77,7 +84,14 @@ where Authorizer { key_source: KeySource::SingleKeySource(Arc::new(KeyData { kid: None, - alg: vec![Algorithm::RS256, Algorithm::RS384, Algorithm::RS512], + algs: vec![ + Algorithm::RS256, + Algorithm::RS384, + Algorithm::RS512, + Algorithm::PS256, + Algorithm::PS384, + Algorithm::PS512, + ], key, })), claims_checker, @@ -90,7 +104,7 @@ where Authorizer { key_source: KeySource::SingleKeySource(Arc::new(KeyData { kid: None, - alg: vec![Algorithm::ES256, Algorithm::ES384], + algs: vec![Algorithm::ES256, Algorithm::ES384], key, })), claims_checker, @@ -103,7 +117,7 @@ where Authorizer { key_source: KeySource::SingleKeySource(Arc::new(KeyData { kid: None, - alg: vec![Algorithm::ES256, Algorithm::ES384], + algs: vec![Algorithm::ES256, Algorithm::ES384], key, })), claims_checker, @@ -116,7 +130,7 @@ where Authorizer { key_source: KeySource::SingleKeySource(Arc::new(KeyData { kid: None, - alg: vec![Algorithm::EdDSA], + algs: vec![Algorithm::EdDSA], key, })), claims_checker, @@ -129,7 +143,7 @@ where Authorizer { key_source: KeySource::SingleKeySource(Arc::new(KeyData { kid: None, - alg: vec![Algorithm::EdDSA], + algs: vec![Algorithm::EdDSA], key, })), claims_checker, @@ -142,7 +156,7 @@ where Authorizer { key_source: KeySource::SingleKeySource(Arc::new(KeyData { kid: None, - alg: vec![Algorithm::HS256, Algorithm::HS384, Algorithm::HS512], + algs: vec![Algorithm::HS256, Algorithm::HS384, Algorithm::HS512], key, })), claims_checker, @@ -214,7 +228,7 @@ where let header = decode_header(token)?; // TODO: (optimisation) build & store jwt_validation in key data, to avoid rebuilding it for each check let val_key = self.key_source.get_key(header).await?; - let jwt_validation = &self.validation.to_jwt_validation(val_key.alg.clone()); + let jwt_validation = &self.validation.to_jwt_validation(val_key.algs.clone()); let token_data = decode::(token, &val_key.key, jwt_validation)?; if let Some(ref checker) = self.claims_checker { diff --git a/jwt-authorizer/src/jwks/mod.rs b/jwt-authorizer/src/jwks/mod.rs index 4d07f06..f745df6 100644 --- a/jwt-authorizer/src/jwks/mod.rs +++ b/jwt-authorizer/src/jwks/mod.rs @@ -1,6 +1,9 @@ use std::{str::FromStr, sync::Arc}; -use jsonwebtoken::{jwk::Jwk, Algorithm, DecodingKey, Header}; +use jsonwebtoken::{ + jwk::{AlgorithmParameters, Jwk}, + Algorithm, DecodingKey, Header, +}; use crate::error::AuthError; @@ -21,15 +24,40 @@ pub enum KeySource { #[derive(Clone)] pub struct KeyData { pub kid: Option, - pub alg: Vec, + /// valid algorithms + pub algs: Vec, pub key: DecodingKey, } +fn get_valid_algs(key: &Jwk) -> Vec { + if let Some(key_alg) = key.common.key_algorithm { + // if alg is not correct => no valid algs => empty array + Algorithm::from_str(key_alg.to_string().as_str()).map_or(vec![], |a| vec![a]) + } else { + // guessing valid algs from key structure + match key.algorithm { + AlgorithmParameters::EllipticCurve(_) => { + vec![Algorithm::ES256, Algorithm::ES384] + } + AlgorithmParameters::RSA(_) => vec![ + Algorithm::RS256, + Algorithm::RS384, + Algorithm::RS512, + Algorithm::PS256, + Algorithm::PS384, + Algorithm::PS512, + ], + AlgorithmParameters::OctetKey(_) => vec![Algorithm::EdDSA], + AlgorithmParameters::OctetKeyPair(_) => vec![Algorithm::HS256, Algorithm::HS384, Algorithm::HS512], + } + } +} + impl KeyData { pub fn from_jwk(key: &Jwk) -> Result { Ok(KeyData { kid: key.common.key_id.clone(), - alg: vec![Algorithm::from_str(key.common.key_algorithm.unwrap().to_string().as_str())?], + algs: get_valid_algs(key), key: DecodingKey::from_jwk(key)?, }) } @@ -55,7 +83,7 @@ impl KeySet { /// Find the key in the set that matches the given key id, if any. pub fn find_alg(&self, alg: &Algorithm) -> Option<&Arc> { - self.0.iter().find(|k| k.alg.contains(alg)) + self.0.iter().find(|k| k.algs.contains(alg)) } /// Find first key.