Add testing of resigned token and FullTokenValidator

This commit is contained in:
Jeffrey Dallatezza
2024-10-24 12:30:26 -07:00
parent 07d166ddbe
commit eed796f3c5
3 changed files with 122 additions and 78 deletions
+17 -4
View File
@@ -104,7 +104,7 @@ pub struct SpacetimeAuth {
use jsonwebtoken;
pub struct TokenClaims {
struct TokenClaims {
pub issuer: String,
pub subject: String,
pub audience: Vec<String>,
@@ -123,11 +123,15 @@ impl From<SpacetimeAuth> for TokenClaims {
impl TokenClaims {
// Compute the id from the issuer and subject.
pub fn id(&self) -> Identity {
fn id(&self) -> Identity {
Identity::from_claims(&self.issuer, &self.subject)
}
pub fn encode_and_sign_with_expiry(&self, private_key: &EncodingKey, expiry: Option<Duration>) -> Result<String, JwtError> {
fn encode_and_sign_with_expiry(
&self,
private_key: &EncodingKey,
expiry: Option<Duration>,
) -> Result<String, JwtError> {
let iat = SystemTime::now();
let exp = expiry.map(|dur| iat + dur);
let claims = SpacetimeIdentityClaims2 {
@@ -165,7 +169,12 @@ impl SpacetimeAuth {
SpacetimeCreds::from_signed_token(token)
};
Ok(Self { creds, identity, subject, issuer: ctx.local_issuer() })
Ok(Self {
creds,
identity,
subject,
issuer: ctx.local_issuer(),
})
}
/// Get the auth credentials as headers to be returned from an endpoint.
@@ -175,6 +184,10 @@ impl SpacetimeAuth {
TypedHeader(SpacetimeIdentityToken(self.creds)),
)
}
pub fn resign_with_expiry(&self, private_key: &EncodingKey, expiry: Duration) -> Result<String, JwtError> {
TokenClaims::from(self.clone()).encode_and_sign_with_expiry(private_key, Some(expiry))
}
}
#[cfg(test)]
+4 -4
View File
@@ -6,11 +6,10 @@ use http::header::CONTENT_TYPE;
use http::StatusCode;
use serde::{Deserialize, Serialize};
use spacetimedb::auth::identity::encode_token_with_expiry;
use spacetimedb_lib::de::serde::DeserializeWrapper;
use spacetimedb_lib::Identity;
use crate::auth::{SpacetimeAuth, SpacetimeAuthRequired, TokenClaims};
use crate::auth::{SpacetimeAuth, SpacetimeAuthRequired};
use crate::{log_and_500, ControlStateDelegate, NodeDelegate};
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -101,8 +100,9 @@ pub async fn create_websocket_token<S: NodeDelegate>(
SpacetimeAuthRequired(auth): SpacetimeAuthRequired,
) -> axum::response::Result<impl IntoResponse> {
let expiry = Duration::from_secs(60);
let claims: TokenClaims = TokenClaims::from(auth);
let token = claims.encode_and_sign_with_expiry(ctx.private_key(), Some(expiry)).map_err(log_and_500)?;
let token = auth
.resign_with_expiry(ctx.private_key(), expiry)
.map_err(log_and_500)?;
// let token = encode_token_with_expiry(ctx.private_key(), auth.identity, Some(expiry)).map_err(log_and_500)?;
Ok(axum::Json(WebsocketTokenResponse { token }))
}
+101 -70
View File
@@ -60,87 +60,53 @@ impl TokenValidator for UnimplementedTokenValidator {
}
}
/*
pub struct FullTokenValidator {
pub public_key: DecodingKey,
pub caching_validator: CachingOidcTokenValidator,
pub struct FullTokenValidator<T: TokenValidator + Send + Sync> {
pub local_key: DecodingKey,
pub local_issuer: String,
pub oidc_validator: T,
// pub caching_validator: CachingOidcTokenValidator,
}
#[async_trait]
impl TokenValidator for FullTokenValidator {
impl<T> TokenValidator for FullTokenValidator<T>
where
T: TokenValidator + Send + Sync,
{
async fn validate_token(&self, token: &str) -> Result<SpacetimeIdentityClaims2, TokenValidationError> {
let issuer = get_raw_issuer(token)?;
if issuer == "localhost" {
let claims = BasicTokenValidator {
public_key: self.public_key.clone(),
issuer,
let local_key_error = {
let first_validator = BasicTokenValidator {
public_key: self.local_key.clone(),
issuer: None,
};
match first_validator.validate_token(token).await {
Ok(claims) => return Ok(claims),
Err(e) => e,
}
.validate_token(token)
.await?;
return Ok(claims);
};
// If that fails, we try the OIDC validator.
let issuer = get_raw_issuer(token)?;
// If we are the issuer, then we should have already validated the token.
// TODO: "localhost" should not be hard-coded.
if issuer == self.local_issuer {
return Err(local_key_error);
}
self.caching_validator.validate_token(token).await
self.oidc_validator.validate_token(token).await
}
}
*/
// This is a helper function that uses a global JWK cache. We should remove this eventually, and make the server hold on to its own.
pub async fn validate_token(
local_key: DecodingKey,
local_issuer: &str,
token: &str,
) -> Result<SpacetimeIdentityClaims2, TokenValidationError> {
let local_key_error = {
let first_validator = BasicTokenValidator {
public_key: local_key.clone(),
issuer: None,
};
match first_validator.validate_token(token).await {
Ok(claims) => return Ok(claims),
Err(e) => e,
}
};
// If that fails, we try the OIDC validator.
let issuer = get_raw_issuer(token)?;
// If we are the issuer, then we should have already validated the token.
// TODO: "localhost" should not be hard-coded.
if issuer == local_issuer {
return Err(local_key_error);
}
GLOBAL_OIDC_VALIDATOR.clone().validate_token(token).await
}
pub struct InitialTestingTokenValidator {
pub public_key: DecodingKey,
}
#[async_trait]
impl TokenValidator for InitialTestingTokenValidator {
async fn validate_token(&self, token: &str) -> Result<SpacetimeIdentityClaims2, TokenValidationError> {
// Initially, we check if we signed the key.
let local_key_error = {
let first_validator = BasicTokenValidator {
public_key: self.public_key.clone(),
issuer: None,
};
match first_validator.validate_token(token).await {
Ok(claims) => return Ok(claims),
Err(e) => e,
}
};
// If that fails, we try the OIDC validator.
let issuer = get_raw_issuer(token)?;
// If we are the issuer, then we should have already validated the token.
// TODO: "localhost" should not be hard-coded.
if issuer == "localhost" {
return Err(local_key_error);
}
let validator = OidcTokenValidator;
validator.validate_token(token).await
}
let validator = FullTokenValidator {
local_key,
local_issuer: local_issuer.to_string(),
oidc_validator: GLOBAL_OIDC_VALIDATOR.clone(),
};
validator.validate_token(token).await
}
// This verifies against a given public key and expected issuer.
@@ -190,8 +156,7 @@ impl TokenValidator for BasicTokenValidator {
expected_issuer
)));
}
}
}
claims.try_into()
}
}
@@ -334,7 +299,7 @@ mod tests {
use crate::auth::identity::{IncomingClaims, SpacetimeIdentityClaims2};
use crate::auth::token_validation::{
CachingOidcTokenValidator, BasicTokenValidator, OidcTokenValidator, TokenValidator,
BasicTokenValidator, CachingOidcTokenValidator, FullTokenValidator, OidcTokenValidator, TokenValidator,
};
use jsonwebkey as jwk;
use jsonwebtoken::{DecodingKey, EncodingKey};
@@ -476,6 +441,61 @@ mod tests {
Ok(())
}
async fn assert_validation_fails<T: TokenValidator>(validator: &T, token: &str) -> anyhow::Result<()> {
let result = validator.validate_token(token).await;
if result.is_ok() {
let claims = result.unwrap();
anyhow::bail!("Validation succeeded when it should have failed: {:?}", claims);
}
Ok(())
}
#[tokio::test]
async fn resigned_token_ignores_issuer() -> anyhow::Result<()> {
// Test that the decoding key must work for LocalTokenValidator.
let kp = KeyPair::generate_p256()?;
let local_issuer = "test1";
let external_issuer = "other_issuer";
let subject = "test_subject";
let orig_claims = IncomingClaims {
identity: None,
subject: subject.to_string(),
issuer: external_issuer.to_string(),
audience: vec![],
iat: std::time::SystemTime::now(),
exp: None,
};
let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::ES256);
let token = jsonwebtoken::encode(&header, &orig_claims, &kp.private_key)?;
// First, try the successful case with the FullTokenValidator.
{
let validator = FullTokenValidator {
local_key: kp.public_key.clone(),
local_issuer: local_issuer.to_string(),
oidc_validator: OidcTokenValidator,
};
let parsed_claims: SpacetimeIdentityClaims2 = validator.validate_token(&token).await?;
assert_eq!(parsed_claims.issuer, external_issuer);
assert_eq!(parsed_claims.subject, subject);
assert_eq!(parsed_claims.identity, Identity::from_claims(external_issuer, subject));
}
// Double check that this token would fail with an OidcTokenValidator.
assert_validation_fails(&OidcTokenValidator, &token).await?;
// Double check that validation fails if we check the issuer.
assert_validation_fails(
&BasicTokenValidator {
public_key: kp.public_key.clone(),
issuer: Some(local_issuer.to_string()),
},
&token,
)
.await?;
Ok(())
}
use axum::routing::get;
use axum::Json;
use axum::Router;
@@ -601,4 +621,15 @@ mod tests {
let v = CachingOidcTokenValidator::get_default();
run_oidc_test(v).await
}
#[tokio::test]
async fn test_full_validator_fallback() -> anyhow::Result<()> {
let kp = KeyPair::generate_p256()?;
let v = FullTokenValidator {
local_key: kp.public_key.clone(),
local_issuer: "local_issuer".to_string(),
oidc_validator: OidcTokenValidator,
};
run_oidc_test(v).await
}
}