From eed796f3c5b41e4e166323d3685a493a10488e03 Mon Sep 17 00:00:00 2001 From: Jeffrey Dallatezza Date: Thu, 24 Oct 2024 12:30:26 -0700 Subject: [PATCH] Add testing of resigned token and FullTokenValidator --- crates/client-api/src/auth.rs | 21 ++- crates/client-api/src/routes/identity.rs | 8 +- crates/core/src/auth/token_validation.rs | 171 +++++++++++++---------- 3 files changed, 122 insertions(+), 78 deletions(-) diff --git a/crates/client-api/src/auth.rs b/crates/client-api/src/auth.rs index 8cfd3ad68b..952aea20b3 100644 --- a/crates/client-api/src/auth.rs +++ b/crates/client-api/src/auth.rs @@ -104,7 +104,7 @@ pub struct SpacetimeAuth { use jsonwebtoken; -pub struct TokenClaims { +struct TokenClaims { pub issuer: String, pub subject: String, pub audience: Vec, @@ -123,11 +123,15 @@ impl From for TokenClaims { impl TokenClaims { // Compute the id from the issuer and subject. - pub fn id(&self) -> Identity { + fn id(&self) -> Identity { Identity::from_claims(&self.issuer, &self.subject) } - pub fn encode_and_sign_with_expiry(&self, private_key: &EncodingKey, expiry: Option) -> Result { + fn encode_and_sign_with_expiry( + &self, + private_key: &EncodingKey, + expiry: Option, + ) -> Result { let iat = SystemTime::now(); let exp = expiry.map(|dur| iat + dur); let claims = SpacetimeIdentityClaims2 { @@ -165,7 +169,12 @@ impl SpacetimeAuth { SpacetimeCreds::from_signed_token(token) }; - Ok(Self { creds, identity, subject, issuer: ctx.local_issuer() }) + Ok(Self { + creds, + identity, + subject, + issuer: ctx.local_issuer(), + }) } /// Get the auth credentials as headers to be returned from an endpoint. @@ -175,6 +184,10 @@ impl SpacetimeAuth { TypedHeader(SpacetimeIdentityToken(self.creds)), ) } + + pub fn resign_with_expiry(&self, private_key: &EncodingKey, expiry: Duration) -> Result { + TokenClaims::from(self.clone()).encode_and_sign_with_expiry(private_key, Some(expiry)) + } } #[cfg(test)] diff --git a/crates/client-api/src/routes/identity.rs b/crates/client-api/src/routes/identity.rs index 717ebefaee..48c375fe63 100644 --- a/crates/client-api/src/routes/identity.rs +++ b/crates/client-api/src/routes/identity.rs @@ -6,11 +6,10 @@ use http::header::CONTENT_TYPE; use http::StatusCode; use serde::{Deserialize, Serialize}; -use spacetimedb::auth::identity::encode_token_with_expiry; use spacetimedb_lib::de::serde::DeserializeWrapper; use spacetimedb_lib::Identity; -use crate::auth::{SpacetimeAuth, SpacetimeAuthRequired, TokenClaims}; +use crate::auth::{SpacetimeAuth, SpacetimeAuthRequired}; use crate::{log_and_500, ControlStateDelegate, NodeDelegate}; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -101,8 +100,9 @@ pub async fn create_websocket_token( SpacetimeAuthRequired(auth): SpacetimeAuthRequired, ) -> axum::response::Result { let expiry = Duration::from_secs(60); - let claims: TokenClaims = TokenClaims::from(auth); - let token = claims.encode_and_sign_with_expiry(ctx.private_key(), Some(expiry)).map_err(log_and_500)?; + let token = auth + .resign_with_expiry(ctx.private_key(), expiry) + .map_err(log_and_500)?; // let token = encode_token_with_expiry(ctx.private_key(), auth.identity, Some(expiry)).map_err(log_and_500)?; Ok(axum::Json(WebsocketTokenResponse { token })) } diff --git a/crates/core/src/auth/token_validation.rs b/crates/core/src/auth/token_validation.rs index 710f6a6611..aee937c4e3 100644 --- a/crates/core/src/auth/token_validation.rs +++ b/crates/core/src/auth/token_validation.rs @@ -60,87 +60,53 @@ impl TokenValidator for UnimplementedTokenValidator { } } -/* -pub struct FullTokenValidator { - pub public_key: DecodingKey, - pub caching_validator: CachingOidcTokenValidator, +pub struct FullTokenValidator { + pub local_key: DecodingKey, + pub local_issuer: String, + pub oidc_validator: T, + // pub caching_validator: CachingOidcTokenValidator, } #[async_trait] -impl TokenValidator for FullTokenValidator { +impl TokenValidator for FullTokenValidator +where + T: TokenValidator + Send + Sync, +{ async fn validate_token(&self, token: &str) -> Result { - let issuer = get_raw_issuer(token)?; - if issuer == "localhost" { - let claims = BasicTokenValidator { - public_key: self.public_key.clone(), - issuer, + let local_key_error = { + let first_validator = BasicTokenValidator { + public_key: self.local_key.clone(), + issuer: None, + }; + match first_validator.validate_token(token).await { + Ok(claims) => return Ok(claims), + Err(e) => e, } - .validate_token(token) - .await?; - return Ok(claims); + }; + + // If that fails, we try the OIDC validator. + let issuer = get_raw_issuer(token)?; + // If we are the issuer, then we should have already validated the token. + // TODO: "localhost" should not be hard-coded. + if issuer == self.local_issuer { + return Err(local_key_error); } - self.caching_validator.validate_token(token).await + self.oidc_validator.validate_token(token).await } } - */ +// This is a helper function that uses a global JWK cache. We should remove this eventually, and make the server hold on to its own. pub async fn validate_token( local_key: DecodingKey, local_issuer: &str, token: &str, ) -> Result { - let local_key_error = { - - let first_validator = BasicTokenValidator { - public_key: local_key.clone(), - issuer: None, - }; - match first_validator.validate_token(token).await { - Ok(claims) => return Ok(claims), - Err(e) => e, - } - }; - - // If that fails, we try the OIDC validator. - let issuer = get_raw_issuer(token)?; - // If we are the issuer, then we should have already validated the token. - // TODO: "localhost" should not be hard-coded. - if issuer == local_issuer { - return Err(local_key_error); - } - GLOBAL_OIDC_VALIDATOR.clone().validate_token(token).await -} - -pub struct InitialTestingTokenValidator { - pub public_key: DecodingKey, -} - -#[async_trait] -impl TokenValidator for InitialTestingTokenValidator { - async fn validate_token(&self, token: &str) -> Result { - // Initially, we check if we signed the key. - let local_key_error = { - - let first_validator = BasicTokenValidator { - public_key: self.public_key.clone(), - issuer: None, - }; - match first_validator.validate_token(token).await { - Ok(claims) => return Ok(claims), - Err(e) => e, - } - }; - - // If that fails, we try the OIDC validator. - let issuer = get_raw_issuer(token)?; - // If we are the issuer, then we should have already validated the token. - // TODO: "localhost" should not be hard-coded. - if issuer == "localhost" { - return Err(local_key_error); - } - let validator = OidcTokenValidator; - validator.validate_token(token).await - } + let validator = FullTokenValidator { + local_key, + local_issuer: local_issuer.to_string(), + oidc_validator: GLOBAL_OIDC_VALIDATOR.clone(), + }; + validator.validate_token(token).await } // This verifies against a given public key and expected issuer. @@ -190,8 +156,7 @@ impl TokenValidator for BasicTokenValidator { expected_issuer ))); } - - } + } claims.try_into() } } @@ -334,7 +299,7 @@ mod tests { use crate::auth::identity::{IncomingClaims, SpacetimeIdentityClaims2}; use crate::auth::token_validation::{ - CachingOidcTokenValidator, BasicTokenValidator, OidcTokenValidator, TokenValidator, + BasicTokenValidator, CachingOidcTokenValidator, FullTokenValidator, OidcTokenValidator, TokenValidator, }; use jsonwebkey as jwk; use jsonwebtoken::{DecodingKey, EncodingKey}; @@ -476,6 +441,61 @@ mod tests { Ok(()) } + async fn assert_validation_fails(validator: &T, token: &str) -> anyhow::Result<()> { + let result = validator.validate_token(token).await; + if result.is_ok() { + let claims = result.unwrap(); + anyhow::bail!("Validation succeeded when it should have failed: {:?}", claims); + } + Ok(()) + } + + #[tokio::test] + async fn resigned_token_ignores_issuer() -> anyhow::Result<()> { + // Test that the decoding key must work for LocalTokenValidator. + let kp = KeyPair::generate_p256()?; + let local_issuer = "test1"; + let external_issuer = "other_issuer"; + let subject = "test_subject"; + + let orig_claims = IncomingClaims { + identity: None, + subject: subject.to_string(), + issuer: external_issuer.to_string(), + audience: vec![], + iat: std::time::SystemTime::now(), + exp: None, + }; + let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::ES256); + let token = jsonwebtoken::encode(&header, &orig_claims, &kp.private_key)?; + + // First, try the successful case with the FullTokenValidator. + { + let validator = FullTokenValidator { + local_key: kp.public_key.clone(), + local_issuer: local_issuer.to_string(), + oidc_validator: OidcTokenValidator, + }; + + let parsed_claims: SpacetimeIdentityClaims2 = validator.validate_token(&token).await?; + assert_eq!(parsed_claims.issuer, external_issuer); + assert_eq!(parsed_claims.subject, subject); + assert_eq!(parsed_claims.identity, Identity::from_claims(external_issuer, subject)); + } + // Double check that this token would fail with an OidcTokenValidator. + assert_validation_fails(&OidcTokenValidator, &token).await?; + // Double check that validation fails if we check the issuer. + assert_validation_fails( + &BasicTokenValidator { + public_key: kp.public_key.clone(), + issuer: Some(local_issuer.to_string()), + }, + &token, + ) + .await?; + Ok(()) + } + use axum::routing::get; use axum::Json; use axum::Router; @@ -601,4 +621,15 @@ mod tests { let v = CachingOidcTokenValidator::get_default(); run_oidc_test(v).await } + + #[tokio::test] + async fn test_full_validator_fallback() -> anyhow::Result<()> { + let kp = KeyPair::generate_p256()?; + let v = FullTokenValidator { + local_key: kp.public_key.clone(), + local_issuer: "local_issuer".to_string(), + oidc_validator: OidcTokenValidator, + }; + run_oidc_test(v).await + } }