Support for the PG wire protocol (#2702)

# Description of Changes

Closes
[#2686](https://github.com/clockworklabs/SpacetimeDB/issues/2686).

Add support for listening using the [PG wire
protocol](https://www.postgresql.org/docs/current/protocol.html) so `pg`
clients could be used against the database.

# API and ABI breaking changes

The output of `duration` is changed to `rfc3339`, instead of the way is
made with `sats` because is what is done in `pg`, see note below.

# Expected complexity level and risk

2

~~There is open questions that are in the [ticket
#2686](https://github.com/clockworklabs/SpacetimeDB/issues/2686). Also
the crate used here require `RustTls`, so it could be good idea to
decide if~~:

* ~~Rewrite a big chunk of code to use `OpenSSL`~~
* ~~Move to `RustTls`
https://github.com/clockworklabs/SpacetimeDB/pull/1700~~
* ~~Pay for the extra compilation cost~~.

I open another port(`5433`) to listen for `pg` connections using `ssl`.
Need to be decided if this is the way or instead try to multi-plex the
current port for both protocols.

# Testing

Only manual testing so far. Solving the above questions allow me to
implement some unit tests. Also, not yet integrated into cloud for the
same reasons.

- [x] Adding some test for the binary encoding of special and primitive
types
- [x] Smoke test using `psql` that connect to the db instance and run
some queries
- [x] Manually inspect using a UI database explorer how infer the types,
some of this tools generate special widgets when displaying `json,
duration, etc`

---------

Co-authored-by: Noa <coolreader18@gmail.com>
This commit is contained in:
Mario Montoya
2025-09-10 14:58:03 -05:00
committed by GitHub
parent 2c74f73550
commit 8adef2b93b
26 changed files with 1790 additions and 291 deletions
+8 -2
View File
@@ -21,7 +21,7 @@ jobs:
include:
- { runner: spacetimedb-runner, smoketest_args: --docker }
- { runner: windows-latest, smoketest_args: --no-build-cli }
runner: [spacetimedb-runner, windows-latest]
runner: [ spacetimedb-runner, windows-latest ]
runs-on: ${{ matrix.runner }}
steps:
- name: Find Git ref
@@ -44,6 +44,10 @@ jobs:
- uses: actions/setup-dotnet@v4
with:
global-json-file: modules/global.json
- name: Install psql (Windows)
if: runner.os == 'Windows'
run: choco install psql -y --no-progress
shell: powershell
- name: Build and start database (Linux)
if: runner.os == 'Linux'
run: docker compose up -d
@@ -54,11 +58,13 @@ jobs:
Start-Process target/debug/spacetimedb-cli.exe start
cd modules
# the sdk-manifests on windows-latest are messed up, so we need to update them
dotnet workload config --update-mode workload-set
dotnet workload config --update-mode manifests
dotnet workload update
- uses: actions/setup-python@v5
with: { python-version: '3.12' }
if: runner.os == 'Windows'
- name: Install psycopg2
run: python -m pip install psycopg2-binary
- name: Run smoketests
# Note: clear_database and replication only work in private
run: python -m smoketests ${{ matrix.smoketest_args }} -x clear_database replication
Generated
+187 -7
View File
@@ -195,6 +195,12 @@ dependencies = [
"derive_arbitrary",
]
[[package]]
name = "array-init"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d62b7694a562cdf5a74227903507c56ab2cc8bdd1f781ed5cb4cf9c9f810bfc"
[[package]]
name = "arrayref"
version = "0.3.9"
@@ -293,6 +299,30 @@ version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
[[package]]
name = "aws-lc-rs"
version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19b756939cb2f8dc900aa6dcd505e6e2428e9cae7ff7b028c49e3946efa70878"
dependencies = [
"aws-lc-sys",
"untrusted 0.7.1",
"zeroize",
]
[[package]]
name = "aws-lc-sys"
version = "0.28.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa9b6986f250236c27e5a204062434a773a13243d2ffc2955f37bdba4c5c6a1"
dependencies = [
"bindgen 0.69.5",
"cc",
"cmake",
"dunce",
"fs_extra",
]
[[package]]
name = "axum"
version = "0.7.9"
@@ -429,6 +459,29 @@ dependencies = [
"serde",
]
[[package]]
name = "bindgen"
version = "0.69.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
dependencies = [
"bitflags 2.9.0",
"cexpr",
"clang-sys",
"itertools 0.12.1",
"lazy_static",
"lazycell",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash 1.1.0",
"shlex",
"syn 2.0.101",
"which 4.4.2",
]
[[package]]
name = "bindgen"
version = "0.71.1"
@@ -444,7 +497,7 @@ dependencies = [
"proc-macro2",
"quote",
"regex",
"rustc-hash",
"rustc-hash 2.1.1",
"shlex",
"syn 2.0.101",
]
@@ -925,6 +978,15 @@ dependencies = [
"winapi",
]
[[package]]
name = "cmake"
version = "0.1.54"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0"
dependencies = [
"cc",
]
[[package]]
name = "cobs"
version = "0.2.3"
@@ -1092,7 +1154,7 @@ dependencies = [
"hashbrown 0.14.5",
"log",
"regalloc2",
"rustc-hash",
"rustc-hash 2.1.1",
"smallvec",
"target-lexicon",
]
@@ -1485,6 +1547,17 @@ dependencies = [
"serde",
]
[[package]]
name = "derive-new"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.101",
]
[[package]]
name = "derive_arbitrary"
version = "1.4.1"
@@ -1602,6 +1675,12 @@ dependencies = [
"shared_child",
]
[[package]]
name = "dunce"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813"
[[package]]
name = "educe"
version = "0.4.23"
@@ -3011,12 +3090,41 @@ dependencies = [
"spacetimedb 1.4.0",
]
[[package]]
name = "lazy-regex"
version = "3.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60c7310b93682b36b98fa7ea4de998d3463ccbebd94d935d6b48ba5b6ffa7126"
dependencies = [
"lazy-regex-proc_macros",
"once_cell",
"regex-lite",
]
[[package]]
name = "lazy-regex-proc_macros"
version = "3.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ba01db5ef81e17eb10a5e0f2109d1b3a3e29bac3070fdbd7d156bf7dbd206a1"
dependencies = [
"proc-macro2",
"quote",
"regex",
"syn 2.0.101",
]
[[package]]
name = "lazy_static"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "lazycell"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]]
name = "leb128"
version = "0.2.5"
@@ -3215,6 +3323,12 @@ dependencies = [
"digest",
]
[[package]]
name = "md5"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae960838283323069879657ca3de837e9f7bbb4c7bf6ea7f1b290d5e9476d2e0"
[[package]]
name = "memchr"
version = "2.7.4"
@@ -3805,6 +3919,31 @@ dependencies = [
"postgres-types",
]
[[package]]
name = "pgwire"
version = "0.32.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ddf403a6ee31cf7f2217b2bd8447cb13dbb6c268d7e81501bc78a4d3daafd294"
dependencies = [
"async-trait",
"aws-lc-rs",
"bytes",
"chrono",
"derive-new",
"futures",
"hex",
"lazy-regex",
"md5",
"postgres-types",
"rand 0.9.1",
"rust_decimal",
"rustls-pki-types",
"thiserror 2.0.12",
"tokio",
"tokio-rustls",
"tokio-util",
]
[[package]]
name = "phf"
version = "0.11.3"
@@ -3956,6 +4095,7 @@ version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613283563cd90e1dfc3518d548caee47e0e725455ed619881f5cf21f36de4b48"
dependencies = [
"array-init",
"bytes",
"chrono",
"fallible-iterator 0.2.0",
@@ -4445,7 +4585,7 @@ checksum = "12908dbeb234370af84d0579b9f68258a0f67e201412dd9a2814e6f45b2fc0f0"
dependencies = [
"hashbrown 0.14.5",
"log",
"rustc-hash",
"rustc-hash 2.1.1",
"slice-group-by",
"smallvec",
]
@@ -4482,6 +4622,12 @@ dependencies = [
"regex-syntax 0.8.5",
]
[[package]]
name = "regex-lite"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a"
[[package]]
name = "regex-syntax"
version = "0.6.29"
@@ -4623,7 +4769,7 @@ dependencies = [
"cfg-if",
"getrandom 0.2.16",
"libc",
"untrusted",
"untrusted 0.9.0",
"windows-sys 0.52.0",
]
@@ -4734,6 +4880,12 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustc-hash"
version = "2.1.1"
@@ -4781,6 +4933,8 @@ version = "0.23.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "730944ca083c1c233a75c09f199e973ca499344a2b7ba9e755c457e86fb4a321"
dependencies = [
"aws-lc-rs",
"log",
"once_cell",
"rustls-pki-types",
"rustls-webpki",
@@ -4821,9 +4975,10 @@ version = "0.103.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7149975849f1abb3832b246010ef62ccc80d3a76169517ada7188252b9cfb437"
dependencies = [
"aws-lc-rs",
"ring",
"rustls-pki-types",
"untrusted",
"untrusted 0.9.0",
]
[[package]]
@@ -5674,7 +5829,7 @@ dependencies = [
"regex",
"reqwest 0.12.15",
"rustc-demangle",
"rustc-hash",
"rustc-hash 2.1.1",
"scopeguard",
"semver",
"serde",
@@ -5960,6 +6115,24 @@ dependencies = [
"xdg",
]
[[package]]
name = "spacetimedb-pg"
version = "1.4.0"
dependencies = [
"anyhow",
"async-trait",
"axum",
"futures",
"http 1.3.1",
"log",
"pgwire",
"spacetimedb-client-api",
"spacetimedb-client-api-messages",
"spacetimedb-lib 1.4.0",
"thiserror 1.0.69",
"tokio",
]
[[package]]
name = "spacetimedb-physical-plan"
version = "1.4.0"
@@ -6207,6 +6380,7 @@ dependencies = [
"spacetimedb-datastore",
"spacetimedb-lib 1.4.0",
"spacetimedb-paths",
"spacetimedb-pg",
"spacetimedb-schema",
"spacetimedb-table",
"tempfile",
@@ -7397,6 +7571,12 @@ version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
[[package]]
name = "untrusted"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
[[package]]
name = "untrusted"
version = "0.9.0"
@@ -7477,7 +7657,7 @@ version = "137.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ca393e2032ddba2a57169e15cac5d0a81cdb3d872a8886f4468bc0f486098d2"
dependencies = [
"bindgen",
"bindgen 0.71.1",
"bitflags 2.9.0",
"fslock",
"gzip-header",
+3
View File
@@ -19,6 +19,7 @@ members = [
"crates/lib",
"crates/metrics",
"crates/paths",
"crates/pg",
"crates/physical-plan",
"crates/primitives",
"crates/query",
@@ -115,6 +116,7 @@ spacetimedb-lib = { path = "crates/lib", default-features = false, version = "1.
spacetimedb-memory-usage = { path = "crates/memory-usage", version = "1.4.0", default-features = false }
spacetimedb-metrics = { path = "crates/metrics", version = "1.4.0" }
spacetimedb-paths = { path = "crates/paths", version = "1.4.0" }
spacetimedb-pg = { path = "crates/pg", version = "1.4.0" }
spacetimedb-physical-plan = { path = "crates/physical-plan", version = "1.4.0" }
spacetimedb-primitives = { path = "crates/primitives", version = "1.4.0" }
spacetimedb-query = { path = "crates/query", version = "1.4.0" }
@@ -214,6 +216,7 @@ paste = "1.0"
percent-encoding = "2.3"
petgraph = { version = "0.6.5", default-features = false }
pin-project-lite = "0.2.9"
pgwire = { version = "0.32", features = ["server-api"] }
postgres-types = "0.2.5"
pretty_assertions = { version = "1.4", features = ["unstable"] }
proc-macro2 = "1.0"
+57 -7
View File
@@ -10,6 +10,7 @@ use anyhow::Context;
use clap::{Arg, ArgAction, ArgMatches};
use reqwest::RequestBuilder;
use spacetimedb_lib::de::serde::SeedWrapper;
use spacetimedb_lib::sats::satn::PsqlClient;
use spacetimedb_lib::sats::{satn, ProductType, ProductValue, Typespace};
pub fn cli() -> clap::Command {
@@ -111,7 +112,7 @@ fn print_stmt_result(
for (pos, result) in if_empty
.into_iter()
.chain(stmt_results.iter().map(|stmt_result| {
let (stats, table) = stmt_result_to_table(stmt_result)?;
let (stats, table) = stmt_result_to_table(PsqlClient::SpacetimeDB, stmt_result)?;
anyhow::Ok(StmtResult {
stats: with_stats.is_some().then_some(stats),
@@ -157,12 +158,13 @@ pub(crate) async fn run_sql(builder: RequestBuilder, sql: &str, with_stats: bool
Ok(())
}
fn stmt_result_to_table(stmt_result: &SqlStmtResult) -> anyhow::Result<(StmtStats, tabled::Table)> {
fn stmt_result_to_table(client: PsqlClient, stmt_result: &SqlStmtResult) -> anyhow::Result<(StmtStats, tabled::Table)> {
let stats = StmtStats::from(stmt_result);
let SqlStmtResult { schema, rows, .. } = stmt_result;
let ty = Typespace::EMPTY.with_type(schema);
let table = build_table(
client,
schema,
rows.iter().map(|row| from_json_seed(row.get(), SeedWrapper(ty))),
)?;
@@ -194,6 +196,7 @@ pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error
/// Generates a [`tabled::Table`] from a schema and rows, using the style of a psql table.
fn build_table<E>(
client: PsqlClient,
schema: &ProductType,
rows: impl Iterator<Item = Result<ProductValue, E>>,
) -> Result<tabled::Table, E> {
@@ -211,6 +214,7 @@ fn build_table<E>(
let row = row?;
builder.push_record(ty.with_values(&row).enumerate().map(|(idx, value)| {
let ty = satn::PsqlType {
client,
tuple: ty.ty(),
field: &ty.ty().elements[idx],
idx,
@@ -446,8 +450,10 @@ Roundtrip time: 1.00ms"#,
Ok(())
}
fn expect_psql_table(ty: &ProductType, rows: Vec<ProductValue>, expected: &str) {
let table = build_table(ty, rows.into_iter().map(Ok::<_, ()>)).unwrap().to_string();
fn expect_psql_table(client: PsqlClient, ty: &ProductType, rows: Vec<ProductValue>, expected: &str) {
let table = build_table(client, ty, rows.into_iter().map(Ok::<_, ()>))
.unwrap()
.to_string();
let mut table = table.split('\n').map(|x| x.trim_end()).join("\n");
table.insert(0, '\n');
assert_eq!(expected, table);
@@ -476,14 +482,25 @@ Roundtrip time: 1.00ms"#,
];
expect_psql_table(
PsqlClient::SpacetimeDB,
&kind,
vec![value],
vec![value.clone()],
r#"
column 0 | column 1 | column 2 | column 3 | column 4 | column 5
----------+----------+--------------------------------------------------------------------+------------------------------------+---------------------------+-----------
"a" | 0 | 0x0000000000000000000000000000000000000000000000000000000000000000 | 0x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | +0.000000"#,
);
expect_psql_table(
PsqlClient::Postgres,
&kind,
vec![value],
r#"
column 0 | column 1 | column 2 | column 3 | column 4 | column 5
----------+----------+----------------------------------------------------------------------+--------------------------------------+-----------------------------+----------
"a" | 0 | "0x0000000000000000000000000000000000000000000000000000000000000000" | "0x00000000000000000000000000000000" | "1970-01-01T00:00:00+00:00" | "P0D""#,
);
// Check struct
let kind: ProductType = [
("bool", AlgebraicType::Bool),
@@ -507,6 +524,7 @@ Roundtrip time: 1.00ms"#,
];
expect_psql_table(
PsqlClient::SpacetimeDB,
&kind,
vec![value.clone()],
r#"
@@ -515,12 +533,23 @@ Roundtrip time: 1.00ms"#,
true | "This is spacetimedb" | 0x01020304050607 | 0x0000000000000000000000000000000000000000000000000000000000000000 | 0x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | +0.000000"#,
);
expect_psql_table(
PsqlClient::Postgres,
&kind,
vec![value.clone()],
r#"
bool | str | bytes | identity | connection_id | timestamp | duration
------+-----------------------+--------------------+----------------------------------------------------------------------+--------------------------------------+-----------------------------+----------
true | "This is spacetimedb" | "0x01020304050607" | "0x0000000000000000000000000000000000000000000000000000000000000000" | "0x00000000000000000000000000000000" | "1970-01-01T00:00:00+00:00" | "P0D""#,
);
// Check nested struct, tuple...
let kind: ProductType = [(None, AlgebraicType::product(kind))].into();
let value = product![value.clone()];
expect_psql_table(
PsqlClient::SpacetimeDB,
&kind,
vec![value.clone()],
r#"
@@ -529,17 +558,38 @@ Roundtrip time: 1.00ms"#,
(bool = true, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000000, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = +0.000000)"#,
);
expect_psql_table(
PsqlClient::Postgres,
&kind,
vec![value.clone()],
r#"
column 0
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
{"bool": true, "str": "This is spacetimedb", "bytes": "0x01020304050607", "identity": "0x0000000000000000000000000000000000000000000000000000000000000000", "connection_id": "0x00000000000000000000000000000000", "timestamp": "1970-01-01T00:00:00+00:00", "duration": "P0D"}"#,
);
let kind: ProductType = [("tuple", AlgebraicType::product(kind))].into();
let value = product![value];
expect_psql_table(
PsqlClient::SpacetimeDB,
&kind,
vec![value.clone()],
r#"
tuple
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(col_0 = (bool = true, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000000, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = +0.000000))"#,
);
expect_psql_table(
PsqlClient::Postgres,
&kind,
vec![value],
r#"
tuple
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(0 = (bool = true, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000000, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = +0.000000))"#,
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
{"col_0": {"bool": true, "str": "This is spacetimedb", "bytes": "0x01020304050607", "identity": "0x0000000000000000000000000000000000000000000000000000000000000000", "connection_id": "0x00000000000000000000000000000000", "timestamp": "1970-01-01T00:00:00+00:00", "duration": "P0D"}}"#,
);
Ok(())
+1 -1
View File
@@ -171,7 +171,7 @@ pub enum SetDefaultDomainResult {
///
/// Must match the regex `^[a-z0-9]+(-[a-z0-9]+)*$`
#[derive(Clone, Debug, serde_with::DeserializeFromStr, serde_with::SerializeDisplay)]
pub struct DatabaseName(String);
pub struct DatabaseName(pub String);
impl AsRef<str> for DatabaseName {
fn as_ref(&self) -> &str {
+12 -8
View File
@@ -222,6 +222,10 @@ impl<TV: TokenValidator + Send + Sync> TokenSigner for JwtKeyAuthProvider<TV> {
impl<TV: TokenValidator + Send + Sync> JwtAuthProvider for JwtKeyAuthProvider<TV> {
type TV = TV;
fn validator(&self) -> &Self::TV {
&self.validator
}
fn local_issuer(&self) -> &str {
&self.local_issuer
}
@@ -229,10 +233,6 @@ impl<TV: TokenValidator + Send + Sync> JwtAuthProvider for JwtKeyAuthProvider<TV
fn public_key_bytes(&self) -> &[u8] {
&self.keys.public_pem
}
fn validator(&self) -> &Self::TV {
&self.validator
}
}
#[cfg(test)]
@@ -260,6 +260,13 @@ mod tests {
}
}
pub async fn validate_token<S: NodeDelegate>(
state: &S,
token: &str,
) -> Result<SpacetimeIdentityClaims, TokenValidationError> {
state.jwt_auth_provider().validator().validate_token(token).await
}
pub struct SpacetimeAuthHeader {
auth: Option<SpacetimeAuth>,
}
@@ -272,10 +279,7 @@ impl<S: NodeDelegate + Send + Sync> axum::extract::FromRequestParts<S> for Space
return Ok(Self { auth: None });
};
let claims = state
.jwt_auth_provider()
.validator()
.validate_token(&creds.token)
let claims = validate_token(state, &creds.token)
.await
.map_err(AuthorizationRejection::Custom)?;
+30 -14
View File
@@ -7,7 +7,7 @@ use crate::auth::{
SpacetimeIdentityToken,
};
use crate::routes::subscribe::generate_random_connection_id;
use crate::util::{ByteStringBody, NameOrIdentity};
pub use crate::util::{ByteStringBody, NameOrIdentity};
use crate::{log_and_500, ControlStateDelegate, DatabaseDef, NodeDelegate};
use axum::body::{Body, Bytes};
use axum::extract::{Path, Query, State};
@@ -31,7 +31,7 @@ use spacetimedb_client_api_messages::name::{
};
use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9;
use spacetimedb_lib::identity::AuthCtx;
use spacetimedb_lib::{sats, Timestamp};
use spacetimedb_lib::{sats, ProductValue, Timestamp};
use spacetimedb_schema::auto_migrate::{
MigrationPolicy as SchemaMigrationPolicy, MigrationToken, PrettyPrintStyle as AutoMigratePrettyPrintStyle,
};
@@ -383,7 +383,7 @@ pub(crate) async fn worker_ctx_find_database(
#[derive(Deserialize)]
pub struct SqlParams {
name_or_identity: NameOrIdentity,
pub name_or_identity: NameOrIdentity,
}
#[derive(Deserialize)]
@@ -391,16 +391,16 @@ pub struct SqlQueryParams {
/// If `true`, return the query result only after its transaction offset
/// is confirmed to be durable.
#[serde(default)]
confirmed: bool,
pub confirmed: bool,
}
pub async fn sql<S>(
State(worker_ctx): State<S>,
Path(SqlParams { name_or_identity }): Path<SqlParams>,
Query(SqlQueryParams { confirmed }): Query<SqlQueryParams>,
Extension(auth): Extension<SpacetimeAuth>,
body: String,
) -> axum::response::Result<impl IntoResponse>
pub async fn sql_direct<S>(
worker_ctx: S,
SqlParams { name_or_identity }: SqlParams,
SqlQueryParams { confirmed }: SqlQueryParams,
caller_identity: Identity,
sql: String,
) -> axum::response::Result<Vec<SqlStmtResult<ProductValue>>>
where
S: NodeDelegate + ControlStateDelegate,
{
@@ -412,7 +412,7 @@ where
.await?
.ok_or(NO_SUCH_DATABASE)?;
let auth = AuthCtx::new(database.owner_identity, auth.identity);
let auth = AuthCtx::new(database.owner_identity, caller_identity);
log::debug!("auth: {auth:?}");
let host = worker_ctx
@@ -420,7 +420,21 @@ where
.await
.map_err(log_and_500)?
.ok_or(StatusCode::NOT_FOUND)?;
let json = host.exec_sql(auth, database, confirmed, body).await?;
host.exec_sql(auth, database, confirmed, sql).await
}
pub async fn sql<S>(
State(worker_ctx): State<S>,
Path(name_or_identity): Path<SqlParams>,
Query(params): Query<SqlQueryParams>,
Extension(auth): Extension<SpacetimeAuth>,
body: String,
) -> axum::response::Result<impl IntoResponse>
where
S: NodeDelegate + ControlStateDelegate,
{
let json = sql_direct(worker_ctx, name_or_identity, params, auth.identity, body).await?;
let total_duration = json.iter().fold(0, |acc, x| acc + x.total_duration_micros);
@@ -488,7 +502,9 @@ pub struct PublishDatabaseQueryParams {
policy: MigrationPolicy,
}
use spacetimedb_client_api_messages::http::SqlStmtResult;
use std::env;
fn require_spacetime_auth_for_creation() -> bool {
env::var("TEMP_REQUIRE_SPACETIME_AUTH").is_ok_and(|v| !v.is_empty())
}
@@ -526,7 +542,7 @@ pub async fn publish<S: NodeDelegate + ControlStateDelegate>(
// so, unless you are the owner, this will fail.
let (database_identity, db_name) = match &name_or_identity {
Some(noa) => match noa.try_resolve(&ctx).await? {
Some(noa) => match noa.try_resolve(&ctx).await.map_err(log_and_500)? {
Ok(resolved) => (resolved, noa.name()),
Err(name) => {
// `name_or_identity` was a `NameOrIdentity::Name`, but no record
+6 -3
View File
@@ -97,10 +97,10 @@ impl NameOrIdentity {
pub async fn try_resolve(
&self,
ctx: &(impl ControlStateReadAccess + ?Sized),
) -> axum::response::Result<Result<Identity, &DatabaseName>> {
) -> anyhow::Result<Result<Identity, &DatabaseName>> {
Ok(match self {
Self::Identity(identity) => Ok(Identity::from(*identity)),
Self::Name(name) => ctx.lookup_identity(name.as_ref()).map_err(log_and_500)?.ok_or(name),
Self::Name(name) => ctx.lookup_identity(name.as_ref())?.ok_or(name),
})
}
@@ -108,7 +108,10 @@ impl NameOrIdentity {
/// response if `self` is a [`NameOrIdentity::Name`] for which no
/// corresponding [`Identity`] is found in the SpacetimeDB DNS.
pub async fn resolve(&self, ctx: &(impl ControlStateReadAccess + ?Sized)) -> axum::response::Result<Identity> {
self.try_resolve(ctx).await?.map_err(|_| StatusCode::NOT_FOUND.into())
self.try_resolve(ctx)
.await
.map_err(log_and_500)?
.map_err(|_| StatusCode::NOT_FOUND.into())
}
}
+6 -3
View File
@@ -15,6 +15,7 @@ pub struct JwtKeys {
pub public: DecodingKey,
pub public_pem: Box<[u8]>,
pub private: EncodingKey,
pub private_pem: Box<[u8]>,
pub kid: Option<String>,
}
@@ -23,15 +24,17 @@ impl JwtKeys {
/// respectively.
///
/// The key files must be PEM encoded ECDSA P256 keys.
pub fn new(public_pem: impl Into<Box<[u8]>>, private_pem: &[u8]) -> anyhow::Result<Self> {
pub fn new(public_pem: impl Into<Box<[u8]>>, private_pem: impl Into<Box<[u8]>>) -> anyhow::Result<Self> {
let public_pem = public_pem.into();
let private_pem = private_pem.into();
let public = DecodingKey::from_ec_pem(&public_pem)?;
let private = EncodingKey::from_ec_pem(private_pem)?;
let private = EncodingKey::from_ec_pem(&private_pem)?;
Ok(Self {
public,
private,
public_pem,
private_pem,
kid: None,
})
}
@@ -75,7 +78,7 @@ pub struct EcKeyPair {
impl TryFrom<EcKeyPair> for JwtKeys {
type Error = anyhow::Error;
fn try_from(pair: EcKeyPair) -> anyhow::Result<Self> {
JwtKeys::new(pair.public_key_bytes, &pair.private_key_bytes)
JwtKeys::new(pair.public_key_bytes, pair.private_key_bytes)
}
}
+4
View File
@@ -63,6 +63,10 @@ pub struct Node {
///
/// If `None`, the node is not currently live.
pub advertise_addr: Option<String>,
/// The address this node is running its postgres API at.
///
/// If `None`, the node is not currently live.
pub pg_addr: Option<String>,
}
#[derive(Clone, PartialEq, Serialize, Deserialize)]
pub struct NodeStatus {
+22
View File
@@ -0,0 +1,22 @@
[package]
name = "spacetimedb-pg"
version.workspace = true
edition.workspace = true
rust-version.workspace = true
license-file = "LICENSE"
description = "Postgres wire protocol Server support for SpacetimeDB"
[dependencies]
spacetimedb-client-api-messages.workspace = true
spacetimedb-client-api.workspace = true
spacetimedb-lib.workspace = true
anyhow.workspace = true
async-trait.workspace = true
axum.workspace = true
futures.workspace = true
http.workspace = true
log.workspace = true
pgwire.workspace = true
thiserror.workspace = true
tokio.workspace = true
+1
View File
@@ -0,0 +1 @@
../../licenses/BSL.txt
+3
View File
@@ -0,0 +1,3 @@
> ⚠️ **Internal Crate** ⚠️
>
> This crate is intended for internal use only. It is **not** stable and may change without notice.
+301
View File
@@ -0,0 +1,301 @@
use crate::pg_server::PgError;
use pgwire::api::portal::Format;
use pgwire::api::results::{DataRowEncoder, FieldInfo};
use pgwire::api::Type;
use spacetimedb_lib::sats::satn::{PsqlChars, PsqlPrintFmt, PsqlType, TypedWriter};
use spacetimedb_lib::sats::{satn, ValueWithType};
use spacetimedb_lib::{
ser, AlgebraicType, AlgebraicValue, ProductType, ProductTypeElement, ProductValue, TimeDuration, Timestamp,
};
use std::borrow::Cow;
use std::sync::Arc;
pub(crate) fn row_desc(schema: &ProductType, format: &Format) -> Arc<Vec<FieldInfo>> {
Arc::new(
schema
.elements
.iter()
.enumerate()
.map(|(pos, ty)| {
let field_name = ty.name.clone().map(Into::into).unwrap_or_else(|| format!("col_{pos}"));
let field_type = type_of(schema, ty);
FieldInfo::new(field_name, None, None, field_type, format.format_for(pos))
})
.collect(),
)
}
pub(crate) fn type_of(schema: &ProductType, ty: &ProductTypeElement) -> Type {
let format = PsqlPrintFmt::use_fmt(schema, ty, ty.name());
match &ty.algebraic_type {
AlgebraicType::String => Type::VARCHAR,
AlgebraicType::Bool => Type::BOOL,
AlgebraicType::U8 | AlgebraicType::I8 | AlgebraicType::I16 => Type::INT2,
AlgebraicType::U16 | AlgebraicType::I32 => Type::INT4,
AlgebraicType::U32 | AlgebraicType::I64 => Type::INT8,
AlgebraicType::U64 | AlgebraicType::I128 | AlgebraicType::U128 | AlgebraicType::I256 | AlgebraicType::U256 => {
Type::NUMERIC
}
AlgebraicType::F32 => Type::FLOAT4,
AlgebraicType::F64 => Type::FLOAT8,
AlgebraicType::Array(ty) => match *ty.elem_ty {
AlgebraicType::String => Type::VARCHAR_ARRAY,
AlgebraicType::Bool => Type::BOOL_ARRAY,
AlgebraicType::U8 => Type::BYTEA,
AlgebraicType::I8 | AlgebraicType::I16 => Type::INT2_ARRAY,
AlgebraicType::U16 | AlgebraicType::I32 => Type::INT4_ARRAY,
AlgebraicType::U32 | AlgebraicType::I64 => Type::INT8_ARRAY,
AlgebraicType::U64
| AlgebraicType::I128
| AlgebraicType::U128
| AlgebraicType::I256
| AlgebraicType::U256 => Type::NUMERIC_ARRAY,
_ => Type::ANYARRAY,
},
AlgebraicType::Product(_) => match format {
PsqlPrintFmt::Hex => Type::BYTEA_ARRAY,
PsqlPrintFmt::Timestamp => Type::TIMESTAMP,
PsqlPrintFmt::Duration => Type::INTERVAL,
_ => Type::JSON,
},
AlgebraicType::Sum(sum) if sum.is_simple_enum() => Type::ANYENUM,
AlgebraicType::Sum(_) => Type::JSON,
_ => Type::UNKNOWN,
}
}
impl ser::Error for PgError {
fn custom<T: std::fmt::Display>(msg: T) -> Self {
PgError::Other(anyhow::anyhow!(msg.to_string()))
}
}
pub(crate) struct PsqlFormatter<'a> {
pub(crate) encoder: &'a mut DataRowEncoder,
}
impl TypedWriter for PsqlFormatter<'_> {
type Error = PgError;
fn write<W: std::fmt::Display>(&mut self, value: W) -> Result<(), Self::Error> {
self.encoder.encode_field(&value.to_string())?;
Ok(())
}
fn write_bool(&mut self, value: bool) -> Result<(), Self::Error> {
self.encoder.encode_field(&value)?;
Ok(())
}
fn write_string(&mut self, value: &str) -> Result<(), Self::Error> {
self.encoder.encode_field(&value)?;
Ok(())
}
fn write_bytes(&mut self, value: &[u8]) -> Result<(), Self::Error> {
self.encoder.encode_field(&value)?;
Ok(())
}
fn write_hex(&mut self, value: &[u8]) -> Result<(), Self::Error> {
self.encoder.encode_field(&value)?;
Ok(())
}
fn write_timestamp(&mut self, value: Timestamp) -> Result<(), Self::Error> {
self.encoder.encode_field(&value.to_rfc3339()?)?;
Ok(())
}
fn write_duration(&mut self, value: TimeDuration) -> Result<(), Self::Error> {
self.encoder.encode_field(&value.to_iso8601())?;
Ok(())
}
fn write_alt_record(
&mut self,
ty: &PsqlType,
value: &ValueWithType<'_, ProductValue>,
) -> Result<bool, Self::Error> {
let json = satn::PsqlWrapper { ty: ty.clone(), value }.to_string();
self.encoder.encode_field(&json)?;
Ok(true)
}
fn write_record(
&mut self,
_fields: Vec<(Cow<str>, PsqlType, ValueWithType<AlgebraicValue>)>,
) -> Result<(), Self::Error> {
unreachable!("Use `write_alt_record` for records in PSQL format");
}
fn write_variant(
&mut self,
tag: u8,
ty: PsqlType,
name: Option<&str>,
value: ValueWithType<AlgebraicValue>,
) -> Result<(), Self::Error> {
// Is a simple enum?
if let AlgebraicType::Sum(sum) = &ty.field.algebraic_type {
if sum.is_simple_enum() {
if let Some(variant_name) = name {
self.encoder.encode_field(&variant_name)?;
return Ok(());
}
}
}
let PsqlChars { start, sep, end, quote } = ty.client.format_chars();
let name = name.map(Cow::from).unwrap_or_else(|| Cow::from(tag.to_string()));
let json = format!(
"{start}{quote}{name}{quote}{sep} {}{end}",
satn::PsqlWrapper { ty, value }
);
self.encoder.encode_field(&json)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pg_server::to_rows;
use futures::StreamExt;
use spacetimedb_client_api_messages::http::SqlStmtResult;
use spacetimedb_lib::sats::algebraic_value::Packed;
use spacetimedb_lib::sats::{i256, product, u256, AlgebraicType, ProductType, SumTypeVariant};
use spacetimedb_lib::{ConnectionId, Identity};
async fn run(schema: ProductType, row: ProductValue) -> String {
let header = row_desc(&schema, &Format::UnifiedText);
let stmt = SqlStmtResult {
schema,
rows: vec![row],
total_duration_micros: 0,
stats: Default::default(),
};
let mut stream = to_rows(stmt, header).unwrap();
let mut result = String::new();
if let Some(row) = stream.next().await {
result = String::from_utf8_lossy(row.unwrap().data.freeze().as_ref()).to_string();
}
result
}
#[tokio::test]
async fn test_primitives() {
let schema = ProductType::from([
AlgebraicType::U8,
AlgebraicType::I8,
AlgebraicType::I16,
AlgebraicType::U16,
AlgebraicType::I32,
AlgebraicType::U32,
AlgebraicType::I64,
AlgebraicType::U64,
AlgebraicType::I128,
AlgebraicType::U128,
AlgebraicType::I256,
AlgebraicType::U256,
AlgebraicType::F32,
AlgebraicType::F64,
AlgebraicType::String,
AlgebraicType::Bool,
]);
let value = product![
1u8,
-1i8,
-2i16,
3u16,
-4i32,
5u32,
-6i64,
7u64,
Packed::from(-8i128),
Packed::from(9u128),
i256::from(-10),
u256::from(11u128),
12.34f32,
56.78f64,
"test".to_string(),
true,
];
let row = run(schema, value).await;
assert_eq!(row, "\0\0\0\u{1}1\0\0\0\u{2}-1\0\0\0\u{2}-2\0\0\0\u{1}3\0\0\0\u{2}-4\0\0\0\u{1}5\0\0\0\u{2}-6\0\0\0\u{1}7\0\0\0\u{2}-8\0\0\0\u{1}9\0\0\0\u{3}-10\0\0\0\u{2}11\0\0\0\u{5}12.34\0\0\0\u{5}56.78\0\0\0\u{4}test\0\0\0\u{1}t");
}
#[tokio::test]
async fn test_enum() {
let some = AlgebraicType::option(AlgebraicType::I64);
let schema = ProductType::from([some.clone(), some]);
let value = product![
AlgebraicValue::sum(0, AlgebraicValue::I64(1)), // Some(1)
AlgebraicValue::sum(1, AlgebraicValue::unit()), // None
];
let row = run(schema, value).await;
assert_eq!(row, "\0\0\0\u{b}{\"some\": 1}\0\0\0\u{c}{\"none\": {}}");
let color = AlgebraicType::Sum([SumTypeVariant::new_named(AlgebraicType::I64, "Gray")].into());
let nested = AlgebraicType::option(color.clone());
let schema = ProductType::from([color, nested]);
// {"Gray": 1}, {"some": {"Gray": 2}}
let value = product![
AlgebraicValue::sum(0, AlgebraicValue::I64(1)), // Gray(1)
AlgebraicValue::sum(0, AlgebraicValue::sum(0, AlgebraicValue::I64(2))), // Some(Gray(2))
];
let row = run(schema.clone(), value.clone()).await;
assert_eq!(row, "\0\0\0\u{b}{\"Gray\": 1}\0\0\0\u{15}{\"some\": {\"Gray\": 2}}");
// Now nested product
let product = AlgebraicType::product([
ProductTypeElement::new(AlgebraicType::Product(schema), Some("x".into())),
ProductTypeElement::new(AlgebraicType::String, Some("y".into())),
]);
let schema = ProductType::from([product.clone()]);
let value = product![AlgebraicValue::product(vec![
value.into(),
AlgebraicValue::String("a".into()),
])];
let row = run(schema, value).await;
assert_eq!(
row,
"\0\0\0G{\"x\": {\"col_0\": {\"Gray\": 1}, \"col_1\": {\"some\": {\"Gray\": 2}}}, \"y\": \"a\"}"
);
// Now a simple enum
let names = AlgebraicType::simple_enum(["A", "B", "C"].into_iter());
let schema = ProductType::from([names.clone(), names.clone(), names]);
let value = product![
AlgebraicValue::enum_simple(0), // A
AlgebraicValue::enum_simple(1), // B
AlgebraicValue::enum_simple(2), // C
];
let row = run(schema, value).await;
assert_eq!(row, "\0\0\0\u{1}A\0\0\0\u{1}B\0\0\0\u{1}C");
}
#[tokio::test]
async fn test_special_types() {
let schema = ProductType::from([
AlgebraicType::identity(),
AlgebraicType::connection_id(),
AlgebraicType::time_duration(),
AlgebraicType::timestamp(),
AlgebraicType::bytes(),
]);
let value = product![
Identity::ZERO,
ConnectionId::ZERO,
TimeDuration::from_micros(0),
Timestamp::from_micros_since_unix_epoch(1622545800000),
AlgebraicValue::Bytes("test".as_bytes().into()),
];
let row = run(schema, value).await;
assert_eq!(row, "\0\0\0B\\x0000000000000000000000000000000000000000000000000000000000000000\0\0\0\"\\x00000000000000000000000000000000\0\0\0\u{3}P0D\0\0\0\u{1d}1970-01-19T18:42:25.800+00:00\0\0\0\n\\x74657374");
}
}
+2
View File
@@ -0,0 +1,2 @@
mod encoder;
pub mod pg_server;
+381
View File
@@ -0,0 +1,381 @@
use std::fmt::Debug;
use std::sync::Arc;
use crate::encoder::{row_desc, PsqlFormatter};
use async_trait::async_trait;
use axum::body::to_bytes;
use axum::response::IntoResponse;
use futures::{stream, Sink};
use futures::{SinkExt, Stream};
use http::StatusCode;
use pgwire::api::auth::{
finish_authentication, save_startup_parameters_to_metadata, DefaultServerParameterProvider, LoginInfo,
StartupHandler,
};
use pgwire::api::portal::Format;
use pgwire::api::query::SimpleQueryHandler;
use pgwire::api::results::{DataRowEncoder, FieldInfo, QueryResponse, Response, Tag};
use pgwire::api::{ClientInfo, METADATA_DATABASE};
use pgwire::api::{PgWireConnectionState, PgWireServerHandlers};
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use pgwire::messages::data::DataRow;
use pgwire::messages::startup::Authentication;
use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
use pgwire::tokio::process_socket;
use spacetimedb_client_api::auth::validate_token;
use spacetimedb_client_api::routes::database;
use spacetimedb_client_api::routes::database::{SqlParams, SqlQueryParams};
use spacetimedb_client_api::{ControlStateReadAccess, ControlStateWriteAccess, NodeDelegate};
use spacetimedb_client_api_messages::http::SqlStmtResult;
use spacetimedb_client_api_messages::name::DatabaseName;
use spacetimedb_lib::sats::satn::{PsqlClient, TypedSerializer};
use spacetimedb_lib::sats::{satn, Serialize, Typespace};
use spacetimedb_lib::version::spacetimedb_lib_version;
use spacetimedb_lib::{Identity, ProductValue};
use thiserror::Error;
use tokio::net::TcpListener;
use tokio::sync::{Mutex, Notify};
#[derive(Error, Debug)]
pub(crate) enum PgError {
#[error("(metadata) {0}")]
MetadataError(anyhow::Error),
#[error("(Sql) {0}")]
Sql(String),
#[error("Database name is required")]
DatabaseNameRequired,
#[error(transparent)]
Pg(#[from] PgWireError),
#[error("SSL is not supported by SpacetimeDB")]
SSLNotSupported,
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl From<PgError> for PgWireError {
fn from(err: PgError) -> Self {
if let PgError::Pg(err) = err {
err
} else {
PgWireError::ApiError(Box::new(err))
}
}
}
#[derive(Clone)]
struct Metadata {
database: String,
caller_identity: Identity,
}
pub(crate) fn to_rows(
stmt: SqlStmtResult<ProductValue>,
header: Arc<Vec<FieldInfo>>,
) -> Result<impl Stream<Item = PgWireResult<DataRow>>, PgError> {
let mut results = Vec::with_capacity(stmt.rows.len());
let ty = Typespace::EMPTY.with_type(&stmt.schema);
for row in stmt.rows {
let mut encoder = DataRowEncoder::new(header.clone());
for (idx, value) in ty.with_values(&row).enumerate() {
let ty = satn::PsqlType {
client: PsqlClient::Postgres,
tuple: ty.ty(),
field: &ty.ty().elements[idx],
idx,
};
let mut fmt = PsqlFormatter { encoder: &mut encoder };
value.serialize(TypedSerializer { ty: &ty, f: &mut fmt })?;
}
results.push(encoder.finish());
}
Ok(stream::iter(results))
}
fn stats(stmt: &SqlStmtResult<ProductValue>) -> String {
let mut info = Vec::new();
if stmt.stats.rows_inserted != 0 {
info.push(format!("inserted: {}", stmt.stats.rows_inserted));
}
if stmt.stats.rows_deleted != 0 {
info.push(format!("deleted: {}", stmt.stats.rows_deleted));
}
if stmt.stats.rows_updated != 0 {
info.push(format!("updated: {}", stmt.stats.rows_updated));
}
info.push(format!(
"server: {:.2?}",
std::time::Duration::from_micros(stmt.total_duration_micros)
));
info.join(", ")
}
struct ResponseWrapper<T>(T);
impl<T> IntoResponse for ResponseWrapper<T> {
fn into_response(self) -> axum::response::Response {
unreachable!("Blank impl to satisfy IntoResponse")
}
}
async fn response<T>(res: axum::response::Result<T>, database: &str) -> Result<T, PgError> {
match res.map(ResponseWrapper) {
Ok(sql) => Ok(sql.0),
err => {
let res = err.into_response();
if res.status() == StatusCode::NOT_FOUND {
log::error!("PG: Database not found: {database}");
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"FATAL".to_string(),
"3D000".to_string(),
format!("database \"{database}\" does not exist"),
)))
.into());
}
let bytes = to_bytes(res.into_body(), usize::MAX)
.await
.map_err(|err| PgWireError::ApiError(Box::new(err)))?;
let err = String::from_utf8_lossy(&bytes);
log::error!("PG: Error for database {database}: {err}");
Err(PgError::Sql(format!("{err}")))
}
}
}
struct PgSpacetimeDB<T> {
ctx: T,
cached: Mutex<Option<Metadata>>,
parameter_provider: DefaultServerParameterProvider,
}
impl<T: ControlStateReadAccess + ControlStateWriteAccess + NodeDelegate + Clone> PgSpacetimeDB<T> {
async fn exe_sql<'a>(&self, query: String) -> PgWireResult<Vec<Response<'a>>> {
let params = self.cached.lock().await.clone().unwrap();
let db = SqlParams {
name_or_identity: database::NameOrIdentity::Name(DatabaseName(params.database.clone())),
};
let sql = match response(
database::sql_direct(
self.ctx.clone(),
db,
SqlQueryParams { confirmed: true },
params.caller_identity,
query.to_string(),
)
.await,
&params.database,
)
.await
{
Ok(sql) => sql,
Err(PgError::Pg(PgWireError::UserError(err))) => {
return Ok(vec![Response::Error(err)]);
}
Err(err) => {
return Err(err.into());
}
};
let mut result = Vec::with_capacity(sql.len());
for sql_result in sql {
let header = row_desc(&sql_result.schema, &Format::UnifiedText);
if sql_result.rows.is_empty() && !query.to_uppercase().contains("SELECT") {
let tag = Tag::new(&stats(&sql_result));
result.push(Response::Execution(tag));
} else {
let rows = to_rows(sql_result, header.clone())?;
let q = QueryResponse::new(header, rows);
result.push(Response::Query(q));
}
}
Ok(result)
}
}
async fn close_client<C, E>(client: &mut C, err: E) -> PgWireResult<()>
where
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
pgwire::messages::response::ErrorResponse: From<E>,
{
let err = pgwire::messages::response::ErrorResponse::from(err);
client.feed(PgWireBackendMessage::ErrorResponse(err)).await?;
client.close().await?;
Ok(())
}
#[async_trait]
impl<T: Sync + Send + ControlStateReadAccess + ControlStateWriteAccess + NodeDelegate> StartupHandler
for PgSpacetimeDB<T>
{
async fn on_startup<C>(&self, client: &mut C, message: PgWireFrontendMessage) -> PgWireResult<()>
where
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
match message {
PgWireFrontendMessage::Startup(ref startup) => {
save_startup_parameters_to_metadata(client, startup);
client.set_state(PgWireConnectionState::AuthenticationInProgress);
let login_info = LoginInfo::from_client_info(client);
if login_info.database().is_none() {
return Err(PgError::DatabaseNameRequired.into());
}
client
.send(PgWireBackendMessage::Authentication(Authentication::CleartextPassword))
.await?;
}
PgWireFrontendMessage::PasswordMessageFamily(pwd) => {
let params = client.metadata();
let param = |param: &str| {
params
.get(param)
.map(String::from)
.ok_or_else(|| PgError::MetadataError(anyhow::anyhow!("Missing parameter: {}", param)))
};
// We don't support `METADATA_USER` because we don't have a user management system.
let database = param(METADATA_DATABASE)?;
let pwd = pwd.into_password()?;
if let Ok(application_name) = param("application_name") {
log::info!("PG: Connecting to database: {database}, by {application_name}",);
} else {
log::info!("PG: Connecting to database: {database}");
}
let name = database::NameOrIdentity::Name(DatabaseName(database.clone()));
match response(name.resolve(&self.ctx).await, &database).await {
Ok(identity) => identity,
Err(PgError::Pg(PgWireError::UserError(err))) => {
return close_client(client, *err).await;
}
Err(err) => {
return Err(err.into());
}
};
let caller_identity = match validate_token(&self.ctx, &pwd.password).await {
Ok(claims) => claims.identity,
Err(err) => {
log::error!(
"PG: Authentication failed for identity `{}` on database {database}: {err}",
pwd.password
);
let err = ErrorInfo::new("FATAL".to_owned(), "28P01".to_owned(), err.to_string());
return close_client(client, err).await;
}
};
log::info!("PG: Connected to database: {database} using identity `{caller_identity}`");
let metadata = Metadata {
database,
caller_identity,
};
self.cached.lock().await.clone_from(&Some(metadata));
finish_authentication(client, &self.parameter_provider).await?;
}
PgWireFrontendMessage::SslRequest(_) => {
let err = PgError::SSLNotSupported;
log::error!("{err}");
let err = ErrorInfo::new("FATAL".to_owned(), "28P01".to_owned(), err.to_string());
return close_client(client, err).await;
}
// The other messages are for features not supported by SpacetimeDB, that are rejected by the parser.
_ => {
unreachable!("Unsupported startup message: {message:?}");
}
}
Ok(())
}
}
#[async_trait]
impl<T: Sync + Send + ControlStateReadAccess + ControlStateWriteAccess + NodeDelegate + Clone> SimpleQueryHandler
for PgSpacetimeDB<T>
{
async fn do_query<'a, C>(&self, _client: &mut C, query: &str) -> PgWireResult<Vec<Response<'a>>>
where
C: ClientInfo + Unpin + Send + Sync,
{
self.exe_sql(query.to_string()).await
}
}
#[derive(Clone)]
pub struct PgSpacetimeDBFactory<T> {
handler: Arc<PgSpacetimeDB<T>>,
}
impl<T> PgSpacetimeDBFactory<T> {
pub fn new(ctx: T) -> Self {
let mut parameter_provider = DefaultServerParameterProvider::default();
parameter_provider.server_version = format!("spacetime {}", spacetimedb_lib_version());
Self {
handler: Arc::new(PgSpacetimeDB {
ctx,
// This is a placeholder, it will be set in the startup handler
cached: None.into(),
parameter_provider,
}),
}
}
}
impl<T: Sync + Send + ControlStateReadAccess + ControlStateWriteAccess + NodeDelegate + Clone> PgWireServerHandlers
for PgSpacetimeDBFactory<T>
{
fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
self.handler.clone()
}
// TODO: fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {}
fn startup_handler(&self) -> Arc<impl StartupHandler> {
self.handler.clone()
}
}
pub async fn start_pg<T: ControlStateReadAccess + ControlStateWriteAccess + NodeDelegate + Clone + 'static>(
shutdown: Arc<Notify>,
ctx: T,
tcp: TcpListener,
) {
let factory = Arc::new(PgSpacetimeDBFactory::new(ctx));
log::debug!(
"PG: Starting SpacetimeDB Protocol listening on {}",
tcp.local_addr().unwrap()
);
loop {
tokio::select! {
accept_result = tcp.accept() => {
match accept_result {
Ok((stream, _addr)) => {
let factory_ref = factory.clone();
tokio::spawn(async move {
process_socket(stream, None, factory_ref).await.inspect_err(|err|{
log::error!("PG: Error processing socket: {err:?}");
})
});
}
Err(e) => {
log::error!("PG: Accept error: {e}");
}
}
}
_ = shutdown.notified() => {
log::info!("PG: Shutting down PostgreSQL server.");
break;
}
}
}
}
+398 -227
View File
@@ -1,13 +1,14 @@
use crate::de::DeserializeSeed;
use crate::time_duration::TimeDuration;
use crate::timestamp::Timestamp;
use crate::{i256, u256, AlgebraicValue, WithTypespace};
use crate::{i256, u256, AlgebraicType, AlgebraicValue, ProductValue, Serialize, SumValue, ValueWithType};
use crate::{ser, ProductType, ProductTypeElement};
use core::fmt;
use core::fmt::Write as _;
use derive_more::{From, Into};
use derive_more::{Display, From, Into};
use std::borrow::Cow;
use std::marker::PhantomData;
/// An extension trait for [`Serialize`](ser::Serialize) providing formatting methods.
/// An extension trait for [`Serialize`] providing formatting methods.
pub trait Satn: ser::Serialize {
/// Formats the value using the SATN data format into the formatter `f`.
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
@@ -18,9 +19,12 @@ pub trait Satn: ser::Serialize {
/// Formats the value using the postgres SATN(PsqlFormatter { f }, /* PsqlType */) formatter `f`.
fn fmt_psql(&self, f: &mut fmt::Formatter, ty: &PsqlType<'_>) -> fmt::Result {
Writer::with(f, |f| {
self.serialize(PsqlFormatter {
fmt: SatnFormatter { f },
self.serialize(TypedSerializer {
ty,
f: &mut SqlFormatter {
fmt: SatnFormatter { f },
ty,
},
})
})?;
Ok(())
@@ -229,9 +233,30 @@ struct SatnFormatter<'a, 'f> {
f: Writer<'a, 'f>,
}
impl SatnFormatter<'_, '_> {
fn ser_variant<T: ser::Serialize + ?Sized>(
&mut self,
_tag: u8,
name: Option<&str>,
value: &T,
) -> Result<(), SatnError> {
write!(self, "(")?;
EntryWrapper::<','>::new(self.f.as_mut()).entry(|mut f| {
if let Some(name) = name {
write!(f, "{name}")?;
}
write!(f, " = ")?;
value.serialize(SatnFormatter { f })?;
Ok(())
})?;
write!(self, ")")?;
Ok(())
}
}
/// An error occurred during serialization to the SATS data format.
#[derive(From, Into)]
struct SatnError(fmt::Error);
pub struct SatnError(fmt::Error);
impl ser::Error for SatnError {
fn custom<T: fmt::Display>(_msg: T) -> Self {
@@ -331,20 +356,11 @@ impl<'a, 'f> ser::Serializer for SatnFormatter<'a, 'f> {
fn serialize_variant<T: ser::Serialize + ?Sized>(
mut self,
_tag: u8,
tag: u8,
name: Option<&str>,
value: &T,
) -> Result<Self::Ok, Self::Error> {
write!(self, "(")?;
EntryWrapper::<','>::new(self.f.as_mut()).entry(|mut f| {
if let Some(name) = name {
write!(f, "{name}")?;
}
write!(f, " = ")?;
value.serialize(SatnFormatter { f })?;
Ok(())
})?;
write!(self, ")")
self.ser_variant(tag, name, value)
}
}
@@ -427,119 +443,41 @@ impl ser::SerializeNamedProduct for NamedFormatter<'_, '_> {
}
}
struct PsqlEntryWrapper<'a, 'f, const SEP: char> {
entry: EntryWrapper<'a, 'f, SEP>,
/// The index of the element.
idx: usize,
ty: &'a PsqlType<'a>,
/// Which client is used to format the `SQL` output?
#[derive(PartialEq, Copy, Clone, Debug)]
pub enum PsqlClient {
SpacetimeDB,
Postgres,
}
/// Provides the data format for named products for `SQL`.
struct PsqlNamedFormatter<'a, 'f> {
/// The formatter for each element separating elements by a `,`.
f: PsqlEntryWrapper<'a, 'f, ','>,
/// If is not [Self::is_special] to control if we start with `(`
start: bool,
/// Remember what format we are using
use_fmt: PsqlPrintFmt,
pub struct PsqlChars {
pub start: char,
pub sep: &'static str,
pub end: char,
pub quote: &'static str,
}
impl<'a, 'f> PsqlNamedFormatter<'a, 'f> {
pub fn new(ty: &'a PsqlType<'a>, f: Writer<'a, 'f>) -> Self {
Self {
start: true,
f: PsqlEntryWrapper {
entry: EntryWrapper::new(f),
idx: 0,
ty,
impl PsqlClient {
pub fn format_chars(&self) -> PsqlChars {
match self {
PsqlClient::SpacetimeDB => PsqlChars {
start: '(',
sep: " =",
end: ')',
quote: "",
},
PsqlClient::Postgres => PsqlChars {
start: '{',
sep: ":",
end: '}',
quote: "\"",
},
// Will set later
use_fmt: PsqlPrintFmt::Satn,
}
}
}
impl ser::SerializeNamedProduct for PsqlNamedFormatter<'_, '_> {
type Ok = ();
type Error = SatnError;
fn serialize_element<T: Satn + ser::Serialize + ?Sized>(
&mut self,
name: Option<&str>,
elem: &T,
) -> Result<(), Self::Error> {
// For binary data & special types, output in `hex` format and skip the tagging of each value
// We need to check for both the enclosing(`self.f.ty`) type and the inner element(`name`) type.
self.use_fmt = self.f.ty.use_fmt(name);
let res = self.f.entry.entry(|mut f| {
let PsqlType { tuple, field, idx } = self.f.ty;
if !self.use_fmt.is_special() {
if self.start {
write!(f, "(")?;
self.start = false;
}
// Format the name or use the index if unnamed.
if let Some(name) = name {
write!(f, "{name}")?;
} else {
write!(f, "{idx}")?;
}
write!(f, " = ")?;
}
//Is a nested product type?
let (tuple, field, idx) = if let Some(product) = field.algebraic_type.as_product() {
(product, &product.elements[self.f.idx], self.f.idx)
} else {
(*tuple, *field, *idx)
};
elem.serialize(PsqlFormatter {
fmt: SatnFormatter { f },
ty: &PsqlType { tuple, field, idx },
})?;
Ok(())
});
// Advance to the next field.
if !self.use_fmt.is_special() {
self.f.idx += 1;
}
res?;
Ok(())
}
fn end(mut self) -> Result<Self::Ok, Self::Error> {
if !self.use_fmt.is_special() {
write!(self.f.entry.fmt, ")")?;
}
Ok(())
}
}
/// Provides the data format for unnamed products for `SQL`.
struct PsqlSeqFormatter<'a, 'f> {
/// Delegates to the named format.
inner: PsqlNamedFormatter<'a, 'f>,
}
impl ser::SerializeSeqProduct for PsqlSeqFormatter<'_, '_> {
type Ok = ();
type Error = SatnError;
fn serialize_element<T: ser::Serialize + ?Sized>(&mut self, elem: &T) -> Result<(), Self::Error> {
ser::SerializeNamedProduct::serialize_element(&mut self.inner, None, elem)
}
fn end(self) -> Result<Self::Ok, Self::Error> {
ser::SerializeNamedProduct::end(self.inner)
}
}
/// How format of the `SQL` output?
#[derive(PartialEq)]
#[derive(Debug, Copy, Clone, PartialEq, Display)]
pub enum PsqlPrintFmt {
/// Print as `hex` format
Hex,
@@ -552,14 +490,46 @@ pub enum PsqlPrintFmt {
}
impl PsqlPrintFmt {
fn is_special(&self) -> bool {
pub fn is_special(&self) -> bool {
self != &PsqlPrintFmt::Satn
}
/// Returns if the type is a special type
///
/// Is required to check both the enclosing type and the inner element type
pub fn use_fmt(tuple: &ProductType, field: &ProductTypeElement, name: Option<&str>) -> PsqlPrintFmt {
if tuple.is_identity()
|| tuple.is_connection_id()
|| field.algebraic_type.is_identity()
|| field.algebraic_type.is_connection_id()
|| name.map(ProductType::is_identity_tag).unwrap_or_default()
|| name.map(ProductType::is_connection_id_tag).unwrap_or_default()
{
return PsqlPrintFmt::Hex;
};
if tuple.is_timestamp()
|| field.algebraic_type.is_timestamp()
|| name.map(ProductType::is_timestamp_tag).unwrap_or_default()
{
return PsqlPrintFmt::Timestamp;
};
if tuple.is_time_duration()
|| field.algebraic_type.is_time_duration()
|| name.map(ProductType::is_time_duration_tag).unwrap_or_default()
{
return PsqlPrintFmt::Duration;
};
PsqlPrintFmt::Satn
}
}
/// A wrapper that remember the `header` of the tuple/struct and the current field
#[derive(Debug, Clone)]
pub struct PsqlType<'a> {
/// The client used to format the output
pub client: PsqlClient,
/// The header of the tuple/struct
pub tuple: &'a ProductType,
/// The current field
@@ -572,168 +542,369 @@ impl PsqlType<'_> {
/// Returns if the type is a special type
///
/// Is required to check both the enclosing type and the inner element type
fn use_fmt(&self, name: Option<&str>) -> PsqlPrintFmt {
if self.tuple.is_identity()
|| self.tuple.is_connection_id()
|| self.field.algebraic_type.is_identity()
|| self.field.algebraic_type.is_connection_id()
|| name.map(ProductType::is_identity_tag).unwrap_or_default()
|| name.map(ProductType::is_connection_id_tag).unwrap_or_default()
{
return PsqlPrintFmt::Hex;
};
if self.tuple.is_timestamp()
|| self.field.algebraic_type.is_timestamp()
|| name.map(ProductType::is_timestamp_tag).unwrap_or_default()
{
return PsqlPrintFmt::Timestamp;
};
if self.tuple.is_time_duration()
|| self.field.algebraic_type.is_time_duration()
|| name.map(ProductType::is_time_duration_tag).unwrap_or_default()
{
return PsqlPrintFmt::Duration;
};
PsqlPrintFmt::Satn
pub fn use_fmt(&self) -> PsqlPrintFmt {
PsqlPrintFmt::use_fmt(self.tuple, self.field, None)
}
}
/// An implementation of [`Serializer`](ser::Serializer) for `SQL` output.
struct PsqlFormatter<'a, 'f> {
pub struct SqlFormatter<'a, 'f> {
fmt: SatnFormatter<'a, 'f>,
ty: &'a PsqlType<'a>,
}
impl<'a, 'f> ser::Serializer for PsqlFormatter<'a, 'f> {
/// A trait for writing values, after the special types has been determined.
///
/// This is used to write values that could have different representations depending on the output format,
/// as defined by [`PsqlClient`] and [`PsqlPrintFmt`].
pub trait TypedWriter {
type Error: ser::Error;
/// Writes a value using [`ser::Serializer`]
fn write<W: fmt::Display>(&mut self, value: W) -> Result<(), Self::Error>;
// Values that need special handling:
fn write_bool(&mut self, value: bool) -> Result<(), Self::Error>;
fn write_string(&mut self, value: &str) -> Result<(), Self::Error>;
fn write_bytes(&mut self, value: &[u8]) -> Result<(), Self::Error>;
fn write_hex(&mut self, value: &[u8]) -> Result<(), Self::Error>;
fn write_timestamp(&mut self, value: Timestamp) -> Result<(), Self::Error>;
fn write_duration(&mut self, value: TimeDuration) -> Result<(), Self::Error>;
/// Writes a value as an alternative record format, e.g., for use `JSON` inside `SQL`.
fn write_alt_record(
&mut self,
_ty: &PsqlType,
_value: &ValueWithType<'_, ProductValue>,
) -> Result<bool, Self::Error> {
Ok(false)
}
fn write_record(
&mut self,
fields: Vec<(Cow<str>, PsqlType, ValueWithType<AlgebraicValue>)>,
) -> Result<(), Self::Error>;
fn write_variant(
&mut self,
tag: u8,
ty: PsqlType,
name: Option<&str>,
value: ValueWithType<AlgebraicValue>,
) -> Result<(), Self::Error>;
}
/// A formatter for arrays that uses the `TypedWriter` trait to write elements.
pub struct TypedArrayFormatter<'a, 'f, F> {
ty: &'a PsqlType<'a>,
f: &'f mut F,
}
impl<F: TypedWriter> ser::SerializeArray for TypedArrayFormatter<'_, '_, F> {
type Ok = ();
type Error = SatnError;
type SerializeArray = ArrayFormatter<'a, 'f>;
type SerializeSeqProduct = PsqlSeqFormatter<'a, 'f>;
type SerializeNamedProduct = PsqlNamedFormatter<'a, 'f>;
type Error = F::Error;
fn serialize_element<T: ser::Serialize + ?Sized>(&mut self, elem: &T) -> Result<(), Self::Error> {
elem.serialize(TypedSerializer { ty: self.ty, f: self.f })?;
Ok(())
}
fn end(self) -> Result<Self::Ok, Self::Error> {
Ok(())
}
}
/// A formatter for sequences that uses the `TypedWriter` trait to write elements.
pub struct TypedSeqFormatter<'a, 'f, F> {
ty: &'a PsqlType<'a>,
f: &'f mut F,
}
impl<F: TypedWriter> ser::SerializeSeqProduct for TypedSeqFormatter<'_, '_, F> {
type Ok = ();
type Error = F::Error;
fn serialize_element<T: ser::Serialize + ?Sized>(&mut self, elem: &T) -> Result<(), Self::Error> {
elem.serialize(TypedSerializer { ty: self.ty, f: self.f })?;
Ok(())
}
fn end(self) -> Result<Self::Ok, Self::Error> {
Ok(())
}
}
/// A formatter for named products that uses the `TypedWriter` trait to write elements.
pub struct TypedNamedProductFormatter<F> {
f: PhantomData<F>,
}
impl<F: TypedWriter> ser::SerializeNamedProduct for TypedNamedProductFormatter<F> {
type Ok = ();
type Error = F::Error;
fn serialize_element<T: ser::Serialize + ?Sized>(
&mut self,
_name: Option<&str>,
_elem: &T,
) -> Result<(), Self::Error> {
Ok(())
}
fn end(self) -> Result<Self::Ok, Self::Error> {
Ok(())
}
}
/// A serializer that uses the `TypedWriter` trait to serialize values
pub struct TypedSerializer<'a, 'f, F> {
pub ty: &'a PsqlType<'a>,
pub f: &'f mut F,
}
impl<'a, 'f, F: TypedWriter> ser::Serializer for TypedSerializer<'a, 'f, F> {
type Ok = ();
type Error = F::Error;
type SerializeArray = TypedArrayFormatter<'a, 'f, F>;
type SerializeSeqProduct = TypedSeqFormatter<'a, 'f, F>;
type SerializeNamedProduct = TypedNamedProductFormatter<F>;
fn serialize_bool(self, v: bool) -> Result<Self::Ok, Self::Error> {
self.fmt.serialize_bool(v)
self.f.write_bool(v)
}
fn serialize_u8(self, v: u8) -> Result<Self::Ok, Self::Error> {
self.fmt.serialize_u8(v)
self.f.write(v)
}
fn serialize_u16(self, v: u16) -> Result<Self::Ok, Self::Error> {
self.fmt.serialize_u16(v)
self.f.write(v)
}
fn serialize_u32(self, v: u32) -> Result<Self::Ok, Self::Error> {
self.fmt.serialize_u32(v)
self.f.write(v)
}
fn serialize_u64(self, v: u64) -> Result<Self::Ok, Self::Error> {
self.fmt.serialize_u64(v)
self.f.write(v)
}
fn serialize_u128(self, v: u128) -> Result<Self::Ok, Self::Error> {
match self.ty.use_fmt(None) {
PsqlPrintFmt::Hex => self.serialize_bytes(&v.to_be_bytes()),
_ => self.fmt.serialize_u128(v),
match self.ty.use_fmt() {
PsqlPrintFmt::Hex => self.f.write_hex(&v.to_be_bytes()),
_ => self.f.write(v),
}
}
fn serialize_u256(self, v: u256) -> Result<Self::Ok, Self::Error> {
match self.ty.use_fmt(None) {
PsqlPrintFmt::Hex => self.serialize_bytes(&v.to_be_bytes()),
_ => self.fmt.serialize_u256(v),
match self.ty.use_fmt() {
PsqlPrintFmt::Hex => self.f.write_hex(&v.to_be_bytes()),
_ => self.f.write(v),
}
}
fn serialize_i8(self, v: i8) -> Result<Self::Ok, Self::Error> {
self.fmt.serialize_i8(v)
self.f.write(v)
}
fn serialize_i16(self, v: i16) -> Result<Self::Ok, Self::Error> {
self.fmt.serialize_i16(v)
self.f.write(v)
}
fn serialize_i32(self, v: i32) -> Result<Self::Ok, Self::Error> {
self.fmt.serialize_i32(v)
self.f.write(v)
}
fn serialize_i64(mut self, v: i64) -> Result<Self::Ok, Self::Error> {
match self.ty.use_fmt(None) {
PsqlPrintFmt::Duration => {
write!(self.fmt, "{}", TimeDuration::from_micros(v))?;
Ok(())
}
PsqlPrintFmt::Timestamp => {
write!(self.fmt, "{}", Timestamp::from_micros_since_unix_epoch(v))?;
Ok(())
}
_ => self.fmt.serialize_i64(v),
fn serialize_i64(self, v: i64) -> Result<Self::Ok, Self::Error> {
match self.ty.use_fmt() {
PsqlPrintFmt::Duration => self.f.write_duration(TimeDuration::from_micros(v)),
PsqlPrintFmt::Timestamp => self.f.write_timestamp(Timestamp::from_micros_since_unix_epoch(v)),
_ => self.f.write(v),
}
}
fn serialize_i128(self, v: i128) -> Result<Self::Ok, Self::Error> {
self.fmt.serialize_i128(v)
self.f.write(v)
}
fn serialize_i256(self, v: i256) -> Result<Self::Ok, Self::Error> {
self.fmt.serialize_i256(v)
self.f.write(v)
}
fn serialize_f32(self, v: f32) -> Result<Self::Ok, Self::Error> {
self.fmt.serialize_f32(v)
self.f.write(v)
}
fn serialize_f64(self, v: f64) -> Result<Self::Ok, Self::Error> {
self.fmt.serialize_f64(v)
self.f.write(v)
}
fn serialize_str(self, v: &str) -> Result<Self::Ok, Self::Error> {
self.fmt.serialize_str(v)
self.f.write_string(v)
}
fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
self.fmt.serialize_bytes(v)
if self.ty.use_fmt() == PsqlPrintFmt::Satn {
self.f.write_hex(v)
} else {
self.f.write_bytes(v)
}
}
fn serialize_array(self, len: usize) -> Result<Self::SerializeArray, Self::Error> {
self.fmt.serialize_array(len)
fn serialize_array(self, _len: usize) -> Result<Self::SerializeArray, Self::Error> {
Ok(TypedArrayFormatter { ty: self.ty, f: self.f })
}
fn serialize_seq_product(self, len: usize) -> Result<Self::SerializeSeqProduct, Self::Error> {
Ok(PsqlSeqFormatter {
inner: self.serialize_named_product(len)?,
})
fn serialize_seq_product(self, _len: usize) -> Result<Self::SerializeSeqProduct, Self::Error> {
Ok(TypedSeqFormatter { ty: self.ty, f: self.f })
}
fn serialize_named_product(self, _len: usize) -> Result<Self::SerializeNamedProduct, Self::Error> {
Ok(PsqlNamedFormatter::new(self.ty, self.fmt.f))
unreachable!("This should never be called, use `serialize_named_product_raw` instead.");
}
fn serialize_variant<T: ser::Serialize + ?Sized>(
fn serialize_named_product_raw(self, value: &ValueWithType<'_, ProductValue>) -> Result<Self::Ok, Self::Error> {
let val = &value.val.elements;
assert_eq!(val.len(), value.ty().elements.len());
// If the value is a special type, we can write it directly
if self.ty.use_fmt().is_special() {
// Is a nested product type?
// We need to check for both the enclosing(`self.ty`) type and the inner element type.
let (tuple, field) = if let Some(product) = self.ty.field.algebraic_type.as_product() {
(product, &product.elements[0])
} else {
(self.ty.tuple, self.ty.field)
};
return value.val.serialize(TypedSerializer {
ty: &PsqlType {
client: self.ty.client,
tuple,
field,
idx: self.ty.idx,
},
f: self.f,
});
}
// Allow to switch to an alternative record format, for example to write a `JSON` record.
if self.f.write_alt_record(self.ty, value)? {
return Ok(());
}
let mut record = Vec::with_capacity(val.len());
for (idx, (val, field)) in val.iter().zip(&*value.ty().elements).enumerate() {
let ty = PsqlType {
client: self.ty.client,
tuple: value.ty(),
field,
idx,
};
record.push((
field
.name()
.map(Cow::from)
.unwrap_or_else(|| Cow::from(format!("col_{idx}"))),
ty,
value.with(&field.algebraic_type, val),
));
}
self.f.write_record(record)
}
fn serialize_variant_raw(self, sum: &ValueWithType<'_, SumValue>) -> Result<Self::Ok, Self::Error> {
let sv = sum.value();
let (tag, val) = (sv.tag, &*sv.value);
let var_ty = &sum.ty().variants[tag as usize]; // Extract the variant type by tag.
let product = ProductType::from([AlgebraicType::sum(sum.ty().clone())]);
let ty = PsqlType {
client: self.ty.client,
tuple: &product,
field: &product.elements[0],
idx: 0,
};
self.f
.write_variant(tag, ty, var_ty.name(), sum.with(&var_ty.algebraic_type, val))
}
fn serialize_variant<T: Serialize + ?Sized>(
self,
tag: u8,
name: Option<&str>,
value: &T,
_tag: u8,
_name: Option<&str>,
_value: &T,
) -> Result<Self::Ok, Self::Error> {
self.fmt.serialize_variant(tag, name, value)
}
unsafe fn serialize_bsatn<Ty>(self, ty: &Ty, bsatn: &[u8]) -> Result<Self::Ok, Self::Error>
where
for<'b, 'de> WithTypespace<'b, Ty>: DeserializeSeed<'de, Output: Into<AlgebraicValue>>,
{
// SAFETY: Forward caller requirements of this method to that we are calling.
unsafe { self.fmt.serialize_bsatn(ty, bsatn) }
}
unsafe fn serialize_bsatn_in_chunks<'c, Ty, I: Clone + Iterator<Item = &'c [u8]>>(
self,
ty: &Ty,
total_bsatn_len: usize,
bsatn: I,
) -> Result<Self::Ok, Self::Error>
where
for<'b, 'de> WithTypespace<'b, Ty>: DeserializeSeed<'de, Output: Into<AlgebraicValue>>,
{
// SAFETY: Forward caller requirements of this method to that we are calling.
unsafe { self.fmt.serialize_bsatn_in_chunks(ty, total_bsatn_len, bsatn) }
}
unsafe fn serialize_str_in_chunks<'c, I: Clone + Iterator<Item = &'c [u8]>>(
self,
total_len: usize,
string: I,
) -> Result<Self::Ok, Self::Error> {
// SAFETY: Forward caller requirements of this method to that we are calling.
unsafe { self.fmt.serialize_str_in_chunks(total_len, string) }
unreachable!("Use `serialize_variant_raw` instead.");
}
}
impl TypedWriter for SqlFormatter<'_, '_> {
type Error = SatnError;
fn write<W: fmt::Display>(&mut self, value: W) -> Result<(), Self::Error> {
write!(self.fmt, "{value}")
}
fn write_bool(&mut self, value: bool) -> Result<(), Self::Error> {
write!(self.fmt, "{value}")
}
fn write_string(&mut self, value: &str) -> Result<(), Self::Error> {
write!(self.fmt, "\"{value}\"")
}
fn write_bytes(&mut self, value: &[u8]) -> Result<(), Self::Error> {
self.write_hex(value)
}
fn write_hex(&mut self, value: &[u8]) -> Result<(), Self::Error> {
match self.ty.client {
PsqlClient::SpacetimeDB => write!(self.fmt, "0x{}", hex::encode(value)),
PsqlClient::Postgres => write!(self.fmt, "\"0x{}\"", hex::encode(value)),
}
}
fn write_timestamp(&mut self, value: Timestamp) -> Result<(), Self::Error> {
match self.ty.client {
PsqlClient::SpacetimeDB => write!(self.fmt, "{}", value.to_rfc3339().unwrap()),
PsqlClient::Postgres => write!(self.fmt, "\"{}\"", value.to_rfc3339().unwrap()),
}
}
fn write_duration(&mut self, value: TimeDuration) -> Result<(), Self::Error> {
match self.ty.client {
PsqlClient::SpacetimeDB => write!(self.fmt, "{value}"),
PsqlClient::Postgres => write!(self.fmt, "\"{}\"", value.to_iso8601()),
}
}
fn write_record(
&mut self,
fields: Vec<(Cow<str>, PsqlType<'_>, ValueWithType<AlgebraicValue>)>,
) -> Result<(), Self::Error> {
let PsqlChars { start, sep, end, quote } = self.ty.client.format_chars();
write!(self.fmt, "{start}")?;
for (idx, (name, ty, value)) in fields.into_iter().enumerate() {
if idx > 0 {
write!(self.fmt, ", ")?;
}
write!(self.fmt, "{quote}{name}{quote}{sep} ")?;
// Serialize the value
value.serialize(TypedSerializer { ty: &ty, f: self })?;
}
write!(self.fmt, "{end}")?;
Ok(())
}
fn write_variant(
&mut self,
tag: u8,
ty: PsqlType,
name: Option<&str>,
value: ValueWithType<AlgebraicValue>,
) -> Result<(), Self::Error> {
self.write_record(vec![(
name.map(Cow::from).unwrap_or_else(|| Cow::from(format!("col_{tag}"))),
ty,
value,
)])
}
}
+26 -1
View File
@@ -6,7 +6,7 @@ mod impls;
pub mod serde;
use crate::de::DeserializeSeed;
use crate::{algebraic_value::ser::ValueSerializer, bsatn, buffer::BufWriter};
use crate::{algebraic_value::ser::ValueSerializer, bsatn, buffer::BufWriter, ProductValue, SumValue, ValueWithType};
use crate::{AlgebraicValue, WithTypespace};
use core::marker::PhantomData;
use core::{convert::Infallible, fmt};
@@ -117,6 +117,31 @@ pub trait Serializer: Sized {
/// The argument is the number of fields in the product.
fn serialize_named_product(self, len: usize) -> Result<Self::SerializeNamedProduct, Self::Error>;
/// Serialize a product with named fields.
///
/// Allow to override the default serialization for where we need to switch the output format,
/// see [`crate::satn::TypedWriter`].
fn serialize_named_product_raw(self, value: &ValueWithType<'_, ProductValue>) -> Result<Self::Ok, Self::Error> {
let val = &value.val.elements;
assert_eq!(val.len(), value.ty().elements.len());
let mut prod = self.serialize_named_product(val.len())?;
for (val, el_ty) in val.iter().zip(&*value.ty().elements) {
prod.serialize_element(el_ty.name(), &value.with(&el_ty.algebraic_type, val))?
}
prod.end()
}
/// Serialize a sum value
///
/// Allow to override the default serialization for where we need to switch the output format,
/// see [`crate::satn::TypedWriter`].
fn serialize_variant_raw(self, sum: &ValueWithType<'_, SumValue>) -> Result<Self::Ok, Self::Error> {
let sv = sum.value();
let (tag, val) = (sv.tag, &*sv.value);
let var_ty = &sum.ty().variants[tag as usize]; // Extract the variant type by tag.
self.serialize_variant(tag, var_ty.name(), &sum.with(&var_ty.algebraic_type, val))
}
/// Serialize a sum value provided the chosen `tag`, `name`, and `value`.
fn serialize_variant<T: Serialize + ?Sized>(
self,
+3 -12
View File
@@ -1,4 +1,4 @@
use super::{Serialize, SerializeArray, SerializeNamedProduct, SerializeSeqProduct, Serializer};
use super::{Serialize, SerializeArray, SerializeSeqProduct, Serializer};
use crate::{i256, u256};
use crate::{AlgebraicType, AlgebraicValue, ArrayValue, ProductValue, SumValue, ValueWithType, F32, F64};
use core::ops::Bound;
@@ -190,19 +190,10 @@ impl_serialize!(
}
);
impl_serialize!([] ValueWithType<'_, SumValue>, (self, ser) => {
let sv = self.value();
let (tag, val) = (sv.tag, &*sv.value);
let var_ty = &self.ty().variants[tag as usize]; // Extract the variant type by tag.
ser.serialize_variant(tag, var_ty.name(), &self.with(&var_ty.algebraic_type, val))
ser.serialize_variant_raw(self)
});
impl_serialize!([] ValueWithType<'_, ProductValue>, (self, ser) => {
let val = &self.value().elements;
assert_eq!(val.len(), self.ty().elements.len());
let mut prod = ser.serialize_named_product(val.len())?;
for (val, el_ty) in val.iter().zip(&*self.ty().elements) {
prod.serialize_element(el_ty.name(), &self.with(&el_ty.algebraic_type, val))?
}
prod.end()
ser.serialize_named_product_raw(self)
});
impl_serialize!([] ValueWithType<'_, ArrayValue>, (self, ser) => {
let mut ty = &*self.ty().elem_ty;
+16
View File
@@ -82,6 +82,22 @@ impl TimeDuration {
pub fn checked_sub(self, other: Self) -> Option<Self> {
self.to_micros().checked_sub(other.to_micros()).map(Self::from_micros)
}
/// Generate an `iso8601` format string.
///
/// This is the better supported format for use for the `pg wire protocol`.
///
/// Example:
/// ```rust
/// use std::time::Duration;
/// use spacetimedb_sats::time_duration::TimeDuration;
/// assert_eq!( TimeDuration::from_micros(0).to_iso8601().as_str(), "P0D");
/// assert_eq!( TimeDuration::from_micros(-1_000_000).to_iso8601().as_str(), "-PT1S");
/// assert_eq!( TimeDuration::from_duration(Duration::from_secs(60 * 24)).to_iso8601().as_str(), "PT1440S");
/// ```
pub fn to_iso8601(self) -> String {
chrono::Duration::microseconds(self.to_micros()).to_string()
}
}
impl From<Duration> for TimeDuration {
+7 -3
View File
@@ -171,13 +171,17 @@ impl Timestamp {
pub fn checked_sub_duration(&self, duration: Duration) -> Option<Self> {
self.checked_sub(TimeDuration::from_duration(duration))
}
/// Returns an RFC 3339 and ISO 8601 date and time string such as `1996-12-19T16:39:57-08:00`.
pub fn to_rfc3339(&self) -> anyhow::Result<String> {
pub fn to_chrono_date_time(&self) -> anyhow::Result<DateTime<chrono::Utc>> {
DateTime::from_timestamp_micros(self.to_micros_since_unix_epoch())
.map(|t| t.to_rfc3339())
.ok_or_else(|| anyhow::anyhow!("Timestamp with i64 microseconds since Unix epoch overflows DateTime"))
.with_context(|| self.to_micros_since_unix_epoch())
}
/// Returns an RFC 3339 and ISO 8601 date and time string such as `1996-12-19T16:39:57-08:00`.
pub fn to_rfc3339(&self) -> anyhow::Result<String> {
Ok(self.to_chrono_date_time()?.to_rfc3339())
}
}
impl Add<TimeDuration> for Timestamp {
+1
View File
@@ -27,6 +27,7 @@ spacetimedb-core.workspace = true
spacetimedb-datastore.workspace = true
spacetimedb-lib.workspace = true
spacetimedb-paths.workspace = true
spacetimedb-pg.workspace = true
spacetimedb-table.workspace = true
spacetimedb-schema.workspace = true
+1
View File
@@ -186,6 +186,7 @@ impl spacetimedb_client_api::ControlStateReadAccess for StandaloneEnv {
id: 0,
unschedulable: false,
advertise_addr: Some("node:80".to_owned()),
pg_addr: Some("node:5432".to_owned()),
}));
}
Ok(None)
+19 -3
View File
@@ -1,3 +1,4 @@
use spacetimedb_pg::pg_server;
use std::sync::Arc;
use crate::{StandaloneEnv, StandaloneOptions};
@@ -176,12 +177,27 @@ pub async fn exec(args: &ArgMatches, db_cores: JobCores) -> anyhow::Result<()> {
db_routes.root_post = db_routes.root_post.layer(DefaultBodyLimit::disable());
db_routes.db_put = db_routes.db_put.layer(DefaultBodyLimit::disable());
let extra = axum::Router::new().nest("/health", spacetimedb_client_api::routes::health::router());
let service = router(&ctx, db_routes, extra).with_state(ctx);
let service = router(&ctx, db_routes, extra).with_state(ctx.clone());
let tcp = TcpListener::bind(listen_addr).await?;
socket2::SockRef::from(&tcp).set_nodelay(true)?;
log::debug!("Starting SpacetimeDB listening on {}", tcp.local_addr().unwrap());
axum::serve(tcp, service).await?;
log::debug!("Starting SpacetimeDB listening on {}", tcp.local_addr()?);
let pg_server_addr = format!("{}:5432", listen_addr.split(':').next().unwrap());
let tcp_pg = TcpListener::bind(pg_server_addr).await?;
let notify = Arc::new(tokio::sync::Notify::new());
let shutdown_notify = notify.clone();
tokio::select! {
_ = pg_server::start_pg(notify.clone(), ctx, tcp_pg) => {},
_ = axum::serve(tcp, service).with_graceful_shutdown(async move {
shutdown_notify.notified().await;
}) => {},
_ = tokio::signal::ctrl_c() => {
println!("Shutting down servers...");
notify.notify_waiters(); // Notify all tasks
}
}
Ok(())
}
+2
View File
@@ -25,6 +25,8 @@ services:
- /stdb
ports:
- "3000:3000"
# Postgres
- "5432:5432"
# Tracy
- "8086:8086"
entrypoint: cargo watch -i flamegraphs -i log.conf --why -C crates/standalone -x 'run start --data-dir=/stdb/data --jwt-pub-key-path=/etc/spacetimedb/id_ecdsa.pub --jwt-priv-key-path=/etc/spacetimedb/id_ecdsa'
+293
View File
@@ -0,0 +1,293 @@
from .. import Smoketest
import subprocess
import os
import tomllib
import psycopg2
def psql(identity: str, sql: str, extra=None) -> str:
"""Call `psql` and execute the given SQL statement."""
if extra is None:
extra = dict()
result = subprocess.run(
["psql", "-h", "127.0.0.1", "-p", "5432", "-U", "postgres", "-d", "quickstart", "--quiet", "-c", sql],
encoding="utf8",
env={**os.environ, **extra, "PGPASSWORD": identity},
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
if result.stderr:
raise Exception(result.stderr.strip())
return result.stdout.strip()
def connect_db(identity: str):
"""Connect to the database using `psycopg2`."""
conn = psycopg2.connect(host="127.0.0.1", port=5432, user="postgres", password=identity, dbname="quickstart")
conn.set_session(autocommit=True) # Disable automic transaction
return conn
class SqlFormat(Smoketest):
AUTOPUBLISH = False
MODULE_CODE = """
use spacetimedb::sats::{i256, u256};
use spacetimedb::{ConnectionId, Identity, ReducerContext, SpacetimeType, Table, Timestamp, TimeDuration};
#[derive(Copy, Clone)]
#[spacetimedb::table(name = t_ints, public)]
pub struct TInts {
i8: i8,
i16: i16,
i32: i32,
i64: i64,
i128: i128,
i256: i256,
}
#[spacetimedb::table(name = t_ints_tuple, public)]
pub struct TIntsTuple {
tuple: TInts,
}
#[derive(Copy, Clone)]
#[spacetimedb::table(name = t_uints, public)]
pub struct TUints {
u8: u8,
u16: u16,
u32: u32,
u64: u64,
u128: u128,
u256: u256,
}
#[spacetimedb::table(name = t_uints_tuple, public)]
pub struct TUintsTuple {
tuple: TUints,
}
#[derive(Clone)]
#[spacetimedb::table(name = t_others, public)]
pub struct TOthers {
bool: bool,
f32: f32,
f64: f64,
str: String,
bytes: Vec<u8>,
identity: Identity,
connection_id: ConnectionId,
timestamp: Timestamp,
duration: TimeDuration,
}
#[spacetimedb::table(name = t_others_tuple, public)]
pub struct TOthersTuple {
tuple: TOthers
}
#[derive(SpacetimeType, Debug, Clone, Copy)]
pub enum Action {
Inactive,
Active,
}
#[derive(SpacetimeType, Debug, Clone, Copy)]
pub enum Color {
Gray(u8),
}
#[derive(Copy, Clone)]
#[spacetimedb::table(name = t_simple_enum, public)]
pub struct TSimpleEnum {
id : u32,
action: Action,
}
#[spacetimedb::table(name = t_enum, public)]
pub struct TEnum {
id : u32,
color: Color,
}
#[spacetimedb::table(name = t_nested, public)]
pub struct TNested {
en: TEnum,
se: TSimpleEnum,
ints: TInts,
}
#[spacetimedb::reducer]
pub fn test(ctx: &ReducerContext) {
let tuple = TInts {
i8: -25,
i16: -3224,
i32: -23443,
i64: -2344353,
i128: -234434897853,
i256: (-234434897853i128).into(),
};
let ints = tuple;
ctx.db.t_ints().insert(tuple);
ctx.db.t_ints_tuple().insert(TIntsTuple { tuple });
let tuple = TUints {
u8: 105,
u16: 1050,
u32: 83892,
u64: 48937498,
u128: 4378528978889,
u256: 4378528978889u128.into(),
};
ctx.db.t_uints().insert(tuple);
ctx.db.t_uints_tuple().insert(TUintsTuple { tuple });
let tuple = TOthers {
bool: true,
f32: 594806.58906,
f64: -3454353.345389043278459,
str: "This is spacetimedb".to_string(),
bytes: vec!(1, 2, 3, 4, 5, 6, 7),
identity: Identity::ONE,
connection_id: ConnectionId::ZERO,
timestamp: Timestamp::UNIX_EPOCH,
duration: TimeDuration::from_micros(1000 * 10000),
};
ctx.db.t_others().insert(tuple.clone());
ctx.db.t_others_tuple().insert(TOthersTuple { tuple });
ctx.db.t_simple_enum().insert(TSimpleEnum { id: 1, action: Action::Inactive });
ctx.db.t_simple_enum().insert(TSimpleEnum { id: 2, action: Action::Active });
ctx.db.t_enum().insert(TEnum { id: 1, color: Color::Gray(128) });
ctx.db.t_nested().insert(TNested {
en: TEnum { id: 1, color: Color::Gray(128) },
se: TSimpleEnum { id: 2, action: Action::Active },
ints,
});
}
"""
def assertSql(self, token: str, sql: str, expected):
self.maxDiff = None
sql_out = psql(token, sql)
sql_out = "\n".join([line.rstrip() for line in sql_out.splitlines()])
expected = "\n".join([line.rstrip() for line in expected.splitlines()])
print(sql_out)
self.assertMultiLineEqual(sql_out, expected)
def read_token(self):
"""Read the token from the config file."""
with open(self.config_path, "rb") as f:
config = tomllib.load(f)
return config['spacetimedb_token']
def test_sql_format(self):
"""This test is designed to test calling `psql` to execute SQL statements"""
token = self.read_token()
self.publish_module("quickstart", clear=True)
self.call("test")
self.assertSql(token, "SELECT * FROM t_ints", """\
i8 | i16 | i32 | i64 | i128 | i256
-----+-------+--------+----------+---------------+---------------
-25 | -3224 | -23443 | -2344353 | -234434897853 | -234434897853
(1 row)""")
self.assertSql(token, "SELECT * FROM t_ints_tuple", """\
tuple
---------------------------------------------------------------------------------------------------------
{"i8": -25, "i16": -3224, "i32": -23443, "i64": -2344353, "i128": -234434897853, "i256": -234434897853}
(1 row)""")
self.assertSql(token, "SELECT * FROM t_uints", """\
u8 | u16 | u32 | u64 | u128 | u256
-----+------+-------+----------+---------------+---------------
105 | 1050 | 83892 | 48937498 | 4378528978889 | 4378528978889
(1 row)""")
self.assertSql(token, "SELECT * FROM t_uints_tuple", """\
tuple
-------------------------------------------------------------------------------------------------------
{"u8": 105, "u16": 1050, "u32": 83892, "u64": 48937498, "u128": 4378528978889, "u256": 4378528978889}
(1 row)""")
self.assertSql(token, "SELECT * FROM t_others", """\
bool | f32 | f64 | str | bytes | identity | connection_id | timestamp | duration
------+-----------+---------------------+---------------------+------------------+--------------------------------------------------------------------+------------------------------------+---------------------------+----------
t | 594806.56 | -3454353.3453890434 | This is spacetimedb | \\x01020304050607 | \\x0000000000000000000000000000000000000000000000000000000000000001 | \\x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | PT10S
(1 row)""")
self.assertSql(token, "SELECT * FROM t_others_tuple", """\
tuple
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
{"bool": true, "f32": 594806.56, "f64": -3454353.3453890434, "str": "This is spacetimedb", "bytes": "0x01020304050607", "identity": "0x0000000000000000000000000000000000000000000000000000000000000001", "connection_id": "0x00000000000000000000000000000000", "timestamp": "1970-01-01T00:00:00+00:00", "duration": "PT10S"}
(1 row)""")
self.assertSql(token, "SELECT * FROM t_simple_enum", """\
id | action
----+----------
1 | Inactive
2 | Active
(2 rows)""")
self.assertSql(token, "SELECT * FROM t_enum", """\
id | color
----+---------------
1 | {"Gray": 128}
(1 row)""")
self.assertSql(token, "SELECT * FROM t_nested", """\
en | se | ints
-----------------------------------+-------------------------------------+---------------------------------------------------------------------------------------------------------
{"id": 1, "color": {"Gray": 128}} | {"id": 2, "action": {"Active": {}}} | {"i8": -25, "i16": -3224, "i32": -23443, "i64": -2344353, "i128": -234434897853, "i256": -234434897853}
(1 row)""")
def test_sql_conn(self):
"""This test is designed to test connecting to the database and executing queries using `psycopg2`"""
token = self.read_token()
self.publish_module("quickstart", clear=True)
self.call("test")
conn = connect_db(token)
# Check prepared statements (faked by `psycopg2`)
with conn.cursor() as cur:
cur.execute("select * from t_uints where u8 = %s and u16 = %s", (105, 1050))
rows = cur.fetchall()
self.assertEqual(rows[0], (105, 1050, 83892, 48937498, 4378528978889, 4378528978889))
# Check long-lived connection
with conn.cursor() as cur:
for _ in range(10):
cur.execute("select count(*) as t from t_uints")
rows = cur.fetchall()
self.assertEqual(rows[0], (1,))
conn.close()
def test_failures(self):
"""This test is designed to test failure cases"""
token = self.read_token()
self.publish_module("quickstart", clear=True)
# Empty query
sql_out = psql(token, "")
self.assertEqual(sql_out, "")
# Connection fails when `ssl` is required
for ssl_mode in ["require", "verify-ca", "verify-full"]:
with self.assertRaises(Exception) as cm:
psql(token, "SELECT * FROM t_uints", extra={"PGSSLMODE": ssl_mode})
self.assertIn("not support SSL", str(cm.exception))
# But works with `ssl` is disabled or optional
for ssl_mode in ["disable", "allow", "prefer"]:
psql(token, "SELECT * FROM t_uints", extra={"PGSSLMODE": ssl_mode})
# Connection fails with invalid token
with self.assertRaises(Exception) as cm:
psql("invalid_token", "SELECT * FROM t_uints")
self.assertIn("Invalid token", str(cm.exception))
# Returns error for unsupported `sql` statements
with self.assertRaises(Exception) as cm:
psql(token, "SELECT CASE a WHEN 1 THEN 'one' ELSE 'other' END FROM t_uints")
self.assertIn("Unsupported", str(cm.exception))
# And prepared statements
with self.assertRaises(Exception) as cm:
psql(token, "SELECT * FROM t_uints where u8 = $1")
self.assertIn("Unsupported", str(cm.exception))