mirror of
https://github.com/clockworklabs/SpacetimeDB.git
synced 2026-05-06 07:26:43 -04:00
Remove Python smoketests (#4896)
# Description of Changes Remove the Python smoketests and the CI check that tests for edits. # API and ABI breaking changes None. CI only. # Expected complexity level and risk 1 # Testing - [ ] All CI passes Co-authored-by: Zeke Foppa <bfops@users.noreply.github.com>
This commit is contained in:
@@ -1025,36 +1025,6 @@ jobs:
|
||||
- name: Check global.json policy
|
||||
run: cargo ci global-json-policy
|
||||
|
||||
warn-python-smoketests:
|
||||
name: Check for Python smoketest edits
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event_name == 'pull_request'
|
||||
permissions:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout sources
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Fail if Python smoketests were modified
|
||||
run: |
|
||||
MERGE_BASE="$(git merge-base origin/${{ github.base_ref }} HEAD)"
|
||||
PYTHON_SMOKETEST_CHANGES="$(git diff --name-only "$MERGE_BASE" HEAD -- 'smoketests/**.py')"
|
||||
|
||||
if [ -n "$PYTHON_SMOKETEST_CHANGES" ]; then
|
||||
echo "::error::This PR modifies legacy Python smoketests. Please add new tests to the Rust smoketests in crates/smoketests/ instead."
|
||||
echo ""
|
||||
echo "Changed files:"
|
||||
echo "$PYTHON_SMOKETEST_CHANGES"
|
||||
echo ""
|
||||
echo "The Python smoketests are being replaced by Rust smoketests."
|
||||
echo "See crates/smoketests/DEVELOP.md for instructions on adding Rust smoketests."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "No Python smoketest changes detected."
|
||||
|
||||
smoketests_mod_rs_complete:
|
||||
name: Check smoketests/mod.rs is complete
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
# Python Smoketests (Legacy)
|
||||
|
||||
> **Note:** These Python smoketests are being replaced by Rust smoketests in `crates/smoketests/`.
|
||||
> Both test suites currently run in CI to ensure consistency during the transition.
|
||||
>
|
||||
> For new tests, please add them to the Rust smoketests. See `crates/smoketests/DEVELOP.md` for instructions.
|
||||
|
||||
---
|
||||
|
||||
## Running the Python Smoketests
|
||||
|
||||
To use the smoketests, you first need to install the dependencies:
|
||||
|
||||
```
|
||||
python -m venv smoketests/venv
|
||||
smoketests/venv/bin/pip install -r smoketests/requirements.txt
|
||||
```
|
||||
|
||||
Then, run the smoketests like so:
|
||||
```
|
||||
smoketests/venv/bin/python -m smoketests <args>
|
||||
```
|
||||
@@ -1,445 +0,0 @@
|
||||
from pathlib import Path
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import shutil
|
||||
import string
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import unittest
|
||||
import logging
|
||||
import http.client
|
||||
import tomllib
|
||||
import functools
|
||||
|
||||
# miscellaneous file paths
|
||||
TEST_DIR = Path(__file__).parent
|
||||
STDB_DIR = TEST_DIR.parent
|
||||
exe_suffix = ".exe" if sys.platform == "win32" else ""
|
||||
SPACETIME_BIN = STDB_DIR / ("target/debug/spacetime" + exe_suffix)
|
||||
TEMPLATE_TARGET_DIR = STDB_DIR / "target/_stdbsmoketests"
|
||||
BASE_STDB_CONFIG_PATH = TEST_DIR / "config.toml"
|
||||
|
||||
# the contents of files for the base smoketest project template
|
||||
TEMPLATE_LIB_RS = open(STDB_DIR / "templates/basic-rs/spacetimedb/src/lib.rs").read()
|
||||
TEMPLATE_CARGO_TOML = open(STDB_DIR / "templates/basic-rs/spacetimedb/Cargo.toml").read()
|
||||
bindings_path = (STDB_DIR / "crates/bindings").absolute()
|
||||
escaped_bindings_path = str(bindings_path).replace('\\', '\\\\\\\\') # double escape for re.sub + toml
|
||||
TYPESCRIPT_BINDINGS_PATH = (STDB_DIR / "crates/bindings-typescript").absolute()
|
||||
TEMPLATE_CARGO_TOML = (re.compile(r"^spacetimedb\s*=.*$", re.M) \
|
||||
.sub(f'spacetimedb = {{ path = "{escaped_bindings_path}", features = {{features}} }}', TEMPLATE_CARGO_TOML))
|
||||
|
||||
# this is set to true when the --docker flag is passed to the cli
|
||||
HAVE_DOCKER = False
|
||||
# this is set to true when the --skip-dotnet flag is not passed to the cli,
|
||||
# and a dotnet installation is detected
|
||||
HAVE_DOTNET = False
|
||||
|
||||
# When we pass --spacetime-login, we are running against a server that requires "real" spacetime logins (rather than `--server-issued-login`).
|
||||
# This is used to skip tests that don't work with that.
|
||||
USE_SPACETIME_LOGIN = False
|
||||
|
||||
# If we pass `--remote-server`, the server address will be something other than the default. This is used to skip tests that rely on use
|
||||
# having the default localhost server.
|
||||
REMOTE_SERVER = False
|
||||
|
||||
# default value can be overridden by `--compose-file` flag
|
||||
COMPOSE_FILE = ".github/docker-compose.yml"
|
||||
|
||||
# this will be initialized by main()
|
||||
STDB_CONFIG = ''
|
||||
|
||||
# we need to late-bind the output stream to allow unittests to capture stdout/stderr.
|
||||
class CapturableHandler(logging.StreamHandler):
|
||||
|
||||
@property
|
||||
def stream(self):
|
||||
return sys.stderr
|
||||
|
||||
@stream.setter
|
||||
def stream(self, value):
|
||||
pass
|
||||
|
||||
handler = CapturableHandler()
|
||||
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
|
||||
logging.getLogger().addHandler(handler)
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
def requires_dotnet(item):
|
||||
if HAVE_DOTNET:
|
||||
return item
|
||||
return unittest.skip("dotnet 8.0 not available")(item)
|
||||
|
||||
def requires_anonymous_login(item):
|
||||
if USE_SPACETIME_LOGIN:
|
||||
return unittest.skip("using `spacetime login`")(item)
|
||||
return item
|
||||
|
||||
def requires_local_server(item):
|
||||
if REMOTE_SERVER:
|
||||
return unittest.skip("running against a remote server")(item)
|
||||
return item
|
||||
|
||||
def build_template_target():
|
||||
if not TEMPLATE_TARGET_DIR.exists():
|
||||
logging.info("Building base compilation artifacts")
|
||||
class BuildModule(Smoketest):
|
||||
AUTOPUBLISH = False
|
||||
|
||||
BuildModule.setUpClass()
|
||||
env = { **os.environ, "CARGO_TARGET_DIR": str(TEMPLATE_TARGET_DIR) }
|
||||
spacetime("build", "--module-path", BuildModule.project_path, env=env)
|
||||
BuildModule.tearDownClass()
|
||||
BuildModule.doClassCleanups()
|
||||
|
||||
|
||||
def requires_docker(item):
|
||||
if HAVE_DOCKER:
|
||||
return item
|
||||
return unittest.skip("docker not available")(item)
|
||||
|
||||
def random_string(k=20):
|
||||
return ''.join(random.choices(string.ascii_lowercase, k=k))
|
||||
|
||||
def extract_fields(cmd_output, field_name):
|
||||
"""
|
||||
parses output from the spacetime cli that's formatted in the "empty" style
|
||||
from tabled:
|
||||
FIELDNAME1 VALUE1
|
||||
THEFIELDNAME2 VALUE2
|
||||
field_name should be which field name you want to filter for
|
||||
"""
|
||||
out = []
|
||||
for line in cmd_output.splitlines():
|
||||
fields = line.split()
|
||||
if len(fields) < 2:
|
||||
continue
|
||||
label, val, *_ = fields
|
||||
if label == field_name:
|
||||
out.append(val)
|
||||
return out
|
||||
|
||||
def parse_sql_result(res: str) -> list[dict]:
|
||||
"""Parse tabular output from an SQL query into a list of dicts."""
|
||||
lines = res.splitlines()
|
||||
headers = lines[0].split('|') if '|' in lines[0] else [lines[0]]
|
||||
headers = [header.strip() for header in headers]
|
||||
rows = []
|
||||
for row in lines[2:]:
|
||||
cols = [col.strip() for col in row.split('|')]
|
||||
rows.append(dict(zip(headers, cols)))
|
||||
return rows
|
||||
|
||||
def extract_field(cmd_output, field_name):
|
||||
field, = extract_fields(cmd_output, field_name)
|
||||
return field
|
||||
|
||||
def log_cmd(args):
|
||||
logging.debug(f"$ {' '.join(str(arg) for arg in args)}")
|
||||
|
||||
|
||||
def run_cmd(*args, capture_stderr=True, check=True, full_output=False, cmd_name=None, log=True, **kwargs):
|
||||
if log:
|
||||
log_cmd(args if cmd_name is None else [cmd_name, *args[1:]])
|
||||
|
||||
needs_close = False
|
||||
if not capture_stderr:
|
||||
logging.debug("--- stderr ---")
|
||||
needs_close = True
|
||||
|
||||
output = subprocess.run(
|
||||
list(args),
|
||||
encoding="utf8",
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE if capture_stderr else None,
|
||||
**kwargs
|
||||
)
|
||||
if log:
|
||||
if capture_stderr and output.stderr.strip() != "":
|
||||
logging.debug(f"--- stderr ---\n{output.stderr.strip()}")
|
||||
needs_close = True
|
||||
if output.stdout.strip() != "":
|
||||
logging.debug(f"--- stdout ---\n{output.stdout.strip()}")
|
||||
needs_close = True
|
||||
if needs_close:
|
||||
logging.debug("--------------\n")
|
||||
|
||||
sys.stderr.flush()
|
||||
if check:
|
||||
if cmd_name is not None:
|
||||
output.args[0] = cmd_name
|
||||
output.check_returncode()
|
||||
return output if full_output else output.stdout
|
||||
|
||||
@functools.cache
|
||||
def pnpm_path():
|
||||
pnpm = shutil.which("pnpm")
|
||||
if not pnpm:
|
||||
raise Exception("pnpm not installed")
|
||||
return pnpm
|
||||
|
||||
def pnpm(*args, **kwargs):
|
||||
return run_cmd(pnpm_path(), *args, **kwargs)
|
||||
|
||||
@functools.cache
|
||||
def build_typescript_sdk():
|
||||
pnpm("install", cwd=TYPESCRIPT_BINDINGS_PATH)
|
||||
pnpm("build", cwd=TYPESCRIPT_BINDINGS_PATH)
|
||||
|
||||
def spacetime(*args, **kwargs):
|
||||
return run_cmd(SPACETIME_BIN, *args, cmd_name="spacetime", **kwargs)
|
||||
|
||||
def new_identity(config_path):
|
||||
spacetime("--config-path", str(config_path), "logout")
|
||||
spacetime("--config-path", str(config_path), "login", "--server-issued-login", "localhost", full_output=False)
|
||||
|
||||
class Smoketest(unittest.TestCase):
|
||||
MODULE_CODE = TEMPLATE_LIB_RS
|
||||
AUTOPUBLISH = True
|
||||
BINDINGS_FEATURES = ["unstable"]
|
||||
EXTRA_DEPS = ""
|
||||
|
||||
@classmethod
|
||||
def cargo_manifest(cls, manifest_text):
|
||||
return manifest_text.replace("{features}", repr(list(cls.BINDINGS_FEATURES))) + cls.EXTRA_DEPS
|
||||
|
||||
# helpers
|
||||
|
||||
@classmethod
|
||||
def spacetime(cls, *args, **kwargs):
|
||||
return spacetime("--config-path", str(cls.config_path), *args, **kwargs)
|
||||
|
||||
def _check_published(self):
|
||||
if not hasattr(self, "database_identity"):
|
||||
raise Exception("Cannot use this function without publishing a module")
|
||||
|
||||
def call(self, reducer, *args, anon=False, check=True, full_output = False):
|
||||
self._check_published()
|
||||
anon = ["--anonymous"] if anon else []
|
||||
return self.spacetime("call", *anon, "--", self.database_identity, reducer, *map(json.dumps, args), check = check, full_output=full_output)
|
||||
|
||||
def sql(self, sql):
|
||||
self._check_published()
|
||||
anon = ["--anonymous"]
|
||||
return self.spacetime("sql", *anon, "--", self.database_identity, sql)
|
||||
|
||||
def logs(self, n):
|
||||
return [log["message"] for log in self.log_records(n)]
|
||||
|
||||
def log_records(self, n):
|
||||
self._check_published()
|
||||
logs = self.spacetime("logs", "--format=json", "-n", str(n), "--", self.database_identity)
|
||||
return list(map(json.loads, logs.splitlines()))
|
||||
|
||||
def publish_module(self, domain=None, *, clear=True, capture_stderr=True,
|
||||
num_replicas=None, break_clients=False, organization=None):
|
||||
publish_output = self.spacetime(
|
||||
"publish",
|
||||
*[domain] if domain is not None else [],
|
||||
*["-c"] if clear and domain is not None else [],
|
||||
"--module-path", self.project_path,
|
||||
# This is required if -c is provided, but is also required for SpacetimeDBPrivate's tests,
|
||||
# because the server address is `node` which doesn't look like `localhost` or `127.0.0.1`
|
||||
# and so the publish step prompts for confirmation.
|
||||
"--yes",
|
||||
*["--num-replicas", f"{num_replicas}"] if num_replicas is not None else [],
|
||||
*["--break-clients"] if break_clients else [],
|
||||
*["--organization", f"{organization}"] if organization is not None else [],
|
||||
capture_stderr=capture_stderr,
|
||||
)
|
||||
self.resolved_identity = re.search(r"identity: ([0-9a-fA-F]+)", publish_output)[1]
|
||||
self.database_identity = self.resolved_identity
|
||||
|
||||
@classmethod
|
||||
def reset_config(cls):
|
||||
if not STDB_CONFIG:
|
||||
raise Exception("config toml has not been initialized yet")
|
||||
cls.config_path.write_text(STDB_CONFIG)
|
||||
|
||||
def fingerprint(self):
|
||||
# Fetch the server's fingerprint; required for `identity list`.
|
||||
self.spacetime("server", "fingerprint", "localhost", "-y")
|
||||
|
||||
def new_identity(self):
|
||||
new_identity(self.__class__.config_path)
|
||||
|
||||
def subscribe(self, *queries, n, confirmed = None, database = None):
|
||||
self._check_published()
|
||||
assert isinstance(n, int)
|
||||
|
||||
args = [
|
||||
SPACETIME_BIN,
|
||||
"--config-path", str(self.config_path),
|
||||
"subscribe",
|
||||
database if database is not None else self.database_identity,
|
||||
"-t", "600",
|
||||
"-n", str(n),
|
||||
"--print-initial-update",
|
||||
]
|
||||
if confirmed is not None:
|
||||
args.append(f"--confirmed={str(confirmed).lower()}")
|
||||
args.extend(["--", *queries])
|
||||
|
||||
fake_args = ["spacetime", *args[1:]]
|
||||
log_cmd(fake_args)
|
||||
|
||||
proc = subprocess.Popen(args, encoding="utf8", stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
||||
def stderr_task():
|
||||
sys.stderr.writelines(proc.stderr)
|
||||
threading.Thread(target=stderr_task).start()
|
||||
|
||||
init_update = proc.stdout.readline().strip()
|
||||
if init_update:
|
||||
print("initial update:", init_update)
|
||||
else:
|
||||
try:
|
||||
code = proc.wait()
|
||||
if code:
|
||||
raise subprocess.CalledProcessError(code, fake_args)
|
||||
print("no initial update, but no error code either")
|
||||
except subprocess.TimeoutExpired:
|
||||
print("no initial update, but process is still running")
|
||||
|
||||
def run():
|
||||
updates = list(map(json.loads, proc.stdout))
|
||||
code = proc.wait()
|
||||
if code:
|
||||
raise subprocess.CalledProcessError(code, fake_args)
|
||||
return updates
|
||||
# Note that we're returning `.join`, not `.join()`; this returns something that the caller can call in order to
|
||||
# join the thread and wait for the results.
|
||||
# If the caller does not invoke this returned value, the thread will just run in the background, not be awaited,
|
||||
# and **not raise any exceptions to the caller**.
|
||||
return ReturnThread(run).join
|
||||
|
||||
def get_server_address(self):
|
||||
with open(self.config_path, "rb") as f:
|
||||
config = tomllib.load(f)
|
||||
token = config['spacetimedb_token']
|
||||
server_name = config['default_server']
|
||||
server_config = next((c for c in config['server_configs'] if c['nickname'] == server_name), None)
|
||||
if server_config is None:
|
||||
raise Exception(f"Unable to find server in config with nickname {server_name}")
|
||||
address = server_config['host']
|
||||
host = address
|
||||
port = None
|
||||
if ":" in host:
|
||||
host, port = host.split(":", 1)
|
||||
protocol = server_config['protocol']
|
||||
|
||||
return dict(address=address, host=host, port=port, protocol=protocol, token=token)
|
||||
|
||||
# Make an HTTP call with `method` to `path`.
|
||||
#
|
||||
# If the response is 200, return the body.
|
||||
# Otherwise, throw an `Exception` constructed with two arguments, the response object and the body.
|
||||
def api_call(self, method, path, body=None, headers={}):
|
||||
server = self.get_server_address()
|
||||
host = server["address"]
|
||||
protocol = server["protocol"]
|
||||
token = server["token"]
|
||||
conn = None
|
||||
if protocol == "http":
|
||||
conn = http.client.HTTPConnection(host)
|
||||
elif protocol == "https":
|
||||
conn = http.client.HTTPSConnection(host)
|
||||
else:
|
||||
raise Exception(f"Unknown protocol: {protocol}")
|
||||
auth = {"Authorization": f'Bearer {token}'}
|
||||
headers.update(auth)
|
||||
log_cmd([method, path])
|
||||
conn.request(method, path, body, headers)
|
||||
resp = conn.getresponse()
|
||||
body = resp.read()
|
||||
logging.debug(f"{resp.status} {body}")
|
||||
if resp.status != 200:
|
||||
raise Exception(resp, body)
|
||||
return body
|
||||
|
||||
@classmethod
|
||||
def write_module_code(cls, module_code):
|
||||
open(cls.project_path / "src/lib.rs", "w").write(module_code)
|
||||
|
||||
# testcase initialization
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.project_path = Path(cls.enterClassContext(tempfile.TemporaryDirectory()))
|
||||
cls.config_path = cls.project_path / "config.toml"
|
||||
cls.reset_config()
|
||||
open(cls.project_path / "Cargo.toml", "w").write(cls.cargo_manifest(TEMPLATE_CARGO_TOML))
|
||||
shutil.copy2(STDB_DIR / "rust-toolchain.toml", cls.project_path)
|
||||
os.mkdir(cls.project_path / "src")
|
||||
cls.write_module_code(cls.MODULE_CODE)
|
||||
if TEMPLATE_TARGET_DIR.exists():
|
||||
shutil.copytree(TEMPLATE_TARGET_DIR, cls.project_path / "target")
|
||||
|
||||
if cls.AUTOPUBLISH:
|
||||
logging.info(f"Compiling module for {cls.__qualname__}...")
|
||||
cls.publish_module(cls, capture_stderr=True) # capture stderr because otherwise it clutters the top-level test logs for some reason.
|
||||
|
||||
def tearDown(self):
|
||||
# if this single test method published a database, clean it up now
|
||||
if "database_identity" in self.__dict__:
|
||||
try:
|
||||
# TODO: save the credentials in publish_module()
|
||||
self.spacetime("delete", "--yes", self.database_identity)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if hasattr(cls, "database_identity"):
|
||||
try:
|
||||
# TODO: save the credentials in publish_module()
|
||||
cls.spacetime("delete", "--yes", cls.database_identity)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
# polyfill; python 3.11 defines this classmethod on TestCase
|
||||
@classmethod
|
||||
def enterClassContext(cls, cm):
|
||||
result = cm.__enter__()
|
||||
cls.addClassCleanup(cm.__exit__, None, None, None)
|
||||
return result
|
||||
|
||||
def assertSql(self, sql: str, expected: str):
|
||||
"""Assert that executing `sql` produces the expected output."""
|
||||
self.maxDiff = None
|
||||
sql_out = self.spacetime("sql", self.database_identity, sql)
|
||||
sql_out = "\n".join([line.rstrip() for line in sql_out.splitlines()])
|
||||
expected = "\n".join([line.rstrip() for line in expected.splitlines()])
|
||||
self.assertMultiLineEqual(sql_out, expected)
|
||||
|
||||
# This is a custom thread class that will propagate an exception to the caller of `.join()`.
|
||||
# This is required because, by default, threads do not propagate exceptions to their callers,
|
||||
# even callers who have called `join`.
|
||||
class ReturnThread:
|
||||
def __init__(self, target):
|
||||
self._target = target
|
||||
self._exception = None
|
||||
self._thread = threading.Thread(target=self._task)
|
||||
self._thread.start()
|
||||
|
||||
def _task(self):
|
||||
# Wrap self._target()` with an exception handler, so we can return the exception
|
||||
# to the caller of `join` below.
|
||||
try:
|
||||
self._result = self._target()
|
||||
except BaseException as e:
|
||||
self._exception = e
|
||||
finally:
|
||||
del self._target
|
||||
|
||||
def join(self, timeout=None):
|
||||
self._thread.join(timeout)
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
return self._result
|
||||
|
||||
@@ -1,190 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import subprocess
|
||||
import unittest
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import fnmatch
|
||||
import json
|
||||
from . import TEST_DIR, SPACETIME_BIN, BASE_STDB_CONFIG_PATH, exe_suffix, build_template_target
|
||||
import smoketests
|
||||
import sys
|
||||
import logging
|
||||
import itertools
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import traceback
|
||||
|
||||
def check_docker():
|
||||
docker_ps = smoketests.run_cmd("docker", "ps", "--format=json")
|
||||
docker_ps = (json.loads(line) for line in docker_ps.splitlines())
|
||||
for docker_container in docker_ps:
|
||||
if "node" in docker_container["Image"] or "spacetime" in docker_container["Image"]:
|
||||
return docker_container["Names"]
|
||||
else:
|
||||
print("Docker container not found, is SpacetimeDB running?")
|
||||
exit(1)
|
||||
|
||||
def check_dotnet() -> bool:
|
||||
try:
|
||||
version = smoketests.run_cmd("dotnet", "--version", log=False).strip()
|
||||
if int(version.split(".")[0]) < 8:
|
||||
logging.info(f"dotnet version {version} not high enough (< 8.0), skipping dotnet smoketests")
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
return True
|
||||
|
||||
class ExclusionaryTestLoader(unittest.TestLoader):
|
||||
def __init__(self, excludelist=()):
|
||||
super().__init__()
|
||||
# build a regex that matches any of the elements of excludelist at a word boundary
|
||||
excludes = '|'.join(fnmatch.translate(exclude).removesuffix(r"\Z") for exclude in excludelist)
|
||||
self.excludepat = excludes and re.compile(rf"^(?:{excludes})\b")
|
||||
|
||||
def loadTestsFromName(self, name, module=None):
|
||||
if self.excludepat:
|
||||
qualname = name
|
||||
if module is not None:
|
||||
qualname = module.__name__ + "." + name
|
||||
if self.excludepat.match(qualname):
|
||||
return self.suiteClass([])
|
||||
return super().loadTestsFromName(name, module)
|
||||
|
||||
def _convert_select_pattern(pattern):
|
||||
return f'*{pattern}*' if '*' not in pattern else pattern
|
||||
|
||||
|
||||
TESTPREFIX = "smoketests.tests."
|
||||
|
||||
def _iter_all_tests(suite_or_case):
|
||||
"""Yield all individual tests from possibly nested TestSuite structures."""
|
||||
if isinstance(suite_or_case, unittest.TestSuite):
|
||||
for t in suite_or_case:
|
||||
yield from _iter_all_tests(t)
|
||||
else:
|
||||
yield suite_or_case
|
||||
|
||||
def main():
|
||||
tests = [fname.removesuffix(".py") for fname in os.listdir(TEST_DIR / "tests") if fname.endswith(".py") and fname != "__init__.py"]
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("test", nargs="*", default=tests)
|
||||
parser.add_argument("--docker", action="store_true")
|
||||
parser.add_argument("--compose-file")
|
||||
parser.add_argument("--no-docker-logs", action="store_true")
|
||||
parser.add_argument("--skip-dotnet", action="store_true", help="ignore tests which require dotnet")
|
||||
parser.add_argument("--show-all-output", action="store_true", help="show all stdout/stderr from the tests as they're running")
|
||||
parser.add_argument("--parallel", action="store_true", help="run test classes in parallel")
|
||||
parser.add_argument("-j", dest='jobs', help="Set number of jobs for parallel test runs. Default is `nproc`", type=int, default=0)
|
||||
parser.add_argument('-k', dest='testNamePatterns',
|
||||
action='append', type=_convert_select_pattern,
|
||||
help='Only run tests which match the given substring')
|
||||
parser.add_argument("-x", dest="exclude", nargs="*", default=[])
|
||||
parser.add_argument("--no-build-cli", action="store_true", help="don't cargo build the cli")
|
||||
parser.add_argument("--list", action="store_true", help="list the tests that would be run, but don't run them")
|
||||
parser.add_argument("--remote-server", action="store", help="Run against a remote server")
|
||||
parser.add_argument("--spacetime-login", action="store_true", help="Use `spacetime login` for these tests (and disable tests that don't work with that)")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.docker:
|
||||
# have docker logs print concurrently with the test output
|
||||
if args.compose_file:
|
||||
smoketests.COMPOSE_FILE = args.compose_file
|
||||
if not args.no_docker_logs:
|
||||
if args.compose_file:
|
||||
subprocess.Popen(["docker", "compose", "-f", args.compose_file, "logs", "-f"])
|
||||
else:
|
||||
docker_container = check_docker()
|
||||
subprocess.Popen(["docker", "logs", "-f", docker_container])
|
||||
smoketests.HAVE_DOCKER = True
|
||||
|
||||
if not args.skip_dotnet:
|
||||
smoketests.HAVE_DOTNET = check_dotnet()
|
||||
if not smoketests.HAVE_DOTNET:
|
||||
print("no suitable dotnet installation found")
|
||||
exit(1)
|
||||
|
||||
add_prefix = lambda testlist: [TESTPREFIX + test for test in testlist]
|
||||
import fnmatch
|
||||
excludelist = add_prefix(args.exclude)
|
||||
testlist = add_prefix(args.test)
|
||||
|
||||
loader = ExclusionaryTestLoader(excludelist)
|
||||
loader.testNamePatterns = args.testNamePatterns
|
||||
|
||||
tests = loader.loadTestsFromNames(testlist)
|
||||
if args.list:
|
||||
failed_cls = getattr(unittest.loader, "_FailedTest", None)
|
||||
any_failed = False
|
||||
for test in _iter_all_tests(tests):
|
||||
name = test.id()
|
||||
if isinstance(test, failed_cls):
|
||||
any_failed = True
|
||||
print('')
|
||||
print("Failed to construct %s:" % test.id())
|
||||
exc = getattr(test, "_exception", None)
|
||||
if exc is not None:
|
||||
tb = ''.join(traceback.format_exception(exc))
|
||||
print(tb.rstrip())
|
||||
print('')
|
||||
else:
|
||||
print(f"{name}")
|
||||
exit(1 if any_failed else 0)
|
||||
|
||||
if not args.no_build_cli:
|
||||
logging.info("Compiling spacetime cli...")
|
||||
smoketests.run_cmd("cargo", "build", cwd=TEST_DIR.parent, capture_stderr=False)
|
||||
|
||||
update_bin_name = "spacetimedb-update" + exe_suffix
|
||||
try:
|
||||
bin_is_symlink = SPACETIME_BIN.readlink() == update_bin_name
|
||||
except OSError:
|
||||
bin_is_symlink = False
|
||||
if not bin_is_symlink:
|
||||
try:
|
||||
os.remove(SPACETIME_BIN)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
try:
|
||||
os.symlink(update_bin_name, SPACETIME_BIN)
|
||||
except OSError:
|
||||
shutil.copyfile(SPACETIME_BIN.with_name(update_bin_name), SPACETIME_BIN)
|
||||
|
||||
os.environ["SPACETIME_SKIP_CLIPPY"] = "1"
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w+b", suffix=".toml", buffering=0, delete_on_close=False) as config_file:
|
||||
with BASE_STDB_CONFIG_PATH.open("rb") as src, config_file.file as dst:
|
||||
shutil.copyfileobj(src, dst)
|
||||
|
||||
if args.remote_server is not None:
|
||||
smoketests.spacetime("--config-path", config_file.name, "server", "edit", "localhost", "--url", args.remote_server, "--yes")
|
||||
smoketests.REMOTE_SERVER = True
|
||||
|
||||
if args.spacetime_login:
|
||||
smoketests.spacetime("--config-path", config_file.name, "logout")
|
||||
smoketests.spacetime("--config-path", config_file.name, "login")
|
||||
smoketests.USE_SPACETIME_LOGIN = True
|
||||
else:
|
||||
smoketests.new_identity(config_file.name)
|
||||
|
||||
smoketests.STDB_CONFIG = Path(config_file.name).read_text()
|
||||
|
||||
build_template_target()
|
||||
buffer = not args.show_all_output
|
||||
verbosity = 2
|
||||
|
||||
if args.parallel:
|
||||
print("parallel test running is under construction, this will probably not work correctly")
|
||||
from . import unittest_parallel
|
||||
unittest_parallel.main(buffer=buffer, verbose=verbosity, level="class", discovered_tests=tests, jobs=args.jobs)
|
||||
else:
|
||||
result = unittest.TextTestRunner(buffer=buffer, verbosity=verbosity).run(tests)
|
||||
if not result.wasSuccessful():
|
||||
parser.exit(status=1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,6 +0,0 @@
|
||||
default_server = "localhost"
|
||||
|
||||
[[server_configs]]
|
||||
nickname = "localhost"
|
||||
host = "127.0.0.1:3000"
|
||||
protocol = "http"
|
||||
@@ -1,208 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Callable
|
||||
from urllib.request import urlopen
|
||||
|
||||
from . import COMPOSE_FILE
|
||||
|
||||
|
||||
def restart_docker():
|
||||
"""
|
||||
Restart all containers defined in the current `COMPOSE_FILE`.
|
||||
|
||||
Checks that all spacetimedb containers are up and running after the restart.
|
||||
If they're not up after a couple of retries, throws an `Exception`.
|
||||
"""
|
||||
print("Restarting containers")
|
||||
|
||||
docker = DockerManager(COMPOSE_FILE)
|
||||
docker.compose("restart")
|
||||
containers = docker.list_spacetimedb_containers()
|
||||
if not containers:
|
||||
raise Exception("No spacetimedb containers found")
|
||||
|
||||
# Ensure all nodes are running.
|
||||
attempts = 0
|
||||
while attempts < 10:
|
||||
attempts += 1
|
||||
containers_alive = {
|
||||
container.name: container.is_running(docker, spacetimedb_ping_url)
|
||||
for container in containers
|
||||
}
|
||||
if all(containers_alive.values()):
|
||||
# sleep a bit more to allow for leader election etc
|
||||
# TODO: make ping endpoint consider all server state
|
||||
time.sleep(2)
|
||||
return
|
||||
else:
|
||||
time.sleep(1)
|
||||
|
||||
raise Exception(f"Not all containers are up and running: {containers_alive!r}")
|
||||
|
||||
def spacetimedb_ping_url(port: int) -> str:
|
||||
return f"http://127.0.0.1:{port}/v1/ping"
|
||||
|
||||
@dataclass
|
||||
class DockerContainer:
|
||||
"""Represents a Docker container with its basic properties."""
|
||||
id: str
|
||||
name: str
|
||||
|
||||
def host_ports(self, docker) -> set[int]:
|
||||
"""
|
||||
Collect all host ports of this container.
|
||||
|
||||
Host ports are ports on the host that are bound to ports of the
|
||||
container.
|
||||
If the container is not currently running, an empty set is returned.
|
||||
"""
|
||||
host_ports = set()
|
||||
info = docker.inspect_container(self)
|
||||
for ports in info.get('NetworkSettings', {}).get('Ports', {}).values():
|
||||
if ports:
|
||||
for ip_and_port in ports:
|
||||
host_port = ip_and_port.get("HostPort")
|
||||
if host_port:
|
||||
host_ports.add(host_port)
|
||||
return host_ports
|
||||
|
||||
def is_running(self, docker, ping_url: Callable[[int], str]) -> bool:
|
||||
"""
|
||||
Check if the container is running.
|
||||
|
||||
`ping_url` takes a port number and returns a URL string that can be used
|
||||
to determine if the host is running by returning a 200 status.
|
||||
|
||||
If `self.host_ports()` returns a non-empty set, and one `ping_url`
|
||||
request is successful, the container is considered running.
|
||||
"""
|
||||
host_ports = self.host_ports(docker)
|
||||
for port in host_ports:
|
||||
url = ping_url(port)
|
||||
print(f"Trying {url} ... ", end='', flush=True)
|
||||
try:
|
||||
with urlopen(url, timeout=0.2) as response:
|
||||
if response.status == 200:
|
||||
print("ok")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"error: {e}")
|
||||
continue
|
||||
|
||||
print(f"container {self.name} not running")
|
||||
return False
|
||||
|
||||
class DockerManager:
|
||||
"""Manages all Docker and Docker Compose operations."""
|
||||
|
||||
def __init__(self, compose_file: str, **config):
|
||||
self.compose_file = compose_file
|
||||
self.network_name = config.get('network_name') or \
|
||||
os.getenv('DOCKER_NETWORK_NAME', 'private_spacetime_cloud')
|
||||
self.control_db_container = config.get('control_db_container') or \
|
||||
os.getenv('CONTROL_DB_CONTAINER', 'node')
|
||||
self.spacetime_cli_bin = config.get('spacetime_cli_bin') or \
|
||||
os.getenv('SPACETIME_CLI_BIN', 'spacetimedb-cloud')
|
||||
|
||||
def _execute_command(self, *args: str) -> str:
|
||||
"""Execute a Docker command and return its output."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
args,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True
|
||||
)
|
||||
return result.stdout.strip()
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Command failed: {e.stderr}")
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"Unexpected error: {str(e)}")
|
||||
raise
|
||||
|
||||
def compose(self, *args: str) -> str:
|
||||
"""Execute a `docker compose` command."""
|
||||
return self._execute_command("docker", "compose", "-f", self.compose_file, *args)
|
||||
|
||||
def docker(self, *args: str) -> str:
|
||||
"""Execute a `docker` command."""
|
||||
return self._execute_command("docker", *args)
|
||||
|
||||
def list_containers(self, *filters) -> List[DockerContainer]:
|
||||
"""
|
||||
List the containers of the current compose file and return as DockerContainer objects.
|
||||
|
||||
All containers are considered, even if not running ('-a' flag).
|
||||
The containers may be filtered by 'filters' ('--filter' option).
|
||||
"""
|
||||
# Use -a so we don't miss a crashed or killed container
|
||||
# when checking for readiness.
|
||||
cmd = ["ps", "-a"]
|
||||
|
||||
# Restrict to the current compose file.
|
||||
compose_file = os.path.abspath(COMPOSE_FILE)
|
||||
cmd.extend(["--filter", f"label=com.docker.compose.project.config_files={compose_file}"])
|
||||
|
||||
# Apply additional filters.
|
||||
for f in filters:
|
||||
cmd.extend(["--filter", f])
|
||||
|
||||
# Output only the fields we need for `DockerContainer`.
|
||||
cmd.extend(["--format", "{{.ID}} {{.Names}}"])
|
||||
|
||||
output = self.docker(*cmd)
|
||||
containers = []
|
||||
for line in output.splitlines():
|
||||
if line.strip():
|
||||
container_id, name = line.split(maxsplit=1)
|
||||
containers.append(DockerContainer(id=container_id, name=name))
|
||||
return containers
|
||||
|
||||
def list_spacetimedb_containers(self) -> List[DockerContainer]:
|
||||
"""List all containers running spacetimedb."""
|
||||
return self.list_containers("label=app=spacetimedb")
|
||||
|
||||
def inspect_container(self, container: DockerContainer):
|
||||
"""Run the `inspect` command for `container`, returning the parsed JSON dict."""
|
||||
info = self.docker("inspect", container.name)
|
||||
return json.loads(info)[0]
|
||||
|
||||
def get_container_by_name(self, name: str) -> Optional[DockerContainer]:
|
||||
"""Find a container by name pattern."""
|
||||
return next(
|
||||
(c for c in self.list_containers() if name in c.name),
|
||||
None
|
||||
)
|
||||
|
||||
def kill_container(self, container_id: str):
|
||||
"""Kill a container by ID."""
|
||||
print(f"Killing container {container_id}")
|
||||
self.docker("kill", container_id)
|
||||
|
||||
def start_container(self, container_id: str):
|
||||
"""Start a container by ID."""
|
||||
print(f"Starting container {container_id}")
|
||||
self.docker("start", container_id)
|
||||
|
||||
def disconnect_container(self, container_id: str):
|
||||
"""Disconnect a container from the network."""
|
||||
print(f"Disconnecting container {container_id}")
|
||||
self.docker("network", "disconnect", self.network_name, container_id)
|
||||
print(f"Disconnected container {container_id}")
|
||||
|
||||
def connect_container(self, container_id: str):
|
||||
"""Connect a container to the network."""
|
||||
print(f"Connecting container {container_id}")
|
||||
self.docker("network", "connect", self.network_name, container_id)
|
||||
print(f"Connected container {container_id}")
|
||||
|
||||
def generate_root_token(self) -> str:
|
||||
"""Generate a root token using spacetimedb-cloud."""
|
||||
return self.compose(
|
||||
"exec", self.control_db_container, self.spacetime_cli_bin, "token", "gen",
|
||||
"--subject=placeholder-node-id",
|
||||
"--jwt-priv-key", "/etc/spacetimedb/keys/id_ecdsa").split('|')[1]
|
||||
@@ -1,3 +0,0 @@
|
||||
psycopg2-binary
|
||||
toml
|
||||
xmltodict
|
||||
@@ -1,92 +0,0 @@
|
||||
from .. import Smoketest, random_string
|
||||
|
||||
class AddRemoveIndex(Smoketest):
|
||||
AUTOPUBLISH = False
|
||||
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::table(accessor = t1)]
|
||||
pub struct T1 { id: u64 }
|
||||
|
||||
#[spacetimedb::table(accessor = t2)]
|
||||
pub struct T2 { id: u64 }
|
||||
|
||||
#[spacetimedb::reducer(init)]
|
||||
pub fn init(ctx: &ReducerContext) {
|
||||
for id in 0..1_000 {
|
||||
ctx.db.t1().insert(T1 { id });
|
||||
ctx.db.t2().insert(T2 { id });
|
||||
}
|
||||
}
|
||||
"""
|
||||
MODULE_CODE_INDEXED = """
|
||||
use spacetimedb::{ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::table(accessor = t1)]
|
||||
pub struct T1 { #[index(btree)] id: u64 }
|
||||
|
||||
#[spacetimedb::table(accessor = t2)]
|
||||
pub struct T2 { #[index(btree)] id: u64 }
|
||||
|
||||
#[spacetimedb::reducer(init)]
|
||||
pub fn init(ctx: &ReducerContext) {
|
||||
for id in 0..1_000 {
|
||||
ctx.db.t1().insert(T1 { id });
|
||||
ctx.db.t2().insert(T2 { id });
|
||||
}
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn add(ctx: &ReducerContext) {
|
||||
let id = 1_001;
|
||||
ctx.db.t1().insert(T1 { id });
|
||||
ctx.db.t2().insert(T2 { id });
|
||||
}
|
||||
"""
|
||||
|
||||
JOIN_QUERY = "select t_1.* from t_1 join t_2 on t_1.id = t_2.id where t_2.id = 1001"
|
||||
|
||||
def between_publishes(self):
|
||||
"""
|
||||
The test `AddRemoveIndexAfterRestart` in `zz_docker.py`
|
||||
overwrites this method to restart docker between each publish,
|
||||
otherwise reusing this test's code.
|
||||
"""
|
||||
pass
|
||||
|
||||
def test_add_then_remove_index(self):
|
||||
"""
|
||||
First publish without the indices,
|
||||
then add the indices, and publish,
|
||||
and finally remove the indices, and publish again.
|
||||
There should be no errors
|
||||
and the unindexed versions should reject subscriptions.
|
||||
"""
|
||||
|
||||
name = random_string()
|
||||
|
||||
# Publish and attempt a subscribing to a join query.
|
||||
# There are no indices, resulting in an unsupported unindexed join.
|
||||
self.publish_module(name, clear = False)
|
||||
with self.assertRaises(Exception):
|
||||
self.subscribe(self.JOIN_QUERY, n = 0)
|
||||
|
||||
self.between_publishes()
|
||||
|
||||
# Publish the indexed version.
|
||||
# Now we have indices, so the query should be accepted.
|
||||
self.write_module_code(self.MODULE_CODE_INDEXED)
|
||||
self.publish_module(name, clear = False)
|
||||
sub = self.subscribe(self.JOIN_QUERY, n = 1)
|
||||
self.call("add", anon = True)
|
||||
sub()
|
||||
|
||||
self.between_publishes()
|
||||
|
||||
# Publish the unindexed version again, removing the index.
|
||||
# The initial subscription should be rejected again.
|
||||
self.write_module_code(self.MODULE_CODE)
|
||||
self.publish_module(name, clear = False)
|
||||
with self.assertRaises(Exception):
|
||||
self.subscribe(self.JOIN_QUERY, n = 0)
|
||||
@@ -1,134 +0,0 @@
|
||||
|
||||
from .. import Smoketest
|
||||
import string
|
||||
import functools
|
||||
|
||||
|
||||
ints = "u8", "u16", "u32", "u64", "u128", "i8", "i16", "i32", "i64", "i128"
|
||||
|
||||
|
||||
def reducer_name(int_ty: str) -> str:
|
||||
# Convert "u8" -> "u_8", "i128" -> "i_128"
|
||||
return f"{int_ty[0]}_{int_ty[1:]}"
|
||||
|
||||
|
||||
class IntTests:
|
||||
make_func = lambda int_ty: lambda self: self.do_test_autoinc(int_ty)
|
||||
for int_ty in ints:
|
||||
locals()[f"test_autoinc_{int_ty}"] = make_func(int_ty)
|
||||
del int_ty, make_func
|
||||
|
||||
|
||||
autoinc1_template = string.Template("""
|
||||
#[spacetimedb::table(accessor = person_$KEY_TY)]
|
||||
pub struct Person_$KEY_TY {
|
||||
#[auto_inc]
|
||||
key_col: $KEY_TY,
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn add_$REDUCER_TY(ctx: &ReducerContext, name: String, expected_value: $KEY_TY) {
|
||||
let value = ctx.db.person_$KEY_TY().insert(Person_$KEY_TY { key_col: 0, name });
|
||||
assert_eq!(value.key_col, expected_value);
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn say_hello_$REDUCER_TY(ctx: &ReducerContext) {
|
||||
for person in ctx.db.person_$KEY_TY().iter() {
|
||||
log::info!("Hello, {}:{}!", person.key_col, person.name);
|
||||
}
|
||||
log::info!("Hello, World!");
|
||||
}
|
||||
""")
|
||||
|
||||
|
||||
class AutoincBasic(IntTests, Smoketest):
|
||||
"This tests the auto_inc functionality"
|
||||
|
||||
MODULE_CODE = f"""
|
||||
#![allow(non_camel_case_types)]
|
||||
use spacetimedb::{{log, ReducerContext, Table}};
|
||||
{"".join(
|
||||
autoinc1_template.substitute(
|
||||
KEY_TY=int_ty,
|
||||
REDUCER_TY=reducer_name(int_ty),
|
||||
)
|
||||
for int_ty in ints
|
||||
)}
|
||||
"""
|
||||
|
||||
def do_test_autoinc(self, int_ty):
|
||||
r = reducer_name(int_ty)
|
||||
self.call(f"add_{r}", "Robert", 1)
|
||||
self.call(f"add_{r}", "Julie", 2)
|
||||
self.call(f"add_{r}", "Samantha", 3)
|
||||
self.call(f"say_hello_{r}")
|
||||
logs = self.logs(4)
|
||||
self.assertIn("Hello, 3:Samantha!", logs)
|
||||
self.assertIn("Hello, 2:Julie!", logs)
|
||||
self.assertIn("Hello, 1:Robert!", logs)
|
||||
self.assertIn("Hello, World!", logs)
|
||||
|
||||
|
||||
autoinc2_template = string.Template("""
|
||||
#[spacetimedb::table(accessor = person_$KEY_TY)]
|
||||
pub struct Person_$KEY_TY {
|
||||
#[auto_inc]
|
||||
#[unique]
|
||||
key_col: $KEY_TY,
|
||||
#[unique]
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn add_new_$REDUCER_TY(ctx: &ReducerContext, name: String) -> Result<(), Box<dyn Error>> {
|
||||
let value = ctx.db.person_$KEY_TY().try_insert(Person_$KEY_TY { key_col: 0, name })?;
|
||||
log::info!("Assigned Value: {} -> {}", value.key_col, value.name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn update_$REDUCER_TY(ctx: &ReducerContext, name: String, new_id: $KEY_TY) {
|
||||
ctx.db.person_$KEY_TY().name().delete(&name);
|
||||
let _value = ctx.db.person_$KEY_TY().insert(Person_$KEY_TY { key_col: new_id, name });
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn say_hello_$REDUCER_TY(ctx: &ReducerContext) {
|
||||
for person in ctx.db.person_$KEY_TY().iter() {
|
||||
log::info!("Hello, {}:{}!", person.key_col, person.name);
|
||||
}
|
||||
log::info!("Hello, World!");
|
||||
}
|
||||
""")
|
||||
|
||||
|
||||
class AutoincUnique(IntTests, Smoketest):
|
||||
"""This tests unique constraints being violated during autoinc insertion"""
|
||||
|
||||
MODULE_CODE = f"""
|
||||
#![allow(non_camel_case_types)]
|
||||
use std::error::Error;
|
||||
use spacetimedb::{{log, ReducerContext, Table}};
|
||||
{"".join(
|
||||
autoinc2_template.substitute(
|
||||
KEY_TY=int_ty,
|
||||
REDUCER_TY=reducer_name(int_ty),
|
||||
)
|
||||
for int_ty in ints
|
||||
)}
|
||||
"""
|
||||
|
||||
def do_test_autoinc(self, int_ty):
|
||||
r = reducer_name(int_ty)
|
||||
self.call(f"update_{r}", "Robert", 2)
|
||||
self.call(f"add_new_{r}", "Success")
|
||||
with self.assertRaises(Exception):
|
||||
self.call(f"add_new_{r}", "Failure")
|
||||
|
||||
self.call(f"say_hello_{r}")
|
||||
logs = self.logs(4)
|
||||
self.assertIn("Hello, 2:Robert!", logs)
|
||||
self.assertIn("Hello, 1:Success!", logs)
|
||||
self.assertIn("Hello, World!", logs)
|
||||
@@ -1,361 +0,0 @@
|
||||
from .. import Smoketest
|
||||
import sys
|
||||
import logging
|
||||
|
||||
|
||||
class AddTableAutoMigration(Smoketest):
|
||||
MODULE_CODE_INIT = """
|
||||
use spacetimedb::{log, ReducerContext, Table, SpacetimeType};
|
||||
use PersonKind::*;
|
||||
|
||||
#[spacetimedb::table(accessor = person, public)]
|
||||
pub struct Person {
|
||||
name: String,
|
||||
kind: PersonKind,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn add_person(ctx: &ReducerContext, name: String, kind: String) {
|
||||
let kind = kind_from_string(kind);
|
||||
ctx.db.person().insert(Person { name, kind });
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn print_persons(ctx: &ReducerContext, prefix: String) {
|
||||
for person in ctx.db.person().iter() {
|
||||
let kind = kind_to_string(person.kind);
|
||||
log::info!("{prefix}: {} - {kind}", person.name);
|
||||
}
|
||||
}
|
||||
|
||||
#[spacetimedb::table(accessor = point_mass)]
|
||||
pub struct PointMass {
|
||||
mass: f64,
|
||||
/// This used to cause an error when check_compatible did not resolve types in a `ModuleDef`.
|
||||
position: Vector2,
|
||||
}
|
||||
|
||||
#[derive(SpacetimeType, Clone, Copy)]
|
||||
pub struct Vector2 {
|
||||
x: f64,
|
||||
y: f64,
|
||||
}
|
||||
"""
|
||||
|
||||
MODULE_CODE = MODULE_CODE_INIT + """
|
||||
|
||||
#[spacetimedb::table(accessor = person_info)]
|
||||
pub struct PersonInfo {
|
||||
#[primary_key]
|
||||
id: u64,
|
||||
}
|
||||
|
||||
#[derive(SpacetimeType, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum PersonKind {
|
||||
Student,
|
||||
}
|
||||
|
||||
fn kind_from_string(_: String) -> PersonKind {
|
||||
Student
|
||||
}
|
||||
|
||||
fn kind_to_string(Student: PersonKind) -> &'static str {
|
||||
"Student"
|
||||
}
|
||||
"""
|
||||
|
||||
MODULE_CODE_UPDATED = (
|
||||
MODULE_CODE_INIT
|
||||
+ """
|
||||
|
||||
#[spacetimedb::table(accessor = person_info)]
|
||||
pub struct PersonInfo {
|
||||
#[primary_key]
|
||||
#[auto_inc]
|
||||
id: u64,
|
||||
}
|
||||
|
||||
#[derive(SpacetimeType, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum PersonKind {
|
||||
Student,
|
||||
Professor,
|
||||
}
|
||||
|
||||
fn kind_from_string(kind: String) -> PersonKind {
|
||||
match &*kind {
|
||||
"Student" => Student,
|
||||
"Professor" => Professor,
|
||||
_ => panic!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn kind_to_string(kind: PersonKind) -> &'static str {
|
||||
match kind {
|
||||
Student => "Student",
|
||||
Professor => "Professor",
|
||||
}
|
||||
}
|
||||
|
||||
#[spacetimedb::table(accessor = book, public)]
|
||||
pub struct Book {
|
||||
isbn: String,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn add_book(ctx: &ReducerContext, isbn: String) {
|
||||
ctx.db.book().insert(Book { isbn });
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn print_books(ctx: &ReducerContext, prefix: String) {
|
||||
for book in ctx.db.book().iter() {
|
||||
log::info!("{}: {}", prefix, book.isbn);
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
def test_add_table_auto_migration(self):
|
||||
"""This tests uploading a module with a schema change that should not require clearing the database."""
|
||||
logging.info("Initial publish complete")
|
||||
|
||||
# Start a subscription before publishing the module, to test that the subscription remains intact after re-publishing.
|
||||
sub = self.subscribe("select * from person", n=4, confirmed=False)
|
||||
|
||||
# initial module code is already published by test framework
|
||||
self.call("add_person", "Robert", "Student")
|
||||
self.call("add_person", "Julie", "Student")
|
||||
self.call("add_person", "Samantha", "Student")
|
||||
self.call("print_persons", "BEFORE")
|
||||
logs = self.logs(100)
|
||||
self.assertIn("BEFORE: Samantha - Student", logs)
|
||||
self.assertIn("BEFORE: Julie - Student", logs)
|
||||
self.assertIn("BEFORE: Robert - Student", logs)
|
||||
|
||||
logging.info(
|
||||
"Initial operations complete, updating module without clear",
|
||||
)
|
||||
|
||||
self.write_module_code(self.MODULE_CODE_UPDATED)
|
||||
self.publish_module(self.database_identity, clear=False)
|
||||
|
||||
logging.info("Updated")
|
||||
self.call("add_person", "Husserl", "Student")
|
||||
|
||||
# If subscription, we should get 4 rows corresponding to 4 reducer calls (including before and after update)
|
||||
sub = sub();
|
||||
self.assertEqual(len(sub), 4)
|
||||
|
||||
self.logs(100)
|
||||
|
||||
self.call("add_person", "Husserl", "Professor")
|
||||
self.call("add_book", "1234567890")
|
||||
self.call("print_persons", "AFTER_PERSON")
|
||||
self.call("print_books", "AFTER_BOOK")
|
||||
|
||||
logs = self.logs(100)
|
||||
self.assertIn("AFTER_PERSON: Samantha - Student", logs)
|
||||
self.assertIn("AFTER_PERSON: Julie - Student", logs)
|
||||
self.assertIn("AFTER_PERSON: Robert - Student", logs)
|
||||
self.assertIn("AFTER_PERSON: Husserl - Professor", logs)
|
||||
self.assertIn("AFTER_BOOK: 1234567890", logs)
|
||||
|
||||
|
||||
class RejectTableChanges(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{log, ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::table(accessor = person)]
|
||||
pub struct Person {
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn add_person(ctx: &ReducerContext, name: String) {
|
||||
ctx.db.person().insert(Person { name });
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn print_persons(ctx: &ReducerContext, prefix: String) {
|
||||
for person in ctx.db.person().iter() {
|
||||
log::info!("{}: {}", prefix, person.name);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
MODULE_CODE_UPDATED = """
|
||||
use spacetimedb::{log, ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::table(accessor = person)]
|
||||
pub struct Person {
|
||||
name: String,
|
||||
age: u128,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn add_person(ctx: &ReducerContext, name: String) {
|
||||
ctx.db.person().insert(Person { name, age: 70 });
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn print_persons(ctx: &ReducerContext, prefix: String) {
|
||||
for person in ctx.db.person().iter() {
|
||||
log::info!("{}: {}", prefix, person.name);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
def test_reject_schema_changes(self):
|
||||
"""This tests that a module with invalid schema changes cannot be published without -c or a migration."""
|
||||
|
||||
logging.info("Initial publish complete, trying to do an invalid update.")
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
self.write_module_code(self.MODULE_CODE_UPDATED)
|
||||
self.publish_module(self.database_identity, clear=False)
|
||||
|
||||
logging.info("Rejected as expected.")
|
||||
|
||||
class AddTableColumns(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{log, ReducerContext, Table};
|
||||
|
||||
#[derive(Debug)]
|
||||
#[spacetimedb::table(accessor = person)]
|
||||
pub struct Person {
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn add_person(ctx: &ReducerContext, name: String) {
|
||||
ctx.db.person().insert(Person { name });
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn print_persons(ctx: &ReducerContext, prefix: String) {
|
||||
for person in ctx.db.person().iter() {
|
||||
log::info!("{}: {}", prefix, person.name);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
MODULE_UPDATED = """
|
||||
use spacetimedb::{log, ReducerContext, Table};
|
||||
|
||||
#[derive(Debug)]
|
||||
#[spacetimedb::table(accessor = person)]
|
||||
pub struct Person {
|
||||
// Add indexes to verify they are handled correctly during migration,
|
||||
// issue #3441
|
||||
#[index(btree)]
|
||||
name: String,
|
||||
#[default(0)]
|
||||
age: u16,
|
||||
#[default(19)]
|
||||
mass: u16,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn add_person(ctx: &ReducerContext, name: String) {
|
||||
ctx.db.person().insert(Person { name, age: 70, mass: 180 });
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn print_persons(ctx: &ReducerContext, prefix: String) {
|
||||
for person in ctx.db.person().iter() {
|
||||
log::info!("{}: {:?}", prefix, person);
|
||||
}
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer(client_disconnected)]
|
||||
pub fn identity_disconnected(ctx: &ReducerContext) {
|
||||
log::info!("FIRST_UPDATE: client disconnected");
|
||||
}
|
||||
"""
|
||||
|
||||
MODULE_UPDATED_AGAIN = """
|
||||
use spacetimedb::{log, ReducerContext, Table};
|
||||
|
||||
#[derive(Debug)]
|
||||
#[spacetimedb::table(accessor = person)]
|
||||
pub struct Person {
|
||||
name: String,
|
||||
age: u16,
|
||||
#[default(19)]
|
||||
mass: u16,
|
||||
#[default(160)]
|
||||
height: u32,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn add_person(ctx: &ReducerContext, name: String) {
|
||||
ctx.db.person().insert(Person { name, age: 70, mass: 180, height: 72 });
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn print_persons(ctx: &ReducerContext, prefix: String) {
|
||||
for person in ctx.db.person().iter() {
|
||||
log::info!("{}: {:?}", prefix, person);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
def test_add_table_columns(self):
|
||||
"""Verify schema upgrades that add columns with defaults (twice)."""
|
||||
|
||||
# Subscribe to person table changes multiple times to simulate active clients
|
||||
NUM_SUBSCRIBERS = 20
|
||||
subs = [None] * NUM_SUBSCRIBERS
|
||||
for i in range(NUM_SUBSCRIBERS):
|
||||
# We need unconfirmed reads for the updates to arrive properly.
|
||||
# Otherwise, there's a race between module teardown in publish, vs subscribers
|
||||
# getting the row deletion they expect.
|
||||
subs[i]= self.subscribe("select * from person", n=5, confirmed=False)
|
||||
|
||||
# Insert under initial schema
|
||||
self.call("add_person", "Robert")
|
||||
|
||||
# First upgrade: add age & mass columns
|
||||
self.write_module_code(self.MODULE_UPDATED)
|
||||
self.publish_module(self.database_identity, clear=False, break_clients=True)
|
||||
self.call("print_persons", "FIRST_UPDATE")
|
||||
|
||||
logs1 = self.logs(100)
|
||||
|
||||
# Validate disconnect + schema migration logs
|
||||
self.assertIn("Disconnecting all users", logs1)
|
||||
self.assertIn(
|
||||
'FIRST_UPDATE: Person { name: "Robert", age: 0, mass: 19 }',
|
||||
logs1,
|
||||
)
|
||||
disconnect_count = logs1.count("FIRST_UPDATE: client disconnected")
|
||||
|
||||
# Insert new data under upgraded schema
|
||||
self.call("add_person", "Robert2")
|
||||
|
||||
self.assertEqual(
|
||||
disconnect_count,
|
||||
# +1 is due to reducer call above
|
||||
NUM_SUBSCRIBERS + 1,
|
||||
msg=f"Unexpected disconnect counts: {disconnect_count}",
|
||||
)
|
||||
|
||||
# Validate all subscribers were disconnected after first upgrade
|
||||
# they should 2 updates: one for initial insertion and one for table drop during migration
|
||||
for i in range(NUM_SUBSCRIBERS):
|
||||
sub = subs[i]()
|
||||
self.assertEqual(len(sub), 2, msg=f"Subscriber {i} received unexpected rows: {sub}")
|
||||
|
||||
|
||||
# Second upgrade
|
||||
self.write_module_code(self.MODULE_UPDATED_AGAIN)
|
||||
self.publish_module(self.database_identity, clear=False, break_clients=True)
|
||||
self.call("print_persons", "UPDATE_2")
|
||||
|
||||
logs2 = self.logs(100)
|
||||
|
||||
# Validate new schema with height
|
||||
self.assertIn(
|
||||
'UPDATE_2: Person { name: "Robert2", age: 70, mass: 180, height: 160 }',
|
||||
logs2,
|
||||
)
|
||||
@@ -1,151 +0,0 @@
|
||||
import string
|
||||
|
||||
from .. import Smoketest
|
||||
|
||||
class CallReducerProcedure(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{log, ProcedureContext, ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::table(accessor = person)]
|
||||
pub struct Person {
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn say_hello(_ctx: &ReducerContext) {
|
||||
log::info!("Hello, World!");
|
||||
}
|
||||
|
||||
#[spacetimedb::procedure]
|
||||
pub fn return_person(_ctx: &mut ProcedureContext) -> Person {
|
||||
return Person { name: "World".to_owned() };
|
||||
}
|
||||
"""
|
||||
|
||||
def test_call_reducer_procedure(self):
|
||||
"""Check calling a reducer (not return) and procedure (return)"""
|
||||
msg = self.call("say_hello")
|
||||
self.assertEqual(msg, "")
|
||||
|
||||
msg = self.call("return_person")
|
||||
self.assertEqual(msg.strip(), '["World"]')
|
||||
|
||||
def test_call_errors(self):
|
||||
"""Check calling a non-existent reducer/procedure raises error"""
|
||||
out = self.call("non_existent_reducer", check= False, full_output=True).stderr
|
||||
identity = self.database_identity
|
||||
self.assertIn(out.strip(), f"""
|
||||
WARNING: This command is UNSTABLE and subject to breaking changes.
|
||||
|
||||
Error: No such reducer OR procedure `non_existent_reducer` for database `{identity}` resolving to identity `{identity}`.
|
||||
|
||||
Here are some existing reducers:
|
||||
- say_hello
|
||||
|
||||
Here are some existing procedures:
|
||||
- return_person""".strip())
|
||||
|
||||
out = self.call("non_existent_procedure", check= False, full_output=True).stderr
|
||||
self.assertIn(out.strip(), f"""
|
||||
WARNING: This command is UNSTABLE and subject to breaking changes.
|
||||
|
||||
Error: No such reducer OR procedure `non_existent_procedure` for database `{identity}` resolving to identity `{identity}`.
|
||||
|
||||
Here are some existing reducers:
|
||||
- say_hello
|
||||
|
||||
Here are some existing procedures:
|
||||
- return_person""".strip())
|
||||
|
||||
out = self.call("say_hell", check= False, full_output=True).stderr
|
||||
|
||||
self.assertIn(out.strip(), f"""
|
||||
WARNING: This command is UNSTABLE and subject to breaking changes.
|
||||
|
||||
Error: No such reducer OR procedure `say_hell` for database `{identity}` resolving to identity `{identity}`.
|
||||
|
||||
A reducer with a similar name exists: `say_hello`""".strip())
|
||||
|
||||
out = self.call("return_perso", check= False, full_output=True).stderr
|
||||
|
||||
self.assertIn(out.strip(), f"""
|
||||
WARNING: This command is UNSTABLE and subject to breaking changes.
|
||||
|
||||
Error: No such reducer OR procedure `return_perso` for database `{identity}` resolving to identity `{identity}`.
|
||||
|
||||
A procedure with a similar name exists: `return_person`""".strip())
|
||||
|
||||
class CallEmptyReducerProcedure(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{log, ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::table(accessor = person)]
|
||||
pub struct Person {
|
||||
name: String,
|
||||
}
|
||||
"""
|
||||
def test_call_empty_errors(self):
|
||||
"""Check calling into a database with no reducers/procedures raises error"""
|
||||
out = self.call("non_existent", check= False, full_output=True).stderr
|
||||
identity = self.database_identity
|
||||
self.assertIn(out.strip(), f"""
|
||||
WARNING: This command is UNSTABLE and subject to breaking changes.
|
||||
|
||||
Error: No such reducer OR procedure `non_existent` for database `{identity}` resolving to identity `{identity}`.
|
||||
|
||||
The database has no reducers.
|
||||
|
||||
The database has no procedures.""".strip())
|
||||
|
||||
module_template = string.Template("""
|
||||
#[spacetimedb::reducer]
|
||||
pub fn say_reducer_$NUM(_ctx: &ReducerContext) {
|
||||
log::info!("Hello from reducer $NUM!");
|
||||
}
|
||||
#[spacetimedb::procedure]
|
||||
pub fn say_procedure_$NUM(_ctx: &mut ProcedureContext) {
|
||||
log::info!("Hello from procedure $NUM!");
|
||||
}
|
||||
""")
|
||||
|
||||
class CallManyReducerProcedure(Smoketest):
|
||||
MODULE_CODE = f"""
|
||||
use spacetimedb::{{log, ProcedureContext, ReducerContext}};
|
||||
{"".join(module_template.substitute(NUM=i) for i in range(11))}
|
||||
"""
|
||||
|
||||
def test_call_many_errors(self):
|
||||
"""Check calling into a database with many reducers/procedures raises error with listing"""
|
||||
out = self.call("non_existent", check= False, full_output=True).stderr
|
||||
identity = self.database_identity
|
||||
self.assertIn(out.strip(), f"""
|
||||
WARNING: This command is UNSTABLE and subject to breaking changes.
|
||||
|
||||
Error: No such reducer OR procedure `non_existent` for database `{identity}` resolving to identity `{identity}`.
|
||||
|
||||
Here are some existing reducers:
|
||||
- say_reducer_0
|
||||
- say_reducer_1
|
||||
- say_reducer_2
|
||||
- say_reducer_3
|
||||
- say_reducer_4
|
||||
- say_reducer_5
|
||||
- say_reducer_6
|
||||
- say_reducer_7
|
||||
- say_reducer_8
|
||||
- say_reducer_9
|
||||
... (1 reducer not shown)
|
||||
|
||||
Here are some existing procedures:
|
||||
- say_procedure_0
|
||||
- say_procedure_1
|
||||
- say_procedure_2
|
||||
- say_procedure_3
|
||||
- say_procedure_4
|
||||
- say_procedure_5
|
||||
- say_procedure_6
|
||||
- say_procedure_7
|
||||
- say_procedure_8
|
||||
- say_procedure_9
|
||||
... (1 procedure not shown)
|
||||
""".strip())
|
||||
@@ -1,93 +0,0 @@
|
||||
from .. import Smoketest, random_string
|
||||
import time
|
||||
|
||||
|
||||
class ClearDatabase(Smoketest):
|
||||
AUTOPUBLISH = False
|
||||
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{ReducerContext, Table, duration};
|
||||
|
||||
#[spacetimedb::table(accessor = counter, public)]
|
||||
pub struct Counter {
|
||||
#[primary_key]
|
||||
id: u64,
|
||||
val: u64
|
||||
}
|
||||
|
||||
#[spacetimedb::table(accessor = scheduled_counter, public, scheduled(inc, at = sched_at))]
|
||||
pub struct ScheduledCounter {
|
||||
#[primary_key]
|
||||
#[auto_inc]
|
||||
scheduled_id: u64,
|
||||
sched_at: spacetimedb::ScheduleAt,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn inc(ctx: &ReducerContext, arg: ScheduledCounter) {
|
||||
if let Some(mut counter) = ctx.db.counter().id().find(arg.scheduled_id) {
|
||||
counter.val += 1;
|
||||
ctx.db.counter().id().update(counter);
|
||||
} else {
|
||||
ctx.db.counter().insert(Counter {
|
||||
id: arg.scheduled_id,
|
||||
val: 1,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer(init)]
|
||||
pub fn init(ctx: &ReducerContext) {
|
||||
ctx.db.scheduled_counter().insert(ScheduledCounter {
|
||||
scheduled_id: 0,
|
||||
sched_at: duration!(100ms).into(),
|
||||
});
|
||||
}
|
||||
"""
|
||||
|
||||
def test_publish_clear_database(self):
|
||||
"""
|
||||
Test that publishing with the clear flag stops the old module.
|
||||
This relies on private control database internals.
|
||||
"""
|
||||
|
||||
name = random_string()
|
||||
|
||||
# Initial publish
|
||||
replicas_1 = self.publish(name, clear = False)
|
||||
self.assertTrue(len(replicas_1) >= 1)
|
||||
|
||||
# Publish with clear = True, destroying `replicas_1`
|
||||
replicas_2 = self.publish(name, clear = True)
|
||||
self.assertTrue(len(replicas_2) >= 1)
|
||||
self.assertNotEqual(replicas_1, replicas_2)
|
||||
|
||||
# Delete the replicas created in the second publish
|
||||
self.spacetime("delete", name)
|
||||
|
||||
# State updates don't happen instantly
|
||||
time.sleep(0.25)
|
||||
|
||||
# Check that all replicas have state `Deleted`
|
||||
replicas = replicas_1 + replicas_2
|
||||
state_filter = f'replica_id = {" OR replica_id = ".join(replicas)}'
|
||||
states = self.query_control(f"select lifecycle from replica_state where {state_filter}")
|
||||
self.assertEqual(len(states), len(replicas))
|
||||
self.assertTrue(all([x == "(Deleted = ())" for x in states]))
|
||||
|
||||
def publish(self, name, clear):
|
||||
self.publish_module(name, clear = clear)
|
||||
replicas = self.query_control(f"""
|
||||
select replica.id from replica
|
||||
join database on database.id = replica.database_id
|
||||
where database.database_identity = '0x{self.resolved_identity}'
|
||||
""")
|
||||
|
||||
return replicas
|
||||
|
||||
|
||||
def query_control(self, sql):
|
||||
out = self.spacetime("sql", "spacetime-control", sql)
|
||||
out = [line.strip() for line in out.splitlines()]
|
||||
out = out[2:] # Remove header
|
||||
return out
|
||||
@@ -1,70 +0,0 @@
|
||||
from .. import Smoketest
|
||||
|
||||
MODULE_HEADER = """
|
||||
use spacetimedb::{ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::table(accessor = all_u8s, public)]
|
||||
pub struct AllU8s {
|
||||
number: u8,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer(init)]
|
||||
pub fn init(ctx: &ReducerContext) {
|
||||
// Here's a bunch of data that no one will be able to subscribe to.
|
||||
for i in u8::MIN..=u8::MAX {
|
||||
ctx.db.all_u8s().insert(AllU8s { number: i });
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
class ClientConnectedErrorRejectsConnection(Smoketest):
|
||||
MODULE_CODE = MODULE_HEADER + """
|
||||
|
||||
#[spacetimedb::reducer(client_connected)]
|
||||
pub fn identity_connected(ctx: &ReducerContext) -> Result<(), String> {
|
||||
Err("Rejecting connection from client".to_string())
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer(client_disconnected)]
|
||||
pub fn identity_disconnected(_ctx: &ReducerContext) {
|
||||
panic!("This should never be called, since we reject all connections!")
|
||||
}
|
||||
"""
|
||||
|
||||
def test_client_connected_error_rejects_connection(self):
|
||||
with self.assertRaises(Exception):
|
||||
self.subscribe("select * from all_u8s", n = 0)()
|
||||
|
||||
logs = self.logs(100)
|
||||
self.assertIn('Rejecting connection from client', logs)
|
||||
self.assertNotIn('This should never be called, since we reject all connections!', logs)
|
||||
|
||||
class ClientDisconnectedErrorStillDeletesStClient(Smoketest):
|
||||
MODULE_CODE = MODULE_HEADER + """
|
||||
#[spacetimedb::reducer(client_connected)]
|
||||
pub fn identity_connected(_ctx: &ReducerContext) -> Result<(), String> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer(client_disconnected)]
|
||||
pub fn identity_disconnected(_ctx: &ReducerContext) {
|
||||
panic!("This should be called, but the `st_client` row should still be deleted")
|
||||
}
|
||||
"""
|
||||
|
||||
def test_client_disconnected_error_still_deletes_st_client(self):
|
||||
self.subscribe("select * from all_u8s", n = 0)()
|
||||
|
||||
logs = self.logs(100)
|
||||
self.assertIn('This should be called, but the `st_client` row should still be deleted', logs)
|
||||
|
||||
sql_out = self.spacetime("sql", self.database_identity, "select * from st_client")
|
||||
|
||||
# The SQL query itself now creates a temporary connection, so we may
|
||||
# see exactly one row (the SQL connection's own). The websocket's row
|
||||
# should be gone. Count non-header, non-separator lines with content.
|
||||
lines = sql_out.strip().split('\n')
|
||||
# Data rows are those that are not the header and not the separator line
|
||||
data_rows = [l for l in lines if '|' in l and '-+-' not in l and 'identity' not in l.lower()]
|
||||
self.assertLessEqual(len(data_rows), 1,
|
||||
f"Expected at most 1 st_client row (the SQL connection itself), got: {sql_out}")
|
||||
@@ -1,54 +0,0 @@
|
||||
from .. import Smoketest, parse_sql_result
|
||||
|
||||
#
|
||||
# TODO: We only test that we can pass a --confirmed flag and that things
|
||||
# appear to works as if we hadn't. Without controlling the server, we can't
|
||||
# test that there is any difference in behavior.
|
||||
#
|
||||
|
||||
class ConfirmedReads(Smoketest):
|
||||
def test_confirmed_reads_receive_updates(self):
|
||||
"""Tests that subscribing with confirmed=true receives updates"""
|
||||
|
||||
sub = self.subscribe("select * from person", n = 2, confirmed = True)
|
||||
self.call("add", "Horst")
|
||||
self.spacetime(
|
||||
"sql",
|
||||
self.database_identity,
|
||||
"insert into person (name) values ('Egon')")
|
||||
|
||||
events = sub()
|
||||
self.assertEqual([
|
||||
{
|
||||
'person': {
|
||||
'deletes': [],
|
||||
'inserts': [{'name': 'Horst'}]
|
||||
}
|
||||
},
|
||||
{
|
||||
'person': {
|
||||
'deletes': [],
|
||||
'inserts': [{'name': 'Egon'}]
|
||||
}
|
||||
}
|
||||
], events)
|
||||
|
||||
class ConfirmedReadsSql(Smoketest):
|
||||
def test_sql_with_confirmed_reads_receives_result(self):
|
||||
"""Tests that an SQL operations with confirmed=true returns a result"""
|
||||
|
||||
self.spacetime(
|
||||
"sql",
|
||||
"--confirmed",
|
||||
"true",
|
||||
self.database_identity,
|
||||
"insert into person (name) values ('Horst')")
|
||||
|
||||
res = self.spacetime(
|
||||
"sql",
|
||||
"--confirmed",
|
||||
"true",
|
||||
self.database_identity,
|
||||
"select * from person")
|
||||
res = parse_sql_result(str(res))
|
||||
self.assertEqual([{'name': '"Horst"'}], res)
|
||||
@@ -1,33 +0,0 @@
|
||||
from .. import Smoketest
|
||||
|
||||
class ConnDisconnFromCli(Smoketest):
|
||||
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{log, ReducerContext};
|
||||
|
||||
#[spacetimedb::reducer(client_connected)]
|
||||
pub fn connected(_ctx: &ReducerContext) {
|
||||
log::info!("_connect called");
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer(client_disconnected)]
|
||||
pub fn disconnected(_ctx: &ReducerContext) {
|
||||
log::info!("disconnect called");
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn say_hello(_ctx: &ReducerContext) {
|
||||
log::info!("Hello, World!");
|
||||
}
|
||||
"""
|
||||
|
||||
def test_conn_disconn(self):
|
||||
"""
|
||||
Ensure that the connect and disconnect functions are called when invoking a reducer from the CLI
|
||||
"""
|
||||
|
||||
self.call("say_hello")
|
||||
logs = self.logs(10)
|
||||
self.assertIn('_connect called', logs)
|
||||
self.assertIn('disconnect called', logs)
|
||||
self.assertIn('Hello, World!', logs)
|
||||
@@ -1,18 +0,0 @@
|
||||
from .. import spacetime
|
||||
import unittest
|
||||
import tempfile
|
||||
|
||||
class CreateProject(unittest.TestCase):
|
||||
def test_create_project(self):
|
||||
"""
|
||||
Ensure that the CLI is able to create a local project. This test does not depend on a running spacetimedb instance.
|
||||
"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with self.assertRaises(Exception):
|
||||
spacetime("init", "--non-interactive", "test-project")
|
||||
with self.assertRaises(Exception):
|
||||
spacetime("init", "--non-interactive", "--project-path", tmpdir, "test-project")
|
||||
spacetime("init", "--non-interactive", "--lang=rust", "--project-path", tmpdir, "test-project")
|
||||
with self.assertRaises(Exception):
|
||||
spacetime("init", "--non-interactive", "--lang=rust", "--project-path", tmpdir, "test-project")
|
||||
@@ -1,81 +0,0 @@
|
||||
from .. import run_cmd, STDB_DIR, requires_dotnet, spacetime
|
||||
import unittest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import subprocess
|
||||
import xml.etree.ElementTree as xml
|
||||
|
||||
|
||||
@requires_dotnet
|
||||
class CreateProject(unittest.TestCase):
|
||||
def test_build_csharp_module(self):
|
||||
"""
|
||||
Ensure that the CLI is able to create and compile a csharp project. This test does not depend on a running spacetimedb instance. Skips if dotnet 8.0 is not available
|
||||
"""
|
||||
|
||||
bindings = Path(STDB_DIR) / "crates" / "bindings-csharp"
|
||||
|
||||
try:
|
||||
|
||||
run_cmd("dotnet", "nuget", "locals", "all", "--clear", cwd=bindings, capture_stderr=True)
|
||||
run_cmd("dotnet", "workload", "install", "wasi-experimental", "--skip-manifest-update", cwd=STDB_DIR / "modules")
|
||||
run_cmd("dotnet", "pack", cwd=bindings, capture_stderr=True)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
spacetime(
|
||||
"init",
|
||||
"--non-interactive",
|
||||
"--lang=csharp",
|
||||
"--project-path",
|
||||
tmpdir,
|
||||
"csharp-project",
|
||||
)
|
||||
|
||||
server_path = Path(tmpdir) / "spacetimedb"
|
||||
|
||||
packed_projects = ["BSATN.Runtime", "Runtime"]
|
||||
|
||||
config = xml.Element("configuration")
|
||||
|
||||
sources = xml.SubElement(config, "packageSources")
|
||||
mappings = xml.SubElement(config, "packageSourceMapping")
|
||||
|
||||
def add_mapping(source, pattern):
|
||||
mapping = xml.SubElement(mappings, "packageSource", key=source)
|
||||
xml.SubElement(mapping, "package", pattern=pattern)
|
||||
|
||||
for project in packed_projects:
|
||||
# Add local build directories as NuGet repositories.
|
||||
path = bindings / project / "bin" / "Release"
|
||||
project = f"SpacetimeDB.{project}"
|
||||
xml.SubElement(sources, "add", key=project, value=str(path))
|
||||
|
||||
# Add strict package source mappings to ensure that
|
||||
# SpacetimeDB.* packages are used from those directories
|
||||
# and never from nuget.org.
|
||||
#
|
||||
# This prevents bugs where we silently used an outdated
|
||||
# version which led to tests passing when they shouldn't.
|
||||
add_mapping(project, project)
|
||||
|
||||
# Add fallback for other packages.
|
||||
add_mapping("nuget.org", "*")
|
||||
|
||||
xml.indent(config)
|
||||
config = xml.tostring(config, encoding="unicode", xml_declaration=True)
|
||||
|
||||
print("Writing `nuget.config` contents:")
|
||||
print(config)
|
||||
|
||||
config_path = server_path / "nuget.config"
|
||||
with open(config_path, "w") as f:
|
||||
f.write(config)
|
||||
|
||||
run_cmd("dotnet", "publish", cwd=server_path, capture_stderr=True)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e)
|
||||
print("output:")
|
||||
print(e.output)
|
||||
raise e
|
||||
@@ -1,12 +0,0 @@
|
||||
import unittest
|
||||
import tempfile
|
||||
import subprocess
|
||||
from .. import Smoketest
|
||||
|
||||
class ClippyDefaultModule(Smoketest):
|
||||
AUTOPUBLISH = False
|
||||
|
||||
def test_default_module_clippy_check(self):
|
||||
"""Ensure that the default rust module has no clippy errors or warnings"""
|
||||
|
||||
subprocess.check_call(["cargo", "clippy", "--", "-Dwarnings"], cwd=self.project_path)
|
||||
@@ -1,63 +0,0 @@
|
||||
from .. import Smoketest, random_string
|
||||
import time
|
||||
|
||||
class DeleteDatabase(Smoketest):
|
||||
AUTOPUBLISH = False
|
||||
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{ReducerContext, Table, duration};
|
||||
|
||||
#[spacetimedb::table(accessor = counter, public)]
|
||||
pub struct Counter {
|
||||
#[primary_key]
|
||||
id: u64,
|
||||
val: u64
|
||||
}
|
||||
|
||||
#[spacetimedb::table(accessor = scheduled_counter, public, scheduled(inc, at = sched_at))]
|
||||
pub struct ScheduledCounter {
|
||||
#[primary_key]
|
||||
#[auto_inc]
|
||||
scheduled_id: u64,
|
||||
sched_at: spacetimedb::ScheduleAt,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn inc(ctx: &ReducerContext, arg: ScheduledCounter) {
|
||||
if let Some(mut counter) = ctx.db.counter().id().find(arg.scheduled_id) {
|
||||
counter.val += 1;
|
||||
ctx.db.counter().id().update(counter);
|
||||
} else {
|
||||
ctx.db.counter().insert(Counter {
|
||||
id: arg.scheduled_id,
|
||||
val: 1,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer(init)]
|
||||
pub fn init(ctx: &ReducerContext) {
|
||||
ctx.db.scheduled_counter().insert(ScheduledCounter {
|
||||
scheduled_id: 0,
|
||||
sched_at: duration!(100ms).into(),
|
||||
});
|
||||
}
|
||||
"""
|
||||
|
||||
def test_delete_database(self):
|
||||
"""
|
||||
Test that deleting a database stops the module.
|
||||
The module is considered stopped if its scheduled reducer stops
|
||||
producing update events.
|
||||
"""
|
||||
|
||||
name = random_string()
|
||||
self.publish_module(name, clear = False)
|
||||
sub = self.subscribe("select * from counter", n = 1000)
|
||||
time.sleep(2)
|
||||
self.spacetime("delete", "--yes", name)
|
||||
|
||||
updates = sub()
|
||||
# At a rate of 100ms, we shouldn't have more than 20 updates in 2secs.
|
||||
# But let's say 50, in case the delete gets delayed for some reason.
|
||||
assert len(updates) <= 50
|
||||
@@ -1,31 +0,0 @@
|
||||
from .. import Smoketest
|
||||
|
||||
class ModuleDescription(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{log, ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::table(accessor = person)]
|
||||
pub struct Person {
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn add(ctx: &ReducerContext, name: String) {
|
||||
ctx.db.person().insert(Person { name });
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn say_hello(ctx: &ReducerContext) {
|
||||
for person in ctx.db.person().iter() {
|
||||
log::info!("Hello, {}!", person.name);
|
||||
}
|
||||
log::info!("Hello, World!");
|
||||
}
|
||||
"""
|
||||
|
||||
def test_describe(self):
|
||||
"""Check describing a module"""
|
||||
|
||||
self.spacetime("describe", "--json", self.database_identity)
|
||||
self.spacetime("describe", "--json", self.database_identity, "reducer", "say_hello")
|
||||
self.spacetime("describe", "--json", self.database_identity, "table", "person")
|
||||
@@ -1,44 +0,0 @@
|
||||
from .. import Smoketest
|
||||
|
||||
class WasmBindgen(Smoketest):
|
||||
AUTOPUBLISH = False
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{log, ReducerContext};
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn test(_ctx: &ReducerContext) {
|
||||
log::info!("Hello! {}", now());
|
||||
}
|
||||
|
||||
#[wasm_bindgen::prelude::wasm_bindgen]
|
||||
extern "C" {
|
||||
fn now() -> i32;
|
||||
}
|
||||
"""
|
||||
EXTRA_DEPS = 'wasm-bindgen = "0.2"'
|
||||
|
||||
def test_detect_wasm_bindgen(self):
|
||||
"""Ensure that spacetime build properly catches wasm_bindgen imports"""
|
||||
|
||||
output = self.spacetime("build", "--module-path", self.project_path, full_output=True, check=False)
|
||||
self.assertTrue(output.returncode)
|
||||
self.assertIn("wasm-bindgen detected", output.stderr)
|
||||
|
||||
class Getrandom(Smoketest):
|
||||
AUTOPUBLISH = False
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{log, ReducerContext};
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn test(_ctx: &ReducerContext) {
|
||||
log::info!("Hello! {}", rand::random::<u8>());
|
||||
}
|
||||
"""
|
||||
EXTRA_DEPS = 'rand = "0.8"'
|
||||
|
||||
def test_detect_getrandom(self):
|
||||
"""Ensure that spacetime build properly catches getrandom"""
|
||||
|
||||
output = self.spacetime("build", "--module-path", self.project_path, full_output=True, check=False)
|
||||
self.assertTrue(output.returncode)
|
||||
self.assertIn("getrandom usage detected", output.stderr)
|
||||
@@ -1,28 +0,0 @@
|
||||
from .. import Smoketest
|
||||
|
||||
class Dml(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::table(accessor = t, public)]
|
||||
pub struct T {
|
||||
name: String,
|
||||
}
|
||||
"""
|
||||
|
||||
def test_subscribe(self):
|
||||
"""Test that we receive subscription updates from DML"""
|
||||
|
||||
# Subscribe to `t`
|
||||
sub = self.subscribe("SELECT * FROM t", n=2)
|
||||
|
||||
self.spacetime("sql", self.database_identity, "INSERT INTO t (name) VALUES ('Alice')")
|
||||
self.spacetime("sql", self.database_identity, "INSERT INTO t (name) VALUES ('Bob')")
|
||||
|
||||
self.assertEqual(
|
||||
sub(),
|
||||
[
|
||||
{"t": {"deletes": [], "inserts": [{"name": "Alice"}]}},
|
||||
{"t": {"deletes": [], "inserts": [{"name": "Bob"}]}},
|
||||
],
|
||||
)
|
||||
@@ -1,84 +0,0 @@
|
||||
from .. import Smoketest, random_string
|
||||
import json
|
||||
|
||||
class Domains(Smoketest):
|
||||
AUTOPUBLISH = False
|
||||
|
||||
def test_set_name(self):
|
||||
"""Tests the functionality of the set-name command"""
|
||||
|
||||
orig_name = random_string()
|
||||
self.publish_module(orig_name)
|
||||
|
||||
rand_name = random_string()
|
||||
|
||||
# This should throw an exception before there's a db with this name
|
||||
with self.assertRaises(Exception):
|
||||
self.spacetime("logs", rand_name)
|
||||
|
||||
self.spacetime("rename", "--to", rand_name, self.database_identity)
|
||||
|
||||
# Now we're essentially just testing that it *doesn't* throw an exception
|
||||
self.spacetime("logs", rand_name)
|
||||
|
||||
# This should throw an exception because the original name shouldn't exist anymore
|
||||
with self.assertRaises(Exception):
|
||||
self.spacetime("logs", orig_name)
|
||||
|
||||
def test_subdomain_behavior(self):
|
||||
"""Test how we treat the / character in published names"""
|
||||
|
||||
root_name = random_string()
|
||||
self.publish_module(root_name)
|
||||
|
||||
# TODO: This is valid in editions with the teams feature, but
|
||||
# smoketests don't know the target's edition.
|
||||
# self.publish_module(f"{root_name}/test")
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
self.publish_module(f"{root_name}//test")
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
self.publish_module(f"{root_name}/test/")
|
||||
|
||||
def test_set_to_existing_name(self):
|
||||
"""Test that we can't rename to a name already in use"""
|
||||
|
||||
self.publish_module()
|
||||
id_to_rename = self.database_identity
|
||||
|
||||
rename_to = random_string()
|
||||
self.publish_module(rename_to)
|
||||
|
||||
# This should throw an exception because there's a db with this name
|
||||
with self.assertRaises(Exception):
|
||||
self.spacetime("rename", "--to", rename_to, id_to_rename)
|
||||
|
||||
def test_replace_names(self):
|
||||
"""Test that we can rename to a list of names"""
|
||||
|
||||
orig_name = random_string()
|
||||
alt_name1 = random_string()
|
||||
alt_name2 = random_string()
|
||||
self.publish_module(orig_name)
|
||||
|
||||
self.api_call(
|
||||
"PUT",
|
||||
f'/v1/database/{orig_name}/names',
|
||||
json.dumps([alt_name1, alt_name2]),
|
||||
{"Content-type": "application/json"}
|
||||
)
|
||||
|
||||
# Use logs to check that name resolution works
|
||||
self.spacetime("logs", alt_name1)
|
||||
self.spacetime("logs", alt_name2)
|
||||
with self.assertRaises(Exception):
|
||||
self.spacetime("logs", orig_name)
|
||||
|
||||
# Restore orig name so the database gets deleted on clean up
|
||||
self.api_call(
|
||||
"PUT",
|
||||
f'/v1/database/{alt_name1}/names',
|
||||
json.dumps([orig_name]),
|
||||
{"Content-type": "application/json"}
|
||||
)
|
||||
@@ -1,75 +0,0 @@
|
||||
from .. import Smoketest, random_string
|
||||
import subprocess
|
||||
|
||||
class FailInitialPublish(Smoketest):
|
||||
AUTOPUBLISH = False
|
||||
|
||||
MODULE_CODE_BROKEN = """
|
||||
use spacetimedb::{client_visibility_filter, Filter};
|
||||
|
||||
#[spacetimedb::table(accessor = person, public)]
|
||||
pub struct Person {
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[client_visibility_filter]
|
||||
// Bug: `Person` is the wrong table name, should be `person`.
|
||||
const HIDE_PEOPLE_EXCEPT_ME: Filter = Filter::Sql("SELECT * FROM Person WHERE name = 'me'");
|
||||
"""
|
||||
|
||||
MODULE_CODE_FIXED = """
|
||||
use spacetimedb::{client_visibility_filter, Filter};
|
||||
|
||||
#[spacetimedb::table(accessor = person, public)]
|
||||
pub struct Person {
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[client_visibility_filter]
|
||||
const HIDE_PEOPLE_EXCEPT_ME: Filter = Filter::Sql("SELECT * FROM person WHERE name = 'me'");
|
||||
"""
|
||||
|
||||
FIXED_QUERY = '"sql": "SELECT * FROM person WHERE name = \'me\'"'
|
||||
|
||||
def test_fail_initial_publish(self):
|
||||
"""This tests that publishing an invalid module does not leave a broken entry in the control DB."""
|
||||
|
||||
name = random_string()
|
||||
|
||||
self.write_module_code(self.MODULE_CODE_BROKEN)
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
self.publish_module(name)
|
||||
|
||||
describe_output = self.spacetime("describe", "--json", name, full_output = True, check = False)
|
||||
|
||||
with self.assertRaises(subprocess.CalledProcessError):
|
||||
describe_output.check_returncode()
|
||||
|
||||
self.assertIn("Error: No such database.", describe_output.stderr)
|
||||
|
||||
# We can publish a fixed module under the same database name.
|
||||
# This used to be broken;
|
||||
# the failed initial publish would leave the control database in a bad state.
|
||||
self.write_module_code(self.MODULE_CODE_FIXED)
|
||||
|
||||
self.publish_module(name, clear = False)
|
||||
describe_output = self.spacetime("describe", "--json", name)
|
||||
|
||||
self.assertIn(
|
||||
self.FIXED_QUERY,
|
||||
[line.strip() for line in describe_output.splitlines()],
|
||||
)
|
||||
|
||||
# Publishing the broken code again fails, but the database still exists afterwards,
|
||||
# with the previous version of the module code.
|
||||
self.write_module_code(self.MODULE_CODE_BROKEN)
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
self.publish_module(name, clear = False)
|
||||
|
||||
describe_output = self.spacetime("describe", "--json", name)
|
||||
self.assertIn(
|
||||
self.FIXED_QUERY,
|
||||
[line.strip() for line in describe_output.splitlines()],
|
||||
)
|
||||
@@ -1,311 +0,0 @@
|
||||
from .. import Smoketest
|
||||
|
||||
class Filtering(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{log, Identity, ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::table(accessor = person)]
|
||||
pub struct Person {
|
||||
#[unique]
|
||||
id: i32,
|
||||
|
||||
name: String,
|
||||
|
||||
#[unique]
|
||||
nick: String,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn insert_person(ctx: &ReducerContext, id: i32, name: String, nick: String) {
|
||||
ctx.db.person().insert(Person { id, name, nick} );
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn insert_person_twice(ctx: &ReducerContext, id: i32, name: String, nick: String) {
|
||||
// We'd like to avoid an error due to a set-semantic error.
|
||||
let name2 = format!("{name}2");
|
||||
ctx.db.person().insert(Person { id, name, nick: nick.clone()} );
|
||||
match ctx.db.person().try_insert(Person { id, name: name2, nick: nick.clone()}) {
|
||||
Ok(_) => {},
|
||||
Err(_) => {
|
||||
log::info!("UNIQUE CONSTRAINT VIOLATION ERROR: id = {}, nick = {}", id, nick)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn delete_person(ctx: &ReducerContext, id: i32) {
|
||||
ctx.db.person().id().delete(&id);
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn find_person(ctx: &ReducerContext, id: i32) {
|
||||
match ctx.db.person().id().find(&id) {
|
||||
Some(person) => log::info!("UNIQUE FOUND: id {}: {}", id, person.name),
|
||||
None => log::info!("UNIQUE NOT FOUND: id {}", id),
|
||||
}
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn find_person_read_only(ctx: &ReducerContext, id: i32) {
|
||||
let ctx = ctx.as_read_only();
|
||||
match ctx.db.person().id().find(&id) {
|
||||
Some(person) => log::info!("UNIQUE FOUND: id {}: {}", id, person.name),
|
||||
None => log::info!("UNIQUE NOT FOUND: id {}", id),
|
||||
}
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn find_person_by_name(ctx: &ReducerContext, name: String) {
|
||||
for person in ctx.db.person().iter().filter(|p| p.name == name) {
|
||||
log::info!("UNIQUE FOUND: id {}: {} aka {}", person.id, person.name, person.nick);
|
||||
}
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn find_person_by_nick(ctx: &ReducerContext, nick: String) {
|
||||
match ctx.db.person().nick().find(&nick) {
|
||||
Some(person) => log::info!("UNIQUE FOUND: id {}: {}", person.id, person.nick),
|
||||
None => log::info!("UNIQUE NOT FOUND: nick {}", nick),
|
||||
}
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn find_person_by_nick_read_only(ctx: &ReducerContext, nick: String) {
|
||||
let ctx = ctx.as_read_only();
|
||||
match ctx.db.person().nick().find(&nick) {
|
||||
Some(person) => log::info!("UNIQUE FOUND: id {}: {}", person.id, person.nick),
|
||||
None => log::info!("UNIQUE NOT FOUND: nick {}", nick),
|
||||
}
|
||||
}
|
||||
|
||||
#[spacetimedb::table(accessor = nonunique_person)]
|
||||
pub struct NonuniquePerson {
|
||||
#[index(btree)]
|
||||
id: i32,
|
||||
name: String,
|
||||
is_human: bool,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn insert_nonunique_person(ctx: &ReducerContext, id: i32, name: String, is_human: bool) {
|
||||
ctx.db.nonunique_person().insert(NonuniquePerson { id, name, is_human } );
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn find_nonunique_person(ctx: &ReducerContext, id: i32) {
|
||||
for person in ctx.db.nonunique_person().id().filter(&id) {
|
||||
log::info!("NONUNIQUE FOUND: id {}: {}", id, person.name)
|
||||
}
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn find_nonunique_person_read_only(ctx: &ReducerContext, id: i32) {
|
||||
let ctx = ctx.as_read_only();
|
||||
for person in ctx.db.nonunique_person().id().filter(&id) {
|
||||
log::info!("NONUNIQUE FOUND: id {}: {}", id, person.name)
|
||||
}
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn find_nonunique_humans(ctx: &ReducerContext) {
|
||||
for person in ctx.db.nonunique_person().iter().filter(|p| p.is_human) {
|
||||
log::info!("HUMAN FOUND: id {}: {}", person.id, person.name);
|
||||
}
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn find_nonunique_non_humans(ctx: &ReducerContext) {
|
||||
for person in ctx.db.nonunique_person().iter().filter(|p| !p.is_human) {
|
||||
log::info!("NON-HUMAN FOUND: id {}: {}", person.id, person.name);
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that [Identity] is filterable and a legal unique column.
|
||||
#[spacetimedb::table(accessor = identified_person)]
|
||||
struct IdentifiedPerson {
|
||||
#[unique]
|
||||
identity: Identity,
|
||||
name: String,
|
||||
}
|
||||
|
||||
fn identify(id_number: u64) -> Identity {
|
||||
let mut bytes = [0u8; 32];
|
||||
bytes[..8].clone_from_slice(&id_number.to_le_bytes());
|
||||
Identity::from_byte_array(bytes)
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
fn insert_identified_person(ctx: &ReducerContext, id_number: u64, name: String) {
|
||||
let identity = identify(id_number);
|
||||
ctx.db.identified_person().insert(IdentifiedPerson { identity, name });
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
fn find_identified_person(ctx: &ReducerContext, id_number: u64) {
|
||||
let identity = identify(id_number);
|
||||
match ctx.db.identified_person().identity().find(&identity) {
|
||||
Some(person) => log::info!("IDENTIFIED FOUND: {}", person.name),
|
||||
None => log::info!("IDENTIFIED NOT FOUND"),
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that indices on non-unique columns behave as we expect.
|
||||
#[spacetimedb::table(accessor = indexed_person)]
|
||||
struct IndexedPerson {
|
||||
#[unique]
|
||||
id: i32,
|
||||
given_name: String,
|
||||
#[index(btree)]
|
||||
surname: String,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
fn insert_indexed_person(ctx: &ReducerContext, id: i32, given_name: String, surname: String) {
|
||||
ctx.db.indexed_person().insert(IndexedPerson { id, given_name, surname });
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
fn delete_indexed_person(ctx: &ReducerContext, id: i32) {
|
||||
ctx.db.indexed_person().id().delete(&id);
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
fn find_indexed_people(ctx: &ReducerContext, surname: String) {
|
||||
for person in ctx.db.indexed_person().surname().filter(&surname) {
|
||||
log::info!("INDEXED FOUND: id {}: {}, {}", person.id, person.surname, person.given_name);
|
||||
}
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
fn find_indexed_people_read_only(ctx: &ReducerContext, surname: String) {
|
||||
let ctx = ctx.as_read_only();
|
||||
for person in ctx.db.indexed_person().surname().filter(&surname) {
|
||||
log::info!("INDEXED FOUND: id {}: {}, {}", person.id, person.surname, person.given_name);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
# TODO: split this into multiple test functions
|
||||
def test_filtering(self):
|
||||
"""Test filtering reducers"""
|
||||
|
||||
self.call("insert_person", 23, "Alice", "al")
|
||||
self.call("insert_person", 42, "Bob", "bo")
|
||||
self.call("insert_person", 64, "Bob", "b2")
|
||||
|
||||
# Find a person who is there.
|
||||
self.call("find_person", 23)
|
||||
self.assertIn("UNIQUE FOUND: id 23: Alice", self.logs(2))
|
||||
|
||||
# Find persons with the same name.
|
||||
self.call("find_person_by_name", "Bob")
|
||||
logs = self.logs(4)
|
||||
self.assertIn("UNIQUE FOUND: id 42: Bob aka bo", logs)
|
||||
self.assertIn("UNIQUE FOUND: id 64: Bob aka b2", logs)
|
||||
|
||||
# Fail to find a person who is not there.
|
||||
self.call("find_person", 43)
|
||||
self.assertIn("UNIQUE NOT FOUND: id 43", self.logs(2))
|
||||
self.call("find_person_read_only", 43)
|
||||
self.assertIn("UNIQUE NOT FOUND: id 43", self.logs(2))
|
||||
|
||||
# Find a person by nickname.
|
||||
self.call("find_person_by_nick", "al")
|
||||
self.assertIn("UNIQUE FOUND: id 23: al", self.logs(2))
|
||||
self.call("find_person_by_nick_read_only", "al")
|
||||
self.assertIn("UNIQUE FOUND: id 23: al", self.logs(2))
|
||||
|
||||
# Remove a person, and then fail to find them.
|
||||
self.call("delete_person", 23)
|
||||
self.call("find_person", 23)
|
||||
self.assertIn("UNIQUE NOT FOUND: id 23", self.logs(2))
|
||||
self.call("find_person_read_only", 23)
|
||||
self.assertIn("UNIQUE NOT FOUND: id 23", self.logs(2))
|
||||
# Also fail by nickname
|
||||
self.call("find_person_by_nick", "al")
|
||||
self.assertIn("UNIQUE NOT FOUND: nick al", self.logs(2))
|
||||
self.call("find_person_by_nick_read_only", "al")
|
||||
self.assertIn("UNIQUE NOT FOUND: nick al", self.logs(2))
|
||||
|
||||
# Add some nonunique people.
|
||||
self.call("insert_nonunique_person", 23, "Alice", True)
|
||||
self.call("insert_nonunique_person", 42, "Bob", True)
|
||||
|
||||
# Find a nonunique person who is there.
|
||||
self.call("find_nonunique_person", 23)
|
||||
self.assertIn('NONUNIQUE FOUND: id 23: Alice', self.logs(2))
|
||||
self.call("find_nonunique_person_read_only", 23)
|
||||
self.assertIn('NONUNIQUE FOUND: id 23: Alice', self.logs(2))
|
||||
|
||||
# Fail to find a nonunique person who is not there.
|
||||
self.call("find_nonunique_person", 43)
|
||||
self.assertNotIn("NONUNIQUE NOT FOUND: id 43", self.logs(2))
|
||||
self.call("find_nonunique_person_read_only", 43)
|
||||
self.assertNotIn("NONUNIQUE NOT FOUND: id 43", self.logs(2))
|
||||
|
||||
# Insert a non-human, then find humans, then find non-humans
|
||||
self.call("insert_nonunique_person", 64, "Jibbitty", False)
|
||||
self.call("find_nonunique_humans")
|
||||
self.assertIn('HUMAN FOUND: id 23: Alice', self.logs(2))
|
||||
self.assertIn('HUMAN FOUND: id 42: Bob', self.logs(2))
|
||||
self.call("find_nonunique_non_humans")
|
||||
self.assertIn('NON-HUMAN FOUND: id 64: Jibbitty', self.logs(2))
|
||||
|
||||
# Add another person with the same id, and find them both.
|
||||
self.call("insert_nonunique_person", 23, "Claire", True)
|
||||
self.call("find_nonunique_person", 23)
|
||||
self.assertIn('NONUNIQUE FOUND: id 23: Alice', self.logs(2))
|
||||
self.assertIn('NONUNIQUE FOUND: id 23: Claire', self.logs(2))
|
||||
self.call("find_nonunique_person_read_only", 23)
|
||||
self.assertIn('NONUNIQUE FOUND: id 23: Alice', self.logs(2))
|
||||
self.assertIn('NONUNIQUE FOUND: id 23: Claire', self.logs(2))
|
||||
|
||||
# Check for issues with things present in index but not DB
|
||||
self.call("insert_person", 101, "Fee", "fee")
|
||||
self.call("insert_person", 102, "Fi", "fi")
|
||||
self.call("insert_person", 103, "Fo", "fo")
|
||||
self.call("insert_person", 104, "Fum", "fum")
|
||||
self.call("delete_person", 103)
|
||||
self.call("find_person", 104)
|
||||
self.assertIn('UNIQUE FOUND: id 104: Fum', self.logs(2))
|
||||
self.call("find_person_read_only", 104)
|
||||
self.assertIn('UNIQUE FOUND: id 104: Fum', self.logs(2))
|
||||
|
||||
# As above, but for non-unique indices: check for consistency between index and DB
|
||||
self.call("insert_indexed_person", 7, "James", "Bond")
|
||||
self.call("insert_indexed_person", 79, "Gold", "Bond")
|
||||
self.call("insert_indexed_person", 1, "Hydrogen", "Bond")
|
||||
self.call("insert_indexed_person", 100, "Whiskey", "Bond")
|
||||
self.call("delete_indexed_person", 100)
|
||||
self.call("find_indexed_people", "Bond")
|
||||
logs = self.logs(10)
|
||||
self.assertIn('INDEXED FOUND: id 7: Bond, James', logs)
|
||||
self.assertIn('INDEXED FOUND: id 79: Bond, Gold', logs)
|
||||
self.assertIn('INDEXED FOUND: id 1: Bond, Hydrogen', logs)
|
||||
self.assertNotIn('INDEXED FOUND: id 100: Bond, Whiskey', logs)
|
||||
self.call("find_indexed_people_read_only", "Bond")
|
||||
logs = self.logs(10)
|
||||
self.assertIn('INDEXED FOUND: id 7: Bond, James', logs)
|
||||
self.assertIn('INDEXED FOUND: id 79: Bond, Gold', logs)
|
||||
self.assertIn('INDEXED FOUND: id 1: Bond, Hydrogen', logs)
|
||||
self.assertNotIn('INDEXED FOUND: id 100: Bond, Whiskey', logs)
|
||||
|
||||
# Non-unique version; does not work yet, see db_delete codegen in SpacetimeDB\crates\bindings-macro\src\lib.rs
|
||||
# self.call("insert_nonunique_person", 101, "Fee")
|
||||
# self.call("insert_nonunique_person", 102, "Fi")
|
||||
# self.call("insert_nonunique_person", 103, "Fo")
|
||||
# self.call("insert_nonunique_person", 104, "Fum")
|
||||
# self.call("find_nonunique_person", 104)
|
||||
# self.assertIn('NONUNIQUE FOUND: id 104: Fum', self.logs(2))
|
||||
|
||||
# Filter by Identity
|
||||
self.call("insert_identified_person", 23, "Alice")
|
||||
self.call("find_identified_person", 23)
|
||||
self.assertIn('IDENTIFIED FOUND: Alice', self.logs(2))
|
||||
|
||||
# Inserting into a table with unique constraints fails
|
||||
# when the second row has the same value in the constrained columns as the first row.
|
||||
# In this case, the table has `#[unique] id` and `#[unique] nick` but not `#[unique] name`.
|
||||
self.call("insert_person_twice", 23, "Alice", "al")
|
||||
self.assertIn('UNIQUE CONSTRAINT VIOLATION ERROR: id = 23, nick = al', self.logs(2))
|
||||
@@ -1,53 +0,0 @@
|
||||
from .. import Smoketest
|
||||
|
||||
class ModuleNestedOp(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{log, ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::table(accessor = account)]
|
||||
pub struct Account {
|
||||
name: String,
|
||||
#[unique]
|
||||
id: i32,
|
||||
}
|
||||
|
||||
#[spacetimedb::table(accessor = friends)]
|
||||
pub struct Friends {
|
||||
friend_1: i32,
|
||||
friend_2: i32,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn create_account(ctx: &ReducerContext, account_id: i32, name: String) {
|
||||
ctx.db.account().insert(Account { id: account_id, name } );
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn add_friend(ctx: &ReducerContext, my_id: i32, their_id: i32) {
|
||||
// Make sure our friend exists
|
||||
for account in ctx.db.account().iter() {
|
||||
if account.id == their_id {
|
||||
ctx.db.friends().insert(Friends { friend_1: my_id, friend_2: their_id });
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn say_friends(ctx: &ReducerContext) {
|
||||
for friendship in ctx.db.friends().iter() {
|
||||
let friend1 = ctx.db.account().id().find(&friendship.friend_1).unwrap();
|
||||
let friend2 = ctx.db.account().id().find(&friendship.friend_2).unwrap();
|
||||
log::info!("{} is friends with {}", friend1.name, friend2.name);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
def test_module_nested_op(self):
|
||||
"""This tests uploading a basic module and calling some functions and checking logs afterwards."""
|
||||
|
||||
self.call("create_account", 1, "House")
|
||||
self.call("create_account", 2, "Wilson")
|
||||
self.call("add_friend", 1, 2)
|
||||
self.call("say_friends")
|
||||
self.assertIn("House is friends with Wilson", self.logs(2))
|
||||
@@ -1,247 +0,0 @@
|
||||
from .. import Smoketest, random_string
|
||||
from subprocess import CalledProcessError
|
||||
import time
|
||||
import itertools
|
||||
|
||||
class UpdateModule(Smoketest):
|
||||
AUTOPUBLISH = False
|
||||
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{log, ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::table(accessor = person)]
|
||||
pub struct Person {
|
||||
#[primary_key]
|
||||
#[auto_inc]
|
||||
id: u64,
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn add(ctx: &ReducerContext, name: String) {
|
||||
ctx.db.person().insert(Person { id: 0, name });
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn say_hello(ctx: &ReducerContext) {
|
||||
for person in ctx.db.person().iter() {
|
||||
log::info!("Hello, {}!", person.name);
|
||||
}
|
||||
log::info!("Hello, World!");
|
||||
}
|
||||
"""
|
||||
MODULE_CODE_B = """
|
||||
#[spacetimedb::table(accessor = person)]
|
||||
pub struct Person {
|
||||
#[primary_key]
|
||||
#[auto_inc]
|
||||
id: u64,
|
||||
name: String,
|
||||
age: u8,
|
||||
}
|
||||
"""
|
||||
|
||||
MODULE_CODE_C = """
|
||||
use spacetimedb::{log, ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::table(accessor = person)]
|
||||
pub struct Person {
|
||||
#[primary_key]
|
||||
#[auto_inc]
|
||||
id: u64,
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[spacetimedb::table(accessor = pets)]
|
||||
pub struct Pet {
|
||||
species: String,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn are_we_updated_yet(ctx: &ReducerContext) {
|
||||
log::info!("MODULE UPDATED");
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def test_module_update(self):
|
||||
"""Test publishing a module without the --delete-data option"""
|
||||
|
||||
name = random_string()
|
||||
|
||||
self.publish_module(name, clear=False)
|
||||
|
||||
self.call("add", "Robert")
|
||||
self.call("add", "Julie")
|
||||
self.call("add", "Samantha")
|
||||
self.call("say_hello")
|
||||
logs = self.logs(100)
|
||||
self.assertIn("Hello, Samantha!", logs)
|
||||
self.assertIn("Hello, Julie!", logs)
|
||||
self.assertIn("Hello, Robert!", logs)
|
||||
self.assertIn("Hello, World!", logs)
|
||||
|
||||
# Unchanged module is ok
|
||||
self.publish_module(name, clear=False)
|
||||
|
||||
# Changing an existing table isn't
|
||||
self.write_module_code(self.MODULE_CODE_B)
|
||||
with self.assertRaises(CalledProcessError) as cm:
|
||||
self.publish_module(name, clear=False)
|
||||
self.assertIn("Error: Aborting because publishing would require manual migration", cm.exception.stderr)
|
||||
|
||||
# Check that the old module is still running by calling say_hello
|
||||
self.call("say_hello")
|
||||
|
||||
# Adding a table is ok
|
||||
self.write_module_code(self.MODULE_CODE_C)
|
||||
self.publish_module(name, clear=False)
|
||||
self.call("are_we_updated_yet")
|
||||
self.assertIn("MODULE UPDATED", self.logs(2))
|
||||
|
||||
|
||||
class UploadModule1(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{log, ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::table(accessor = person)]
|
||||
pub struct Person {
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn add(ctx: &ReducerContext, name: String) {
|
||||
ctx.db.person().insert(Person { name });
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn say_hello(ctx: &ReducerContext) {
|
||||
for person in ctx.db.person().iter() {
|
||||
log::info!("Hello, {}!", person.name);
|
||||
}
|
||||
log::info!("Hello, World!");
|
||||
}
|
||||
"""
|
||||
|
||||
def test_upload_module_1(self):
|
||||
"""This tests uploading a basic module and calling some functions and checking logs afterwards."""
|
||||
|
||||
self.call("add", "Robert")
|
||||
self.call("add", "Julie")
|
||||
self.call("add", "Samantha")
|
||||
self.call("say_hello")
|
||||
logs = self.logs(100)
|
||||
self.assertIn("Hello, Samantha!", logs)
|
||||
self.assertIn("Hello, Julie!", logs)
|
||||
self.assertIn("Hello, Robert!", logs)
|
||||
self.assertIn("Hello, World!", logs)
|
||||
|
||||
|
||||
class UploadModule2(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{log, duration, ReducerContext, Table, Timestamp};
|
||||
|
||||
|
||||
#[spacetimedb::table(accessor = scheduled_message, public, scheduled(my_repeating_reducer))]
|
||||
pub struct ScheduledMessage {
|
||||
#[primary_key]
|
||||
#[auto_inc]
|
||||
scheduled_id: u64,
|
||||
scheduled_at: spacetimedb::ScheduleAt,
|
||||
prev: Timestamp,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer(init)]
|
||||
fn init(ctx: &ReducerContext) {
|
||||
ctx.db.scheduled_message().insert(ScheduledMessage { prev: ctx.timestamp, scheduled_id: 0, scheduled_at: duration!(100ms).into(), });
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn my_repeating_reducer(ctx: &ReducerContext, arg: ScheduledMessage) {
|
||||
log::info!("Invoked: ts={:?}, delta={:?}", ctx.timestamp, ctx.timestamp.duration_since(arg.prev));
|
||||
}
|
||||
"""
|
||||
def test_upload_module_2(self):
|
||||
"""This test deploys a module with a repeating reducer and checks the logs to make sure its running."""
|
||||
|
||||
time.sleep(2)
|
||||
lines = sum(1 for line in self.logs(100) if "Invoked" in line)
|
||||
time.sleep(4)
|
||||
new_lines = sum(1 for line in self.logs(100) if "Invoked" in line)
|
||||
self.assertLess(lines, new_lines)
|
||||
|
||||
|
||||
class HotswapModule(Smoketest):
|
||||
AUTOPUBLISH = False
|
||||
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::table(accessor = person)]
|
||||
pub struct Person {
|
||||
#[primary_key]
|
||||
#[auto_inc]
|
||||
id: u64,
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn add_person(ctx: &ReducerContext, name: String) {
|
||||
ctx.db.person().insert(Person { id: 0, name });
|
||||
}
|
||||
"""
|
||||
|
||||
MODULE_CODE_B = """
|
||||
use spacetimedb::{ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::table(accessor = person)]
|
||||
pub struct Person {
|
||||
#[primary_key]
|
||||
#[auto_inc]
|
||||
id: u64,
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn add_person(ctx: &ReducerContext, name: String) {
|
||||
ctx.db.person().insert(Person { id: 0, name });
|
||||
}
|
||||
|
||||
#[spacetimedb::table(accessor = pet)]
|
||||
pub struct Pet {
|
||||
#[primary_key]
|
||||
species: String,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn add_pet(ctx: &ReducerContext, species: String) {
|
||||
ctx.db.pet().insert(Pet { species });
|
||||
}
|
||||
"""
|
||||
|
||||
def test_hotswap_module(self):
|
||||
"""Tests hotswapping of modules."""
|
||||
|
||||
# Publish MODULE_CODE and subscribe to all
|
||||
name = random_string()
|
||||
self.publish_module(name, clear=False)
|
||||
sub = self.subscribe("SELECT * FROM *", n=2)
|
||||
|
||||
# Trigger event on the subscription
|
||||
self.call("add_person", "Horst")
|
||||
|
||||
# Update the module
|
||||
self.write_module_code(self.MODULE_CODE_B)
|
||||
self.publish_module(name, clear=False)
|
||||
|
||||
# Assert that the module was updated
|
||||
self.call("add_pet", "Turtle")
|
||||
# And trigger another event on the subscription
|
||||
self.call("add_person", "Cindy")
|
||||
|
||||
# Note that 'SELECT * FROM *' does NOT get refreshed to include the
|
||||
# new table (this is a known limitation).
|
||||
self.assertEqual(sub(), [
|
||||
{'person': {'deletes': [], 'inserts': [{'id': 1, 'name': 'Horst'}]}},
|
||||
{'person': {'deletes': [], 'inserts': [{'id': 2, 'name': 'Cindy'}]}}
|
||||
])
|
||||
@@ -1,38 +0,0 @@
|
||||
from .. import Smoketest, random_string
|
||||
import tempfile
|
||||
import os
|
||||
from glob import iglob
|
||||
|
||||
|
||||
def count_matches(dir, needle):
|
||||
count = 0
|
||||
for f in iglob(os.path.join(dir, "**/*.cs"), recursive=True):
|
||||
with open(f) as f:
|
||||
count += f.read().count(needle)
|
||||
return count
|
||||
|
||||
|
||||
class Namespaces(Smoketest):
|
||||
AUTOPUBLISH = False
|
||||
|
||||
def test_spacetimedb_ns_csharp(self):
|
||||
"""Ensure that the default namespace is working properly"""
|
||||
|
||||
namespace = "SpacetimeDB.Types"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
self.spacetime("generate", "--out-dir", tmpdir, "--lang=cs", "--module-path", self.project_path)
|
||||
|
||||
self.assertEqual(count_matches(tmpdir, f"namespace {namespace}"), 5)
|
||||
self.assertEqual(count_matches(tmpdir, "using SpacetimeDB;"), 0)
|
||||
|
||||
def test_custom_ns_csharp(self):
|
||||
"""Ensure that when a custom namespace is specified on the command line, it actually gets used in generation"""
|
||||
|
||||
namespace = random_string()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
self.spacetime("generate", "--out-dir", tmpdir, "--lang=cs", "--namespace", namespace, "--module-path", self.project_path)
|
||||
|
||||
self.assertEqual(count_matches(tmpdir, f"namespace {namespace}"), 5)
|
||||
self.assertEqual(count_matches(tmpdir, "using SpacetimeDB;"), 5)
|
||||
@@ -1,53 +0,0 @@
|
||||
from .. import Smoketest, requires_anonymous_login
|
||||
import time
|
||||
|
||||
class NewUserFlow(Smoketest):
|
||||
AUTOPUBLISH = False
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{log, ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::table(accessor = person)]
|
||||
pub struct Person {
|
||||
name: String
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn add(ctx: &ReducerContext, name: String) {
|
||||
ctx.db.person().insert(Person { name });
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn say_hello(ctx: &ReducerContext) {
|
||||
for person in ctx.db.person().iter() {
|
||||
log::info!("Hello, {}!", person.name);
|
||||
}
|
||||
log::info!("Hello, World!");
|
||||
}
|
||||
"""
|
||||
|
||||
@requires_anonymous_login
|
||||
def test_new_user_flow(self):
|
||||
"""Test the entirety of the new user flow."""
|
||||
|
||||
## Publish your module
|
||||
self.new_identity()
|
||||
|
||||
self.publish_module()
|
||||
|
||||
# Calling our database
|
||||
self.call("say_hello")
|
||||
self.assertIn("Hello, World!", self.logs(2))
|
||||
|
||||
## Calling functions with arguments
|
||||
self.call("add", "Tyler")
|
||||
self.call("say_hello")
|
||||
self.assertEqual(self.logs(5).count("Hello, World!"), 2)
|
||||
self.assertEqual(self.logs(5).count("Hello, Tyler!"), 1)
|
||||
|
||||
out = self.spacetime("sql", self.database_identity, "SELECT * FROM person")
|
||||
# The spaces after the name are important
|
||||
self.assertMultiLineEqual(out, """\
|
||||
name
|
||||
---------
|
||||
"Tyler"
|
||||
""")
|
||||
@@ -1,50 +0,0 @@
|
||||
from .. import Smoketest
|
||||
|
||||
class Panic(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{log, ReducerContext};
|
||||
use std::cell::RefCell;
|
||||
|
||||
thread_local! {
|
||||
static X: RefCell<u32> = RefCell::new(0);
|
||||
}
|
||||
#[spacetimedb::reducer]
|
||||
fn first(_ctx: &ReducerContext) {
|
||||
X.with(|x| {
|
||||
let _x = x.borrow_mut();
|
||||
panic!()
|
||||
})
|
||||
}
|
||||
#[spacetimedb::reducer]
|
||||
fn second(_ctx: &ReducerContext) {
|
||||
X.with(|x| *x.borrow_mut());
|
||||
log::info!("Test Passed");
|
||||
}
|
||||
"""
|
||||
|
||||
def test_panic(self):
|
||||
"""Tests to check if a SpacetimeDB module can handle a panic without corrupting"""
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
self.call("first")
|
||||
|
||||
self.call("second")
|
||||
self.assertIn("Test Passed", self.logs(2))
|
||||
|
||||
class ReducerError(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::ReducerContext;
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
fn fail(_ctx: &ReducerContext) -> Result<(), String> {
|
||||
Err("oopsie :(".into())
|
||||
}
|
||||
"""
|
||||
|
||||
def test_reducer_error_message(self):
|
||||
"""Tests to ensure an error message returned from a reducer gets printed to logs"""
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
self.call("fail")
|
||||
|
||||
self.assertIn("oopsie :(", self.logs(2))
|
||||
@@ -1,165 +0,0 @@
|
||||
from .. import Smoketest, random_string
|
||||
import json
|
||||
|
||||
class Permissions(Smoketest):
|
||||
AUTOPUBLISH = False
|
||||
|
||||
def setUp(self):
|
||||
self.reset_config()
|
||||
|
||||
def test_call(self):
|
||||
"""Ensure that anyone has the permission to call any standard reducer"""
|
||||
|
||||
self.publish_module()
|
||||
|
||||
self.call("say_hello", anon=True)
|
||||
|
||||
self.assertEqual("\n".join(self.logs(10000)).count("World"), 1)
|
||||
|
||||
def test_delete(self):
|
||||
"""Ensure that you cannot delete a database that you do not own"""
|
||||
|
||||
self.publish_module()
|
||||
|
||||
self.new_identity()
|
||||
with self.assertRaises(Exception):
|
||||
self.spacetime("delete", "--yes", self.database_identity)
|
||||
|
||||
def test_describe(self):
|
||||
"""Ensure that anyone can describe any database"""
|
||||
|
||||
self.publish_module()
|
||||
|
||||
self.spacetime("describe", "--anonymous", "--json", self.database_identity)
|
||||
|
||||
def test_logs(self):
|
||||
"""Ensure that we are not able to view the logs of a module that we don't have permission to view"""
|
||||
|
||||
self.publish_module()
|
||||
|
||||
self.reset_config()
|
||||
self.new_identity()
|
||||
self.call("say_hello")
|
||||
|
||||
self.reset_config()
|
||||
self.new_identity()
|
||||
with self.assertRaises(Exception):
|
||||
self.spacetime("logs", self.database_identity, "-n", "10000")
|
||||
|
||||
def test_publish(self):
|
||||
"""This test checks to make sure that you cannot publish to an identity that you do not own."""
|
||||
|
||||
self.publish_module()
|
||||
|
||||
self.new_identity()
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
self.spacetime("publish", self.database_identity, "--module-path", self.project_path, "--delete-data", "--yes")
|
||||
|
||||
# Check that this holds without `--delete-data`, too.
|
||||
with self.assertRaises(Exception):
|
||||
self.spacetime("publish", self.database_identity, "--module-path", self.project_path, "--yes")
|
||||
|
||||
def test_replace_names(self):
|
||||
"""Test that you can't replace names of a database you don't own"""
|
||||
|
||||
name = random_string()
|
||||
self.publish_module(name)
|
||||
|
||||
self.new_identity()
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
self.api_call(
|
||||
"PUT",
|
||||
f'/v1/database/{name}/names',
|
||||
json.dumps(["post", "gres"]),
|
||||
{"Content-type": "application/json"}
|
||||
)
|
||||
|
||||
class PrivateTablePermissions(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::table(accessor = secret, private)]
|
||||
pub struct Secret {
|
||||
answer: u8,
|
||||
}
|
||||
|
||||
#[spacetimedb::table(accessor = common_knowledge, public)]
|
||||
pub struct CommonKnowledge {
|
||||
thing: String,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer(init)]
|
||||
pub fn init(ctx: &ReducerContext) {
|
||||
ctx.db.secret().insert(Secret { answer: 42 });
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn do_thing(ctx: &ReducerContext, thing: String) {
|
||||
ctx.db.secret().insert(Secret { answer: 20 });
|
||||
ctx.db.common_knowledge().insert(CommonKnowledge { thing });
|
||||
}
|
||||
"""
|
||||
|
||||
def test_private_table(self):
|
||||
"""Ensure that a private table can only be queried by the database owner"""
|
||||
|
||||
out = self.spacetime("sql", self.database_identity, "select * from secret")
|
||||
answer = "\n".join([
|
||||
" answer ",
|
||||
"--------",
|
||||
" 42 ",
|
||||
""
|
||||
])
|
||||
self.assertMultiLineEqual(str(out), answer)
|
||||
|
||||
self.reset_config()
|
||||
self.new_identity()
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
self.spacetime("sql", self.database_identity, "select * from secret")
|
||||
|
||||
# Subscribing to the private table failes.
|
||||
with self.assertRaises(Exception):
|
||||
self.subscribe("SELECT * FROM secret", n=0)
|
||||
|
||||
# Subscribing to the public table works.
|
||||
sub = self.subscribe("SELECT * FROM common_knowledge", n = 1)
|
||||
self.call("do_thing", "godmorgon")
|
||||
self.assertEqual(sub(), [
|
||||
{
|
||||
'common_knowledge': {
|
||||
'deletes': [],
|
||||
'inserts': [{'thing': 'godmorgon'}]
|
||||
}
|
||||
}
|
||||
])
|
||||
|
||||
# Subscribing to both tables returns updates for the public one.
|
||||
sub = self.subscribe("SELECT * FROM *", n=1)
|
||||
self.call("do_thing", "howdy", anon=True)
|
||||
self.assertEqual(sub(), [
|
||||
{
|
||||
'common_knowledge': {
|
||||
'deletes': [],
|
||||
'inserts': [{'thing': 'howdy'}]
|
||||
}
|
||||
}
|
||||
])
|
||||
|
||||
|
||||
class LifecycleReducers(Smoketest):
|
||||
lifecycle_kinds = "init", "client_connected", "client_disconnected"
|
||||
|
||||
MODULE_CODE = "\n".join(f"""
|
||||
#[spacetimedb::reducer({kind})]
|
||||
fn lifecycle_{kind}(_ctx: &spacetimedb::ReducerContext) {{}}
|
||||
""" for kind in lifecycle_kinds)
|
||||
|
||||
def test_lifecycle_reducers_cant_be_called(self):
|
||||
"""Ensure that lifecycle reducers (init, on_connect, etc) can't be called"""
|
||||
|
||||
for kind in self.lifecycle_kinds:
|
||||
with self.assertRaises(Exception):
|
||||
self.call(f"lifecycle_{kind}")
|
||||
@@ -1,403 +0,0 @@
|
||||
import logging
|
||||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import xmltodict
|
||||
|
||||
import smoketests
|
||||
from .. import Smoketest, STDB_DIR, run_cmd, TEMPLATE_CARGO_TOML, TYPESCRIPT_BINDINGS_PATH, build_typescript_sdk, pnpm
|
||||
|
||||
|
||||
def _write_file(path: Path, content: str):
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(content)
|
||||
|
||||
|
||||
def _append_to_file(path: Path, content: str):
|
||||
with open(path, "a", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def _parse_quickstart(doc_path: Path, language: str, module_name: str, server: bool) -> str:
|
||||
"""Extract code blocks from `quickstart.md` docs.
|
||||
This will replicate the steps in the quickstart guide, so if it fails the quickstart guide is broken.
|
||||
"""
|
||||
content = Path(doc_path).read_text()
|
||||
|
||||
# append " server" to the codeblock language if we're extracting server code
|
||||
if server:
|
||||
codeblock_lang = "ts server" if language == "typescript" else f"{language} server"
|
||||
else:
|
||||
codeblock_lang = "ts" if language == "typescript" else language
|
||||
|
||||
blocks = re.findall(rf"```{codeblock_lang}\n(.*?)\n```", content, re.DOTALL)
|
||||
|
||||
end = ""
|
||||
if language == "csharp":
|
||||
found = False
|
||||
filtered_blocks = []
|
||||
for block in blocks:
|
||||
# The doc first create an empty class Module, so we need to fixup the closing
|
||||
if "partial class Module" in block:
|
||||
block = block.replace("}", "")
|
||||
end = "\n}"
|
||||
# Remove the first `OnConnected` block, which body is later updated
|
||||
if "OnConnected(DbConnection conn" in block and not found:
|
||||
found = True
|
||||
continue
|
||||
filtered_blocks.append(block)
|
||||
blocks = filtered_blocks
|
||||
# So we could have a different db for each language
|
||||
return "\n".join(blocks).replace("quickstart-chat", module_name) + end
|
||||
|
||||
def load_nuget_config(p: Path):
|
||||
if p.exists():
|
||||
with p.open("rb") as f:
|
||||
return xmltodict.parse(f.read(), force_list=["add", "packageSource", "package"])
|
||||
return {}
|
||||
|
||||
|
||||
def _nuget_config_path(project_dir: Path) -> Path:
|
||||
p_upper = project_dir / "NuGet.Config"
|
||||
if p_upper.exists():
|
||||
return p_upper
|
||||
|
||||
p_lower = project_dir / "nuget.config"
|
||||
if p_lower.exists():
|
||||
return p_lower
|
||||
|
||||
return p_upper
|
||||
|
||||
def save_nuget_config(p: Path, doc: dict):
|
||||
# Write back (pretty, UTF-8, no BOM)
|
||||
xml = xmltodict.unparse(doc, pretty=True)
|
||||
p.write_text(xml, encoding="utf-8")
|
||||
|
||||
def add_source(doc: dict, *, key: str, path: str) -> None:
|
||||
cfg = doc.setdefault("configuration", {})
|
||||
sources = cfg.setdefault("packageSources", {})
|
||||
source_entries = sources.setdefault("add", [])
|
||||
|
||||
source_entries.append({"@key": key, "@value": str(path)})
|
||||
|
||||
def add_mapping(doc: dict, *, key: str, pattern: str) -> None:
|
||||
cfg = doc.setdefault("configuration", {})
|
||||
|
||||
psm = cfg.setdefault("packageSourceMapping", {})
|
||||
mapping_sources = psm.setdefault("packageSource", [])
|
||||
|
||||
# Find or create the target <packageSource key="...">
|
||||
target = next((s for s in mapping_sources if s.get("@key") == key), None)
|
||||
if target is None:
|
||||
target = {"@key": key, "package": []}
|
||||
mapping_sources.append(target)
|
||||
|
||||
pkgs = target.setdefault("package", [])
|
||||
|
||||
existing = {pkg.get("@pattern") for pkg in pkgs if "@pattern" in pkg}
|
||||
if pattern not in existing:
|
||||
pkgs.append({"@pattern": pattern})
|
||||
|
||||
def override_nuget_package(*, project_dir: Path, package: str, source_dir: Path, build_subdir: str):
|
||||
"""Override nuget config to use a local NuGet package on a .NET project"""
|
||||
# Make sure the local package is built
|
||||
repo_nuget_config = STDB_DIR / "NuGet.Config"
|
||||
if repo_nuget_config.exists():
|
||||
run_cmd(
|
||||
"dotnet",
|
||||
"restore",
|
||||
"--configfile",
|
||||
str(repo_nuget_config),
|
||||
cwd=source_dir,
|
||||
capture_stderr=True,
|
||||
)
|
||||
run_cmd("dotnet", "pack", "-c", "Release", "--no-restore", cwd=source_dir)
|
||||
else:
|
||||
run_cmd("dotnet", "pack", "-c", "Release", cwd=source_dir)
|
||||
|
||||
p = _nuget_config_path(Path(project_dir))
|
||||
doc = load_nuget_config(p)
|
||||
add_source(doc, key=package, path=source_dir/build_subdir)
|
||||
add_mapping(doc, key=package, pattern=package)
|
||||
add_source(doc, key="nuget.org", path="https://api.nuget.org/v3/index.json")
|
||||
# Fallback for other packages
|
||||
add_mapping(doc, key="nuget.org", pattern="*")
|
||||
save_nuget_config(p, doc)
|
||||
|
||||
# Clear any caches for nuget packages
|
||||
run_cmd("dotnet", "nuget", "locals", "--clear", "all", capture_stderr=True)
|
||||
|
||||
class BaseQuickstart(Smoketest):
|
||||
AUTOPUBLISH = False
|
||||
MODULE_CODE = ""
|
||||
|
||||
lang = None
|
||||
client_lang = None
|
||||
codeblock_langs = None
|
||||
server_doc = None
|
||||
client_doc = None
|
||||
server_file = None
|
||||
client_file = None
|
||||
module_bindings = None
|
||||
extra_code = None
|
||||
replacements = {}
|
||||
connected_str = None
|
||||
run_cmd = []
|
||||
build_cmd = []
|
||||
|
||||
def project_init(self, path: Path):
|
||||
raise NotImplementedError
|
||||
|
||||
def sdk_setup(self, path: Path):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _module_name(self):
|
||||
return f"quickstart-chat-{self.lang}"
|
||||
|
||||
def _publish(self) -> Path:
|
||||
base_path = Path(self.enterClassContext(tempfile.TemporaryDirectory()))
|
||||
server_path = base_path / "server"
|
||||
|
||||
self.generate_server(server_path)
|
||||
self.publish_module(self._module_name, capture_stderr=True, clear=True)
|
||||
return base_path / "client"
|
||||
|
||||
def generate_server(self, server_path: Path):
|
||||
"""Generate the server code from the quickstart documentation."""
|
||||
logging.info(f"Generating server code {self.lang}: {server_path}...")
|
||||
self.spacetime(
|
||||
"init",
|
||||
"--non-interactive",
|
||||
"--lang",
|
||||
self.lang,
|
||||
"--project-path",
|
||||
server_path,
|
||||
"spacetimedb-project",
|
||||
capture_stderr=True,
|
||||
)
|
||||
self.project_path = server_path / "spacetimedb"
|
||||
shutil.copy2(STDB_DIR / "rust-toolchain.toml", self.project_path)
|
||||
_write_file(self.project_path / self.server_file, _parse_quickstart(self.server_doc, self.lang, self._module_name, server=True))
|
||||
self.server_postprocess(self.project_path)
|
||||
self.spacetime("build", "-d", "-p", self.project_path, capture_stderr=True)
|
||||
|
||||
def server_postprocess(self, server_path: Path):
|
||||
"""Optional per-language hook."""
|
||||
pass
|
||||
|
||||
def check(self, input_cmd: str, client_path: Path, contains: str):
|
||||
"""Run the client command and check if the output contains the expected string."""
|
||||
output = run_cmd(*self.run_cmd, input=input_cmd, cwd=client_path, capture_stderr=True, text=True)
|
||||
print(f"Output for {self.lang} client:\n{output}")
|
||||
self.assertIn(contains, output)
|
||||
|
||||
def _test_quickstart(self):
|
||||
"""Run the quickstart client."""
|
||||
client_path = self._publish()
|
||||
self.project_init(client_path)
|
||||
self.sdk_setup(client_path)
|
||||
|
||||
run_cmd(*self.build_cmd, cwd=client_path, capture_stderr=True)
|
||||
|
||||
client_lang = self.client_lang or self.lang
|
||||
self.spacetime(
|
||||
"generate", "--lang", client_lang,
|
||||
"--out-dir", client_path / self.module_bindings,
|
||||
"--module-path", self.project_path, capture_stderr=True
|
||||
)
|
||||
# Replay the quickstart guide steps
|
||||
main = _parse_quickstart(self.client_doc, client_lang, self._module_name, server=False)
|
||||
for src, dst in self.replacements.items():
|
||||
main = main.replace(src, dst)
|
||||
main += "\n" + self.extra_code
|
||||
server = self.get_server_address()
|
||||
host = server["address"]
|
||||
protocol = server["protocol"]
|
||||
main = main.replace("http://localhost:3000", f"{protocol}://{host}")
|
||||
_write_file(client_path / self.client_file, main)
|
||||
|
||||
self.check("", client_path, self.connected_str)
|
||||
self.check("/name Alice", client_path, "Alice")
|
||||
self.check("Hello World", client_path, "Hello World")
|
||||
|
||||
|
||||
class Rust(BaseQuickstart):
|
||||
lang = "rust"
|
||||
server_doc = STDB_DIR / "docs/docs/00100-intro/00300-tutorials/00100-chat-app.md"
|
||||
client_doc = STDB_DIR / "docs/docs/00100-intro/00300-tutorials/00100-chat-app.md"
|
||||
server_file = "src/lib.rs"
|
||||
client_file = "src/main.rs"
|
||||
module_bindings = "src/module_bindings"
|
||||
run_cmd = ["cargo", "run"]
|
||||
build_cmd = ["cargo", "build"]
|
||||
|
||||
replacements = {
|
||||
# Replace the interactive user input to allow direct testing
|
||||
"user_input_loop(&ctx)": "user_input_direct(&ctx)",
|
||||
# Don't cache the token, because it will cause the test to fail if we run against a non-default server (because we don't cache the corresponding signing keypair)
|
||||
".with_token(creds_store()": "//.with_token(creds_store()"
|
||||
}
|
||||
extra_code = """
|
||||
fn user_input_direct(ctx: &DbConnection) {
|
||||
let mut line = String::new();
|
||||
std::io::stdin().read_line(&mut line).expect("Failed to read from stdin.");
|
||||
if let Some(name) = line.strip_prefix("/name ") {
|
||||
ctx.reducers.set_name(name.to_string()).unwrap();
|
||||
} else {
|
||||
ctx.reducers.send_message(line).unwrap();
|
||||
}
|
||||
std::thread::sleep(std::time::Duration::from_secs(1));
|
||||
std::process::exit(0);
|
||||
}
|
||||
"""
|
||||
connected_str = "connected"
|
||||
|
||||
def project_init(self, path: Path):
|
||||
run_cmd("cargo", "new", "--bin", "--name", "quickstart_chat_client", "client", cwd=path.parent,
|
||||
capture_stderr=True)
|
||||
|
||||
def sdk_setup(self, path: Path):
|
||||
sdk_rust_path = (STDB_DIR / "sdks/rust").absolute()
|
||||
sdk_rust_toml_escaped = str(sdk_rust_path).replace('\\', '\\\\\\\\') # double escape for re.sub + toml
|
||||
sdk_rust_toml = f'spacetimedb-sdk = {{ path = "{sdk_rust_toml_escaped}" }}\nlog = "0.4"\nhex = "0.4"\n'
|
||||
_append_to_file(path / "Cargo.toml", sdk_rust_toml)
|
||||
|
||||
def server_postprocess(self, server_path: Path):
|
||||
_write_file(server_path / "Cargo.toml", self.cargo_manifest(TEMPLATE_CARGO_TOML))
|
||||
|
||||
def test_quickstart(self):
|
||||
"""Run the Rust quickstart guides for server and client."""
|
||||
self._test_quickstart()
|
||||
|
||||
|
||||
class CSharp(BaseQuickstart):
|
||||
lang = "csharp"
|
||||
server_doc = STDB_DIR / "docs/docs/00100-intro/00300-tutorials/00100-chat-app.md"
|
||||
client_doc = STDB_DIR / "docs/docs/00100-intro/00300-tutorials/00100-chat-app.md"
|
||||
server_file = "Lib.cs"
|
||||
client_file = "Program.cs"
|
||||
module_bindings = "module_bindings"
|
||||
run_cmd = ["dotnet", "run"]
|
||||
build_cmd = ["dotnet", "build"]
|
||||
|
||||
# Replace the interactive user input to allow direct testing
|
||||
replacements = {
|
||||
"InputLoop();": "UserInputDirect();",
|
||||
".OnConnect(OnConnected)": ".OnConnect(OnConnectedSignal)",
|
||||
".OnConnectError(OnConnectError)": ".OnConnectError(OnConnectErrorSignal)",
|
||||
# Don't cache the token, because it will cause the test to fail if we run against a non-default server (because we don't cache the corresponding signing keypair)
|
||||
".WithToken(AuthToken.Token)": "//.WithToken(AuthToken.Token)",
|
||||
"Main();": "" # To put the main function at the end so it can see the new functions
|
||||
}
|
||||
# So we can wait for the connection to be established...
|
||||
extra_code = """
|
||||
var connectedEvent = new ManualResetEventSlim(false);
|
||||
var connectionFailed = new ManualResetEventSlim(false);
|
||||
void OnConnectErrorSignal(Exception e)
|
||||
{
|
||||
OnConnectError(e);
|
||||
connectionFailed.Set();
|
||||
}
|
||||
void OnConnectedSignal(DbConnection conn, Identity identity, string authToken)
|
||||
{
|
||||
OnConnected(conn, identity, authToken);
|
||||
connectedEvent.Set();
|
||||
}
|
||||
|
||||
void UserInputDirect() {
|
||||
string? line = Console.In.ReadToEnd()?.Trim();
|
||||
if (line == null) Environment.Exit(0);
|
||||
|
||||
if (!WaitHandle.WaitAny(
|
||||
new[] { connectedEvent.WaitHandle, connectionFailed.WaitHandle },
|
||||
TimeSpan.FromSeconds(5)
|
||||
).Equals(0))
|
||||
{
|
||||
Console.WriteLine("Failed to connect to server within timeout.");
|
||||
Environment.Exit(1);
|
||||
}
|
||||
|
||||
if (line.StartsWith("/name ")) {
|
||||
input_queue.Enqueue(("name", line[6..]));
|
||||
} else {
|
||||
input_queue.Enqueue(("message", line));
|
||||
}
|
||||
Thread.Sleep(1000);
|
||||
}
|
||||
Main();
|
||||
"""
|
||||
connected_str = "Connected"
|
||||
|
||||
def project_init(self, path: Path):
|
||||
run_cmd("dotnet", "new", "console", "--name", "QuickstartChatClient", "--output", path, capture_stderr=True)
|
||||
|
||||
def sdk_setup(self, path: Path):
|
||||
override_nuget_package(
|
||||
project_dir=STDB_DIR/"sdks/csharp",
|
||||
package="SpacetimeDB.BSATN.Runtime",
|
||||
source_dir=(STDB_DIR / "crates/bindings-csharp/BSATN.Runtime").absolute(),
|
||||
build_subdir="bin/Release"
|
||||
)
|
||||
# This one is only needed because the regression-tests subdir uses it
|
||||
override_nuget_package(
|
||||
project_dir=STDB_DIR/"sdks/csharp",
|
||||
package="SpacetimeDB.Runtime",
|
||||
source_dir=(STDB_DIR / "crates/bindings-csharp/Runtime").absolute(),
|
||||
build_subdir="bin/Release"
|
||||
)
|
||||
override_nuget_package(
|
||||
project_dir=path,
|
||||
package="SpacetimeDB.BSATN.Runtime",
|
||||
source_dir=(STDB_DIR / "crates/bindings-csharp/BSATN.Runtime").absolute(),
|
||||
build_subdir="bin/Release"
|
||||
)
|
||||
override_nuget_package(
|
||||
project_dir=path,
|
||||
package="SpacetimeDB.ClientSDK",
|
||||
source_dir=(STDB_DIR / "sdks/csharp").absolute(),
|
||||
build_subdir="bin~/Release"
|
||||
)
|
||||
run_cmd("dotnet", "add", "package", "SpacetimeDB.ClientSDK", cwd=path, capture_stderr=True)
|
||||
|
||||
def server_postprocess(self, server_path: Path):
|
||||
override_nuget_package(
|
||||
project_dir=server_path,
|
||||
package="SpacetimeDB.Runtime",
|
||||
source_dir=(STDB_DIR / "crates/bindings-csharp/Runtime").absolute(),
|
||||
build_subdir="bin/Release"
|
||||
)
|
||||
override_nuget_package(
|
||||
project_dir=server_path,
|
||||
package="SpacetimeDB.BSATN.Runtime",
|
||||
source_dir=(STDB_DIR / "crates/bindings-csharp/BSATN.Runtime").absolute(),
|
||||
build_subdir="bin/Release"
|
||||
)
|
||||
|
||||
def test_quickstart(self):
|
||||
"""Run the C# quickstart guides for server and client."""
|
||||
if not smoketests.HAVE_DOTNET:
|
||||
self.skipTest("C# SDK requires .NET to be installed.")
|
||||
self._test_quickstart()
|
||||
|
||||
# We use the Rust client for testing the TypeScript server quickstart because
|
||||
# the TypeScript client quickstart is a React app, which is difficult to
|
||||
# smoketest.
|
||||
class TypeScript(Rust):
|
||||
lang = "typescript"
|
||||
client_lang = "rust"
|
||||
server_doc = STDB_DIR / "docs/docs/00100-intro/00300-tutorials/00100-chat-app.md"
|
||||
server_file = "src/index.ts"
|
||||
|
||||
def server_postprocess(self, server_path: Path):
|
||||
build_typescript_sdk()
|
||||
# We already have spacetimedb as a depencency, but it's expecting to fetch from npm.
|
||||
# If we don't uninstall before installing, pnpm can panic because it can't find
|
||||
# the specified version on npm (even though we're about to override it anyway).
|
||||
pnpm("uninstall", 'spacetimedb', cwd=server_path)
|
||||
pnpm("install", TYPESCRIPT_BINDINGS_PATH, cwd=server_path)
|
||||
|
||||
def test_quickstart(self):
|
||||
"""Run the TypeScript quickstart guides for server."""
|
||||
self._test_quickstart()
|
||||
@@ -1,507 +0,0 @@
|
||||
import time
|
||||
import unittest
|
||||
from typing import Callable
|
||||
import json
|
||||
|
||||
from .. import COMPOSE_FILE, Smoketest, random_string, requires_docker, spacetime, parse_sql_result
|
||||
from ..docker import DockerManager
|
||||
|
||||
def retry(func: Callable, max_retries: int = 3, retry_delay: int = 2):
|
||||
"""Retry a function on failure with delay."""
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
return func()
|
||||
except Exception as e:
|
||||
if attempt < max_retries:
|
||||
print(f"Attempt {attempt} failed: {e}. Retrying in {retry_delay} seconds...")
|
||||
time.sleep(retry_delay)
|
||||
else:
|
||||
print("Max retries reached. Skipping the exception.")
|
||||
return False
|
||||
|
||||
def int_vals(rows: list[dict]) -> list[dict]:
|
||||
"""For all dicts in list, cast all values in dict to int."""
|
||||
return [{k: int(v) for k, v in row.items()} for row in rows]
|
||||
|
||||
class Cluster:
|
||||
"""Manages leader-related operations and state for SpaceTime database cluster."""
|
||||
|
||||
def __init__(self, docker_manager, smoketest: Smoketest):
|
||||
self.docker = docker_manager
|
||||
self.test = smoketest
|
||||
|
||||
# Ensure all containers are up.
|
||||
self.docker.compose("up", "-d")
|
||||
|
||||
def sql(self, sql: str) -> list[dict]:
|
||||
"""Query the test database."""
|
||||
res = self.test.sql(sql)
|
||||
return parse_sql_result(str(res))
|
||||
|
||||
def read_controldb(self, sql: str) -> list[dict]:
|
||||
"""Query the control database."""
|
||||
res = self.test.spacetime("sql", "spacetime-control", sql)
|
||||
return parse_sql_result(str(res))
|
||||
|
||||
def get_db_id(self):
|
||||
"""Query database ID."""
|
||||
sql = f"select id from database where database_identity=0x{self.test.database_identity}"
|
||||
res = self.read_controldb(sql)
|
||||
return int(res[0]['id'])
|
||||
|
||||
def get_all_replicas(self):
|
||||
"""Get all replica nodes in the cluster."""
|
||||
database_id = self.get_db_id()
|
||||
sql = f"select id, node_id from replica where database_id={database_id}"
|
||||
return int_vals(self.read_controldb(sql))
|
||||
|
||||
def get_leader_info(self):
|
||||
"""Get current leader's node information including ID, hostname, and container ID."""
|
||||
|
||||
database_id = self.get_db_id()
|
||||
sql = f""" \
|
||||
select node_v2.id, node_v2.network_addr from node_v2 \
|
||||
join replica on replica.node_id=node_v2.id \
|
||||
join replication_state on replication_state.leader=replica.id \
|
||||
where replication_state.database_id={database_id} \
|
||||
"""
|
||||
rows = self.read_controldb(sql)
|
||||
if not rows:
|
||||
raise Exception("Could not find current leader's node")
|
||||
|
||||
leader_node_id = int(rows[0]['id'])
|
||||
hostname = ""
|
||||
if "(some =" in rows[0]['network_addr']:
|
||||
address = rows[0]['network_addr'].split('"')[1]
|
||||
hostname = address.split(':')[0]
|
||||
|
||||
# Find container ID
|
||||
container_id = ""
|
||||
containers = self.docker.list_containers()
|
||||
for container in containers:
|
||||
if hostname in container.name:
|
||||
container_id = container.id
|
||||
break
|
||||
|
||||
return {
|
||||
'node_id': leader_node_id,
|
||||
'hostname': hostname,
|
||||
'container_id': container_id
|
||||
}
|
||||
|
||||
def wait_for_leader_change(self, previous_leader_node, max_attempts=10, delay=2):
|
||||
"""Wait for leader to change and return new leader node_id."""
|
||||
|
||||
for _ in range(max_attempts):
|
||||
try:
|
||||
current_leader_node = self.get_leader_info()['node_id']
|
||||
if current_leader_node != previous_leader_node:
|
||||
return current_leader_node
|
||||
except Exception:
|
||||
print("No current leader")
|
||||
|
||||
time.sleep(delay)
|
||||
return None
|
||||
|
||||
def ensure_leader_health(self, id):
|
||||
"""Verify leader is healthy by inserting a row."""
|
||||
|
||||
retry(lambda: self.test.call("start", id, 1))
|
||||
rows = self.sql(f"select id from counter where id={id}")
|
||||
if len(rows) < 1 or int(rows[0]['id']) != id:
|
||||
raise ValueError(f"Could not find {id} in counter table")
|
||||
# Wait for at least one tick to ensure buffers are flushed.
|
||||
# TODO: Replace with confirmed read.
|
||||
time.sleep(0.6)
|
||||
|
||||
def wait_counter_value(self, id, value, max_attempts=10, delay=1):
|
||||
"""Wait for the value for `id` in the counter table to reach `value`"""
|
||||
|
||||
for _ in range(max_attempts):
|
||||
rows = self.sql(f"select * from counter where id={id}")
|
||||
if len(rows) >= 1 and int(rows[0]['value']) >= value:
|
||||
return
|
||||
else:
|
||||
time.sleep(delay)
|
||||
|
||||
raise ValueError(f"Counter {id} below {value}")
|
||||
|
||||
|
||||
def fail_leader(self, action='kill'):
|
||||
"""Force leader failure through either killing or network disconnect."""
|
||||
leader_info = self.get_leader_info()
|
||||
container_id = leader_info['container_id']
|
||||
|
||||
if not container_id:
|
||||
raise ValueError("Could not find leader container")
|
||||
|
||||
if action == 'kill':
|
||||
self.docker.kill_container(container_id)
|
||||
elif action == 'disconnect':
|
||||
self.docker.disconnect_container(container_id)
|
||||
else:
|
||||
raise ValueError(f"Unknown action: {action}")
|
||||
|
||||
return container_id
|
||||
|
||||
def restore_leader(self, container_id, action='start'):
|
||||
"""Restore failed leader through either starting or network reconnect."""
|
||||
if action == 'start':
|
||||
self.docker.start_container(container_id)
|
||||
elif action == 'connect':
|
||||
self.docker.connect_container(container_id)
|
||||
else:
|
||||
raise ValueError(f"Unknown action: {action}")
|
||||
|
||||
@requires_docker
|
||||
class ReplicationTest(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{duration, ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::table(accessor = counter, public)]
|
||||
pub struct Counter {
|
||||
#[primary_key]
|
||||
#[auto_inc]
|
||||
id: u64,
|
||||
#[index(btree)]
|
||||
value: u64,
|
||||
}
|
||||
|
||||
#[spacetimedb::table(accessor = schedule_counter, public, scheduled(increment, at = sched_at))]
|
||||
pub struct ScheduledCounter {
|
||||
#[primary_key]
|
||||
#[auto_inc]
|
||||
scheduled_id: u64,
|
||||
sched_at: spacetimedb::ScheduleAt,
|
||||
count: u64,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
fn increment(ctx: &ReducerContext, arg: ScheduledCounter) {
|
||||
// if the counter exists, increment it
|
||||
if let Some(counter) = ctx.db.counter().id().find(arg.scheduled_id) {
|
||||
if counter.value == arg.count {
|
||||
ctx.db.schedule_counter().delete(arg);
|
||||
return;
|
||||
}
|
||||
// update counter
|
||||
ctx.db.counter().id().update(Counter {
|
||||
id: arg.scheduled_id,
|
||||
value: counter.value + 1,
|
||||
});
|
||||
} else {
|
||||
// insert fresh counter
|
||||
ctx.db.counter().insert(Counter {
|
||||
id: arg.scheduled_id,
|
||||
value: 1,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
fn start(ctx: &ReducerContext, id: u64, count: u64) {
|
||||
ctx.db.schedule_counter().insert(ScheduledCounter {
|
||||
scheduled_id: id,
|
||||
sched_at: duration!(0ms).into(),
|
||||
count,
|
||||
});
|
||||
}
|
||||
|
||||
#[spacetimedb::table(accessor = message, public)]
|
||||
pub struct Message {
|
||||
#[primary_key]
|
||||
#[auto_inc]
|
||||
id: u64,
|
||||
text: String
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
fn send_message(ctx: &ReducerContext, text: String) {
|
||||
ctx.db.message().insert(Message { id: 0, text });
|
||||
}
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.root_config = cls.project_path / "root_config"
|
||||
spacetime("--config-path", cls.root_config, "server", "set-default", "local")
|
||||
|
||||
def setUp(self):
|
||||
self.docker = DockerManager(COMPOSE_FILE)
|
||||
self.root_token = self.docker.generate_root_token()
|
||||
|
||||
self.cluster = Cluster(self.docker, self)
|
||||
|
||||
def tearDown(self):
|
||||
# Ensure containers that were brought down during a test are back up.
|
||||
self.docker.compose("up", "-d")
|
||||
super().tearDown()
|
||||
|
||||
def add_me_as_admin(self):
|
||||
"""Add the current user as an admin account"""
|
||||
db_owner_id = str(self.spacetime("login", "show")).split()[-1]
|
||||
spacetime("--config-path", self.root_config, "login", "--token", self.root_token)
|
||||
spacetime("--config-path", self.root_config, "call", "spacetime-control", "create_admin_account", f"0x{db_owner_id}")
|
||||
|
||||
def start(self, id: int, count: int):
|
||||
"""Send a message to the database."""
|
||||
retry(lambda: self.call("start", id, count))
|
||||
|
||||
def collect_counter_rows(self):
|
||||
return int_vals(self.cluster.sql("select * from counter"))
|
||||
|
||||
def call_control(self, reducer, *args):
|
||||
self.spacetime("call", "spacetime-control", reducer, *map(json.dumps, args))
|
||||
|
||||
|
||||
class LeaderElection(ReplicationTest):
|
||||
def test_leader_election_in_loop(self):
|
||||
"""This test fails a leader, wait for new leader to be elected and verify if commits replicated to new leader"""
|
||||
iterations = 5
|
||||
row_ids = [101 + i for i in range(iterations * 2)]
|
||||
for (first_id, second_id) in zip(row_ids[::2], row_ids[1::2]):
|
||||
cur_leader = self.cluster.wait_for_leader_change(None)
|
||||
print(f"ensure leader health {first_id}")
|
||||
self.cluster.ensure_leader_health(first_id)
|
||||
|
||||
print(f"killing current leader: {cur_leader}")
|
||||
container_id = self.cluster.fail_leader()
|
||||
|
||||
self.assertIsNotNone(container_id)
|
||||
|
||||
next_leader = self.cluster.wait_for_leader_change(cur_leader)
|
||||
self.assertNotEqual(cur_leader, next_leader)
|
||||
# this check if leader election happened
|
||||
print(f"ensure_leader_health {second_id}")
|
||||
self.cluster.ensure_leader_health(second_id)
|
||||
# restart the old leader, so that we can maintain quorum for next iteration
|
||||
print(f"reconnect leader {container_id}")
|
||||
self.cluster.restore_leader(container_id, 'start')
|
||||
|
||||
# Ensure we have a current leader
|
||||
last_row_id = row_ids[-1] + 1
|
||||
self.cluster.ensure_leader_health(row_ids[-1] + 1)
|
||||
row_ids.append(last_row_id)
|
||||
|
||||
# Verify that all inserted rows are present
|
||||
stored_row_ids = [row['id'] for row in self.collect_counter_rows()]
|
||||
self.assertEqual(set(stored_row_ids), set(row_ids))
|
||||
|
||||
class LeaderDisconnect(ReplicationTest):
|
||||
def test_leader_c_disconnect_in_loop(self):
|
||||
"""This test disconnects a leader, wait for new leader to be elected and verify if commits replicated to new leader"""
|
||||
|
||||
iterations = 5
|
||||
row_ids = [201 + i for i in range(iterations * 2)]
|
||||
|
||||
for (first_id, second_id) in zip(row_ids[::2], row_ids[1::2]):
|
||||
print(f"first={first_id} second={second_id}")
|
||||
cur_leader = self.cluster.wait_for_leader_change(None)
|
||||
print(f"ensure leader health {first_id}")
|
||||
self.cluster.ensure_leader_health(first_id)
|
||||
|
||||
print("disconnect current leader")
|
||||
container_id = self.cluster.fail_leader('disconnect')
|
||||
self.assertIsNotNone(container_id)
|
||||
print(f"disconnected leader's container is {container_id}")
|
||||
|
||||
next_leader = self.cluster.wait_for_leader_change(cur_leader)
|
||||
self.assertNotEqual(cur_leader, next_leader)
|
||||
# this check if leader election happened
|
||||
print(f"ensure_leader_health {second_id}")
|
||||
self.cluster.ensure_leader_health(second_id)
|
||||
|
||||
# restart the old leader, so that we can maintain quorum for next iteration
|
||||
print(f"reconnect leader {container_id}")
|
||||
self.cluster.restore_leader(container_id, 'connect')
|
||||
|
||||
# Ensure we have a current leader
|
||||
last_row_id = row_ids[-1] + 1
|
||||
self.cluster.ensure_leader_health(last_row_id)
|
||||
row_ids.append(last_row_id)
|
||||
|
||||
# Verify that all inserted rows are present
|
||||
stored_row_ids = [row['id'] for row in self.collect_counter_rows()]
|
||||
self.assertEqual(set(stored_row_ids), set(row_ids))
|
||||
|
||||
|
||||
@unittest.skip("drain_node not yet supported")
|
||||
class DrainLeader(ReplicationTest):
|
||||
def test_drain_leader_node(self):
|
||||
"""This test moves leader replica to different node"""
|
||||
self.add_me_as_admin()
|
||||
cur_leader_node_id = self.cluster.wait_for_leader_change(None)
|
||||
self.cluster.ensure_leader_health(301)
|
||||
|
||||
replicas = self.cluster.get_all_replicas()
|
||||
empty_node_id = 14
|
||||
for replica in replicas:
|
||||
empty_node_id = empty_node_id - replica['node_id']
|
||||
self.spacetime("call", "spacetime-control", "drain_node", f"{cur_leader_node_id}", f"{empty_node_id}")
|
||||
|
||||
time.sleep(5)
|
||||
self.cluster.ensure_leader_health(302)
|
||||
replicas = self.cluster.get_all_replicas()
|
||||
for replica in replicas:
|
||||
self.assertNotEqual(replica['node_id'], cur_leader_node_id)
|
||||
|
||||
|
||||
class PreferLeader(ReplicationTest):
|
||||
def test_prefer_leader(self):
|
||||
"""This test moves leader replica to different node"""
|
||||
self.add_me_as_admin()
|
||||
cur_leader_node_id = self.cluster.wait_for_leader_change(None)
|
||||
self.cluster.ensure_leader_health(401)
|
||||
|
||||
replicas = self.cluster.get_all_replicas()
|
||||
prefer_replica = {}
|
||||
for replica in replicas:
|
||||
if replica['node_id'] != cur_leader_node_id:
|
||||
prefer_replica = replica
|
||||
break
|
||||
prefer_replica_id = prefer_replica['id']
|
||||
self.spacetime("call", "spacetime-control", "prefer_leader", f"{prefer_replica_id}")
|
||||
|
||||
next_leader_node_id = self.cluster.wait_for_leader_change(cur_leader_node_id)
|
||||
self.cluster.ensure_leader_health(402)
|
||||
self.assertEqual(prefer_replica['node_id'], next_leader_node_id)
|
||||
|
||||
# verify if all past rows are present in new leader
|
||||
stored_row_ids = [row['id'] for row in self.collect_counter_rows()]
|
||||
self.assertEqual(set(stored_row_ids), set([401, 402]))
|
||||
|
||||
|
||||
class ManyTransactions(ReplicationTest):
|
||||
def test_a_many_transactions(self):
|
||||
"""This test sends many messages to the database and verifies that they are all present"""
|
||||
self.cluster.wait_for_leader_change(None)
|
||||
num_messages = 10000
|
||||
sub = self.subscribe("SELECT * FROM counter", n = num_messages)
|
||||
self.start(1, num_messages)
|
||||
|
||||
message_table = sub()[-1:]
|
||||
self.assertIn({
|
||||
'counter': {
|
||||
'deletes': [{'id': 1, 'value': num_messages - 1}],
|
||||
'inserts': [{'id': 1, 'value': num_messages}]
|
||||
}
|
||||
}, message_table)
|
||||
|
||||
|
||||
|
||||
class QuorumLoss(ReplicationTest):
|
||||
def test_quorum_loss(self):
|
||||
"""This test makes cluster to lose majority of followers to verify if leader eventually stop accepting writes"""
|
||||
|
||||
for i in range(11):
|
||||
self.call("send_message", f"{i}")
|
||||
|
||||
leader = self.cluster.get_leader_info()
|
||||
containers = self.docker.list_containers()
|
||||
for container in containers:
|
||||
if leader['container_id'] != container.id and "worker" in container.name:
|
||||
self.docker.kill_container(container.id)
|
||||
|
||||
time.sleep(2)
|
||||
with self.assertRaises(Exception):
|
||||
for i in range(1001):
|
||||
self.call("send_message", "terminal")
|
||||
|
||||
|
||||
class EnableReplicationTest(ReplicationTest):
|
||||
AUTOPUBLISH = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.expected_counter_rows = []
|
||||
|
||||
def run_counter(self, id, n = 100):
|
||||
self.start(id, n)
|
||||
self.cluster.wait_counter_value(id, n)
|
||||
self.expected_counter_rows.append({"id": id, "value": n})
|
||||
self.assertEqual(self.collect_counter_rows(), self.expected_counter_rows)
|
||||
|
||||
def subscribe_to_enable_replication_events(self):
|
||||
id = self.cluster.get_db_id()
|
||||
return self.subscribe(
|
||||
f"select * from staged_enable_replication_event where database_id={id}",
|
||||
n = 2,
|
||||
database = "spacetime-control"
|
||||
)
|
||||
|
||||
def assert_bootstrap_complete(self, sub):
|
||||
events = sub()
|
||||
self.assertEqual(
|
||||
events[-1]['staged_enable_replication_event']['inserts'][0]['message'],
|
||||
'bootstrap complete',
|
||||
)
|
||||
|
||||
def enable_replication(self, database_name):
|
||||
sub = self.subscribe_to_enable_replication_events()
|
||||
self.call_control("enable_replication", {"Name": database_name}, 3)
|
||||
self.assert_bootstrap_complete(sub)
|
||||
|
||||
class EnableReplicationUnsuspended(EnableReplicationTest):
|
||||
def test_enable_replication_fails_if_not_suspended(self):
|
||||
"""Tests that the database to enable replication on must be suspended"""
|
||||
|
||||
self.add_me_as_admin()
|
||||
name = random_string()
|
||||
|
||||
self.publish_module(name, num_replicas = 1)
|
||||
self.cluster.wait_for_leader_change(None)
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
self.call_control("enable_replication", {"Name": name}, 3)
|
||||
|
||||
|
||||
class EnableReplicationSuspended(EnableReplicationTest):
|
||||
def test_enable_replication_on_suspended_database(self):
|
||||
"""Tests that we can enable replication on a suspended database"""
|
||||
|
||||
self.add_me_as_admin()
|
||||
name = random_string()
|
||||
|
||||
self.publish_module(name, num_replicas = 1)
|
||||
self.cluster.wait_for_leader_change(None)
|
||||
self.cluster.ensure_leader_health(1)
|
||||
|
||||
self.call_control("suspend_database", {"Name": name})
|
||||
# Database is now unreachable.
|
||||
with self.assertRaises(Exception):
|
||||
self.call("send_message", "hi")
|
||||
|
||||
self.enable_replication(name)
|
||||
# Still unreachable until we call unsuspend.
|
||||
with self.assertRaises(Exception):
|
||||
self.call("send_message", "hi")
|
||||
|
||||
self.call_control("unsuspend_database", {"Name": name})
|
||||
self.cluster.wait_for_leader_change(None)
|
||||
self.cluster.ensure_leader_health(2)
|
||||
|
||||
class EnableDisableReplication(EnableReplicationTest):
|
||||
def test_enable_disable_replication(self):
|
||||
"""Tests that we can enable then disable replication"""
|
||||
|
||||
self.add_me_as_admin()
|
||||
name = random_string()
|
||||
|
||||
self.publish_module(name, num_replicas = 1)
|
||||
# ensure database is up and commitlog ends up non-empty
|
||||
self.run_counter(1, 100)
|
||||
|
||||
# suspend first
|
||||
self.call_control("suspend_database", {"Name": name})
|
||||
# enable replication and wait for it to complete
|
||||
self.enable_replication(name)
|
||||
# unsuspend
|
||||
self.call_control("unsuspend_database", {"Name": name})
|
||||
|
||||
self.cluster.wait_for_leader_change(None)
|
||||
self.run_counter(2, 100)
|
||||
|
||||
self.call_control("disable_replication", {"Name": name})
|
||||
self.run_counter(3, 100)
|
||||
@@ -1,158 +0,0 @@
|
||||
import logging
|
||||
|
||||
from .. import Smoketest, random_string
|
||||
|
||||
class Rls(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{Identity, ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::table(accessor = users, public)]
|
||||
pub struct Users {
|
||||
name: String,
|
||||
identity: Identity,
|
||||
}
|
||||
|
||||
#[spacetimedb::client_visibility_filter]
|
||||
const USER_FILTER: spacetimedb::Filter = spacetimedb::Filter::Sql(
|
||||
"SELECT * FROM users WHERE identity = :sender"
|
||||
);
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn add_user(ctx: &ReducerContext, name: String) {
|
||||
ctx.db.users().insert(Users { name, identity: ctx.sender() });
|
||||
}
|
||||
"""
|
||||
|
||||
def test_rls_rules(self):
|
||||
"""Tests for querying tables with RLS rules"""
|
||||
|
||||
# Insert an identity for Alice
|
||||
self.call("add_user", "Alice")
|
||||
|
||||
# Insert a new identity for Bob
|
||||
self.reset_config()
|
||||
self.new_identity()
|
||||
self.call("add_user", "Bob")
|
||||
|
||||
# Query the users table using Bob's identity
|
||||
self.assertSql("SELECT name FROM users", """\
|
||||
name
|
||||
-------
|
||||
"Bob"
|
||||
""")
|
||||
|
||||
# Query the users table using a new identity
|
||||
self.reset_config()
|
||||
self.new_identity()
|
||||
self.assertSql("SELECT name FROM users", """\
|
||||
name
|
||||
------
|
||||
""")
|
||||
|
||||
class BrokenRls(Smoketest):
|
||||
AUTOPUBLISH = False
|
||||
|
||||
MODULE_CODE_BROKEN = """
|
||||
use spacetimedb::{client_visibility_filter, Filter};
|
||||
|
||||
#[spacetimedb::table(accessor = user)]
|
||||
pub struct User {
|
||||
identity: Identity,
|
||||
}
|
||||
|
||||
#[client_visibility_filter]
|
||||
const PERSON_FILTER: Filter = Filter::Sql("SELECT * FROM \"user\" WHERE identity = :sender");
|
||||
"""
|
||||
|
||||
def test_publish_fails_for_rls_on_private_table(self):
|
||||
"""This tests that publishing an RLS rule on a private table fails"""
|
||||
|
||||
name = random_string()
|
||||
|
||||
self.write_module_code(self.MODULE_CODE_BROKEN)
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
self.publish_module(name)
|
||||
|
||||
class DisconnectRls(Smoketest):
|
||||
AUTOPUBLISH = False
|
||||
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{Identity, ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::table(accessor = users, public)]
|
||||
pub struct Users {
|
||||
name: String,
|
||||
identity: Identity,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn add_user(ctx: &ReducerContext, name: String) {
|
||||
ctx.db.users().insert(Users { name, identity: ctx.sender() });
|
||||
}
|
||||
"""
|
||||
|
||||
ADD_RLS = """
|
||||
#[spacetimedb::client_visibility_filter]
|
||||
const USER_FILTER: spacetimedb::Filter = spacetimedb::Filter::Sql(
|
||||
"SELECT * FROM users WHERE identity = :sender"
|
||||
);
|
||||
"""
|
||||
|
||||
def assertSql(self, sql, expected):
|
||||
self.maxDiff = None
|
||||
sql_out = self.spacetime("sql", self.database_identity, sql)
|
||||
sql_out = "\n".join([line.rstrip() for line in sql_out.splitlines()])
|
||||
expected = "\n".join([line.rstrip() for line in expected.splitlines()])
|
||||
self.assertMultiLineEqual(sql_out, expected)
|
||||
|
||||
def test_rls_disconnect_if_change(self):
|
||||
"""This tests that changing the RLS rules disconnects existing clients"""
|
||||
|
||||
name = random_string()
|
||||
|
||||
self.write_module_code(self.MODULE_CODE)
|
||||
|
||||
self.publish_module(name)
|
||||
logging.info("Initial publish complete")
|
||||
|
||||
# Now add the RLS rules
|
||||
self.write_module_code(self.MODULE_CODE + self.ADD_RLS)
|
||||
self.publish_module(name, clear=False, break_clients=True)
|
||||
|
||||
# Check the row-level SQL filter is added correctly
|
||||
self.assertSql(
|
||||
"SELECT sql FROM st_row_level_security",
|
||||
"""\
|
||||
sql
|
||||
------------------------------------------------
|
||||
"SELECT * FROM users WHERE identity = :sender"
|
||||
""",
|
||||
)
|
||||
|
||||
logging.info("Re-publish with RLS complete")
|
||||
|
||||
logs = self.logs(100)
|
||||
|
||||
# Validate disconnect + schema migration logs
|
||||
self.assertIn("Disconnecting all users", logs)
|
||||
|
||||
def test_rls_no_disconnect(self):
|
||||
"""This tests that not changing the RLS rules does not disconnect existing clients"""
|
||||
|
||||
name = random_string()
|
||||
|
||||
self.write_module_code(self.MODULE_CODE + self.ADD_RLS)
|
||||
|
||||
self.publish_module(name)
|
||||
logging.info("Initial publish complete")
|
||||
|
||||
# Now re-publish the same module code
|
||||
self.publish_module(name, clear=False, break_clients=False)
|
||||
|
||||
logging.info("Re-publish without RLS change complete")
|
||||
|
||||
logs = self.logs(100)
|
||||
|
||||
# Validate no disconnect logs
|
||||
self.assertNotIn("Disconnecting all users", logs)
|
||||
@@ -1,275 +0,0 @@
|
||||
from .. import Smoketest
|
||||
import time
|
||||
|
||||
|
||||
class CancelReducer(Smoketest):
|
||||
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{duration, log, ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::reducer(init)]
|
||||
fn init(ctx: &ReducerContext) {
|
||||
let schedule = ctx.db.scheduled_reducer_args().insert(ScheduledReducerArgs {
|
||||
num: 1,
|
||||
scheduled_id: 0,
|
||||
scheduled_at: duration!(100ms).into(),
|
||||
});
|
||||
ctx.db.scheduled_reducer_args().scheduled_id().delete(&schedule.scheduled_id);
|
||||
|
||||
let schedule = ctx.db.scheduled_reducer_args().insert(ScheduledReducerArgs {
|
||||
num: 2,
|
||||
scheduled_id: 0,
|
||||
scheduled_at: duration!(1000ms).into(),
|
||||
});
|
||||
do_cancel(ctx, schedule.scheduled_id);
|
||||
}
|
||||
|
||||
#[spacetimedb::table(accessor = scheduled_reducer_args, public, scheduled(reducer))]
|
||||
pub struct ScheduledReducerArgs {
|
||||
#[primary_key]
|
||||
#[auto_inc]
|
||||
scheduled_id: u64,
|
||||
scheduled_at: spacetimedb::ScheduleAt,
|
||||
num: i32,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
fn do_cancel(ctx: &ReducerContext, schedule_id: u64) {
|
||||
ctx.db.scheduled_reducer_args().scheduled_id().delete(&schedule_id);
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
fn reducer(_ctx: &ReducerContext, args: ScheduledReducerArgs) {
|
||||
log::info!("the reducer ran: {}", args.num);
|
||||
}
|
||||
"""
|
||||
|
||||
def test_cancel_reducer(self):
|
||||
"""Ensure cancelling a reducer works"""
|
||||
|
||||
time.sleep(2)
|
||||
logs = "\n".join(self.logs(5))
|
||||
self.assertNotIn("the reducer ran", logs)
|
||||
|
||||
|
||||
TIMESTAMP_ZERO = {"__timestamp_micros_since_unix_epoch__": 0}
|
||||
|
||||
|
||||
class SubscribeScheduledTable(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{log, duration, ReducerContext, Table, Timestamp};
|
||||
|
||||
#[spacetimedb::table(accessor = scheduled_table, public, scheduled(my_reducer, at = sched_at))]
|
||||
pub struct ScheduledTable {
|
||||
#[primary_key]
|
||||
#[auto_inc]
|
||||
scheduled_id: u64,
|
||||
sched_at: spacetimedb::ScheduleAt,
|
||||
prev: Timestamp,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
fn schedule_reducer(ctx: &ReducerContext) {
|
||||
ctx.db.scheduled_table().insert(ScheduledTable { prev: Timestamp::from_micros_since_unix_epoch(0), scheduled_id: 2, sched_at: Timestamp::from_micros_since_unix_epoch(0).into(), });
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
fn schedule_repeated_reducer(ctx: &ReducerContext) {
|
||||
ctx.db.scheduled_table().insert(ScheduledTable { prev: Timestamp::from_micros_since_unix_epoch(0), scheduled_id: 1, sched_at: duration!(100ms).into(), });
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn my_reducer(ctx: &ReducerContext, arg: ScheduledTable) {
|
||||
log::info!("Invoked: ts={:?}, delta={:?}", ctx.timestamp, ctx.timestamp.duration_since(arg.prev));
|
||||
}
|
||||
"""
|
||||
|
||||
def test_scheduled_table_subscription(self):
|
||||
"""This test deploys a module with a scheduled reducer and check if client receives subscription update for scheduled table entry and deletion of reducer once it ran"""
|
||||
# subscribe to empty scheduled_table
|
||||
sub = self.subscribe("SELECT * FROM scheduled_table", n=2)
|
||||
# call a reducer to schedule a reducer
|
||||
self.call("schedule_reducer")
|
||||
|
||||
time.sleep(2)
|
||||
lines = sum(1 for line in self.logs(100) if "Invoked:" in line)
|
||||
# scheduled reducer should be ran by now
|
||||
self.assertEqual(lines, 1)
|
||||
|
||||
row_entry = {
|
||||
"prev": TIMESTAMP_ZERO,
|
||||
"scheduled_id": 2,
|
||||
"sched_at": {"Time": TIMESTAMP_ZERO},
|
||||
}
|
||||
# subscription should have 2 updates, first for row insert in scheduled table and second for row deletion.
|
||||
self.assertEqual(
|
||||
sub(),
|
||||
[
|
||||
{"scheduled_table": {"deletes": [], "inserts": [row_entry]}},
|
||||
{"scheduled_table": {"deletes": [row_entry], "inserts": []}},
|
||||
],
|
||||
)
|
||||
|
||||
def test_scheduled_table_subscription_repeated_reducer(self):
|
||||
"""This test deploys a module with a repeated reducer and check if client receives subscription update for scheduled table entry and no delete entry"""
|
||||
# subscribe to empty scheduled_table
|
||||
sub = self.subscribe("SELECT * FROM scheduled_table", n=2)
|
||||
# call a reducer to schedule a reducer
|
||||
self.call("schedule_repeated_reducer")
|
||||
|
||||
time.sleep(2)
|
||||
lines = sum(1 for line in self.logs(100) if "Invoked:" in line)
|
||||
# repeated reducer should have run more than once.
|
||||
self.assertLess(2, lines)
|
||||
|
||||
# scheduling repeated reducer again just to get 2nd subscription update.
|
||||
self.call("schedule_reducer")
|
||||
|
||||
repeated_row_entry = {
|
||||
"prev": TIMESTAMP_ZERO,
|
||||
"scheduled_id": 1,
|
||||
"sched_at": {"Interval": {"__time_duration_micros__": 100000}},
|
||||
}
|
||||
row_entry = {
|
||||
"prev": TIMESTAMP_ZERO,
|
||||
"scheduled_id": 2,
|
||||
"sched_at": {"Time": TIMESTAMP_ZERO},
|
||||
}
|
||||
|
||||
# subscription should have 2 updates and should not have any deletes
|
||||
self.assertEqual(
|
||||
sub(),
|
||||
[
|
||||
{"scheduled_table": {"deletes": [], "inserts": [repeated_row_entry]}},
|
||||
{"scheduled_table": {"deletes": [], "inserts": [row_entry]}},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class SubscribeScheduledProcedureTable(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{log, duration, ReducerContext, ProcedureContext, Table, Timestamp};
|
||||
|
||||
#[spacetimedb::table(accessor = scheduled_table, public, scheduled(my_procedure, at = sched_at))]
|
||||
pub struct ScheduledTable {
|
||||
#[primary_key]
|
||||
#[auto_inc]
|
||||
scheduled_id: u64,
|
||||
sched_at: spacetimedb::ScheduleAt,
|
||||
prev: Timestamp,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
fn schedule_procedure(ctx: &ReducerContext) {
|
||||
ctx.db.scheduled_table().insert(ScheduledTable { prev: Timestamp::from_micros_since_unix_epoch(0), scheduled_id: 2, sched_at: Timestamp::from_micros_since_unix_epoch(0).into(), });
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
fn schedule_repeated_procedure(ctx: &ReducerContext) {
|
||||
ctx.db.scheduled_table().insert(ScheduledTable { prev: Timestamp::from_micros_since_unix_epoch(0), scheduled_id: 1, sched_at: duration!(100ms).into(), });
|
||||
}
|
||||
|
||||
#[spacetimedb::procedure]
|
||||
pub fn my_procedure(ctx: &mut ProcedureContext, arg: ScheduledTable) {
|
||||
log::info!("Invoked: ts={:?}, delta={:?}", ctx.timestamp, ctx.timestamp.duration_since(arg.prev));
|
||||
}
|
||||
"""
|
||||
|
||||
def test_scheduled_table_subscription(self):
|
||||
"""This test deploys a module with a scheduled procedure and check if client receives subscription update for scheduled table entry and deletion of procedure once it ran"""
|
||||
# subscribe to empty scheduled_table
|
||||
sub = self.subscribe("SELECT * FROM scheduled_table", n=2)
|
||||
# call a reducer to schedule a procedure
|
||||
self.call("schedule_procedure")
|
||||
|
||||
time.sleep(2)
|
||||
lines = sum(1 for line in self.logs(100) if "Invoked:" in line)
|
||||
# scheduled procedure should be ran by now
|
||||
self.assertEqual(lines, 1)
|
||||
|
||||
row_entry = {
|
||||
"prev": TIMESTAMP_ZERO,
|
||||
"scheduled_id": 2,
|
||||
"sched_at": {"Time": TIMESTAMP_ZERO},
|
||||
}
|
||||
# subscription should have 2 updates, first for row insert in scheduled table and second for row deletion.
|
||||
self.assertEqual(
|
||||
sub(),
|
||||
[
|
||||
{"scheduled_table": {"deletes": [], "inserts": [row_entry]}},
|
||||
{"scheduled_table": {"deletes": [row_entry], "inserts": []}},
|
||||
],
|
||||
)
|
||||
|
||||
def test_scheduled_table_subscription_repeated_procedure(self):
|
||||
"""This test deploys a module with a repeated procedure and check if client receives subscription update for scheduled table entry and no delete entry"""
|
||||
# subscribe to empty scheduled_table
|
||||
sub = self.subscribe("SELECT * FROM scheduled_table", n=2)
|
||||
# call a reducer to schedule a procedure
|
||||
self.call("schedule_repeated_procedure")
|
||||
|
||||
time.sleep(2)
|
||||
lines = sum(1 for line in self.logs(100) if "Invoked:" in line)
|
||||
# repeated procedure should have run more than once.
|
||||
self.assertLess(2, lines)
|
||||
|
||||
# scheduling repeated procedure again just to get 2nd subscription update.
|
||||
self.call("schedule_procedure")
|
||||
|
||||
repeated_row_entry = {
|
||||
"prev": TIMESTAMP_ZERO,
|
||||
"scheduled_id": 1,
|
||||
"sched_at": {"Interval": {"__time_duration_micros__": 100000}},
|
||||
}
|
||||
row_entry = {
|
||||
"prev": TIMESTAMP_ZERO,
|
||||
"scheduled_id": 2,
|
||||
"sched_at": {"Time": TIMESTAMP_ZERO},
|
||||
}
|
||||
|
||||
# subscription should have 2 updates and should not have any deletes
|
||||
self.assertEqual(
|
||||
sub(),
|
||||
[
|
||||
{"scheduled_table": {"deletes": [], "inserts": [repeated_row_entry]}},
|
||||
{"scheduled_table": {"deletes": [], "inserts": [row_entry]}},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class VolatileNonatomicScheduleImmediate(Smoketest):
|
||||
BINDINGS_FEATURES = ["unstable"]
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::table(accessor = my_table, public)]
|
||||
pub struct MyTable {
|
||||
x: String,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
fn do_schedule(_ctx: &ReducerContext) {
|
||||
spacetimedb::volatile_nonatomic_schedule_immediate!(do_insert("hello".to_owned()));
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
fn do_insert(ctx: &ReducerContext, x: String) {
|
||||
ctx.db.my_table().insert(MyTable { x });
|
||||
}
|
||||
"""
|
||||
|
||||
def test_volatile_nonatomic_schedule_immediate(self):
|
||||
"""Check that volatile_nonatomic_schedule_immediate works"""
|
||||
|
||||
sub = self.subscribe("SELECT * FROM my_table", n=2)
|
||||
|
||||
self.call("do_insert", "yay!")
|
||||
self.call("do_schedule")
|
||||
|
||||
self.assertEqual(
|
||||
sub(),
|
||||
[
|
||||
{"my_table": {"deletes": [], "inserts": [{"x": "yay!"}]}},
|
||||
{"my_table": {"deletes": [], "inserts": [{"x": "hello"}]}},
|
||||
],
|
||||
)
|
||||
@@ -1,36 +0,0 @@
|
||||
from .. import Smoketest, extract_field, requires_local_server
|
||||
import re
|
||||
|
||||
# We require a local server because these tests have hardcoded server addresses.
|
||||
@requires_local_server
|
||||
class Servers(Smoketest):
|
||||
AUTOPUBLISH = False
|
||||
|
||||
def test_servers(self):
|
||||
"""Verify that we can add and list server configurations"""
|
||||
|
||||
out = self.spacetime("server", "add", "--url", "https://testnet.spacetimedb.com", "testnet", "--no-fingerprint")
|
||||
self.assertEqual(extract_field(out, "Host:"), "testnet.spacetimedb.com")
|
||||
self.assertEqual(extract_field(out, "Protocol:"), "https")
|
||||
|
||||
servers = self.spacetime("server", "list")
|
||||
self.assertRegex(servers, re.compile(r"^\s*testnet\.spacetimedb\.com\s+https\s+testnet\s*$", re.M))
|
||||
self.assertRegex(servers, re.compile(r"^\s*\*\*\*\s+127\.0\.0\.1:3000\s+http\s+localhost\s*$", re.M))
|
||||
|
||||
out = self.spacetime("server", "fingerprint", "http://127.0.0.1:3000", "-y")
|
||||
self.assertIn("No saved fingerprint for server 127.0.0.1:3000.", out)
|
||||
|
||||
out = self.spacetime("server", "fingerprint", "http://127.0.0.1:3000")
|
||||
self.assertIn("Fingerprint is unchanged for server 127.0.0.1:3000", out)
|
||||
|
||||
out = self.spacetime("server", "fingerprint", "localhost")
|
||||
self.assertIn("Fingerprint is unchanged for server localhost", out)
|
||||
|
||||
def test_edit_server(self):
|
||||
"""Verify that we can edit server configurations"""
|
||||
|
||||
out = self.spacetime("server", "add", "--url", "https://foo.com", "foo", "--no-fingerprint")
|
||||
out = self.spacetime("server", "edit", "foo", "--url", "https://edited-testnet.spacetimedb.com", "--new-name", "edited-testnet", "--no-fingerprint", "--yes")
|
||||
|
||||
servers = self.spacetime("server", "list")
|
||||
self.assertRegex(servers, re.compile(r"^\s*edited-testnet\.spacetimedb\.com\s+https\s+edited-testnet\s*$", re.M))
|
||||
@@ -1,174 +0,0 @@
|
||||
from .. import Smoketest
|
||||
|
||||
|
||||
class SqlFormat(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::sats::{i256, u256};
|
||||
use spacetimedb::{ConnectionId, Identity, ReducerContext, Table, Timestamp, TimeDuration, SpacetimeType, Uuid};
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[spacetimedb::table(accessor = t_ints)]
|
||||
pub struct TInts {
|
||||
i8: i8,
|
||||
i16: i16,
|
||||
i32: i32,
|
||||
i64: i64,
|
||||
i128: i128,
|
||||
i256: i256,
|
||||
}
|
||||
|
||||
#[spacetimedb::table(accessor = t_ints_tuple)]
|
||||
pub struct TIntsTuple {
|
||||
tuple: TInts,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[spacetimedb::table(accessor = t_uints)]
|
||||
pub struct TUints {
|
||||
u8: u8,
|
||||
u16: u16,
|
||||
u32: u32,
|
||||
u64: u64,
|
||||
u128: u128,
|
||||
u256: u256,
|
||||
}
|
||||
|
||||
#[spacetimedb::table(accessor = t_uints_tuple)]
|
||||
pub struct TUintsTuple {
|
||||
tuple: TUints,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[spacetimedb::table(accessor = t_others)]
|
||||
pub struct TOthers {
|
||||
bool: bool,
|
||||
f32: f32,
|
||||
f64: f64,
|
||||
str: String,
|
||||
bytes: Vec<u8>,
|
||||
identity: Identity,
|
||||
connection_id: ConnectionId,
|
||||
timestamp: Timestamp,
|
||||
duration: TimeDuration,
|
||||
uuid: Uuid,
|
||||
}
|
||||
|
||||
#[spacetimedb::table(accessor = t_others_tuple)]
|
||||
pub struct TOthersTuple {
|
||||
tuple: TOthers
|
||||
}
|
||||
|
||||
#[derive(SpacetimeType, Debug, Clone, Copy)]
|
||||
pub enum Action {
|
||||
Inactive,
|
||||
Active,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[spacetimedb::table(accessor = t_enums)]
|
||||
pub struct TEnums {
|
||||
bool_opt: Option<bool>,
|
||||
bool_result: Result<bool, String>,
|
||||
action: Action,
|
||||
}
|
||||
|
||||
#[spacetimedb::table(accessor = t_enums_tuple)]
|
||||
pub struct TEnumsTuple {
|
||||
tuple: TEnums,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn test(ctx: &ReducerContext) {
|
||||
let tuple = TInts {
|
||||
i8: -25,
|
||||
i16: -3224,
|
||||
i32: -23443,
|
||||
i64: -2344353,
|
||||
i128: -234434897853,
|
||||
i256: (-234434897853i128).into(),
|
||||
};
|
||||
ctx.db.t_ints().insert(tuple);
|
||||
ctx.db.t_ints_tuple().insert(TIntsTuple { tuple });
|
||||
|
||||
let tuple = TUints {
|
||||
u8: 105,
|
||||
u16: 1050,
|
||||
u32: 83892,
|
||||
u64: 48937498,
|
||||
u128: 4378528978889,
|
||||
u256: 4378528978889u128.into(),
|
||||
};
|
||||
ctx.db.t_uints().insert(tuple);
|
||||
ctx.db.t_uints_tuple().insert(TUintsTuple { tuple });
|
||||
|
||||
let tuple = TOthers {
|
||||
bool: true,
|
||||
f32: 594806.58906,
|
||||
f64: -3454353.345389043278459,
|
||||
str: "This is spacetimedb".to_string(),
|
||||
bytes: vec!(1, 2, 3, 4, 5, 6, 7),
|
||||
identity: Identity::ONE,
|
||||
connection_id: ConnectionId::ZERO,
|
||||
timestamp: Timestamp::UNIX_EPOCH,
|
||||
duration: TimeDuration::ZERO,
|
||||
uuid: Uuid::NIL,
|
||||
};
|
||||
ctx.db.t_others().insert(tuple.clone());
|
||||
ctx.db.t_others_tuple().insert(TOthersTuple { tuple });
|
||||
|
||||
let tuple = TEnums {
|
||||
bool_opt: Some(true),
|
||||
bool_result: Ok(false),
|
||||
action: Action::Active,
|
||||
};
|
||||
|
||||
ctx.db.t_enums().insert(tuple.clone());
|
||||
ctx.db.t_enums_tuple().insert(TEnumsTuple { tuple });
|
||||
}
|
||||
"""
|
||||
|
||||
def test_sql_format(self):
|
||||
"""This test is designed to test the format of the output of sql queries"""
|
||||
|
||||
self.call("test")
|
||||
|
||||
self.assertSql("SELECT * FROM t_ints", """\
|
||||
i_8 | i_16 | i_32 | i_64 | i_128 | i_256
|
||||
-----+-------+--------+----------+---------------+---------------
|
||||
-25 | -3224 | -23443 | -2344353 | -234434897853 | -234434897853
|
||||
""")
|
||||
self.assertSql("SELECT * FROM t_ints_tuple", """\
|
||||
tuple
|
||||
---------------------------------------------------------------------------------------------------------
|
||||
(i_8 = -25, i_16 = -3224, i_32 = -23443, i_64 = -2344353, i_128 = -234434897853, i_256 = -234434897853)
|
||||
""")
|
||||
self.assertSql("SELECT * FROM t_uints", """\
|
||||
u_8 | u_16 | u_32 | u_64 | u_128 | u_256
|
||||
-----+------+-------+----------+---------------+---------------
|
||||
105 | 1050 | 83892 | 48937498 | 4378528978889 | 4378528978889
|
||||
""")
|
||||
self.assertSql("SELECT * FROM t_uints_tuple", """\
|
||||
tuple
|
||||
-------------------------------------------------------------------------------------------------------
|
||||
(u_8 = 105, u_16 = 1050, u_32 = 83892, u_64 = 48937498, u_128 = 4378528978889, u_256 = 4378528978889)
|
||||
""")
|
||||
self.assertSql("SELECT * FROM t_others", """\
|
||||
bool | f_32 | f_64 | str | bytes | identity | connection_id | timestamp | duration | uuid
|
||||
------+-----------+--------------------+-----------------------+------------------+--------------------------------------------------------------------+------------------------------------+---------------------------+-----------+----------------------------------------
|
||||
true | 594806.56 | -3454353.345389043 | "This is spacetimedb" | 0x01020304050607 | 0x0000000000000000000000000000000000000000000000000000000000000001 | 0x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | +0.000000 | "00000000-0000-0000-0000-000000000000"
|
||||
""")
|
||||
self.assertSql("SELECT * FROM t_others_tuple", """\
|
||||
tuple
|
||||
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
(bool = true, f_32 = 594806.56, f_64 = -3454353.345389043, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000001, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = +0.000000, uuid = "00000000-0000-0000-0000-000000000000")
|
||||
""")
|
||||
self.assertSql("SELECT * FROM t_enums", """\
|
||||
bool_opt | bool_result | action
|
||||
---------------+--------------+---------------
|
||||
(some = true) | (ok = false) | (active = ())
|
||||
""")
|
||||
self.assertSql("SELECT * FROM t_enums_tuple", """\
|
||||
tuple
|
||||
--------------------------------------------------------------------------------
|
||||
(bool_opt = (some = true), bool_result = (ok = false), action = (active = ()))
|
||||
""")
|
||||
@@ -1,558 +0,0 @@
|
||||
import json
|
||||
import toml
|
||||
|
||||
from .. import COMPOSE_FILE, Smoketest, parse_sql_result, random_string, spacetime
|
||||
from ..docker import DockerManager
|
||||
from ..tests.replication import Cluster
|
||||
|
||||
OWNER = "Owner"
|
||||
ADMIN = "Admin"
|
||||
DEVELOPER = "Developer"
|
||||
VIEWER = "Viewer"
|
||||
|
||||
ROLES = [OWNER, ADMIN, DEVELOPER, VIEWER]
|
||||
|
||||
def get(d: dict, k):
|
||||
return (k, d[k])
|
||||
|
||||
class CreateChildDatabase(Smoketest):
|
||||
AUTOPUBLISH = False
|
||||
|
||||
def test_create_child_database(self):
|
||||
"""
|
||||
Test that the owner can add a child database,
|
||||
and that deleting the parent also deletes the child.
|
||||
"""
|
||||
|
||||
parent_name = random_string()
|
||||
child_name = random_string()
|
||||
|
||||
self.publish_module(parent_name)
|
||||
parent_identity = self.database_identity
|
||||
self.publish_module(f"{parent_name}/{child_name}")
|
||||
child_identity = self.database_identity
|
||||
|
||||
databases = self.query_controldb(parent_identity, child_identity)
|
||||
self.assertEqual(2, len(databases))
|
||||
|
||||
self.spacetime("delete", "--yes", parent_name)
|
||||
|
||||
databases = self.query_controldb(parent_identity, child_identity)
|
||||
self.assertEqual(0, len(databases))
|
||||
|
||||
def query_controldb(self, parent, child):
|
||||
res = self.spacetime(
|
||||
"sql",
|
||||
"spacetime-control",
|
||||
f"select * from database where database_identity = 0x{parent} or database_identity = 0x{child}"
|
||||
)
|
||||
return parse_sql_result(str(res))
|
||||
|
||||
|
||||
class ChangeDatabaseHierarchy(Smoketest):
|
||||
AUTOPUBLISH = False
|
||||
|
||||
def test_change_database_hierarchy(self):
|
||||
"""
|
||||
Test that changing the hierarchy of an existing database is not
|
||||
supported.
|
||||
"""
|
||||
|
||||
parent_name = f"parent-{random_string()}"
|
||||
sibling_name = f"sibling-{random_string()}"
|
||||
child_name = f"child-{random_string()}"
|
||||
|
||||
self.publish_module(parent_name)
|
||||
self.publish_module(sibling_name)
|
||||
|
||||
# Publish as a child of 'parent_name'.
|
||||
self.publish_module(f"{parent_name}/{child_name}")
|
||||
|
||||
# Publishing again with a different parent is rejected...
|
||||
with self.assertRaises(Exception):
|
||||
self.publish_module(f"{sibling_name}/{child_name}", clear = False)
|
||||
# ..even if `clear = True`
|
||||
with self.assertRaises(Exception):
|
||||
self.publish_module(f"{sibling_name}/{child_name}", clear = True)
|
||||
|
||||
# Publishing again with the same parent is ok.
|
||||
self.publish_module(f"{parent_name}/{child_name}", clear = False)
|
||||
|
||||
|
||||
class TeamsPermissionsTest(Smoketest):
|
||||
AUTOPUBLISH = False
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.root_config = cls.project_path / "root_config"
|
||||
spacetime("--config-path", cls.root_config, "server", "set-default", "local")
|
||||
|
||||
def setUp(self):
|
||||
self.docker = DockerManager(COMPOSE_FILE)
|
||||
self.root_token = self.docker.generate_root_token()
|
||||
|
||||
self.cluster = Cluster(self.docker, self)
|
||||
|
||||
def create_identity(self):
|
||||
"""
|
||||
Obtain a fresh identity and token from the server.
|
||||
Doesn't alter the config.toml for this test instance.
|
||||
"""
|
||||
resp = self.api_call("POST", "/v1/identity")
|
||||
return json.loads(resp)
|
||||
|
||||
def create_collaborators(self, database):
|
||||
"""
|
||||
Create collaborators for the current database, one for each role.
|
||||
"""
|
||||
collaborators = {}
|
||||
for role in ROLES:
|
||||
identity_and_token = self.create_identity()
|
||||
self.call_controldb_reducer(
|
||||
"upsert_collaborator",
|
||||
{"Name": database},
|
||||
[f"0x{identity_and_token['identity']}"],
|
||||
{role: {}}
|
||||
)
|
||||
collaborators[role] = identity_and_token
|
||||
return collaborators
|
||||
|
||||
def create_organization(self):
|
||||
"""
|
||||
Create an organization with one member per role.
|
||||
"""
|
||||
members = {}
|
||||
organization_identity = self.create_identity()['identity']
|
||||
for role in ROLES:
|
||||
member = self.create_identity()
|
||||
self.call_controldb_reducer(
|
||||
"upsert_organization_member",
|
||||
[f"0x{organization_identity}"],
|
||||
[f"0x{member['identity']}"],
|
||||
{role: {}}
|
||||
)
|
||||
members[role] = member
|
||||
organization = {
|
||||
"organization": organization_identity,
|
||||
"members": members
|
||||
}
|
||||
self.organization = organization
|
||||
|
||||
return organization
|
||||
|
||||
def make_admin(self):
|
||||
"""
|
||||
Create an admin account for the currently logged-in identity.
|
||||
"""
|
||||
identity = str(self.spacetime("login", "show")).split()[-1]
|
||||
spacetime("--config-path", self.root_config, "login", "--token", self.root_token)
|
||||
spacetime("--config-path", self.root_config, "call",
|
||||
"spacetime-control", "create_admin_account", f"0x{identity}")
|
||||
|
||||
def call_controldb_reducer(self, reducer, *args):
|
||||
"""
|
||||
Call a controldb reducer.
|
||||
"""
|
||||
self.spacetime("call", "spacetime-control", reducer, *map(json.dumps, args))
|
||||
|
||||
def login_with(self, identity_and_token: dict):
|
||||
self.spacetime("logout")
|
||||
config = toml.load(self.config_path)
|
||||
config['spacetimedb_token'] = identity_and_token['token']
|
||||
with open(self.config_path, 'w') as f:
|
||||
toml.dump(config, f)
|
||||
|
||||
def publish_as(self, role_and_token, module, code = None, clear = False, org = None):
|
||||
print(f"publishing {module} with org {org} as {role_and_token[0]}:")
|
||||
code = self.MODULE_CODE if code is None else code
|
||||
print(f"{code}")
|
||||
self.login_with(role_and_token[1])
|
||||
self.write_module_code(code)
|
||||
self.publish_module(module, clear = clear, organization = org)
|
||||
return self.database_identity
|
||||
|
||||
def sql_as(self, role_and_token, database, sql):
|
||||
"""
|
||||
Log in as `token` and run an SQL statement against `database`
|
||||
"""
|
||||
print(f"running sql as {role_and_token[0]}: {sql}")
|
||||
self.login_with(role_and_token[1])
|
||||
res = self.spacetime("sql", database, sql)
|
||||
return parse_sql_result(str(res))
|
||||
|
||||
def subscribe_as(self, role_and_token, *queries, n):
|
||||
"""
|
||||
Log in as `token` and subscribe to the current database using `queries`.
|
||||
"""
|
||||
print(f"subscribe as {role_and_token[0]}: {queries}")
|
||||
self.login_with(role_and_token[1])
|
||||
return self.subscribe(*queries, n = n)
|
||||
|
||||
def tearDown(self):
|
||||
if "organization" in self.__dict__:
|
||||
# Log in as owner
|
||||
self.login_with(self.organization['members'][OWNER])
|
||||
# Delete database (requires org to still exist)
|
||||
super().tearDown()
|
||||
# Delete org
|
||||
try:
|
||||
self.call_controldb_reducer(
|
||||
"delete_organization",
|
||||
[f"0x{self.organization['organization']}"]
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
super().tearDown()
|
||||
|
||||
|
||||
class TeamsMutableSql(TeamsPermissionsTest):
|
||||
MODULE_CODE = """
|
||||
#[spacetimedb::table(accessor = person, public)]
|
||||
struct Person {
|
||||
name: String,
|
||||
}
|
||||
"""
|
||||
|
||||
def run_test(self, database, team):
|
||||
for role, token in team.items():
|
||||
self.login_with(token)
|
||||
dml = f"insert into person (name) values ('bob-the-{role}')"
|
||||
if role == OWNER or role == ADMIN:
|
||||
self.spacetime("sql", database, dml)
|
||||
else:
|
||||
with self.assertRaises(Exception):
|
||||
self.spacetime("sql", database, dml)
|
||||
|
||||
class CollaboratorsMutableSql(TeamsMutableSql):
|
||||
def test_permissions_mut_sql_collaborators(self):
|
||||
"""
|
||||
Tests that only owner and admin collaborators can perform mutable SQL
|
||||
transactions.
|
||||
"""
|
||||
|
||||
name = random_string()
|
||||
self.publish_module(name)
|
||||
team = self.create_collaborators(name)
|
||||
self.run_test(name, team)
|
||||
|
||||
|
||||
class OrgMutableSql(TeamsMutableSql):
|
||||
def test_org_permissions_mut_sql_org_members(self):
|
||||
"""
|
||||
Tests that only owner and admin organization members can perform mutable
|
||||
SQL transactions.
|
||||
"""
|
||||
|
||||
self.make_admin()
|
||||
org = self.create_organization()
|
||||
name = random_string()
|
||||
|
||||
self.login_with(org['members'][OWNER])
|
||||
self.publish_module(name, organization = f"0x{org['organization']}")
|
||||
|
||||
self.run_test(name, org['members'])
|
||||
|
||||
|
||||
class TeamsPublishDatabase(TeamsPermissionsTest):
|
||||
MODULE_CODE = """
|
||||
#[spacetimedb::table(accessor = person, public)]
|
||||
struct Person {
|
||||
name: String,
|
||||
}
|
||||
"""
|
||||
|
||||
MODULE_CODE_OWNER = MODULE_CODE + """
|
||||
#[spacetimedb::table(accessor = owner)]
|
||||
struct Owner {
|
||||
name: String,
|
||||
}
|
||||
"""
|
||||
|
||||
MODULE_CODE_ADMIN = MODULE_CODE_OWNER + """
|
||||
#[spacetimedb::table(accessor = admin)]
|
||||
struct Admin {
|
||||
name: String,
|
||||
}
|
||||
"""
|
||||
|
||||
MODULE_CODE_DEVELOPER = MODULE_CODE_ADMIN + """
|
||||
#[spacetimedb::table(accessor = developer)]
|
||||
struct Developer {
|
||||
name: String,
|
||||
}
|
||||
"""
|
||||
|
||||
MODULE_CODE_VIEWER = MODULE_CODE_DEVELOPER + """
|
||||
#[spacetimedb::table(accessor = viewer)]
|
||||
struct Viewer {
|
||||
name: String,
|
||||
}
|
||||
"""
|
||||
|
||||
MODULES = {
|
||||
OWNER: MODULE_CODE_OWNER,
|
||||
ADMIN: MODULE_CODE_ADMIN,
|
||||
DEVELOPER: MODULE_CODE_DEVELOPER,
|
||||
VIEWER: MODULE_CODE_VIEWER
|
||||
}
|
||||
|
||||
def run_test(self, parent, child, team, org):
|
||||
self.assert_all_except_viewer_can_update(parent, team, org = org)
|
||||
|
||||
# Create a child database.
|
||||
child_path = f"{parent}/{child}"
|
||||
|
||||
# Developer and viewer should not be able to create a child.
|
||||
for role in [DEVELOPER, VIEWER]:
|
||||
with self.assertRaises(Exception):
|
||||
self.publish_as(get(team, role), child_path, self.MODULE_CODE, org = org)
|
||||
# But admin should succeed.
|
||||
self.publish_as(get(team, ADMIN), child_path, self.MODULE_CODE, org = org)
|
||||
|
||||
# Once created, only viewer should be denied updating.
|
||||
self.assert_all_except_viewer_can_update(child_path, team, org)
|
||||
|
||||
def assert_all_except_viewer_can_update(self, database, team, org):
|
||||
for role in [OWNER, ADMIN, DEVELOPER]:
|
||||
self.publish_as(get(team, role), database, self.MODULES[role], org = org)
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
self.publish_as(get(team, VIEWER), database, self.MODULES[VIEWER], org = org)
|
||||
|
||||
|
||||
class CollaboratorsPublishDatabase(TeamsPublishDatabase):
|
||||
def test_permissions_publish_collaborators(self):
|
||||
"""
|
||||
Tests that only owner, admin and developer collaborators can publish a
|
||||
database.
|
||||
"""
|
||||
|
||||
parent = random_string()
|
||||
child = random_string()
|
||||
self.publish_module(parent)
|
||||
team = self.create_collaborators(parent)
|
||||
|
||||
self.run_test(parent, child, team, org = None)
|
||||
|
||||
|
||||
class OrgPublishDatabase(TeamsPublishDatabase):
|
||||
def test_permissions_publish_org_members(self):
|
||||
"""
|
||||
Tests that only owner, admin and developer organization members can
|
||||
publish a database.
|
||||
"""
|
||||
|
||||
self.make_admin()
|
||||
org = self.create_organization()
|
||||
parent = random_string()
|
||||
child = random_string()
|
||||
|
||||
self.login_with(org['members'][OWNER])
|
||||
self.publish_module(parent, organization = f"0x{org['organization']}")
|
||||
|
||||
self.run_test(parent, child, org['members'],
|
||||
org = org['organization'])
|
||||
|
||||
|
||||
class TeamsClearDatabase(TeamsPermissionsTest):
|
||||
def assert_can_clear(self, auth, database):
|
||||
self.publish_as(auth, database, clear = True)
|
||||
|
||||
def assert_cannot_clear(self, auth, database):
|
||||
with self.assertRaises(Exception):
|
||||
self.publish_as(auth, database, clear = True)
|
||||
|
||||
def assert_clear_permissions(self, team, database):
|
||||
for role in [OWNER, ADMIN]:
|
||||
self.assert_can_clear(get(team, role), database)
|
||||
|
||||
for role in [DEVELOPER, VIEWER]:
|
||||
self.assert_cannot_clear(get(team, role), database)
|
||||
|
||||
|
||||
class CollaboratorsClearDatabase(TeamsClearDatabase):
|
||||
def test_permissions_clear_collaborators(self):
|
||||
"""
|
||||
Tests that only owner and admin collaborators can clear a database.
|
||||
"""
|
||||
|
||||
parent = random_string()
|
||||
self.publish_module(parent)
|
||||
# First degree owner can clear.
|
||||
self.publish_module(parent, clear = True)
|
||||
|
||||
team = self.create_collaborators(parent)
|
||||
self.assert_clear_permissions(team, parent)
|
||||
|
||||
# Child databases cannot be cleared at all
|
||||
child = f"{parent}/{random_string()}"
|
||||
self.publish_as(get(team, OWNER), child)
|
||||
for auth in team.items():
|
||||
self.assert_cannot_clear(auth, child)
|
||||
|
||||
|
||||
class OrgClearDatabase(TeamsClearDatabase):
|
||||
def test_permissions_clear_org(self):
|
||||
"""
|
||||
Test that only owner or admin org members can clear a database.
|
||||
"""
|
||||
|
||||
self.make_admin()
|
||||
org = self.create_organization()
|
||||
team = org['members']
|
||||
|
||||
parent = random_string()
|
||||
|
||||
self.login_with(org['members'][OWNER])
|
||||
self.publish_module(parent, organization = f"0x{org['organization']}")
|
||||
self.assert_clear_permissions(team, parent)
|
||||
|
||||
# Child databases cannot be cleared at all
|
||||
child = f"{parent}/{random_string()}"
|
||||
self.publish_as(get(team, ADMIN), child)
|
||||
for auth in team.items():
|
||||
self.assert_cannot_clear(auth, child)
|
||||
|
||||
|
||||
class TeamsDeleteDatabase(TeamsPermissionsTest):
|
||||
def delete_as(self, role_and_token, database):
|
||||
print(f"delete {database} as {role_and_token[0]}")
|
||||
self.login_with(role_and_token[1])
|
||||
self.spacetime("delete", "--yes", database)
|
||||
|
||||
|
||||
class CollaboratorsDeleteDatabase(TeamsDeleteDatabase):
|
||||
def test_permissions_delete_collaborators(self):
|
||||
"""
|
||||
Tests that only owners can delete databases.
|
||||
"""
|
||||
|
||||
parent = random_string()
|
||||
child = random_string()
|
||||
self.publish_module(parent)
|
||||
self.spacetime("delete", "--yes", parent)
|
||||
|
||||
self.publish_module(parent)
|
||||
|
||||
team = self.create_collaborators(parent)
|
||||
for role in [ADMIN, DEVELOPER, VIEWER]:
|
||||
with self.assertRaises(Exception):
|
||||
self.delete_as(get(team, role), parent)
|
||||
|
||||
child_path = f"{parent}/{child}"
|
||||
|
||||
# If admin creates a child, they should also be able to delete it,
|
||||
# because they are the owner of the child.
|
||||
print("publish and delete as admin")
|
||||
self.publish_as(get(team, ADMIN), child_path)
|
||||
self.delete_as(get(team, ADMIN), child)
|
||||
|
||||
# The owner role should be able to delete.
|
||||
print("publish as admin, delete as owner")
|
||||
self.publish_as(get(team, ADMIN), child_path)
|
||||
self.delete_as(get(team, OWNER), child)
|
||||
|
||||
# Anyone else should be denied if not direct owner.
|
||||
print("publish as owner, deny deletion by admin, developer, viewer")
|
||||
self.publish_as(get(team, OWNER), child_path)
|
||||
for role in [ADMIN, DEVELOPER, VIEWER]:
|
||||
with self.assertRaises(Exception):
|
||||
self.delete_as(get(team, role), child)
|
||||
|
||||
print("delete child as owner")
|
||||
self.delete_as(get(team, OWNER), child)
|
||||
|
||||
print("delete parent as owner")
|
||||
self.delete_as(get(team, OWNER), parent)
|
||||
|
||||
|
||||
class OrgDeleteDatabase(TeamsDeleteDatabase):
|
||||
def test_permissions_delete_org(self):
|
||||
"""
|
||||
Tests that only organization owners can delete databases.
|
||||
"""
|
||||
|
||||
self.make_admin()
|
||||
org = self.create_organization()
|
||||
team = org['members']
|
||||
parent = random_string()
|
||||
child = random_string()
|
||||
|
||||
self.login_with(org['members'][OWNER])
|
||||
self.publish_module(parent, organization = f"0x{org['organization']}")
|
||||
self.publish_module(f"{parent}/{child}")
|
||||
|
||||
# Org databases can only be deleted by owners
|
||||
# because ownership is transferred to the org
|
||||
# and publisher attribution is lost.
|
||||
for database in [child, parent]:
|
||||
for role in [ADMIN, DEVELOPER, VIEWER]:
|
||||
with self.assertRaises(Exception):
|
||||
self.delete_as(get(team, role), database)
|
||||
|
||||
self.delete_as(get(team, OWNER), child)
|
||||
self.delete_as(get(team, OWNER), parent)
|
||||
|
||||
|
||||
class TeamsPrivateTables(TeamsPermissionsTest):
|
||||
def run_test(self, database, team):
|
||||
owner = get(team, OWNER)
|
||||
self.sql_as(owner, database, "insert into person (name) values ('horsti')")
|
||||
|
||||
for auth in team.items():
|
||||
rows = self.sql_as(auth, database, "select * from person")
|
||||
self.assertEqual(rows, [{ "name": '"horsti"' }])
|
||||
|
||||
for auth in team.items():
|
||||
sub = self.subscribe_as(auth, "select * from person", n = 2)
|
||||
self.sql_as(owner, database, "insert into person (name) values ('hansmans')")
|
||||
self.sql_as(owner, database, "delete from person where name = 'hansmans'")
|
||||
res = sub()
|
||||
self.assertEqual(
|
||||
res,
|
||||
[
|
||||
{
|
||||
'person': {
|
||||
'deletes': [],
|
||||
'inserts': [{'name': 'hansmans'}]
|
||||
}
|
||||
},
|
||||
{
|
||||
'person': {
|
||||
'deletes': [{'name': 'hansmans'}],
|
||||
'inserts': []
|
||||
}
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class CollaboratorsPrivateTables(TeamsPrivateTables):
|
||||
def test_permissions_private_tables(self):
|
||||
"""
|
||||
Test that all collaborators can read private tables.
|
||||
"""
|
||||
|
||||
database = random_string()
|
||||
self.publish_module(database)
|
||||
|
||||
team = self.create_collaborators(database)
|
||||
self.run_test(database, team)
|
||||
|
||||
|
||||
class OrgPrivateTables(TeamsPrivateTables):
|
||||
def test_org_permissions_private_tables(self):
|
||||
"""
|
||||
Test that all organization members can read private tables.
|
||||
"""
|
||||
|
||||
self.make_admin()
|
||||
org = self.create_organization()
|
||||
database = random_string()
|
||||
|
||||
self.login_with(org['members'][OWNER])
|
||||
self.publish_module(database, organization = f"0x{org['organization']}")
|
||||
|
||||
self.run_test(database, org['members'])
|
||||
@@ -1,35 +0,0 @@
|
||||
from .. import Smoketest, random_string
|
||||
import unittest
|
||||
import json
|
||||
import io
|
||||
|
||||
TIMESTAMP_TAG = "__timestamp_micros_since_unix_epoch__"
|
||||
|
||||
class TimestampRoute(Smoketest):
|
||||
AUTO_PUBLISH = False
|
||||
|
||||
def test_timestamp_route(self):
|
||||
name = random_string()
|
||||
|
||||
# A request for the timestamp at a non-existent database is an error...
|
||||
with self.assertRaises(Exception) as err:
|
||||
self.api_call(
|
||||
"GET",
|
||||
f"/v1/database/{name}/unstable/timestamp",
|
||||
)
|
||||
# ... with code 404.
|
||||
self.assertEqual(err.exception.args[0].status, 404)
|
||||
|
||||
self.publish_module(name)
|
||||
|
||||
# A request for the timestamp at an extant database is a success...
|
||||
resp = self.api_call(
|
||||
"GET",
|
||||
f"/v1/database/{name}/unstable/timestamp",
|
||||
)
|
||||
|
||||
# ... and the response body is a SATS-JSON encoded `Timestamp`.
|
||||
timestamp = json.load(io.BytesIO(resp))
|
||||
self.assertIsInstance(timestamp, dict)
|
||||
self.assertIn(TIMESTAMP_TAG, timestamp)
|
||||
self.assertIsInstance(timestamp[TIMESTAMP_TAG], int)
|
||||
@@ -1,822 +0,0 @@
|
||||
from .. import Smoketest, random_string
|
||||
|
||||
class Views(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::ViewContext;
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[spacetimedb::table(accessor = player_state)]
|
||||
pub struct PlayerState {
|
||||
#[primary_key]
|
||||
id: u64,
|
||||
#[index(btree)]
|
||||
level: u64,
|
||||
}
|
||||
|
||||
#[spacetimedb::view(accessor = player, public)]
|
||||
pub fn player(ctx: &ViewContext) -> Option<PlayerState> {
|
||||
ctx.db.player_state().id().find(0u64)
|
||||
}
|
||||
"""
|
||||
|
||||
def test_st_view_tables(self):
|
||||
"""This test asserts that views populate the st_view_* system tables"""
|
||||
|
||||
self.assertSql("SELECT * FROM st_view", """\
|
||||
view_id | view_name | table_id | is_public | is_anonymous
|
||||
---------+-----------+---------------+-----------+--------------
|
||||
4096 | "player" | (some = 4097) | true | false
|
||||
""")
|
||||
|
||||
self.assertSql("SELECT * FROM st_view_column", """\
|
||||
view_id | col_pos | col_name | col_type
|
||||
---------+---------+----------+----------
|
||||
4096 | 0 | "id" | 0x0d
|
||||
4096 | 1 | "level" | 0x0d
|
||||
""")
|
||||
|
||||
class FailPublish(Smoketest):
|
||||
AUTOPUBLISH = False
|
||||
|
||||
MODULE_CODE_BROKEN_NAMESPACE = """
|
||||
use spacetimedb::ViewContext;
|
||||
|
||||
#[spacetimedb::table(accessor = person, public)]
|
||||
pub struct Person {
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[spacetimedb::view(accessor = person, public)]
|
||||
pub fn person(ctx: &ViewContext) -> Option<Person> {
|
||||
None
|
||||
}
|
||||
"""
|
||||
|
||||
MODULE_CODE_BROKEN_RETURN_TYPE = """
|
||||
use spacetimedb::{SpacetimeType, ViewContext};
|
||||
|
||||
#[derive(SpacetimeType)]
|
||||
pub enum ABC {
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
}
|
||||
|
||||
#[spacetimedb::view(accessor = person, public)]
|
||||
pub fn person(ctx: &ViewContext) -> Option<ABC> {
|
||||
None
|
||||
}
|
||||
"""
|
||||
|
||||
def test_fail_publish_namespace_collision(self):
|
||||
"""Publishing a module should fail if a table and view have the same name"""
|
||||
|
||||
name = random_string()
|
||||
|
||||
self.write_module_code(self.MODULE_CODE_BROKEN_NAMESPACE)
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
self.publish_module(name)
|
||||
|
||||
def test_fail_publish_wrong_return_type(self):
|
||||
"""Publishing a module should fail if the inner return type is not a product type"""
|
||||
|
||||
name = random_string()
|
||||
|
||||
self.write_module_code(self.MODULE_CODE_BROKEN_RETURN_TYPE)
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
self.publish_module(name)
|
||||
|
||||
class SqlViews(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{AnonymousViewContext, ReducerContext, Table, ViewContext};
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[spacetimedb::table(accessor = player_state)]
|
||||
#[spacetimedb::table(accessor = player_level)]
|
||||
pub struct PlayerState {
|
||||
#[primary_key]
|
||||
id: u64,
|
||||
#[index(btree)]
|
||||
level: u64,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Clone)]
|
||||
#[spacetimedb::table(accessor = player_info, index(accessor=age_level_index, btree(columns = [age, level])))]
|
||||
pub struct PlayerInfo {
|
||||
#[primary_key]
|
||||
id: u64,
|
||||
age: u64,
|
||||
level: u64,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn add_player_level(ctx: &ReducerContext, id: u64, level: u64) {
|
||||
ctx.db.player_level().insert(PlayerState { id, level });
|
||||
}
|
||||
|
||||
#[spacetimedb::view(accessor = my_player_and_level, public)]
|
||||
pub fn my_player_and_level(ctx: &AnonymousViewContext) -> Option<PlayerState> {
|
||||
ctx.db.player_level().id().find(0)
|
||||
}
|
||||
|
||||
#[spacetimedb::view(accessor = player_and_level, public)]
|
||||
pub fn player_and_level(ctx: &AnonymousViewContext) -> Vec<PlayerState> {
|
||||
ctx.db.player_level().level().filter(2u64).collect()
|
||||
}
|
||||
|
||||
#[spacetimedb::view(accessor = player, public)]
|
||||
pub fn player(ctx: &ViewContext) -> Option<PlayerState> {
|
||||
log::info!("player view called");
|
||||
ctx.db.player_state().id().find(42)
|
||||
}
|
||||
|
||||
#[spacetimedb::view(accessor = player_none, public)]
|
||||
pub fn player_none(_ctx: &ViewContext) -> Option<PlayerState> {
|
||||
None
|
||||
}
|
||||
|
||||
#[spacetimedb::view(accessor = player_vec, public)]
|
||||
pub fn player_vec(ctx: &ViewContext) -> Vec<PlayerState> {
|
||||
let first = ctx.db.player_state().id().find(42).unwrap();
|
||||
let second = PlayerState { id: 7, level: 3 };
|
||||
vec![first, second]
|
||||
}
|
||||
|
||||
#[spacetimedb::view(accessor = player_info_multi_index, public)]
|
||||
pub fn player_info_view(ctx: &ViewContext) -> Option<PlayerInfo> {
|
||||
|
||||
log::info!("player_info called");
|
||||
ctx.db.player_info().age_level_index().filter((25u64, 7u64)).next()
|
||||
|
||||
}
|
||||
"""
|
||||
|
||||
def assertSql(self, sql, expected):
|
||||
self.maxDiff = None
|
||||
sql_out = self.spacetime("sql", self.database_identity, sql)
|
||||
sql_out = "\n".join([line.rstrip() for line in sql_out.splitlines()])
|
||||
expected = "\n".join([line.rstrip() for line in expected.splitlines()])
|
||||
|
||||
self.assertMultiLineEqual(sql_out, expected)
|
||||
|
||||
def insert_initial_data(self):
|
||||
self.spacetime(
|
||||
"sql",
|
||||
self.database_identity,
|
||||
"""\
|
||||
INSERT INTO player_state (id, level) VALUES (42, 7);
|
||||
""",
|
||||
)
|
||||
|
||||
def call_player_view(self):
|
||||
|
||||
self.assertSql("SELECT * FROM player", """\
|
||||
id | level
|
||||
----+-------
|
||||
42 | 7
|
||||
""")
|
||||
|
||||
def test_http_sql(self):
|
||||
"""This test asserts that views can be queried over HTTP SQL"""
|
||||
self.insert_initial_data()
|
||||
|
||||
self.call_player_view()
|
||||
|
||||
self.assertSql("SELECT * FROM player_none", """\
|
||||
id | level
|
||||
----+-------
|
||||
""")
|
||||
|
||||
self.assertSql("SELECT * FROM player_vec", """\
|
||||
id | level
|
||||
----+-------
|
||||
42 | 7
|
||||
7 | 3
|
||||
""")
|
||||
|
||||
# test is prefixed with 'a' to ensure it runs before any other tests,
|
||||
# since it relies on log capturing starting from an empty log.
|
||||
def test_a_view_materialization(self):
|
||||
"""This test asserts whether views are materialized correctly"""
|
||||
|
||||
player_called_log = "player view called"
|
||||
|
||||
# call view, with no data
|
||||
self.assertSql("SELECT * FROM player", """\
|
||||
id | level
|
||||
----+-------
|
||||
""")
|
||||
logs = self.logs(100)
|
||||
self.assertEqual(logs.count(player_called_log), 1)
|
||||
|
||||
self.insert_initial_data()
|
||||
|
||||
# Should invoke view as data is inserted
|
||||
self.call_player_view()
|
||||
|
||||
logs = self.logs(100)
|
||||
self.assertEqual(logs.count(player_called_log), 2)
|
||||
|
||||
self.call_player_view()
|
||||
# the view is cached
|
||||
logs = self.logs(100)
|
||||
self.assertEqual(logs.count(player_called_log), 2)
|
||||
|
||||
# inserting new row should not trigger view invocation due to readsets
|
||||
self.spacetime(
|
||||
"sql",
|
||||
self.database_identity,
|
||||
"""\
|
||||
INSERT INTO player_state (id, level) VALUES (22, 8);
|
||||
""",
|
||||
)
|
||||
|
||||
self.call_player_view()
|
||||
logs = self.logs(100)
|
||||
self.assertEqual(logs.count(player_called_log), 2)
|
||||
|
||||
# Updating the row that the view depends on should trigger re-evaluation
|
||||
self.spacetime(
|
||||
"sql",
|
||||
self.database_identity,
|
||||
"""
|
||||
UPDATE player_state SET level = 9 WHERE id = 42;
|
||||
""",
|
||||
)
|
||||
|
||||
# On fourth call, after updating the dependent row, the view is re-evaluated
|
||||
logs = self.logs(100)
|
||||
self.assertEqual(logs.count(player_called_log), 3)
|
||||
|
||||
|
||||
# Updating it back for other tests to work
|
||||
self.spacetime(
|
||||
"sql",
|
||||
self.database_identity,
|
||||
"""
|
||||
UPDATE player_state SET level = 7 WHERE id = 42;
|
||||
""",
|
||||
)
|
||||
|
||||
def test_view_multi_index_materialization(self):
|
||||
"""This test asserts whether views using multi-column indexes are materialized correctly"""
|
||||
|
||||
player_called_log = "player_info called"
|
||||
|
||||
# call view, with no data
|
||||
self.assertSql("SELECT * FROM player_info_multi_index", """\
|
||||
id | age | level
|
||||
----+-----+-------
|
||||
""")
|
||||
|
||||
logs = self.logs(100)
|
||||
self.assertEqual(logs.count(player_called_log), 1)
|
||||
|
||||
# Insert data
|
||||
self.spacetime(
|
||||
"sql",
|
||||
self.database_identity,
|
||||
"""\
|
||||
INSERT INTO player_info (id, age, level) VALUES (1, 25, 7);
|
||||
""",
|
||||
)
|
||||
|
||||
# Should invoke view as data is inserted
|
||||
self.assertSql("SELECT * FROM player_info_multi_index", """\
|
||||
id | age | level
|
||||
----+-----+-------
|
||||
1 | 25 | 7
|
||||
""")
|
||||
logs = self.logs(100)
|
||||
self.assertEqual(logs.count(player_called_log), 2)
|
||||
|
||||
|
||||
# Inserting a row that does not match should not trigger re-evaluation
|
||||
self.spacetime(
|
||||
"sql",
|
||||
self.database_identity,
|
||||
"""\
|
||||
INSERT INTO player_info (id, age, level) VALUES (2, 25, 8);
|
||||
""",
|
||||
)
|
||||
|
||||
logs = self.logs(100)
|
||||
self.assertEqual(logs.count(player_called_log), 2)
|
||||
|
||||
# Updating the row that the view depends on should trigger re-evaluation
|
||||
self.spacetime(
|
||||
"sql",
|
||||
self.database_identity,
|
||||
"""
|
||||
UPDATE player_info SET age = 26 WHERE id = 1;
|
||||
""",
|
||||
)
|
||||
logs = self.logs(100)
|
||||
self.assertEqual(logs.count(player_called_log), 3)
|
||||
self.assertSql("SELECT * FROM player_info_multi_index", """\
|
||||
id | age | level
|
||||
----+-----+-------
|
||||
""")
|
||||
|
||||
|
||||
def test_query_anonymous_view_reducer(self):
|
||||
"""Tests that anonymous views are updated for reducers"""
|
||||
self.call("add_player_level", 0, 1)
|
||||
self.call("add_player_level", 1, 2)
|
||||
|
||||
self.assertSql("SELECT * FROM my_player_and_level", """\
|
||||
id | level
|
||||
----+-------
|
||||
0 | 1
|
||||
""")
|
||||
|
||||
self.assertSql("SELECT * FROM player_and_level", """\
|
||||
id | level
|
||||
----+-------
|
||||
1 | 2
|
||||
""")
|
||||
|
||||
self.call("add_player_level", 2, 2)
|
||||
|
||||
self.assertSql("SELECT * FROM player_and_level", """\
|
||||
id | level
|
||||
----+-------
|
||||
1 | 2
|
||||
2 | 2
|
||||
""")
|
||||
|
||||
self.assertSql("SELECT * FROM player_and_level WHERE id = 2", """\
|
||||
id | level
|
||||
----+-------
|
||||
2 | 2
|
||||
""")
|
||||
|
||||
|
||||
class AutoMigrateViews(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::ViewContext;
|
||||
#[derive(Copy, Clone)]
|
||||
#[spacetimedb::table(accessor = player_state)]
|
||||
pub struct PlayerState {
|
||||
#[primary_key]
|
||||
id: u64,
|
||||
#[index(btree)]
|
||||
level: u64,
|
||||
}
|
||||
#[spacetimedb::view(accessor = player, public)]
|
||||
pub fn player(ctx: &ViewContext) -> Option<PlayerState> {
|
||||
ctx.db.player_state().id().find(1u64)
|
||||
}
|
||||
"""
|
||||
|
||||
MODULE_CODE_UPDATED = """
|
||||
use spacetimedb::ViewContext;
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[spacetimedb::table(accessor = player_state)]
|
||||
pub struct PlayerState {
|
||||
#[primary_key]
|
||||
id: u64,
|
||||
#[index(btree)]
|
||||
level: u64,
|
||||
}
|
||||
#[spacetimedb::view(accessor = player, public)]
|
||||
pub fn player(ctx: &ViewContext) -> Option<PlayerState> {
|
||||
ctx.db.player_state().id().find(2u64)
|
||||
}
|
||||
"""
|
||||
|
||||
def assertSql(self, sql, expected):
|
||||
self.maxDiff = None
|
||||
sql_out = self.spacetime("sql", self.database_identity, sql)
|
||||
sql_out = "\n".join([line.rstrip() for line in sql_out.splitlines()])
|
||||
expected = "\n".join([line.rstrip() for line in expected.splitlines()])
|
||||
self.assertMultiLineEqual(sql_out, expected)
|
||||
|
||||
def test_views_auto_migration(self):
|
||||
"""Assert that views are auto-migrated correctly"""
|
||||
|
||||
self.spacetime(
|
||||
"sql",
|
||||
self.database_identity,
|
||||
"INSERT INTO player_state (id, level) VALUES (1, 1);",
|
||||
)
|
||||
self.spacetime(
|
||||
"sql",
|
||||
self.database_identity,
|
||||
"INSERT INTO player_state (id, level) VALUES (2, 2);",
|
||||
)
|
||||
|
||||
self.assertSql("SELECT * FROM player", """\
|
||||
id | level
|
||||
----+-------
|
||||
1 | 1
|
||||
""")
|
||||
|
||||
self.write_module_code(self.MODULE_CODE_UPDATED)
|
||||
self.publish_module(self.database_identity, clear=False)
|
||||
|
||||
self.assertSql("SELECT * FROM player", """\
|
||||
id | level
|
||||
----+-------
|
||||
2 | 2
|
||||
""")
|
||||
|
||||
|
||||
class AutoMigrateDropView(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::ViewContext;
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[spacetimedb::table(accessor = player_state)]
|
||||
pub struct PlayerState {
|
||||
#[primary_key]
|
||||
id: u64,
|
||||
#[index(btree)]
|
||||
level: u64,
|
||||
}
|
||||
|
||||
#[spacetimedb::view(accessor = player, public)]
|
||||
pub fn player(ctx: &ViewContext) -> Option<PlayerState> {
|
||||
ctx.db.player_state().id().find(1u64)
|
||||
}
|
||||
"""
|
||||
|
||||
MODULE_CODE_DROP_VIEW = """
|
||||
#[derive(Copy, Clone)]
|
||||
#[spacetimedb::table(accessor = player_state)]
|
||||
pub struct PlayerState {
|
||||
#[primary_key]
|
||||
id: u64,
|
||||
#[index(btree)]
|
||||
level: u64,
|
||||
}
|
||||
"""
|
||||
|
||||
def test_auto_migration_drop_view(self):
|
||||
"""Assert that views can be dropped in an auto-migration"""
|
||||
|
||||
self.write_module_code(self.MODULE_CODE_DROP_VIEW)
|
||||
self.publish_module(self.database_identity, clear=False, break_clients=False)
|
||||
|
||||
|
||||
class AutoMigrateAddView(Smoketest):
|
||||
MODULE_CODE = """
|
||||
#[derive(Copy, Clone)]
|
||||
#[spacetimedb::table(accessor = player_state)]
|
||||
pub struct PlayerState {
|
||||
#[primary_key]
|
||||
id: u64,
|
||||
#[index(btree)]
|
||||
level: u64,
|
||||
}
|
||||
"""
|
||||
|
||||
MODULE_CODE_ADD_VIEW = """
|
||||
use spacetimedb::ViewContext;
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[spacetimedb::table(accessor = player_state)]
|
||||
pub struct PlayerState {
|
||||
#[primary_key]
|
||||
id: u64,
|
||||
#[index(btree)]
|
||||
level: u64,
|
||||
}
|
||||
|
||||
#[spacetimedb::view(accessor = player, public)]
|
||||
pub fn player(ctx: &ViewContext) -> Option<PlayerState> {
|
||||
ctx.db.player_state().id().find(1u64)
|
||||
}
|
||||
"""
|
||||
|
||||
def test_auto_migration_drop_view(self):
|
||||
"""Assert that views can be added in an auto-migration"""
|
||||
|
||||
self.write_module_code(self.MODULE_CODE_ADD_VIEW)
|
||||
self.publish_module(self.database_identity, clear=False)
|
||||
|
||||
|
||||
class AutoMigrateViewsTrapped(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::ViewContext;
|
||||
#[derive(Copy, Clone)]
|
||||
#[spacetimedb::table(accessor = player_state)]
|
||||
pub struct PlayerState {
|
||||
#[primary_key]
|
||||
id: u64,
|
||||
#[index(btree)]
|
||||
level: u64,
|
||||
}
|
||||
#[spacetimedb::view(accessor = player, public)]
|
||||
pub fn player(ctx: &ViewContext) -> Option<PlayerState> {
|
||||
ctx.db.player_state().id().find(1u64)
|
||||
}
|
||||
"""
|
||||
|
||||
TRAPPED_MODULE_CODE_UPDATED = """
|
||||
use spacetimedb::ViewContext;
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[spacetimedb::table(accessor = player_state)]
|
||||
pub struct PlayerState {
|
||||
#[primary_key]
|
||||
id: u64,
|
||||
#[index(btree)]
|
||||
level: u64,
|
||||
}
|
||||
#[spacetimedb::view(accessor = player, public)]
|
||||
pub fn player(_ctx: &ViewContext) -> Option<PlayerState> {
|
||||
panic!("This view is trapped")
|
||||
}
|
||||
"""
|
||||
|
||||
MODULE_CODE_RECOVERED = """
|
||||
use spacetimedb::ViewContext;
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[spacetimedb::table(accessor = player_state)]
|
||||
pub struct PlayerState {
|
||||
#[primary_key]
|
||||
id: u64,
|
||||
#[index(btree)]
|
||||
level: u64,
|
||||
}
|
||||
#[spacetimedb::view(accessor = player, public)]
|
||||
pub fn player(ctx: &ViewContext) -> Option<PlayerState> {
|
||||
ctx.db.player_state().id().find(2u64)
|
||||
}
|
||||
"""
|
||||
|
||||
def assertSql(self, sql, expected):
|
||||
self.maxDiff = None
|
||||
sql_out = self.spacetime("sql", self.database_identity, sql)
|
||||
sql_out = "\n".join([line.rstrip() for line in sql_out.splitlines()])
|
||||
expected = "\n".join([line.rstrip() for line in expected.splitlines()])
|
||||
self.assertMultiLineEqual(sql_out, expected)
|
||||
|
||||
def test_recovery_from_trapped_views_auto_migration(self):
|
||||
"""Assert that view auto-migration recovers correctly after trapped migration"""
|
||||
|
||||
self.spacetime(
|
||||
"sql",
|
||||
self.database_identity,
|
||||
"INSERT INTO player_state (id, level) VALUES (1, 1);",
|
||||
)
|
||||
|
||||
# Trigger initial materialization
|
||||
self.assertSql("SELECT * FROM player", """\
|
||||
id | level
|
||||
----+-------
|
||||
1 | 1
|
||||
""")
|
||||
|
||||
# Attempt to publish trapped module (should fail)
|
||||
self.write_module_code(self.TRAPPED_MODULE_CODE_UPDATED)
|
||||
with self.assertRaises(Exception):
|
||||
self.publish_module(self.database_identity, clear=False)
|
||||
|
||||
# Ensure old module still serves queries
|
||||
self.assertSql("SELECT * FROM player", """\
|
||||
id | level
|
||||
----+-------
|
||||
1 | 1
|
||||
""")
|
||||
|
||||
# Fix the module and publish again
|
||||
self.write_module_code(self.MODULE_CODE_RECOVERED)
|
||||
self.publish_module(self.database_identity, clear=False)
|
||||
|
||||
self.assertSql("SELECT * FROM player", """\
|
||||
id | level
|
||||
----+-------
|
||||
""")
|
||||
|
||||
class SubscribeViews(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{Identity, ReducerContext, Table, ViewContext};
|
||||
|
||||
#[spacetimedb::table(accessor = player_state)]
|
||||
pub struct PlayerState {
|
||||
#[primary_key]
|
||||
identity: Identity,
|
||||
#[unique]
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[spacetimedb::view(accessor = my_player, public)]
|
||||
pub fn my_player(ctx: &ViewContext) -> Option<PlayerState> {
|
||||
ctx.db.player_state().identity().find(ctx.sender())
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn insert_player(ctx: &ReducerContext, name: String) {
|
||||
ctx.db.player_state().insert(PlayerState { name, identity: ctx.sender() });
|
||||
}
|
||||
"""
|
||||
|
||||
def test_subscribing_with_different_identities(self):
|
||||
"""Tests different clients subscribing to a client-specific view"""
|
||||
|
||||
# Insert an identity for Alice
|
||||
self.call("insert_player", "Alice")
|
||||
|
||||
# Generate a new identity for Bob
|
||||
self.reset_config()
|
||||
self.new_identity()
|
||||
|
||||
# Subscribe to `my_player` as Bob
|
||||
sub = self.subscribe("select * from my_player", n=1)
|
||||
self.call("insert_player", "Bob")
|
||||
events = sub()
|
||||
|
||||
# Project out the identity field.
|
||||
# TODO: Eventually we should be able to do this directly in the sql.
|
||||
# But for now we implement it in python.
|
||||
projection = [
|
||||
{
|
||||
'my_player': {
|
||||
'deletes': [
|
||||
{'name': row['name']}
|
||||
for row in event['my_player']['deletes']
|
||||
],
|
||||
'inserts': [
|
||||
{'name': row['name']}
|
||||
for row in event['my_player']['inserts']
|
||||
],
|
||||
}
|
||||
}
|
||||
for event in events
|
||||
]
|
||||
|
||||
self.assertEqual(
|
||||
[
|
||||
{
|
||||
'my_player': {
|
||||
'deletes': [],
|
||||
'inserts': [{'name': 'Bob'}],
|
||||
}
|
||||
},
|
||||
],
|
||||
projection,
|
||||
)
|
||||
|
||||
|
||||
class QueryView(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{Query, ReducerContext, Table, ViewContext};
|
||||
|
||||
#[spacetimedb::table(accessor = user, public)]
|
||||
pub struct User {
|
||||
#[primary_key]
|
||||
identity: u8,
|
||||
name: String,
|
||||
online: bool,
|
||||
}
|
||||
|
||||
#[spacetimedb::table(accessor = person, public)]
|
||||
pub struct Person {
|
||||
#[primary_key]
|
||||
identity: u8,
|
||||
name: String,
|
||||
#[index(btree)]
|
||||
age: u8,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer(init)]
|
||||
fn init(ctx: &ReducerContext) {
|
||||
ctx.db.user().insert(User {
|
||||
identity: 1,
|
||||
name: "Alice".to_string(),
|
||||
online: true,
|
||||
});
|
||||
|
||||
ctx.db.user().insert(User {
|
||||
identity: 2,
|
||||
name: "BOB".to_string(),
|
||||
online: false,
|
||||
});
|
||||
|
||||
|
||||
ctx.db.user().insert(User {
|
||||
identity: 3,
|
||||
name: "POP".to_string(),
|
||||
online: false,
|
||||
});
|
||||
|
||||
ctx.db.person().insert(Person {
|
||||
identity: 1,
|
||||
name: "Alice".to_string(),
|
||||
age: 30,
|
||||
});
|
||||
|
||||
|
||||
ctx.db.person().insert(Person {
|
||||
identity: 2,
|
||||
name: "BOB".to_string(),
|
||||
|
||||
age: 20,
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
#[spacetimedb::view(accessor = online_users, public)]
|
||||
fn online_users(ctx: &ViewContext) -> impl Query<User> {
|
||||
ctx.from.user().r#where(|c| c.online.eq(true))
|
||||
}
|
||||
|
||||
#[spacetimedb::view(accessor = online_users_age, public)]
|
||||
fn online_users_age(ctx: &ViewContext) -> impl Query<Person> {
|
||||
ctx.from
|
||||
.user()
|
||||
.r#where(|u| u.online.eq(true))
|
||||
.right_semijoin(ctx.from.person(), |u, p| u.identity.eq(p.identity))
|
||||
}
|
||||
|
||||
#[spacetimedb::view(accessor = offline_user_20_years_old, public)]
|
||||
fn offline_user_in_twienties(ctx: &ViewContext) -> impl Query<User> {
|
||||
ctx.from
|
||||
.person()
|
||||
.filter(|p| p.age.eq(20))
|
||||
.right_semijoin(ctx.from.user(), |p, u| p.identity.eq(u.identity))
|
||||
.filter(|u| u.online.eq(false))
|
||||
}
|
||||
|
||||
#[spacetimedb::view(accessor = users_whos_age_is_known, public)]
|
||||
fn users_whos_age_is_known(ctx: &ViewContext) -> impl Query<User> {
|
||||
ctx.from
|
||||
.user()
|
||||
.left_semijoin(ctx.from.person(), |p, u| p.identity.eq(u.identity))
|
||||
}
|
||||
|
||||
#[spacetimedb::view(accessor = users_who_are_above_20_and_below_30, public)]
|
||||
fn users_who_are_above_20_and_below_30(ctx: &ViewContext) -> impl Query<Person> {
|
||||
ctx.from
|
||||
.person()
|
||||
.r#where(|p| p.age.gt(20).and(p.age.lt(30)))
|
||||
}
|
||||
|
||||
#[spacetimedb::view(accessor = users_who_are_above_eq_20_and_below_eq_30, public)]
|
||||
fn users_who_are_above_eq_20_and_below_eq_30(ctx: &ViewContext) -> impl Query<Person> {
|
||||
ctx.from
|
||||
.person()
|
||||
.r#where(|p| p.age.gte(20).and(p.age.lte(30)))
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def test_query_view(self):
|
||||
"""Tests that views returning Query types work as expected"""
|
||||
|
||||
self.assertSql("SELECT * FROM online_users", """\
|
||||
identity | name | online
|
||||
----------+---------+--------
|
||||
1 | "Alice" | true
|
||||
""")
|
||||
|
||||
def test_query_right_semijoin_view(self):
|
||||
"""Tests that views returning Query types with right semijoin work as expected"""
|
||||
|
||||
self.assertSql("SELECT * FROM online_users_age", """\
|
||||
identity | name | age
|
||||
----------+---------+-----
|
||||
1 | "Alice" | 30
|
||||
""")
|
||||
|
||||
def test_query_left_semijoin_view(self):
|
||||
"""Tests that views returning Query types with left semijoin work as expected"""
|
||||
|
||||
self.assertSql("SELECT * FROM users_whos_age_is_known", """\
|
||||
identity | name | online
|
||||
----------+---------+--------
|
||||
1 | "Alice" | true
|
||||
2 | "BOB" | false
|
||||
""")
|
||||
|
||||
def test_query_complex_right_semijoin_view(self):
|
||||
"""Tests that views returning Query types with right semijoin work as expected"""
|
||||
|
||||
self.assertSql("SELECT * FROM offline_user_20_years_old", """\
|
||||
identity | name | online
|
||||
----------+-------+--------
|
||||
2 | "BOB" | false
|
||||
""")
|
||||
|
||||
def test_where_expr_view(self):
|
||||
"""Tests that views with where expressions work as expected"""
|
||||
|
||||
self.assertSql("SELECT * FROM users_who_are_above_20_and_below_30", """\
|
||||
identity | name | age
|
||||
----------+------+-----
|
||||
""")
|
||||
|
||||
self.assertSql("SELECT * FROM users_who_are_above_eq_20_and_below_eq_30", """\
|
||||
identity | name | age
|
||||
----------+---------+-----
|
||||
1 | "Alice" | 30
|
||||
2 | "BOB" | 20
|
||||
""")
|
||||
|
||||
@@ -1,176 +0,0 @@
|
||||
from .. import Smoketest, requires_docker
|
||||
from ..docker import restart_docker
|
||||
from urllib.request import urlopen
|
||||
from .add_remove_index import AddRemoveIndex
|
||||
|
||||
|
||||
@requires_docker
|
||||
class DockerRestartModule(Smoketest):
|
||||
# Note: creating indexes on `Person`
|
||||
# exercises more possible failure cases when replaying after restart
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{log, ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::table(accessor = person, index(accessor = name_idx, btree(columns = [name])))]
|
||||
pub struct Person {
|
||||
#[primary_key]
|
||||
#[auto_inc]
|
||||
id: u32,
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn add(ctx: &ReducerContext, name: String) {
|
||||
ctx.db.person().insert(Person { id: 0, name });
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn say_hello(ctx: &ReducerContext) {
|
||||
for person in ctx.db.person().iter() {
|
||||
log::info!("Hello, {}!", person.name);
|
||||
}
|
||||
log::info!("Hello, World!");
|
||||
}
|
||||
"""
|
||||
|
||||
def test_restart_module(self):
|
||||
"""This tests to see if SpacetimeDB can be queried after a restart"""
|
||||
|
||||
self.call("add", "Robert")
|
||||
|
||||
restart_docker()
|
||||
|
||||
self.call("add", "Julie")
|
||||
self.call("add", "Samantha")
|
||||
self.call("say_hello")
|
||||
logs = self.logs(100)
|
||||
self.assertIn("Hello, Samantha!", logs)
|
||||
self.assertIn("Hello, Julie!", logs)
|
||||
self.assertIn("Hello, Robert!", logs)
|
||||
self.assertIn("Hello, World!", logs)
|
||||
|
||||
|
||||
@requires_docker
|
||||
class DockerRestartSql(Smoketest):
|
||||
# Note: creating indexes on `Person`
|
||||
# exercises more possible failure cases when replaying after restart
|
||||
MODULE_CODE = """
|
||||
use spacetimedb::{log, ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::table(accessor = person, index(accessor = name_idx, btree(columns = [name])))]
|
||||
pub struct Person {
|
||||
#[primary_key]
|
||||
#[auto_inc]
|
||||
id: u32,
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn add(ctx: &ReducerContext, name: String) {
|
||||
ctx.db.person().insert(Person { id: 0, name });
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
pub fn say_hello(ctx: &ReducerContext) {
|
||||
for person in ctx.db.person().iter() {
|
||||
log::info!("Hello, {}!", person.name);
|
||||
}
|
||||
log::info!("Hello, World!");
|
||||
}
|
||||
"""
|
||||
|
||||
def test_restart_module(self):
|
||||
"""This tests to see if SpacetimeDB can be queried after a restart"""
|
||||
|
||||
self.call("add", "Robert")
|
||||
self.call("add", "Julie")
|
||||
self.call("add", "Samantha")
|
||||
self.call("say_hello")
|
||||
logs = self.logs(100)
|
||||
self.assertIn("Hello, Samantha!", logs)
|
||||
self.assertIn("Hello, Julie!", logs)
|
||||
self.assertIn("Hello, Robert!", logs)
|
||||
self.assertIn("Hello, World!", logs)
|
||||
|
||||
restart_docker()
|
||||
|
||||
sql_out = self.spacetime("sql", self.database_identity, "SELECT name FROM person WHERE id = 3")
|
||||
self.assertMultiLineEqual(sql_out, """ name \n------------\n "Samantha" \n""")
|
||||
|
||||
@requires_docker
|
||||
class DockerRestartAutoDisconnect(Smoketest):
|
||||
MODULE_CODE = """
|
||||
use log::info;
|
||||
use spacetimedb::{ConnectionId, Identity, ReducerContext, Table};
|
||||
|
||||
#[spacetimedb::table(accessor = connected_client)]
|
||||
pub struct ConnectedClient {
|
||||
identity: Identity,
|
||||
connection_id: ConnectionId,
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer(client_connected)]
|
||||
fn on_connect(ctx: &ReducerContext) {
|
||||
ctx.db.connected_client().insert(ConnectedClient {
|
||||
identity: ctx.sender(),
|
||||
connection_id: ctx.connection_id().expect("sender connection id unset"),
|
||||
});
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer(client_disconnected)]
|
||||
fn on_disconnect(ctx: &ReducerContext) {
|
||||
let sender_identity = &ctx.sender();
|
||||
let connection_id = ctx.connection_id();
|
||||
let sender_connection_id = connection_id.as_ref().expect("sender connection id unset");
|
||||
let match_client = |row: &ConnectedClient| {
|
||||
&row.identity == sender_identity && &row.connection_id == sender_connection_id
|
||||
};
|
||||
if let Some(client) = ctx.db.connected_client().iter().find(match_client) {
|
||||
ctx.db.connected_client().delete(client);
|
||||
}
|
||||
}
|
||||
|
||||
#[spacetimedb::reducer]
|
||||
fn print_num_connected(ctx: &ReducerContext) {
|
||||
let n = ctx.db.connected_client().count();
|
||||
info!("CONNECTED CLIENTS: {n}")
|
||||
}
|
||||
"""
|
||||
|
||||
def test_restart_disconnects(self):
|
||||
"""Tests if clients are automatically disconnected after a restart"""
|
||||
|
||||
# Start two subscribers
|
||||
self.subscribe("SELECT * FROM connected_client", n=2)
|
||||
self.subscribe("SELECT * FROM connected_client", n=2)
|
||||
|
||||
# Assert that we have two clients + the reducer call
|
||||
self.call("print_num_connected")
|
||||
logs = self.logs(10)
|
||||
self.assertEqual("CONNECTED CLIENTS: 3", logs.pop())
|
||||
|
||||
restart_docker()
|
||||
|
||||
# After restart, only the current call should be connected
|
||||
self.call("print_num_connected")
|
||||
logs = self.logs(10)
|
||||
self.assertEqual("CONNECTED CLIENTS: 1", logs.pop())
|
||||
|
||||
@requires_docker
|
||||
class AddRemoveIndexAfterRestart(AddRemoveIndex):
|
||||
"""
|
||||
`AddRemoveIndex` from `add_remove_index.py`,
|
||||
but restarts docker between each publish.
|
||||
|
||||
This detects a bug we once had, hopefully fixed now,
|
||||
where the system autoinc sequences were borked after restart,
|
||||
leading newly-created database objects to re-use IDs.
|
||||
|
||||
First publish the module without the indices,
|
||||
then restart docker, then add the indices and publish.
|
||||
Then restart docker, and publish again.
|
||||
There should be no errors from publishing,
|
||||
and the unindexed versions should reject subscriptions.
|
||||
"""
|
||||
def between_publishes(self):
|
||||
restart_docker()
|
||||
@@ -1,376 +0,0 @@
|
||||
# vendored and modified from unittest-parallel by Craig Hobbs
|
||||
# TODO: upstream some of these changes? to make it usable as a library, maybe?
|
||||
|
||||
# Licensed under the MIT License
|
||||
# https://github.com/craigahobbs/unittest-parallel/blob/main/LICENSE
|
||||
# full text of license file below:
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Copyright (c) 2017 Craig A. Hobbs
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
"""
|
||||
unittest-parallel command-line script main module
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from contextlib import contextmanager
|
||||
from io import StringIO
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
import concurrent.futures
|
||||
import threading
|
||||
|
||||
import coverage
|
||||
|
||||
|
||||
def main(**kwargs):
|
||||
"""
|
||||
unittest-parallel command-line script main entry point
|
||||
"""
|
||||
|
||||
# Command line arguments
|
||||
parser = argparse.ArgumentParser(prog='unittest-parallel')
|
||||
parser.add_argument('-v', '--verbose', action='store_const', const=2, default=1,
|
||||
help='Verbose output')
|
||||
parser.add_argument('-q', '--quiet', dest='verbose', action='store_const', const=0, default=1,
|
||||
help='Quiet output')
|
||||
parser.add_argument('-f', '--failfast', action='store_true', default=False,
|
||||
help='Stop on first fail or error')
|
||||
parser.add_argument('-b', '--buffer', action='store_true', default=False,
|
||||
help='Buffer stdout and stderr during tests')
|
||||
parser.add_argument('-k', dest='testNamePatterns', action='append', type=_convert_select_pattern,
|
||||
help='Only run tests which match the given substring')
|
||||
parser.add_argument('-s', '--start-directory', metavar='START', default='.',
|
||||
help="Directory to start discovery ('.' default)")
|
||||
parser.add_argument('-p', '--pattern', metavar='PATTERN', default='test*.py',
|
||||
help="Pattern to match tests ('test*.py' default)")
|
||||
parser.add_argument('-t', '--top-level-directory', metavar='TOP',
|
||||
help='Top level directory of project (defaults to start directory)')
|
||||
group_parallel = parser.add_argument_group('parallelization options')
|
||||
group_parallel.add_argument('-j', '--jobs', metavar='COUNT', type=int, default=0,
|
||||
help='The number of test processes (default is 0, all cores)')
|
||||
group_parallel.add_argument('--level', choices=['module', 'class', 'test'], default='module',
|
||||
help="Set the test parallelism level (default is 'module')")
|
||||
group_parallel.add_argument('--disable-process-pooling', action='store_true', default=False,
|
||||
help='Do not reuse processes used to run test suites')
|
||||
group_coverage = parser.add_argument_group('coverage options')
|
||||
group_coverage.add_argument('--coverage', action='store_true',
|
||||
help='Run tests with coverage')
|
||||
group_coverage.add_argument('--coverage-branch', action='store_true',
|
||||
help='Run tests with branch coverage')
|
||||
group_coverage.add_argument('--coverage-rcfile', metavar='RCFILE',
|
||||
help='Specify coverage configuration file')
|
||||
group_coverage.add_argument('--coverage-include', metavar='PAT', action='append',
|
||||
help='Include only files matching one of these patterns. Accepts shell-style (quoted) wildcards.')
|
||||
group_coverage.add_argument('--coverage-omit', metavar='PAT', action='append',
|
||||
help='Omit files matching one of these patterns. Accepts shell-style (quoted) wildcards.')
|
||||
group_coverage.add_argument('--coverage-source', metavar='SRC', action='append',
|
||||
help='A list of packages or directories of code to be measured')
|
||||
group_coverage.add_argument('--coverage-html', metavar='DIR',
|
||||
help='Generate coverage HTML report')
|
||||
group_coverage.add_argument('--coverage-xml', metavar='FILE',
|
||||
help='Generate coverage XML report')
|
||||
group_coverage.add_argument('--coverage-fail-under', metavar='MIN', type=float,
|
||||
help='Fail if coverage percentage under min')
|
||||
args = parser.parse_args(args=[])
|
||||
args.__dict__.update(kwargs)
|
||||
if args.coverage_branch:
|
||||
args.coverage = args.coverage_branch
|
||||
|
||||
process_count = max(0, args.jobs)
|
||||
if process_count == 0:
|
||||
process_count = os.cpu_count()
|
||||
|
||||
# Create the temporary directory (for coverage files)
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
|
||||
# Discover tests
|
||||
# with _coverage(args, temp_dir):
|
||||
# test_loader = unittest.TestLoader()
|
||||
# if args.testNamePatterns:
|
||||
# test_loader.testNamePatterns = args.testNamePatterns
|
||||
# discover_suite = test_loader.discover(args.start_directory, pattern=args.pattern, top_level_dir=args.top_level_directory)
|
||||
discover_suite = args.discovered_tests
|
||||
|
||||
# Get the parallelizable test suites
|
||||
if args.level == 'test':
|
||||
test_suites = list(_iter_test_cases(discover_suite))
|
||||
elif args.level == 'class':
|
||||
test_suites = list(_iter_class_suites(discover_suite))
|
||||
else: # args.level == 'module'
|
||||
test_suites = list(_iter_module_suites(discover_suite))
|
||||
|
||||
# Don't use more processes than test suites
|
||||
process_count = max(1, min(len(test_suites), process_count))
|
||||
|
||||
# Report test suites and processes
|
||||
print(
|
||||
f'Running {len(test_suites)} test suites ({discover_suite.countTestCases()} total tests) across {process_count} threads',
|
||||
file=sys.stderr
|
||||
)
|
||||
if args.verbose > 1:
|
||||
print(file=sys.stderr)
|
||||
|
||||
# Run the tests in parallel
|
||||
start_time = time.perf_counter()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=process_count) as executor:
|
||||
test_manager = ParallelTestManager(args, temp_dir)
|
||||
|
||||
futures = [executor.submit(test_manager.run_tests, suite) for suite in discover_suite]
|
||||
|
||||
# Aggregate parallel test run results
|
||||
tests_run = 0
|
||||
errors = []
|
||||
failures = []
|
||||
skipped = 0
|
||||
expected_failures = 0
|
||||
unexpected_successes = 0
|
||||
|
||||
for fut in concurrent.futures.as_completed(futures):
|
||||
try:
|
||||
result, stream = fut.result()
|
||||
except concurrent.futures.CancelledError:
|
||||
continue
|
||||
tests_run += result.testsRun
|
||||
errors.extend(ParallelTestManager._format_error(result, error) for error in result.errors)
|
||||
failures.extend(ParallelTestManager._format_error(result, failure) for failure in result.failures)
|
||||
skipped += len(result.skipped)
|
||||
expected_failures += len(result.expectedFailures)
|
||||
unexpected_successes += len(result.unexpectedSuccesses)
|
||||
if result.shouldStop:
|
||||
for fut in futures:
|
||||
fut.cancel()
|
||||
|
||||
is_success = not(errors or failures or unexpected_successes)
|
||||
|
||||
stop_time = time.perf_counter()
|
||||
test_duration = stop_time - start_time
|
||||
|
||||
# Compute test info
|
||||
infos = []
|
||||
if failures:
|
||||
infos.append(f'failures={len(failures)}')
|
||||
if errors:
|
||||
infos.append(f'errors={len(errors)}')
|
||||
if skipped:
|
||||
infos.append(f'skipped={skipped}')
|
||||
if expected_failures:
|
||||
infos.append(f'expected failures={expected_failures}')
|
||||
if unexpected_successes:
|
||||
infos.append(f'unexpected successes={unexpected_successes}')
|
||||
|
||||
# Report test errors
|
||||
if errors or failures:
|
||||
print(file=sys.stderr)
|
||||
for error in errors:
|
||||
print(error, file=sys.stderr)
|
||||
for failure in failures:
|
||||
print(failure, file=sys.stderr)
|
||||
elif args.verbose > 0:
|
||||
print(file=sys.stderr)
|
||||
|
||||
# Test report
|
||||
print(unittest.TextTestResult.separator2, file=sys.stderr)
|
||||
print(f'Ran {tests_run} {"tests" if tests_run > 1 else "test"} in {test_duration:.3f}s', file=sys.stderr)
|
||||
print(file=sys.stderr)
|
||||
print(f'{"OK" if is_success else "FAILED"}{" (" + ", ".join(infos) + ")" if infos else ""}', file=sys.stderr)
|
||||
|
||||
# Return an error status on failure
|
||||
if not is_success:
|
||||
parser.exit(status=len(errors) + len(failures) + unexpected_successes)
|
||||
|
||||
# Coverage?
|
||||
if args.coverage:
|
||||
|
||||
# Combine the coverage files
|
||||
cov_options = {}
|
||||
if args.coverage_rcfile is not None:
|
||||
cov_options['config_file'] = args.coverage_rcfile
|
||||
cov = coverage.Coverage(**cov_options)
|
||||
cov.combine(data_paths=[os.path.join(temp_dir, x) for x in os.listdir(temp_dir)])
|
||||
|
||||
# Coverage report
|
||||
print(file=sys.stderr)
|
||||
percent_covered = cov.report(ignore_errors=True, file=sys.stderr)
|
||||
print(f'Total coverage is {percent_covered:.2f}%', file=sys.stderr)
|
||||
|
||||
# HTML coverage report
|
||||
if args.coverage_html:
|
||||
cov.html_report(directory=args.coverage_html, ignore_errors=True)
|
||||
|
||||
# XML coverage report
|
||||
if args.coverage_xml:
|
||||
cov.xml_report(outfile=args.coverage_xml, ignore_errors=True)
|
||||
|
||||
# Fail under
|
||||
if args.coverage_fail_under and percent_covered < args.coverage_fail_under:
|
||||
parser.exit(status=2)
|
||||
|
||||
|
||||
def _convert_select_pattern(pattern):
|
||||
if not '*' in pattern:
|
||||
return f'*{pattern}*'
|
||||
return pattern
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _coverage(args, temp_dir):
|
||||
# Running tests with coverage?
|
||||
if args.coverage:
|
||||
# Generate a random coverage data file name - file is deleted along with containing directory
|
||||
with tempfile.NamedTemporaryFile(dir=temp_dir, delete=False) as coverage_file:
|
||||
pass
|
||||
|
||||
# Create the coverage object
|
||||
cov_options = {
|
||||
'branch': args.coverage_branch,
|
||||
'data_file': coverage_file.name,
|
||||
'include': args.coverage_include,
|
||||
'omit': (args.coverage_omit if args.coverage_omit else []) + [__file__],
|
||||
'source': args.coverage_source
|
||||
}
|
||||
if args.coverage_rcfile is not None:
|
||||
cov_options['config_file'] = args.coverage_rcfile
|
||||
cov = coverage.Coverage(**cov_options)
|
||||
try:
|
||||
# Start measuring code coverage
|
||||
cov.start()
|
||||
|
||||
# Yield for unit test running
|
||||
yield cov
|
||||
finally:
|
||||
# Stop measuring code coverage
|
||||
cov.stop()
|
||||
|
||||
# Save the collected coverage data to the data file
|
||||
cov.save()
|
||||
else:
|
||||
# Not running tests with coverage - yield for unit test running
|
||||
yield None
|
||||
|
||||
|
||||
# Iterate module-level test suites - all top-level test suites returned from TestLoader.discover
|
||||
def _iter_module_suites(test_suite):
|
||||
for module_suite in test_suite:
|
||||
if module_suite.countTestCases():
|
||||
yield module_suite
|
||||
|
||||
|
||||
# Iterate class-level test suites - test suites that contains test cases
|
||||
def _iter_class_suites(test_suite):
|
||||
has_cases = any(isinstance(suite, unittest.TestCase) for suite in test_suite)
|
||||
if has_cases:
|
||||
yield test_suite
|
||||
else:
|
||||
for suite in test_suite:
|
||||
yield from _iter_class_suites(suite)
|
||||
|
||||
|
||||
# Iterate test cases (methods)
|
||||
def _iter_test_cases(test_suite):
|
||||
if isinstance(test_suite, unittest.TestCase):
|
||||
yield test_suite
|
||||
else:
|
||||
for suite in test_suite:
|
||||
yield from _iter_test_cases(suite)
|
||||
|
||||
|
||||
class ParallelTestManager:
|
||||
|
||||
def __init__(self, args, temp_dir):
|
||||
self.args = args
|
||||
self.temp_dir = temp_dir
|
||||
|
||||
def run_tests(self, test_suite):
|
||||
# Run unit tests
|
||||
with _coverage(self.args, self.temp_dir):
|
||||
stream = StringIO()
|
||||
runner = unittest.TextTestRunner(
|
||||
stream=stream,
|
||||
resultclass=ParallelTextTestResult,
|
||||
verbosity=self.args.verbose,
|
||||
failfast=self.args.failfast,
|
||||
buffer=self.args.buffer
|
||||
)
|
||||
result = runner.run(test_suite)
|
||||
|
||||
# Return (test_count, errors, failures, skipped_count, expected_failure_count, unexpected_success_count)
|
||||
return result, stream
|
||||
|
||||
@staticmethod
|
||||
def _format_error(result, error):
|
||||
return '\n'.join([
|
||||
unittest.TextTestResult.separator1,
|
||||
result.getDescription(error[0]),
|
||||
unittest.TextTestResult.separator2,
|
||||
error[1]
|
||||
])
|
||||
|
||||
|
||||
class ParallelTextTestResult(unittest.TextTestResult):
|
||||
|
||||
def __init__(self, stream, descriptions, verbosity):
|
||||
stream = type(stream)(sys.stderr)
|
||||
super().__init__(stream, descriptions, verbosity)
|
||||
|
||||
def startTest(self, test):
|
||||
if self.showAll:
|
||||
self.stream.writeln(f'{self.getDescription(test)} ...')
|
||||
self.stream.flush()
|
||||
super(unittest.TextTestResult, self).startTest(test)
|
||||
|
||||
def _add_helper(self, test, dots_message, show_all_message):
|
||||
if self.showAll:
|
||||
self.stream.writeln(f'{self.getDescription(test)} ... {show_all_message}')
|
||||
elif self.dots:
|
||||
self.stream.write(dots_message)
|
||||
self.stream.flush()
|
||||
|
||||
def addSuccess(self, test):
|
||||
super(unittest.TextTestResult, self).addSuccess(test)
|
||||
self._add_helper(test, '.', 'ok')
|
||||
|
||||
def addError(self, test, err):
|
||||
super(unittest.TextTestResult, self).addError(test, err)
|
||||
self._add_helper(test, 'E', 'ERROR')
|
||||
|
||||
def addFailure(self, test, err):
|
||||
super(unittest.TextTestResult, self).addFailure(test, err)
|
||||
self._add_helper(test, 'F', 'FAIL')
|
||||
|
||||
def addSkip(self, test, reason):
|
||||
super(unittest.TextTestResult, self).addSkip(test, reason)
|
||||
self._add_helper(test, 's', f'skipped {reason!r}')
|
||||
|
||||
def addExpectedFailure(self, test, err):
|
||||
super(unittest.TextTestResult, self).addExpectedFailure(test, err)
|
||||
self._add_helper(test, 'x', 'expected failure')
|
||||
|
||||
def addUnexpectedSuccess(self, test):
|
||||
super(unittest.TextTestResult, self).addUnexpectedSuccess(test)
|
||||
self._add_helper(test, 'u', 'unexpected success')
|
||||
|
||||
def printErrors(self):
|
||||
pass
|
||||
Reference in New Issue
Block a user