mirror of
https://github.com/clockworklabs/SpacetimeDB.git
synced 2026-05-06 07:26:43 -04:00
Support for the PG wire protocol (#2702)
# Description of Changes Closes [#2686](https://github.com/clockworklabs/SpacetimeDB/issues/2686). Add support for listening using the [PG wire protocol](https://www.postgresql.org/docs/current/protocol.html) so `pg` clients could be used against the database. # API and ABI breaking changes The output of `duration` is changed to `rfc3339`, instead of the way is made with `sats` because is what is done in `pg`, see note below. # Expected complexity level and risk 2 ~~There is open questions that are in the [ticket #2686](https://github.com/clockworklabs/SpacetimeDB/issues/2686). Also the crate used here require `RustTls`, so it could be good idea to decide if~~: * ~~Rewrite a big chunk of code to use `OpenSSL`~~ * ~~Move to `RustTls` https://github.com/clockworklabs/SpacetimeDB/pull/1700~~ * ~~Pay for the extra compilation cost~~. I open another port(`5433`) to listen for `pg` connections using `ssl`. Need to be decided if this is the way or instead try to multi-plex the current port for both protocols. # Testing Only manual testing so far. Solving the above questions allow me to implement some unit tests. Also, not yet integrated into cloud for the same reasons. - [x] Adding some test for the binary encoding of special and primitive types - [x] Smoke test using `psql` that connect to the db instance and run some queries - [x] Manually inspect using a UI database explorer how infer the types, some of this tools generate special widgets when displaying `json, duration, etc` --------- Co-authored-by: Noa <coolreader18@gmail.com>
This commit is contained in:
@@ -21,7 +21,7 @@ jobs:
|
||||
include:
|
||||
- { runner: spacetimedb-runner, smoketest_args: --docker }
|
||||
- { runner: windows-latest, smoketest_args: --no-build-cli }
|
||||
runner: [spacetimedb-runner, windows-latest]
|
||||
runner: [ spacetimedb-runner, windows-latest ]
|
||||
runs-on: ${{ matrix.runner }}
|
||||
steps:
|
||||
- name: Find Git ref
|
||||
@@ -44,6 +44,10 @@ jobs:
|
||||
- uses: actions/setup-dotnet@v4
|
||||
with:
|
||||
global-json-file: modules/global.json
|
||||
- name: Install psql (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
run: choco install psql -y --no-progress
|
||||
shell: powershell
|
||||
- name: Build and start database (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
run: docker compose up -d
|
||||
@@ -54,11 +58,13 @@ jobs:
|
||||
Start-Process target/debug/spacetimedb-cli.exe start
|
||||
cd modules
|
||||
# the sdk-manifests on windows-latest are messed up, so we need to update them
|
||||
dotnet workload config --update-mode workload-set
|
||||
dotnet workload config --update-mode manifests
|
||||
dotnet workload update
|
||||
- uses: actions/setup-python@v5
|
||||
with: { python-version: '3.12' }
|
||||
if: runner.os == 'Windows'
|
||||
- name: Install psycopg2
|
||||
run: python -m pip install psycopg2-binary
|
||||
- name: Run smoketests
|
||||
# Note: clear_database and replication only work in private
|
||||
run: python -m smoketests ${{ matrix.smoketest_args }} -x clear_database replication
|
||||
|
||||
Generated
+187
-7
@@ -195,6 +195,12 @@ dependencies = [
|
||||
"derive_arbitrary",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "array-init"
|
||||
version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3d62b7694a562cdf5a74227903507c56ab2cc8bdd1f781ed5cb4cf9c9f810bfc"
|
||||
|
||||
[[package]]
|
||||
name = "arrayref"
|
||||
version = "0.3.9"
|
||||
@@ -293,6 +299,30 @@ version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
|
||||
|
||||
[[package]]
|
||||
name = "aws-lc-rs"
|
||||
version = "1.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "19b756939cb2f8dc900aa6dcd505e6e2428e9cae7ff7b028c49e3946efa70878"
|
||||
dependencies = [
|
||||
"aws-lc-sys",
|
||||
"untrusted 0.7.1",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aws-lc-sys"
|
||||
version = "0.28.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bfa9b6986f250236c27e5a204062434a773a13243d2ffc2955f37bdba4c5c6a1"
|
||||
dependencies = [
|
||||
"bindgen 0.69.5",
|
||||
"cc",
|
||||
"cmake",
|
||||
"dunce",
|
||||
"fs_extra",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum"
|
||||
version = "0.7.9"
|
||||
@@ -429,6 +459,29 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bindgen"
|
||||
version = "0.69.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
|
||||
dependencies = [
|
||||
"bitflags 2.9.0",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"itertools 0.12.1",
|
||||
"lazy_static",
|
||||
"lazycell",
|
||||
"log",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"regex",
|
||||
"rustc-hash 1.1.0",
|
||||
"shlex",
|
||||
"syn 2.0.101",
|
||||
"which 4.4.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bindgen"
|
||||
version = "0.71.1"
|
||||
@@ -444,7 +497,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"regex",
|
||||
"rustc-hash",
|
||||
"rustc-hash 2.1.1",
|
||||
"shlex",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
@@ -925,6 +978,15 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cmake"
|
||||
version = "0.1.54"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cobs"
|
||||
version = "0.2.3"
|
||||
@@ -1092,7 +1154,7 @@ dependencies = [
|
||||
"hashbrown 0.14.5",
|
||||
"log",
|
||||
"regalloc2",
|
||||
"rustc-hash",
|
||||
"rustc-hash 2.1.1",
|
||||
"smallvec",
|
||||
"target-lexicon",
|
||||
]
|
||||
@@ -1485,6 +1547,17 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive-new"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_arbitrary"
|
||||
version = "1.4.1"
|
||||
@@ -1602,6 +1675,12 @@ dependencies = [
|
||||
"shared_child",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dunce"
|
||||
version = "1.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813"
|
||||
|
||||
[[package]]
|
||||
name = "educe"
|
||||
version = "0.4.23"
|
||||
@@ -3011,12 +3090,41 @@ dependencies = [
|
||||
"spacetimedb 1.4.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy-regex"
|
||||
version = "3.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "60c7310b93682b36b98fa7ea4de998d3463ccbebd94d935d6b48ba5b6ffa7126"
|
||||
dependencies = [
|
||||
"lazy-regex-proc_macros",
|
||||
"once_cell",
|
||||
"regex-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy-regex-proc_macros"
|
||||
version = "3.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ba01db5ef81e17eb10a5e0f2109d1b3a3e29bac3070fdbd7d156bf7dbd206a1"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"regex",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
||||
|
||||
[[package]]
|
||||
name = "lazycell"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
|
||||
|
||||
[[package]]
|
||||
name = "leb128"
|
||||
version = "0.2.5"
|
||||
@@ -3215,6 +3323,12 @@ dependencies = [
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "md5"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ae960838283323069879657ca3de837e9f7bbb4c7bf6ea7f1b290d5e9476d2e0"
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.7.4"
|
||||
@@ -3805,6 +3919,31 @@ dependencies = [
|
||||
"postgres-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pgwire"
|
||||
version = "0.32.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ddf403a6ee31cf7f2217b2bd8447cb13dbb6c268d7e81501bc78a4d3daafd294"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"aws-lc-rs",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"derive-new",
|
||||
"futures",
|
||||
"hex",
|
||||
"lazy-regex",
|
||||
"md5",
|
||||
"postgres-types",
|
||||
"rand 0.9.1",
|
||||
"rust_decimal",
|
||||
"rustls-pki-types",
|
||||
"thiserror 2.0.12",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tokio-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "phf"
|
||||
version = "0.11.3"
|
||||
@@ -3956,6 +4095,7 @@ version = "0.2.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "613283563cd90e1dfc3518d548caee47e0e725455ed619881f5cf21f36de4b48"
|
||||
dependencies = [
|
||||
"array-init",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"fallible-iterator 0.2.0",
|
||||
@@ -4445,7 +4585,7 @@ checksum = "12908dbeb234370af84d0579b9f68258a0f67e201412dd9a2814e6f45b2fc0f0"
|
||||
dependencies = [
|
||||
"hashbrown 0.14.5",
|
||||
"log",
|
||||
"rustc-hash",
|
||||
"rustc-hash 2.1.1",
|
||||
"slice-group-by",
|
||||
"smallvec",
|
||||
]
|
||||
@@ -4482,6 +4622,12 @@ dependencies = [
|
||||
"regex-syntax 0.8.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-lite"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a"
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.6.29"
|
||||
@@ -4623,7 +4769,7 @@ dependencies = [
|
||||
"cfg-if",
|
||||
"getrandom 0.2.16",
|
||||
"libc",
|
||||
"untrusted",
|
||||
"untrusted 0.9.0",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
@@ -4734,6 +4880,12 @@ version = "0.1.24"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hash"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hash"
|
||||
version = "2.1.1"
|
||||
@@ -4781,6 +4933,8 @@ version = "0.23.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "730944ca083c1c233a75c09f199e973ca499344a2b7ba9e755c457e86fb4a321"
|
||||
dependencies = [
|
||||
"aws-lc-rs",
|
||||
"log",
|
||||
"once_cell",
|
||||
"rustls-pki-types",
|
||||
"rustls-webpki",
|
||||
@@ -4821,9 +4975,10 @@ version = "0.103.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7149975849f1abb3832b246010ef62ccc80d3a76169517ada7188252b9cfb437"
|
||||
dependencies = [
|
||||
"aws-lc-rs",
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"untrusted",
|
||||
"untrusted 0.9.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5674,7 +5829,7 @@ dependencies = [
|
||||
"regex",
|
||||
"reqwest 0.12.15",
|
||||
"rustc-demangle",
|
||||
"rustc-hash",
|
||||
"rustc-hash 2.1.1",
|
||||
"scopeguard",
|
||||
"semver",
|
||||
"serde",
|
||||
@@ -5960,6 +6115,24 @@ dependencies = [
|
||||
"xdg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spacetimedb-pg"
|
||||
version = "1.4.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"axum",
|
||||
"futures",
|
||||
"http 1.3.1",
|
||||
"log",
|
||||
"pgwire",
|
||||
"spacetimedb-client-api",
|
||||
"spacetimedb-client-api-messages",
|
||||
"spacetimedb-lib 1.4.0",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spacetimedb-physical-plan"
|
||||
version = "1.4.0"
|
||||
@@ -6207,6 +6380,7 @@ dependencies = [
|
||||
"spacetimedb-datastore",
|
||||
"spacetimedb-lib 1.4.0",
|
||||
"spacetimedb-paths",
|
||||
"spacetimedb-pg",
|
||||
"spacetimedb-schema",
|
||||
"spacetimedb-table",
|
||||
"tempfile",
|
||||
@@ -7397,6 +7571,12 @@ version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
|
||||
|
||||
[[package]]
|
||||
name = "untrusted"
|
||||
version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
|
||||
|
||||
[[package]]
|
||||
name = "untrusted"
|
||||
version = "0.9.0"
|
||||
@@ -7477,7 +7657,7 @@ version = "137.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1ca393e2032ddba2a57169e15cac5d0a81cdb3d872a8886f4468bc0f486098d2"
|
||||
dependencies = [
|
||||
"bindgen",
|
||||
"bindgen 0.71.1",
|
||||
"bitflags 2.9.0",
|
||||
"fslock",
|
||||
"gzip-header",
|
||||
|
||||
@@ -19,6 +19,7 @@ members = [
|
||||
"crates/lib",
|
||||
"crates/metrics",
|
||||
"crates/paths",
|
||||
"crates/pg",
|
||||
"crates/physical-plan",
|
||||
"crates/primitives",
|
||||
"crates/query",
|
||||
@@ -115,6 +116,7 @@ spacetimedb-lib = { path = "crates/lib", default-features = false, version = "1.
|
||||
spacetimedb-memory-usage = { path = "crates/memory-usage", version = "1.4.0", default-features = false }
|
||||
spacetimedb-metrics = { path = "crates/metrics", version = "1.4.0" }
|
||||
spacetimedb-paths = { path = "crates/paths", version = "1.4.0" }
|
||||
spacetimedb-pg = { path = "crates/pg", version = "1.4.0" }
|
||||
spacetimedb-physical-plan = { path = "crates/physical-plan", version = "1.4.0" }
|
||||
spacetimedb-primitives = { path = "crates/primitives", version = "1.4.0" }
|
||||
spacetimedb-query = { path = "crates/query", version = "1.4.0" }
|
||||
@@ -214,6 +216,7 @@ paste = "1.0"
|
||||
percent-encoding = "2.3"
|
||||
petgraph = { version = "0.6.5", default-features = false }
|
||||
pin-project-lite = "0.2.9"
|
||||
pgwire = { version = "0.32", features = ["server-api"] }
|
||||
postgres-types = "0.2.5"
|
||||
pretty_assertions = { version = "1.4", features = ["unstable"] }
|
||||
proc-macro2 = "1.0"
|
||||
|
||||
@@ -10,6 +10,7 @@ use anyhow::Context;
|
||||
use clap::{Arg, ArgAction, ArgMatches};
|
||||
use reqwest::RequestBuilder;
|
||||
use spacetimedb_lib::de::serde::SeedWrapper;
|
||||
use spacetimedb_lib::sats::satn::PsqlClient;
|
||||
use spacetimedb_lib::sats::{satn, ProductType, ProductValue, Typespace};
|
||||
|
||||
pub fn cli() -> clap::Command {
|
||||
@@ -111,7 +112,7 @@ fn print_stmt_result(
|
||||
for (pos, result) in if_empty
|
||||
.into_iter()
|
||||
.chain(stmt_results.iter().map(|stmt_result| {
|
||||
let (stats, table) = stmt_result_to_table(stmt_result)?;
|
||||
let (stats, table) = stmt_result_to_table(PsqlClient::SpacetimeDB, stmt_result)?;
|
||||
|
||||
anyhow::Ok(StmtResult {
|
||||
stats: with_stats.is_some().then_some(stats),
|
||||
@@ -157,12 +158,13 @@ pub(crate) async fn run_sql(builder: RequestBuilder, sql: &str, with_stats: bool
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn stmt_result_to_table(stmt_result: &SqlStmtResult) -> anyhow::Result<(StmtStats, tabled::Table)> {
|
||||
fn stmt_result_to_table(client: PsqlClient, stmt_result: &SqlStmtResult) -> anyhow::Result<(StmtStats, tabled::Table)> {
|
||||
let stats = StmtStats::from(stmt_result);
|
||||
let SqlStmtResult { schema, rows, .. } = stmt_result;
|
||||
let ty = Typespace::EMPTY.with_type(schema);
|
||||
|
||||
let table = build_table(
|
||||
client,
|
||||
schema,
|
||||
rows.iter().map(|row| from_json_seed(row.get(), SeedWrapper(ty))),
|
||||
)?;
|
||||
@@ -194,6 +196,7 @@ pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error
|
||||
|
||||
/// Generates a [`tabled::Table`] from a schema and rows, using the style of a psql table.
|
||||
fn build_table<E>(
|
||||
client: PsqlClient,
|
||||
schema: &ProductType,
|
||||
rows: impl Iterator<Item = Result<ProductValue, E>>,
|
||||
) -> Result<tabled::Table, E> {
|
||||
@@ -211,6 +214,7 @@ fn build_table<E>(
|
||||
let row = row?;
|
||||
builder.push_record(ty.with_values(&row).enumerate().map(|(idx, value)| {
|
||||
let ty = satn::PsqlType {
|
||||
client,
|
||||
tuple: ty.ty(),
|
||||
field: &ty.ty().elements[idx],
|
||||
idx,
|
||||
@@ -446,8 +450,10 @@ Roundtrip time: 1.00ms"#,
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn expect_psql_table(ty: &ProductType, rows: Vec<ProductValue>, expected: &str) {
|
||||
let table = build_table(ty, rows.into_iter().map(Ok::<_, ()>)).unwrap().to_string();
|
||||
fn expect_psql_table(client: PsqlClient, ty: &ProductType, rows: Vec<ProductValue>, expected: &str) {
|
||||
let table = build_table(client, ty, rows.into_iter().map(Ok::<_, ()>))
|
||||
.unwrap()
|
||||
.to_string();
|
||||
let mut table = table.split('\n').map(|x| x.trim_end()).join("\n");
|
||||
table.insert(0, '\n');
|
||||
assert_eq!(expected, table);
|
||||
@@ -476,14 +482,25 @@ Roundtrip time: 1.00ms"#,
|
||||
];
|
||||
|
||||
expect_psql_table(
|
||||
PsqlClient::SpacetimeDB,
|
||||
&kind,
|
||||
vec![value],
|
||||
vec![value.clone()],
|
||||
r#"
|
||||
column 0 | column 1 | column 2 | column 3 | column 4 | column 5
|
||||
----------+----------+--------------------------------------------------------------------+------------------------------------+---------------------------+-----------
|
||||
"a" | 0 | 0x0000000000000000000000000000000000000000000000000000000000000000 | 0x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | +0.000000"#,
|
||||
);
|
||||
|
||||
expect_psql_table(
|
||||
PsqlClient::Postgres,
|
||||
&kind,
|
||||
vec![value],
|
||||
r#"
|
||||
column 0 | column 1 | column 2 | column 3 | column 4 | column 5
|
||||
----------+----------+----------------------------------------------------------------------+--------------------------------------+-----------------------------+----------
|
||||
"a" | 0 | "0x0000000000000000000000000000000000000000000000000000000000000000" | "0x00000000000000000000000000000000" | "1970-01-01T00:00:00+00:00" | "P0D""#,
|
||||
);
|
||||
|
||||
// Check struct
|
||||
let kind: ProductType = [
|
||||
("bool", AlgebraicType::Bool),
|
||||
@@ -507,6 +524,7 @@ Roundtrip time: 1.00ms"#,
|
||||
];
|
||||
|
||||
expect_psql_table(
|
||||
PsqlClient::SpacetimeDB,
|
||||
&kind,
|
||||
vec![value.clone()],
|
||||
r#"
|
||||
@@ -515,12 +533,23 @@ Roundtrip time: 1.00ms"#,
|
||||
true | "This is spacetimedb" | 0x01020304050607 | 0x0000000000000000000000000000000000000000000000000000000000000000 | 0x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | +0.000000"#,
|
||||
);
|
||||
|
||||
expect_psql_table(
|
||||
PsqlClient::Postgres,
|
||||
&kind,
|
||||
vec![value.clone()],
|
||||
r#"
|
||||
bool | str | bytes | identity | connection_id | timestamp | duration
|
||||
------+-----------------------+--------------------+----------------------------------------------------------------------+--------------------------------------+-----------------------------+----------
|
||||
true | "This is spacetimedb" | "0x01020304050607" | "0x0000000000000000000000000000000000000000000000000000000000000000" | "0x00000000000000000000000000000000" | "1970-01-01T00:00:00+00:00" | "P0D""#,
|
||||
);
|
||||
|
||||
// Check nested struct, tuple...
|
||||
let kind: ProductType = [(None, AlgebraicType::product(kind))].into();
|
||||
|
||||
let value = product![value.clone()];
|
||||
|
||||
expect_psql_table(
|
||||
PsqlClient::SpacetimeDB,
|
||||
&kind,
|
||||
vec![value.clone()],
|
||||
r#"
|
||||
@@ -529,17 +558,38 @@ Roundtrip time: 1.00ms"#,
|
||||
(bool = true, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000000, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = +0.000000)"#,
|
||||
);
|
||||
|
||||
expect_psql_table(
|
||||
PsqlClient::Postgres,
|
||||
&kind,
|
||||
vec![value.clone()],
|
||||
r#"
|
||||
column 0
|
||||
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
{"bool": true, "str": "This is spacetimedb", "bytes": "0x01020304050607", "identity": "0x0000000000000000000000000000000000000000000000000000000000000000", "connection_id": "0x00000000000000000000000000000000", "timestamp": "1970-01-01T00:00:00+00:00", "duration": "P0D"}"#,
|
||||
);
|
||||
|
||||
let kind: ProductType = [("tuple", AlgebraicType::product(kind))].into();
|
||||
|
||||
let value = product![value];
|
||||
|
||||
expect_psql_table(
|
||||
PsqlClient::SpacetimeDB,
|
||||
&kind,
|
||||
vec![value.clone()],
|
||||
r#"
|
||||
tuple
|
||||
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
(col_0 = (bool = true, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000000, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = +0.000000))"#,
|
||||
);
|
||||
|
||||
expect_psql_table(
|
||||
PsqlClient::Postgres,
|
||||
&kind,
|
||||
vec![value],
|
||||
r#"
|
||||
tuple
|
||||
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
(0 = (bool = true, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000000, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = +0.000000))"#,
|
||||
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
{"col_0": {"bool": true, "str": "This is spacetimedb", "bytes": "0x01020304050607", "identity": "0x0000000000000000000000000000000000000000000000000000000000000000", "connection_id": "0x00000000000000000000000000000000", "timestamp": "1970-01-01T00:00:00+00:00", "duration": "P0D"}}"#,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -171,7 +171,7 @@ pub enum SetDefaultDomainResult {
|
||||
///
|
||||
/// Must match the regex `^[a-z0-9]+(-[a-z0-9]+)*$`
|
||||
#[derive(Clone, Debug, serde_with::DeserializeFromStr, serde_with::SerializeDisplay)]
|
||||
pub struct DatabaseName(String);
|
||||
pub struct DatabaseName(pub String);
|
||||
|
||||
impl AsRef<str> for DatabaseName {
|
||||
fn as_ref(&self) -> &str {
|
||||
|
||||
@@ -222,6 +222,10 @@ impl<TV: TokenValidator + Send + Sync> TokenSigner for JwtKeyAuthProvider<TV> {
|
||||
impl<TV: TokenValidator + Send + Sync> JwtAuthProvider for JwtKeyAuthProvider<TV> {
|
||||
type TV = TV;
|
||||
|
||||
fn validator(&self) -> &Self::TV {
|
||||
&self.validator
|
||||
}
|
||||
|
||||
fn local_issuer(&self) -> &str {
|
||||
&self.local_issuer
|
||||
}
|
||||
@@ -229,10 +233,6 @@ impl<TV: TokenValidator + Send + Sync> JwtAuthProvider for JwtKeyAuthProvider<TV
|
||||
fn public_key_bytes(&self) -> &[u8] {
|
||||
&self.keys.public_pem
|
||||
}
|
||||
|
||||
fn validator(&self) -> &Self::TV {
|
||||
&self.validator
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -260,6 +260,13 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn validate_token<S: NodeDelegate>(
|
||||
state: &S,
|
||||
token: &str,
|
||||
) -> Result<SpacetimeIdentityClaims, TokenValidationError> {
|
||||
state.jwt_auth_provider().validator().validate_token(token).await
|
||||
}
|
||||
|
||||
pub struct SpacetimeAuthHeader {
|
||||
auth: Option<SpacetimeAuth>,
|
||||
}
|
||||
@@ -272,10 +279,7 @@ impl<S: NodeDelegate + Send + Sync> axum::extract::FromRequestParts<S> for Space
|
||||
return Ok(Self { auth: None });
|
||||
};
|
||||
|
||||
let claims = state
|
||||
.jwt_auth_provider()
|
||||
.validator()
|
||||
.validate_token(&creds.token)
|
||||
let claims = validate_token(state, &creds.token)
|
||||
.await
|
||||
.map_err(AuthorizationRejection::Custom)?;
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ use crate::auth::{
|
||||
SpacetimeIdentityToken,
|
||||
};
|
||||
use crate::routes::subscribe::generate_random_connection_id;
|
||||
use crate::util::{ByteStringBody, NameOrIdentity};
|
||||
pub use crate::util::{ByteStringBody, NameOrIdentity};
|
||||
use crate::{log_and_500, ControlStateDelegate, DatabaseDef, NodeDelegate};
|
||||
use axum::body::{Body, Bytes};
|
||||
use axum::extract::{Path, Query, State};
|
||||
@@ -31,7 +31,7 @@ use spacetimedb_client_api_messages::name::{
|
||||
};
|
||||
use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9;
|
||||
use spacetimedb_lib::identity::AuthCtx;
|
||||
use spacetimedb_lib::{sats, Timestamp};
|
||||
use spacetimedb_lib::{sats, ProductValue, Timestamp};
|
||||
use spacetimedb_schema::auto_migrate::{
|
||||
MigrationPolicy as SchemaMigrationPolicy, MigrationToken, PrettyPrintStyle as AutoMigratePrettyPrintStyle,
|
||||
};
|
||||
@@ -383,7 +383,7 @@ pub(crate) async fn worker_ctx_find_database(
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SqlParams {
|
||||
name_or_identity: NameOrIdentity,
|
||||
pub name_or_identity: NameOrIdentity,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -391,16 +391,16 @@ pub struct SqlQueryParams {
|
||||
/// If `true`, return the query result only after its transaction offset
|
||||
/// is confirmed to be durable.
|
||||
#[serde(default)]
|
||||
confirmed: bool,
|
||||
pub confirmed: bool,
|
||||
}
|
||||
|
||||
pub async fn sql<S>(
|
||||
State(worker_ctx): State<S>,
|
||||
Path(SqlParams { name_or_identity }): Path<SqlParams>,
|
||||
Query(SqlQueryParams { confirmed }): Query<SqlQueryParams>,
|
||||
Extension(auth): Extension<SpacetimeAuth>,
|
||||
body: String,
|
||||
) -> axum::response::Result<impl IntoResponse>
|
||||
pub async fn sql_direct<S>(
|
||||
worker_ctx: S,
|
||||
SqlParams { name_or_identity }: SqlParams,
|
||||
SqlQueryParams { confirmed }: SqlQueryParams,
|
||||
caller_identity: Identity,
|
||||
sql: String,
|
||||
) -> axum::response::Result<Vec<SqlStmtResult<ProductValue>>>
|
||||
where
|
||||
S: NodeDelegate + ControlStateDelegate,
|
||||
{
|
||||
@@ -412,7 +412,7 @@ where
|
||||
.await?
|
||||
.ok_or(NO_SUCH_DATABASE)?;
|
||||
|
||||
let auth = AuthCtx::new(database.owner_identity, auth.identity);
|
||||
let auth = AuthCtx::new(database.owner_identity, caller_identity);
|
||||
log::debug!("auth: {auth:?}");
|
||||
|
||||
let host = worker_ctx
|
||||
@@ -420,7 +420,21 @@ where
|
||||
.await
|
||||
.map_err(log_and_500)?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
let json = host.exec_sql(auth, database, confirmed, body).await?;
|
||||
|
||||
host.exec_sql(auth, database, confirmed, sql).await
|
||||
}
|
||||
|
||||
pub async fn sql<S>(
|
||||
State(worker_ctx): State<S>,
|
||||
Path(name_or_identity): Path<SqlParams>,
|
||||
Query(params): Query<SqlQueryParams>,
|
||||
Extension(auth): Extension<SpacetimeAuth>,
|
||||
body: String,
|
||||
) -> axum::response::Result<impl IntoResponse>
|
||||
where
|
||||
S: NodeDelegate + ControlStateDelegate,
|
||||
{
|
||||
let json = sql_direct(worker_ctx, name_or_identity, params, auth.identity, body).await?;
|
||||
|
||||
let total_duration = json.iter().fold(0, |acc, x| acc + x.total_duration_micros);
|
||||
|
||||
@@ -488,7 +502,9 @@ pub struct PublishDatabaseQueryParams {
|
||||
policy: MigrationPolicy,
|
||||
}
|
||||
|
||||
use spacetimedb_client_api_messages::http::SqlStmtResult;
|
||||
use std::env;
|
||||
|
||||
fn require_spacetime_auth_for_creation() -> bool {
|
||||
env::var("TEMP_REQUIRE_SPACETIME_AUTH").is_ok_and(|v| !v.is_empty())
|
||||
}
|
||||
@@ -526,7 +542,7 @@ pub async fn publish<S: NodeDelegate + ControlStateDelegate>(
|
||||
// so, unless you are the owner, this will fail.
|
||||
|
||||
let (database_identity, db_name) = match &name_or_identity {
|
||||
Some(noa) => match noa.try_resolve(&ctx).await? {
|
||||
Some(noa) => match noa.try_resolve(&ctx).await.map_err(log_and_500)? {
|
||||
Ok(resolved) => (resolved, noa.name()),
|
||||
Err(name) => {
|
||||
// `name_or_identity` was a `NameOrIdentity::Name`, but no record
|
||||
|
||||
@@ -97,10 +97,10 @@ impl NameOrIdentity {
|
||||
pub async fn try_resolve(
|
||||
&self,
|
||||
ctx: &(impl ControlStateReadAccess + ?Sized),
|
||||
) -> axum::response::Result<Result<Identity, &DatabaseName>> {
|
||||
) -> anyhow::Result<Result<Identity, &DatabaseName>> {
|
||||
Ok(match self {
|
||||
Self::Identity(identity) => Ok(Identity::from(*identity)),
|
||||
Self::Name(name) => ctx.lookup_identity(name.as_ref()).map_err(log_and_500)?.ok_or(name),
|
||||
Self::Name(name) => ctx.lookup_identity(name.as_ref())?.ok_or(name),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -108,7 +108,10 @@ impl NameOrIdentity {
|
||||
/// response if `self` is a [`NameOrIdentity::Name`] for which no
|
||||
/// corresponding [`Identity`] is found in the SpacetimeDB DNS.
|
||||
pub async fn resolve(&self, ctx: &(impl ControlStateReadAccess + ?Sized)) -> axum::response::Result<Identity> {
|
||||
self.try_resolve(ctx).await?.map_err(|_| StatusCode::NOT_FOUND.into())
|
||||
self.try_resolve(ctx)
|
||||
.await
|
||||
.map_err(log_and_500)?
|
||||
.map_err(|_| StatusCode::NOT_FOUND.into())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ pub struct JwtKeys {
|
||||
pub public: DecodingKey,
|
||||
pub public_pem: Box<[u8]>,
|
||||
pub private: EncodingKey,
|
||||
pub private_pem: Box<[u8]>,
|
||||
pub kid: Option<String>,
|
||||
}
|
||||
|
||||
@@ -23,15 +24,17 @@ impl JwtKeys {
|
||||
/// respectively.
|
||||
///
|
||||
/// The key files must be PEM encoded ECDSA P256 keys.
|
||||
pub fn new(public_pem: impl Into<Box<[u8]>>, private_pem: &[u8]) -> anyhow::Result<Self> {
|
||||
pub fn new(public_pem: impl Into<Box<[u8]>>, private_pem: impl Into<Box<[u8]>>) -> anyhow::Result<Self> {
|
||||
let public_pem = public_pem.into();
|
||||
let private_pem = private_pem.into();
|
||||
let public = DecodingKey::from_ec_pem(&public_pem)?;
|
||||
let private = EncodingKey::from_ec_pem(private_pem)?;
|
||||
let private = EncodingKey::from_ec_pem(&private_pem)?;
|
||||
|
||||
Ok(Self {
|
||||
public,
|
||||
private,
|
||||
public_pem,
|
||||
private_pem,
|
||||
kid: None,
|
||||
})
|
||||
}
|
||||
@@ -75,7 +78,7 @@ pub struct EcKeyPair {
|
||||
impl TryFrom<EcKeyPair> for JwtKeys {
|
||||
type Error = anyhow::Error;
|
||||
fn try_from(pair: EcKeyPair) -> anyhow::Result<Self> {
|
||||
JwtKeys::new(pair.public_key_bytes, &pair.private_key_bytes)
|
||||
JwtKeys::new(pair.public_key_bytes, pair.private_key_bytes)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -63,6 +63,10 @@ pub struct Node {
|
||||
///
|
||||
/// If `None`, the node is not currently live.
|
||||
pub advertise_addr: Option<String>,
|
||||
/// The address this node is running its postgres API at.
|
||||
///
|
||||
/// If `None`, the node is not currently live.
|
||||
pub pg_addr: Option<String>,
|
||||
}
|
||||
#[derive(Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct NodeStatus {
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
[package]
|
||||
name = "spacetimedb-pg"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
rust-version.workspace = true
|
||||
license-file = "LICENSE"
|
||||
description = "Postgres wire protocol Server support for SpacetimeDB"
|
||||
|
||||
[dependencies]
|
||||
spacetimedb-client-api-messages.workspace = true
|
||||
spacetimedb-client-api.workspace = true
|
||||
spacetimedb-lib.workspace = true
|
||||
|
||||
anyhow.workspace = true
|
||||
async-trait.workspace = true
|
||||
axum.workspace = true
|
||||
futures.workspace = true
|
||||
http.workspace = true
|
||||
log.workspace = true
|
||||
pgwire.workspace = true
|
||||
thiserror.workspace = true
|
||||
tokio.workspace = true
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../licenses/BSL.txt
|
||||
@@ -0,0 +1,3 @@
|
||||
> ⚠️ **Internal Crate** ⚠️
|
||||
>
|
||||
> This crate is intended for internal use only. It is **not** stable and may change without notice.
|
||||
@@ -0,0 +1,301 @@
|
||||
use crate::pg_server::PgError;
|
||||
use pgwire::api::portal::Format;
|
||||
use pgwire::api::results::{DataRowEncoder, FieldInfo};
|
||||
use pgwire::api::Type;
|
||||
use spacetimedb_lib::sats::satn::{PsqlChars, PsqlPrintFmt, PsqlType, TypedWriter};
|
||||
use spacetimedb_lib::sats::{satn, ValueWithType};
|
||||
use spacetimedb_lib::{
|
||||
ser, AlgebraicType, AlgebraicValue, ProductType, ProductTypeElement, ProductValue, TimeDuration, Timestamp,
|
||||
};
|
||||
use std::borrow::Cow;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub(crate) fn row_desc(schema: &ProductType, format: &Format) -> Arc<Vec<FieldInfo>> {
|
||||
Arc::new(
|
||||
schema
|
||||
.elements
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(pos, ty)| {
|
||||
let field_name = ty.name.clone().map(Into::into).unwrap_or_else(|| format!("col_{pos}"));
|
||||
let field_type = type_of(schema, ty);
|
||||
FieldInfo::new(field_name, None, None, field_type, format.format_for(pos))
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn type_of(schema: &ProductType, ty: &ProductTypeElement) -> Type {
|
||||
let format = PsqlPrintFmt::use_fmt(schema, ty, ty.name());
|
||||
match &ty.algebraic_type {
|
||||
AlgebraicType::String => Type::VARCHAR,
|
||||
AlgebraicType::Bool => Type::BOOL,
|
||||
AlgebraicType::U8 | AlgebraicType::I8 | AlgebraicType::I16 => Type::INT2,
|
||||
AlgebraicType::U16 | AlgebraicType::I32 => Type::INT4,
|
||||
AlgebraicType::U32 | AlgebraicType::I64 => Type::INT8,
|
||||
AlgebraicType::U64 | AlgebraicType::I128 | AlgebraicType::U128 | AlgebraicType::I256 | AlgebraicType::U256 => {
|
||||
Type::NUMERIC
|
||||
}
|
||||
AlgebraicType::F32 => Type::FLOAT4,
|
||||
AlgebraicType::F64 => Type::FLOAT8,
|
||||
AlgebraicType::Array(ty) => match *ty.elem_ty {
|
||||
AlgebraicType::String => Type::VARCHAR_ARRAY,
|
||||
AlgebraicType::Bool => Type::BOOL_ARRAY,
|
||||
AlgebraicType::U8 => Type::BYTEA,
|
||||
AlgebraicType::I8 | AlgebraicType::I16 => Type::INT2_ARRAY,
|
||||
AlgebraicType::U16 | AlgebraicType::I32 => Type::INT4_ARRAY,
|
||||
AlgebraicType::U32 | AlgebraicType::I64 => Type::INT8_ARRAY,
|
||||
AlgebraicType::U64
|
||||
| AlgebraicType::I128
|
||||
| AlgebraicType::U128
|
||||
| AlgebraicType::I256
|
||||
| AlgebraicType::U256 => Type::NUMERIC_ARRAY,
|
||||
_ => Type::ANYARRAY,
|
||||
},
|
||||
AlgebraicType::Product(_) => match format {
|
||||
PsqlPrintFmt::Hex => Type::BYTEA_ARRAY,
|
||||
PsqlPrintFmt::Timestamp => Type::TIMESTAMP,
|
||||
PsqlPrintFmt::Duration => Type::INTERVAL,
|
||||
_ => Type::JSON,
|
||||
},
|
||||
AlgebraicType::Sum(sum) if sum.is_simple_enum() => Type::ANYENUM,
|
||||
AlgebraicType::Sum(_) => Type::JSON,
|
||||
_ => Type::UNKNOWN,
|
||||
}
|
||||
}
|
||||
|
||||
impl ser::Error for PgError {
|
||||
fn custom<T: std::fmt::Display>(msg: T) -> Self {
|
||||
PgError::Other(anyhow::anyhow!(msg.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct PsqlFormatter<'a> {
|
||||
pub(crate) encoder: &'a mut DataRowEncoder,
|
||||
}
|
||||
|
||||
impl TypedWriter for PsqlFormatter<'_> {
|
||||
type Error = PgError;
|
||||
|
||||
fn write<W: std::fmt::Display>(&mut self, value: W) -> Result<(), Self::Error> {
|
||||
self.encoder.encode_field(&value.to_string())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_bool(&mut self, value: bool) -> Result<(), Self::Error> {
|
||||
self.encoder.encode_field(&value)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_string(&mut self, value: &str) -> Result<(), Self::Error> {
|
||||
self.encoder.encode_field(&value)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_bytes(&mut self, value: &[u8]) -> Result<(), Self::Error> {
|
||||
self.encoder.encode_field(&value)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_hex(&mut self, value: &[u8]) -> Result<(), Self::Error> {
|
||||
self.encoder.encode_field(&value)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_timestamp(&mut self, value: Timestamp) -> Result<(), Self::Error> {
|
||||
self.encoder.encode_field(&value.to_rfc3339()?)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_duration(&mut self, value: TimeDuration) -> Result<(), Self::Error> {
|
||||
self.encoder.encode_field(&value.to_iso8601())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_alt_record(
|
||||
&mut self,
|
||||
ty: &PsqlType,
|
||||
value: &ValueWithType<'_, ProductValue>,
|
||||
) -> Result<bool, Self::Error> {
|
||||
let json = satn::PsqlWrapper { ty: ty.clone(), value }.to_string();
|
||||
self.encoder.encode_field(&json)?;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
fn write_record(
|
||||
&mut self,
|
||||
_fields: Vec<(Cow<str>, PsqlType, ValueWithType<AlgebraicValue>)>,
|
||||
) -> Result<(), Self::Error> {
|
||||
unreachable!("Use `write_alt_record` for records in PSQL format");
|
||||
}
|
||||
|
||||
fn write_variant(
|
||||
&mut self,
|
||||
tag: u8,
|
||||
ty: PsqlType,
|
||||
name: Option<&str>,
|
||||
value: ValueWithType<AlgebraicValue>,
|
||||
) -> Result<(), Self::Error> {
|
||||
// Is a simple enum?
|
||||
if let AlgebraicType::Sum(sum) = &ty.field.algebraic_type {
|
||||
if sum.is_simple_enum() {
|
||||
if let Some(variant_name) = name {
|
||||
self.encoder.encode_field(&variant_name)?;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let PsqlChars { start, sep, end, quote } = ty.client.format_chars();
|
||||
let name = name.map(Cow::from).unwrap_or_else(|| Cow::from(tag.to_string()));
|
||||
let json = format!(
|
||||
"{start}{quote}{name}{quote}{sep} {}{end}",
|
||||
satn::PsqlWrapper { ty, value }
|
||||
);
|
||||
self.encoder.encode_field(&json)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::pg_server::to_rows;
|
||||
use futures::StreamExt;
|
||||
use spacetimedb_client_api_messages::http::SqlStmtResult;
|
||||
use spacetimedb_lib::sats::algebraic_value::Packed;
|
||||
use spacetimedb_lib::sats::{i256, product, u256, AlgebraicType, ProductType, SumTypeVariant};
|
||||
use spacetimedb_lib::{ConnectionId, Identity};
|
||||
|
||||
async fn run(schema: ProductType, row: ProductValue) -> String {
|
||||
let header = row_desc(&schema, &Format::UnifiedText);
|
||||
|
||||
let stmt = SqlStmtResult {
|
||||
schema,
|
||||
rows: vec![row],
|
||||
total_duration_micros: 0,
|
||||
stats: Default::default(),
|
||||
};
|
||||
let mut stream = to_rows(stmt, header).unwrap();
|
||||
let mut result = String::new();
|
||||
if let Some(row) = stream.next().await {
|
||||
result = String::from_utf8_lossy(row.unwrap().data.freeze().as_ref()).to_string();
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_primitives() {
|
||||
let schema = ProductType::from([
|
||||
AlgebraicType::U8,
|
||||
AlgebraicType::I8,
|
||||
AlgebraicType::I16,
|
||||
AlgebraicType::U16,
|
||||
AlgebraicType::I32,
|
||||
AlgebraicType::U32,
|
||||
AlgebraicType::I64,
|
||||
AlgebraicType::U64,
|
||||
AlgebraicType::I128,
|
||||
AlgebraicType::U128,
|
||||
AlgebraicType::I256,
|
||||
AlgebraicType::U256,
|
||||
AlgebraicType::F32,
|
||||
AlgebraicType::F64,
|
||||
AlgebraicType::String,
|
||||
AlgebraicType::Bool,
|
||||
]);
|
||||
let value = product![
|
||||
1u8,
|
||||
-1i8,
|
||||
-2i16,
|
||||
3u16,
|
||||
-4i32,
|
||||
5u32,
|
||||
-6i64,
|
||||
7u64,
|
||||
Packed::from(-8i128),
|
||||
Packed::from(9u128),
|
||||
i256::from(-10),
|
||||
u256::from(11u128),
|
||||
12.34f32,
|
||||
56.78f64,
|
||||
"test".to_string(),
|
||||
true,
|
||||
];
|
||||
|
||||
let row = run(schema, value).await;
|
||||
assert_eq!(row, "\0\0\0\u{1}1\0\0\0\u{2}-1\0\0\0\u{2}-2\0\0\0\u{1}3\0\0\0\u{2}-4\0\0\0\u{1}5\0\0\0\u{2}-6\0\0\0\u{1}7\0\0\0\u{2}-8\0\0\0\u{1}9\0\0\0\u{3}-10\0\0\0\u{2}11\0\0\0\u{5}12.34\0\0\0\u{5}56.78\0\0\0\u{4}test\0\0\0\u{1}t");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_enum() {
|
||||
let some = AlgebraicType::option(AlgebraicType::I64);
|
||||
let schema = ProductType::from([some.clone(), some]);
|
||||
let value = product![
|
||||
AlgebraicValue::sum(0, AlgebraicValue::I64(1)), // Some(1)
|
||||
AlgebraicValue::sum(1, AlgebraicValue::unit()), // None
|
||||
];
|
||||
|
||||
let row = run(schema, value).await;
|
||||
assert_eq!(row, "\0\0\0\u{b}{\"some\": 1}\0\0\0\u{c}{\"none\": {}}");
|
||||
|
||||
let color = AlgebraicType::Sum([SumTypeVariant::new_named(AlgebraicType::I64, "Gray")].into());
|
||||
let nested = AlgebraicType::option(color.clone());
|
||||
let schema = ProductType::from([color, nested]);
|
||||
// {"Gray": 1}, {"some": {"Gray": 2}}
|
||||
let value = product![
|
||||
AlgebraicValue::sum(0, AlgebraicValue::I64(1)), // Gray(1)
|
||||
AlgebraicValue::sum(0, AlgebraicValue::sum(0, AlgebraicValue::I64(2))), // Some(Gray(2))
|
||||
];
|
||||
let row = run(schema.clone(), value.clone()).await;
|
||||
assert_eq!(row, "\0\0\0\u{b}{\"Gray\": 1}\0\0\0\u{15}{\"some\": {\"Gray\": 2}}");
|
||||
|
||||
// Now nested product
|
||||
let product = AlgebraicType::product([
|
||||
ProductTypeElement::new(AlgebraicType::Product(schema), Some("x".into())),
|
||||
ProductTypeElement::new(AlgebraicType::String, Some("y".into())),
|
||||
]);
|
||||
let schema = ProductType::from([product.clone()]);
|
||||
let value = product![AlgebraicValue::product(vec![
|
||||
value.into(),
|
||||
AlgebraicValue::String("a".into()),
|
||||
])];
|
||||
let row = run(schema, value).await;
|
||||
assert_eq!(
|
||||
row,
|
||||
"\0\0\0G{\"x\": {\"col_0\": {\"Gray\": 1}, \"col_1\": {\"some\": {\"Gray\": 2}}}, \"y\": \"a\"}"
|
||||
);
|
||||
|
||||
// Now a simple enum
|
||||
let names = AlgebraicType::simple_enum(["A", "B", "C"].into_iter());
|
||||
let schema = ProductType::from([names.clone(), names.clone(), names]);
|
||||
let value = product![
|
||||
AlgebraicValue::enum_simple(0), // A
|
||||
AlgebraicValue::enum_simple(1), // B
|
||||
AlgebraicValue::enum_simple(2), // C
|
||||
];
|
||||
let row = run(schema, value).await;
|
||||
assert_eq!(row, "\0\0\0\u{1}A\0\0\0\u{1}B\0\0\0\u{1}C");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_special_types() {
|
||||
let schema = ProductType::from([
|
||||
AlgebraicType::identity(),
|
||||
AlgebraicType::connection_id(),
|
||||
AlgebraicType::time_duration(),
|
||||
AlgebraicType::timestamp(),
|
||||
AlgebraicType::bytes(),
|
||||
]);
|
||||
let value = product![
|
||||
Identity::ZERO,
|
||||
ConnectionId::ZERO,
|
||||
TimeDuration::from_micros(0),
|
||||
Timestamp::from_micros_since_unix_epoch(1622545800000),
|
||||
AlgebraicValue::Bytes("test".as_bytes().into()),
|
||||
];
|
||||
|
||||
let row = run(schema, value).await;
|
||||
assert_eq!(row, "\0\0\0B\\x0000000000000000000000000000000000000000000000000000000000000000\0\0\0\"\\x00000000000000000000000000000000\0\0\0\u{3}P0D\0\0\0\u{1d}1970-01-19T18:42:25.800+00:00\0\0\0\n\\x74657374");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
mod encoder;
|
||||
pub mod pg_server;
|
||||
@@ -0,0 +1,381 @@
|
||||
use std::fmt::Debug;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::encoder::{row_desc, PsqlFormatter};
|
||||
use async_trait::async_trait;
|
||||
use axum::body::to_bytes;
|
||||
use axum::response::IntoResponse;
|
||||
use futures::{stream, Sink};
|
||||
use futures::{SinkExt, Stream};
|
||||
use http::StatusCode;
|
||||
use pgwire::api::auth::{
|
||||
finish_authentication, save_startup_parameters_to_metadata, DefaultServerParameterProvider, LoginInfo,
|
||||
StartupHandler,
|
||||
};
|
||||
use pgwire::api::portal::Format;
|
||||
use pgwire::api::query::SimpleQueryHandler;
|
||||
use pgwire::api::results::{DataRowEncoder, FieldInfo, QueryResponse, Response, Tag};
|
||||
use pgwire::api::{ClientInfo, METADATA_DATABASE};
|
||||
use pgwire::api::{PgWireConnectionState, PgWireServerHandlers};
|
||||
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
|
||||
use pgwire::messages::data::DataRow;
|
||||
use pgwire::messages::startup::Authentication;
|
||||
use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
|
||||
use pgwire::tokio::process_socket;
|
||||
use spacetimedb_client_api::auth::validate_token;
|
||||
use spacetimedb_client_api::routes::database;
|
||||
use spacetimedb_client_api::routes::database::{SqlParams, SqlQueryParams};
|
||||
use spacetimedb_client_api::{ControlStateReadAccess, ControlStateWriteAccess, NodeDelegate};
|
||||
use spacetimedb_client_api_messages::http::SqlStmtResult;
|
||||
use spacetimedb_client_api_messages::name::DatabaseName;
|
||||
use spacetimedb_lib::sats::satn::{PsqlClient, TypedSerializer};
|
||||
use spacetimedb_lib::sats::{satn, Serialize, Typespace};
|
||||
use spacetimedb_lib::version::spacetimedb_lib_version;
|
||||
use spacetimedb_lib::{Identity, ProductValue};
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::{Mutex, Notify};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub(crate) enum PgError {
|
||||
#[error("(metadata) {0}")]
|
||||
MetadataError(anyhow::Error),
|
||||
#[error("(Sql) {0}")]
|
||||
Sql(String),
|
||||
#[error("Database name is required")]
|
||||
DatabaseNameRequired,
|
||||
#[error(transparent)]
|
||||
Pg(#[from] PgWireError),
|
||||
#[error("SSL is not supported by SpacetimeDB")]
|
||||
SSLNotSupported,
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
impl From<PgError> for PgWireError {
|
||||
fn from(err: PgError) -> Self {
|
||||
if let PgError::Pg(err) = err {
|
||||
err
|
||||
} else {
|
||||
PgWireError::ApiError(Box::new(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Metadata {
|
||||
database: String,
|
||||
caller_identity: Identity,
|
||||
}
|
||||
|
||||
pub(crate) fn to_rows(
|
||||
stmt: SqlStmtResult<ProductValue>,
|
||||
header: Arc<Vec<FieldInfo>>,
|
||||
) -> Result<impl Stream<Item = PgWireResult<DataRow>>, PgError> {
|
||||
let mut results = Vec::with_capacity(stmt.rows.len());
|
||||
let ty = Typespace::EMPTY.with_type(&stmt.schema);
|
||||
|
||||
for row in stmt.rows {
|
||||
let mut encoder = DataRowEncoder::new(header.clone());
|
||||
|
||||
for (idx, value) in ty.with_values(&row).enumerate() {
|
||||
let ty = satn::PsqlType {
|
||||
client: PsqlClient::Postgres,
|
||||
tuple: ty.ty(),
|
||||
field: &ty.ty().elements[idx],
|
||||
idx,
|
||||
};
|
||||
let mut fmt = PsqlFormatter { encoder: &mut encoder };
|
||||
value.serialize(TypedSerializer { ty: &ty, f: &mut fmt })?;
|
||||
}
|
||||
results.push(encoder.finish());
|
||||
}
|
||||
Ok(stream::iter(results))
|
||||
}
|
||||
|
||||
fn stats(stmt: &SqlStmtResult<ProductValue>) -> String {
|
||||
let mut info = Vec::new();
|
||||
if stmt.stats.rows_inserted != 0 {
|
||||
info.push(format!("inserted: {}", stmt.stats.rows_inserted));
|
||||
}
|
||||
if stmt.stats.rows_deleted != 0 {
|
||||
info.push(format!("deleted: {}", stmt.stats.rows_deleted));
|
||||
}
|
||||
if stmt.stats.rows_updated != 0 {
|
||||
info.push(format!("updated: {}", stmt.stats.rows_updated));
|
||||
}
|
||||
info.push(format!(
|
||||
"server: {:.2?}",
|
||||
std::time::Duration::from_micros(stmt.total_duration_micros)
|
||||
));
|
||||
|
||||
info.join(", ")
|
||||
}
|
||||
|
||||
struct ResponseWrapper<T>(T);
|
||||
impl<T> IntoResponse for ResponseWrapper<T> {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
unreachable!("Blank impl to satisfy IntoResponse")
|
||||
}
|
||||
}
|
||||
|
||||
async fn response<T>(res: axum::response::Result<T>, database: &str) -> Result<T, PgError> {
|
||||
match res.map(ResponseWrapper) {
|
||||
Ok(sql) => Ok(sql.0),
|
||||
err => {
|
||||
let res = err.into_response();
|
||||
if res.status() == StatusCode::NOT_FOUND {
|
||||
log::error!("PG: Database not found: {database}");
|
||||
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
|
||||
"FATAL".to_string(),
|
||||
"3D000".to_string(),
|
||||
format!("database \"{database}\" does not exist"),
|
||||
)))
|
||||
.into());
|
||||
}
|
||||
let bytes = to_bytes(res.into_body(), usize::MAX)
|
||||
.await
|
||||
.map_err(|err| PgWireError::ApiError(Box::new(err)))?;
|
||||
let err = String::from_utf8_lossy(&bytes);
|
||||
log::error!("PG: Error for database {database}: {err}");
|
||||
Err(PgError::Sql(format!("{err}")))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct PgSpacetimeDB<T> {
|
||||
ctx: T,
|
||||
cached: Mutex<Option<Metadata>>,
|
||||
parameter_provider: DefaultServerParameterProvider,
|
||||
}
|
||||
|
||||
impl<T: ControlStateReadAccess + ControlStateWriteAccess + NodeDelegate + Clone> PgSpacetimeDB<T> {
|
||||
async fn exe_sql<'a>(&self, query: String) -> PgWireResult<Vec<Response<'a>>> {
|
||||
let params = self.cached.lock().await.clone().unwrap();
|
||||
let db = SqlParams {
|
||||
name_or_identity: database::NameOrIdentity::Name(DatabaseName(params.database.clone())),
|
||||
};
|
||||
|
||||
let sql = match response(
|
||||
database::sql_direct(
|
||||
self.ctx.clone(),
|
||||
db,
|
||||
SqlQueryParams { confirmed: true },
|
||||
params.caller_identity,
|
||||
query.to_string(),
|
||||
)
|
||||
.await,
|
||||
¶ms.database,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(sql) => sql,
|
||||
Err(PgError::Pg(PgWireError::UserError(err))) => {
|
||||
return Ok(vec![Response::Error(err)]);
|
||||
}
|
||||
Err(err) => {
|
||||
return Err(err.into());
|
||||
}
|
||||
};
|
||||
|
||||
let mut result = Vec::with_capacity(sql.len());
|
||||
for sql_result in sql {
|
||||
let header = row_desc(&sql_result.schema, &Format::UnifiedText);
|
||||
if sql_result.rows.is_empty() && !query.to_uppercase().contains("SELECT") {
|
||||
let tag = Tag::new(&stats(&sql_result));
|
||||
result.push(Response::Execution(tag));
|
||||
} else {
|
||||
let rows = to_rows(sql_result, header.clone())?;
|
||||
let q = QueryResponse::new(header, rows);
|
||||
result.push(Response::Query(q));
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
async fn close_client<C, E>(client: &mut C, err: E) -> PgWireResult<()>
|
||||
where
|
||||
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
|
||||
C::Error: Debug,
|
||||
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
|
||||
pgwire::messages::response::ErrorResponse: From<E>,
|
||||
{
|
||||
let err = pgwire::messages::response::ErrorResponse::from(err);
|
||||
client.feed(PgWireBackendMessage::ErrorResponse(err)).await?;
|
||||
client.close().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T: Sync + Send + ControlStateReadAccess + ControlStateWriteAccess + NodeDelegate> StartupHandler
|
||||
for PgSpacetimeDB<T>
|
||||
{
|
||||
async fn on_startup<C>(&self, client: &mut C, message: PgWireFrontendMessage) -> PgWireResult<()>
|
||||
where
|
||||
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
|
||||
C::Error: Debug,
|
||||
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
|
||||
{
|
||||
match message {
|
||||
PgWireFrontendMessage::Startup(ref startup) => {
|
||||
save_startup_parameters_to_metadata(client, startup);
|
||||
client.set_state(PgWireConnectionState::AuthenticationInProgress);
|
||||
|
||||
let login_info = LoginInfo::from_client_info(client);
|
||||
|
||||
if login_info.database().is_none() {
|
||||
return Err(PgError::DatabaseNameRequired.into());
|
||||
}
|
||||
|
||||
client
|
||||
.send(PgWireBackendMessage::Authentication(Authentication::CleartextPassword))
|
||||
.await?;
|
||||
}
|
||||
PgWireFrontendMessage::PasswordMessageFamily(pwd) => {
|
||||
let params = client.metadata();
|
||||
let param = |param: &str| {
|
||||
params
|
||||
.get(param)
|
||||
.map(String::from)
|
||||
.ok_or_else(|| PgError::MetadataError(anyhow::anyhow!("Missing parameter: {}", param)))
|
||||
};
|
||||
|
||||
// We don't support `METADATA_USER` because we don't have a user management system.
|
||||
let database = param(METADATA_DATABASE)?;
|
||||
let pwd = pwd.into_password()?;
|
||||
if let Ok(application_name) = param("application_name") {
|
||||
log::info!("PG: Connecting to database: {database}, by {application_name}",);
|
||||
} else {
|
||||
log::info!("PG: Connecting to database: {database}");
|
||||
}
|
||||
|
||||
let name = database::NameOrIdentity::Name(DatabaseName(database.clone()));
|
||||
match response(name.resolve(&self.ctx).await, &database).await {
|
||||
Ok(identity) => identity,
|
||||
Err(PgError::Pg(PgWireError::UserError(err))) => {
|
||||
return close_client(client, *err).await;
|
||||
}
|
||||
Err(err) => {
|
||||
return Err(err.into());
|
||||
}
|
||||
};
|
||||
|
||||
let caller_identity = match validate_token(&self.ctx, &pwd.password).await {
|
||||
Ok(claims) => claims.identity,
|
||||
Err(err) => {
|
||||
log::error!(
|
||||
"PG: Authentication failed for identity `{}` on database {database}: {err}",
|
||||
pwd.password
|
||||
);
|
||||
let err = ErrorInfo::new("FATAL".to_owned(), "28P01".to_owned(), err.to_string());
|
||||
return close_client(client, err).await;
|
||||
}
|
||||
};
|
||||
|
||||
log::info!("PG: Connected to database: {database} using identity `{caller_identity}`");
|
||||
|
||||
let metadata = Metadata {
|
||||
database,
|
||||
caller_identity,
|
||||
};
|
||||
self.cached.lock().await.clone_from(&Some(metadata));
|
||||
finish_authentication(client, &self.parameter_provider).await?;
|
||||
}
|
||||
PgWireFrontendMessage::SslRequest(_) => {
|
||||
let err = PgError::SSLNotSupported;
|
||||
log::error!("{err}");
|
||||
let err = ErrorInfo::new("FATAL".to_owned(), "28P01".to_owned(), err.to_string());
|
||||
return close_client(client, err).await;
|
||||
}
|
||||
// The other messages are for features not supported by SpacetimeDB, that are rejected by the parser.
|
||||
_ => {
|
||||
unreachable!("Unsupported startup message: {message:?}");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T: Sync + Send + ControlStateReadAccess + ControlStateWriteAccess + NodeDelegate + Clone> SimpleQueryHandler
|
||||
for PgSpacetimeDB<T>
|
||||
{
|
||||
async fn do_query<'a, C>(&self, _client: &mut C, query: &str) -> PgWireResult<Vec<Response<'a>>>
|
||||
where
|
||||
C: ClientInfo + Unpin + Send + Sync,
|
||||
{
|
||||
self.exe_sql(query.to_string()).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PgSpacetimeDBFactory<T> {
|
||||
handler: Arc<PgSpacetimeDB<T>>,
|
||||
}
|
||||
|
||||
impl<T> PgSpacetimeDBFactory<T> {
|
||||
pub fn new(ctx: T) -> Self {
|
||||
let mut parameter_provider = DefaultServerParameterProvider::default();
|
||||
parameter_provider.server_version = format!("spacetime {}", spacetimedb_lib_version());
|
||||
|
||||
Self {
|
||||
handler: Arc::new(PgSpacetimeDB {
|
||||
ctx,
|
||||
// This is a placeholder, it will be set in the startup handler
|
||||
cached: None.into(),
|
||||
parameter_provider,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Sync + Send + ControlStateReadAccess + ControlStateWriteAccess + NodeDelegate + Clone> PgWireServerHandlers
|
||||
for PgSpacetimeDBFactory<T>
|
||||
{
|
||||
fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
|
||||
self.handler.clone()
|
||||
}
|
||||
|
||||
// TODO: fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {}
|
||||
|
||||
fn startup_handler(&self) -> Arc<impl StartupHandler> {
|
||||
self.handler.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_pg<T: ControlStateReadAccess + ControlStateWriteAccess + NodeDelegate + Clone + 'static>(
|
||||
shutdown: Arc<Notify>,
|
||||
ctx: T,
|
||||
tcp: TcpListener,
|
||||
) {
|
||||
let factory = Arc::new(PgSpacetimeDBFactory::new(ctx));
|
||||
|
||||
log::debug!(
|
||||
"PG: Starting SpacetimeDB Protocol listening on {}",
|
||||
tcp.local_addr().unwrap()
|
||||
);
|
||||
loop {
|
||||
tokio::select! {
|
||||
accept_result = tcp.accept() => {
|
||||
match accept_result {
|
||||
Ok((stream, _addr)) => {
|
||||
let factory_ref = factory.clone();
|
||||
tokio::spawn(async move {
|
||||
process_socket(stream, None, factory_ref).await.inspect_err(|err|{
|
||||
log::error!("PG: Error processing socket: {err:?}");
|
||||
})
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("PG: Accept error: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = shutdown.notified() => {
|
||||
log::info!("PG: Shutting down PostgreSQL server.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
+398
-227
@@ -1,13 +1,14 @@
|
||||
use crate::de::DeserializeSeed;
|
||||
use crate::time_duration::TimeDuration;
|
||||
use crate::timestamp::Timestamp;
|
||||
use crate::{i256, u256, AlgebraicValue, WithTypespace};
|
||||
use crate::{i256, u256, AlgebraicType, AlgebraicValue, ProductValue, Serialize, SumValue, ValueWithType};
|
||||
use crate::{ser, ProductType, ProductTypeElement};
|
||||
use core::fmt;
|
||||
use core::fmt::Write as _;
|
||||
use derive_more::{From, Into};
|
||||
use derive_more::{Display, From, Into};
|
||||
use std::borrow::Cow;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// An extension trait for [`Serialize`](ser::Serialize) providing formatting methods.
|
||||
/// An extension trait for [`Serialize`] providing formatting methods.
|
||||
pub trait Satn: ser::Serialize {
|
||||
/// Formats the value using the SATN data format into the formatter `f`.
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
@@ -18,9 +19,12 @@ pub trait Satn: ser::Serialize {
|
||||
/// Formats the value using the postgres SATN(PsqlFormatter { f }, /* PsqlType */) formatter `f`.
|
||||
fn fmt_psql(&self, f: &mut fmt::Formatter, ty: &PsqlType<'_>) -> fmt::Result {
|
||||
Writer::with(f, |f| {
|
||||
self.serialize(PsqlFormatter {
|
||||
fmt: SatnFormatter { f },
|
||||
self.serialize(TypedSerializer {
|
||||
ty,
|
||||
f: &mut SqlFormatter {
|
||||
fmt: SatnFormatter { f },
|
||||
ty,
|
||||
},
|
||||
})
|
||||
})?;
|
||||
Ok(())
|
||||
@@ -229,9 +233,30 @@ struct SatnFormatter<'a, 'f> {
|
||||
f: Writer<'a, 'f>,
|
||||
}
|
||||
|
||||
impl SatnFormatter<'_, '_> {
|
||||
fn ser_variant<T: ser::Serialize + ?Sized>(
|
||||
&mut self,
|
||||
_tag: u8,
|
||||
name: Option<&str>,
|
||||
value: &T,
|
||||
) -> Result<(), SatnError> {
|
||||
write!(self, "(")?;
|
||||
EntryWrapper::<','>::new(self.f.as_mut()).entry(|mut f| {
|
||||
if let Some(name) = name {
|
||||
write!(f, "{name}")?;
|
||||
}
|
||||
write!(f, " = ")?;
|
||||
value.serialize(SatnFormatter { f })?;
|
||||
Ok(())
|
||||
})?;
|
||||
write!(self, ")")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
/// An error occurred during serialization to the SATS data format.
|
||||
#[derive(From, Into)]
|
||||
struct SatnError(fmt::Error);
|
||||
pub struct SatnError(fmt::Error);
|
||||
|
||||
impl ser::Error for SatnError {
|
||||
fn custom<T: fmt::Display>(_msg: T) -> Self {
|
||||
@@ -331,20 +356,11 @@ impl<'a, 'f> ser::Serializer for SatnFormatter<'a, 'f> {
|
||||
|
||||
fn serialize_variant<T: ser::Serialize + ?Sized>(
|
||||
mut self,
|
||||
_tag: u8,
|
||||
tag: u8,
|
||||
name: Option<&str>,
|
||||
value: &T,
|
||||
) -> Result<Self::Ok, Self::Error> {
|
||||
write!(self, "(")?;
|
||||
EntryWrapper::<','>::new(self.f.as_mut()).entry(|mut f| {
|
||||
if let Some(name) = name {
|
||||
write!(f, "{name}")?;
|
||||
}
|
||||
write!(f, " = ")?;
|
||||
value.serialize(SatnFormatter { f })?;
|
||||
Ok(())
|
||||
})?;
|
||||
write!(self, ")")
|
||||
self.ser_variant(tag, name, value)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -427,119 +443,41 @@ impl ser::SerializeNamedProduct for NamedFormatter<'_, '_> {
|
||||
}
|
||||
}
|
||||
|
||||
struct PsqlEntryWrapper<'a, 'f, const SEP: char> {
|
||||
entry: EntryWrapper<'a, 'f, SEP>,
|
||||
/// The index of the element.
|
||||
idx: usize,
|
||||
ty: &'a PsqlType<'a>,
|
||||
/// Which client is used to format the `SQL` output?
|
||||
#[derive(PartialEq, Copy, Clone, Debug)]
|
||||
pub enum PsqlClient {
|
||||
SpacetimeDB,
|
||||
Postgres,
|
||||
}
|
||||
|
||||
/// Provides the data format for named products for `SQL`.
|
||||
struct PsqlNamedFormatter<'a, 'f> {
|
||||
/// The formatter for each element separating elements by a `,`.
|
||||
f: PsqlEntryWrapper<'a, 'f, ','>,
|
||||
/// If is not [Self::is_special] to control if we start with `(`
|
||||
start: bool,
|
||||
/// Remember what format we are using
|
||||
use_fmt: PsqlPrintFmt,
|
||||
pub struct PsqlChars {
|
||||
pub start: char,
|
||||
pub sep: &'static str,
|
||||
pub end: char,
|
||||
pub quote: &'static str,
|
||||
}
|
||||
|
||||
impl<'a, 'f> PsqlNamedFormatter<'a, 'f> {
|
||||
pub fn new(ty: &'a PsqlType<'a>, f: Writer<'a, 'f>) -> Self {
|
||||
Self {
|
||||
start: true,
|
||||
f: PsqlEntryWrapper {
|
||||
entry: EntryWrapper::new(f),
|
||||
idx: 0,
|
||||
ty,
|
||||
impl PsqlClient {
|
||||
pub fn format_chars(&self) -> PsqlChars {
|
||||
match self {
|
||||
PsqlClient::SpacetimeDB => PsqlChars {
|
||||
start: '(',
|
||||
sep: " =",
|
||||
end: ')',
|
||||
quote: "",
|
||||
},
|
||||
PsqlClient::Postgres => PsqlChars {
|
||||
start: '{',
|
||||
sep: ":",
|
||||
end: '}',
|
||||
quote: "\"",
|
||||
},
|
||||
// Will set later
|
||||
use_fmt: PsqlPrintFmt::Satn,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ser::SerializeNamedProduct for PsqlNamedFormatter<'_, '_> {
|
||||
type Ok = ();
|
||||
type Error = SatnError;
|
||||
|
||||
fn serialize_element<T: Satn + ser::Serialize + ?Sized>(
|
||||
&mut self,
|
||||
name: Option<&str>,
|
||||
elem: &T,
|
||||
) -> Result<(), Self::Error> {
|
||||
// For binary data & special types, output in `hex` format and skip the tagging of each value
|
||||
// We need to check for both the enclosing(`self.f.ty`) type and the inner element(`name`) type.
|
||||
self.use_fmt = self.f.ty.use_fmt(name);
|
||||
let res = self.f.entry.entry(|mut f| {
|
||||
let PsqlType { tuple, field, idx } = self.f.ty;
|
||||
if !self.use_fmt.is_special() {
|
||||
if self.start {
|
||||
write!(f, "(")?;
|
||||
self.start = false;
|
||||
}
|
||||
// Format the name or use the index if unnamed.
|
||||
if let Some(name) = name {
|
||||
write!(f, "{name}")?;
|
||||
} else {
|
||||
write!(f, "{idx}")?;
|
||||
}
|
||||
write!(f, " = ")?;
|
||||
}
|
||||
//Is a nested product type?
|
||||
let (tuple, field, idx) = if let Some(product) = field.algebraic_type.as_product() {
|
||||
(product, &product.elements[self.f.idx], self.f.idx)
|
||||
} else {
|
||||
(*tuple, *field, *idx)
|
||||
};
|
||||
|
||||
elem.serialize(PsqlFormatter {
|
||||
fmt: SatnFormatter { f },
|
||||
ty: &PsqlType { tuple, field, idx },
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
});
|
||||
|
||||
// Advance to the next field.
|
||||
if !self.use_fmt.is_special() {
|
||||
self.f.idx += 1;
|
||||
}
|
||||
|
||||
res?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn end(mut self) -> Result<Self::Ok, Self::Error> {
|
||||
if !self.use_fmt.is_special() {
|
||||
write!(self.f.entry.fmt, ")")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides the data format for unnamed products for `SQL`.
|
||||
struct PsqlSeqFormatter<'a, 'f> {
|
||||
/// Delegates to the named format.
|
||||
inner: PsqlNamedFormatter<'a, 'f>,
|
||||
}
|
||||
|
||||
impl ser::SerializeSeqProduct for PsqlSeqFormatter<'_, '_> {
|
||||
type Ok = ();
|
||||
type Error = SatnError;
|
||||
|
||||
fn serialize_element<T: ser::Serialize + ?Sized>(&mut self, elem: &T) -> Result<(), Self::Error> {
|
||||
ser::SerializeNamedProduct::serialize_element(&mut self.inner, None, elem)
|
||||
}
|
||||
|
||||
fn end(self) -> Result<Self::Ok, Self::Error> {
|
||||
ser::SerializeNamedProduct::end(self.inner)
|
||||
}
|
||||
}
|
||||
|
||||
/// How format of the `SQL` output?
|
||||
#[derive(PartialEq)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Display)]
|
||||
pub enum PsqlPrintFmt {
|
||||
/// Print as `hex` format
|
||||
Hex,
|
||||
@@ -552,14 +490,46 @@ pub enum PsqlPrintFmt {
|
||||
}
|
||||
|
||||
impl PsqlPrintFmt {
|
||||
fn is_special(&self) -> bool {
|
||||
pub fn is_special(&self) -> bool {
|
||||
self != &PsqlPrintFmt::Satn
|
||||
}
|
||||
/// Returns if the type is a special type
|
||||
///
|
||||
/// Is required to check both the enclosing type and the inner element type
|
||||
pub fn use_fmt(tuple: &ProductType, field: &ProductTypeElement, name: Option<&str>) -> PsqlPrintFmt {
|
||||
if tuple.is_identity()
|
||||
|| tuple.is_connection_id()
|
||||
|| field.algebraic_type.is_identity()
|
||||
|| field.algebraic_type.is_connection_id()
|
||||
|| name.map(ProductType::is_identity_tag).unwrap_or_default()
|
||||
|| name.map(ProductType::is_connection_id_tag).unwrap_or_default()
|
||||
{
|
||||
return PsqlPrintFmt::Hex;
|
||||
};
|
||||
|
||||
if tuple.is_timestamp()
|
||||
|| field.algebraic_type.is_timestamp()
|
||||
|| name.map(ProductType::is_timestamp_tag).unwrap_or_default()
|
||||
{
|
||||
return PsqlPrintFmt::Timestamp;
|
||||
};
|
||||
|
||||
if tuple.is_time_duration()
|
||||
|| field.algebraic_type.is_time_duration()
|
||||
|| name.map(ProductType::is_time_duration_tag).unwrap_or_default()
|
||||
{
|
||||
return PsqlPrintFmt::Duration;
|
||||
};
|
||||
|
||||
PsqlPrintFmt::Satn
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper that remember the `header` of the tuple/struct and the current field
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PsqlType<'a> {
|
||||
/// The client used to format the output
|
||||
pub client: PsqlClient,
|
||||
/// The header of the tuple/struct
|
||||
pub tuple: &'a ProductType,
|
||||
/// The current field
|
||||
@@ -572,168 +542,369 @@ impl PsqlType<'_> {
|
||||
/// Returns if the type is a special type
|
||||
///
|
||||
/// Is required to check both the enclosing type and the inner element type
|
||||
fn use_fmt(&self, name: Option<&str>) -> PsqlPrintFmt {
|
||||
if self.tuple.is_identity()
|
||||
|| self.tuple.is_connection_id()
|
||||
|| self.field.algebraic_type.is_identity()
|
||||
|| self.field.algebraic_type.is_connection_id()
|
||||
|| name.map(ProductType::is_identity_tag).unwrap_or_default()
|
||||
|| name.map(ProductType::is_connection_id_tag).unwrap_or_default()
|
||||
{
|
||||
return PsqlPrintFmt::Hex;
|
||||
};
|
||||
|
||||
if self.tuple.is_timestamp()
|
||||
|| self.field.algebraic_type.is_timestamp()
|
||||
|| name.map(ProductType::is_timestamp_tag).unwrap_or_default()
|
||||
{
|
||||
return PsqlPrintFmt::Timestamp;
|
||||
};
|
||||
|
||||
if self.tuple.is_time_duration()
|
||||
|| self.field.algebraic_type.is_time_duration()
|
||||
|| name.map(ProductType::is_time_duration_tag).unwrap_or_default()
|
||||
{
|
||||
return PsqlPrintFmt::Duration;
|
||||
};
|
||||
|
||||
PsqlPrintFmt::Satn
|
||||
pub fn use_fmt(&self) -> PsqlPrintFmt {
|
||||
PsqlPrintFmt::use_fmt(self.tuple, self.field, None)
|
||||
}
|
||||
}
|
||||
|
||||
/// An implementation of [`Serializer`](ser::Serializer) for `SQL` output.
|
||||
struct PsqlFormatter<'a, 'f> {
|
||||
pub struct SqlFormatter<'a, 'f> {
|
||||
fmt: SatnFormatter<'a, 'f>,
|
||||
ty: &'a PsqlType<'a>,
|
||||
}
|
||||
|
||||
impl<'a, 'f> ser::Serializer for PsqlFormatter<'a, 'f> {
|
||||
/// A trait for writing values, after the special types has been determined.
|
||||
///
|
||||
/// This is used to write values that could have different representations depending on the output format,
|
||||
/// as defined by [`PsqlClient`] and [`PsqlPrintFmt`].
|
||||
pub trait TypedWriter {
|
||||
type Error: ser::Error;
|
||||
|
||||
/// Writes a value using [`ser::Serializer`]
|
||||
fn write<W: fmt::Display>(&mut self, value: W) -> Result<(), Self::Error>;
|
||||
|
||||
// Values that need special handling:
|
||||
|
||||
fn write_bool(&mut self, value: bool) -> Result<(), Self::Error>;
|
||||
fn write_string(&mut self, value: &str) -> Result<(), Self::Error>;
|
||||
fn write_bytes(&mut self, value: &[u8]) -> Result<(), Self::Error>;
|
||||
fn write_hex(&mut self, value: &[u8]) -> Result<(), Self::Error>;
|
||||
fn write_timestamp(&mut self, value: Timestamp) -> Result<(), Self::Error>;
|
||||
fn write_duration(&mut self, value: TimeDuration) -> Result<(), Self::Error>;
|
||||
/// Writes a value as an alternative record format, e.g., for use `JSON` inside `SQL`.
|
||||
fn write_alt_record(
|
||||
&mut self,
|
||||
_ty: &PsqlType,
|
||||
_value: &ValueWithType<'_, ProductValue>,
|
||||
) -> Result<bool, Self::Error> {
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
fn write_record(
|
||||
&mut self,
|
||||
fields: Vec<(Cow<str>, PsqlType, ValueWithType<AlgebraicValue>)>,
|
||||
) -> Result<(), Self::Error>;
|
||||
|
||||
fn write_variant(
|
||||
&mut self,
|
||||
tag: u8,
|
||||
ty: PsqlType,
|
||||
name: Option<&str>,
|
||||
value: ValueWithType<AlgebraicValue>,
|
||||
) -> Result<(), Self::Error>;
|
||||
}
|
||||
|
||||
/// A formatter for arrays that uses the `TypedWriter` trait to write elements.
|
||||
pub struct TypedArrayFormatter<'a, 'f, F> {
|
||||
ty: &'a PsqlType<'a>,
|
||||
f: &'f mut F,
|
||||
}
|
||||
|
||||
impl<F: TypedWriter> ser::SerializeArray for TypedArrayFormatter<'_, '_, F> {
|
||||
type Ok = ();
|
||||
type Error = SatnError;
|
||||
type SerializeArray = ArrayFormatter<'a, 'f>;
|
||||
type SerializeSeqProduct = PsqlSeqFormatter<'a, 'f>;
|
||||
type SerializeNamedProduct = PsqlNamedFormatter<'a, 'f>;
|
||||
type Error = F::Error;
|
||||
|
||||
fn serialize_element<T: ser::Serialize + ?Sized>(&mut self, elem: &T) -> Result<(), Self::Error> {
|
||||
elem.serialize(TypedSerializer { ty: self.ty, f: self.f })?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn end(self) -> Result<Self::Ok, Self::Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A formatter for sequences that uses the `TypedWriter` trait to write elements.
|
||||
pub struct TypedSeqFormatter<'a, 'f, F> {
|
||||
ty: &'a PsqlType<'a>,
|
||||
f: &'f mut F,
|
||||
}
|
||||
|
||||
impl<F: TypedWriter> ser::SerializeSeqProduct for TypedSeqFormatter<'_, '_, F> {
|
||||
type Ok = ();
|
||||
type Error = F::Error;
|
||||
|
||||
fn serialize_element<T: ser::Serialize + ?Sized>(&mut self, elem: &T) -> Result<(), Self::Error> {
|
||||
elem.serialize(TypedSerializer { ty: self.ty, f: self.f })?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn end(self) -> Result<Self::Ok, Self::Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A formatter for named products that uses the `TypedWriter` trait to write elements.
|
||||
pub struct TypedNamedProductFormatter<F> {
|
||||
f: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: TypedWriter> ser::SerializeNamedProduct for TypedNamedProductFormatter<F> {
|
||||
type Ok = ();
|
||||
type Error = F::Error;
|
||||
|
||||
fn serialize_element<T: ser::Serialize + ?Sized>(
|
||||
&mut self,
|
||||
_name: Option<&str>,
|
||||
_elem: &T,
|
||||
) -> Result<(), Self::Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn end(self) -> Result<Self::Ok, Self::Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A serializer that uses the `TypedWriter` trait to serialize values
|
||||
pub struct TypedSerializer<'a, 'f, F> {
|
||||
pub ty: &'a PsqlType<'a>,
|
||||
pub f: &'f mut F,
|
||||
}
|
||||
|
||||
impl<'a, 'f, F: TypedWriter> ser::Serializer for TypedSerializer<'a, 'f, F> {
|
||||
type Ok = ();
|
||||
type Error = F::Error;
|
||||
type SerializeArray = TypedArrayFormatter<'a, 'f, F>;
|
||||
type SerializeSeqProduct = TypedSeqFormatter<'a, 'f, F>;
|
||||
type SerializeNamedProduct = TypedNamedProductFormatter<F>;
|
||||
|
||||
fn serialize_bool(self, v: bool) -> Result<Self::Ok, Self::Error> {
|
||||
self.fmt.serialize_bool(v)
|
||||
self.f.write_bool(v)
|
||||
}
|
||||
|
||||
fn serialize_u8(self, v: u8) -> Result<Self::Ok, Self::Error> {
|
||||
self.fmt.serialize_u8(v)
|
||||
self.f.write(v)
|
||||
}
|
||||
|
||||
fn serialize_u16(self, v: u16) -> Result<Self::Ok, Self::Error> {
|
||||
self.fmt.serialize_u16(v)
|
||||
self.f.write(v)
|
||||
}
|
||||
|
||||
fn serialize_u32(self, v: u32) -> Result<Self::Ok, Self::Error> {
|
||||
self.fmt.serialize_u32(v)
|
||||
self.f.write(v)
|
||||
}
|
||||
|
||||
fn serialize_u64(self, v: u64) -> Result<Self::Ok, Self::Error> {
|
||||
self.fmt.serialize_u64(v)
|
||||
self.f.write(v)
|
||||
}
|
||||
|
||||
fn serialize_u128(self, v: u128) -> Result<Self::Ok, Self::Error> {
|
||||
match self.ty.use_fmt(None) {
|
||||
PsqlPrintFmt::Hex => self.serialize_bytes(&v.to_be_bytes()),
|
||||
_ => self.fmt.serialize_u128(v),
|
||||
match self.ty.use_fmt() {
|
||||
PsqlPrintFmt::Hex => self.f.write_hex(&v.to_be_bytes()),
|
||||
_ => self.f.write(v),
|
||||
}
|
||||
}
|
||||
|
||||
fn serialize_u256(self, v: u256) -> Result<Self::Ok, Self::Error> {
|
||||
match self.ty.use_fmt(None) {
|
||||
PsqlPrintFmt::Hex => self.serialize_bytes(&v.to_be_bytes()),
|
||||
_ => self.fmt.serialize_u256(v),
|
||||
match self.ty.use_fmt() {
|
||||
PsqlPrintFmt::Hex => self.f.write_hex(&v.to_be_bytes()),
|
||||
_ => self.f.write(v),
|
||||
}
|
||||
}
|
||||
|
||||
fn serialize_i8(self, v: i8) -> Result<Self::Ok, Self::Error> {
|
||||
self.fmt.serialize_i8(v)
|
||||
self.f.write(v)
|
||||
}
|
||||
|
||||
fn serialize_i16(self, v: i16) -> Result<Self::Ok, Self::Error> {
|
||||
self.fmt.serialize_i16(v)
|
||||
self.f.write(v)
|
||||
}
|
||||
|
||||
fn serialize_i32(self, v: i32) -> Result<Self::Ok, Self::Error> {
|
||||
self.fmt.serialize_i32(v)
|
||||
self.f.write(v)
|
||||
}
|
||||
fn serialize_i64(mut self, v: i64) -> Result<Self::Ok, Self::Error> {
|
||||
match self.ty.use_fmt(None) {
|
||||
PsqlPrintFmt::Duration => {
|
||||
write!(self.fmt, "{}", TimeDuration::from_micros(v))?;
|
||||
Ok(())
|
||||
}
|
||||
PsqlPrintFmt::Timestamp => {
|
||||
write!(self.fmt, "{}", Timestamp::from_micros_since_unix_epoch(v))?;
|
||||
Ok(())
|
||||
}
|
||||
_ => self.fmt.serialize_i64(v),
|
||||
|
||||
fn serialize_i64(self, v: i64) -> Result<Self::Ok, Self::Error> {
|
||||
match self.ty.use_fmt() {
|
||||
PsqlPrintFmt::Duration => self.f.write_duration(TimeDuration::from_micros(v)),
|
||||
PsqlPrintFmt::Timestamp => self.f.write_timestamp(Timestamp::from_micros_since_unix_epoch(v)),
|
||||
_ => self.f.write(v),
|
||||
}
|
||||
}
|
||||
|
||||
fn serialize_i128(self, v: i128) -> Result<Self::Ok, Self::Error> {
|
||||
self.fmt.serialize_i128(v)
|
||||
self.f.write(v)
|
||||
}
|
||||
|
||||
fn serialize_i256(self, v: i256) -> Result<Self::Ok, Self::Error> {
|
||||
self.fmt.serialize_i256(v)
|
||||
self.f.write(v)
|
||||
}
|
||||
|
||||
fn serialize_f32(self, v: f32) -> Result<Self::Ok, Self::Error> {
|
||||
self.fmt.serialize_f32(v)
|
||||
self.f.write(v)
|
||||
}
|
||||
|
||||
fn serialize_f64(self, v: f64) -> Result<Self::Ok, Self::Error> {
|
||||
self.fmt.serialize_f64(v)
|
||||
self.f.write(v)
|
||||
}
|
||||
|
||||
fn serialize_str(self, v: &str) -> Result<Self::Ok, Self::Error> {
|
||||
self.fmt.serialize_str(v)
|
||||
self.f.write_string(v)
|
||||
}
|
||||
|
||||
fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
|
||||
self.fmt.serialize_bytes(v)
|
||||
if self.ty.use_fmt() == PsqlPrintFmt::Satn {
|
||||
self.f.write_hex(v)
|
||||
} else {
|
||||
self.f.write_bytes(v)
|
||||
}
|
||||
}
|
||||
|
||||
fn serialize_array(self, len: usize) -> Result<Self::SerializeArray, Self::Error> {
|
||||
self.fmt.serialize_array(len)
|
||||
fn serialize_array(self, _len: usize) -> Result<Self::SerializeArray, Self::Error> {
|
||||
Ok(TypedArrayFormatter { ty: self.ty, f: self.f })
|
||||
}
|
||||
|
||||
fn serialize_seq_product(self, len: usize) -> Result<Self::SerializeSeqProduct, Self::Error> {
|
||||
Ok(PsqlSeqFormatter {
|
||||
inner: self.serialize_named_product(len)?,
|
||||
})
|
||||
fn serialize_seq_product(self, _len: usize) -> Result<Self::SerializeSeqProduct, Self::Error> {
|
||||
Ok(TypedSeqFormatter { ty: self.ty, f: self.f })
|
||||
}
|
||||
|
||||
fn serialize_named_product(self, _len: usize) -> Result<Self::SerializeNamedProduct, Self::Error> {
|
||||
Ok(PsqlNamedFormatter::new(self.ty, self.fmt.f))
|
||||
unreachable!("This should never be called, use `serialize_named_product_raw` instead.");
|
||||
}
|
||||
|
||||
fn serialize_variant<T: ser::Serialize + ?Sized>(
|
||||
fn serialize_named_product_raw(self, value: &ValueWithType<'_, ProductValue>) -> Result<Self::Ok, Self::Error> {
|
||||
let val = &value.val.elements;
|
||||
assert_eq!(val.len(), value.ty().elements.len());
|
||||
// If the value is a special type, we can write it directly
|
||||
if self.ty.use_fmt().is_special() {
|
||||
// Is a nested product type?
|
||||
// We need to check for both the enclosing(`self.ty`) type and the inner element type.
|
||||
let (tuple, field) = if let Some(product) = self.ty.field.algebraic_type.as_product() {
|
||||
(product, &product.elements[0])
|
||||
} else {
|
||||
(self.ty.tuple, self.ty.field)
|
||||
};
|
||||
return value.val.serialize(TypedSerializer {
|
||||
ty: &PsqlType {
|
||||
client: self.ty.client,
|
||||
tuple,
|
||||
field,
|
||||
idx: self.ty.idx,
|
||||
},
|
||||
f: self.f,
|
||||
});
|
||||
}
|
||||
// Allow to switch to an alternative record format, for example to write a `JSON` record.
|
||||
if self.f.write_alt_record(self.ty, value)? {
|
||||
return Ok(());
|
||||
}
|
||||
let mut record = Vec::with_capacity(val.len());
|
||||
|
||||
for (idx, (val, field)) in val.iter().zip(&*value.ty().elements).enumerate() {
|
||||
let ty = PsqlType {
|
||||
client: self.ty.client,
|
||||
tuple: value.ty(),
|
||||
field,
|
||||
idx,
|
||||
};
|
||||
record.push((
|
||||
field
|
||||
.name()
|
||||
.map(Cow::from)
|
||||
.unwrap_or_else(|| Cow::from(format!("col_{idx}"))),
|
||||
ty,
|
||||
value.with(&field.algebraic_type, val),
|
||||
));
|
||||
}
|
||||
self.f.write_record(record)
|
||||
}
|
||||
|
||||
fn serialize_variant_raw(self, sum: &ValueWithType<'_, SumValue>) -> Result<Self::Ok, Self::Error> {
|
||||
let sv = sum.value();
|
||||
let (tag, val) = (sv.tag, &*sv.value);
|
||||
let var_ty = &sum.ty().variants[tag as usize]; // Extract the variant type by tag.
|
||||
let product = ProductType::from([AlgebraicType::sum(sum.ty().clone())]);
|
||||
let ty = PsqlType {
|
||||
client: self.ty.client,
|
||||
tuple: &product,
|
||||
field: &product.elements[0],
|
||||
idx: 0,
|
||||
};
|
||||
self.f
|
||||
.write_variant(tag, ty, var_ty.name(), sum.with(&var_ty.algebraic_type, val))
|
||||
}
|
||||
|
||||
fn serialize_variant<T: Serialize + ?Sized>(
|
||||
self,
|
||||
tag: u8,
|
||||
name: Option<&str>,
|
||||
value: &T,
|
||||
_tag: u8,
|
||||
_name: Option<&str>,
|
||||
_value: &T,
|
||||
) -> Result<Self::Ok, Self::Error> {
|
||||
self.fmt.serialize_variant(tag, name, value)
|
||||
}
|
||||
|
||||
unsafe fn serialize_bsatn<Ty>(self, ty: &Ty, bsatn: &[u8]) -> Result<Self::Ok, Self::Error>
|
||||
where
|
||||
for<'b, 'de> WithTypespace<'b, Ty>: DeserializeSeed<'de, Output: Into<AlgebraicValue>>,
|
||||
{
|
||||
// SAFETY: Forward caller requirements of this method to that we are calling.
|
||||
unsafe { self.fmt.serialize_bsatn(ty, bsatn) }
|
||||
}
|
||||
|
||||
unsafe fn serialize_bsatn_in_chunks<'c, Ty, I: Clone + Iterator<Item = &'c [u8]>>(
|
||||
self,
|
||||
ty: &Ty,
|
||||
total_bsatn_len: usize,
|
||||
bsatn: I,
|
||||
) -> Result<Self::Ok, Self::Error>
|
||||
where
|
||||
for<'b, 'de> WithTypespace<'b, Ty>: DeserializeSeed<'de, Output: Into<AlgebraicValue>>,
|
||||
{
|
||||
// SAFETY: Forward caller requirements of this method to that we are calling.
|
||||
unsafe { self.fmt.serialize_bsatn_in_chunks(ty, total_bsatn_len, bsatn) }
|
||||
}
|
||||
|
||||
unsafe fn serialize_str_in_chunks<'c, I: Clone + Iterator<Item = &'c [u8]>>(
|
||||
self,
|
||||
total_len: usize,
|
||||
string: I,
|
||||
) -> Result<Self::Ok, Self::Error> {
|
||||
// SAFETY: Forward caller requirements of this method to that we are calling.
|
||||
unsafe { self.fmt.serialize_str_in_chunks(total_len, string) }
|
||||
unreachable!("Use `serialize_variant_raw` instead.");
|
||||
}
|
||||
}
|
||||
|
||||
impl TypedWriter for SqlFormatter<'_, '_> {
|
||||
type Error = SatnError;
|
||||
|
||||
fn write<W: fmt::Display>(&mut self, value: W) -> Result<(), Self::Error> {
|
||||
write!(self.fmt, "{value}")
|
||||
}
|
||||
|
||||
fn write_bool(&mut self, value: bool) -> Result<(), Self::Error> {
|
||||
write!(self.fmt, "{value}")
|
||||
}
|
||||
|
||||
fn write_string(&mut self, value: &str) -> Result<(), Self::Error> {
|
||||
write!(self.fmt, "\"{value}\"")
|
||||
}
|
||||
|
||||
fn write_bytes(&mut self, value: &[u8]) -> Result<(), Self::Error> {
|
||||
self.write_hex(value)
|
||||
}
|
||||
|
||||
fn write_hex(&mut self, value: &[u8]) -> Result<(), Self::Error> {
|
||||
match self.ty.client {
|
||||
PsqlClient::SpacetimeDB => write!(self.fmt, "0x{}", hex::encode(value)),
|
||||
PsqlClient::Postgres => write!(self.fmt, "\"0x{}\"", hex::encode(value)),
|
||||
}
|
||||
}
|
||||
|
||||
fn write_timestamp(&mut self, value: Timestamp) -> Result<(), Self::Error> {
|
||||
match self.ty.client {
|
||||
PsqlClient::SpacetimeDB => write!(self.fmt, "{}", value.to_rfc3339().unwrap()),
|
||||
PsqlClient::Postgres => write!(self.fmt, "\"{}\"", value.to_rfc3339().unwrap()),
|
||||
}
|
||||
}
|
||||
|
||||
fn write_duration(&mut self, value: TimeDuration) -> Result<(), Self::Error> {
|
||||
match self.ty.client {
|
||||
PsqlClient::SpacetimeDB => write!(self.fmt, "{value}"),
|
||||
PsqlClient::Postgres => write!(self.fmt, "\"{}\"", value.to_iso8601()),
|
||||
}
|
||||
}
|
||||
|
||||
fn write_record(
|
||||
&mut self,
|
||||
fields: Vec<(Cow<str>, PsqlType<'_>, ValueWithType<AlgebraicValue>)>,
|
||||
) -> Result<(), Self::Error> {
|
||||
let PsqlChars { start, sep, end, quote } = self.ty.client.format_chars();
|
||||
write!(self.fmt, "{start}")?;
|
||||
for (idx, (name, ty, value)) in fields.into_iter().enumerate() {
|
||||
if idx > 0 {
|
||||
write!(self.fmt, ", ")?;
|
||||
}
|
||||
write!(self.fmt, "{quote}{name}{quote}{sep} ")?;
|
||||
|
||||
// Serialize the value
|
||||
value.serialize(TypedSerializer { ty: &ty, f: self })?;
|
||||
}
|
||||
write!(self.fmt, "{end}")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_variant(
|
||||
&mut self,
|
||||
tag: u8,
|
||||
ty: PsqlType,
|
||||
name: Option<&str>,
|
||||
value: ValueWithType<AlgebraicValue>,
|
||||
) -> Result<(), Self::Error> {
|
||||
self.write_record(vec![(
|
||||
name.map(Cow::from).unwrap_or_else(|| Cow::from(format!("col_{tag}"))),
|
||||
ty,
|
||||
value,
|
||||
)])
|
||||
}
|
||||
}
|
||||
|
||||
+26
-1
@@ -6,7 +6,7 @@ mod impls;
|
||||
pub mod serde;
|
||||
|
||||
use crate::de::DeserializeSeed;
|
||||
use crate::{algebraic_value::ser::ValueSerializer, bsatn, buffer::BufWriter};
|
||||
use crate::{algebraic_value::ser::ValueSerializer, bsatn, buffer::BufWriter, ProductValue, SumValue, ValueWithType};
|
||||
use crate::{AlgebraicValue, WithTypespace};
|
||||
use core::marker::PhantomData;
|
||||
use core::{convert::Infallible, fmt};
|
||||
@@ -117,6 +117,31 @@ pub trait Serializer: Sized {
|
||||
/// The argument is the number of fields in the product.
|
||||
fn serialize_named_product(self, len: usize) -> Result<Self::SerializeNamedProduct, Self::Error>;
|
||||
|
||||
/// Serialize a product with named fields.
|
||||
///
|
||||
/// Allow to override the default serialization for where we need to switch the output format,
|
||||
/// see [`crate::satn::TypedWriter`].
|
||||
fn serialize_named_product_raw(self, value: &ValueWithType<'_, ProductValue>) -> Result<Self::Ok, Self::Error> {
|
||||
let val = &value.val.elements;
|
||||
assert_eq!(val.len(), value.ty().elements.len());
|
||||
let mut prod = self.serialize_named_product(val.len())?;
|
||||
for (val, el_ty) in val.iter().zip(&*value.ty().elements) {
|
||||
prod.serialize_element(el_ty.name(), &value.with(&el_ty.algebraic_type, val))?
|
||||
}
|
||||
prod.end()
|
||||
}
|
||||
|
||||
/// Serialize a sum value
|
||||
///
|
||||
/// Allow to override the default serialization for where we need to switch the output format,
|
||||
/// see [`crate::satn::TypedWriter`].
|
||||
fn serialize_variant_raw(self, sum: &ValueWithType<'_, SumValue>) -> Result<Self::Ok, Self::Error> {
|
||||
let sv = sum.value();
|
||||
let (tag, val) = (sv.tag, &*sv.value);
|
||||
let var_ty = &sum.ty().variants[tag as usize]; // Extract the variant type by tag.
|
||||
self.serialize_variant(tag, var_ty.name(), &sum.with(&var_ty.algebraic_type, val))
|
||||
}
|
||||
|
||||
/// Serialize a sum value provided the chosen `tag`, `name`, and `value`.
|
||||
fn serialize_variant<T: Serialize + ?Sized>(
|
||||
self,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use super::{Serialize, SerializeArray, SerializeNamedProduct, SerializeSeqProduct, Serializer};
|
||||
use super::{Serialize, SerializeArray, SerializeSeqProduct, Serializer};
|
||||
use crate::{i256, u256};
|
||||
use crate::{AlgebraicType, AlgebraicValue, ArrayValue, ProductValue, SumValue, ValueWithType, F32, F64};
|
||||
use core::ops::Bound;
|
||||
@@ -190,19 +190,10 @@ impl_serialize!(
|
||||
}
|
||||
);
|
||||
impl_serialize!([] ValueWithType<'_, SumValue>, (self, ser) => {
|
||||
let sv = self.value();
|
||||
let (tag, val) = (sv.tag, &*sv.value);
|
||||
let var_ty = &self.ty().variants[tag as usize]; // Extract the variant type by tag.
|
||||
ser.serialize_variant(tag, var_ty.name(), &self.with(&var_ty.algebraic_type, val))
|
||||
ser.serialize_variant_raw(self)
|
||||
});
|
||||
impl_serialize!([] ValueWithType<'_, ProductValue>, (self, ser) => {
|
||||
let val = &self.value().elements;
|
||||
assert_eq!(val.len(), self.ty().elements.len());
|
||||
let mut prod = ser.serialize_named_product(val.len())?;
|
||||
for (val, el_ty) in val.iter().zip(&*self.ty().elements) {
|
||||
prod.serialize_element(el_ty.name(), &self.with(&el_ty.algebraic_type, val))?
|
||||
}
|
||||
prod.end()
|
||||
ser.serialize_named_product_raw(self)
|
||||
});
|
||||
impl_serialize!([] ValueWithType<'_, ArrayValue>, (self, ser) => {
|
||||
let mut ty = &*self.ty().elem_ty;
|
||||
|
||||
@@ -82,6 +82,22 @@ impl TimeDuration {
|
||||
pub fn checked_sub(self, other: Self) -> Option<Self> {
|
||||
self.to_micros().checked_sub(other.to_micros()).map(Self::from_micros)
|
||||
}
|
||||
|
||||
/// Generate an `iso8601` format string.
|
||||
///
|
||||
/// This is the better supported format for use for the `pg wire protocol`.
|
||||
///
|
||||
/// Example:
|
||||
/// ```rust
|
||||
/// use std::time::Duration;
|
||||
/// use spacetimedb_sats::time_duration::TimeDuration;
|
||||
/// assert_eq!( TimeDuration::from_micros(0).to_iso8601().as_str(), "P0D");
|
||||
/// assert_eq!( TimeDuration::from_micros(-1_000_000).to_iso8601().as_str(), "-PT1S");
|
||||
/// assert_eq!( TimeDuration::from_duration(Duration::from_secs(60 * 24)).to_iso8601().as_str(), "PT1440S");
|
||||
/// ```
|
||||
pub fn to_iso8601(self) -> String {
|
||||
chrono::Duration::microseconds(self.to_micros()).to_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Duration> for TimeDuration {
|
||||
|
||||
@@ -171,13 +171,17 @@ impl Timestamp {
|
||||
pub fn checked_sub_duration(&self, duration: Duration) -> Option<Self> {
|
||||
self.checked_sub(TimeDuration::from_duration(duration))
|
||||
}
|
||||
/// Returns an RFC 3339 and ISO 8601 date and time string such as `1996-12-19T16:39:57-08:00`.
|
||||
pub fn to_rfc3339(&self) -> anyhow::Result<String> {
|
||||
|
||||
pub fn to_chrono_date_time(&self) -> anyhow::Result<DateTime<chrono::Utc>> {
|
||||
DateTime::from_timestamp_micros(self.to_micros_since_unix_epoch())
|
||||
.map(|t| t.to_rfc3339())
|
||||
.ok_or_else(|| anyhow::anyhow!("Timestamp with i64 microseconds since Unix epoch overflows DateTime"))
|
||||
.with_context(|| self.to_micros_since_unix_epoch())
|
||||
}
|
||||
|
||||
/// Returns an RFC 3339 and ISO 8601 date and time string such as `1996-12-19T16:39:57-08:00`.
|
||||
pub fn to_rfc3339(&self) -> anyhow::Result<String> {
|
||||
Ok(self.to_chrono_date_time()?.to_rfc3339())
|
||||
}
|
||||
}
|
||||
|
||||
impl Add<TimeDuration> for Timestamp {
|
||||
|
||||
@@ -27,6 +27,7 @@ spacetimedb-core.workspace = true
|
||||
spacetimedb-datastore.workspace = true
|
||||
spacetimedb-lib.workspace = true
|
||||
spacetimedb-paths.workspace = true
|
||||
spacetimedb-pg.workspace = true
|
||||
spacetimedb-table.workspace = true
|
||||
spacetimedb-schema.workspace = true
|
||||
|
||||
|
||||
@@ -186,6 +186,7 @@ impl spacetimedb_client_api::ControlStateReadAccess for StandaloneEnv {
|
||||
id: 0,
|
||||
unschedulable: false,
|
||||
advertise_addr: Some("node:80".to_owned()),
|
||||
pg_addr: Some("node:5432".to_owned()),
|
||||
}));
|
||||
}
|
||||
Ok(None)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use spacetimedb_pg::pg_server;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{StandaloneEnv, StandaloneOptions};
|
||||
@@ -176,12 +177,27 @@ pub async fn exec(args: &ArgMatches, db_cores: JobCores) -> anyhow::Result<()> {
|
||||
db_routes.root_post = db_routes.root_post.layer(DefaultBodyLimit::disable());
|
||||
db_routes.db_put = db_routes.db_put.layer(DefaultBodyLimit::disable());
|
||||
let extra = axum::Router::new().nest("/health", spacetimedb_client_api::routes::health::router());
|
||||
let service = router(&ctx, db_routes, extra).with_state(ctx);
|
||||
let service = router(&ctx, db_routes, extra).with_state(ctx.clone());
|
||||
|
||||
let tcp = TcpListener::bind(listen_addr).await?;
|
||||
socket2::SockRef::from(&tcp).set_nodelay(true)?;
|
||||
log::debug!("Starting SpacetimeDB listening on {}", tcp.local_addr().unwrap());
|
||||
axum::serve(tcp, service).await?;
|
||||
log::debug!("Starting SpacetimeDB listening on {}", tcp.local_addr()?);
|
||||
let pg_server_addr = format!("{}:5432", listen_addr.split(':').next().unwrap());
|
||||
let tcp_pg = TcpListener::bind(pg_server_addr).await?;
|
||||
|
||||
let notify = Arc::new(tokio::sync::Notify::new());
|
||||
let shutdown_notify = notify.clone();
|
||||
tokio::select! {
|
||||
_ = pg_server::start_pg(notify.clone(), ctx, tcp_pg) => {},
|
||||
_ = axum::serve(tcp, service).with_graceful_shutdown(async move {
|
||||
shutdown_notify.notified().await;
|
||||
}) => {},
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
println!("Shutting down servers...");
|
||||
notify.notify_waiters(); // Notify all tasks
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -25,6 +25,8 @@ services:
|
||||
- /stdb
|
||||
ports:
|
||||
- "3000:3000"
|
||||
# Postgres
|
||||
- "5432:5432"
|
||||
# Tracy
|
||||
- "8086:8086"
|
||||
entrypoint: cargo watch -i flamegraphs -i log.conf --why -C crates/standalone -x 'run start --data-dir=/stdb/data --jwt-pub-key-path=/etc/spacetimedb/id_ecdsa.pub --jwt-priv-key-path=/etc/spacetimedb/id_ecdsa'
|
||||
|
||||
@@ -0,0 +1,293 @@
|
||||
from .. import Smoketest
|
||||
import subprocess
|
||||
import os
|
||||
import tomllib
|
||||
import psycopg2
|
||||
|
||||
|
||||
def psql(identity: str, sql: str, extra=None) -> str:
|
||||
"""Call `psql` and execute the given SQL statement."""
|
||||
if extra is None:
|
||||
extra = dict()
|
||||
result = subprocess.run(
|
||||
["psql", "-h", "127.0.0.1", "-p", "5432", "-U", "postgres", "-d", "quickstart", "--quiet", "-c", sql],
|
||||
encoding="utf8",
|
||||
env={**os.environ, **extra, "PGPASSWORD": identity},
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
|
||||
if result.stderr:
|
||||
raise Exception(result.stderr.strip())
|
||||
return result.stdout.strip()
|
||||
|
||||
|
||||
def connect_db(identity: str):
|
||||
"""Connect to the database using `psycopg2`."""
|
||||
conn = psycopg2.connect(host="127.0.0.1", port=5432, user="postgres", password=identity, dbname="quickstart")
|
||||
conn.set_session(autocommit=True) # Disable automic transaction
|
||||
return conn
|
||||
|
||||
|
||||
class SqlFormat(Smoketest):
|
||||
AUTOPUBLISH = False
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::sats::{i256, u256};
|
||||
use spacetimedb::{ConnectionId, Identity, ReducerContext, SpacetimeType, Table, Timestamp, TimeDuration};
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[spacetimedb::table(name = t_ints, public)]
|
||||
pub struct TInts {
|
||||
i8: i8,
|
||||
i16: i16,
|
||||
i32: i32,
|
||||
i64: i64,
|
||||
i128: i128,
|
||||
i256: i256,
|
||||
}
|
||||
|
||||
#[spacetimedb::table(name = t_ints_tuple, public)]
|
||||
pub struct TIntsTuple {
|
||||
tuple: TInts,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[spacetimedb::table(name = t_uints, public)]
|
||||
pub struct TUints {
|
||||
u8: u8,
|
||||
u16: u16,
|
||||
u32: u32,
|
||||
u64: u64,
|
||||
u128: u128,
|
||||
u256: u256,
|
||||
}
|
||||
|
||||
#[spacetimedb::table(name = t_uints_tuple, public)]
|
||||
pub struct TUintsTuple {
|
||||
tuple: TUints,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[spacetimedb::table(name = t_others, public)]
|
||||
pub struct TOthers {
|
||||
bool: bool,
|
||||
f32: f32,
|
||||
f64: f64,
|
||||
str: String,
|
||||
bytes: Vec<u8>,
|
||||
identity: Identity,
|
||||
connection_id: ConnectionId,
|
||||
timestamp: Timestamp,
|
||||
duration: TimeDuration,
|
||||
}
|
||||
|
||||
#[spacetimedb::table(name = t_others_tuple, public)]
|
||||
pub struct TOthersTuple {
|
||||
tuple: TOthers
|
||||
}
|
||||
|
||||
#[derive(SpacetimeType, Debug, Clone, Copy)]
|
||||
pub enum Action {
|
||||
Inactive,
|
||||
Active,
|
||||
}
|
||||
|
||||
#[derive(SpacetimeType, Debug, Clone, Copy)]
|
||||
pub enum Color {
|
||||
Gray(u8),
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[spacetimedb::table(name = t_simple_enum, public)]
|
||||
pub struct TSimpleEnum {
|
||||
id : u32,
|
||||
action: Action,
|
||||
}
|
||||
|
||||
#[spacetimedb::table(name = t_enum, public)]
|
||||
pub struct TEnum {
|
||||
id : u32,
|
||||
color: Color,
|
||||
}
|
||||
|
||||
#[spacetimedb::table(name = t_nested, public)]
|
||||
pub struct TNested {
|
||||
en: TEnum,
|
||||
se: TSimpleEnum,
|
||||
ints: TInts,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn test(ctx: &ReducerContext) {
|
||||
let tuple = TInts {
|
||||
i8: -25,
|
||||
i16: -3224,
|
||||
i32: -23443,
|
||||
i64: -2344353,
|
||||
i128: -234434897853,
|
||||
i256: (-234434897853i128).into(),
|
||||
};
|
||||
let ints = tuple;
|
||||
ctx.db.t_ints().insert(tuple);
|
||||
ctx.db.t_ints_tuple().insert(TIntsTuple { tuple });
|
||||
|
||||
let tuple = TUints {
|
||||
u8: 105,
|
||||
u16: 1050,
|
||||
u32: 83892,
|
||||
u64: 48937498,
|
||||
u128: 4378528978889,
|
||||
u256: 4378528978889u128.into(),
|
||||
};
|
||||
ctx.db.t_uints().insert(tuple);
|
||||
ctx.db.t_uints_tuple().insert(TUintsTuple { tuple });
|
||||
|
||||
let tuple = TOthers {
|
||||
bool: true,
|
||||
f32: 594806.58906,
|
||||
f64: -3454353.345389043278459,
|
||||
str: "This is spacetimedb".to_string(),
|
||||
bytes: vec!(1, 2, 3, 4, 5, 6, 7),
|
||||
identity: Identity::ONE,
|
||||
connection_id: ConnectionId::ZERO,
|
||||
timestamp: Timestamp::UNIX_EPOCH,
|
||||
duration: TimeDuration::from_micros(1000 * 10000),
|
||||
};
|
||||
ctx.db.t_others().insert(tuple.clone());
|
||||
ctx.db.t_others_tuple().insert(TOthersTuple { tuple });
|
||||
|
||||
ctx.db.t_simple_enum().insert(TSimpleEnum { id: 1, action: Action::Inactive });
|
||||
ctx.db.t_simple_enum().insert(TSimpleEnum { id: 2, action: Action::Active });
|
||||
|
||||
ctx.db.t_enum().insert(TEnum { id: 1, color: Color::Gray(128) });
|
||||
|
||||
ctx.db.t_nested().insert(TNested {
|
||||
en: TEnum { id: 1, color: Color::Gray(128) },
|
||||
se: TSimpleEnum { id: 2, action: Action::Active },
|
||||
ints,
|
||||
});
|
||||
}
|
||||
"""
|
||||
|
||||
def assertSql(self, token: str, sql: str, expected):
|
||||
self.maxDiff = None
|
||||
sql_out = psql(token, sql)
|
||||
sql_out = "\n".join([line.rstrip() for line in sql_out.splitlines()])
|
||||
expected = "\n".join([line.rstrip() for line in expected.splitlines()])
|
||||
print(sql_out)
|
||||
self.assertMultiLineEqual(sql_out, expected)
|
||||
|
||||
def read_token(self):
|
||||
"""Read the token from the config file."""
|
||||
with open(self.config_path, "rb") as f:
|
||||
config = tomllib.load(f)
|
||||
return config['spacetimedb_token']
|
||||
|
||||
def test_sql_format(self):
|
||||
"""This test is designed to test calling `psql` to execute SQL statements"""
|
||||
token = self.read_token()
|
||||
self.publish_module("quickstart", clear=True)
|
||||
|
||||
self.call("test")
|
||||
|
||||
self.assertSql(token, "SELECT * FROM t_ints", """\
|
||||
i8 | i16 | i32 | i64 | i128 | i256
|
||||
-----+-------+--------+----------+---------------+---------------
|
||||
-25 | -3224 | -23443 | -2344353 | -234434897853 | -234434897853
|
||||
(1 row)""")
|
||||
self.assertSql(token, "SELECT * FROM t_ints_tuple", """\
|
||||
tuple
|
||||
---------------------------------------------------------------------------------------------------------
|
||||
{"i8": -25, "i16": -3224, "i32": -23443, "i64": -2344353, "i128": -234434897853, "i256": -234434897853}
|
||||
(1 row)""")
|
||||
self.assertSql(token, "SELECT * FROM t_uints", """\
|
||||
u8 | u16 | u32 | u64 | u128 | u256
|
||||
-----+------+-------+----------+---------------+---------------
|
||||
105 | 1050 | 83892 | 48937498 | 4378528978889 | 4378528978889
|
||||
(1 row)""")
|
||||
self.assertSql(token, "SELECT * FROM t_uints_tuple", """\
|
||||
tuple
|
||||
-------------------------------------------------------------------------------------------------------
|
||||
{"u8": 105, "u16": 1050, "u32": 83892, "u64": 48937498, "u128": 4378528978889, "u256": 4378528978889}
|
||||
(1 row)""")
|
||||
self.assertSql(token, "SELECT * FROM t_others", """\
|
||||
bool | f32 | f64 | str | bytes | identity | connection_id | timestamp | duration
|
||||
------+-----------+---------------------+---------------------+------------------+--------------------------------------------------------------------+------------------------------------+---------------------------+----------
|
||||
t | 594806.56 | -3454353.3453890434 | This is spacetimedb | \\x01020304050607 | \\x0000000000000000000000000000000000000000000000000000000000000001 | \\x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | PT10S
|
||||
(1 row)""")
|
||||
self.assertSql(token, "SELECT * FROM t_others_tuple", """\
|
||||
tuple
|
||||
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
{"bool": true, "f32": 594806.56, "f64": -3454353.3453890434, "str": "This is spacetimedb", "bytes": "0x01020304050607", "identity": "0x0000000000000000000000000000000000000000000000000000000000000001", "connection_id": "0x00000000000000000000000000000000", "timestamp": "1970-01-01T00:00:00+00:00", "duration": "PT10S"}
|
||||
(1 row)""")
|
||||
self.assertSql(token, "SELECT * FROM t_simple_enum", """\
|
||||
id | action
|
||||
----+----------
|
||||
1 | Inactive
|
||||
2 | Active
|
||||
(2 rows)""")
|
||||
self.assertSql(token, "SELECT * FROM t_enum", """\
|
||||
id | color
|
||||
----+---------------
|
||||
1 | {"Gray": 128}
|
||||
(1 row)""")
|
||||
self.assertSql(token, "SELECT * FROM t_nested", """\
|
||||
en | se | ints
|
||||
-----------------------------------+-------------------------------------+---------------------------------------------------------------------------------------------------------
|
||||
{"id": 1, "color": {"Gray": 128}} | {"id": 2, "action": {"Active": {}}} | {"i8": -25, "i16": -3224, "i32": -23443, "i64": -2344353, "i128": -234434897853, "i256": -234434897853}
|
||||
(1 row)""")
|
||||
|
||||
def test_sql_conn(self):
|
||||
"""This test is designed to test connecting to the database and executing queries using `psycopg2`"""
|
||||
token = self.read_token()
|
||||
self.publish_module("quickstart", clear=True)
|
||||
self.call("test")
|
||||
|
||||
conn = connect_db(token)
|
||||
# Check prepared statements (faked by `psycopg2`)
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("select * from t_uints where u8 = %s and u16 = %s", (105, 1050))
|
||||
rows = cur.fetchall()
|
||||
self.assertEqual(rows[0], (105, 1050, 83892, 48937498, 4378528978889, 4378528978889))
|
||||
# Check long-lived connection
|
||||
with conn.cursor() as cur:
|
||||
for _ in range(10):
|
||||
cur.execute("select count(*) as t from t_uints")
|
||||
rows = cur.fetchall()
|
||||
self.assertEqual(rows[0], (1,))
|
||||
conn.close()
|
||||
|
||||
def test_failures(self):
|
||||
"""This test is designed to test failure cases"""
|
||||
token = self.read_token()
|
||||
self.publish_module("quickstart", clear=True)
|
||||
|
||||
# Empty query
|
||||
sql_out = psql(token, "")
|
||||
self.assertEqual(sql_out, "")
|
||||
|
||||
# Connection fails when `ssl` is required
|
||||
for ssl_mode in ["require", "verify-ca", "verify-full"]:
|
||||
with self.assertRaises(Exception) as cm:
|
||||
psql(token, "SELECT * FROM t_uints", extra={"PGSSLMODE": ssl_mode})
|
||||
self.assertIn("not support SSL", str(cm.exception))
|
||||
|
||||
# But works with `ssl` is disabled or optional
|
||||
for ssl_mode in ["disable", "allow", "prefer"]:
|
||||
psql(token, "SELECT * FROM t_uints", extra={"PGSSLMODE": ssl_mode})
|
||||
|
||||
# Connection fails with invalid token
|
||||
with self.assertRaises(Exception) as cm:
|
||||
psql("invalid_token", "SELECT * FROM t_uints")
|
||||
self.assertIn("Invalid token", str(cm.exception))
|
||||
|
||||
# Returns error for unsupported `sql` statements
|
||||
with self.assertRaises(Exception) as cm:
|
||||
psql(token, "SELECT CASE a WHEN 1 THEN 'one' ELSE 'other' END FROM t_uints")
|
||||
self.assertIn("Unsupported", str(cm.exception))
|
||||
|
||||
# And prepared statements
|
||||
with self.assertRaises(Exception) as cm:
|
||||
psql(token, "SELECT * FROM t_uints where u8 = $1")
|
||||
self.assertIn("Unsupported", str(cm.exception))
|
||||
Reference in New Issue
Block a user