diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 62edf2d64..84576981d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,7 +21,7 @@ jobs: include: - { runner: spacetimedb-runner, smoketest_args: --docker } - { runner: windows-latest, smoketest_args: --no-build-cli } - runner: [spacetimedb-runner, windows-latest] + runner: [ spacetimedb-runner, windows-latest ] runs-on: ${{ matrix.runner }} steps: - name: Find Git ref @@ -44,6 +44,10 @@ jobs: - uses: actions/setup-dotnet@v4 with: global-json-file: modules/global.json + - name: Install psql (Windows) + if: runner.os == 'Windows' + run: choco install psql -y --no-progress + shell: powershell - name: Build and start database (Linux) if: runner.os == 'Linux' run: docker compose up -d @@ -54,11 +58,13 @@ jobs: Start-Process target/debug/spacetimedb-cli.exe start cd modules # the sdk-manifests on windows-latest are messed up, so we need to update them - dotnet workload config --update-mode workload-set + dotnet workload config --update-mode manifests dotnet workload update - uses: actions/setup-python@v5 with: { python-version: '3.12' } if: runner.os == 'Windows' + - name: Install psycopg2 + run: python -m pip install psycopg2-binary - name: Run smoketests # Note: clear_database and replication only work in private run: python -m smoketests ${{ matrix.smoketest_args }} -x clear_database replication diff --git a/Cargo.lock b/Cargo.lock index 87c02c91a..92a4baf56 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -195,6 +195,12 @@ dependencies = [ "derive_arbitrary", ] +[[package]] +name = "array-init" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d62b7694a562cdf5a74227903507c56ab2cc8bdd1f781ed5cb4cf9c9f810bfc" + [[package]] name = "arrayref" version = "0.3.9" @@ -293,6 +299,30 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "aws-lc-rs" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b756939cb2f8dc900aa6dcd505e6e2428e9cae7ff7b028c49e3946efa70878" +dependencies = [ + "aws-lc-sys", + "untrusted 0.7.1", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.28.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa9b6986f250236c27e5a204062434a773a13243d2ffc2955f37bdba4c5c6a1" +dependencies = [ + "bindgen 0.69.5", + "cc", + "cmake", + "dunce", + "fs_extra", +] + [[package]] name = "axum" version = "0.7.9" @@ -429,6 +459,29 @@ dependencies = [ "serde", ] +[[package]] +name = "bindgen" +version = "0.69.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +dependencies = [ + "bitflags 2.9.0", + "cexpr", + "clang-sys", + "itertools 0.12.1", + "lazy_static", + "lazycell", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash 1.1.0", + "shlex", + "syn 2.0.101", + "which 4.4.2", +] + [[package]] name = "bindgen" version = "0.71.1" @@ -444,7 +497,7 @@ dependencies = [ "proc-macro2", "quote", "regex", - "rustc-hash", + "rustc-hash 2.1.1", "shlex", "syn 2.0.101", ] @@ -925,6 +978,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "cmake" +version = "0.1.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +dependencies = [ + "cc", +] + [[package]] name = "cobs" version = "0.2.3" @@ -1092,7 +1154,7 @@ dependencies = [ "hashbrown 0.14.5", "log", "regalloc2", - "rustc-hash", + "rustc-hash 2.1.1", "smallvec", "target-lexicon", ] @@ -1485,6 +1547,17 @@ dependencies = [ "serde", ] +[[package]] +name = "derive-new" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "derive_arbitrary" version = "1.4.1" @@ -1602,6 +1675,12 @@ dependencies = [ "shared_child", ] +[[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + [[package]] name = "educe" version = "0.4.23" @@ -3011,12 +3090,41 @@ dependencies = [ "spacetimedb 1.4.0", ] +[[package]] +name = "lazy-regex" +version = "3.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60c7310b93682b36b98fa7ea4de998d3463ccbebd94d935d6b48ba5b6ffa7126" +dependencies = [ + "lazy-regex-proc_macros", + "once_cell", + "regex-lite", +] + +[[package]] +name = "lazy-regex-proc_macros" +version = "3.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ba01db5ef81e17eb10a5e0f2109d1b3a3e29bac3070fdbd7d156bf7dbd206a1" +dependencies = [ + "proc-macro2", + "quote", + "regex", + "syn 2.0.101", +] + [[package]] name = "lazy_static" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + [[package]] name = "leb128" version = "0.2.5" @@ -3215,6 +3323,12 @@ dependencies = [ "digest", ] +[[package]] +name = "md5" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae960838283323069879657ca3de837e9f7bbb4c7bf6ea7f1b290d5e9476d2e0" + [[package]] name = "memchr" version = "2.7.4" @@ -3805,6 +3919,31 @@ dependencies = [ "postgres-types", ] +[[package]] +name = "pgwire" +version = "0.32.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddf403a6ee31cf7f2217b2bd8447cb13dbb6c268d7e81501bc78a4d3daafd294" +dependencies = [ + "async-trait", + "aws-lc-rs", + "bytes", + "chrono", + "derive-new", + "futures", + "hex", + "lazy-regex", + "md5", + "postgres-types", + "rand 0.9.1", + "rust_decimal", + "rustls-pki-types", + "thiserror 2.0.12", + "tokio", + "tokio-rustls", + "tokio-util", +] + [[package]] name = "phf" version = "0.11.3" @@ -3956,6 +4095,7 @@ version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613283563cd90e1dfc3518d548caee47e0e725455ed619881f5cf21f36de4b48" dependencies = [ + "array-init", "bytes", "chrono", "fallible-iterator 0.2.0", @@ -4445,7 +4585,7 @@ checksum = "12908dbeb234370af84d0579b9f68258a0f67e201412dd9a2814e6f45b2fc0f0" dependencies = [ "hashbrown 0.14.5", "log", - "rustc-hash", + "rustc-hash 2.1.1", "slice-group-by", "smallvec", ] @@ -4482,6 +4622,12 @@ dependencies = [ "regex-syntax 0.8.5", ] +[[package]] +name = "regex-lite" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" + [[package]] name = "regex-syntax" version = "0.6.29" @@ -4623,7 +4769,7 @@ dependencies = [ "cfg-if", "getrandom 0.2.16", "libc", - "untrusted", + "untrusted 0.9.0", "windows-sys 0.52.0", ] @@ -4734,6 +4880,12 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustc-hash" version = "2.1.1" @@ -4781,6 +4933,8 @@ version = "0.23.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "730944ca083c1c233a75c09f199e973ca499344a2b7ba9e755c457e86fb4a321" dependencies = [ + "aws-lc-rs", + "log", "once_cell", "rustls-pki-types", "rustls-webpki", @@ -4821,9 +4975,10 @@ version = "0.103.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7149975849f1abb3832b246010ef62ccc80d3a76169517ada7188252b9cfb437" dependencies = [ + "aws-lc-rs", "ring", "rustls-pki-types", - "untrusted", + "untrusted 0.9.0", ] [[package]] @@ -5674,7 +5829,7 @@ dependencies = [ "regex", "reqwest 0.12.15", "rustc-demangle", - "rustc-hash", + "rustc-hash 2.1.1", "scopeguard", "semver", "serde", @@ -5960,6 +6115,24 @@ dependencies = [ "xdg", ] +[[package]] +name = "spacetimedb-pg" +version = "1.4.0" +dependencies = [ + "anyhow", + "async-trait", + "axum", + "futures", + "http 1.3.1", + "log", + "pgwire", + "spacetimedb-client-api", + "spacetimedb-client-api-messages", + "spacetimedb-lib 1.4.0", + "thiserror 1.0.69", + "tokio", +] + [[package]] name = "spacetimedb-physical-plan" version = "1.4.0" @@ -6207,6 +6380,7 @@ dependencies = [ "spacetimedb-datastore", "spacetimedb-lib 1.4.0", "spacetimedb-paths", + "spacetimedb-pg", "spacetimedb-schema", "spacetimedb-table", "tempfile", @@ -7397,6 +7571,12 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "untrusted" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" + [[package]] name = "untrusted" version = "0.9.0" @@ -7477,7 +7657,7 @@ version = "137.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ca393e2032ddba2a57169e15cac5d0a81cdb3d872a8886f4468bc0f486098d2" dependencies = [ - "bindgen", + "bindgen 0.71.1", "bitflags 2.9.0", "fslock", "gzip-header", diff --git a/Cargo.toml b/Cargo.toml index 76529fe1e..cb45e9d89 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ members = [ "crates/lib", "crates/metrics", "crates/paths", + "crates/pg", "crates/physical-plan", "crates/primitives", "crates/query", @@ -115,6 +116,7 @@ spacetimedb-lib = { path = "crates/lib", default-features = false, version = "1. spacetimedb-memory-usage = { path = "crates/memory-usage", version = "1.4.0", default-features = false } spacetimedb-metrics = { path = "crates/metrics", version = "1.4.0" } spacetimedb-paths = { path = "crates/paths", version = "1.4.0" } +spacetimedb-pg = { path = "crates/pg", version = "1.4.0" } spacetimedb-physical-plan = { path = "crates/physical-plan", version = "1.4.0" } spacetimedb-primitives = { path = "crates/primitives", version = "1.4.0" } spacetimedb-query = { path = "crates/query", version = "1.4.0" } @@ -214,6 +216,7 @@ paste = "1.0" percent-encoding = "2.3" petgraph = { version = "0.6.5", default-features = false } pin-project-lite = "0.2.9" +pgwire = { version = "0.32", features = ["server-api"] } postgres-types = "0.2.5" pretty_assertions = { version = "1.4", features = ["unstable"] } proc-macro2 = "1.0" diff --git a/crates/cli/src/subcommands/sql.rs b/crates/cli/src/subcommands/sql.rs index 12ad41a49..00a8d645e 100644 --- a/crates/cli/src/subcommands/sql.rs +++ b/crates/cli/src/subcommands/sql.rs @@ -10,6 +10,7 @@ use anyhow::Context; use clap::{Arg, ArgAction, ArgMatches}; use reqwest::RequestBuilder; use spacetimedb_lib::de::serde::SeedWrapper; +use spacetimedb_lib::sats::satn::PsqlClient; use spacetimedb_lib::sats::{satn, ProductType, ProductValue, Typespace}; pub fn cli() -> clap::Command { @@ -111,7 +112,7 @@ fn print_stmt_result( for (pos, result) in if_empty .into_iter() .chain(stmt_results.iter().map(|stmt_result| { - let (stats, table) = stmt_result_to_table(stmt_result)?; + let (stats, table) = stmt_result_to_table(PsqlClient::SpacetimeDB, stmt_result)?; anyhow::Ok(StmtResult { stats: with_stats.is_some().then_some(stats), @@ -157,12 +158,13 @@ pub(crate) async fn run_sql(builder: RequestBuilder, sql: &str, with_stats: bool Ok(()) } -fn stmt_result_to_table(stmt_result: &SqlStmtResult) -> anyhow::Result<(StmtStats, tabled::Table)> { +fn stmt_result_to_table(client: PsqlClient, stmt_result: &SqlStmtResult) -> anyhow::Result<(StmtStats, tabled::Table)> { let stats = StmtStats::from(stmt_result); let SqlStmtResult { schema, rows, .. } = stmt_result; let ty = Typespace::EMPTY.with_type(schema); let table = build_table( + client, schema, rows.iter().map(|row| from_json_seed(row.get(), SeedWrapper(ty))), )?; @@ -194,6 +196,7 @@ pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error /// Generates a [`tabled::Table`] from a schema and rows, using the style of a psql table. fn build_table( + client: PsqlClient, schema: &ProductType, rows: impl Iterator>, ) -> Result { @@ -211,6 +214,7 @@ fn build_table( let row = row?; builder.push_record(ty.with_values(&row).enumerate().map(|(idx, value)| { let ty = satn::PsqlType { + client, tuple: ty.ty(), field: &ty.ty().elements[idx], idx, @@ -446,8 +450,10 @@ Roundtrip time: 1.00ms"#, Ok(()) } - fn expect_psql_table(ty: &ProductType, rows: Vec, expected: &str) { - let table = build_table(ty, rows.into_iter().map(Ok::<_, ()>)).unwrap().to_string(); + fn expect_psql_table(client: PsqlClient, ty: &ProductType, rows: Vec, expected: &str) { + let table = build_table(client, ty, rows.into_iter().map(Ok::<_, ()>)) + .unwrap() + .to_string(); let mut table = table.split('\n').map(|x| x.trim_end()).join("\n"); table.insert(0, '\n'); assert_eq!(expected, table); @@ -476,14 +482,25 @@ Roundtrip time: 1.00ms"#, ]; expect_psql_table( + PsqlClient::SpacetimeDB, &kind, - vec![value], + vec![value.clone()], r#" column 0 | column 1 | column 2 | column 3 | column 4 | column 5 ----------+----------+--------------------------------------------------------------------+------------------------------------+---------------------------+----------- "a" | 0 | 0x0000000000000000000000000000000000000000000000000000000000000000 | 0x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | +0.000000"#, ); + expect_psql_table( + PsqlClient::Postgres, + &kind, + vec![value], + r#" + column 0 | column 1 | column 2 | column 3 | column 4 | column 5 +----------+----------+----------------------------------------------------------------------+--------------------------------------+-----------------------------+---------- + "a" | 0 | "0x0000000000000000000000000000000000000000000000000000000000000000" | "0x00000000000000000000000000000000" | "1970-01-01T00:00:00+00:00" | "P0D""#, + ); + // Check struct let kind: ProductType = [ ("bool", AlgebraicType::Bool), @@ -507,6 +524,7 @@ Roundtrip time: 1.00ms"#, ]; expect_psql_table( + PsqlClient::SpacetimeDB, &kind, vec![value.clone()], r#" @@ -515,12 +533,23 @@ Roundtrip time: 1.00ms"#, true | "This is spacetimedb" | 0x01020304050607 | 0x0000000000000000000000000000000000000000000000000000000000000000 | 0x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | +0.000000"#, ); + expect_psql_table( + PsqlClient::Postgres, + &kind, + vec![value.clone()], + r#" + bool | str | bytes | identity | connection_id | timestamp | duration +------+-----------------------+--------------------+----------------------------------------------------------------------+--------------------------------------+-----------------------------+---------- + true | "This is spacetimedb" | "0x01020304050607" | "0x0000000000000000000000000000000000000000000000000000000000000000" | "0x00000000000000000000000000000000" | "1970-01-01T00:00:00+00:00" | "P0D""#, + ); + // Check nested struct, tuple... let kind: ProductType = [(None, AlgebraicType::product(kind))].into(); let value = product![value.clone()]; expect_psql_table( + PsqlClient::SpacetimeDB, &kind, vec![value.clone()], r#" @@ -529,17 +558,38 @@ Roundtrip time: 1.00ms"#, (bool = true, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000000, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = +0.000000)"#, ); + expect_psql_table( + PsqlClient::Postgres, + &kind, + vec![value.clone()], + r#" + column 0 +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + {"bool": true, "str": "This is spacetimedb", "bytes": "0x01020304050607", "identity": "0x0000000000000000000000000000000000000000000000000000000000000000", "connection_id": "0x00000000000000000000000000000000", "timestamp": "1970-01-01T00:00:00+00:00", "duration": "P0D"}"#, + ); + let kind: ProductType = [("tuple", AlgebraicType::product(kind))].into(); let value = product![value]; expect_psql_table( + PsqlClient::SpacetimeDB, + &kind, + vec![value.clone()], + r#" + tuple +-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + (col_0 = (bool = true, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000000, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = +0.000000))"#, + ); + + expect_psql_table( + PsqlClient::Postgres, &kind, vec![value], r#" tuple ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- - (0 = (bool = true, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000000, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = +0.000000))"#, +-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + {"col_0": {"bool": true, "str": "This is spacetimedb", "bytes": "0x01020304050607", "identity": "0x0000000000000000000000000000000000000000000000000000000000000000", "connection_id": "0x00000000000000000000000000000000", "timestamp": "1970-01-01T00:00:00+00:00", "duration": "P0D"}}"#, ); Ok(()) diff --git a/crates/client-api-messages/src/name.rs b/crates/client-api-messages/src/name.rs index 5fac594f9..49cee23a1 100644 --- a/crates/client-api-messages/src/name.rs +++ b/crates/client-api-messages/src/name.rs @@ -171,7 +171,7 @@ pub enum SetDefaultDomainResult { /// /// Must match the regex `^[a-z0-9]+(-[a-z0-9]+)*$` #[derive(Clone, Debug, serde_with::DeserializeFromStr, serde_with::SerializeDisplay)] -pub struct DatabaseName(String); +pub struct DatabaseName(pub String); impl AsRef for DatabaseName { fn as_ref(&self) -> &str { diff --git a/crates/client-api/src/auth.rs b/crates/client-api/src/auth.rs index 610316258..c616ff53e 100644 --- a/crates/client-api/src/auth.rs +++ b/crates/client-api/src/auth.rs @@ -222,6 +222,10 @@ impl TokenSigner for JwtKeyAuthProvider { impl JwtAuthProvider for JwtKeyAuthProvider { type TV = TV; + fn validator(&self) -> &Self::TV { + &self.validator + } + fn local_issuer(&self) -> &str { &self.local_issuer } @@ -229,10 +233,6 @@ impl JwtAuthProvider for JwtKeyAuthProvider &[u8] { &self.keys.public_pem } - - fn validator(&self) -> &Self::TV { - &self.validator - } } #[cfg(test)] @@ -260,6 +260,13 @@ mod tests { } } +pub async fn validate_token( + state: &S, + token: &str, +) -> Result { + state.jwt_auth_provider().validator().validate_token(token).await +} + pub struct SpacetimeAuthHeader { auth: Option, } @@ -272,10 +279,7 @@ impl axum::extract::FromRequestParts for Space return Ok(Self { auth: None }); }; - let claims = state - .jwt_auth_provider() - .validator() - .validate_token(&creds.token) + let claims = validate_token(state, &creds.token) .await .map_err(AuthorizationRejection::Custom)?; diff --git a/crates/client-api/src/routes/database.rs b/crates/client-api/src/routes/database.rs index 52b2458bf..fdee60f49 100644 --- a/crates/client-api/src/routes/database.rs +++ b/crates/client-api/src/routes/database.rs @@ -7,7 +7,7 @@ use crate::auth::{ SpacetimeIdentityToken, }; use crate::routes::subscribe::generate_random_connection_id; -use crate::util::{ByteStringBody, NameOrIdentity}; +pub use crate::util::{ByteStringBody, NameOrIdentity}; use crate::{log_and_500, ControlStateDelegate, DatabaseDef, NodeDelegate}; use axum::body::{Body, Bytes}; use axum::extract::{Path, Query, State}; @@ -31,7 +31,7 @@ use spacetimedb_client_api_messages::name::{ }; use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9; use spacetimedb_lib::identity::AuthCtx; -use spacetimedb_lib::{sats, Timestamp}; +use spacetimedb_lib::{sats, ProductValue, Timestamp}; use spacetimedb_schema::auto_migrate::{ MigrationPolicy as SchemaMigrationPolicy, MigrationToken, PrettyPrintStyle as AutoMigratePrettyPrintStyle, }; @@ -383,7 +383,7 @@ pub(crate) async fn worker_ctx_find_database( #[derive(Deserialize)] pub struct SqlParams { - name_or_identity: NameOrIdentity, + pub name_or_identity: NameOrIdentity, } #[derive(Deserialize)] @@ -391,16 +391,16 @@ pub struct SqlQueryParams { /// If `true`, return the query result only after its transaction offset /// is confirmed to be durable. #[serde(default)] - confirmed: bool, + pub confirmed: bool, } -pub async fn sql( - State(worker_ctx): State, - Path(SqlParams { name_or_identity }): Path, - Query(SqlQueryParams { confirmed }): Query, - Extension(auth): Extension, - body: String, -) -> axum::response::Result +pub async fn sql_direct( + worker_ctx: S, + SqlParams { name_or_identity }: SqlParams, + SqlQueryParams { confirmed }: SqlQueryParams, + caller_identity: Identity, + sql: String, +) -> axum::response::Result>> where S: NodeDelegate + ControlStateDelegate, { @@ -412,7 +412,7 @@ where .await? .ok_or(NO_SUCH_DATABASE)?; - let auth = AuthCtx::new(database.owner_identity, auth.identity); + let auth = AuthCtx::new(database.owner_identity, caller_identity); log::debug!("auth: {auth:?}"); let host = worker_ctx @@ -420,7 +420,21 @@ where .await .map_err(log_and_500)? .ok_or(StatusCode::NOT_FOUND)?; - let json = host.exec_sql(auth, database, confirmed, body).await?; + + host.exec_sql(auth, database, confirmed, sql).await +} + +pub async fn sql( + State(worker_ctx): State, + Path(name_or_identity): Path, + Query(params): Query, + Extension(auth): Extension, + body: String, +) -> axum::response::Result +where + S: NodeDelegate + ControlStateDelegate, +{ + let json = sql_direct(worker_ctx, name_or_identity, params, auth.identity, body).await?; let total_duration = json.iter().fold(0, |acc, x| acc + x.total_duration_micros); @@ -488,7 +502,9 @@ pub struct PublishDatabaseQueryParams { policy: MigrationPolicy, } +use spacetimedb_client_api_messages::http::SqlStmtResult; use std::env; + fn require_spacetime_auth_for_creation() -> bool { env::var("TEMP_REQUIRE_SPACETIME_AUTH").is_ok_and(|v| !v.is_empty()) } @@ -526,7 +542,7 @@ pub async fn publish( // so, unless you are the owner, this will fail. let (database_identity, db_name) = match &name_or_identity { - Some(noa) => match noa.try_resolve(&ctx).await? { + Some(noa) => match noa.try_resolve(&ctx).await.map_err(log_and_500)? { Ok(resolved) => (resolved, noa.name()), Err(name) => { // `name_or_identity` was a `NameOrIdentity::Name`, but no record diff --git a/crates/client-api/src/util.rs b/crates/client-api/src/util.rs index 91386986f..c38bf33c0 100644 --- a/crates/client-api/src/util.rs +++ b/crates/client-api/src/util.rs @@ -97,10 +97,10 @@ impl NameOrIdentity { pub async fn try_resolve( &self, ctx: &(impl ControlStateReadAccess + ?Sized), - ) -> axum::response::Result> { + ) -> anyhow::Result> { Ok(match self { Self::Identity(identity) => Ok(Identity::from(*identity)), - Self::Name(name) => ctx.lookup_identity(name.as_ref()).map_err(log_and_500)?.ok_or(name), + Self::Name(name) => ctx.lookup_identity(name.as_ref())?.ok_or(name), }) } @@ -108,7 +108,10 @@ impl NameOrIdentity { /// response if `self` is a [`NameOrIdentity::Name`] for which no /// corresponding [`Identity`] is found in the SpacetimeDB DNS. pub async fn resolve(&self, ctx: &(impl ControlStateReadAccess + ?Sized)) -> axum::response::Result { - self.try_resolve(ctx).await?.map_err(|_| StatusCode::NOT_FOUND.into()) + self.try_resolve(ctx) + .await + .map_err(log_and_500)? + .map_err(|_| StatusCode::NOT_FOUND.into()) } } diff --git a/crates/core/src/auth/mod.rs b/crates/core/src/auth/mod.rs index f9c381902..e1e38a667 100644 --- a/crates/core/src/auth/mod.rs +++ b/crates/core/src/auth/mod.rs @@ -15,6 +15,7 @@ pub struct JwtKeys { pub public: DecodingKey, pub public_pem: Box<[u8]>, pub private: EncodingKey, + pub private_pem: Box<[u8]>, pub kid: Option, } @@ -23,15 +24,17 @@ impl JwtKeys { /// respectively. /// /// The key files must be PEM encoded ECDSA P256 keys. - pub fn new(public_pem: impl Into>, private_pem: &[u8]) -> anyhow::Result { + pub fn new(public_pem: impl Into>, private_pem: impl Into>) -> anyhow::Result { let public_pem = public_pem.into(); + let private_pem = private_pem.into(); let public = DecodingKey::from_ec_pem(&public_pem)?; - let private = EncodingKey::from_ec_pem(private_pem)?; + let private = EncodingKey::from_ec_pem(&private_pem)?; Ok(Self { public, private, public_pem, + private_pem, kid: None, }) } @@ -75,7 +78,7 @@ pub struct EcKeyPair { impl TryFrom for JwtKeys { type Error = anyhow::Error; fn try_from(pair: EcKeyPair) -> anyhow::Result { - JwtKeys::new(pair.public_key_bytes, &pair.private_key_bytes) + JwtKeys::new(pair.public_key_bytes, pair.private_key_bytes) } } diff --git a/crates/core/src/messages/control_db.rs b/crates/core/src/messages/control_db.rs index 21a9155dc..8299875e3 100644 --- a/crates/core/src/messages/control_db.rs +++ b/crates/core/src/messages/control_db.rs @@ -63,6 +63,10 @@ pub struct Node { /// /// If `None`, the node is not currently live. pub advertise_addr: Option, + /// The address this node is running its postgres API at. + /// + /// If `None`, the node is not currently live. + pub pg_addr: Option, } #[derive(Clone, PartialEq, Serialize, Deserialize)] pub struct NodeStatus { diff --git a/crates/pg/Cargo.toml b/crates/pg/Cargo.toml new file mode 100644 index 000000000..dd49122de --- /dev/null +++ b/crates/pg/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "spacetimedb-pg" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license-file = "LICENSE" +description = "Postgres wire protocol Server support for SpacetimeDB" + +[dependencies] +spacetimedb-client-api-messages.workspace = true +spacetimedb-client-api.workspace = true +spacetimedb-lib.workspace = true + +anyhow.workspace = true +async-trait.workspace = true +axum.workspace = true +futures.workspace = true +http.workspace = true +log.workspace = true +pgwire.workspace = true +thiserror.workspace = true +tokio.workspace = true diff --git a/crates/pg/LICENSE b/crates/pg/LICENSE new file mode 120000 index 000000000..8540cf8a9 --- /dev/null +++ b/crates/pg/LICENSE @@ -0,0 +1 @@ +../../licenses/BSL.txt \ No newline at end of file diff --git a/crates/pg/README.md b/crates/pg/README.md new file mode 100644 index 000000000..fc5b684dd --- /dev/null +++ b/crates/pg/README.md @@ -0,0 +1,3 @@ +> ⚠️ **Internal Crate** ⚠️ +> +> This crate is intended for internal use only. It is **not** stable and may change without notice. diff --git a/crates/pg/src/encoder.rs b/crates/pg/src/encoder.rs new file mode 100644 index 000000000..f5a6ed990 --- /dev/null +++ b/crates/pg/src/encoder.rs @@ -0,0 +1,301 @@ +use crate::pg_server::PgError; +use pgwire::api::portal::Format; +use pgwire::api::results::{DataRowEncoder, FieldInfo}; +use pgwire::api::Type; +use spacetimedb_lib::sats::satn::{PsqlChars, PsqlPrintFmt, PsqlType, TypedWriter}; +use spacetimedb_lib::sats::{satn, ValueWithType}; +use spacetimedb_lib::{ + ser, AlgebraicType, AlgebraicValue, ProductType, ProductTypeElement, ProductValue, TimeDuration, Timestamp, +}; +use std::borrow::Cow; +use std::sync::Arc; + +pub(crate) fn row_desc(schema: &ProductType, format: &Format) -> Arc> { + Arc::new( + schema + .elements + .iter() + .enumerate() + .map(|(pos, ty)| { + let field_name = ty.name.clone().map(Into::into).unwrap_or_else(|| format!("col_{pos}")); + let field_type = type_of(schema, ty); + FieldInfo::new(field_name, None, None, field_type, format.format_for(pos)) + }) + .collect(), + ) +} + +pub(crate) fn type_of(schema: &ProductType, ty: &ProductTypeElement) -> Type { + let format = PsqlPrintFmt::use_fmt(schema, ty, ty.name()); + match &ty.algebraic_type { + AlgebraicType::String => Type::VARCHAR, + AlgebraicType::Bool => Type::BOOL, + AlgebraicType::U8 | AlgebraicType::I8 | AlgebraicType::I16 => Type::INT2, + AlgebraicType::U16 | AlgebraicType::I32 => Type::INT4, + AlgebraicType::U32 | AlgebraicType::I64 => Type::INT8, + AlgebraicType::U64 | AlgebraicType::I128 | AlgebraicType::U128 | AlgebraicType::I256 | AlgebraicType::U256 => { + Type::NUMERIC + } + AlgebraicType::F32 => Type::FLOAT4, + AlgebraicType::F64 => Type::FLOAT8, + AlgebraicType::Array(ty) => match *ty.elem_ty { + AlgebraicType::String => Type::VARCHAR_ARRAY, + AlgebraicType::Bool => Type::BOOL_ARRAY, + AlgebraicType::U8 => Type::BYTEA, + AlgebraicType::I8 | AlgebraicType::I16 => Type::INT2_ARRAY, + AlgebraicType::U16 | AlgebraicType::I32 => Type::INT4_ARRAY, + AlgebraicType::U32 | AlgebraicType::I64 => Type::INT8_ARRAY, + AlgebraicType::U64 + | AlgebraicType::I128 + | AlgebraicType::U128 + | AlgebraicType::I256 + | AlgebraicType::U256 => Type::NUMERIC_ARRAY, + _ => Type::ANYARRAY, + }, + AlgebraicType::Product(_) => match format { + PsqlPrintFmt::Hex => Type::BYTEA_ARRAY, + PsqlPrintFmt::Timestamp => Type::TIMESTAMP, + PsqlPrintFmt::Duration => Type::INTERVAL, + _ => Type::JSON, + }, + AlgebraicType::Sum(sum) if sum.is_simple_enum() => Type::ANYENUM, + AlgebraicType::Sum(_) => Type::JSON, + _ => Type::UNKNOWN, + } +} + +impl ser::Error for PgError { + fn custom(msg: T) -> Self { + PgError::Other(anyhow::anyhow!(msg.to_string())) + } +} + +pub(crate) struct PsqlFormatter<'a> { + pub(crate) encoder: &'a mut DataRowEncoder, +} + +impl TypedWriter for PsqlFormatter<'_> { + type Error = PgError; + + fn write(&mut self, value: W) -> Result<(), Self::Error> { + self.encoder.encode_field(&value.to_string())?; + Ok(()) + } + + fn write_bool(&mut self, value: bool) -> Result<(), Self::Error> { + self.encoder.encode_field(&value)?; + Ok(()) + } + + fn write_string(&mut self, value: &str) -> Result<(), Self::Error> { + self.encoder.encode_field(&value)?; + Ok(()) + } + + fn write_bytes(&mut self, value: &[u8]) -> Result<(), Self::Error> { + self.encoder.encode_field(&value)?; + Ok(()) + } + + fn write_hex(&mut self, value: &[u8]) -> Result<(), Self::Error> { + self.encoder.encode_field(&value)?; + Ok(()) + } + + fn write_timestamp(&mut self, value: Timestamp) -> Result<(), Self::Error> { + self.encoder.encode_field(&value.to_rfc3339()?)?; + Ok(()) + } + + fn write_duration(&mut self, value: TimeDuration) -> Result<(), Self::Error> { + self.encoder.encode_field(&value.to_iso8601())?; + Ok(()) + } + + fn write_alt_record( + &mut self, + ty: &PsqlType, + value: &ValueWithType<'_, ProductValue>, + ) -> Result { + let json = satn::PsqlWrapper { ty: ty.clone(), value }.to_string(); + self.encoder.encode_field(&json)?; + Ok(true) + } + + fn write_record( + &mut self, + _fields: Vec<(Cow, PsqlType, ValueWithType)>, + ) -> Result<(), Self::Error> { + unreachable!("Use `write_alt_record` for records in PSQL format"); + } + + fn write_variant( + &mut self, + tag: u8, + ty: PsqlType, + name: Option<&str>, + value: ValueWithType, + ) -> Result<(), Self::Error> { + // Is a simple enum? + if let AlgebraicType::Sum(sum) = &ty.field.algebraic_type { + if sum.is_simple_enum() { + if let Some(variant_name) = name { + self.encoder.encode_field(&variant_name)?; + return Ok(()); + } + } + } + + let PsqlChars { start, sep, end, quote } = ty.client.format_chars(); + let name = name.map(Cow::from).unwrap_or_else(|| Cow::from(tag.to_string())); + let json = format!( + "{start}{quote}{name}{quote}{sep} {}{end}", + satn::PsqlWrapper { ty, value } + ); + self.encoder.encode_field(&json)?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::pg_server::to_rows; + use futures::StreamExt; + use spacetimedb_client_api_messages::http::SqlStmtResult; + use spacetimedb_lib::sats::algebraic_value::Packed; + use spacetimedb_lib::sats::{i256, product, u256, AlgebraicType, ProductType, SumTypeVariant}; + use spacetimedb_lib::{ConnectionId, Identity}; + + async fn run(schema: ProductType, row: ProductValue) -> String { + let header = row_desc(&schema, &Format::UnifiedText); + + let stmt = SqlStmtResult { + schema, + rows: vec![row], + total_duration_micros: 0, + stats: Default::default(), + }; + let mut stream = to_rows(stmt, header).unwrap(); + let mut result = String::new(); + if let Some(row) = stream.next().await { + result = String::from_utf8_lossy(row.unwrap().data.freeze().as_ref()).to_string(); + } + result + } + + #[tokio::test] + async fn test_primitives() { + let schema = ProductType::from([ + AlgebraicType::U8, + AlgebraicType::I8, + AlgebraicType::I16, + AlgebraicType::U16, + AlgebraicType::I32, + AlgebraicType::U32, + AlgebraicType::I64, + AlgebraicType::U64, + AlgebraicType::I128, + AlgebraicType::U128, + AlgebraicType::I256, + AlgebraicType::U256, + AlgebraicType::F32, + AlgebraicType::F64, + AlgebraicType::String, + AlgebraicType::Bool, + ]); + let value = product![ + 1u8, + -1i8, + -2i16, + 3u16, + -4i32, + 5u32, + -6i64, + 7u64, + Packed::from(-8i128), + Packed::from(9u128), + i256::from(-10), + u256::from(11u128), + 12.34f32, + 56.78f64, + "test".to_string(), + true, + ]; + + let row = run(schema, value).await; + assert_eq!(row, "\0\0\0\u{1}1\0\0\0\u{2}-1\0\0\0\u{2}-2\0\0\0\u{1}3\0\0\0\u{2}-4\0\0\0\u{1}5\0\0\0\u{2}-6\0\0\0\u{1}7\0\0\0\u{2}-8\0\0\0\u{1}9\0\0\0\u{3}-10\0\0\0\u{2}11\0\0\0\u{5}12.34\0\0\0\u{5}56.78\0\0\0\u{4}test\0\0\0\u{1}t"); + } + + #[tokio::test] + async fn test_enum() { + let some = AlgebraicType::option(AlgebraicType::I64); + let schema = ProductType::from([some.clone(), some]); + let value = product![ + AlgebraicValue::sum(0, AlgebraicValue::I64(1)), // Some(1) + AlgebraicValue::sum(1, AlgebraicValue::unit()), // None + ]; + + let row = run(schema, value).await; + assert_eq!(row, "\0\0\0\u{b}{\"some\": 1}\0\0\0\u{c}{\"none\": {}}"); + + let color = AlgebraicType::Sum([SumTypeVariant::new_named(AlgebraicType::I64, "Gray")].into()); + let nested = AlgebraicType::option(color.clone()); + let schema = ProductType::from([color, nested]); + // {"Gray": 1}, {"some": {"Gray": 2}} + let value = product![ + AlgebraicValue::sum(0, AlgebraicValue::I64(1)), // Gray(1) + AlgebraicValue::sum(0, AlgebraicValue::sum(0, AlgebraicValue::I64(2))), // Some(Gray(2)) + ]; + let row = run(schema.clone(), value.clone()).await; + assert_eq!(row, "\0\0\0\u{b}{\"Gray\": 1}\0\0\0\u{15}{\"some\": {\"Gray\": 2}}"); + + // Now nested product + let product = AlgebraicType::product([ + ProductTypeElement::new(AlgebraicType::Product(schema), Some("x".into())), + ProductTypeElement::new(AlgebraicType::String, Some("y".into())), + ]); + let schema = ProductType::from([product.clone()]); + let value = product![AlgebraicValue::product(vec![ + value.into(), + AlgebraicValue::String("a".into()), + ])]; + let row = run(schema, value).await; + assert_eq!( + row, + "\0\0\0G{\"x\": {\"col_0\": {\"Gray\": 1}, \"col_1\": {\"some\": {\"Gray\": 2}}}, \"y\": \"a\"}" + ); + + // Now a simple enum + let names = AlgebraicType::simple_enum(["A", "B", "C"].into_iter()); + let schema = ProductType::from([names.clone(), names.clone(), names]); + let value = product![ + AlgebraicValue::enum_simple(0), // A + AlgebraicValue::enum_simple(1), // B + AlgebraicValue::enum_simple(2), // C + ]; + let row = run(schema, value).await; + assert_eq!(row, "\0\0\0\u{1}A\0\0\0\u{1}B\0\0\0\u{1}C"); + } + + #[tokio::test] + async fn test_special_types() { + let schema = ProductType::from([ + AlgebraicType::identity(), + AlgebraicType::connection_id(), + AlgebraicType::time_duration(), + AlgebraicType::timestamp(), + AlgebraicType::bytes(), + ]); + let value = product![ + Identity::ZERO, + ConnectionId::ZERO, + TimeDuration::from_micros(0), + Timestamp::from_micros_since_unix_epoch(1622545800000), + AlgebraicValue::Bytes("test".as_bytes().into()), + ]; + + let row = run(schema, value).await; + assert_eq!(row, "\0\0\0B\\x0000000000000000000000000000000000000000000000000000000000000000\0\0\0\"\\x00000000000000000000000000000000\0\0\0\u{3}P0D\0\0\0\u{1d}1970-01-19T18:42:25.800+00:00\0\0\0\n\\x74657374"); + } +} diff --git a/crates/pg/src/lib.rs b/crates/pg/src/lib.rs new file mode 100644 index 000000000..c4466bbc5 --- /dev/null +++ b/crates/pg/src/lib.rs @@ -0,0 +1,2 @@ +mod encoder; +pub mod pg_server; diff --git a/crates/pg/src/pg_server.rs b/crates/pg/src/pg_server.rs new file mode 100644 index 000000000..39b0fcaca --- /dev/null +++ b/crates/pg/src/pg_server.rs @@ -0,0 +1,381 @@ +use std::fmt::Debug; +use std::sync::Arc; + +use crate::encoder::{row_desc, PsqlFormatter}; +use async_trait::async_trait; +use axum::body::to_bytes; +use axum::response::IntoResponse; +use futures::{stream, Sink}; +use futures::{SinkExt, Stream}; +use http::StatusCode; +use pgwire::api::auth::{ + finish_authentication, save_startup_parameters_to_metadata, DefaultServerParameterProvider, LoginInfo, + StartupHandler, +}; +use pgwire::api::portal::Format; +use pgwire::api::query::SimpleQueryHandler; +use pgwire::api::results::{DataRowEncoder, FieldInfo, QueryResponse, Response, Tag}; +use pgwire::api::{ClientInfo, METADATA_DATABASE}; +use pgwire::api::{PgWireConnectionState, PgWireServerHandlers}; +use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; +use pgwire::messages::data::DataRow; +use pgwire::messages::startup::Authentication; +use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage}; +use pgwire::tokio::process_socket; +use spacetimedb_client_api::auth::validate_token; +use spacetimedb_client_api::routes::database; +use spacetimedb_client_api::routes::database::{SqlParams, SqlQueryParams}; +use spacetimedb_client_api::{ControlStateReadAccess, ControlStateWriteAccess, NodeDelegate}; +use spacetimedb_client_api_messages::http::SqlStmtResult; +use spacetimedb_client_api_messages::name::DatabaseName; +use spacetimedb_lib::sats::satn::{PsqlClient, TypedSerializer}; +use spacetimedb_lib::sats::{satn, Serialize, Typespace}; +use spacetimedb_lib::version::spacetimedb_lib_version; +use spacetimedb_lib::{Identity, ProductValue}; +use thiserror::Error; +use tokio::net::TcpListener; +use tokio::sync::{Mutex, Notify}; + +#[derive(Error, Debug)] +pub(crate) enum PgError { + #[error("(metadata) {0}")] + MetadataError(anyhow::Error), + #[error("(Sql) {0}")] + Sql(String), + #[error("Database name is required")] + DatabaseNameRequired, + #[error(transparent)] + Pg(#[from] PgWireError), + #[error("SSL is not supported by SpacetimeDB")] + SSLNotSupported, + #[error(transparent)] + Other(#[from] anyhow::Error), +} + +impl From for PgWireError { + fn from(err: PgError) -> Self { + if let PgError::Pg(err) = err { + err + } else { + PgWireError::ApiError(Box::new(err)) + } + } +} + +#[derive(Clone)] +struct Metadata { + database: String, + caller_identity: Identity, +} + +pub(crate) fn to_rows( + stmt: SqlStmtResult, + header: Arc>, +) -> Result>, PgError> { + let mut results = Vec::with_capacity(stmt.rows.len()); + let ty = Typespace::EMPTY.with_type(&stmt.schema); + + for row in stmt.rows { + let mut encoder = DataRowEncoder::new(header.clone()); + + for (idx, value) in ty.with_values(&row).enumerate() { + let ty = satn::PsqlType { + client: PsqlClient::Postgres, + tuple: ty.ty(), + field: &ty.ty().elements[idx], + idx, + }; + let mut fmt = PsqlFormatter { encoder: &mut encoder }; + value.serialize(TypedSerializer { ty: &ty, f: &mut fmt })?; + } + results.push(encoder.finish()); + } + Ok(stream::iter(results)) +} + +fn stats(stmt: &SqlStmtResult) -> String { + let mut info = Vec::new(); + if stmt.stats.rows_inserted != 0 { + info.push(format!("inserted: {}", stmt.stats.rows_inserted)); + } + if stmt.stats.rows_deleted != 0 { + info.push(format!("deleted: {}", stmt.stats.rows_deleted)); + } + if stmt.stats.rows_updated != 0 { + info.push(format!("updated: {}", stmt.stats.rows_updated)); + } + info.push(format!( + "server: {:.2?}", + std::time::Duration::from_micros(stmt.total_duration_micros) + )); + + info.join(", ") +} + +struct ResponseWrapper(T); +impl IntoResponse for ResponseWrapper { + fn into_response(self) -> axum::response::Response { + unreachable!("Blank impl to satisfy IntoResponse") + } +} + +async fn response(res: axum::response::Result, database: &str) -> Result { + match res.map(ResponseWrapper) { + Ok(sql) => Ok(sql.0), + err => { + let res = err.into_response(); + if res.status() == StatusCode::NOT_FOUND { + log::error!("PG: Database not found: {database}"); + return Err(PgWireError::UserError(Box::new(ErrorInfo::new( + "FATAL".to_string(), + "3D000".to_string(), + format!("database \"{database}\" does not exist"), + ))) + .into()); + } + let bytes = to_bytes(res.into_body(), usize::MAX) + .await + .map_err(|err| PgWireError::ApiError(Box::new(err)))?; + let err = String::from_utf8_lossy(&bytes); + log::error!("PG: Error for database {database}: {err}"); + Err(PgError::Sql(format!("{err}"))) + } + } +} + +struct PgSpacetimeDB { + ctx: T, + cached: Mutex>, + parameter_provider: DefaultServerParameterProvider, +} + +impl PgSpacetimeDB { + async fn exe_sql<'a>(&self, query: String) -> PgWireResult>> { + let params = self.cached.lock().await.clone().unwrap(); + let db = SqlParams { + name_or_identity: database::NameOrIdentity::Name(DatabaseName(params.database.clone())), + }; + + let sql = match response( + database::sql_direct( + self.ctx.clone(), + db, + SqlQueryParams { confirmed: true }, + params.caller_identity, + query.to_string(), + ) + .await, + ¶ms.database, + ) + .await + { + Ok(sql) => sql, + Err(PgError::Pg(PgWireError::UserError(err))) => { + return Ok(vec![Response::Error(err)]); + } + Err(err) => { + return Err(err.into()); + } + }; + + let mut result = Vec::with_capacity(sql.len()); + for sql_result in sql { + let header = row_desc(&sql_result.schema, &Format::UnifiedText); + if sql_result.rows.is_empty() && !query.to_uppercase().contains("SELECT") { + let tag = Tag::new(&stats(&sql_result)); + result.push(Response::Execution(tag)); + } else { + let rows = to_rows(sql_result, header.clone())?; + let q = QueryResponse::new(header, rows); + result.push(Response::Query(q)); + } + } + Ok(result) + } +} + +async fn close_client(client: &mut C, err: E) -> PgWireResult<()> +where + C: ClientInfo + Sink + Unpin + Send, + C::Error: Debug, + PgWireError: From<>::Error>, + pgwire::messages::response::ErrorResponse: From, +{ + let err = pgwire::messages::response::ErrorResponse::from(err); + client.feed(PgWireBackendMessage::ErrorResponse(err)).await?; + client.close().await?; + Ok(()) +} + +#[async_trait] +impl StartupHandler + for PgSpacetimeDB +{ + async fn on_startup(&self, client: &mut C, message: PgWireFrontendMessage) -> PgWireResult<()> + where + C: ClientInfo + Sink + Unpin + Send, + C::Error: Debug, + PgWireError: From<>::Error>, + { + match message { + PgWireFrontendMessage::Startup(ref startup) => { + save_startup_parameters_to_metadata(client, startup); + client.set_state(PgWireConnectionState::AuthenticationInProgress); + + let login_info = LoginInfo::from_client_info(client); + + if login_info.database().is_none() { + return Err(PgError::DatabaseNameRequired.into()); + } + + client + .send(PgWireBackendMessage::Authentication(Authentication::CleartextPassword)) + .await?; + } + PgWireFrontendMessage::PasswordMessageFamily(pwd) => { + let params = client.metadata(); + let param = |param: &str| { + params + .get(param) + .map(String::from) + .ok_or_else(|| PgError::MetadataError(anyhow::anyhow!("Missing parameter: {}", param))) + }; + + // We don't support `METADATA_USER` because we don't have a user management system. + let database = param(METADATA_DATABASE)?; + let pwd = pwd.into_password()?; + if let Ok(application_name) = param("application_name") { + log::info!("PG: Connecting to database: {database}, by {application_name}",); + } else { + log::info!("PG: Connecting to database: {database}"); + } + + let name = database::NameOrIdentity::Name(DatabaseName(database.clone())); + match response(name.resolve(&self.ctx).await, &database).await { + Ok(identity) => identity, + Err(PgError::Pg(PgWireError::UserError(err))) => { + return close_client(client, *err).await; + } + Err(err) => { + return Err(err.into()); + } + }; + + let caller_identity = match validate_token(&self.ctx, &pwd.password).await { + Ok(claims) => claims.identity, + Err(err) => { + log::error!( + "PG: Authentication failed for identity `{}` on database {database}: {err}", + pwd.password + ); + let err = ErrorInfo::new("FATAL".to_owned(), "28P01".to_owned(), err.to_string()); + return close_client(client, err).await; + } + }; + + log::info!("PG: Connected to database: {database} using identity `{caller_identity}`"); + + let metadata = Metadata { + database, + caller_identity, + }; + self.cached.lock().await.clone_from(&Some(metadata)); + finish_authentication(client, &self.parameter_provider).await?; + } + PgWireFrontendMessage::SslRequest(_) => { + let err = PgError::SSLNotSupported; + log::error!("{err}"); + let err = ErrorInfo::new("FATAL".to_owned(), "28P01".to_owned(), err.to_string()); + return close_client(client, err).await; + } + // The other messages are for features not supported by SpacetimeDB, that are rejected by the parser. + _ => { + unreachable!("Unsupported startup message: {message:?}"); + } + } + Ok(()) + } +} + +#[async_trait] +impl SimpleQueryHandler + for PgSpacetimeDB +{ + async fn do_query<'a, C>(&self, _client: &mut C, query: &str) -> PgWireResult>> + where + C: ClientInfo + Unpin + Send + Sync, + { + self.exe_sql(query.to_string()).await + } +} + +#[derive(Clone)] +pub struct PgSpacetimeDBFactory { + handler: Arc>, +} + +impl PgSpacetimeDBFactory { + pub fn new(ctx: T) -> Self { + let mut parameter_provider = DefaultServerParameterProvider::default(); + parameter_provider.server_version = format!("spacetime {}", spacetimedb_lib_version()); + + Self { + handler: Arc::new(PgSpacetimeDB { + ctx, + // This is a placeholder, it will be set in the startup handler + cached: None.into(), + parameter_provider, + }), + } + } +} + +impl PgWireServerHandlers + for PgSpacetimeDBFactory +{ + fn simple_query_handler(&self) -> Arc { + self.handler.clone() + } + + // TODO: fn extended_query_handler(&self) -> Arc {} + + fn startup_handler(&self) -> Arc { + self.handler.clone() + } +} + +pub async fn start_pg( + shutdown: Arc, + ctx: T, + tcp: TcpListener, +) { + let factory = Arc::new(PgSpacetimeDBFactory::new(ctx)); + + log::debug!( + "PG: Starting SpacetimeDB Protocol listening on {}", + tcp.local_addr().unwrap() + ); + loop { + tokio::select! { + accept_result = tcp.accept() => { + match accept_result { + Ok((stream, _addr)) => { + let factory_ref = factory.clone(); + tokio::spawn(async move { + process_socket(stream, None, factory_ref).await.inspect_err(|err|{ + log::error!("PG: Error processing socket: {err:?}"); + }) + }); + } + Err(e) => { + log::error!("PG: Accept error: {e}"); + } + } + } + _ = shutdown.notified() => { + log::info!("PG: Shutting down PostgreSQL server."); + break; + } + } + } +} diff --git a/crates/sats/src/satn.rs b/crates/sats/src/satn.rs index 446260952..2a5f12826 100644 --- a/crates/sats/src/satn.rs +++ b/crates/sats/src/satn.rs @@ -1,13 +1,14 @@ -use crate::de::DeserializeSeed; use crate::time_duration::TimeDuration; use crate::timestamp::Timestamp; -use crate::{i256, u256, AlgebraicValue, WithTypespace}; +use crate::{i256, u256, AlgebraicType, AlgebraicValue, ProductValue, Serialize, SumValue, ValueWithType}; use crate::{ser, ProductType, ProductTypeElement}; use core::fmt; use core::fmt::Write as _; -use derive_more::{From, Into}; +use derive_more::{Display, From, Into}; +use std::borrow::Cow; +use std::marker::PhantomData; -/// An extension trait for [`Serialize`](ser::Serialize) providing formatting methods. +/// An extension trait for [`Serialize`] providing formatting methods. pub trait Satn: ser::Serialize { /// Formats the value using the SATN data format into the formatter `f`. fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -18,9 +19,12 @@ pub trait Satn: ser::Serialize { /// Formats the value using the postgres SATN(PsqlFormatter { f }, /* PsqlType */) formatter `f`. fn fmt_psql(&self, f: &mut fmt::Formatter, ty: &PsqlType<'_>) -> fmt::Result { Writer::with(f, |f| { - self.serialize(PsqlFormatter { - fmt: SatnFormatter { f }, + self.serialize(TypedSerializer { ty, + f: &mut SqlFormatter { + fmt: SatnFormatter { f }, + ty, + }, }) })?; Ok(()) @@ -229,9 +233,30 @@ struct SatnFormatter<'a, 'f> { f: Writer<'a, 'f>, } +impl SatnFormatter<'_, '_> { + fn ser_variant( + &mut self, + _tag: u8, + name: Option<&str>, + value: &T, + ) -> Result<(), SatnError> { + write!(self, "(")?; + EntryWrapper::<','>::new(self.f.as_mut()).entry(|mut f| { + if let Some(name) = name { + write!(f, "{name}")?; + } + write!(f, " = ")?; + value.serialize(SatnFormatter { f })?; + Ok(()) + })?; + write!(self, ")")?; + + Ok(()) + } +} /// An error occurred during serialization to the SATS data format. #[derive(From, Into)] -struct SatnError(fmt::Error); +pub struct SatnError(fmt::Error); impl ser::Error for SatnError { fn custom(_msg: T) -> Self { @@ -331,20 +356,11 @@ impl<'a, 'f> ser::Serializer for SatnFormatter<'a, 'f> { fn serialize_variant( mut self, - _tag: u8, + tag: u8, name: Option<&str>, value: &T, ) -> Result { - write!(self, "(")?; - EntryWrapper::<','>::new(self.f.as_mut()).entry(|mut f| { - if let Some(name) = name { - write!(f, "{name}")?; - } - write!(f, " = ")?; - value.serialize(SatnFormatter { f })?; - Ok(()) - })?; - write!(self, ")") + self.ser_variant(tag, name, value) } } @@ -427,119 +443,41 @@ impl ser::SerializeNamedProduct for NamedFormatter<'_, '_> { } } -struct PsqlEntryWrapper<'a, 'f, const SEP: char> { - entry: EntryWrapper<'a, 'f, SEP>, - /// The index of the element. - idx: usize, - ty: &'a PsqlType<'a>, +/// Which client is used to format the `SQL` output? +#[derive(PartialEq, Copy, Clone, Debug)] +pub enum PsqlClient { + SpacetimeDB, + Postgres, } -/// Provides the data format for named products for `SQL`. -struct PsqlNamedFormatter<'a, 'f> { - /// The formatter for each element separating elements by a `,`. - f: PsqlEntryWrapper<'a, 'f, ','>, - /// If is not [Self::is_special] to control if we start with `(` - start: bool, - /// Remember what format we are using - use_fmt: PsqlPrintFmt, +pub struct PsqlChars { + pub start: char, + pub sep: &'static str, + pub end: char, + pub quote: &'static str, } -impl<'a, 'f> PsqlNamedFormatter<'a, 'f> { - pub fn new(ty: &'a PsqlType<'a>, f: Writer<'a, 'f>) -> Self { - Self { - start: true, - f: PsqlEntryWrapper { - entry: EntryWrapper::new(f), - idx: 0, - ty, +impl PsqlClient { + pub fn format_chars(&self) -> PsqlChars { + match self { + PsqlClient::SpacetimeDB => PsqlChars { + start: '(', + sep: " =", + end: ')', + quote: "", + }, + PsqlClient::Postgres => PsqlChars { + start: '{', + sep: ":", + end: '}', + quote: "\"", }, - // Will set later - use_fmt: PsqlPrintFmt::Satn, } } } -impl ser::SerializeNamedProduct for PsqlNamedFormatter<'_, '_> { - type Ok = (); - type Error = SatnError; - - fn serialize_element( - &mut self, - name: Option<&str>, - elem: &T, - ) -> Result<(), Self::Error> { - // For binary data & special types, output in `hex` format and skip the tagging of each value - // We need to check for both the enclosing(`self.f.ty`) type and the inner element(`name`) type. - self.use_fmt = self.f.ty.use_fmt(name); - let res = self.f.entry.entry(|mut f| { - let PsqlType { tuple, field, idx } = self.f.ty; - if !self.use_fmt.is_special() { - if self.start { - write!(f, "(")?; - self.start = false; - } - // Format the name or use the index if unnamed. - if let Some(name) = name { - write!(f, "{name}")?; - } else { - write!(f, "{idx}")?; - } - write!(f, " = ")?; - } - //Is a nested product type? - let (tuple, field, idx) = if let Some(product) = field.algebraic_type.as_product() { - (product, &product.elements[self.f.idx], self.f.idx) - } else { - (*tuple, *field, *idx) - }; - - elem.serialize(PsqlFormatter { - fmt: SatnFormatter { f }, - ty: &PsqlType { tuple, field, idx }, - })?; - - Ok(()) - }); - - // Advance to the next field. - if !self.use_fmt.is_special() { - self.f.idx += 1; - } - - res?; - - Ok(()) - } - - fn end(mut self) -> Result { - if !self.use_fmt.is_special() { - write!(self.f.entry.fmt, ")")?; - } - Ok(()) - } -} - -/// Provides the data format for unnamed products for `SQL`. -struct PsqlSeqFormatter<'a, 'f> { - /// Delegates to the named format. - inner: PsqlNamedFormatter<'a, 'f>, -} - -impl ser::SerializeSeqProduct for PsqlSeqFormatter<'_, '_> { - type Ok = (); - type Error = SatnError; - - fn serialize_element(&mut self, elem: &T) -> Result<(), Self::Error> { - ser::SerializeNamedProduct::serialize_element(&mut self.inner, None, elem) - } - - fn end(self) -> Result { - ser::SerializeNamedProduct::end(self.inner) - } -} - /// How format of the `SQL` output? -#[derive(PartialEq)] +#[derive(Debug, Copy, Clone, PartialEq, Display)] pub enum PsqlPrintFmt { /// Print as `hex` format Hex, @@ -552,14 +490,46 @@ pub enum PsqlPrintFmt { } impl PsqlPrintFmt { - fn is_special(&self) -> bool { + pub fn is_special(&self) -> bool { self != &PsqlPrintFmt::Satn } + /// Returns if the type is a special type + /// + /// Is required to check both the enclosing type and the inner element type + pub fn use_fmt(tuple: &ProductType, field: &ProductTypeElement, name: Option<&str>) -> PsqlPrintFmt { + if tuple.is_identity() + || tuple.is_connection_id() + || field.algebraic_type.is_identity() + || field.algebraic_type.is_connection_id() + || name.map(ProductType::is_identity_tag).unwrap_or_default() + || name.map(ProductType::is_connection_id_tag).unwrap_or_default() + { + return PsqlPrintFmt::Hex; + }; + + if tuple.is_timestamp() + || field.algebraic_type.is_timestamp() + || name.map(ProductType::is_timestamp_tag).unwrap_or_default() + { + return PsqlPrintFmt::Timestamp; + }; + + if tuple.is_time_duration() + || field.algebraic_type.is_time_duration() + || name.map(ProductType::is_time_duration_tag).unwrap_or_default() + { + return PsqlPrintFmt::Duration; + }; + + PsqlPrintFmt::Satn + } } /// A wrapper that remember the `header` of the tuple/struct and the current field #[derive(Debug, Clone)] pub struct PsqlType<'a> { + /// The client used to format the output + pub client: PsqlClient, /// The header of the tuple/struct pub tuple: &'a ProductType, /// The current field @@ -572,168 +542,369 @@ impl PsqlType<'_> { /// Returns if the type is a special type /// /// Is required to check both the enclosing type and the inner element type - fn use_fmt(&self, name: Option<&str>) -> PsqlPrintFmt { - if self.tuple.is_identity() - || self.tuple.is_connection_id() - || self.field.algebraic_type.is_identity() - || self.field.algebraic_type.is_connection_id() - || name.map(ProductType::is_identity_tag).unwrap_or_default() - || name.map(ProductType::is_connection_id_tag).unwrap_or_default() - { - return PsqlPrintFmt::Hex; - }; - - if self.tuple.is_timestamp() - || self.field.algebraic_type.is_timestamp() - || name.map(ProductType::is_timestamp_tag).unwrap_or_default() - { - return PsqlPrintFmt::Timestamp; - }; - - if self.tuple.is_time_duration() - || self.field.algebraic_type.is_time_duration() - || name.map(ProductType::is_time_duration_tag).unwrap_or_default() - { - return PsqlPrintFmt::Duration; - }; - - PsqlPrintFmt::Satn + pub fn use_fmt(&self) -> PsqlPrintFmt { + PsqlPrintFmt::use_fmt(self.tuple, self.field, None) } } /// An implementation of [`Serializer`](ser::Serializer) for `SQL` output. -struct PsqlFormatter<'a, 'f> { +pub struct SqlFormatter<'a, 'f> { fmt: SatnFormatter<'a, 'f>, ty: &'a PsqlType<'a>, } -impl<'a, 'f> ser::Serializer for PsqlFormatter<'a, 'f> { +/// A trait for writing values, after the special types has been determined. +/// +/// This is used to write values that could have different representations depending on the output format, +/// as defined by [`PsqlClient`] and [`PsqlPrintFmt`]. +pub trait TypedWriter { + type Error: ser::Error; + + /// Writes a value using [`ser::Serializer`] + fn write(&mut self, value: W) -> Result<(), Self::Error>; + + // Values that need special handling: + + fn write_bool(&mut self, value: bool) -> Result<(), Self::Error>; + fn write_string(&mut self, value: &str) -> Result<(), Self::Error>; + fn write_bytes(&mut self, value: &[u8]) -> Result<(), Self::Error>; + fn write_hex(&mut self, value: &[u8]) -> Result<(), Self::Error>; + fn write_timestamp(&mut self, value: Timestamp) -> Result<(), Self::Error>; + fn write_duration(&mut self, value: TimeDuration) -> Result<(), Self::Error>; + /// Writes a value as an alternative record format, e.g., for use `JSON` inside `SQL`. + fn write_alt_record( + &mut self, + _ty: &PsqlType, + _value: &ValueWithType<'_, ProductValue>, + ) -> Result { + Ok(false) + } + + fn write_record( + &mut self, + fields: Vec<(Cow, PsqlType, ValueWithType)>, + ) -> Result<(), Self::Error>; + + fn write_variant( + &mut self, + tag: u8, + ty: PsqlType, + name: Option<&str>, + value: ValueWithType, + ) -> Result<(), Self::Error>; +} + +/// A formatter for arrays that uses the `TypedWriter` trait to write elements. +pub struct TypedArrayFormatter<'a, 'f, F> { + ty: &'a PsqlType<'a>, + f: &'f mut F, +} + +impl ser::SerializeArray for TypedArrayFormatter<'_, '_, F> { type Ok = (); - type Error = SatnError; - type SerializeArray = ArrayFormatter<'a, 'f>; - type SerializeSeqProduct = PsqlSeqFormatter<'a, 'f>; - type SerializeNamedProduct = PsqlNamedFormatter<'a, 'f>; + type Error = F::Error; + + fn serialize_element(&mut self, elem: &T) -> Result<(), Self::Error> { + elem.serialize(TypedSerializer { ty: self.ty, f: self.f })?; + Ok(()) + } + + fn end(self) -> Result { + Ok(()) + } +} + +/// A formatter for sequences that uses the `TypedWriter` trait to write elements. +pub struct TypedSeqFormatter<'a, 'f, F> { + ty: &'a PsqlType<'a>, + f: &'f mut F, +} + +impl ser::SerializeSeqProduct for TypedSeqFormatter<'_, '_, F> { + type Ok = (); + type Error = F::Error; + + fn serialize_element(&mut self, elem: &T) -> Result<(), Self::Error> { + elem.serialize(TypedSerializer { ty: self.ty, f: self.f })?; + Ok(()) + } + + fn end(self) -> Result { + Ok(()) + } +} + +/// A formatter for named products that uses the `TypedWriter` trait to write elements. +pub struct TypedNamedProductFormatter { + f: PhantomData, +} + +impl ser::SerializeNamedProduct for TypedNamedProductFormatter { + type Ok = (); + type Error = F::Error; + + fn serialize_element( + &mut self, + _name: Option<&str>, + _elem: &T, + ) -> Result<(), Self::Error> { + Ok(()) + } + + fn end(self) -> Result { + Ok(()) + } +} + +/// A serializer that uses the `TypedWriter` trait to serialize values +pub struct TypedSerializer<'a, 'f, F> { + pub ty: &'a PsqlType<'a>, + pub f: &'f mut F, +} + +impl<'a, 'f, F: TypedWriter> ser::Serializer for TypedSerializer<'a, 'f, F> { + type Ok = (); + type Error = F::Error; + type SerializeArray = TypedArrayFormatter<'a, 'f, F>; + type SerializeSeqProduct = TypedSeqFormatter<'a, 'f, F>; + type SerializeNamedProduct = TypedNamedProductFormatter; fn serialize_bool(self, v: bool) -> Result { - self.fmt.serialize_bool(v) + self.f.write_bool(v) } + fn serialize_u8(self, v: u8) -> Result { - self.fmt.serialize_u8(v) + self.f.write(v) } + fn serialize_u16(self, v: u16) -> Result { - self.fmt.serialize_u16(v) + self.f.write(v) } + fn serialize_u32(self, v: u32) -> Result { - self.fmt.serialize_u32(v) + self.f.write(v) } + fn serialize_u64(self, v: u64) -> Result { - self.fmt.serialize_u64(v) + self.f.write(v) } + fn serialize_u128(self, v: u128) -> Result { - match self.ty.use_fmt(None) { - PsqlPrintFmt::Hex => self.serialize_bytes(&v.to_be_bytes()), - _ => self.fmt.serialize_u128(v), + match self.ty.use_fmt() { + PsqlPrintFmt::Hex => self.f.write_hex(&v.to_be_bytes()), + _ => self.f.write(v), } } + fn serialize_u256(self, v: u256) -> Result { - match self.ty.use_fmt(None) { - PsqlPrintFmt::Hex => self.serialize_bytes(&v.to_be_bytes()), - _ => self.fmt.serialize_u256(v), + match self.ty.use_fmt() { + PsqlPrintFmt::Hex => self.f.write_hex(&v.to_be_bytes()), + _ => self.f.write(v), } } + fn serialize_i8(self, v: i8) -> Result { - self.fmt.serialize_i8(v) + self.f.write(v) } + fn serialize_i16(self, v: i16) -> Result { - self.fmt.serialize_i16(v) + self.f.write(v) } + fn serialize_i32(self, v: i32) -> Result { - self.fmt.serialize_i32(v) + self.f.write(v) } - fn serialize_i64(mut self, v: i64) -> Result { - match self.ty.use_fmt(None) { - PsqlPrintFmt::Duration => { - write!(self.fmt, "{}", TimeDuration::from_micros(v))?; - Ok(()) - } - PsqlPrintFmt::Timestamp => { - write!(self.fmt, "{}", Timestamp::from_micros_since_unix_epoch(v))?; - Ok(()) - } - _ => self.fmt.serialize_i64(v), + + fn serialize_i64(self, v: i64) -> Result { + match self.ty.use_fmt() { + PsqlPrintFmt::Duration => self.f.write_duration(TimeDuration::from_micros(v)), + PsqlPrintFmt::Timestamp => self.f.write_timestamp(Timestamp::from_micros_since_unix_epoch(v)), + _ => self.f.write(v), } } + fn serialize_i128(self, v: i128) -> Result { - self.fmt.serialize_i128(v) + self.f.write(v) } + fn serialize_i256(self, v: i256) -> Result { - self.fmt.serialize_i256(v) + self.f.write(v) } + fn serialize_f32(self, v: f32) -> Result { - self.fmt.serialize_f32(v) + self.f.write(v) } + fn serialize_f64(self, v: f64) -> Result { - self.fmt.serialize_f64(v) + self.f.write(v) } fn serialize_str(self, v: &str) -> Result { - self.fmt.serialize_str(v) + self.f.write_string(v) } fn serialize_bytes(self, v: &[u8]) -> Result { - self.fmt.serialize_bytes(v) + if self.ty.use_fmt() == PsqlPrintFmt::Satn { + self.f.write_hex(v) + } else { + self.f.write_bytes(v) + } } - fn serialize_array(self, len: usize) -> Result { - self.fmt.serialize_array(len) + fn serialize_array(self, _len: usize) -> Result { + Ok(TypedArrayFormatter { ty: self.ty, f: self.f }) } - fn serialize_seq_product(self, len: usize) -> Result { - Ok(PsqlSeqFormatter { - inner: self.serialize_named_product(len)?, - }) + fn serialize_seq_product(self, _len: usize) -> Result { + Ok(TypedSeqFormatter { ty: self.ty, f: self.f }) } fn serialize_named_product(self, _len: usize) -> Result { - Ok(PsqlNamedFormatter::new(self.ty, self.fmt.f)) + unreachable!("This should never be called, use `serialize_named_product_raw` instead."); } - fn serialize_variant( + fn serialize_named_product_raw(self, value: &ValueWithType<'_, ProductValue>) -> Result { + let val = &value.val.elements; + assert_eq!(val.len(), value.ty().elements.len()); + // If the value is a special type, we can write it directly + if self.ty.use_fmt().is_special() { + // Is a nested product type? + // We need to check for both the enclosing(`self.ty`) type and the inner element type. + let (tuple, field) = if let Some(product) = self.ty.field.algebraic_type.as_product() { + (product, &product.elements[0]) + } else { + (self.ty.tuple, self.ty.field) + }; + return value.val.serialize(TypedSerializer { + ty: &PsqlType { + client: self.ty.client, + tuple, + field, + idx: self.ty.idx, + }, + f: self.f, + }); + } + // Allow to switch to an alternative record format, for example to write a `JSON` record. + if self.f.write_alt_record(self.ty, value)? { + return Ok(()); + } + let mut record = Vec::with_capacity(val.len()); + + for (idx, (val, field)) in val.iter().zip(&*value.ty().elements).enumerate() { + let ty = PsqlType { + client: self.ty.client, + tuple: value.ty(), + field, + idx, + }; + record.push(( + field + .name() + .map(Cow::from) + .unwrap_or_else(|| Cow::from(format!("col_{idx}"))), + ty, + value.with(&field.algebraic_type, val), + )); + } + self.f.write_record(record) + } + + fn serialize_variant_raw(self, sum: &ValueWithType<'_, SumValue>) -> Result { + let sv = sum.value(); + let (tag, val) = (sv.tag, &*sv.value); + let var_ty = &sum.ty().variants[tag as usize]; // Extract the variant type by tag. + let product = ProductType::from([AlgebraicType::sum(sum.ty().clone())]); + let ty = PsqlType { + client: self.ty.client, + tuple: &product, + field: &product.elements[0], + idx: 0, + }; + self.f + .write_variant(tag, ty, var_ty.name(), sum.with(&var_ty.algebraic_type, val)) + } + + fn serialize_variant( self, - tag: u8, - name: Option<&str>, - value: &T, + _tag: u8, + _name: Option<&str>, + _value: &T, ) -> Result { - self.fmt.serialize_variant(tag, name, value) - } - - unsafe fn serialize_bsatn(self, ty: &Ty, bsatn: &[u8]) -> Result - where - for<'b, 'de> WithTypespace<'b, Ty>: DeserializeSeed<'de, Output: Into>, - { - // SAFETY: Forward caller requirements of this method to that we are calling. - unsafe { self.fmt.serialize_bsatn(ty, bsatn) } - } - - unsafe fn serialize_bsatn_in_chunks<'c, Ty, I: Clone + Iterator>( - self, - ty: &Ty, - total_bsatn_len: usize, - bsatn: I, - ) -> Result - where - for<'b, 'de> WithTypespace<'b, Ty>: DeserializeSeed<'de, Output: Into>, - { - // SAFETY: Forward caller requirements of this method to that we are calling. - unsafe { self.fmt.serialize_bsatn_in_chunks(ty, total_bsatn_len, bsatn) } - } - - unsafe fn serialize_str_in_chunks<'c, I: Clone + Iterator>( - self, - total_len: usize, - string: I, - ) -> Result { - // SAFETY: Forward caller requirements of this method to that we are calling. - unsafe { self.fmt.serialize_str_in_chunks(total_len, string) } + unreachable!("Use `serialize_variant_raw` instead."); + } +} + +impl TypedWriter for SqlFormatter<'_, '_> { + type Error = SatnError; + + fn write(&mut self, value: W) -> Result<(), Self::Error> { + write!(self.fmt, "{value}") + } + + fn write_bool(&mut self, value: bool) -> Result<(), Self::Error> { + write!(self.fmt, "{value}") + } + + fn write_string(&mut self, value: &str) -> Result<(), Self::Error> { + write!(self.fmt, "\"{value}\"") + } + + fn write_bytes(&mut self, value: &[u8]) -> Result<(), Self::Error> { + self.write_hex(value) + } + + fn write_hex(&mut self, value: &[u8]) -> Result<(), Self::Error> { + match self.ty.client { + PsqlClient::SpacetimeDB => write!(self.fmt, "0x{}", hex::encode(value)), + PsqlClient::Postgres => write!(self.fmt, "\"0x{}\"", hex::encode(value)), + } + } + + fn write_timestamp(&mut self, value: Timestamp) -> Result<(), Self::Error> { + match self.ty.client { + PsqlClient::SpacetimeDB => write!(self.fmt, "{}", value.to_rfc3339().unwrap()), + PsqlClient::Postgres => write!(self.fmt, "\"{}\"", value.to_rfc3339().unwrap()), + } + } + + fn write_duration(&mut self, value: TimeDuration) -> Result<(), Self::Error> { + match self.ty.client { + PsqlClient::SpacetimeDB => write!(self.fmt, "{value}"), + PsqlClient::Postgres => write!(self.fmt, "\"{}\"", value.to_iso8601()), + } + } + + fn write_record( + &mut self, + fields: Vec<(Cow, PsqlType<'_>, ValueWithType)>, + ) -> Result<(), Self::Error> { + let PsqlChars { start, sep, end, quote } = self.ty.client.format_chars(); + write!(self.fmt, "{start}")?; + for (idx, (name, ty, value)) in fields.into_iter().enumerate() { + if idx > 0 { + write!(self.fmt, ", ")?; + } + write!(self.fmt, "{quote}{name}{quote}{sep} ")?; + + // Serialize the value + value.serialize(TypedSerializer { ty: &ty, f: self })?; + } + write!(self.fmt, "{end}")?; + Ok(()) + } + + fn write_variant( + &mut self, + tag: u8, + ty: PsqlType, + name: Option<&str>, + value: ValueWithType, + ) -> Result<(), Self::Error> { + self.write_record(vec![( + name.map(Cow::from).unwrap_or_else(|| Cow::from(format!("col_{tag}"))), + ty, + value, + )]) } } diff --git a/crates/sats/src/ser.rs b/crates/sats/src/ser.rs index a9fd0fe01..4d56fc5aa 100644 --- a/crates/sats/src/ser.rs +++ b/crates/sats/src/ser.rs @@ -6,7 +6,7 @@ mod impls; pub mod serde; use crate::de::DeserializeSeed; -use crate::{algebraic_value::ser::ValueSerializer, bsatn, buffer::BufWriter}; +use crate::{algebraic_value::ser::ValueSerializer, bsatn, buffer::BufWriter, ProductValue, SumValue, ValueWithType}; use crate::{AlgebraicValue, WithTypespace}; use core::marker::PhantomData; use core::{convert::Infallible, fmt}; @@ -117,6 +117,31 @@ pub trait Serializer: Sized { /// The argument is the number of fields in the product. fn serialize_named_product(self, len: usize) -> Result; + /// Serialize a product with named fields. + /// + /// Allow to override the default serialization for where we need to switch the output format, + /// see [`crate::satn::TypedWriter`]. + fn serialize_named_product_raw(self, value: &ValueWithType<'_, ProductValue>) -> Result { + let val = &value.val.elements; + assert_eq!(val.len(), value.ty().elements.len()); + let mut prod = self.serialize_named_product(val.len())?; + for (val, el_ty) in val.iter().zip(&*value.ty().elements) { + prod.serialize_element(el_ty.name(), &value.with(&el_ty.algebraic_type, val))? + } + prod.end() + } + + /// Serialize a sum value + /// + /// Allow to override the default serialization for where we need to switch the output format, + /// see [`crate::satn::TypedWriter`]. + fn serialize_variant_raw(self, sum: &ValueWithType<'_, SumValue>) -> Result { + let sv = sum.value(); + let (tag, val) = (sv.tag, &*sv.value); + let var_ty = &sum.ty().variants[tag as usize]; // Extract the variant type by tag. + self.serialize_variant(tag, var_ty.name(), &sum.with(&var_ty.algebraic_type, val)) + } + /// Serialize a sum value provided the chosen `tag`, `name`, and `value`. fn serialize_variant( self, diff --git a/crates/sats/src/ser/impls.rs b/crates/sats/src/ser/impls.rs index 914099e14..39d08d38d 100644 --- a/crates/sats/src/ser/impls.rs +++ b/crates/sats/src/ser/impls.rs @@ -1,4 +1,4 @@ -use super::{Serialize, SerializeArray, SerializeNamedProduct, SerializeSeqProduct, Serializer}; +use super::{Serialize, SerializeArray, SerializeSeqProduct, Serializer}; use crate::{i256, u256}; use crate::{AlgebraicType, AlgebraicValue, ArrayValue, ProductValue, SumValue, ValueWithType, F32, F64}; use core::ops::Bound; @@ -190,19 +190,10 @@ impl_serialize!( } ); impl_serialize!([] ValueWithType<'_, SumValue>, (self, ser) => { - let sv = self.value(); - let (tag, val) = (sv.tag, &*sv.value); - let var_ty = &self.ty().variants[tag as usize]; // Extract the variant type by tag. - ser.serialize_variant(tag, var_ty.name(), &self.with(&var_ty.algebraic_type, val)) + ser.serialize_variant_raw(self) }); impl_serialize!([] ValueWithType<'_, ProductValue>, (self, ser) => { - let val = &self.value().elements; - assert_eq!(val.len(), self.ty().elements.len()); - let mut prod = ser.serialize_named_product(val.len())?; - for (val, el_ty) in val.iter().zip(&*self.ty().elements) { - prod.serialize_element(el_ty.name(), &self.with(&el_ty.algebraic_type, val))? - } - prod.end() + ser.serialize_named_product_raw(self) }); impl_serialize!([] ValueWithType<'_, ArrayValue>, (self, ser) => { let mut ty = &*self.ty().elem_ty; diff --git a/crates/sats/src/time_duration.rs b/crates/sats/src/time_duration.rs index bd8d4543f..f6d932c4b 100644 --- a/crates/sats/src/time_duration.rs +++ b/crates/sats/src/time_duration.rs @@ -82,6 +82,22 @@ impl TimeDuration { pub fn checked_sub(self, other: Self) -> Option { self.to_micros().checked_sub(other.to_micros()).map(Self::from_micros) } + + /// Generate an `iso8601` format string. + /// + /// This is the better supported format for use for the `pg wire protocol`. + /// + /// Example: + /// ```rust + /// use std::time::Duration; + /// use spacetimedb_sats::time_duration::TimeDuration; + /// assert_eq!( TimeDuration::from_micros(0).to_iso8601().as_str(), "P0D"); + /// assert_eq!( TimeDuration::from_micros(-1_000_000).to_iso8601().as_str(), "-PT1S"); + /// assert_eq!( TimeDuration::from_duration(Duration::from_secs(60 * 24)).to_iso8601().as_str(), "PT1440S"); + /// ``` + pub fn to_iso8601(self) -> String { + chrono::Duration::microseconds(self.to_micros()).to_string() + } } impl From for TimeDuration { diff --git a/crates/sats/src/timestamp.rs b/crates/sats/src/timestamp.rs index 5da2e7396..50affcf60 100644 --- a/crates/sats/src/timestamp.rs +++ b/crates/sats/src/timestamp.rs @@ -171,13 +171,17 @@ impl Timestamp { pub fn checked_sub_duration(&self, duration: Duration) -> Option { self.checked_sub(TimeDuration::from_duration(duration)) } - /// Returns an RFC 3339 and ISO 8601 date and time string such as `1996-12-19T16:39:57-08:00`. - pub fn to_rfc3339(&self) -> anyhow::Result { + + pub fn to_chrono_date_time(&self) -> anyhow::Result> { DateTime::from_timestamp_micros(self.to_micros_since_unix_epoch()) - .map(|t| t.to_rfc3339()) .ok_or_else(|| anyhow::anyhow!("Timestamp with i64 microseconds since Unix epoch overflows DateTime")) .with_context(|| self.to_micros_since_unix_epoch()) } + + /// Returns an RFC 3339 and ISO 8601 date and time string such as `1996-12-19T16:39:57-08:00`. + pub fn to_rfc3339(&self) -> anyhow::Result { + Ok(self.to_chrono_date_time()?.to_rfc3339()) + } } impl Add for Timestamp { diff --git a/crates/standalone/Cargo.toml b/crates/standalone/Cargo.toml index 625f75145..e71972172 100644 --- a/crates/standalone/Cargo.toml +++ b/crates/standalone/Cargo.toml @@ -27,6 +27,7 @@ spacetimedb-core.workspace = true spacetimedb-datastore.workspace = true spacetimedb-lib.workspace = true spacetimedb-paths.workspace = true +spacetimedb-pg.workspace = true spacetimedb-table.workspace = true spacetimedb-schema.workspace = true diff --git a/crates/standalone/src/lib.rs b/crates/standalone/src/lib.rs index 9ebb9c230..860dddd68 100644 --- a/crates/standalone/src/lib.rs +++ b/crates/standalone/src/lib.rs @@ -186,6 +186,7 @@ impl spacetimedb_client_api::ControlStateReadAccess for StandaloneEnv { id: 0, unschedulable: false, advertise_addr: Some("node:80".to_owned()), + pg_addr: Some("node:5432".to_owned()), })); } Ok(None) diff --git a/crates/standalone/src/subcommands/start.rs b/crates/standalone/src/subcommands/start.rs index 10a528300..5615945e5 100644 --- a/crates/standalone/src/subcommands/start.rs +++ b/crates/standalone/src/subcommands/start.rs @@ -1,3 +1,4 @@ +use spacetimedb_pg::pg_server; use std::sync::Arc; use crate::{StandaloneEnv, StandaloneOptions}; @@ -176,12 +177,27 @@ pub async fn exec(args: &ArgMatches, db_cores: JobCores) -> anyhow::Result<()> { db_routes.root_post = db_routes.root_post.layer(DefaultBodyLimit::disable()); db_routes.db_put = db_routes.db_put.layer(DefaultBodyLimit::disable()); let extra = axum::Router::new().nest("/health", spacetimedb_client_api::routes::health::router()); - let service = router(&ctx, db_routes, extra).with_state(ctx); + let service = router(&ctx, db_routes, extra).with_state(ctx.clone()); let tcp = TcpListener::bind(listen_addr).await?; socket2::SockRef::from(&tcp).set_nodelay(true)?; - log::debug!("Starting SpacetimeDB listening on {}", tcp.local_addr().unwrap()); - axum::serve(tcp, service).await?; + log::debug!("Starting SpacetimeDB listening on {}", tcp.local_addr()?); + let pg_server_addr = format!("{}:5432", listen_addr.split(':').next().unwrap()); + let tcp_pg = TcpListener::bind(pg_server_addr).await?; + + let notify = Arc::new(tokio::sync::Notify::new()); + let shutdown_notify = notify.clone(); + tokio::select! { + _ = pg_server::start_pg(notify.clone(), ctx, tcp_pg) => {}, + _ = axum::serve(tcp, service).with_graceful_shutdown(async move { + shutdown_notify.notified().await; + }) => {}, + _ = tokio::signal::ctrl_c() => { + println!("Shutting down servers..."); + notify.notify_waiters(); // Notify all tasks + } + } + Ok(()) } diff --git a/docker-compose.yml b/docker-compose.yml index 8e2dd39fb..a31ed13d8 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -25,6 +25,8 @@ services: - /stdb ports: - "3000:3000" + # Postgres + - "5432:5432" # Tracy - "8086:8086" entrypoint: cargo watch -i flamegraphs -i log.conf --why -C crates/standalone -x 'run start --data-dir=/stdb/data --jwt-pub-key-path=/etc/spacetimedb/id_ecdsa.pub --jwt-priv-key-path=/etc/spacetimedb/id_ecdsa' diff --git a/smoketests/tests/pg_wire.py b/smoketests/tests/pg_wire.py new file mode 100644 index 000000000..89c7880c3 --- /dev/null +++ b/smoketests/tests/pg_wire.py @@ -0,0 +1,293 @@ +from .. import Smoketest +import subprocess +import os +import tomllib +import psycopg2 + + +def psql(identity: str, sql: str, extra=None) -> str: + """Call `psql` and execute the given SQL statement.""" + if extra is None: + extra = dict() + result = subprocess.run( + ["psql", "-h", "127.0.0.1", "-p", "5432", "-U", "postgres", "-d", "quickstart", "--quiet", "-c", sql], + encoding="utf8", + env={**os.environ, **extra, "PGPASSWORD": identity}, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + if result.stderr: + raise Exception(result.stderr.strip()) + return result.stdout.strip() + + +def connect_db(identity: str): + """Connect to the database using `psycopg2`.""" + conn = psycopg2.connect(host="127.0.0.1", port=5432, user="postgres", password=identity, dbname="quickstart") + conn.set_session(autocommit=True) # Disable automic transaction + return conn + + +class SqlFormat(Smoketest): + AUTOPUBLISH = False + MODULE_CODE = """ +use spacetimedb::sats::{i256, u256}; +use spacetimedb::{ConnectionId, Identity, ReducerContext, SpacetimeType, Table, Timestamp, TimeDuration}; + +#[derive(Copy, Clone)] +#[spacetimedb::table(name = t_ints, public)] +pub struct TInts { + i8: i8, + i16: i16, + i32: i32, + i64: i64, + i128: i128, + i256: i256, +} + +#[spacetimedb::table(name = t_ints_tuple, public)] +pub struct TIntsTuple { + tuple: TInts, +} + +#[derive(Copy, Clone)] +#[spacetimedb::table(name = t_uints, public)] +pub struct TUints { + u8: u8, + u16: u16, + u32: u32, + u64: u64, + u128: u128, + u256: u256, +} + +#[spacetimedb::table(name = t_uints_tuple, public)] +pub struct TUintsTuple { + tuple: TUints, +} + +#[derive(Clone)] +#[spacetimedb::table(name = t_others, public)] +pub struct TOthers { + bool: bool, + f32: f32, + f64: f64, + str: String, + bytes: Vec, + identity: Identity, + connection_id: ConnectionId, + timestamp: Timestamp, + duration: TimeDuration, +} + +#[spacetimedb::table(name = t_others_tuple, public)] +pub struct TOthersTuple { + tuple: TOthers +} + +#[derive(SpacetimeType, Debug, Clone, Copy)] +pub enum Action { + Inactive, + Active, +} + +#[derive(SpacetimeType, Debug, Clone, Copy)] +pub enum Color { + Gray(u8), +} + +#[derive(Copy, Clone)] +#[spacetimedb::table(name = t_simple_enum, public)] +pub struct TSimpleEnum { + id : u32, + action: Action, +} + +#[spacetimedb::table(name = t_enum, public)] +pub struct TEnum { + id : u32, + color: Color, +} + +#[spacetimedb::table(name = t_nested, public)] +pub struct TNested { + en: TEnum, + se: TSimpleEnum, + ints: TInts, +} + +#[spacetimedb::reducer] +pub fn test(ctx: &ReducerContext) { + let tuple = TInts { + i8: -25, + i16: -3224, + i32: -23443, + i64: -2344353, + i128: -234434897853, + i256: (-234434897853i128).into(), + }; + let ints = tuple; + ctx.db.t_ints().insert(tuple); + ctx.db.t_ints_tuple().insert(TIntsTuple { tuple }); + + let tuple = TUints { + u8: 105, + u16: 1050, + u32: 83892, + u64: 48937498, + u128: 4378528978889, + u256: 4378528978889u128.into(), + }; + ctx.db.t_uints().insert(tuple); + ctx.db.t_uints_tuple().insert(TUintsTuple { tuple }); + + let tuple = TOthers { + bool: true, + f32: 594806.58906, + f64: -3454353.345389043278459, + str: "This is spacetimedb".to_string(), + bytes: vec!(1, 2, 3, 4, 5, 6, 7), + identity: Identity::ONE, + connection_id: ConnectionId::ZERO, + timestamp: Timestamp::UNIX_EPOCH, + duration: TimeDuration::from_micros(1000 * 10000), + }; + ctx.db.t_others().insert(tuple.clone()); + ctx.db.t_others_tuple().insert(TOthersTuple { tuple }); + + ctx.db.t_simple_enum().insert(TSimpleEnum { id: 1, action: Action::Inactive }); + ctx.db.t_simple_enum().insert(TSimpleEnum { id: 2, action: Action::Active }); + + ctx.db.t_enum().insert(TEnum { id: 1, color: Color::Gray(128) }); + + ctx.db.t_nested().insert(TNested { + en: TEnum { id: 1, color: Color::Gray(128) }, + se: TSimpleEnum { id: 2, action: Action::Active }, + ints, + }); +} +""" + + def assertSql(self, token: str, sql: str, expected): + self.maxDiff = None + sql_out = psql(token, sql) + sql_out = "\n".join([line.rstrip() for line in sql_out.splitlines()]) + expected = "\n".join([line.rstrip() for line in expected.splitlines()]) + print(sql_out) + self.assertMultiLineEqual(sql_out, expected) + + def read_token(self): + """Read the token from the config file.""" + with open(self.config_path, "rb") as f: + config = tomllib.load(f) + return config['spacetimedb_token'] + + def test_sql_format(self): + """This test is designed to test calling `psql` to execute SQL statements""" + token = self.read_token() + self.publish_module("quickstart", clear=True) + + self.call("test") + + self.assertSql(token, "SELECT * FROM t_ints", """\ +i8 | i16 | i32 | i64 | i128 | i256 +-----+-------+--------+----------+---------------+--------------- + -25 | -3224 | -23443 | -2344353 | -234434897853 | -234434897853 +(1 row)""") + self.assertSql(token, "SELECT * FROM t_ints_tuple", """\ +tuple +--------------------------------------------------------------------------------------------------------- + {"i8": -25, "i16": -3224, "i32": -23443, "i64": -2344353, "i128": -234434897853, "i256": -234434897853} +(1 row)""") + self.assertSql(token, "SELECT * FROM t_uints", """\ +u8 | u16 | u32 | u64 | u128 | u256 +-----+------+-------+----------+---------------+--------------- + 105 | 1050 | 83892 | 48937498 | 4378528978889 | 4378528978889 +(1 row)""") + self.assertSql(token, "SELECT * FROM t_uints_tuple", """\ +tuple +------------------------------------------------------------------------------------------------------- + {"u8": 105, "u16": 1050, "u32": 83892, "u64": 48937498, "u128": 4378528978889, "u256": 4378528978889} +(1 row)""") + self.assertSql(token, "SELECT * FROM t_others", """\ +bool | f32 | f64 | str | bytes | identity | connection_id | timestamp | duration +------+-----------+---------------------+---------------------+------------------+--------------------------------------------------------------------+------------------------------------+---------------------------+---------- + t | 594806.56 | -3454353.3453890434 | This is spacetimedb | \\x01020304050607 | \\x0000000000000000000000000000000000000000000000000000000000000001 | \\x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | PT10S +(1 row)""") + self.assertSql(token, "SELECT * FROM t_others_tuple", """\ +tuple +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + {"bool": true, "f32": 594806.56, "f64": -3454353.3453890434, "str": "This is spacetimedb", "bytes": "0x01020304050607", "identity": "0x0000000000000000000000000000000000000000000000000000000000000001", "connection_id": "0x00000000000000000000000000000000", "timestamp": "1970-01-01T00:00:00+00:00", "duration": "PT10S"} +(1 row)""") + self.assertSql(token, "SELECT * FROM t_simple_enum", """\ +id | action +----+---------- + 1 | Inactive + 2 | Active +(2 rows)""") + self.assertSql(token, "SELECT * FROM t_enum", """\ +id | color +----+--------------- + 1 | {"Gray": 128} +(1 row)""") + self.assertSql(token, "SELECT * FROM t_nested", """\ +en | se | ints +-----------------------------------+-------------------------------------+--------------------------------------------------------------------------------------------------------- + {"id": 1, "color": {"Gray": 128}} | {"id": 2, "action": {"Active": {}}} | {"i8": -25, "i16": -3224, "i32": -23443, "i64": -2344353, "i128": -234434897853, "i256": -234434897853} +(1 row)""") + + def test_sql_conn(self): + """This test is designed to test connecting to the database and executing queries using `psycopg2`""" + token = self.read_token() + self.publish_module("quickstart", clear=True) + self.call("test") + + conn = connect_db(token) + # Check prepared statements (faked by `psycopg2`) + with conn.cursor() as cur: + cur.execute("select * from t_uints where u8 = %s and u16 = %s", (105, 1050)) + rows = cur.fetchall() + self.assertEqual(rows[0], (105, 1050, 83892, 48937498, 4378528978889, 4378528978889)) + # Check long-lived connection + with conn.cursor() as cur: + for _ in range(10): + cur.execute("select count(*) as t from t_uints") + rows = cur.fetchall() + self.assertEqual(rows[0], (1,)) + conn.close() + + def test_failures(self): + """This test is designed to test failure cases""" + token = self.read_token() + self.publish_module("quickstart", clear=True) + + # Empty query + sql_out = psql(token, "") + self.assertEqual(sql_out, "") + + # Connection fails when `ssl` is required + for ssl_mode in ["require", "verify-ca", "verify-full"]: + with self.assertRaises(Exception) as cm: + psql(token, "SELECT * FROM t_uints", extra={"PGSSLMODE": ssl_mode}) + self.assertIn("not support SSL", str(cm.exception)) + + # But works with `ssl` is disabled or optional + for ssl_mode in ["disable", "allow", "prefer"]: + psql(token, "SELECT * FROM t_uints", extra={"PGSSLMODE": ssl_mode}) + + # Connection fails with invalid token + with self.assertRaises(Exception) as cm: + psql("invalid_token", "SELECT * FROM t_uints") + self.assertIn("Invalid token", str(cm.exception)) + + # Returns error for unsupported `sql` statements + with self.assertRaises(Exception) as cm: + psql(token, "SELECT CASE a WHEN 1 THEN 'one' ELSE 'other' END FROM t_uints") + self.assertIn("Unsupported", str(cm.exception)) + + # And prepared statements + with self.assertRaises(Exception) as cm: + psql(token, "SELECT * FROM t_uints where u8 = $1") + self.assertIn("Unsupported", str(cm.exception))