fix(claims): aud can a string or an array of strings

fixes #26
This commit is contained in:
cduvray 2023-07-06 07:47:22 +02:00 committed by cduvray
parent fc82bea5f4
commit 70ce996275
2 changed files with 21 additions and 38 deletions

View file

@ -1,7 +1,6 @@
use chrono::{DateTime, TimeZone, Utc}; use chrono::{DateTime, TimeZone, Utc};
use std::fmt;
use serde::{de, Deserialize, Deserializer}; use serde::Deserialize;
/// Seconds since the epoch /// Seconds since the epoch
#[derive(Deserialize, Clone, PartialEq, Eq, Debug)] #[derive(Deserialize, Clone, PartialEq, Eq, Debug)]
@ -13,8 +12,12 @@ impl From<NumericDate> for DateTime<Utc> {
} }
} }
#[derive(PartialEq, Debug, Clone)] #[derive(PartialEq, Debug, Clone, Deserialize)]
pub struct StringList(Vec<String>); #[serde(untagged)]
pub enum OneOrArray<T> {
One(T),
Array(Vec<T>),
}
/// Claims mentioned in the JWT specifications. /// Claims mentioned in the JWT specifications.
/// ///
@ -23,39 +26,13 @@ pub struct StringList(Vec<String>);
pub struct RegisteredClaims { pub struct RegisteredClaims {
pub iss: Option<String>, pub iss: Option<String>,
pub sub: Option<String>, pub sub: Option<String>,
pub aud: Option<StringList>, pub aud: Option<OneOrArray<String>>,
pub exp: Option<NumericDate>, pub exp: Option<NumericDate>,
pub nbf: Option<NumericDate>, pub nbf: Option<NumericDate>,
pub iat: Option<NumericDate>, pub iat: Option<NumericDate>,
pub jti: Option<String>, pub jti: Option<String>,
} }
impl<'de> Deserialize<'de> for StringList {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct StringListVisitor;
impl<'de> de::Visitor<'de> for StringListVisitor {
type Value = StringList;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(formatter, "a space seperated strings")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
let auds: Vec<String> = v.split(' ').map(|s| s.to_owned()).collect();
Ok(StringList(auds))
}
}
deserializer.deserialize_string(StringListVisitor)
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
@ -63,17 +40,20 @@ mod tests {
use serde::Deserialize; use serde::Deserialize;
use serde_json::json; use serde_json::json;
use crate::claims::{NumericDate, RegisteredClaims, StringList}; use crate::claims::{NumericDate, OneOrArray, RegisteredClaims};
#[derive(Deserialize)] #[derive(Deserialize)]
struct TestStruct { struct TestStruct {
v: StringList, v: OneOrArray<String>,
} }
#[test] #[test]
fn rfc_claims_aud() { fn rfc_claims_aud() {
let a: TestStruct = serde_json::from_str(r#"{"v":"a b"}"#).unwrap(); let a: TestStruct = serde_json::from_str(r#"{"v":"a"}"#).unwrap();
assert_eq!(a.v, StringList(vec!["a".to_owned(), "b".to_owned()])); assert_eq!(a.v, OneOrArray::One("a".to_owned()));
let a: TestStruct = serde_json::from_str(r#"{"v":["a", "b"]}"#).unwrap();
assert_eq!(a.v, OneOrArray::Array(vec!["a".to_owned(), "b".to_owned()]));
} }
#[test] #[test]
@ -87,7 +67,7 @@ mod tests {
fn rfc_claims() { fn rfc_claims() {
let jwt_json = json!({ let jwt_json = json!({
"iss": "http://localhost:3001", "iss": "http://localhost:3001",
"aud": "aud1 aud2", "aud": ["aud1", "aud2"],
"sub": "bob", "sub": "bob",
"exp": 1516240122, "exp": 1516240122,
"iat": 1516239022, "iat": 1516239022,
@ -96,7 +76,10 @@ mod tests {
let claims: RegisteredClaims = serde_json::from_value(jwt_json).expect("Failed RfcClaims deserialisation"); let claims: RegisteredClaims = serde_json::from_value(jwt_json).expect("Failed RfcClaims deserialisation");
assert_eq!(claims.iss.unwrap(), "http://localhost:3001"); assert_eq!(claims.iss.unwrap(), "http://localhost:3001");
assert_eq!(claims.aud.unwrap(), StringList(vec!["aud1".to_owned(), "aud2".to_owned()])); assert_eq!(
claims.aud.unwrap(),
OneOrArray::Array(vec!["aud1".to_owned(), "aud2".to_owned()])
);
assert_eq!(claims.exp.unwrap(), NumericDate(1516240122)); assert_eq!(claims.exp.unwrap(), NumericDate(1516240122));
let dt: DateTime<Utc> = claims.iat.unwrap().into(); let dt: DateTime<Utc> = claims.iat.unwrap().into();

View file

@ -6,7 +6,7 @@ use jsonwebtoken::TokenData;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
pub use self::error::AuthError; pub use self::error::AuthError;
pub use claims::{NumericDate, RegisteredClaims, StringList}; pub use claims::{NumericDate, OneOrArray, RegisteredClaims};
pub use jwks::key_store_manager::{Refresh, RefreshStrategy}; pub use jwks::key_store_manager::{Refresh, RefreshStrategy};
pub use layer::JwtAuthorizer; pub use layer::JwtAuthorizer;
pub use validation::Validation; pub use validation::Validation;