mirror of
https://github.com/clockworklabs/SpacetimeDB.git
synced 2026-05-20 14:42:10 -04:00
Add testing of resigned token and FullTokenValidator
This commit is contained in:
@@ -104,7 +104,7 @@ pub struct SpacetimeAuth {
|
||||
|
||||
use jsonwebtoken;
|
||||
|
||||
pub struct TokenClaims {
|
||||
struct TokenClaims {
|
||||
pub issuer: String,
|
||||
pub subject: String,
|
||||
pub audience: Vec<String>,
|
||||
@@ -123,11 +123,15 @@ impl From<SpacetimeAuth> for TokenClaims {
|
||||
|
||||
impl TokenClaims {
|
||||
// Compute the id from the issuer and subject.
|
||||
pub fn id(&self) -> Identity {
|
||||
fn id(&self) -> Identity {
|
||||
Identity::from_claims(&self.issuer, &self.subject)
|
||||
}
|
||||
|
||||
pub fn encode_and_sign_with_expiry(&self, private_key: &EncodingKey, expiry: Option<Duration>) -> Result<String, JwtError> {
|
||||
fn encode_and_sign_with_expiry(
|
||||
&self,
|
||||
private_key: &EncodingKey,
|
||||
expiry: Option<Duration>,
|
||||
) -> Result<String, JwtError> {
|
||||
let iat = SystemTime::now();
|
||||
let exp = expiry.map(|dur| iat + dur);
|
||||
let claims = SpacetimeIdentityClaims2 {
|
||||
@@ -165,7 +169,12 @@ impl SpacetimeAuth {
|
||||
SpacetimeCreds::from_signed_token(token)
|
||||
};
|
||||
|
||||
Ok(Self { creds, identity, subject, issuer: ctx.local_issuer() })
|
||||
Ok(Self {
|
||||
creds,
|
||||
identity,
|
||||
subject,
|
||||
issuer: ctx.local_issuer(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the auth credentials as headers to be returned from an endpoint.
|
||||
@@ -175,6 +184,10 @@ impl SpacetimeAuth {
|
||||
TypedHeader(SpacetimeIdentityToken(self.creds)),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn resign_with_expiry(&self, private_key: &EncodingKey, expiry: Duration) -> Result<String, JwtError> {
|
||||
TokenClaims::from(self.clone()).encode_and_sign_with_expiry(private_key, Some(expiry))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -6,11 +6,10 @@ use http::header::CONTENT_TYPE;
|
||||
use http::StatusCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use spacetimedb::auth::identity::encode_token_with_expiry;
|
||||
use spacetimedb_lib::de::serde::DeserializeWrapper;
|
||||
use spacetimedb_lib::Identity;
|
||||
|
||||
use crate::auth::{SpacetimeAuth, SpacetimeAuthRequired, TokenClaims};
|
||||
use crate::auth::{SpacetimeAuth, SpacetimeAuthRequired};
|
||||
use crate::{log_and_500, ControlStateDelegate, NodeDelegate};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -101,8 +100,9 @@ pub async fn create_websocket_token<S: NodeDelegate>(
|
||||
SpacetimeAuthRequired(auth): SpacetimeAuthRequired,
|
||||
) -> axum::response::Result<impl IntoResponse> {
|
||||
let expiry = Duration::from_secs(60);
|
||||
let claims: TokenClaims = TokenClaims::from(auth);
|
||||
let token = claims.encode_and_sign_with_expiry(ctx.private_key(), Some(expiry)).map_err(log_and_500)?;
|
||||
let token = auth
|
||||
.resign_with_expiry(ctx.private_key(), expiry)
|
||||
.map_err(log_and_500)?;
|
||||
// let token = encode_token_with_expiry(ctx.private_key(), auth.identity, Some(expiry)).map_err(log_and_500)?;
|
||||
Ok(axum::Json(WebsocketTokenResponse { token }))
|
||||
}
|
||||
|
||||
@@ -60,87 +60,53 @@ impl TokenValidator for UnimplementedTokenValidator {
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
pub struct FullTokenValidator {
|
||||
pub public_key: DecodingKey,
|
||||
pub caching_validator: CachingOidcTokenValidator,
|
||||
pub struct FullTokenValidator<T: TokenValidator + Send + Sync> {
|
||||
pub local_key: DecodingKey,
|
||||
pub local_issuer: String,
|
||||
pub oidc_validator: T,
|
||||
// pub caching_validator: CachingOidcTokenValidator,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TokenValidator for FullTokenValidator {
|
||||
impl<T> TokenValidator for FullTokenValidator<T>
|
||||
where
|
||||
T: TokenValidator + Send + Sync,
|
||||
{
|
||||
async fn validate_token(&self, token: &str) -> Result<SpacetimeIdentityClaims2, TokenValidationError> {
|
||||
let issuer = get_raw_issuer(token)?;
|
||||
if issuer == "localhost" {
|
||||
let claims = BasicTokenValidator {
|
||||
public_key: self.public_key.clone(),
|
||||
issuer,
|
||||
let local_key_error = {
|
||||
let first_validator = BasicTokenValidator {
|
||||
public_key: self.local_key.clone(),
|
||||
issuer: None,
|
||||
};
|
||||
match first_validator.validate_token(token).await {
|
||||
Ok(claims) => return Ok(claims),
|
||||
Err(e) => e,
|
||||
}
|
||||
.validate_token(token)
|
||||
.await?;
|
||||
return Ok(claims);
|
||||
};
|
||||
|
||||
// If that fails, we try the OIDC validator.
|
||||
let issuer = get_raw_issuer(token)?;
|
||||
// If we are the issuer, then we should have already validated the token.
|
||||
// TODO: "localhost" should not be hard-coded.
|
||||
if issuer == self.local_issuer {
|
||||
return Err(local_key_error);
|
||||
}
|
||||
self.caching_validator.validate_token(token).await
|
||||
self.oidc_validator.validate_token(token).await
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
// This is a helper function that uses a global JWK cache. We should remove this eventually, and make the server hold on to its own.
|
||||
pub async fn validate_token(
|
||||
local_key: DecodingKey,
|
||||
local_issuer: &str,
|
||||
token: &str,
|
||||
) -> Result<SpacetimeIdentityClaims2, TokenValidationError> {
|
||||
let local_key_error = {
|
||||
|
||||
let first_validator = BasicTokenValidator {
|
||||
public_key: local_key.clone(),
|
||||
issuer: None,
|
||||
};
|
||||
match first_validator.validate_token(token).await {
|
||||
Ok(claims) => return Ok(claims),
|
||||
Err(e) => e,
|
||||
}
|
||||
};
|
||||
|
||||
// If that fails, we try the OIDC validator.
|
||||
let issuer = get_raw_issuer(token)?;
|
||||
// If we are the issuer, then we should have already validated the token.
|
||||
// TODO: "localhost" should not be hard-coded.
|
||||
if issuer == local_issuer {
|
||||
return Err(local_key_error);
|
||||
}
|
||||
GLOBAL_OIDC_VALIDATOR.clone().validate_token(token).await
|
||||
}
|
||||
|
||||
pub struct InitialTestingTokenValidator {
|
||||
pub public_key: DecodingKey,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TokenValidator for InitialTestingTokenValidator {
|
||||
async fn validate_token(&self, token: &str) -> Result<SpacetimeIdentityClaims2, TokenValidationError> {
|
||||
// Initially, we check if we signed the key.
|
||||
let local_key_error = {
|
||||
|
||||
let first_validator = BasicTokenValidator {
|
||||
public_key: self.public_key.clone(),
|
||||
issuer: None,
|
||||
};
|
||||
match first_validator.validate_token(token).await {
|
||||
Ok(claims) => return Ok(claims),
|
||||
Err(e) => e,
|
||||
}
|
||||
};
|
||||
|
||||
// If that fails, we try the OIDC validator.
|
||||
let issuer = get_raw_issuer(token)?;
|
||||
// If we are the issuer, then we should have already validated the token.
|
||||
// TODO: "localhost" should not be hard-coded.
|
||||
if issuer == "localhost" {
|
||||
return Err(local_key_error);
|
||||
}
|
||||
let validator = OidcTokenValidator;
|
||||
validator.validate_token(token).await
|
||||
}
|
||||
let validator = FullTokenValidator {
|
||||
local_key,
|
||||
local_issuer: local_issuer.to_string(),
|
||||
oidc_validator: GLOBAL_OIDC_VALIDATOR.clone(),
|
||||
};
|
||||
validator.validate_token(token).await
|
||||
}
|
||||
|
||||
// This verifies against a given public key and expected issuer.
|
||||
@@ -190,8 +156,7 @@ impl TokenValidator for BasicTokenValidator {
|
||||
expected_issuer
|
||||
)));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
claims.try_into()
|
||||
}
|
||||
}
|
||||
@@ -334,7 +299,7 @@ mod tests {
|
||||
|
||||
use crate::auth::identity::{IncomingClaims, SpacetimeIdentityClaims2};
|
||||
use crate::auth::token_validation::{
|
||||
CachingOidcTokenValidator, BasicTokenValidator, OidcTokenValidator, TokenValidator,
|
||||
BasicTokenValidator, CachingOidcTokenValidator, FullTokenValidator, OidcTokenValidator, TokenValidator,
|
||||
};
|
||||
use jsonwebkey as jwk;
|
||||
use jsonwebtoken::{DecodingKey, EncodingKey};
|
||||
@@ -476,6 +441,61 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn assert_validation_fails<T: TokenValidator>(validator: &T, token: &str) -> anyhow::Result<()> {
|
||||
let result = validator.validate_token(token).await;
|
||||
if result.is_ok() {
|
||||
let claims = result.unwrap();
|
||||
anyhow::bail!("Validation succeeded when it should have failed: {:?}", claims);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn resigned_token_ignores_issuer() -> anyhow::Result<()> {
|
||||
// Test that the decoding key must work for LocalTokenValidator.
|
||||
let kp = KeyPair::generate_p256()?;
|
||||
let local_issuer = "test1";
|
||||
let external_issuer = "other_issuer";
|
||||
let subject = "test_subject";
|
||||
|
||||
let orig_claims = IncomingClaims {
|
||||
identity: None,
|
||||
subject: subject.to_string(),
|
||||
issuer: external_issuer.to_string(),
|
||||
audience: vec![],
|
||||
iat: std::time::SystemTime::now(),
|
||||
exp: None,
|
||||
};
|
||||
let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::ES256);
|
||||
let token = jsonwebtoken::encode(&header, &orig_claims, &kp.private_key)?;
|
||||
|
||||
// First, try the successful case with the FullTokenValidator.
|
||||
{
|
||||
let validator = FullTokenValidator {
|
||||
local_key: kp.public_key.clone(),
|
||||
local_issuer: local_issuer.to_string(),
|
||||
oidc_validator: OidcTokenValidator,
|
||||
};
|
||||
|
||||
let parsed_claims: SpacetimeIdentityClaims2 = validator.validate_token(&token).await?;
|
||||
assert_eq!(parsed_claims.issuer, external_issuer);
|
||||
assert_eq!(parsed_claims.subject, subject);
|
||||
assert_eq!(parsed_claims.identity, Identity::from_claims(external_issuer, subject));
|
||||
}
|
||||
// Double check that this token would fail with an OidcTokenValidator.
|
||||
assert_validation_fails(&OidcTokenValidator, &token).await?;
|
||||
// Double check that validation fails if we check the issuer.
|
||||
assert_validation_fails(
|
||||
&BasicTokenValidator {
|
||||
public_key: kp.public_key.clone(),
|
||||
issuer: Some(local_issuer.to_string()),
|
||||
},
|
||||
&token,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
use axum::routing::get;
|
||||
use axum::Json;
|
||||
use axum::Router;
|
||||
@@ -601,4 +621,15 @@ mod tests {
|
||||
let v = CachingOidcTokenValidator::get_default();
|
||||
run_oidc_test(v).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_full_validator_fallback() -> anyhow::Result<()> {
|
||||
let kp = KeyPair::generate_p256()?;
|
||||
let v = FullTokenValidator {
|
||||
local_key: kp.public_key.clone(),
|
||||
local_issuer: "local_issuer".to_string(),
|
||||
oidc_validator: OidcTokenValidator,
|
||||
};
|
||||
run_oidc_test(v).await
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user