Remove Python smoketests (#4896)

# Description of Changes

Remove the Python smoketests and the CI check that tests for edits.

# API and ABI breaking changes

None. CI only.

# Expected complexity level and risk

1

# Testing

- [ ] All CI passes

Co-authored-by: Zeke Foppa <bfops@users.noreply.github.com>
This commit is contained in:
Zeke Foppa
2026-04-27 21:34:24 -07:00
committed by GitHub
parent 3b28744938
commit 9e2946b2c0
43 changed files with 0 additions and 6765 deletions
-30
View File
@@ -1025,36 +1025,6 @@ jobs:
- name: Check global.json policy
run: cargo ci global-json-policy
warn-python-smoketests:
name: Check for Python smoketest edits
runs-on: ubuntu-latest
if: github.event_name == 'pull_request'
permissions:
contents: read
steps:
- name: Checkout sources
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Fail if Python smoketests were modified
run: |
MERGE_BASE="$(git merge-base origin/${{ github.base_ref }} HEAD)"
PYTHON_SMOKETEST_CHANGES="$(git diff --name-only "$MERGE_BASE" HEAD -- 'smoketests/**.py')"
if [ -n "$PYTHON_SMOKETEST_CHANGES" ]; then
echo "::error::This PR modifies legacy Python smoketests. Please add new tests to the Rust smoketests in crates/smoketests/ instead."
echo ""
echo "Changed files:"
echo "$PYTHON_SMOKETEST_CHANGES"
echo ""
echo "The Python smoketests are being replaced by Rust smoketests."
echo "See crates/smoketests/DEVELOP.md for instructions on adding Rust smoketests."
exit 1
fi
echo "No Python smoketest changes detected."
smoketests_mod_rs_complete:
name: Check smoketests/mod.rs is complete
runs-on: ubuntu-latest
-22
View File
@@ -1,22 +0,0 @@
# Python Smoketests (Legacy)
> **Note:** These Python smoketests are being replaced by Rust smoketests in `crates/smoketests/`.
> Both test suites currently run in CI to ensure consistency during the transition.
>
> For new tests, please add them to the Rust smoketests. See `crates/smoketests/DEVELOP.md` for instructions.
---
## Running the Python Smoketests
To use the smoketests, you first need to install the dependencies:
```
python -m venv smoketests/venv
smoketests/venv/bin/pip install -r smoketests/requirements.txt
```
Then, run the smoketests like so:
```
smoketests/venv/bin/python -m smoketests <args>
```
-445
View File
@@ -1,445 +0,0 @@
from pathlib import Path
import contextlib
import json
import os
import random
import re
import shutil
import string
import subprocess
import sys
import tempfile
import threading
import unittest
import logging
import http.client
import tomllib
import functools
# miscellaneous file paths
TEST_DIR = Path(__file__).parent
STDB_DIR = TEST_DIR.parent
exe_suffix = ".exe" if sys.platform == "win32" else ""
SPACETIME_BIN = STDB_DIR / ("target/debug/spacetime" + exe_suffix)
TEMPLATE_TARGET_DIR = STDB_DIR / "target/_stdbsmoketests"
BASE_STDB_CONFIG_PATH = TEST_DIR / "config.toml"
# the contents of files for the base smoketest project template
TEMPLATE_LIB_RS = open(STDB_DIR / "templates/basic-rs/spacetimedb/src/lib.rs").read()
TEMPLATE_CARGO_TOML = open(STDB_DIR / "templates/basic-rs/spacetimedb/Cargo.toml").read()
bindings_path = (STDB_DIR / "crates/bindings").absolute()
escaped_bindings_path = str(bindings_path).replace('\\', '\\\\\\\\') # double escape for re.sub + toml
TYPESCRIPT_BINDINGS_PATH = (STDB_DIR / "crates/bindings-typescript").absolute()
TEMPLATE_CARGO_TOML = (re.compile(r"^spacetimedb\s*=.*$", re.M) \
.sub(f'spacetimedb = {{ path = "{escaped_bindings_path}", features = {{features}} }}', TEMPLATE_CARGO_TOML))
# this is set to true when the --docker flag is passed to the cli
HAVE_DOCKER = False
# this is set to true when the --skip-dotnet flag is not passed to the cli,
# and a dotnet installation is detected
HAVE_DOTNET = False
# When we pass --spacetime-login, we are running against a server that requires "real" spacetime logins (rather than `--server-issued-login`).
# This is used to skip tests that don't work with that.
USE_SPACETIME_LOGIN = False
# If we pass `--remote-server`, the server address will be something other than the default. This is used to skip tests that rely on use
# having the default localhost server.
REMOTE_SERVER = False
# default value can be overridden by `--compose-file` flag
COMPOSE_FILE = ".github/docker-compose.yml"
# this will be initialized by main()
STDB_CONFIG = ''
# we need to late-bind the output stream to allow unittests to capture stdout/stderr.
class CapturableHandler(logging.StreamHandler):
@property
def stream(self):
return sys.stderr
@stream.setter
def stream(self, value):
pass
handler = CapturableHandler()
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
logging.getLogger().addHandler(handler)
logging.getLogger().setLevel(logging.DEBUG)
def requires_dotnet(item):
if HAVE_DOTNET:
return item
return unittest.skip("dotnet 8.0 not available")(item)
def requires_anonymous_login(item):
if USE_SPACETIME_LOGIN:
return unittest.skip("using `spacetime login`")(item)
return item
def requires_local_server(item):
if REMOTE_SERVER:
return unittest.skip("running against a remote server")(item)
return item
def build_template_target():
if not TEMPLATE_TARGET_DIR.exists():
logging.info("Building base compilation artifacts")
class BuildModule(Smoketest):
AUTOPUBLISH = False
BuildModule.setUpClass()
env = { **os.environ, "CARGO_TARGET_DIR": str(TEMPLATE_TARGET_DIR) }
spacetime("build", "--module-path", BuildModule.project_path, env=env)
BuildModule.tearDownClass()
BuildModule.doClassCleanups()
def requires_docker(item):
if HAVE_DOCKER:
return item
return unittest.skip("docker not available")(item)
def random_string(k=20):
return ''.join(random.choices(string.ascii_lowercase, k=k))
def extract_fields(cmd_output, field_name):
"""
parses output from the spacetime cli that's formatted in the "empty" style
from tabled:
FIELDNAME1 VALUE1
THEFIELDNAME2 VALUE2
field_name should be which field name you want to filter for
"""
out = []
for line in cmd_output.splitlines():
fields = line.split()
if len(fields) < 2:
continue
label, val, *_ = fields
if label == field_name:
out.append(val)
return out
def parse_sql_result(res: str) -> list[dict]:
"""Parse tabular output from an SQL query into a list of dicts."""
lines = res.splitlines()
headers = lines[0].split('|') if '|' in lines[0] else [lines[0]]
headers = [header.strip() for header in headers]
rows = []
for row in lines[2:]:
cols = [col.strip() for col in row.split('|')]
rows.append(dict(zip(headers, cols)))
return rows
def extract_field(cmd_output, field_name):
field, = extract_fields(cmd_output, field_name)
return field
def log_cmd(args):
logging.debug(f"$ {' '.join(str(arg) for arg in args)}")
def run_cmd(*args, capture_stderr=True, check=True, full_output=False, cmd_name=None, log=True, **kwargs):
if log:
log_cmd(args if cmd_name is None else [cmd_name, *args[1:]])
needs_close = False
if not capture_stderr:
logging.debug("--- stderr ---")
needs_close = True
output = subprocess.run(
list(args),
encoding="utf8",
stdout=subprocess.PIPE,
stderr=subprocess.PIPE if capture_stderr else None,
**kwargs
)
if log:
if capture_stderr and output.stderr.strip() != "":
logging.debug(f"--- stderr ---\n{output.stderr.strip()}")
needs_close = True
if output.stdout.strip() != "":
logging.debug(f"--- stdout ---\n{output.stdout.strip()}")
needs_close = True
if needs_close:
logging.debug("--------------\n")
sys.stderr.flush()
if check:
if cmd_name is not None:
output.args[0] = cmd_name
output.check_returncode()
return output if full_output else output.stdout
@functools.cache
def pnpm_path():
pnpm = shutil.which("pnpm")
if not pnpm:
raise Exception("pnpm not installed")
return pnpm
def pnpm(*args, **kwargs):
return run_cmd(pnpm_path(), *args, **kwargs)
@functools.cache
def build_typescript_sdk():
pnpm("install", cwd=TYPESCRIPT_BINDINGS_PATH)
pnpm("build", cwd=TYPESCRIPT_BINDINGS_PATH)
def spacetime(*args, **kwargs):
return run_cmd(SPACETIME_BIN, *args, cmd_name="spacetime", **kwargs)
def new_identity(config_path):
spacetime("--config-path", str(config_path), "logout")
spacetime("--config-path", str(config_path), "login", "--server-issued-login", "localhost", full_output=False)
class Smoketest(unittest.TestCase):
MODULE_CODE = TEMPLATE_LIB_RS
AUTOPUBLISH = True
BINDINGS_FEATURES = ["unstable"]
EXTRA_DEPS = ""
@classmethod
def cargo_manifest(cls, manifest_text):
return manifest_text.replace("{features}", repr(list(cls.BINDINGS_FEATURES))) + cls.EXTRA_DEPS
# helpers
@classmethod
def spacetime(cls, *args, **kwargs):
return spacetime("--config-path", str(cls.config_path), *args, **kwargs)
def _check_published(self):
if not hasattr(self, "database_identity"):
raise Exception("Cannot use this function without publishing a module")
def call(self, reducer, *args, anon=False, check=True, full_output = False):
self._check_published()
anon = ["--anonymous"] if anon else []
return self.spacetime("call", *anon, "--", self.database_identity, reducer, *map(json.dumps, args), check = check, full_output=full_output)
def sql(self, sql):
self._check_published()
anon = ["--anonymous"]
return self.spacetime("sql", *anon, "--", self.database_identity, sql)
def logs(self, n):
return [log["message"] for log in self.log_records(n)]
def log_records(self, n):
self._check_published()
logs = self.spacetime("logs", "--format=json", "-n", str(n), "--", self.database_identity)
return list(map(json.loads, logs.splitlines()))
def publish_module(self, domain=None, *, clear=True, capture_stderr=True,
num_replicas=None, break_clients=False, organization=None):
publish_output = self.spacetime(
"publish",
*[domain] if domain is not None else [],
*["-c"] if clear and domain is not None else [],
"--module-path", self.project_path,
# This is required if -c is provided, but is also required for SpacetimeDBPrivate's tests,
# because the server address is `node` which doesn't look like `localhost` or `127.0.0.1`
# and so the publish step prompts for confirmation.
"--yes",
*["--num-replicas", f"{num_replicas}"] if num_replicas is not None else [],
*["--break-clients"] if break_clients else [],
*["--organization", f"{organization}"] if organization is not None else [],
capture_stderr=capture_stderr,
)
self.resolved_identity = re.search(r"identity: ([0-9a-fA-F]+)", publish_output)[1]
self.database_identity = self.resolved_identity
@classmethod
def reset_config(cls):
if not STDB_CONFIG:
raise Exception("config toml has not been initialized yet")
cls.config_path.write_text(STDB_CONFIG)
def fingerprint(self):
# Fetch the server's fingerprint; required for `identity list`.
self.spacetime("server", "fingerprint", "localhost", "-y")
def new_identity(self):
new_identity(self.__class__.config_path)
def subscribe(self, *queries, n, confirmed = None, database = None):
self._check_published()
assert isinstance(n, int)
args = [
SPACETIME_BIN,
"--config-path", str(self.config_path),
"subscribe",
database if database is not None else self.database_identity,
"-t", "600",
"-n", str(n),
"--print-initial-update",
]
if confirmed is not None:
args.append(f"--confirmed={str(confirmed).lower()}")
args.extend(["--", *queries])
fake_args = ["spacetime", *args[1:]]
log_cmd(fake_args)
proc = subprocess.Popen(args, encoding="utf8", stdout=subprocess.PIPE, stderr=subprocess.PIPE)
def stderr_task():
sys.stderr.writelines(proc.stderr)
threading.Thread(target=stderr_task).start()
init_update = proc.stdout.readline().strip()
if init_update:
print("initial update:", init_update)
else:
try:
code = proc.wait()
if code:
raise subprocess.CalledProcessError(code, fake_args)
print("no initial update, but no error code either")
except subprocess.TimeoutExpired:
print("no initial update, but process is still running")
def run():
updates = list(map(json.loads, proc.stdout))
code = proc.wait()
if code:
raise subprocess.CalledProcessError(code, fake_args)
return updates
# Note that we're returning `.join`, not `.join()`; this returns something that the caller can call in order to
# join the thread and wait for the results.
# If the caller does not invoke this returned value, the thread will just run in the background, not be awaited,
# and **not raise any exceptions to the caller**.
return ReturnThread(run).join
def get_server_address(self):
with open(self.config_path, "rb") as f:
config = tomllib.load(f)
token = config['spacetimedb_token']
server_name = config['default_server']
server_config = next((c for c in config['server_configs'] if c['nickname'] == server_name), None)
if server_config is None:
raise Exception(f"Unable to find server in config with nickname {server_name}")
address = server_config['host']
host = address
port = None
if ":" in host:
host, port = host.split(":", 1)
protocol = server_config['protocol']
return dict(address=address, host=host, port=port, protocol=protocol, token=token)
# Make an HTTP call with `method` to `path`.
#
# If the response is 200, return the body.
# Otherwise, throw an `Exception` constructed with two arguments, the response object and the body.
def api_call(self, method, path, body=None, headers={}):
server = self.get_server_address()
host = server["address"]
protocol = server["protocol"]
token = server["token"]
conn = None
if protocol == "http":
conn = http.client.HTTPConnection(host)
elif protocol == "https":
conn = http.client.HTTPSConnection(host)
else:
raise Exception(f"Unknown protocol: {protocol}")
auth = {"Authorization": f'Bearer {token}'}
headers.update(auth)
log_cmd([method, path])
conn.request(method, path, body, headers)
resp = conn.getresponse()
body = resp.read()
logging.debug(f"{resp.status} {body}")
if resp.status != 200:
raise Exception(resp, body)
return body
@classmethod
def write_module_code(cls, module_code):
open(cls.project_path / "src/lib.rs", "w").write(module_code)
# testcase initialization
@classmethod
def setUpClass(cls):
cls.project_path = Path(cls.enterClassContext(tempfile.TemporaryDirectory()))
cls.config_path = cls.project_path / "config.toml"
cls.reset_config()
open(cls.project_path / "Cargo.toml", "w").write(cls.cargo_manifest(TEMPLATE_CARGO_TOML))
shutil.copy2(STDB_DIR / "rust-toolchain.toml", cls.project_path)
os.mkdir(cls.project_path / "src")
cls.write_module_code(cls.MODULE_CODE)
if TEMPLATE_TARGET_DIR.exists():
shutil.copytree(TEMPLATE_TARGET_DIR, cls.project_path / "target")
if cls.AUTOPUBLISH:
logging.info(f"Compiling module for {cls.__qualname__}...")
cls.publish_module(cls, capture_stderr=True) # capture stderr because otherwise it clutters the top-level test logs for some reason.
def tearDown(self):
# if this single test method published a database, clean it up now
if "database_identity" in self.__dict__:
try:
# TODO: save the credentials in publish_module()
self.spacetime("delete", "--yes", self.database_identity)
except Exception:
pass
@classmethod
def tearDownClass(cls):
if hasattr(cls, "database_identity"):
try:
# TODO: save the credentials in publish_module()
cls.spacetime("delete", "--yes", cls.database_identity)
except Exception:
pass
if sys.version_info < (3, 11):
# polyfill; python 3.11 defines this classmethod on TestCase
@classmethod
def enterClassContext(cls, cm):
result = cm.__enter__()
cls.addClassCleanup(cm.__exit__, None, None, None)
return result
def assertSql(self, sql: str, expected: str):
"""Assert that executing `sql` produces the expected output."""
self.maxDiff = None
sql_out = self.spacetime("sql", self.database_identity, sql)
sql_out = "\n".join([line.rstrip() for line in sql_out.splitlines()])
expected = "\n".join([line.rstrip() for line in expected.splitlines()])
self.assertMultiLineEqual(sql_out, expected)
# This is a custom thread class that will propagate an exception to the caller of `.join()`.
# This is required because, by default, threads do not propagate exceptions to their callers,
# even callers who have called `join`.
class ReturnThread:
def __init__(self, target):
self._target = target
self._exception = None
self._thread = threading.Thread(target=self._task)
self._thread.start()
def _task(self):
# Wrap self._target()` with an exception handler, so we can return the exception
# to the caller of `join` below.
try:
self._result = self._target()
except BaseException as e:
self._exception = e
finally:
del self._target
def join(self, timeout=None):
self._thread.join(timeout)
if self._exception is not None:
raise self._exception
return self._result
-190
View File
@@ -1,190 +0,0 @@
#!/usr/bin/env python
import subprocess
import unittest
import argparse
import os
import re
import fnmatch
import json
from . import TEST_DIR, SPACETIME_BIN, BASE_STDB_CONFIG_PATH, exe_suffix, build_template_target
import smoketests
import sys
import logging
import itertools
import tempfile
from pathlib import Path
import shutil
import traceback
def check_docker():
docker_ps = smoketests.run_cmd("docker", "ps", "--format=json")
docker_ps = (json.loads(line) for line in docker_ps.splitlines())
for docker_container in docker_ps:
if "node" in docker_container["Image"] or "spacetime" in docker_container["Image"]:
return docker_container["Names"]
else:
print("Docker container not found, is SpacetimeDB running?")
exit(1)
def check_dotnet() -> bool:
try:
version = smoketests.run_cmd("dotnet", "--version", log=False).strip()
if int(version.split(".")[0]) < 8:
logging.info(f"dotnet version {version} not high enough (< 8.0), skipping dotnet smoketests")
return False
except Exception:
return False
return True
class ExclusionaryTestLoader(unittest.TestLoader):
def __init__(self, excludelist=()):
super().__init__()
# build a regex that matches any of the elements of excludelist at a word boundary
excludes = '|'.join(fnmatch.translate(exclude).removesuffix(r"\Z") for exclude in excludelist)
self.excludepat = excludes and re.compile(rf"^(?:{excludes})\b")
def loadTestsFromName(self, name, module=None):
if self.excludepat:
qualname = name
if module is not None:
qualname = module.__name__ + "." + name
if self.excludepat.match(qualname):
return self.suiteClass([])
return super().loadTestsFromName(name, module)
def _convert_select_pattern(pattern):
return f'*{pattern}*' if '*' not in pattern else pattern
TESTPREFIX = "smoketests.tests."
def _iter_all_tests(suite_or_case):
"""Yield all individual tests from possibly nested TestSuite structures."""
if isinstance(suite_or_case, unittest.TestSuite):
for t in suite_or_case:
yield from _iter_all_tests(t)
else:
yield suite_or_case
def main():
tests = [fname.removesuffix(".py") for fname in os.listdir(TEST_DIR / "tests") if fname.endswith(".py") and fname != "__init__.py"]
parser = argparse.ArgumentParser()
parser.add_argument("test", nargs="*", default=tests)
parser.add_argument("--docker", action="store_true")
parser.add_argument("--compose-file")
parser.add_argument("--no-docker-logs", action="store_true")
parser.add_argument("--skip-dotnet", action="store_true", help="ignore tests which require dotnet")
parser.add_argument("--show-all-output", action="store_true", help="show all stdout/stderr from the tests as they're running")
parser.add_argument("--parallel", action="store_true", help="run test classes in parallel")
parser.add_argument("-j", dest='jobs', help="Set number of jobs for parallel test runs. Default is `nproc`", type=int, default=0)
parser.add_argument('-k', dest='testNamePatterns',
action='append', type=_convert_select_pattern,
help='Only run tests which match the given substring')
parser.add_argument("-x", dest="exclude", nargs="*", default=[])
parser.add_argument("--no-build-cli", action="store_true", help="don't cargo build the cli")
parser.add_argument("--list", action="store_true", help="list the tests that would be run, but don't run them")
parser.add_argument("--remote-server", action="store", help="Run against a remote server")
parser.add_argument("--spacetime-login", action="store_true", help="Use `spacetime login` for these tests (and disable tests that don't work with that)")
args = parser.parse_args()
if args.docker:
# have docker logs print concurrently with the test output
if args.compose_file:
smoketests.COMPOSE_FILE = args.compose_file
if not args.no_docker_logs:
if args.compose_file:
subprocess.Popen(["docker", "compose", "-f", args.compose_file, "logs", "-f"])
else:
docker_container = check_docker()
subprocess.Popen(["docker", "logs", "-f", docker_container])
smoketests.HAVE_DOCKER = True
if not args.skip_dotnet:
smoketests.HAVE_DOTNET = check_dotnet()
if not smoketests.HAVE_DOTNET:
print("no suitable dotnet installation found")
exit(1)
add_prefix = lambda testlist: [TESTPREFIX + test for test in testlist]
import fnmatch
excludelist = add_prefix(args.exclude)
testlist = add_prefix(args.test)
loader = ExclusionaryTestLoader(excludelist)
loader.testNamePatterns = args.testNamePatterns
tests = loader.loadTestsFromNames(testlist)
if args.list:
failed_cls = getattr(unittest.loader, "_FailedTest", None)
any_failed = False
for test in _iter_all_tests(tests):
name = test.id()
if isinstance(test, failed_cls):
any_failed = True
print('')
print("Failed to construct %s:" % test.id())
exc = getattr(test, "_exception", None)
if exc is not None:
tb = ''.join(traceback.format_exception(exc))
print(tb.rstrip())
print('')
else:
print(f"{name}")
exit(1 if any_failed else 0)
if not args.no_build_cli:
logging.info("Compiling spacetime cli...")
smoketests.run_cmd("cargo", "build", cwd=TEST_DIR.parent, capture_stderr=False)
update_bin_name = "spacetimedb-update" + exe_suffix
try:
bin_is_symlink = SPACETIME_BIN.readlink() == update_bin_name
except OSError:
bin_is_symlink = False
if not bin_is_symlink:
try:
os.remove(SPACETIME_BIN)
except FileNotFoundError:
pass
try:
os.symlink(update_bin_name, SPACETIME_BIN)
except OSError:
shutil.copyfile(SPACETIME_BIN.with_name(update_bin_name), SPACETIME_BIN)
os.environ["SPACETIME_SKIP_CLIPPY"] = "1"
with tempfile.NamedTemporaryFile(mode="w+b", suffix=".toml", buffering=0, delete_on_close=False) as config_file:
with BASE_STDB_CONFIG_PATH.open("rb") as src, config_file.file as dst:
shutil.copyfileobj(src, dst)
if args.remote_server is not None:
smoketests.spacetime("--config-path", config_file.name, "server", "edit", "localhost", "--url", args.remote_server, "--yes")
smoketests.REMOTE_SERVER = True
if args.spacetime_login:
smoketests.spacetime("--config-path", config_file.name, "logout")
smoketests.spacetime("--config-path", config_file.name, "login")
smoketests.USE_SPACETIME_LOGIN = True
else:
smoketests.new_identity(config_file.name)
smoketests.STDB_CONFIG = Path(config_file.name).read_text()
build_template_target()
buffer = not args.show_all_output
verbosity = 2
if args.parallel:
print("parallel test running is under construction, this will probably not work correctly")
from . import unittest_parallel
unittest_parallel.main(buffer=buffer, verbose=verbosity, level="class", discovered_tests=tests, jobs=args.jobs)
else:
result = unittest.TextTestRunner(buffer=buffer, verbosity=verbosity).run(tests)
if not result.wasSuccessful():
parser.exit(status=1)
if __name__ == '__main__':
main()
-6
View File
@@ -1,6 +0,0 @@
default_server = "localhost"
[[server_configs]]
nickname = "localhost"
host = "127.0.0.1:3000"
protocol = "http"
-208
View File
@@ -1,208 +0,0 @@
import json
import os
import subprocess
import time
from dataclasses import dataclass
from typing import List, Optional, Callable
from urllib.request import urlopen
from . import COMPOSE_FILE
def restart_docker():
"""
Restart all containers defined in the current `COMPOSE_FILE`.
Checks that all spacetimedb containers are up and running after the restart.
If they're not up after a couple of retries, throws an `Exception`.
"""
print("Restarting containers")
docker = DockerManager(COMPOSE_FILE)
docker.compose("restart")
containers = docker.list_spacetimedb_containers()
if not containers:
raise Exception("No spacetimedb containers found")
# Ensure all nodes are running.
attempts = 0
while attempts < 10:
attempts += 1
containers_alive = {
container.name: container.is_running(docker, spacetimedb_ping_url)
for container in containers
}
if all(containers_alive.values()):
# sleep a bit more to allow for leader election etc
# TODO: make ping endpoint consider all server state
time.sleep(2)
return
else:
time.sleep(1)
raise Exception(f"Not all containers are up and running: {containers_alive!r}")
def spacetimedb_ping_url(port: int) -> str:
return f"http://127.0.0.1:{port}/v1/ping"
@dataclass
class DockerContainer:
"""Represents a Docker container with its basic properties."""
id: str
name: str
def host_ports(self, docker) -> set[int]:
"""
Collect all host ports of this container.
Host ports are ports on the host that are bound to ports of the
container.
If the container is not currently running, an empty set is returned.
"""
host_ports = set()
info = docker.inspect_container(self)
for ports in info.get('NetworkSettings', {}).get('Ports', {}).values():
if ports:
for ip_and_port in ports:
host_port = ip_and_port.get("HostPort")
if host_port:
host_ports.add(host_port)
return host_ports
def is_running(self, docker, ping_url: Callable[[int], str]) -> bool:
"""
Check if the container is running.
`ping_url` takes a port number and returns a URL string that can be used
to determine if the host is running by returning a 200 status.
If `self.host_ports()` returns a non-empty set, and one `ping_url`
request is successful, the container is considered running.
"""
host_ports = self.host_ports(docker)
for port in host_ports:
url = ping_url(port)
print(f"Trying {url} ... ", end='', flush=True)
try:
with urlopen(url, timeout=0.2) as response:
if response.status == 200:
print("ok")
return True
except Exception as e:
print(f"error: {e}")
continue
print(f"container {self.name} not running")
return False
class DockerManager:
"""Manages all Docker and Docker Compose operations."""
def __init__(self, compose_file: str, **config):
self.compose_file = compose_file
self.network_name = config.get('network_name') or \
os.getenv('DOCKER_NETWORK_NAME', 'private_spacetime_cloud')
self.control_db_container = config.get('control_db_container') or \
os.getenv('CONTROL_DB_CONTAINER', 'node')
self.spacetime_cli_bin = config.get('spacetime_cli_bin') or \
os.getenv('SPACETIME_CLI_BIN', 'spacetimedb-cloud')
def _execute_command(self, *args: str) -> str:
"""Execute a Docker command and return its output."""
try:
result = subprocess.run(
args,
capture_output=True,
text=True,
check=True
)
return result.stdout.strip()
except subprocess.CalledProcessError as e:
print(f"Command failed: {e.stderr}")
raise
except Exception as e:
print(f"Unexpected error: {str(e)}")
raise
def compose(self, *args: str) -> str:
"""Execute a `docker compose` command."""
return self._execute_command("docker", "compose", "-f", self.compose_file, *args)
def docker(self, *args: str) -> str:
"""Execute a `docker` command."""
return self._execute_command("docker", *args)
def list_containers(self, *filters) -> List[DockerContainer]:
"""
List the containers of the current compose file and return as DockerContainer objects.
All containers are considered, even if not running ('-a' flag).
The containers may be filtered by 'filters' ('--filter' option).
"""
# Use -a so we don't miss a crashed or killed container
# when checking for readiness.
cmd = ["ps", "-a"]
# Restrict to the current compose file.
compose_file = os.path.abspath(COMPOSE_FILE)
cmd.extend(["--filter", f"label=com.docker.compose.project.config_files={compose_file}"])
# Apply additional filters.
for f in filters:
cmd.extend(["--filter", f])
# Output only the fields we need for `DockerContainer`.
cmd.extend(["--format", "{{.ID}} {{.Names}}"])
output = self.docker(*cmd)
containers = []
for line in output.splitlines():
if line.strip():
container_id, name = line.split(maxsplit=1)
containers.append(DockerContainer(id=container_id, name=name))
return containers
def list_spacetimedb_containers(self) -> List[DockerContainer]:
"""List all containers running spacetimedb."""
return self.list_containers("label=app=spacetimedb")
def inspect_container(self, container: DockerContainer):
"""Run the `inspect` command for `container`, returning the parsed JSON dict."""
info = self.docker("inspect", container.name)
return json.loads(info)[0]
def get_container_by_name(self, name: str) -> Optional[DockerContainer]:
"""Find a container by name pattern."""
return next(
(c for c in self.list_containers() if name in c.name),
None
)
def kill_container(self, container_id: str):
"""Kill a container by ID."""
print(f"Killing container {container_id}")
self.docker("kill", container_id)
def start_container(self, container_id: str):
"""Start a container by ID."""
print(f"Starting container {container_id}")
self.docker("start", container_id)
def disconnect_container(self, container_id: str):
"""Disconnect a container from the network."""
print(f"Disconnecting container {container_id}")
self.docker("network", "disconnect", self.network_name, container_id)
print(f"Disconnected container {container_id}")
def connect_container(self, container_id: str):
"""Connect a container to the network."""
print(f"Connecting container {container_id}")
self.docker("network", "connect", self.network_name, container_id)
print(f"Connected container {container_id}")
def generate_root_token(self) -> str:
"""Generate a root token using spacetimedb-cloud."""
return self.compose(
"exec", self.control_db_container, self.spacetime_cli_bin, "token", "gen",
"--subject=placeholder-node-id",
"--jwt-priv-key", "/etc/spacetimedb/keys/id_ecdsa").split('|')[1]
-3
View File
@@ -1,3 +0,0 @@
psycopg2-binary
toml
xmltodict
View File
-92
View File
@@ -1,92 +0,0 @@
from .. import Smoketest, random_string
class AddRemoveIndex(Smoketest):
AUTOPUBLISH = False
MODULE_CODE = """
use spacetimedb::{ReducerContext, Table};
#[spacetimedb::table(accessor = t1)]
pub struct T1 { id: u64 }
#[spacetimedb::table(accessor = t2)]
pub struct T2 { id: u64 }
#[spacetimedb::reducer(init)]
pub fn init(ctx: &ReducerContext) {
for id in 0..1_000 {
ctx.db.t1().insert(T1 { id });
ctx.db.t2().insert(T2 { id });
}
}
"""
MODULE_CODE_INDEXED = """
use spacetimedb::{ReducerContext, Table};
#[spacetimedb::table(accessor = t1)]
pub struct T1 { #[index(btree)] id: u64 }
#[spacetimedb::table(accessor = t2)]
pub struct T2 { #[index(btree)] id: u64 }
#[spacetimedb::reducer(init)]
pub fn init(ctx: &ReducerContext) {
for id in 0..1_000 {
ctx.db.t1().insert(T1 { id });
ctx.db.t2().insert(T2 { id });
}
}
#[spacetimedb::reducer]
pub fn add(ctx: &ReducerContext) {
let id = 1_001;
ctx.db.t1().insert(T1 { id });
ctx.db.t2().insert(T2 { id });
}
"""
JOIN_QUERY = "select t_1.* from t_1 join t_2 on t_1.id = t_2.id where t_2.id = 1001"
def between_publishes(self):
"""
The test `AddRemoveIndexAfterRestart` in `zz_docker.py`
overwrites this method to restart docker between each publish,
otherwise reusing this test's code.
"""
pass
def test_add_then_remove_index(self):
"""
First publish without the indices,
then add the indices, and publish,
and finally remove the indices, and publish again.
There should be no errors
and the unindexed versions should reject subscriptions.
"""
name = random_string()
# Publish and attempt a subscribing to a join query.
# There are no indices, resulting in an unsupported unindexed join.
self.publish_module(name, clear = False)
with self.assertRaises(Exception):
self.subscribe(self.JOIN_QUERY, n = 0)
self.between_publishes()
# Publish the indexed version.
# Now we have indices, so the query should be accepted.
self.write_module_code(self.MODULE_CODE_INDEXED)
self.publish_module(name, clear = False)
sub = self.subscribe(self.JOIN_QUERY, n = 1)
self.call("add", anon = True)
sub()
self.between_publishes()
# Publish the unindexed version again, removing the index.
# The initial subscription should be rejected again.
self.write_module_code(self.MODULE_CODE)
self.publish_module(name, clear = False)
with self.assertRaises(Exception):
self.subscribe(self.JOIN_QUERY, n = 0)
-134
View File
@@ -1,134 +0,0 @@
from .. import Smoketest
import string
import functools
ints = "u8", "u16", "u32", "u64", "u128", "i8", "i16", "i32", "i64", "i128"
def reducer_name(int_ty: str) -> str:
# Convert "u8" -> "u_8", "i128" -> "i_128"
return f"{int_ty[0]}_{int_ty[1:]}"
class IntTests:
make_func = lambda int_ty: lambda self: self.do_test_autoinc(int_ty)
for int_ty in ints:
locals()[f"test_autoinc_{int_ty}"] = make_func(int_ty)
del int_ty, make_func
autoinc1_template = string.Template("""
#[spacetimedb::table(accessor = person_$KEY_TY)]
pub struct Person_$KEY_TY {
#[auto_inc]
key_col: $KEY_TY,
name: String,
}
#[spacetimedb::reducer]
pub fn add_$REDUCER_TY(ctx: &ReducerContext, name: String, expected_value: $KEY_TY) {
let value = ctx.db.person_$KEY_TY().insert(Person_$KEY_TY { key_col: 0, name });
assert_eq!(value.key_col, expected_value);
}
#[spacetimedb::reducer]
pub fn say_hello_$REDUCER_TY(ctx: &ReducerContext) {
for person in ctx.db.person_$KEY_TY().iter() {
log::info!("Hello, {}:{}!", person.key_col, person.name);
}
log::info!("Hello, World!");
}
""")
class AutoincBasic(IntTests, Smoketest):
"This tests the auto_inc functionality"
MODULE_CODE = f"""
#![allow(non_camel_case_types)]
use spacetimedb::{{log, ReducerContext, Table}};
{"".join(
autoinc1_template.substitute(
KEY_TY=int_ty,
REDUCER_TY=reducer_name(int_ty),
)
for int_ty in ints
)}
"""
def do_test_autoinc(self, int_ty):
r = reducer_name(int_ty)
self.call(f"add_{r}", "Robert", 1)
self.call(f"add_{r}", "Julie", 2)
self.call(f"add_{r}", "Samantha", 3)
self.call(f"say_hello_{r}")
logs = self.logs(4)
self.assertIn("Hello, 3:Samantha!", logs)
self.assertIn("Hello, 2:Julie!", logs)
self.assertIn("Hello, 1:Robert!", logs)
self.assertIn("Hello, World!", logs)
autoinc2_template = string.Template("""
#[spacetimedb::table(accessor = person_$KEY_TY)]
pub struct Person_$KEY_TY {
#[auto_inc]
#[unique]
key_col: $KEY_TY,
#[unique]
name: String,
}
#[spacetimedb::reducer]
pub fn add_new_$REDUCER_TY(ctx: &ReducerContext, name: String) -> Result<(), Box<dyn Error>> {
let value = ctx.db.person_$KEY_TY().try_insert(Person_$KEY_TY { key_col: 0, name })?;
log::info!("Assigned Value: {} -> {}", value.key_col, value.name);
Ok(())
}
#[spacetimedb::reducer]
pub fn update_$REDUCER_TY(ctx: &ReducerContext, name: String, new_id: $KEY_TY) {
ctx.db.person_$KEY_TY().name().delete(&name);
let _value = ctx.db.person_$KEY_TY().insert(Person_$KEY_TY { key_col: new_id, name });
}
#[spacetimedb::reducer]
pub fn say_hello_$REDUCER_TY(ctx: &ReducerContext) {
for person in ctx.db.person_$KEY_TY().iter() {
log::info!("Hello, {}:{}!", person.key_col, person.name);
}
log::info!("Hello, World!");
}
""")
class AutoincUnique(IntTests, Smoketest):
"""This tests unique constraints being violated during autoinc insertion"""
MODULE_CODE = f"""
#![allow(non_camel_case_types)]
use std::error::Error;
use spacetimedb::{{log, ReducerContext, Table}};
{"".join(
autoinc2_template.substitute(
KEY_TY=int_ty,
REDUCER_TY=reducer_name(int_ty),
)
for int_ty in ints
)}
"""
def do_test_autoinc(self, int_ty):
r = reducer_name(int_ty)
self.call(f"update_{r}", "Robert", 2)
self.call(f"add_new_{r}", "Success")
with self.assertRaises(Exception):
self.call(f"add_new_{r}", "Failure")
self.call(f"say_hello_{r}")
logs = self.logs(4)
self.assertIn("Hello, 2:Robert!", logs)
self.assertIn("Hello, 1:Success!", logs)
self.assertIn("Hello, World!", logs)
-361
View File
@@ -1,361 +0,0 @@
from .. import Smoketest
import sys
import logging
class AddTableAutoMigration(Smoketest):
MODULE_CODE_INIT = """
use spacetimedb::{log, ReducerContext, Table, SpacetimeType};
use PersonKind::*;
#[spacetimedb::table(accessor = person, public)]
pub struct Person {
name: String,
kind: PersonKind,
}
#[spacetimedb::reducer]
pub fn add_person(ctx: &ReducerContext, name: String, kind: String) {
let kind = kind_from_string(kind);
ctx.db.person().insert(Person { name, kind });
}
#[spacetimedb::reducer]
pub fn print_persons(ctx: &ReducerContext, prefix: String) {
for person in ctx.db.person().iter() {
let kind = kind_to_string(person.kind);
log::info!("{prefix}: {} - {kind}", person.name);
}
}
#[spacetimedb::table(accessor = point_mass)]
pub struct PointMass {
mass: f64,
/// This used to cause an error when check_compatible did not resolve types in a `ModuleDef`.
position: Vector2,
}
#[derive(SpacetimeType, Clone, Copy)]
pub struct Vector2 {
x: f64,
y: f64,
}
"""
MODULE_CODE = MODULE_CODE_INIT + """
#[spacetimedb::table(accessor = person_info)]
pub struct PersonInfo {
#[primary_key]
id: u64,
}
#[derive(SpacetimeType, Clone, Copy, PartialEq, Eq)]
pub enum PersonKind {
Student,
}
fn kind_from_string(_: String) -> PersonKind {
Student
}
fn kind_to_string(Student: PersonKind) -> &'static str {
"Student"
}
"""
MODULE_CODE_UPDATED = (
MODULE_CODE_INIT
+ """
#[spacetimedb::table(accessor = person_info)]
pub struct PersonInfo {
#[primary_key]
#[auto_inc]
id: u64,
}
#[derive(SpacetimeType, Clone, Copy, PartialEq, Eq)]
pub enum PersonKind {
Student,
Professor,
}
fn kind_from_string(kind: String) -> PersonKind {
match &*kind {
"Student" => Student,
"Professor" => Professor,
_ => panic!(),
}
}
fn kind_to_string(kind: PersonKind) -> &'static str {
match kind {
Student => "Student",
Professor => "Professor",
}
}
#[spacetimedb::table(accessor = book, public)]
pub struct Book {
isbn: String,
}
#[spacetimedb::reducer]
pub fn add_book(ctx: &ReducerContext, isbn: String) {
ctx.db.book().insert(Book { isbn });
}
#[spacetimedb::reducer]
pub fn print_books(ctx: &ReducerContext, prefix: String) {
for book in ctx.db.book().iter() {
log::info!("{}: {}", prefix, book.isbn);
}
}
"""
)
def test_add_table_auto_migration(self):
"""This tests uploading a module with a schema change that should not require clearing the database."""
logging.info("Initial publish complete")
# Start a subscription before publishing the module, to test that the subscription remains intact after re-publishing.
sub = self.subscribe("select * from person", n=4, confirmed=False)
# initial module code is already published by test framework
self.call("add_person", "Robert", "Student")
self.call("add_person", "Julie", "Student")
self.call("add_person", "Samantha", "Student")
self.call("print_persons", "BEFORE")
logs = self.logs(100)
self.assertIn("BEFORE: Samantha - Student", logs)
self.assertIn("BEFORE: Julie - Student", logs)
self.assertIn("BEFORE: Robert - Student", logs)
logging.info(
"Initial operations complete, updating module without clear",
)
self.write_module_code(self.MODULE_CODE_UPDATED)
self.publish_module(self.database_identity, clear=False)
logging.info("Updated")
self.call("add_person", "Husserl", "Student")
# If subscription, we should get 4 rows corresponding to 4 reducer calls (including before and after update)
sub = sub();
self.assertEqual(len(sub), 4)
self.logs(100)
self.call("add_person", "Husserl", "Professor")
self.call("add_book", "1234567890")
self.call("print_persons", "AFTER_PERSON")
self.call("print_books", "AFTER_BOOK")
logs = self.logs(100)
self.assertIn("AFTER_PERSON: Samantha - Student", logs)
self.assertIn("AFTER_PERSON: Julie - Student", logs)
self.assertIn("AFTER_PERSON: Robert - Student", logs)
self.assertIn("AFTER_PERSON: Husserl - Professor", logs)
self.assertIn("AFTER_BOOK: 1234567890", logs)
class RejectTableChanges(Smoketest):
MODULE_CODE = """
use spacetimedb::{log, ReducerContext, Table};
#[spacetimedb::table(accessor = person)]
pub struct Person {
name: String,
}
#[spacetimedb::reducer]
pub fn add_person(ctx: &ReducerContext, name: String) {
ctx.db.person().insert(Person { name });
}
#[spacetimedb::reducer]
pub fn print_persons(ctx: &ReducerContext, prefix: String) {
for person in ctx.db.person().iter() {
log::info!("{}: {}", prefix, person.name);
}
}
"""
MODULE_CODE_UPDATED = """
use spacetimedb::{log, ReducerContext, Table};
#[spacetimedb::table(accessor = person)]
pub struct Person {
name: String,
age: u128,
}
#[spacetimedb::reducer]
pub fn add_person(ctx: &ReducerContext, name: String) {
ctx.db.person().insert(Person { name, age: 70 });
}
#[spacetimedb::reducer]
pub fn print_persons(ctx: &ReducerContext, prefix: String) {
for person in ctx.db.person().iter() {
log::info!("{}: {}", prefix, person.name);
}
}
"""
def test_reject_schema_changes(self):
"""This tests that a module with invalid schema changes cannot be published without -c or a migration."""
logging.info("Initial publish complete, trying to do an invalid update.")
with self.assertRaises(Exception):
self.write_module_code(self.MODULE_CODE_UPDATED)
self.publish_module(self.database_identity, clear=False)
logging.info("Rejected as expected.")
class AddTableColumns(Smoketest):
MODULE_CODE = """
use spacetimedb::{log, ReducerContext, Table};
#[derive(Debug)]
#[spacetimedb::table(accessor = person)]
pub struct Person {
name: String,
}
#[spacetimedb::reducer]
pub fn add_person(ctx: &ReducerContext, name: String) {
ctx.db.person().insert(Person { name });
}
#[spacetimedb::reducer]
pub fn print_persons(ctx: &ReducerContext, prefix: String) {
for person in ctx.db.person().iter() {
log::info!("{}: {}", prefix, person.name);
}
}
"""
MODULE_UPDATED = """
use spacetimedb::{log, ReducerContext, Table};
#[derive(Debug)]
#[spacetimedb::table(accessor = person)]
pub struct Person {
// Add indexes to verify they are handled correctly during migration,
// issue #3441
#[index(btree)]
name: String,
#[default(0)]
age: u16,
#[default(19)]
mass: u16,
}
#[spacetimedb::reducer]
pub fn add_person(ctx: &ReducerContext, name: String) {
ctx.db.person().insert(Person { name, age: 70, mass: 180 });
}
#[spacetimedb::reducer]
pub fn print_persons(ctx: &ReducerContext, prefix: String) {
for person in ctx.db.person().iter() {
log::info!("{}: {:?}", prefix, person);
}
}
#[spacetimedb::reducer(client_disconnected)]
pub fn identity_disconnected(ctx: &ReducerContext) {
log::info!("FIRST_UPDATE: client disconnected");
}
"""
MODULE_UPDATED_AGAIN = """
use spacetimedb::{log, ReducerContext, Table};
#[derive(Debug)]
#[spacetimedb::table(accessor = person)]
pub struct Person {
name: String,
age: u16,
#[default(19)]
mass: u16,
#[default(160)]
height: u32,
}
#[spacetimedb::reducer]
pub fn add_person(ctx: &ReducerContext, name: String) {
ctx.db.person().insert(Person { name, age: 70, mass: 180, height: 72 });
}
#[spacetimedb::reducer]
pub fn print_persons(ctx: &ReducerContext, prefix: String) {
for person in ctx.db.person().iter() {
log::info!("{}: {:?}", prefix, person);
}
}
"""
def test_add_table_columns(self):
"""Verify schema upgrades that add columns with defaults (twice)."""
# Subscribe to person table changes multiple times to simulate active clients
NUM_SUBSCRIBERS = 20
subs = [None] * NUM_SUBSCRIBERS
for i in range(NUM_SUBSCRIBERS):
# We need unconfirmed reads for the updates to arrive properly.
# Otherwise, there's a race between module teardown in publish, vs subscribers
# getting the row deletion they expect.
subs[i]= self.subscribe("select * from person", n=5, confirmed=False)
# Insert under initial schema
self.call("add_person", "Robert")
# First upgrade: add age & mass columns
self.write_module_code(self.MODULE_UPDATED)
self.publish_module(self.database_identity, clear=False, break_clients=True)
self.call("print_persons", "FIRST_UPDATE")
logs1 = self.logs(100)
# Validate disconnect + schema migration logs
self.assertIn("Disconnecting all users", logs1)
self.assertIn(
'FIRST_UPDATE: Person { name: "Robert", age: 0, mass: 19 }',
logs1,
)
disconnect_count = logs1.count("FIRST_UPDATE: client disconnected")
# Insert new data under upgraded schema
self.call("add_person", "Robert2")
self.assertEqual(
disconnect_count,
# +1 is due to reducer call above
NUM_SUBSCRIBERS + 1,
msg=f"Unexpected disconnect counts: {disconnect_count}",
)
# Validate all subscribers were disconnected after first upgrade
# they should 2 updates: one for initial insertion and one for table drop during migration
for i in range(NUM_SUBSCRIBERS):
sub = subs[i]()
self.assertEqual(len(sub), 2, msg=f"Subscriber {i} received unexpected rows: {sub}")
# Second upgrade
self.write_module_code(self.MODULE_UPDATED_AGAIN)
self.publish_module(self.database_identity, clear=False, break_clients=True)
self.call("print_persons", "UPDATE_2")
logs2 = self.logs(100)
# Validate new schema with height
self.assertIn(
'UPDATE_2: Person { name: "Robert2", age: 70, mass: 180, height: 160 }',
logs2,
)
-151
View File
@@ -1,151 +0,0 @@
import string
from .. import Smoketest
class CallReducerProcedure(Smoketest):
MODULE_CODE = """
use spacetimedb::{log, ProcedureContext, ReducerContext, Table};
#[spacetimedb::table(accessor = person)]
pub struct Person {
name: String,
}
#[spacetimedb::reducer]
pub fn say_hello(_ctx: &ReducerContext) {
log::info!("Hello, World!");
}
#[spacetimedb::procedure]
pub fn return_person(_ctx: &mut ProcedureContext) -> Person {
return Person { name: "World".to_owned() };
}
"""
def test_call_reducer_procedure(self):
"""Check calling a reducer (not return) and procedure (return)"""
msg = self.call("say_hello")
self.assertEqual(msg, "")
msg = self.call("return_person")
self.assertEqual(msg.strip(), '["World"]')
def test_call_errors(self):
"""Check calling a non-existent reducer/procedure raises error"""
out = self.call("non_existent_reducer", check= False, full_output=True).stderr
identity = self.database_identity
self.assertIn(out.strip(), f"""
WARNING: This command is UNSTABLE and subject to breaking changes.
Error: No such reducer OR procedure `non_existent_reducer` for database `{identity}` resolving to identity `{identity}`.
Here are some existing reducers:
- say_hello
Here are some existing procedures:
- return_person""".strip())
out = self.call("non_existent_procedure", check= False, full_output=True).stderr
self.assertIn(out.strip(), f"""
WARNING: This command is UNSTABLE and subject to breaking changes.
Error: No such reducer OR procedure `non_existent_procedure` for database `{identity}` resolving to identity `{identity}`.
Here are some existing reducers:
- say_hello
Here are some existing procedures:
- return_person""".strip())
out = self.call("say_hell", check= False, full_output=True).stderr
self.assertIn(out.strip(), f"""
WARNING: This command is UNSTABLE and subject to breaking changes.
Error: No such reducer OR procedure `say_hell` for database `{identity}` resolving to identity `{identity}`.
A reducer with a similar name exists: `say_hello`""".strip())
out = self.call("return_perso", check= False, full_output=True).stderr
self.assertIn(out.strip(), f"""
WARNING: This command is UNSTABLE and subject to breaking changes.
Error: No such reducer OR procedure `return_perso` for database `{identity}` resolving to identity `{identity}`.
A procedure with a similar name exists: `return_person`""".strip())
class CallEmptyReducerProcedure(Smoketest):
MODULE_CODE = """
use spacetimedb::{log, ReducerContext, Table};
#[spacetimedb::table(accessor = person)]
pub struct Person {
name: String,
}
"""
def test_call_empty_errors(self):
"""Check calling into a database with no reducers/procedures raises error"""
out = self.call("non_existent", check= False, full_output=True).stderr
identity = self.database_identity
self.assertIn(out.strip(), f"""
WARNING: This command is UNSTABLE and subject to breaking changes.
Error: No such reducer OR procedure `non_existent` for database `{identity}` resolving to identity `{identity}`.
The database has no reducers.
The database has no procedures.""".strip())
module_template = string.Template("""
#[spacetimedb::reducer]
pub fn say_reducer_$NUM(_ctx: &ReducerContext) {
log::info!("Hello from reducer $NUM!");
}
#[spacetimedb::procedure]
pub fn say_procedure_$NUM(_ctx: &mut ProcedureContext) {
log::info!("Hello from procedure $NUM!");
}
""")
class CallManyReducerProcedure(Smoketest):
MODULE_CODE = f"""
use spacetimedb::{{log, ProcedureContext, ReducerContext}};
{"".join(module_template.substitute(NUM=i) for i in range(11))}
"""
def test_call_many_errors(self):
"""Check calling into a database with many reducers/procedures raises error with listing"""
out = self.call("non_existent", check= False, full_output=True).stderr
identity = self.database_identity
self.assertIn(out.strip(), f"""
WARNING: This command is UNSTABLE and subject to breaking changes.
Error: No such reducer OR procedure `non_existent` for database `{identity}` resolving to identity `{identity}`.
Here are some existing reducers:
- say_reducer_0
- say_reducer_1
- say_reducer_2
- say_reducer_3
- say_reducer_4
- say_reducer_5
- say_reducer_6
- say_reducer_7
- say_reducer_8
- say_reducer_9
... (1 reducer not shown)
Here are some existing procedures:
- say_procedure_0
- say_procedure_1
- say_procedure_2
- say_procedure_3
- say_procedure_4
- say_procedure_5
- say_procedure_6
- say_procedure_7
- say_procedure_8
- say_procedure_9
... (1 procedure not shown)
""".strip())
-93
View File
@@ -1,93 +0,0 @@
from .. import Smoketest, random_string
import time
class ClearDatabase(Smoketest):
AUTOPUBLISH = False
MODULE_CODE = """
use spacetimedb::{ReducerContext, Table, duration};
#[spacetimedb::table(accessor = counter, public)]
pub struct Counter {
#[primary_key]
id: u64,
val: u64
}
#[spacetimedb::table(accessor = scheduled_counter, public, scheduled(inc, at = sched_at))]
pub struct ScheduledCounter {
#[primary_key]
#[auto_inc]
scheduled_id: u64,
sched_at: spacetimedb::ScheduleAt,
}
#[spacetimedb::reducer]
pub fn inc(ctx: &ReducerContext, arg: ScheduledCounter) {
if let Some(mut counter) = ctx.db.counter().id().find(arg.scheduled_id) {
counter.val += 1;
ctx.db.counter().id().update(counter);
} else {
ctx.db.counter().insert(Counter {
id: arg.scheduled_id,
val: 1,
});
}
}
#[spacetimedb::reducer(init)]
pub fn init(ctx: &ReducerContext) {
ctx.db.scheduled_counter().insert(ScheduledCounter {
scheduled_id: 0,
sched_at: duration!(100ms).into(),
});
}
"""
def test_publish_clear_database(self):
"""
Test that publishing with the clear flag stops the old module.
This relies on private control database internals.
"""
name = random_string()
# Initial publish
replicas_1 = self.publish(name, clear = False)
self.assertTrue(len(replicas_1) >= 1)
# Publish with clear = True, destroying `replicas_1`
replicas_2 = self.publish(name, clear = True)
self.assertTrue(len(replicas_2) >= 1)
self.assertNotEqual(replicas_1, replicas_2)
# Delete the replicas created in the second publish
self.spacetime("delete", name)
# State updates don't happen instantly
time.sleep(0.25)
# Check that all replicas have state `Deleted`
replicas = replicas_1 + replicas_2
state_filter = f'replica_id = {" OR replica_id = ".join(replicas)}'
states = self.query_control(f"select lifecycle from replica_state where {state_filter}")
self.assertEqual(len(states), len(replicas))
self.assertTrue(all([x == "(Deleted = ())" for x in states]))
def publish(self, name, clear):
self.publish_module(name, clear = clear)
replicas = self.query_control(f"""
select replica.id from replica
join database on database.id = replica.database_id
where database.database_identity = '0x{self.resolved_identity}'
""")
return replicas
def query_control(self, sql):
out = self.spacetime("sql", "spacetime-control", sql)
out = [line.strip() for line in out.splitlines()]
out = out[2:] # Remove header
return out
@@ -1,70 +0,0 @@
from .. import Smoketest
MODULE_HEADER = """
use spacetimedb::{ReducerContext, Table};
#[spacetimedb::table(accessor = all_u8s, public)]
pub struct AllU8s {
number: u8,
}
#[spacetimedb::reducer(init)]
pub fn init(ctx: &ReducerContext) {
// Here's a bunch of data that no one will be able to subscribe to.
for i in u8::MIN..=u8::MAX {
ctx.db.all_u8s().insert(AllU8s { number: i });
}
}
"""
class ClientConnectedErrorRejectsConnection(Smoketest):
MODULE_CODE = MODULE_HEADER + """
#[spacetimedb::reducer(client_connected)]
pub fn identity_connected(ctx: &ReducerContext) -> Result<(), String> {
Err("Rejecting connection from client".to_string())
}
#[spacetimedb::reducer(client_disconnected)]
pub fn identity_disconnected(_ctx: &ReducerContext) {
panic!("This should never be called, since we reject all connections!")
}
"""
def test_client_connected_error_rejects_connection(self):
with self.assertRaises(Exception):
self.subscribe("select * from all_u8s", n = 0)()
logs = self.logs(100)
self.assertIn('Rejecting connection from client', logs)
self.assertNotIn('This should never be called, since we reject all connections!', logs)
class ClientDisconnectedErrorStillDeletesStClient(Smoketest):
MODULE_CODE = MODULE_HEADER + """
#[spacetimedb::reducer(client_connected)]
pub fn identity_connected(_ctx: &ReducerContext) -> Result<(), String> {
Ok(())
}
#[spacetimedb::reducer(client_disconnected)]
pub fn identity_disconnected(_ctx: &ReducerContext) {
panic!("This should be called, but the `st_client` row should still be deleted")
}
"""
def test_client_disconnected_error_still_deletes_st_client(self):
self.subscribe("select * from all_u8s", n = 0)()
logs = self.logs(100)
self.assertIn('This should be called, but the `st_client` row should still be deleted', logs)
sql_out = self.spacetime("sql", self.database_identity, "select * from st_client")
# The SQL query itself now creates a temporary connection, so we may
# see exactly one row (the SQL connection's own). The websocket's row
# should be gone. Count non-header, non-separator lines with content.
lines = sql_out.strip().split('\n')
# Data rows are those that are not the header and not the separator line
data_rows = [l for l in lines if '|' in l and '-+-' not in l and 'identity' not in l.lower()]
self.assertLessEqual(len(data_rows), 1,
f"Expected at most 1 st_client row (the SQL connection itself), got: {sql_out}")
-54
View File
@@ -1,54 +0,0 @@
from .. import Smoketest, parse_sql_result
#
# TODO: We only test that we can pass a --confirmed flag and that things
# appear to works as if we hadn't. Without controlling the server, we can't
# test that there is any difference in behavior.
#
class ConfirmedReads(Smoketest):
def test_confirmed_reads_receive_updates(self):
"""Tests that subscribing with confirmed=true receives updates"""
sub = self.subscribe("select * from person", n = 2, confirmed = True)
self.call("add", "Horst")
self.spacetime(
"sql",
self.database_identity,
"insert into person (name) values ('Egon')")
events = sub()
self.assertEqual([
{
'person': {
'deletes': [],
'inserts': [{'name': 'Horst'}]
}
},
{
'person': {
'deletes': [],
'inserts': [{'name': 'Egon'}]
}
}
], events)
class ConfirmedReadsSql(Smoketest):
def test_sql_with_confirmed_reads_receives_result(self):
"""Tests that an SQL operations with confirmed=true returns a result"""
self.spacetime(
"sql",
"--confirmed",
"true",
self.database_identity,
"insert into person (name) values ('Horst')")
res = self.spacetime(
"sql",
"--confirmed",
"true",
self.database_identity,
"select * from person")
res = parse_sql_result(str(res))
self.assertEqual([{'name': '"Horst"'}], res)
@@ -1,33 +0,0 @@
from .. import Smoketest
class ConnDisconnFromCli(Smoketest):
MODULE_CODE = """
use spacetimedb::{log, ReducerContext};
#[spacetimedb::reducer(client_connected)]
pub fn connected(_ctx: &ReducerContext) {
log::info!("_connect called");
}
#[spacetimedb::reducer(client_disconnected)]
pub fn disconnected(_ctx: &ReducerContext) {
log::info!("disconnect called");
}
#[spacetimedb::reducer]
pub fn say_hello(_ctx: &ReducerContext) {
log::info!("Hello, World!");
}
"""
def test_conn_disconn(self):
"""
Ensure that the connect and disconnect functions are called when invoking a reducer from the CLI
"""
self.call("say_hello")
logs = self.logs(10)
self.assertIn('_connect called', logs)
self.assertIn('disconnect called', logs)
self.assertIn('Hello, World!', logs)
-18
View File
@@ -1,18 +0,0 @@
from .. import spacetime
import unittest
import tempfile
class CreateProject(unittest.TestCase):
def test_create_project(self):
"""
Ensure that the CLI is able to create a local project. This test does not depend on a running spacetimedb instance.
"""
with tempfile.TemporaryDirectory() as tmpdir:
with self.assertRaises(Exception):
spacetime("init", "--non-interactive", "test-project")
with self.assertRaises(Exception):
spacetime("init", "--non-interactive", "--project-path", tmpdir, "test-project")
spacetime("init", "--non-interactive", "--lang=rust", "--project-path", tmpdir, "test-project")
with self.assertRaises(Exception):
spacetime("init", "--non-interactive", "--lang=rust", "--project-path", tmpdir, "test-project")
-81
View File
@@ -1,81 +0,0 @@
from .. import run_cmd, STDB_DIR, requires_dotnet, spacetime
import unittest
import tempfile
from pathlib import Path
import shutil
import subprocess
import xml.etree.ElementTree as xml
@requires_dotnet
class CreateProject(unittest.TestCase):
def test_build_csharp_module(self):
"""
Ensure that the CLI is able to create and compile a csharp project. This test does not depend on a running spacetimedb instance. Skips if dotnet 8.0 is not available
"""
bindings = Path(STDB_DIR) / "crates" / "bindings-csharp"
try:
run_cmd("dotnet", "nuget", "locals", "all", "--clear", cwd=bindings, capture_stderr=True)
run_cmd("dotnet", "workload", "install", "wasi-experimental", "--skip-manifest-update", cwd=STDB_DIR / "modules")
run_cmd("dotnet", "pack", cwd=bindings, capture_stderr=True)
with tempfile.TemporaryDirectory() as tmpdir:
spacetime(
"init",
"--non-interactive",
"--lang=csharp",
"--project-path",
tmpdir,
"csharp-project",
)
server_path = Path(tmpdir) / "spacetimedb"
packed_projects = ["BSATN.Runtime", "Runtime"]
config = xml.Element("configuration")
sources = xml.SubElement(config, "packageSources")
mappings = xml.SubElement(config, "packageSourceMapping")
def add_mapping(source, pattern):
mapping = xml.SubElement(mappings, "packageSource", key=source)
xml.SubElement(mapping, "package", pattern=pattern)
for project in packed_projects:
# Add local build directories as NuGet repositories.
path = bindings / project / "bin" / "Release"
project = f"SpacetimeDB.{project}"
xml.SubElement(sources, "add", key=project, value=str(path))
# Add strict package source mappings to ensure that
# SpacetimeDB.* packages are used from those directories
# and never from nuget.org.
#
# This prevents bugs where we silently used an outdated
# version which led to tests passing when they shouldn't.
add_mapping(project, project)
# Add fallback for other packages.
add_mapping("nuget.org", "*")
xml.indent(config)
config = xml.tostring(config, encoding="unicode", xml_declaration=True)
print("Writing `nuget.config` contents:")
print(config)
config_path = server_path / "nuget.config"
with open(config_path, "w") as f:
f.write(config)
run_cmd("dotnet", "publish", cwd=server_path, capture_stderr=True)
except subprocess.CalledProcessError as e:
print(e)
print("output:")
print(e.output)
raise e
-12
View File
@@ -1,12 +0,0 @@
import unittest
import tempfile
import subprocess
from .. import Smoketest
class ClippyDefaultModule(Smoketest):
AUTOPUBLISH = False
def test_default_module_clippy_check(self):
"""Ensure that the default rust module has no clippy errors or warnings"""
subprocess.check_call(["cargo", "clippy", "--", "-Dwarnings"], cwd=self.project_path)
-63
View File
@@ -1,63 +0,0 @@
from .. import Smoketest, random_string
import time
class DeleteDatabase(Smoketest):
AUTOPUBLISH = False
MODULE_CODE = """
use spacetimedb::{ReducerContext, Table, duration};
#[spacetimedb::table(accessor = counter, public)]
pub struct Counter {
#[primary_key]
id: u64,
val: u64
}
#[spacetimedb::table(accessor = scheduled_counter, public, scheduled(inc, at = sched_at))]
pub struct ScheduledCounter {
#[primary_key]
#[auto_inc]
scheduled_id: u64,
sched_at: spacetimedb::ScheduleAt,
}
#[spacetimedb::reducer]
pub fn inc(ctx: &ReducerContext, arg: ScheduledCounter) {
if let Some(mut counter) = ctx.db.counter().id().find(arg.scheduled_id) {
counter.val += 1;
ctx.db.counter().id().update(counter);
} else {
ctx.db.counter().insert(Counter {
id: arg.scheduled_id,
val: 1,
});
}
}
#[spacetimedb::reducer(init)]
pub fn init(ctx: &ReducerContext) {
ctx.db.scheduled_counter().insert(ScheduledCounter {
scheduled_id: 0,
sched_at: duration!(100ms).into(),
});
}
"""
def test_delete_database(self):
"""
Test that deleting a database stops the module.
The module is considered stopped if its scheduled reducer stops
producing update events.
"""
name = random_string()
self.publish_module(name, clear = False)
sub = self.subscribe("select * from counter", n = 1000)
time.sleep(2)
self.spacetime("delete", "--yes", name)
updates = sub()
# At a rate of 100ms, we shouldn't have more than 20 updates in 2secs.
# But let's say 50, in case the delete gets delayed for some reason.
assert len(updates) <= 50
-31
View File
@@ -1,31 +0,0 @@
from .. import Smoketest
class ModuleDescription(Smoketest):
MODULE_CODE = """
use spacetimedb::{log, ReducerContext, Table};
#[spacetimedb::table(accessor = person)]
pub struct Person {
name: String,
}
#[spacetimedb::reducer]
pub fn add(ctx: &ReducerContext, name: String) {
ctx.db.person().insert(Person { name });
}
#[spacetimedb::reducer]
pub fn say_hello(ctx: &ReducerContext) {
for person in ctx.db.person().iter() {
log::info!("Hello, {}!", person.name);
}
log::info!("Hello, World!");
}
"""
def test_describe(self):
"""Check describing a module"""
self.spacetime("describe", "--json", self.database_identity)
self.spacetime("describe", "--json", self.database_identity, "reducer", "say_hello")
self.spacetime("describe", "--json", self.database_identity, "table", "person")
-44
View File
@@ -1,44 +0,0 @@
from .. import Smoketest
class WasmBindgen(Smoketest):
AUTOPUBLISH = False
MODULE_CODE = """
use spacetimedb::{log, ReducerContext};
#[spacetimedb::reducer]
pub fn test(_ctx: &ReducerContext) {
log::info!("Hello! {}", now());
}
#[wasm_bindgen::prelude::wasm_bindgen]
extern "C" {
fn now() -> i32;
}
"""
EXTRA_DEPS = 'wasm-bindgen = "0.2"'
def test_detect_wasm_bindgen(self):
"""Ensure that spacetime build properly catches wasm_bindgen imports"""
output = self.spacetime("build", "--module-path", self.project_path, full_output=True, check=False)
self.assertTrue(output.returncode)
self.assertIn("wasm-bindgen detected", output.stderr)
class Getrandom(Smoketest):
AUTOPUBLISH = False
MODULE_CODE = """
use spacetimedb::{log, ReducerContext};
#[spacetimedb::reducer]
pub fn test(_ctx: &ReducerContext) {
log::info!("Hello! {}", rand::random::<u8>());
}
"""
EXTRA_DEPS = 'rand = "0.8"'
def test_detect_getrandom(self):
"""Ensure that spacetime build properly catches getrandom"""
output = self.spacetime("build", "--module-path", self.project_path, full_output=True, check=False)
self.assertTrue(output.returncode)
self.assertIn("getrandom usage detected", output.stderr)
-28
View File
@@ -1,28 +0,0 @@
from .. import Smoketest
class Dml(Smoketest):
MODULE_CODE = """
use spacetimedb::{ReducerContext, Table};
#[spacetimedb::table(accessor = t, public)]
pub struct T {
name: String,
}
"""
def test_subscribe(self):
"""Test that we receive subscription updates from DML"""
# Subscribe to `t`
sub = self.subscribe("SELECT * FROM t", n=2)
self.spacetime("sql", self.database_identity, "INSERT INTO t (name) VALUES ('Alice')")
self.spacetime("sql", self.database_identity, "INSERT INTO t (name) VALUES ('Bob')")
self.assertEqual(
sub(),
[
{"t": {"deletes": [], "inserts": [{"name": "Alice"}]}},
{"t": {"deletes": [], "inserts": [{"name": "Bob"}]}},
],
)
-84
View File
@@ -1,84 +0,0 @@
from .. import Smoketest, random_string
import json
class Domains(Smoketest):
AUTOPUBLISH = False
def test_set_name(self):
"""Tests the functionality of the set-name command"""
orig_name = random_string()
self.publish_module(orig_name)
rand_name = random_string()
# This should throw an exception before there's a db with this name
with self.assertRaises(Exception):
self.spacetime("logs", rand_name)
self.spacetime("rename", "--to", rand_name, self.database_identity)
# Now we're essentially just testing that it *doesn't* throw an exception
self.spacetime("logs", rand_name)
# This should throw an exception because the original name shouldn't exist anymore
with self.assertRaises(Exception):
self.spacetime("logs", orig_name)
def test_subdomain_behavior(self):
"""Test how we treat the / character in published names"""
root_name = random_string()
self.publish_module(root_name)
# TODO: This is valid in editions with the teams feature, but
# smoketests don't know the target's edition.
# self.publish_module(f"{root_name}/test")
with self.assertRaises(Exception):
self.publish_module(f"{root_name}//test")
with self.assertRaises(Exception):
self.publish_module(f"{root_name}/test/")
def test_set_to_existing_name(self):
"""Test that we can't rename to a name already in use"""
self.publish_module()
id_to_rename = self.database_identity
rename_to = random_string()
self.publish_module(rename_to)
# This should throw an exception because there's a db with this name
with self.assertRaises(Exception):
self.spacetime("rename", "--to", rename_to, id_to_rename)
def test_replace_names(self):
"""Test that we can rename to a list of names"""
orig_name = random_string()
alt_name1 = random_string()
alt_name2 = random_string()
self.publish_module(orig_name)
self.api_call(
"PUT",
f'/v1/database/{orig_name}/names',
json.dumps([alt_name1, alt_name2]),
{"Content-type": "application/json"}
)
# Use logs to check that name resolution works
self.spacetime("logs", alt_name1)
self.spacetime("logs", alt_name2)
with self.assertRaises(Exception):
self.spacetime("logs", orig_name)
# Restore orig name so the database gets deleted on clean up
self.api_call(
"PUT",
f'/v1/database/{alt_name1}/names',
json.dumps([orig_name]),
{"Content-type": "application/json"}
)
-75
View File
@@ -1,75 +0,0 @@
from .. import Smoketest, random_string
import subprocess
class FailInitialPublish(Smoketest):
AUTOPUBLISH = False
MODULE_CODE_BROKEN = """
use spacetimedb::{client_visibility_filter, Filter};
#[spacetimedb::table(accessor = person, public)]
pub struct Person {
name: String,
}
#[client_visibility_filter]
// Bug: `Person` is the wrong table name, should be `person`.
const HIDE_PEOPLE_EXCEPT_ME: Filter = Filter::Sql("SELECT * FROM Person WHERE name = 'me'");
"""
MODULE_CODE_FIXED = """
use spacetimedb::{client_visibility_filter, Filter};
#[spacetimedb::table(accessor = person, public)]
pub struct Person {
name: String,
}
#[client_visibility_filter]
const HIDE_PEOPLE_EXCEPT_ME: Filter = Filter::Sql("SELECT * FROM person WHERE name = 'me'");
"""
FIXED_QUERY = '"sql": "SELECT * FROM person WHERE name = \'me\'"'
def test_fail_initial_publish(self):
"""This tests that publishing an invalid module does not leave a broken entry in the control DB."""
name = random_string()
self.write_module_code(self.MODULE_CODE_BROKEN)
with self.assertRaises(Exception):
self.publish_module(name)
describe_output = self.spacetime("describe", "--json", name, full_output = True, check = False)
with self.assertRaises(subprocess.CalledProcessError):
describe_output.check_returncode()
self.assertIn("Error: No such database.", describe_output.stderr)
# We can publish a fixed module under the same database name.
# This used to be broken;
# the failed initial publish would leave the control database in a bad state.
self.write_module_code(self.MODULE_CODE_FIXED)
self.publish_module(name, clear = False)
describe_output = self.spacetime("describe", "--json", name)
self.assertIn(
self.FIXED_QUERY,
[line.strip() for line in describe_output.splitlines()],
)
# Publishing the broken code again fails, but the database still exists afterwards,
# with the previous version of the module code.
self.write_module_code(self.MODULE_CODE_BROKEN)
with self.assertRaises(Exception):
self.publish_module(name, clear = False)
describe_output = self.spacetime("describe", "--json", name)
self.assertIn(
self.FIXED_QUERY,
[line.strip() for line in describe_output.splitlines()],
)
-311
View File
@@ -1,311 +0,0 @@
from .. import Smoketest
class Filtering(Smoketest):
MODULE_CODE = """
use spacetimedb::{log, Identity, ReducerContext, Table};
#[spacetimedb::table(accessor = person)]
pub struct Person {
#[unique]
id: i32,
name: String,
#[unique]
nick: String,
}
#[spacetimedb::reducer]
pub fn insert_person(ctx: &ReducerContext, id: i32, name: String, nick: String) {
ctx.db.person().insert(Person { id, name, nick} );
}
#[spacetimedb::reducer]
pub fn insert_person_twice(ctx: &ReducerContext, id: i32, name: String, nick: String) {
// We'd like to avoid an error due to a set-semantic error.
let name2 = format!("{name}2");
ctx.db.person().insert(Person { id, name, nick: nick.clone()} );
match ctx.db.person().try_insert(Person { id, name: name2, nick: nick.clone()}) {
Ok(_) => {},
Err(_) => {
log::info!("UNIQUE CONSTRAINT VIOLATION ERROR: id = {}, nick = {}", id, nick)
}
}
}
#[spacetimedb::reducer]
pub fn delete_person(ctx: &ReducerContext, id: i32) {
ctx.db.person().id().delete(&id);
}
#[spacetimedb::reducer]
pub fn find_person(ctx: &ReducerContext, id: i32) {
match ctx.db.person().id().find(&id) {
Some(person) => log::info!("UNIQUE FOUND: id {}: {}", id, person.name),
None => log::info!("UNIQUE NOT FOUND: id {}", id),
}
}
#[spacetimedb::reducer]
pub fn find_person_read_only(ctx: &ReducerContext, id: i32) {
let ctx = ctx.as_read_only();
match ctx.db.person().id().find(&id) {
Some(person) => log::info!("UNIQUE FOUND: id {}: {}", id, person.name),
None => log::info!("UNIQUE NOT FOUND: id {}", id),
}
}
#[spacetimedb::reducer]
pub fn find_person_by_name(ctx: &ReducerContext, name: String) {
for person in ctx.db.person().iter().filter(|p| p.name == name) {
log::info!("UNIQUE FOUND: id {}: {} aka {}", person.id, person.name, person.nick);
}
}
#[spacetimedb::reducer]
pub fn find_person_by_nick(ctx: &ReducerContext, nick: String) {
match ctx.db.person().nick().find(&nick) {
Some(person) => log::info!("UNIQUE FOUND: id {}: {}", person.id, person.nick),
None => log::info!("UNIQUE NOT FOUND: nick {}", nick),
}
}
#[spacetimedb::reducer]
pub fn find_person_by_nick_read_only(ctx: &ReducerContext, nick: String) {
let ctx = ctx.as_read_only();
match ctx.db.person().nick().find(&nick) {
Some(person) => log::info!("UNIQUE FOUND: id {}: {}", person.id, person.nick),
None => log::info!("UNIQUE NOT FOUND: nick {}", nick),
}
}
#[spacetimedb::table(accessor = nonunique_person)]
pub struct NonuniquePerson {
#[index(btree)]
id: i32,
name: String,
is_human: bool,
}
#[spacetimedb::reducer]
pub fn insert_nonunique_person(ctx: &ReducerContext, id: i32, name: String, is_human: bool) {
ctx.db.nonunique_person().insert(NonuniquePerson { id, name, is_human } );
}
#[spacetimedb::reducer]
pub fn find_nonunique_person(ctx: &ReducerContext, id: i32) {
for person in ctx.db.nonunique_person().id().filter(&id) {
log::info!("NONUNIQUE FOUND: id {}: {}", id, person.name)
}
}
#[spacetimedb::reducer]
pub fn find_nonunique_person_read_only(ctx: &ReducerContext, id: i32) {
let ctx = ctx.as_read_only();
for person in ctx.db.nonunique_person().id().filter(&id) {
log::info!("NONUNIQUE FOUND: id {}: {}", id, person.name)
}
}
#[spacetimedb::reducer]
pub fn find_nonunique_humans(ctx: &ReducerContext) {
for person in ctx.db.nonunique_person().iter().filter(|p| p.is_human) {
log::info!("HUMAN FOUND: id {}: {}", person.id, person.name);
}
}
#[spacetimedb::reducer]
pub fn find_nonunique_non_humans(ctx: &ReducerContext) {
for person in ctx.db.nonunique_person().iter().filter(|p| !p.is_human) {
log::info!("NON-HUMAN FOUND: id {}: {}", person.id, person.name);
}
}
// Ensure that [Identity] is filterable and a legal unique column.
#[spacetimedb::table(accessor = identified_person)]
struct IdentifiedPerson {
#[unique]
identity: Identity,
name: String,
}
fn identify(id_number: u64) -> Identity {
let mut bytes = [0u8; 32];
bytes[..8].clone_from_slice(&id_number.to_le_bytes());
Identity::from_byte_array(bytes)
}
#[spacetimedb::reducer]
fn insert_identified_person(ctx: &ReducerContext, id_number: u64, name: String) {
let identity = identify(id_number);
ctx.db.identified_person().insert(IdentifiedPerson { identity, name });
}
#[spacetimedb::reducer]
fn find_identified_person(ctx: &ReducerContext, id_number: u64) {
let identity = identify(id_number);
match ctx.db.identified_person().identity().find(&identity) {
Some(person) => log::info!("IDENTIFIED FOUND: {}", person.name),
None => log::info!("IDENTIFIED NOT FOUND"),
}
}
// Ensure that indices on non-unique columns behave as we expect.
#[spacetimedb::table(accessor = indexed_person)]
struct IndexedPerson {
#[unique]
id: i32,
given_name: String,
#[index(btree)]
surname: String,
}
#[spacetimedb::reducer]
fn insert_indexed_person(ctx: &ReducerContext, id: i32, given_name: String, surname: String) {
ctx.db.indexed_person().insert(IndexedPerson { id, given_name, surname });
}
#[spacetimedb::reducer]
fn delete_indexed_person(ctx: &ReducerContext, id: i32) {
ctx.db.indexed_person().id().delete(&id);
}
#[spacetimedb::reducer]
fn find_indexed_people(ctx: &ReducerContext, surname: String) {
for person in ctx.db.indexed_person().surname().filter(&surname) {
log::info!("INDEXED FOUND: id {}: {}, {}", person.id, person.surname, person.given_name);
}
}
#[spacetimedb::reducer]
fn find_indexed_people_read_only(ctx: &ReducerContext, surname: String) {
let ctx = ctx.as_read_only();
for person in ctx.db.indexed_person().surname().filter(&surname) {
log::info!("INDEXED FOUND: id {}: {}, {}", person.id, person.surname, person.given_name);
}
}
"""
# TODO: split this into multiple test functions
def test_filtering(self):
"""Test filtering reducers"""
self.call("insert_person", 23, "Alice", "al")
self.call("insert_person", 42, "Bob", "bo")
self.call("insert_person", 64, "Bob", "b2")
# Find a person who is there.
self.call("find_person", 23)
self.assertIn("UNIQUE FOUND: id 23: Alice", self.logs(2))
# Find persons with the same name.
self.call("find_person_by_name", "Bob")
logs = self.logs(4)
self.assertIn("UNIQUE FOUND: id 42: Bob aka bo", logs)
self.assertIn("UNIQUE FOUND: id 64: Bob aka b2", logs)
# Fail to find a person who is not there.
self.call("find_person", 43)
self.assertIn("UNIQUE NOT FOUND: id 43", self.logs(2))
self.call("find_person_read_only", 43)
self.assertIn("UNIQUE NOT FOUND: id 43", self.logs(2))
# Find a person by nickname.
self.call("find_person_by_nick", "al")
self.assertIn("UNIQUE FOUND: id 23: al", self.logs(2))
self.call("find_person_by_nick_read_only", "al")
self.assertIn("UNIQUE FOUND: id 23: al", self.logs(2))
# Remove a person, and then fail to find them.
self.call("delete_person", 23)
self.call("find_person", 23)
self.assertIn("UNIQUE NOT FOUND: id 23", self.logs(2))
self.call("find_person_read_only", 23)
self.assertIn("UNIQUE NOT FOUND: id 23", self.logs(2))
# Also fail by nickname
self.call("find_person_by_nick", "al")
self.assertIn("UNIQUE NOT FOUND: nick al", self.logs(2))
self.call("find_person_by_nick_read_only", "al")
self.assertIn("UNIQUE NOT FOUND: nick al", self.logs(2))
# Add some nonunique people.
self.call("insert_nonunique_person", 23, "Alice", True)
self.call("insert_nonunique_person", 42, "Bob", True)
# Find a nonunique person who is there.
self.call("find_nonunique_person", 23)
self.assertIn('NONUNIQUE FOUND: id 23: Alice', self.logs(2))
self.call("find_nonunique_person_read_only", 23)
self.assertIn('NONUNIQUE FOUND: id 23: Alice', self.logs(2))
# Fail to find a nonunique person who is not there.
self.call("find_nonunique_person", 43)
self.assertNotIn("NONUNIQUE NOT FOUND: id 43", self.logs(2))
self.call("find_nonunique_person_read_only", 43)
self.assertNotIn("NONUNIQUE NOT FOUND: id 43", self.logs(2))
# Insert a non-human, then find humans, then find non-humans
self.call("insert_nonunique_person", 64, "Jibbitty", False)
self.call("find_nonunique_humans")
self.assertIn('HUMAN FOUND: id 23: Alice', self.logs(2))
self.assertIn('HUMAN FOUND: id 42: Bob', self.logs(2))
self.call("find_nonunique_non_humans")
self.assertIn('NON-HUMAN FOUND: id 64: Jibbitty', self.logs(2))
# Add another person with the same id, and find them both.
self.call("insert_nonunique_person", 23, "Claire", True)
self.call("find_nonunique_person", 23)
self.assertIn('NONUNIQUE FOUND: id 23: Alice', self.logs(2))
self.assertIn('NONUNIQUE FOUND: id 23: Claire', self.logs(2))
self.call("find_nonunique_person_read_only", 23)
self.assertIn('NONUNIQUE FOUND: id 23: Alice', self.logs(2))
self.assertIn('NONUNIQUE FOUND: id 23: Claire', self.logs(2))
# Check for issues with things present in index but not DB
self.call("insert_person", 101, "Fee", "fee")
self.call("insert_person", 102, "Fi", "fi")
self.call("insert_person", 103, "Fo", "fo")
self.call("insert_person", 104, "Fum", "fum")
self.call("delete_person", 103)
self.call("find_person", 104)
self.assertIn('UNIQUE FOUND: id 104: Fum', self.logs(2))
self.call("find_person_read_only", 104)
self.assertIn('UNIQUE FOUND: id 104: Fum', self.logs(2))
# As above, but for non-unique indices: check for consistency between index and DB
self.call("insert_indexed_person", 7, "James", "Bond")
self.call("insert_indexed_person", 79, "Gold", "Bond")
self.call("insert_indexed_person", 1, "Hydrogen", "Bond")
self.call("insert_indexed_person", 100, "Whiskey", "Bond")
self.call("delete_indexed_person", 100)
self.call("find_indexed_people", "Bond")
logs = self.logs(10)
self.assertIn('INDEXED FOUND: id 7: Bond, James', logs)
self.assertIn('INDEXED FOUND: id 79: Bond, Gold', logs)
self.assertIn('INDEXED FOUND: id 1: Bond, Hydrogen', logs)
self.assertNotIn('INDEXED FOUND: id 100: Bond, Whiskey', logs)
self.call("find_indexed_people_read_only", "Bond")
logs = self.logs(10)
self.assertIn('INDEXED FOUND: id 7: Bond, James', logs)
self.assertIn('INDEXED FOUND: id 79: Bond, Gold', logs)
self.assertIn('INDEXED FOUND: id 1: Bond, Hydrogen', logs)
self.assertNotIn('INDEXED FOUND: id 100: Bond, Whiskey', logs)
# Non-unique version; does not work yet, see db_delete codegen in SpacetimeDB\crates\bindings-macro\src\lib.rs
# self.call("insert_nonunique_person", 101, "Fee")
# self.call("insert_nonunique_person", 102, "Fi")
# self.call("insert_nonunique_person", 103, "Fo")
# self.call("insert_nonunique_person", 104, "Fum")
# self.call("find_nonunique_person", 104)
# self.assertIn('NONUNIQUE FOUND: id 104: Fum', self.logs(2))
# Filter by Identity
self.call("insert_identified_person", 23, "Alice")
self.call("find_identified_person", 23)
self.assertIn('IDENTIFIED FOUND: Alice', self.logs(2))
# Inserting into a table with unique constraints fails
# when the second row has the same value in the constrained columns as the first row.
# In this case, the table has `#[unique] id` and `#[unique] nick` but not `#[unique] name`.
self.call("insert_person_twice", 23, "Alice", "al")
self.assertIn('UNIQUE CONSTRAINT VIOLATION ERROR: id = 23, nick = al', self.logs(2))
-53
View File
@@ -1,53 +0,0 @@
from .. import Smoketest
class ModuleNestedOp(Smoketest):
MODULE_CODE = """
use spacetimedb::{log, ReducerContext, Table};
#[spacetimedb::table(accessor = account)]
pub struct Account {
name: String,
#[unique]
id: i32,
}
#[spacetimedb::table(accessor = friends)]
pub struct Friends {
friend_1: i32,
friend_2: i32,
}
#[spacetimedb::reducer]
pub fn create_account(ctx: &ReducerContext, account_id: i32, name: String) {
ctx.db.account().insert(Account { id: account_id, name } );
}
#[spacetimedb::reducer]
pub fn add_friend(ctx: &ReducerContext, my_id: i32, their_id: i32) {
// Make sure our friend exists
for account in ctx.db.account().iter() {
if account.id == their_id {
ctx.db.friends().insert(Friends { friend_1: my_id, friend_2: their_id });
return;
}
}
}
#[spacetimedb::reducer]
pub fn say_friends(ctx: &ReducerContext) {
for friendship in ctx.db.friends().iter() {
let friend1 = ctx.db.account().id().find(&friendship.friend_1).unwrap();
let friend2 = ctx.db.account().id().find(&friendship.friend_2).unwrap();
log::info!("{} is friends with {}", friend1.name, friend2.name);
}
}
"""
def test_module_nested_op(self):
"""This tests uploading a basic module and calling some functions and checking logs afterwards."""
self.call("create_account", 1, "House")
self.call("create_account", 2, "Wilson")
self.call("add_friend", 1, 2)
self.call("say_friends")
self.assertIn("House is friends with Wilson", self.logs(2))
-247
View File
@@ -1,247 +0,0 @@
from .. import Smoketest, random_string
from subprocess import CalledProcessError
import time
import itertools
class UpdateModule(Smoketest):
AUTOPUBLISH = False
MODULE_CODE = """
use spacetimedb::{log, ReducerContext, Table};
#[spacetimedb::table(accessor = person)]
pub struct Person {
#[primary_key]
#[auto_inc]
id: u64,
name: String,
}
#[spacetimedb::reducer]
pub fn add(ctx: &ReducerContext, name: String) {
ctx.db.person().insert(Person { id: 0, name });
}
#[spacetimedb::reducer]
pub fn say_hello(ctx: &ReducerContext) {
for person in ctx.db.person().iter() {
log::info!("Hello, {}!", person.name);
}
log::info!("Hello, World!");
}
"""
MODULE_CODE_B = """
#[spacetimedb::table(accessor = person)]
pub struct Person {
#[primary_key]
#[auto_inc]
id: u64,
name: String,
age: u8,
}
"""
MODULE_CODE_C = """
use spacetimedb::{log, ReducerContext, Table};
#[spacetimedb::table(accessor = person)]
pub struct Person {
#[primary_key]
#[auto_inc]
id: u64,
name: String,
}
#[spacetimedb::table(accessor = pets)]
pub struct Pet {
species: String,
}
#[spacetimedb::reducer]
pub fn are_we_updated_yet(ctx: &ReducerContext) {
log::info!("MODULE UPDATED");
}
"""
def test_module_update(self):
"""Test publishing a module without the --delete-data option"""
name = random_string()
self.publish_module(name, clear=False)
self.call("add", "Robert")
self.call("add", "Julie")
self.call("add", "Samantha")
self.call("say_hello")
logs = self.logs(100)
self.assertIn("Hello, Samantha!", logs)
self.assertIn("Hello, Julie!", logs)
self.assertIn("Hello, Robert!", logs)
self.assertIn("Hello, World!", logs)
# Unchanged module is ok
self.publish_module(name, clear=False)
# Changing an existing table isn't
self.write_module_code(self.MODULE_CODE_B)
with self.assertRaises(CalledProcessError) as cm:
self.publish_module(name, clear=False)
self.assertIn("Error: Aborting because publishing would require manual migration", cm.exception.stderr)
# Check that the old module is still running by calling say_hello
self.call("say_hello")
# Adding a table is ok
self.write_module_code(self.MODULE_CODE_C)
self.publish_module(name, clear=False)
self.call("are_we_updated_yet")
self.assertIn("MODULE UPDATED", self.logs(2))
class UploadModule1(Smoketest):
MODULE_CODE = """
use spacetimedb::{log, ReducerContext, Table};
#[spacetimedb::table(accessor = person)]
pub struct Person {
name: String,
}
#[spacetimedb::reducer]
pub fn add(ctx: &ReducerContext, name: String) {
ctx.db.person().insert(Person { name });
}
#[spacetimedb::reducer]
pub fn say_hello(ctx: &ReducerContext) {
for person in ctx.db.person().iter() {
log::info!("Hello, {}!", person.name);
}
log::info!("Hello, World!");
}
"""
def test_upload_module_1(self):
"""This tests uploading a basic module and calling some functions and checking logs afterwards."""
self.call("add", "Robert")
self.call("add", "Julie")
self.call("add", "Samantha")
self.call("say_hello")
logs = self.logs(100)
self.assertIn("Hello, Samantha!", logs)
self.assertIn("Hello, Julie!", logs)
self.assertIn("Hello, Robert!", logs)
self.assertIn("Hello, World!", logs)
class UploadModule2(Smoketest):
MODULE_CODE = """
use spacetimedb::{log, duration, ReducerContext, Table, Timestamp};
#[spacetimedb::table(accessor = scheduled_message, public, scheduled(my_repeating_reducer))]
pub struct ScheduledMessage {
#[primary_key]
#[auto_inc]
scheduled_id: u64,
scheduled_at: spacetimedb::ScheduleAt,
prev: Timestamp,
}
#[spacetimedb::reducer(init)]
fn init(ctx: &ReducerContext) {
ctx.db.scheduled_message().insert(ScheduledMessage { prev: ctx.timestamp, scheduled_id: 0, scheduled_at: duration!(100ms).into(), });
}
#[spacetimedb::reducer]
pub fn my_repeating_reducer(ctx: &ReducerContext, arg: ScheduledMessage) {
log::info!("Invoked: ts={:?}, delta={:?}", ctx.timestamp, ctx.timestamp.duration_since(arg.prev));
}
"""
def test_upload_module_2(self):
"""This test deploys a module with a repeating reducer and checks the logs to make sure its running."""
time.sleep(2)
lines = sum(1 for line in self.logs(100) if "Invoked" in line)
time.sleep(4)
new_lines = sum(1 for line in self.logs(100) if "Invoked" in line)
self.assertLess(lines, new_lines)
class HotswapModule(Smoketest):
AUTOPUBLISH = False
MODULE_CODE = """
use spacetimedb::{ReducerContext, Table};
#[spacetimedb::table(accessor = person)]
pub struct Person {
#[primary_key]
#[auto_inc]
id: u64,
name: String,
}
#[spacetimedb::reducer]
pub fn add_person(ctx: &ReducerContext, name: String) {
ctx.db.person().insert(Person { id: 0, name });
}
"""
MODULE_CODE_B = """
use spacetimedb::{ReducerContext, Table};
#[spacetimedb::table(accessor = person)]
pub struct Person {
#[primary_key]
#[auto_inc]
id: u64,
name: String,
}
#[spacetimedb::reducer]
pub fn add_person(ctx: &ReducerContext, name: String) {
ctx.db.person().insert(Person { id: 0, name });
}
#[spacetimedb::table(accessor = pet)]
pub struct Pet {
#[primary_key]
species: String,
}
#[spacetimedb::reducer]
pub fn add_pet(ctx: &ReducerContext, species: String) {
ctx.db.pet().insert(Pet { species });
}
"""
def test_hotswap_module(self):
"""Tests hotswapping of modules."""
# Publish MODULE_CODE and subscribe to all
name = random_string()
self.publish_module(name, clear=False)
sub = self.subscribe("SELECT * FROM *", n=2)
# Trigger event on the subscription
self.call("add_person", "Horst")
# Update the module
self.write_module_code(self.MODULE_CODE_B)
self.publish_module(name, clear=False)
# Assert that the module was updated
self.call("add_pet", "Turtle")
# And trigger another event on the subscription
self.call("add_person", "Cindy")
# Note that 'SELECT * FROM *' does NOT get refreshed to include the
# new table (this is a known limitation).
self.assertEqual(sub(), [
{'person': {'deletes': [], 'inserts': [{'id': 1, 'name': 'Horst'}]}},
{'person': {'deletes': [], 'inserts': [{'id': 2, 'name': 'Cindy'}]}}
])
-38
View File
@@ -1,38 +0,0 @@
from .. import Smoketest, random_string
import tempfile
import os
from glob import iglob
def count_matches(dir, needle):
count = 0
for f in iglob(os.path.join(dir, "**/*.cs"), recursive=True):
with open(f) as f:
count += f.read().count(needle)
return count
class Namespaces(Smoketest):
AUTOPUBLISH = False
def test_spacetimedb_ns_csharp(self):
"""Ensure that the default namespace is working properly"""
namespace = "SpacetimeDB.Types"
with tempfile.TemporaryDirectory() as tmpdir:
self.spacetime("generate", "--out-dir", tmpdir, "--lang=cs", "--module-path", self.project_path)
self.assertEqual(count_matches(tmpdir, f"namespace {namespace}"), 5)
self.assertEqual(count_matches(tmpdir, "using SpacetimeDB;"), 0)
def test_custom_ns_csharp(self):
"""Ensure that when a custom namespace is specified on the command line, it actually gets used in generation"""
namespace = random_string()
with tempfile.TemporaryDirectory() as tmpdir:
self.spacetime("generate", "--out-dir", tmpdir, "--lang=cs", "--namespace", namespace, "--module-path", self.project_path)
self.assertEqual(count_matches(tmpdir, f"namespace {namespace}"), 5)
self.assertEqual(count_matches(tmpdir, "using SpacetimeDB;"), 5)
-53
View File
@@ -1,53 +0,0 @@
from .. import Smoketest, requires_anonymous_login
import time
class NewUserFlow(Smoketest):
AUTOPUBLISH = False
MODULE_CODE = """
use spacetimedb::{log, ReducerContext, Table};
#[spacetimedb::table(accessor = person)]
pub struct Person {
name: String
}
#[spacetimedb::reducer]
pub fn add(ctx: &ReducerContext, name: String) {
ctx.db.person().insert(Person { name });
}
#[spacetimedb::reducer]
pub fn say_hello(ctx: &ReducerContext) {
for person in ctx.db.person().iter() {
log::info!("Hello, {}!", person.name);
}
log::info!("Hello, World!");
}
"""
@requires_anonymous_login
def test_new_user_flow(self):
"""Test the entirety of the new user flow."""
## Publish your module
self.new_identity()
self.publish_module()
# Calling our database
self.call("say_hello")
self.assertIn("Hello, World!", self.logs(2))
## Calling functions with arguments
self.call("add", "Tyler")
self.call("say_hello")
self.assertEqual(self.logs(5).count("Hello, World!"), 2)
self.assertEqual(self.logs(5).count("Hello, Tyler!"), 1)
out = self.spacetime("sql", self.database_identity, "SELECT * FROM person")
# The spaces after the name are important
self.assertMultiLineEqual(out, """\
name
---------
"Tyler"
""")
-50
View File
@@ -1,50 +0,0 @@
from .. import Smoketest
class Panic(Smoketest):
MODULE_CODE = """
use spacetimedb::{log, ReducerContext};
use std::cell::RefCell;
thread_local! {
static X: RefCell<u32> = RefCell::new(0);
}
#[spacetimedb::reducer]
fn first(_ctx: &ReducerContext) {
X.with(|x| {
let _x = x.borrow_mut();
panic!()
})
}
#[spacetimedb::reducer]
fn second(_ctx: &ReducerContext) {
X.with(|x| *x.borrow_mut());
log::info!("Test Passed");
}
"""
def test_panic(self):
"""Tests to check if a SpacetimeDB module can handle a panic without corrupting"""
with self.assertRaises(Exception):
self.call("first")
self.call("second")
self.assertIn("Test Passed", self.logs(2))
class ReducerError(Smoketest):
MODULE_CODE = """
use spacetimedb::ReducerContext;
#[spacetimedb::reducer]
fn fail(_ctx: &ReducerContext) -> Result<(), String> {
Err("oopsie :(".into())
}
"""
def test_reducer_error_message(self):
"""Tests to ensure an error message returned from a reducer gets printed to logs"""
with self.assertRaises(Exception):
self.call("fail")
self.assertIn("oopsie :(", self.logs(2))
-165
View File
@@ -1,165 +0,0 @@
from .. import Smoketest, random_string
import json
class Permissions(Smoketest):
AUTOPUBLISH = False
def setUp(self):
self.reset_config()
def test_call(self):
"""Ensure that anyone has the permission to call any standard reducer"""
self.publish_module()
self.call("say_hello", anon=True)
self.assertEqual("\n".join(self.logs(10000)).count("World"), 1)
def test_delete(self):
"""Ensure that you cannot delete a database that you do not own"""
self.publish_module()
self.new_identity()
with self.assertRaises(Exception):
self.spacetime("delete", "--yes", self.database_identity)
def test_describe(self):
"""Ensure that anyone can describe any database"""
self.publish_module()
self.spacetime("describe", "--anonymous", "--json", self.database_identity)
def test_logs(self):
"""Ensure that we are not able to view the logs of a module that we don't have permission to view"""
self.publish_module()
self.reset_config()
self.new_identity()
self.call("say_hello")
self.reset_config()
self.new_identity()
with self.assertRaises(Exception):
self.spacetime("logs", self.database_identity, "-n", "10000")
def test_publish(self):
"""This test checks to make sure that you cannot publish to an identity that you do not own."""
self.publish_module()
self.new_identity()
with self.assertRaises(Exception):
self.spacetime("publish", self.database_identity, "--module-path", self.project_path, "--delete-data", "--yes")
# Check that this holds without `--delete-data`, too.
with self.assertRaises(Exception):
self.spacetime("publish", self.database_identity, "--module-path", self.project_path, "--yes")
def test_replace_names(self):
"""Test that you can't replace names of a database you don't own"""
name = random_string()
self.publish_module(name)
self.new_identity()
with self.assertRaises(Exception):
self.api_call(
"PUT",
f'/v1/database/{name}/names',
json.dumps(["post", "gres"]),
{"Content-type": "application/json"}
)
class PrivateTablePermissions(Smoketest):
MODULE_CODE = """
use spacetimedb::{ReducerContext, Table};
#[spacetimedb::table(accessor = secret, private)]
pub struct Secret {
answer: u8,
}
#[spacetimedb::table(accessor = common_knowledge, public)]
pub struct CommonKnowledge {
thing: String,
}
#[spacetimedb::reducer(init)]
pub fn init(ctx: &ReducerContext) {
ctx.db.secret().insert(Secret { answer: 42 });
}
#[spacetimedb::reducer]
pub fn do_thing(ctx: &ReducerContext, thing: String) {
ctx.db.secret().insert(Secret { answer: 20 });
ctx.db.common_knowledge().insert(CommonKnowledge { thing });
}
"""
def test_private_table(self):
"""Ensure that a private table can only be queried by the database owner"""
out = self.spacetime("sql", self.database_identity, "select * from secret")
answer = "\n".join([
" answer ",
"--------",
" 42 ",
""
])
self.assertMultiLineEqual(str(out), answer)
self.reset_config()
self.new_identity()
with self.assertRaises(Exception):
self.spacetime("sql", self.database_identity, "select * from secret")
# Subscribing to the private table failes.
with self.assertRaises(Exception):
self.subscribe("SELECT * FROM secret", n=0)
# Subscribing to the public table works.
sub = self.subscribe("SELECT * FROM common_knowledge", n = 1)
self.call("do_thing", "godmorgon")
self.assertEqual(sub(), [
{
'common_knowledge': {
'deletes': [],
'inserts': [{'thing': 'godmorgon'}]
}
}
])
# Subscribing to both tables returns updates for the public one.
sub = self.subscribe("SELECT * FROM *", n=1)
self.call("do_thing", "howdy", anon=True)
self.assertEqual(sub(), [
{
'common_knowledge': {
'deletes': [],
'inserts': [{'thing': 'howdy'}]
}
}
])
class LifecycleReducers(Smoketest):
lifecycle_kinds = "init", "client_connected", "client_disconnected"
MODULE_CODE = "\n".join(f"""
#[spacetimedb::reducer({kind})]
fn lifecycle_{kind}(_ctx: &spacetimedb::ReducerContext) {{}}
""" for kind in lifecycle_kinds)
def test_lifecycle_reducers_cant_be_called(self):
"""Ensure that lifecycle reducers (init, on_connect, etc) can't be called"""
for kind in self.lifecycle_kinds:
with self.assertRaises(Exception):
self.call(f"lifecycle_{kind}")
-403
View File
@@ -1,403 +0,0 @@
import logging
import re
import shutil
from pathlib import Path
import tempfile
import xmltodict
import smoketests
from .. import Smoketest, STDB_DIR, run_cmd, TEMPLATE_CARGO_TOML, TYPESCRIPT_BINDINGS_PATH, build_typescript_sdk, pnpm
def _write_file(path: Path, content: str):
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(content)
def _append_to_file(path: Path, content: str):
with open(path, "a", encoding="utf-8") as f:
f.write(content)
def _parse_quickstart(doc_path: Path, language: str, module_name: str, server: bool) -> str:
"""Extract code blocks from `quickstart.md` docs.
This will replicate the steps in the quickstart guide, so if it fails the quickstart guide is broken.
"""
content = Path(doc_path).read_text()
# append " server" to the codeblock language if we're extracting server code
if server:
codeblock_lang = "ts server" if language == "typescript" else f"{language} server"
else:
codeblock_lang = "ts" if language == "typescript" else language
blocks = re.findall(rf"```{codeblock_lang}\n(.*?)\n```", content, re.DOTALL)
end = ""
if language == "csharp":
found = False
filtered_blocks = []
for block in blocks:
# The doc first create an empty class Module, so we need to fixup the closing
if "partial class Module" in block:
block = block.replace("}", "")
end = "\n}"
# Remove the first `OnConnected` block, which body is later updated
if "OnConnected(DbConnection conn" in block and not found:
found = True
continue
filtered_blocks.append(block)
blocks = filtered_blocks
# So we could have a different db for each language
return "\n".join(blocks).replace("quickstart-chat", module_name) + end
def load_nuget_config(p: Path):
if p.exists():
with p.open("rb") as f:
return xmltodict.parse(f.read(), force_list=["add", "packageSource", "package"])
return {}
def _nuget_config_path(project_dir: Path) -> Path:
p_upper = project_dir / "NuGet.Config"
if p_upper.exists():
return p_upper
p_lower = project_dir / "nuget.config"
if p_lower.exists():
return p_lower
return p_upper
def save_nuget_config(p: Path, doc: dict):
# Write back (pretty, UTF-8, no BOM)
xml = xmltodict.unparse(doc, pretty=True)
p.write_text(xml, encoding="utf-8")
def add_source(doc: dict, *, key: str, path: str) -> None:
cfg = doc.setdefault("configuration", {})
sources = cfg.setdefault("packageSources", {})
source_entries = sources.setdefault("add", [])
source_entries.append({"@key": key, "@value": str(path)})
def add_mapping(doc: dict, *, key: str, pattern: str) -> None:
cfg = doc.setdefault("configuration", {})
psm = cfg.setdefault("packageSourceMapping", {})
mapping_sources = psm.setdefault("packageSource", [])
# Find or create the target <packageSource key="...">
target = next((s for s in mapping_sources if s.get("@key") == key), None)
if target is None:
target = {"@key": key, "package": []}
mapping_sources.append(target)
pkgs = target.setdefault("package", [])
existing = {pkg.get("@pattern") for pkg in pkgs if "@pattern" in pkg}
if pattern not in existing:
pkgs.append({"@pattern": pattern})
def override_nuget_package(*, project_dir: Path, package: str, source_dir: Path, build_subdir: str):
"""Override nuget config to use a local NuGet package on a .NET project"""
# Make sure the local package is built
repo_nuget_config = STDB_DIR / "NuGet.Config"
if repo_nuget_config.exists():
run_cmd(
"dotnet",
"restore",
"--configfile",
str(repo_nuget_config),
cwd=source_dir,
capture_stderr=True,
)
run_cmd("dotnet", "pack", "-c", "Release", "--no-restore", cwd=source_dir)
else:
run_cmd("dotnet", "pack", "-c", "Release", cwd=source_dir)
p = _nuget_config_path(Path(project_dir))
doc = load_nuget_config(p)
add_source(doc, key=package, path=source_dir/build_subdir)
add_mapping(doc, key=package, pattern=package)
add_source(doc, key="nuget.org", path="https://api.nuget.org/v3/index.json")
# Fallback for other packages
add_mapping(doc, key="nuget.org", pattern="*")
save_nuget_config(p, doc)
# Clear any caches for nuget packages
run_cmd("dotnet", "nuget", "locals", "--clear", "all", capture_stderr=True)
class BaseQuickstart(Smoketest):
AUTOPUBLISH = False
MODULE_CODE = ""
lang = None
client_lang = None
codeblock_langs = None
server_doc = None
client_doc = None
server_file = None
client_file = None
module_bindings = None
extra_code = None
replacements = {}
connected_str = None
run_cmd = []
build_cmd = []
def project_init(self, path: Path):
raise NotImplementedError
def sdk_setup(self, path: Path):
raise NotImplementedError
@property
def _module_name(self):
return f"quickstart-chat-{self.lang}"
def _publish(self) -> Path:
base_path = Path(self.enterClassContext(tempfile.TemporaryDirectory()))
server_path = base_path / "server"
self.generate_server(server_path)
self.publish_module(self._module_name, capture_stderr=True, clear=True)
return base_path / "client"
def generate_server(self, server_path: Path):
"""Generate the server code from the quickstart documentation."""
logging.info(f"Generating server code {self.lang}: {server_path}...")
self.spacetime(
"init",
"--non-interactive",
"--lang",
self.lang,
"--project-path",
server_path,
"spacetimedb-project",
capture_stderr=True,
)
self.project_path = server_path / "spacetimedb"
shutil.copy2(STDB_DIR / "rust-toolchain.toml", self.project_path)
_write_file(self.project_path / self.server_file, _parse_quickstart(self.server_doc, self.lang, self._module_name, server=True))
self.server_postprocess(self.project_path)
self.spacetime("build", "-d", "-p", self.project_path, capture_stderr=True)
def server_postprocess(self, server_path: Path):
"""Optional per-language hook."""
pass
def check(self, input_cmd: str, client_path: Path, contains: str):
"""Run the client command and check if the output contains the expected string."""
output = run_cmd(*self.run_cmd, input=input_cmd, cwd=client_path, capture_stderr=True, text=True)
print(f"Output for {self.lang} client:\n{output}")
self.assertIn(contains, output)
def _test_quickstart(self):
"""Run the quickstart client."""
client_path = self._publish()
self.project_init(client_path)
self.sdk_setup(client_path)
run_cmd(*self.build_cmd, cwd=client_path, capture_stderr=True)
client_lang = self.client_lang or self.lang
self.spacetime(
"generate", "--lang", client_lang,
"--out-dir", client_path / self.module_bindings,
"--module-path", self.project_path, capture_stderr=True
)
# Replay the quickstart guide steps
main = _parse_quickstart(self.client_doc, client_lang, self._module_name, server=False)
for src, dst in self.replacements.items():
main = main.replace(src, dst)
main += "\n" + self.extra_code
server = self.get_server_address()
host = server["address"]
protocol = server["protocol"]
main = main.replace("http://localhost:3000", f"{protocol}://{host}")
_write_file(client_path / self.client_file, main)
self.check("", client_path, self.connected_str)
self.check("/name Alice", client_path, "Alice")
self.check("Hello World", client_path, "Hello World")
class Rust(BaseQuickstart):
lang = "rust"
server_doc = STDB_DIR / "docs/docs/00100-intro/00300-tutorials/00100-chat-app.md"
client_doc = STDB_DIR / "docs/docs/00100-intro/00300-tutorials/00100-chat-app.md"
server_file = "src/lib.rs"
client_file = "src/main.rs"
module_bindings = "src/module_bindings"
run_cmd = ["cargo", "run"]
build_cmd = ["cargo", "build"]
replacements = {
# Replace the interactive user input to allow direct testing
"user_input_loop(&ctx)": "user_input_direct(&ctx)",
# Don't cache the token, because it will cause the test to fail if we run against a non-default server (because we don't cache the corresponding signing keypair)
".with_token(creds_store()": "//.with_token(creds_store()"
}
extra_code = """
fn user_input_direct(ctx: &DbConnection) {
let mut line = String::new();
std::io::stdin().read_line(&mut line).expect("Failed to read from stdin.");
if let Some(name) = line.strip_prefix("/name ") {
ctx.reducers.set_name(name.to_string()).unwrap();
} else {
ctx.reducers.send_message(line).unwrap();
}
std::thread::sleep(std::time::Duration::from_secs(1));
std::process::exit(0);
}
"""
connected_str = "connected"
def project_init(self, path: Path):
run_cmd("cargo", "new", "--bin", "--name", "quickstart_chat_client", "client", cwd=path.parent,
capture_stderr=True)
def sdk_setup(self, path: Path):
sdk_rust_path = (STDB_DIR / "sdks/rust").absolute()
sdk_rust_toml_escaped = str(sdk_rust_path).replace('\\', '\\\\\\\\') # double escape for re.sub + toml
sdk_rust_toml = f'spacetimedb-sdk = {{ path = "{sdk_rust_toml_escaped}" }}\nlog = "0.4"\nhex = "0.4"\n'
_append_to_file(path / "Cargo.toml", sdk_rust_toml)
def server_postprocess(self, server_path: Path):
_write_file(server_path / "Cargo.toml", self.cargo_manifest(TEMPLATE_CARGO_TOML))
def test_quickstart(self):
"""Run the Rust quickstart guides for server and client."""
self._test_quickstart()
class CSharp(BaseQuickstart):
lang = "csharp"
server_doc = STDB_DIR / "docs/docs/00100-intro/00300-tutorials/00100-chat-app.md"
client_doc = STDB_DIR / "docs/docs/00100-intro/00300-tutorials/00100-chat-app.md"
server_file = "Lib.cs"
client_file = "Program.cs"
module_bindings = "module_bindings"
run_cmd = ["dotnet", "run"]
build_cmd = ["dotnet", "build"]
# Replace the interactive user input to allow direct testing
replacements = {
"InputLoop();": "UserInputDirect();",
".OnConnect(OnConnected)": ".OnConnect(OnConnectedSignal)",
".OnConnectError(OnConnectError)": ".OnConnectError(OnConnectErrorSignal)",
# Don't cache the token, because it will cause the test to fail if we run against a non-default server (because we don't cache the corresponding signing keypair)
".WithToken(AuthToken.Token)": "//.WithToken(AuthToken.Token)",
"Main();": "" # To put the main function at the end so it can see the new functions
}
# So we can wait for the connection to be established...
extra_code = """
var connectedEvent = new ManualResetEventSlim(false);
var connectionFailed = new ManualResetEventSlim(false);
void OnConnectErrorSignal(Exception e)
{
OnConnectError(e);
connectionFailed.Set();
}
void OnConnectedSignal(DbConnection conn, Identity identity, string authToken)
{
OnConnected(conn, identity, authToken);
connectedEvent.Set();
}
void UserInputDirect() {
string? line = Console.In.ReadToEnd()?.Trim();
if (line == null) Environment.Exit(0);
if (!WaitHandle.WaitAny(
new[] { connectedEvent.WaitHandle, connectionFailed.WaitHandle },
TimeSpan.FromSeconds(5)
).Equals(0))
{
Console.WriteLine("Failed to connect to server within timeout.");
Environment.Exit(1);
}
if (line.StartsWith("/name ")) {
input_queue.Enqueue(("name", line[6..]));
} else {
input_queue.Enqueue(("message", line));
}
Thread.Sleep(1000);
}
Main();
"""
connected_str = "Connected"
def project_init(self, path: Path):
run_cmd("dotnet", "new", "console", "--name", "QuickstartChatClient", "--output", path, capture_stderr=True)
def sdk_setup(self, path: Path):
override_nuget_package(
project_dir=STDB_DIR/"sdks/csharp",
package="SpacetimeDB.BSATN.Runtime",
source_dir=(STDB_DIR / "crates/bindings-csharp/BSATN.Runtime").absolute(),
build_subdir="bin/Release"
)
# This one is only needed because the regression-tests subdir uses it
override_nuget_package(
project_dir=STDB_DIR/"sdks/csharp",
package="SpacetimeDB.Runtime",
source_dir=(STDB_DIR / "crates/bindings-csharp/Runtime").absolute(),
build_subdir="bin/Release"
)
override_nuget_package(
project_dir=path,
package="SpacetimeDB.BSATN.Runtime",
source_dir=(STDB_DIR / "crates/bindings-csharp/BSATN.Runtime").absolute(),
build_subdir="bin/Release"
)
override_nuget_package(
project_dir=path,
package="SpacetimeDB.ClientSDK",
source_dir=(STDB_DIR / "sdks/csharp").absolute(),
build_subdir="bin~/Release"
)
run_cmd("dotnet", "add", "package", "SpacetimeDB.ClientSDK", cwd=path, capture_stderr=True)
def server_postprocess(self, server_path: Path):
override_nuget_package(
project_dir=server_path,
package="SpacetimeDB.Runtime",
source_dir=(STDB_DIR / "crates/bindings-csharp/Runtime").absolute(),
build_subdir="bin/Release"
)
override_nuget_package(
project_dir=server_path,
package="SpacetimeDB.BSATN.Runtime",
source_dir=(STDB_DIR / "crates/bindings-csharp/BSATN.Runtime").absolute(),
build_subdir="bin/Release"
)
def test_quickstart(self):
"""Run the C# quickstart guides for server and client."""
if not smoketests.HAVE_DOTNET:
self.skipTest("C# SDK requires .NET to be installed.")
self._test_quickstart()
# We use the Rust client for testing the TypeScript server quickstart because
# the TypeScript client quickstart is a React app, which is difficult to
# smoketest.
class TypeScript(Rust):
lang = "typescript"
client_lang = "rust"
server_doc = STDB_DIR / "docs/docs/00100-intro/00300-tutorials/00100-chat-app.md"
server_file = "src/index.ts"
def server_postprocess(self, server_path: Path):
build_typescript_sdk()
# We already have spacetimedb as a depencency, but it's expecting to fetch from npm.
# If we don't uninstall before installing, pnpm can panic because it can't find
# the specified version on npm (even though we're about to override it anyway).
pnpm("uninstall", 'spacetimedb', cwd=server_path)
pnpm("install", TYPESCRIPT_BINDINGS_PATH, cwd=server_path)
def test_quickstart(self):
"""Run the TypeScript quickstart guides for server."""
self._test_quickstart()
-507
View File
@@ -1,507 +0,0 @@
import time
import unittest
from typing import Callable
import json
from .. import COMPOSE_FILE, Smoketest, random_string, requires_docker, spacetime, parse_sql_result
from ..docker import DockerManager
def retry(func: Callable, max_retries: int = 3, retry_delay: int = 2):
"""Retry a function on failure with delay."""
for attempt in range(1, max_retries + 1):
try:
return func()
except Exception as e:
if attempt < max_retries:
print(f"Attempt {attempt} failed: {e}. Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
else:
print("Max retries reached. Skipping the exception.")
return False
def int_vals(rows: list[dict]) -> list[dict]:
"""For all dicts in list, cast all values in dict to int."""
return [{k: int(v) for k, v in row.items()} for row in rows]
class Cluster:
"""Manages leader-related operations and state for SpaceTime database cluster."""
def __init__(self, docker_manager, smoketest: Smoketest):
self.docker = docker_manager
self.test = smoketest
# Ensure all containers are up.
self.docker.compose("up", "-d")
def sql(self, sql: str) -> list[dict]:
"""Query the test database."""
res = self.test.sql(sql)
return parse_sql_result(str(res))
def read_controldb(self, sql: str) -> list[dict]:
"""Query the control database."""
res = self.test.spacetime("sql", "spacetime-control", sql)
return parse_sql_result(str(res))
def get_db_id(self):
"""Query database ID."""
sql = f"select id from database where database_identity=0x{self.test.database_identity}"
res = self.read_controldb(sql)
return int(res[0]['id'])
def get_all_replicas(self):
"""Get all replica nodes in the cluster."""
database_id = self.get_db_id()
sql = f"select id, node_id from replica where database_id={database_id}"
return int_vals(self.read_controldb(sql))
def get_leader_info(self):
"""Get current leader's node information including ID, hostname, and container ID."""
database_id = self.get_db_id()
sql = f""" \
select node_v2.id, node_v2.network_addr from node_v2 \
join replica on replica.node_id=node_v2.id \
join replication_state on replication_state.leader=replica.id \
where replication_state.database_id={database_id} \
"""
rows = self.read_controldb(sql)
if not rows:
raise Exception("Could not find current leader's node")
leader_node_id = int(rows[0]['id'])
hostname = ""
if "(some =" in rows[0]['network_addr']:
address = rows[0]['network_addr'].split('"')[1]
hostname = address.split(':')[0]
# Find container ID
container_id = ""
containers = self.docker.list_containers()
for container in containers:
if hostname in container.name:
container_id = container.id
break
return {
'node_id': leader_node_id,
'hostname': hostname,
'container_id': container_id
}
def wait_for_leader_change(self, previous_leader_node, max_attempts=10, delay=2):
"""Wait for leader to change and return new leader node_id."""
for _ in range(max_attempts):
try:
current_leader_node = self.get_leader_info()['node_id']
if current_leader_node != previous_leader_node:
return current_leader_node
except Exception:
print("No current leader")
time.sleep(delay)
return None
def ensure_leader_health(self, id):
"""Verify leader is healthy by inserting a row."""
retry(lambda: self.test.call("start", id, 1))
rows = self.sql(f"select id from counter where id={id}")
if len(rows) < 1 or int(rows[0]['id']) != id:
raise ValueError(f"Could not find {id} in counter table")
# Wait for at least one tick to ensure buffers are flushed.
# TODO: Replace with confirmed read.
time.sleep(0.6)
def wait_counter_value(self, id, value, max_attempts=10, delay=1):
"""Wait for the value for `id` in the counter table to reach `value`"""
for _ in range(max_attempts):
rows = self.sql(f"select * from counter where id={id}")
if len(rows) >= 1 and int(rows[0]['value']) >= value:
return
else:
time.sleep(delay)
raise ValueError(f"Counter {id} below {value}")
def fail_leader(self, action='kill'):
"""Force leader failure through either killing or network disconnect."""
leader_info = self.get_leader_info()
container_id = leader_info['container_id']
if not container_id:
raise ValueError("Could not find leader container")
if action == 'kill':
self.docker.kill_container(container_id)
elif action == 'disconnect':
self.docker.disconnect_container(container_id)
else:
raise ValueError(f"Unknown action: {action}")
return container_id
def restore_leader(self, container_id, action='start'):
"""Restore failed leader through either starting or network reconnect."""
if action == 'start':
self.docker.start_container(container_id)
elif action == 'connect':
self.docker.connect_container(container_id)
else:
raise ValueError(f"Unknown action: {action}")
@requires_docker
class ReplicationTest(Smoketest):
MODULE_CODE = """
use spacetimedb::{duration, ReducerContext, Table};
#[spacetimedb::table(accessor = counter, public)]
pub struct Counter {
#[primary_key]
#[auto_inc]
id: u64,
#[index(btree)]
value: u64,
}
#[spacetimedb::table(accessor = schedule_counter, public, scheduled(increment, at = sched_at))]
pub struct ScheduledCounter {
#[primary_key]
#[auto_inc]
scheduled_id: u64,
sched_at: spacetimedb::ScheduleAt,
count: u64,
}
#[spacetimedb::reducer]
fn increment(ctx: &ReducerContext, arg: ScheduledCounter) {
// if the counter exists, increment it
if let Some(counter) = ctx.db.counter().id().find(arg.scheduled_id) {
if counter.value == arg.count {
ctx.db.schedule_counter().delete(arg);
return;
}
// update counter
ctx.db.counter().id().update(Counter {
id: arg.scheduled_id,
value: counter.value + 1,
});
} else {
// insert fresh counter
ctx.db.counter().insert(Counter {
id: arg.scheduled_id,
value: 1,
});
}
}
#[spacetimedb::reducer]
fn start(ctx: &ReducerContext, id: u64, count: u64) {
ctx.db.schedule_counter().insert(ScheduledCounter {
scheduled_id: id,
sched_at: duration!(0ms).into(),
count,
});
}
#[spacetimedb::table(accessor = message, public)]
pub struct Message {
#[primary_key]
#[auto_inc]
id: u64,
text: String
}
#[spacetimedb::reducer]
fn send_message(ctx: &ReducerContext, text: String) {
ctx.db.message().insert(Message { id: 0, text });
}
"""
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.root_config = cls.project_path / "root_config"
spacetime("--config-path", cls.root_config, "server", "set-default", "local")
def setUp(self):
self.docker = DockerManager(COMPOSE_FILE)
self.root_token = self.docker.generate_root_token()
self.cluster = Cluster(self.docker, self)
def tearDown(self):
# Ensure containers that were brought down during a test are back up.
self.docker.compose("up", "-d")
super().tearDown()
def add_me_as_admin(self):
"""Add the current user as an admin account"""
db_owner_id = str(self.spacetime("login", "show")).split()[-1]
spacetime("--config-path", self.root_config, "login", "--token", self.root_token)
spacetime("--config-path", self.root_config, "call", "spacetime-control", "create_admin_account", f"0x{db_owner_id}")
def start(self, id: int, count: int):
"""Send a message to the database."""
retry(lambda: self.call("start", id, count))
def collect_counter_rows(self):
return int_vals(self.cluster.sql("select * from counter"))
def call_control(self, reducer, *args):
self.spacetime("call", "spacetime-control", reducer, *map(json.dumps, args))
class LeaderElection(ReplicationTest):
def test_leader_election_in_loop(self):
"""This test fails a leader, wait for new leader to be elected and verify if commits replicated to new leader"""
iterations = 5
row_ids = [101 + i for i in range(iterations * 2)]
for (first_id, second_id) in zip(row_ids[::2], row_ids[1::2]):
cur_leader = self.cluster.wait_for_leader_change(None)
print(f"ensure leader health {first_id}")
self.cluster.ensure_leader_health(first_id)
print(f"killing current leader: {cur_leader}")
container_id = self.cluster.fail_leader()
self.assertIsNotNone(container_id)
next_leader = self.cluster.wait_for_leader_change(cur_leader)
self.assertNotEqual(cur_leader, next_leader)
# this check if leader election happened
print(f"ensure_leader_health {second_id}")
self.cluster.ensure_leader_health(second_id)
# restart the old leader, so that we can maintain quorum for next iteration
print(f"reconnect leader {container_id}")
self.cluster.restore_leader(container_id, 'start')
# Ensure we have a current leader
last_row_id = row_ids[-1] + 1
self.cluster.ensure_leader_health(row_ids[-1] + 1)
row_ids.append(last_row_id)
# Verify that all inserted rows are present
stored_row_ids = [row['id'] for row in self.collect_counter_rows()]
self.assertEqual(set(stored_row_ids), set(row_ids))
class LeaderDisconnect(ReplicationTest):
def test_leader_c_disconnect_in_loop(self):
"""This test disconnects a leader, wait for new leader to be elected and verify if commits replicated to new leader"""
iterations = 5
row_ids = [201 + i for i in range(iterations * 2)]
for (first_id, second_id) in zip(row_ids[::2], row_ids[1::2]):
print(f"first={first_id} second={second_id}")
cur_leader = self.cluster.wait_for_leader_change(None)
print(f"ensure leader health {first_id}")
self.cluster.ensure_leader_health(first_id)
print("disconnect current leader")
container_id = self.cluster.fail_leader('disconnect')
self.assertIsNotNone(container_id)
print(f"disconnected leader's container is {container_id}")
next_leader = self.cluster.wait_for_leader_change(cur_leader)
self.assertNotEqual(cur_leader, next_leader)
# this check if leader election happened
print(f"ensure_leader_health {second_id}")
self.cluster.ensure_leader_health(second_id)
# restart the old leader, so that we can maintain quorum for next iteration
print(f"reconnect leader {container_id}")
self.cluster.restore_leader(container_id, 'connect')
# Ensure we have a current leader
last_row_id = row_ids[-1] + 1
self.cluster.ensure_leader_health(last_row_id)
row_ids.append(last_row_id)
# Verify that all inserted rows are present
stored_row_ids = [row['id'] for row in self.collect_counter_rows()]
self.assertEqual(set(stored_row_ids), set(row_ids))
@unittest.skip("drain_node not yet supported")
class DrainLeader(ReplicationTest):
def test_drain_leader_node(self):
"""This test moves leader replica to different node"""
self.add_me_as_admin()
cur_leader_node_id = self.cluster.wait_for_leader_change(None)
self.cluster.ensure_leader_health(301)
replicas = self.cluster.get_all_replicas()
empty_node_id = 14
for replica in replicas:
empty_node_id = empty_node_id - replica['node_id']
self.spacetime("call", "spacetime-control", "drain_node", f"{cur_leader_node_id}", f"{empty_node_id}")
time.sleep(5)
self.cluster.ensure_leader_health(302)
replicas = self.cluster.get_all_replicas()
for replica in replicas:
self.assertNotEqual(replica['node_id'], cur_leader_node_id)
class PreferLeader(ReplicationTest):
def test_prefer_leader(self):
"""This test moves leader replica to different node"""
self.add_me_as_admin()
cur_leader_node_id = self.cluster.wait_for_leader_change(None)
self.cluster.ensure_leader_health(401)
replicas = self.cluster.get_all_replicas()
prefer_replica = {}
for replica in replicas:
if replica['node_id'] != cur_leader_node_id:
prefer_replica = replica
break
prefer_replica_id = prefer_replica['id']
self.spacetime("call", "spacetime-control", "prefer_leader", f"{prefer_replica_id}")
next_leader_node_id = self.cluster.wait_for_leader_change(cur_leader_node_id)
self.cluster.ensure_leader_health(402)
self.assertEqual(prefer_replica['node_id'], next_leader_node_id)
# verify if all past rows are present in new leader
stored_row_ids = [row['id'] for row in self.collect_counter_rows()]
self.assertEqual(set(stored_row_ids), set([401, 402]))
class ManyTransactions(ReplicationTest):
def test_a_many_transactions(self):
"""This test sends many messages to the database and verifies that they are all present"""
self.cluster.wait_for_leader_change(None)
num_messages = 10000
sub = self.subscribe("SELECT * FROM counter", n = num_messages)
self.start(1, num_messages)
message_table = sub()[-1:]
self.assertIn({
'counter': {
'deletes': [{'id': 1, 'value': num_messages - 1}],
'inserts': [{'id': 1, 'value': num_messages}]
}
}, message_table)
class QuorumLoss(ReplicationTest):
def test_quorum_loss(self):
"""This test makes cluster to lose majority of followers to verify if leader eventually stop accepting writes"""
for i in range(11):
self.call("send_message", f"{i}")
leader = self.cluster.get_leader_info()
containers = self.docker.list_containers()
for container in containers:
if leader['container_id'] != container.id and "worker" in container.name:
self.docker.kill_container(container.id)
time.sleep(2)
with self.assertRaises(Exception):
for i in range(1001):
self.call("send_message", "terminal")
class EnableReplicationTest(ReplicationTest):
AUTOPUBLISH = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.expected_counter_rows = []
def run_counter(self, id, n = 100):
self.start(id, n)
self.cluster.wait_counter_value(id, n)
self.expected_counter_rows.append({"id": id, "value": n})
self.assertEqual(self.collect_counter_rows(), self.expected_counter_rows)
def subscribe_to_enable_replication_events(self):
id = self.cluster.get_db_id()
return self.subscribe(
f"select * from staged_enable_replication_event where database_id={id}",
n = 2,
database = "spacetime-control"
)
def assert_bootstrap_complete(self, sub):
events = sub()
self.assertEqual(
events[-1]['staged_enable_replication_event']['inserts'][0]['message'],
'bootstrap complete',
)
def enable_replication(self, database_name):
sub = self.subscribe_to_enable_replication_events()
self.call_control("enable_replication", {"Name": database_name}, 3)
self.assert_bootstrap_complete(sub)
class EnableReplicationUnsuspended(EnableReplicationTest):
def test_enable_replication_fails_if_not_suspended(self):
"""Tests that the database to enable replication on must be suspended"""
self.add_me_as_admin()
name = random_string()
self.publish_module(name, num_replicas = 1)
self.cluster.wait_for_leader_change(None)
with self.assertRaises(Exception):
self.call_control("enable_replication", {"Name": name}, 3)
class EnableReplicationSuspended(EnableReplicationTest):
def test_enable_replication_on_suspended_database(self):
"""Tests that we can enable replication on a suspended database"""
self.add_me_as_admin()
name = random_string()
self.publish_module(name, num_replicas = 1)
self.cluster.wait_for_leader_change(None)
self.cluster.ensure_leader_health(1)
self.call_control("suspend_database", {"Name": name})
# Database is now unreachable.
with self.assertRaises(Exception):
self.call("send_message", "hi")
self.enable_replication(name)
# Still unreachable until we call unsuspend.
with self.assertRaises(Exception):
self.call("send_message", "hi")
self.call_control("unsuspend_database", {"Name": name})
self.cluster.wait_for_leader_change(None)
self.cluster.ensure_leader_health(2)
class EnableDisableReplication(EnableReplicationTest):
def test_enable_disable_replication(self):
"""Tests that we can enable then disable replication"""
self.add_me_as_admin()
name = random_string()
self.publish_module(name, num_replicas = 1)
# ensure database is up and commitlog ends up non-empty
self.run_counter(1, 100)
# suspend first
self.call_control("suspend_database", {"Name": name})
# enable replication and wait for it to complete
self.enable_replication(name)
# unsuspend
self.call_control("unsuspend_database", {"Name": name})
self.cluster.wait_for_leader_change(None)
self.run_counter(2, 100)
self.call_control("disable_replication", {"Name": name})
self.run_counter(3, 100)
-158
View File
@@ -1,158 +0,0 @@
import logging
from .. import Smoketest, random_string
class Rls(Smoketest):
MODULE_CODE = """
use spacetimedb::{Identity, ReducerContext, Table};
#[spacetimedb::table(accessor = users, public)]
pub struct Users {
name: String,
identity: Identity,
}
#[spacetimedb::client_visibility_filter]
const USER_FILTER: spacetimedb::Filter = spacetimedb::Filter::Sql(
"SELECT * FROM users WHERE identity = :sender"
);
#[spacetimedb::reducer]
pub fn add_user(ctx: &ReducerContext, name: String) {
ctx.db.users().insert(Users { name, identity: ctx.sender() });
}
"""
def test_rls_rules(self):
"""Tests for querying tables with RLS rules"""
# Insert an identity for Alice
self.call("add_user", "Alice")
# Insert a new identity for Bob
self.reset_config()
self.new_identity()
self.call("add_user", "Bob")
# Query the users table using Bob's identity
self.assertSql("SELECT name FROM users", """\
name
-------
"Bob"
""")
# Query the users table using a new identity
self.reset_config()
self.new_identity()
self.assertSql("SELECT name FROM users", """\
name
------
""")
class BrokenRls(Smoketest):
AUTOPUBLISH = False
MODULE_CODE_BROKEN = """
use spacetimedb::{client_visibility_filter, Filter};
#[spacetimedb::table(accessor = user)]
pub struct User {
identity: Identity,
}
#[client_visibility_filter]
const PERSON_FILTER: Filter = Filter::Sql("SELECT * FROM \"user\" WHERE identity = :sender");
"""
def test_publish_fails_for_rls_on_private_table(self):
"""This tests that publishing an RLS rule on a private table fails"""
name = random_string()
self.write_module_code(self.MODULE_CODE_BROKEN)
with self.assertRaises(Exception):
self.publish_module(name)
class DisconnectRls(Smoketest):
AUTOPUBLISH = False
MODULE_CODE = """
use spacetimedb::{Identity, ReducerContext, Table};
#[spacetimedb::table(accessor = users, public)]
pub struct Users {
name: String,
identity: Identity,
}
#[spacetimedb::reducer]
pub fn add_user(ctx: &ReducerContext, name: String) {
ctx.db.users().insert(Users { name, identity: ctx.sender() });
}
"""
ADD_RLS = """
#[spacetimedb::client_visibility_filter]
const USER_FILTER: spacetimedb::Filter = spacetimedb::Filter::Sql(
"SELECT * FROM users WHERE identity = :sender"
);
"""
def assertSql(self, sql, expected):
self.maxDiff = None
sql_out = self.spacetime("sql", self.database_identity, sql)
sql_out = "\n".join([line.rstrip() for line in sql_out.splitlines()])
expected = "\n".join([line.rstrip() for line in expected.splitlines()])
self.assertMultiLineEqual(sql_out, expected)
def test_rls_disconnect_if_change(self):
"""This tests that changing the RLS rules disconnects existing clients"""
name = random_string()
self.write_module_code(self.MODULE_CODE)
self.publish_module(name)
logging.info("Initial publish complete")
# Now add the RLS rules
self.write_module_code(self.MODULE_CODE + self.ADD_RLS)
self.publish_module(name, clear=False, break_clients=True)
# Check the row-level SQL filter is added correctly
self.assertSql(
"SELECT sql FROM st_row_level_security",
"""\
sql
------------------------------------------------
"SELECT * FROM users WHERE identity = :sender"
""",
)
logging.info("Re-publish with RLS complete")
logs = self.logs(100)
# Validate disconnect + schema migration logs
self.assertIn("Disconnecting all users", logs)
def test_rls_no_disconnect(self):
"""This tests that not changing the RLS rules does not disconnect existing clients"""
name = random_string()
self.write_module_code(self.MODULE_CODE + self.ADD_RLS)
self.publish_module(name)
logging.info("Initial publish complete")
# Now re-publish the same module code
self.publish_module(name, clear=False, break_clients=False)
logging.info("Re-publish without RLS change complete")
logs = self.logs(100)
# Validate no disconnect logs
self.assertNotIn("Disconnecting all users", logs)
-275
View File
@@ -1,275 +0,0 @@
from .. import Smoketest
import time
class CancelReducer(Smoketest):
MODULE_CODE = """
use spacetimedb::{duration, log, ReducerContext, Table};
#[spacetimedb::reducer(init)]
fn init(ctx: &ReducerContext) {
let schedule = ctx.db.scheduled_reducer_args().insert(ScheduledReducerArgs {
num: 1,
scheduled_id: 0,
scheduled_at: duration!(100ms).into(),
});
ctx.db.scheduled_reducer_args().scheduled_id().delete(&schedule.scheduled_id);
let schedule = ctx.db.scheduled_reducer_args().insert(ScheduledReducerArgs {
num: 2,
scheduled_id: 0,
scheduled_at: duration!(1000ms).into(),
});
do_cancel(ctx, schedule.scheduled_id);
}
#[spacetimedb::table(accessor = scheduled_reducer_args, public, scheduled(reducer))]
pub struct ScheduledReducerArgs {
#[primary_key]
#[auto_inc]
scheduled_id: u64,
scheduled_at: spacetimedb::ScheduleAt,
num: i32,
}
#[spacetimedb::reducer]
fn do_cancel(ctx: &ReducerContext, schedule_id: u64) {
ctx.db.scheduled_reducer_args().scheduled_id().delete(&schedule_id);
}
#[spacetimedb::reducer]
fn reducer(_ctx: &ReducerContext, args: ScheduledReducerArgs) {
log::info!("the reducer ran: {}", args.num);
}
"""
def test_cancel_reducer(self):
"""Ensure cancelling a reducer works"""
time.sleep(2)
logs = "\n".join(self.logs(5))
self.assertNotIn("the reducer ran", logs)
TIMESTAMP_ZERO = {"__timestamp_micros_since_unix_epoch__": 0}
class SubscribeScheduledTable(Smoketest):
MODULE_CODE = """
use spacetimedb::{log, duration, ReducerContext, Table, Timestamp};
#[spacetimedb::table(accessor = scheduled_table, public, scheduled(my_reducer, at = sched_at))]
pub struct ScheduledTable {
#[primary_key]
#[auto_inc]
scheduled_id: u64,
sched_at: spacetimedb::ScheduleAt,
prev: Timestamp,
}
#[spacetimedb::reducer]
fn schedule_reducer(ctx: &ReducerContext) {
ctx.db.scheduled_table().insert(ScheduledTable { prev: Timestamp::from_micros_since_unix_epoch(0), scheduled_id: 2, sched_at: Timestamp::from_micros_since_unix_epoch(0).into(), });
}
#[spacetimedb::reducer]
fn schedule_repeated_reducer(ctx: &ReducerContext) {
ctx.db.scheduled_table().insert(ScheduledTable { prev: Timestamp::from_micros_since_unix_epoch(0), scheduled_id: 1, sched_at: duration!(100ms).into(), });
}
#[spacetimedb::reducer]
pub fn my_reducer(ctx: &ReducerContext, arg: ScheduledTable) {
log::info!("Invoked: ts={:?}, delta={:?}", ctx.timestamp, ctx.timestamp.duration_since(arg.prev));
}
"""
def test_scheduled_table_subscription(self):
"""This test deploys a module with a scheduled reducer and check if client receives subscription update for scheduled table entry and deletion of reducer once it ran"""
# subscribe to empty scheduled_table
sub = self.subscribe("SELECT * FROM scheduled_table", n=2)
# call a reducer to schedule a reducer
self.call("schedule_reducer")
time.sleep(2)
lines = sum(1 for line in self.logs(100) if "Invoked:" in line)
# scheduled reducer should be ran by now
self.assertEqual(lines, 1)
row_entry = {
"prev": TIMESTAMP_ZERO,
"scheduled_id": 2,
"sched_at": {"Time": TIMESTAMP_ZERO},
}
# subscription should have 2 updates, first for row insert in scheduled table and second for row deletion.
self.assertEqual(
sub(),
[
{"scheduled_table": {"deletes": [], "inserts": [row_entry]}},
{"scheduled_table": {"deletes": [row_entry], "inserts": []}},
],
)
def test_scheduled_table_subscription_repeated_reducer(self):
"""This test deploys a module with a repeated reducer and check if client receives subscription update for scheduled table entry and no delete entry"""
# subscribe to empty scheduled_table
sub = self.subscribe("SELECT * FROM scheduled_table", n=2)
# call a reducer to schedule a reducer
self.call("schedule_repeated_reducer")
time.sleep(2)
lines = sum(1 for line in self.logs(100) if "Invoked:" in line)
# repeated reducer should have run more than once.
self.assertLess(2, lines)
# scheduling repeated reducer again just to get 2nd subscription update.
self.call("schedule_reducer")
repeated_row_entry = {
"prev": TIMESTAMP_ZERO,
"scheduled_id": 1,
"sched_at": {"Interval": {"__time_duration_micros__": 100000}},
}
row_entry = {
"prev": TIMESTAMP_ZERO,
"scheduled_id": 2,
"sched_at": {"Time": TIMESTAMP_ZERO},
}
# subscription should have 2 updates and should not have any deletes
self.assertEqual(
sub(),
[
{"scheduled_table": {"deletes": [], "inserts": [repeated_row_entry]}},
{"scheduled_table": {"deletes": [], "inserts": [row_entry]}},
],
)
class SubscribeScheduledProcedureTable(Smoketest):
MODULE_CODE = """
use spacetimedb::{log, duration, ReducerContext, ProcedureContext, Table, Timestamp};
#[spacetimedb::table(accessor = scheduled_table, public, scheduled(my_procedure, at = sched_at))]
pub struct ScheduledTable {
#[primary_key]
#[auto_inc]
scheduled_id: u64,
sched_at: spacetimedb::ScheduleAt,
prev: Timestamp,
}
#[spacetimedb::reducer]
fn schedule_procedure(ctx: &ReducerContext) {
ctx.db.scheduled_table().insert(ScheduledTable { prev: Timestamp::from_micros_since_unix_epoch(0), scheduled_id: 2, sched_at: Timestamp::from_micros_since_unix_epoch(0).into(), });
}
#[spacetimedb::reducer]
fn schedule_repeated_procedure(ctx: &ReducerContext) {
ctx.db.scheduled_table().insert(ScheduledTable { prev: Timestamp::from_micros_since_unix_epoch(0), scheduled_id: 1, sched_at: duration!(100ms).into(), });
}
#[spacetimedb::procedure]
pub fn my_procedure(ctx: &mut ProcedureContext, arg: ScheduledTable) {
log::info!("Invoked: ts={:?}, delta={:?}", ctx.timestamp, ctx.timestamp.duration_since(arg.prev));
}
"""
def test_scheduled_table_subscription(self):
"""This test deploys a module with a scheduled procedure and check if client receives subscription update for scheduled table entry and deletion of procedure once it ran"""
# subscribe to empty scheduled_table
sub = self.subscribe("SELECT * FROM scheduled_table", n=2)
# call a reducer to schedule a procedure
self.call("schedule_procedure")
time.sleep(2)
lines = sum(1 for line in self.logs(100) if "Invoked:" in line)
# scheduled procedure should be ran by now
self.assertEqual(lines, 1)
row_entry = {
"prev": TIMESTAMP_ZERO,
"scheduled_id": 2,
"sched_at": {"Time": TIMESTAMP_ZERO},
}
# subscription should have 2 updates, first for row insert in scheduled table and second for row deletion.
self.assertEqual(
sub(),
[
{"scheduled_table": {"deletes": [], "inserts": [row_entry]}},
{"scheduled_table": {"deletes": [row_entry], "inserts": []}},
],
)
def test_scheduled_table_subscription_repeated_procedure(self):
"""This test deploys a module with a repeated procedure and check if client receives subscription update for scheduled table entry and no delete entry"""
# subscribe to empty scheduled_table
sub = self.subscribe("SELECT * FROM scheduled_table", n=2)
# call a reducer to schedule a procedure
self.call("schedule_repeated_procedure")
time.sleep(2)
lines = sum(1 for line in self.logs(100) if "Invoked:" in line)
# repeated procedure should have run more than once.
self.assertLess(2, lines)
# scheduling repeated procedure again just to get 2nd subscription update.
self.call("schedule_procedure")
repeated_row_entry = {
"prev": TIMESTAMP_ZERO,
"scheduled_id": 1,
"sched_at": {"Interval": {"__time_duration_micros__": 100000}},
}
row_entry = {
"prev": TIMESTAMP_ZERO,
"scheduled_id": 2,
"sched_at": {"Time": TIMESTAMP_ZERO},
}
# subscription should have 2 updates and should not have any deletes
self.assertEqual(
sub(),
[
{"scheduled_table": {"deletes": [], "inserts": [repeated_row_entry]}},
{"scheduled_table": {"deletes": [], "inserts": [row_entry]}},
],
)
class VolatileNonatomicScheduleImmediate(Smoketest):
BINDINGS_FEATURES = ["unstable"]
MODULE_CODE = """
use spacetimedb::{ReducerContext, Table};
#[spacetimedb::table(accessor = my_table, public)]
pub struct MyTable {
x: String,
}
#[spacetimedb::reducer]
fn do_schedule(_ctx: &ReducerContext) {
spacetimedb::volatile_nonatomic_schedule_immediate!(do_insert("hello".to_owned()));
}
#[spacetimedb::reducer]
fn do_insert(ctx: &ReducerContext, x: String) {
ctx.db.my_table().insert(MyTable { x });
}
"""
def test_volatile_nonatomic_schedule_immediate(self):
"""Check that volatile_nonatomic_schedule_immediate works"""
sub = self.subscribe("SELECT * FROM my_table", n=2)
self.call("do_insert", "yay!")
self.call("do_schedule")
self.assertEqual(
sub(),
[
{"my_table": {"deletes": [], "inserts": [{"x": "yay!"}]}},
{"my_table": {"deletes": [], "inserts": [{"x": "hello"}]}},
],
)
-36
View File
@@ -1,36 +0,0 @@
from .. import Smoketest, extract_field, requires_local_server
import re
# We require a local server because these tests have hardcoded server addresses.
@requires_local_server
class Servers(Smoketest):
AUTOPUBLISH = False
def test_servers(self):
"""Verify that we can add and list server configurations"""
out = self.spacetime("server", "add", "--url", "https://testnet.spacetimedb.com", "testnet", "--no-fingerprint")
self.assertEqual(extract_field(out, "Host:"), "testnet.spacetimedb.com")
self.assertEqual(extract_field(out, "Protocol:"), "https")
servers = self.spacetime("server", "list")
self.assertRegex(servers, re.compile(r"^\s*testnet\.spacetimedb\.com\s+https\s+testnet\s*$", re.M))
self.assertRegex(servers, re.compile(r"^\s*\*\*\*\s+127\.0\.0\.1:3000\s+http\s+localhost\s*$", re.M))
out = self.spacetime("server", "fingerprint", "http://127.0.0.1:3000", "-y")
self.assertIn("No saved fingerprint for server 127.0.0.1:3000.", out)
out = self.spacetime("server", "fingerprint", "http://127.0.0.1:3000")
self.assertIn("Fingerprint is unchanged for server 127.0.0.1:3000", out)
out = self.spacetime("server", "fingerprint", "localhost")
self.assertIn("Fingerprint is unchanged for server localhost", out)
def test_edit_server(self):
"""Verify that we can edit server configurations"""
out = self.spacetime("server", "add", "--url", "https://foo.com", "foo", "--no-fingerprint")
out = self.spacetime("server", "edit", "foo", "--url", "https://edited-testnet.spacetimedb.com", "--new-name", "edited-testnet", "--no-fingerprint", "--yes")
servers = self.spacetime("server", "list")
self.assertRegex(servers, re.compile(r"^\s*edited-testnet\.spacetimedb\.com\s+https\s+edited-testnet\s*$", re.M))
-174
View File
@@ -1,174 +0,0 @@
from .. import Smoketest
class SqlFormat(Smoketest):
MODULE_CODE = """
use spacetimedb::sats::{i256, u256};
use spacetimedb::{ConnectionId, Identity, ReducerContext, Table, Timestamp, TimeDuration, SpacetimeType, Uuid};
#[derive(Copy, Clone)]
#[spacetimedb::table(accessor = t_ints)]
pub struct TInts {
i8: i8,
i16: i16,
i32: i32,
i64: i64,
i128: i128,
i256: i256,
}
#[spacetimedb::table(accessor = t_ints_tuple)]
pub struct TIntsTuple {
tuple: TInts,
}
#[derive(Copy, Clone)]
#[spacetimedb::table(accessor = t_uints)]
pub struct TUints {
u8: u8,
u16: u16,
u32: u32,
u64: u64,
u128: u128,
u256: u256,
}
#[spacetimedb::table(accessor = t_uints_tuple)]
pub struct TUintsTuple {
tuple: TUints,
}
#[derive(Clone)]
#[spacetimedb::table(accessor = t_others)]
pub struct TOthers {
bool: bool,
f32: f32,
f64: f64,
str: String,
bytes: Vec<u8>,
identity: Identity,
connection_id: ConnectionId,
timestamp: Timestamp,
duration: TimeDuration,
uuid: Uuid,
}
#[spacetimedb::table(accessor = t_others_tuple)]
pub struct TOthersTuple {
tuple: TOthers
}
#[derive(SpacetimeType, Debug, Clone, Copy)]
pub enum Action {
Inactive,
Active,
}
#[derive(Clone)]
#[spacetimedb::table(accessor = t_enums)]
pub struct TEnums {
bool_opt: Option<bool>,
bool_result: Result<bool, String>,
action: Action,
}
#[spacetimedb::table(accessor = t_enums_tuple)]
pub struct TEnumsTuple {
tuple: TEnums,
}
#[spacetimedb::reducer]
pub fn test(ctx: &ReducerContext) {
let tuple = TInts {
i8: -25,
i16: -3224,
i32: -23443,
i64: -2344353,
i128: -234434897853,
i256: (-234434897853i128).into(),
};
ctx.db.t_ints().insert(tuple);
ctx.db.t_ints_tuple().insert(TIntsTuple { tuple });
let tuple = TUints {
u8: 105,
u16: 1050,
u32: 83892,
u64: 48937498,
u128: 4378528978889,
u256: 4378528978889u128.into(),
};
ctx.db.t_uints().insert(tuple);
ctx.db.t_uints_tuple().insert(TUintsTuple { tuple });
let tuple = TOthers {
bool: true,
f32: 594806.58906,
f64: -3454353.345389043278459,
str: "This is spacetimedb".to_string(),
bytes: vec!(1, 2, 3, 4, 5, 6, 7),
identity: Identity::ONE,
connection_id: ConnectionId::ZERO,
timestamp: Timestamp::UNIX_EPOCH,
duration: TimeDuration::ZERO,
uuid: Uuid::NIL,
};
ctx.db.t_others().insert(tuple.clone());
ctx.db.t_others_tuple().insert(TOthersTuple { tuple });
let tuple = TEnums {
bool_opt: Some(true),
bool_result: Ok(false),
action: Action::Active,
};
ctx.db.t_enums().insert(tuple.clone());
ctx.db.t_enums_tuple().insert(TEnumsTuple { tuple });
}
"""
def test_sql_format(self):
"""This test is designed to test the format of the output of sql queries"""
self.call("test")
self.assertSql("SELECT * FROM t_ints", """\
i_8 | i_16 | i_32 | i_64 | i_128 | i_256
-----+-------+--------+----------+---------------+---------------
-25 | -3224 | -23443 | -2344353 | -234434897853 | -234434897853
""")
self.assertSql("SELECT * FROM t_ints_tuple", """\
tuple
---------------------------------------------------------------------------------------------------------
(i_8 = -25, i_16 = -3224, i_32 = -23443, i_64 = -2344353, i_128 = -234434897853, i_256 = -234434897853)
""")
self.assertSql("SELECT * FROM t_uints", """\
u_8 | u_16 | u_32 | u_64 | u_128 | u_256
-----+------+-------+----------+---------------+---------------
105 | 1050 | 83892 | 48937498 | 4378528978889 | 4378528978889
""")
self.assertSql("SELECT * FROM t_uints_tuple", """\
tuple
-------------------------------------------------------------------------------------------------------
(u_8 = 105, u_16 = 1050, u_32 = 83892, u_64 = 48937498, u_128 = 4378528978889, u_256 = 4378528978889)
""")
self.assertSql("SELECT * FROM t_others", """\
bool | f_32 | f_64 | str | bytes | identity | connection_id | timestamp | duration | uuid
------+-----------+--------------------+-----------------------+------------------+--------------------------------------------------------------------+------------------------------------+---------------------------+-----------+----------------------------------------
true | 594806.56 | -3454353.345389043 | "This is spacetimedb" | 0x01020304050607 | 0x0000000000000000000000000000000000000000000000000000000000000001 | 0x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | +0.000000 | "00000000-0000-0000-0000-000000000000"
""")
self.assertSql("SELECT * FROM t_others_tuple", """\
tuple
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(bool = true, f_32 = 594806.56, f_64 = -3454353.345389043, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000001, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = +0.000000, uuid = "00000000-0000-0000-0000-000000000000")
""")
self.assertSql("SELECT * FROM t_enums", """\
bool_opt | bool_result | action
---------------+--------------+---------------
(some = true) | (ok = false) | (active = ())
""")
self.assertSql("SELECT * FROM t_enums_tuple", """\
tuple
--------------------------------------------------------------------------------
(bool_opt = (some = true), bool_result = (ok = false), action = (active = ()))
""")
-558
View File
@@ -1,558 +0,0 @@
import json
import toml
from .. import COMPOSE_FILE, Smoketest, parse_sql_result, random_string, spacetime
from ..docker import DockerManager
from ..tests.replication import Cluster
OWNER = "Owner"
ADMIN = "Admin"
DEVELOPER = "Developer"
VIEWER = "Viewer"
ROLES = [OWNER, ADMIN, DEVELOPER, VIEWER]
def get(d: dict, k):
return (k, d[k])
class CreateChildDatabase(Smoketest):
AUTOPUBLISH = False
def test_create_child_database(self):
"""
Test that the owner can add a child database,
and that deleting the parent also deletes the child.
"""
parent_name = random_string()
child_name = random_string()
self.publish_module(parent_name)
parent_identity = self.database_identity
self.publish_module(f"{parent_name}/{child_name}")
child_identity = self.database_identity
databases = self.query_controldb(parent_identity, child_identity)
self.assertEqual(2, len(databases))
self.spacetime("delete", "--yes", parent_name)
databases = self.query_controldb(parent_identity, child_identity)
self.assertEqual(0, len(databases))
def query_controldb(self, parent, child):
res = self.spacetime(
"sql",
"spacetime-control",
f"select * from database where database_identity = 0x{parent} or database_identity = 0x{child}"
)
return parse_sql_result(str(res))
class ChangeDatabaseHierarchy(Smoketest):
AUTOPUBLISH = False
def test_change_database_hierarchy(self):
"""
Test that changing the hierarchy of an existing database is not
supported.
"""
parent_name = f"parent-{random_string()}"
sibling_name = f"sibling-{random_string()}"
child_name = f"child-{random_string()}"
self.publish_module(parent_name)
self.publish_module(sibling_name)
# Publish as a child of 'parent_name'.
self.publish_module(f"{parent_name}/{child_name}")
# Publishing again with a different parent is rejected...
with self.assertRaises(Exception):
self.publish_module(f"{sibling_name}/{child_name}", clear = False)
# ..even if `clear = True`
with self.assertRaises(Exception):
self.publish_module(f"{sibling_name}/{child_name}", clear = True)
# Publishing again with the same parent is ok.
self.publish_module(f"{parent_name}/{child_name}", clear = False)
class TeamsPermissionsTest(Smoketest):
AUTOPUBLISH = False
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.root_config = cls.project_path / "root_config"
spacetime("--config-path", cls.root_config, "server", "set-default", "local")
def setUp(self):
self.docker = DockerManager(COMPOSE_FILE)
self.root_token = self.docker.generate_root_token()
self.cluster = Cluster(self.docker, self)
def create_identity(self):
"""
Obtain a fresh identity and token from the server.
Doesn't alter the config.toml for this test instance.
"""
resp = self.api_call("POST", "/v1/identity")
return json.loads(resp)
def create_collaborators(self, database):
"""
Create collaborators for the current database, one for each role.
"""
collaborators = {}
for role in ROLES:
identity_and_token = self.create_identity()
self.call_controldb_reducer(
"upsert_collaborator",
{"Name": database},
[f"0x{identity_and_token['identity']}"],
{role: {}}
)
collaborators[role] = identity_and_token
return collaborators
def create_organization(self):
"""
Create an organization with one member per role.
"""
members = {}
organization_identity = self.create_identity()['identity']
for role in ROLES:
member = self.create_identity()
self.call_controldb_reducer(
"upsert_organization_member",
[f"0x{organization_identity}"],
[f"0x{member['identity']}"],
{role: {}}
)
members[role] = member
organization = {
"organization": organization_identity,
"members": members
}
self.organization = organization
return organization
def make_admin(self):
"""
Create an admin account for the currently logged-in identity.
"""
identity = str(self.spacetime("login", "show")).split()[-1]
spacetime("--config-path", self.root_config, "login", "--token", self.root_token)
spacetime("--config-path", self.root_config, "call",
"spacetime-control", "create_admin_account", f"0x{identity}")
def call_controldb_reducer(self, reducer, *args):
"""
Call a controldb reducer.
"""
self.spacetime("call", "spacetime-control", reducer, *map(json.dumps, args))
def login_with(self, identity_and_token: dict):
self.spacetime("logout")
config = toml.load(self.config_path)
config['spacetimedb_token'] = identity_and_token['token']
with open(self.config_path, 'w') as f:
toml.dump(config, f)
def publish_as(self, role_and_token, module, code = None, clear = False, org = None):
print(f"publishing {module} with org {org} as {role_and_token[0]}:")
code = self.MODULE_CODE if code is None else code
print(f"{code}")
self.login_with(role_and_token[1])
self.write_module_code(code)
self.publish_module(module, clear = clear, organization = org)
return self.database_identity
def sql_as(self, role_and_token, database, sql):
"""
Log in as `token` and run an SQL statement against `database`
"""
print(f"running sql as {role_and_token[0]}: {sql}")
self.login_with(role_and_token[1])
res = self.spacetime("sql", database, sql)
return parse_sql_result(str(res))
def subscribe_as(self, role_and_token, *queries, n):
"""
Log in as `token` and subscribe to the current database using `queries`.
"""
print(f"subscribe as {role_and_token[0]}: {queries}")
self.login_with(role_and_token[1])
return self.subscribe(*queries, n = n)
def tearDown(self):
if "organization" in self.__dict__:
# Log in as owner
self.login_with(self.organization['members'][OWNER])
# Delete database (requires org to still exist)
super().tearDown()
# Delete org
try:
self.call_controldb_reducer(
"delete_organization",
[f"0x{self.organization['organization']}"]
)
except Exception:
pass
else:
super().tearDown()
class TeamsMutableSql(TeamsPermissionsTest):
MODULE_CODE = """
#[spacetimedb::table(accessor = person, public)]
struct Person {
name: String,
}
"""
def run_test(self, database, team):
for role, token in team.items():
self.login_with(token)
dml = f"insert into person (name) values ('bob-the-{role}')"
if role == OWNER or role == ADMIN:
self.spacetime("sql", database, dml)
else:
with self.assertRaises(Exception):
self.spacetime("sql", database, dml)
class CollaboratorsMutableSql(TeamsMutableSql):
def test_permissions_mut_sql_collaborators(self):
"""
Tests that only owner and admin collaborators can perform mutable SQL
transactions.
"""
name = random_string()
self.publish_module(name)
team = self.create_collaborators(name)
self.run_test(name, team)
class OrgMutableSql(TeamsMutableSql):
def test_org_permissions_mut_sql_org_members(self):
"""
Tests that only owner and admin organization members can perform mutable
SQL transactions.
"""
self.make_admin()
org = self.create_organization()
name = random_string()
self.login_with(org['members'][OWNER])
self.publish_module(name, organization = f"0x{org['organization']}")
self.run_test(name, org['members'])
class TeamsPublishDatabase(TeamsPermissionsTest):
MODULE_CODE = """
#[spacetimedb::table(accessor = person, public)]
struct Person {
name: String,
}
"""
MODULE_CODE_OWNER = MODULE_CODE + """
#[spacetimedb::table(accessor = owner)]
struct Owner {
name: String,
}
"""
MODULE_CODE_ADMIN = MODULE_CODE_OWNER + """
#[spacetimedb::table(accessor = admin)]
struct Admin {
name: String,
}
"""
MODULE_CODE_DEVELOPER = MODULE_CODE_ADMIN + """
#[spacetimedb::table(accessor = developer)]
struct Developer {
name: String,
}
"""
MODULE_CODE_VIEWER = MODULE_CODE_DEVELOPER + """
#[spacetimedb::table(accessor = viewer)]
struct Viewer {
name: String,
}
"""
MODULES = {
OWNER: MODULE_CODE_OWNER,
ADMIN: MODULE_CODE_ADMIN,
DEVELOPER: MODULE_CODE_DEVELOPER,
VIEWER: MODULE_CODE_VIEWER
}
def run_test(self, parent, child, team, org):
self.assert_all_except_viewer_can_update(parent, team, org = org)
# Create a child database.
child_path = f"{parent}/{child}"
# Developer and viewer should not be able to create a child.
for role in [DEVELOPER, VIEWER]:
with self.assertRaises(Exception):
self.publish_as(get(team, role), child_path, self.MODULE_CODE, org = org)
# But admin should succeed.
self.publish_as(get(team, ADMIN), child_path, self.MODULE_CODE, org = org)
# Once created, only viewer should be denied updating.
self.assert_all_except_viewer_can_update(child_path, team, org)
def assert_all_except_viewer_can_update(self, database, team, org):
for role in [OWNER, ADMIN, DEVELOPER]:
self.publish_as(get(team, role), database, self.MODULES[role], org = org)
with self.assertRaises(Exception):
self.publish_as(get(team, VIEWER), database, self.MODULES[VIEWER], org = org)
class CollaboratorsPublishDatabase(TeamsPublishDatabase):
def test_permissions_publish_collaborators(self):
"""
Tests that only owner, admin and developer collaborators can publish a
database.
"""
parent = random_string()
child = random_string()
self.publish_module(parent)
team = self.create_collaborators(parent)
self.run_test(parent, child, team, org = None)
class OrgPublishDatabase(TeamsPublishDatabase):
def test_permissions_publish_org_members(self):
"""
Tests that only owner, admin and developer organization members can
publish a database.
"""
self.make_admin()
org = self.create_organization()
parent = random_string()
child = random_string()
self.login_with(org['members'][OWNER])
self.publish_module(parent, organization = f"0x{org['organization']}")
self.run_test(parent, child, org['members'],
org = org['organization'])
class TeamsClearDatabase(TeamsPermissionsTest):
def assert_can_clear(self, auth, database):
self.publish_as(auth, database, clear = True)
def assert_cannot_clear(self, auth, database):
with self.assertRaises(Exception):
self.publish_as(auth, database, clear = True)
def assert_clear_permissions(self, team, database):
for role in [OWNER, ADMIN]:
self.assert_can_clear(get(team, role), database)
for role in [DEVELOPER, VIEWER]:
self.assert_cannot_clear(get(team, role), database)
class CollaboratorsClearDatabase(TeamsClearDatabase):
def test_permissions_clear_collaborators(self):
"""
Tests that only owner and admin collaborators can clear a database.
"""
parent = random_string()
self.publish_module(parent)
# First degree owner can clear.
self.publish_module(parent, clear = True)
team = self.create_collaborators(parent)
self.assert_clear_permissions(team, parent)
# Child databases cannot be cleared at all
child = f"{parent}/{random_string()}"
self.publish_as(get(team, OWNER), child)
for auth in team.items():
self.assert_cannot_clear(auth, child)
class OrgClearDatabase(TeamsClearDatabase):
def test_permissions_clear_org(self):
"""
Test that only owner or admin org members can clear a database.
"""
self.make_admin()
org = self.create_organization()
team = org['members']
parent = random_string()
self.login_with(org['members'][OWNER])
self.publish_module(parent, organization = f"0x{org['organization']}")
self.assert_clear_permissions(team, parent)
# Child databases cannot be cleared at all
child = f"{parent}/{random_string()}"
self.publish_as(get(team, ADMIN), child)
for auth in team.items():
self.assert_cannot_clear(auth, child)
class TeamsDeleteDatabase(TeamsPermissionsTest):
def delete_as(self, role_and_token, database):
print(f"delete {database} as {role_and_token[0]}")
self.login_with(role_and_token[1])
self.spacetime("delete", "--yes", database)
class CollaboratorsDeleteDatabase(TeamsDeleteDatabase):
def test_permissions_delete_collaborators(self):
"""
Tests that only owners can delete databases.
"""
parent = random_string()
child = random_string()
self.publish_module(parent)
self.spacetime("delete", "--yes", parent)
self.publish_module(parent)
team = self.create_collaborators(parent)
for role in [ADMIN, DEVELOPER, VIEWER]:
with self.assertRaises(Exception):
self.delete_as(get(team, role), parent)
child_path = f"{parent}/{child}"
# If admin creates a child, they should also be able to delete it,
# because they are the owner of the child.
print("publish and delete as admin")
self.publish_as(get(team, ADMIN), child_path)
self.delete_as(get(team, ADMIN), child)
# The owner role should be able to delete.
print("publish as admin, delete as owner")
self.publish_as(get(team, ADMIN), child_path)
self.delete_as(get(team, OWNER), child)
# Anyone else should be denied if not direct owner.
print("publish as owner, deny deletion by admin, developer, viewer")
self.publish_as(get(team, OWNER), child_path)
for role in [ADMIN, DEVELOPER, VIEWER]:
with self.assertRaises(Exception):
self.delete_as(get(team, role), child)
print("delete child as owner")
self.delete_as(get(team, OWNER), child)
print("delete parent as owner")
self.delete_as(get(team, OWNER), parent)
class OrgDeleteDatabase(TeamsDeleteDatabase):
def test_permissions_delete_org(self):
"""
Tests that only organization owners can delete databases.
"""
self.make_admin()
org = self.create_organization()
team = org['members']
parent = random_string()
child = random_string()
self.login_with(org['members'][OWNER])
self.publish_module(parent, organization = f"0x{org['organization']}")
self.publish_module(f"{parent}/{child}")
# Org databases can only be deleted by owners
# because ownership is transferred to the org
# and publisher attribution is lost.
for database in [child, parent]:
for role in [ADMIN, DEVELOPER, VIEWER]:
with self.assertRaises(Exception):
self.delete_as(get(team, role), database)
self.delete_as(get(team, OWNER), child)
self.delete_as(get(team, OWNER), parent)
class TeamsPrivateTables(TeamsPermissionsTest):
def run_test(self, database, team):
owner = get(team, OWNER)
self.sql_as(owner, database, "insert into person (name) values ('horsti')")
for auth in team.items():
rows = self.sql_as(auth, database, "select * from person")
self.assertEqual(rows, [{ "name": '"horsti"' }])
for auth in team.items():
sub = self.subscribe_as(auth, "select * from person", n = 2)
self.sql_as(owner, database, "insert into person (name) values ('hansmans')")
self.sql_as(owner, database, "delete from person where name = 'hansmans'")
res = sub()
self.assertEqual(
res,
[
{
'person': {
'deletes': [],
'inserts': [{'name': 'hansmans'}]
}
},
{
'person': {
'deletes': [{'name': 'hansmans'}],
'inserts': []
}
}
],
)
class CollaboratorsPrivateTables(TeamsPrivateTables):
def test_permissions_private_tables(self):
"""
Test that all collaborators can read private tables.
"""
database = random_string()
self.publish_module(database)
team = self.create_collaborators(database)
self.run_test(database, team)
class OrgPrivateTables(TeamsPrivateTables):
def test_org_permissions_private_tables(self):
"""
Test that all organization members can read private tables.
"""
self.make_admin()
org = self.create_organization()
database = random_string()
self.login_with(org['members'][OWNER])
self.publish_module(database, organization = f"0x{org['organization']}")
self.run_test(database, org['members'])
-35
View File
@@ -1,35 +0,0 @@
from .. import Smoketest, random_string
import unittest
import json
import io
TIMESTAMP_TAG = "__timestamp_micros_since_unix_epoch__"
class TimestampRoute(Smoketest):
AUTO_PUBLISH = False
def test_timestamp_route(self):
name = random_string()
# A request for the timestamp at a non-existent database is an error...
with self.assertRaises(Exception) as err:
self.api_call(
"GET",
f"/v1/database/{name}/unstable/timestamp",
)
# ... with code 404.
self.assertEqual(err.exception.args[0].status, 404)
self.publish_module(name)
# A request for the timestamp at an extant database is a success...
resp = self.api_call(
"GET",
f"/v1/database/{name}/unstable/timestamp",
)
# ... and the response body is a SATS-JSON encoded `Timestamp`.
timestamp = json.load(io.BytesIO(resp))
self.assertIsInstance(timestamp, dict)
self.assertIn(TIMESTAMP_TAG, timestamp)
self.assertIsInstance(timestamp[TIMESTAMP_TAG], int)
-822
View File
@@ -1,822 +0,0 @@
from .. import Smoketest, random_string
class Views(Smoketest):
MODULE_CODE = """
use spacetimedb::ViewContext;
#[derive(Copy, Clone)]
#[spacetimedb::table(accessor = player_state)]
pub struct PlayerState {
#[primary_key]
id: u64,
#[index(btree)]
level: u64,
}
#[spacetimedb::view(accessor = player, public)]
pub fn player(ctx: &ViewContext) -> Option<PlayerState> {
ctx.db.player_state().id().find(0u64)
}
"""
def test_st_view_tables(self):
"""This test asserts that views populate the st_view_* system tables"""
self.assertSql("SELECT * FROM st_view", """\
view_id | view_name | table_id | is_public | is_anonymous
---------+-----------+---------------+-----------+--------------
4096 | "player" | (some = 4097) | true | false
""")
self.assertSql("SELECT * FROM st_view_column", """\
view_id | col_pos | col_name | col_type
---------+---------+----------+----------
4096 | 0 | "id" | 0x0d
4096 | 1 | "level" | 0x0d
""")
class FailPublish(Smoketest):
AUTOPUBLISH = False
MODULE_CODE_BROKEN_NAMESPACE = """
use spacetimedb::ViewContext;
#[spacetimedb::table(accessor = person, public)]
pub struct Person {
name: String,
}
#[spacetimedb::view(accessor = person, public)]
pub fn person(ctx: &ViewContext) -> Option<Person> {
None
}
"""
MODULE_CODE_BROKEN_RETURN_TYPE = """
use spacetimedb::{SpacetimeType, ViewContext};
#[derive(SpacetimeType)]
pub enum ABC {
A,
B,
C,
}
#[spacetimedb::view(accessor = person, public)]
pub fn person(ctx: &ViewContext) -> Option<ABC> {
None
}
"""
def test_fail_publish_namespace_collision(self):
"""Publishing a module should fail if a table and view have the same name"""
name = random_string()
self.write_module_code(self.MODULE_CODE_BROKEN_NAMESPACE)
with self.assertRaises(Exception):
self.publish_module(name)
def test_fail_publish_wrong_return_type(self):
"""Publishing a module should fail if the inner return type is not a product type"""
name = random_string()
self.write_module_code(self.MODULE_CODE_BROKEN_RETURN_TYPE)
with self.assertRaises(Exception):
self.publish_module(name)
class SqlViews(Smoketest):
MODULE_CODE = """
use spacetimedb::{AnonymousViewContext, ReducerContext, Table, ViewContext};
#[derive(Copy, Clone)]
#[spacetimedb::table(accessor = player_state)]
#[spacetimedb::table(accessor = player_level)]
pub struct PlayerState {
#[primary_key]
id: u64,
#[index(btree)]
level: u64,
}
#[derive(Clone)]
#[spacetimedb::table(accessor = player_info, index(accessor=age_level_index, btree(columns = [age, level])))]
pub struct PlayerInfo {
#[primary_key]
id: u64,
age: u64,
level: u64,
}
#[spacetimedb::reducer]
pub fn add_player_level(ctx: &ReducerContext, id: u64, level: u64) {
ctx.db.player_level().insert(PlayerState { id, level });
}
#[spacetimedb::view(accessor = my_player_and_level, public)]
pub fn my_player_and_level(ctx: &AnonymousViewContext) -> Option<PlayerState> {
ctx.db.player_level().id().find(0)
}
#[spacetimedb::view(accessor = player_and_level, public)]
pub fn player_and_level(ctx: &AnonymousViewContext) -> Vec<PlayerState> {
ctx.db.player_level().level().filter(2u64).collect()
}
#[spacetimedb::view(accessor = player, public)]
pub fn player(ctx: &ViewContext) -> Option<PlayerState> {
log::info!("player view called");
ctx.db.player_state().id().find(42)
}
#[spacetimedb::view(accessor = player_none, public)]
pub fn player_none(_ctx: &ViewContext) -> Option<PlayerState> {
None
}
#[spacetimedb::view(accessor = player_vec, public)]
pub fn player_vec(ctx: &ViewContext) -> Vec<PlayerState> {
let first = ctx.db.player_state().id().find(42).unwrap();
let second = PlayerState { id: 7, level: 3 };
vec![first, second]
}
#[spacetimedb::view(accessor = player_info_multi_index, public)]
pub fn player_info_view(ctx: &ViewContext) -> Option<PlayerInfo> {
log::info!("player_info called");
ctx.db.player_info().age_level_index().filter((25u64, 7u64)).next()
}
"""
def assertSql(self, sql, expected):
self.maxDiff = None
sql_out = self.spacetime("sql", self.database_identity, sql)
sql_out = "\n".join([line.rstrip() for line in sql_out.splitlines()])
expected = "\n".join([line.rstrip() for line in expected.splitlines()])
self.assertMultiLineEqual(sql_out, expected)
def insert_initial_data(self):
self.spacetime(
"sql",
self.database_identity,
"""\
INSERT INTO player_state (id, level) VALUES (42, 7);
""",
)
def call_player_view(self):
self.assertSql("SELECT * FROM player", """\
id | level
----+-------
42 | 7
""")
def test_http_sql(self):
"""This test asserts that views can be queried over HTTP SQL"""
self.insert_initial_data()
self.call_player_view()
self.assertSql("SELECT * FROM player_none", """\
id | level
----+-------
""")
self.assertSql("SELECT * FROM player_vec", """\
id | level
----+-------
42 | 7
7 | 3
""")
# test is prefixed with 'a' to ensure it runs before any other tests,
# since it relies on log capturing starting from an empty log.
def test_a_view_materialization(self):
"""This test asserts whether views are materialized correctly"""
player_called_log = "player view called"
# call view, with no data
self.assertSql("SELECT * FROM player", """\
id | level
----+-------
""")
logs = self.logs(100)
self.assertEqual(logs.count(player_called_log), 1)
self.insert_initial_data()
# Should invoke view as data is inserted
self.call_player_view()
logs = self.logs(100)
self.assertEqual(logs.count(player_called_log), 2)
self.call_player_view()
# the view is cached
logs = self.logs(100)
self.assertEqual(logs.count(player_called_log), 2)
# inserting new row should not trigger view invocation due to readsets
self.spacetime(
"sql",
self.database_identity,
"""\
INSERT INTO player_state (id, level) VALUES (22, 8);
""",
)
self.call_player_view()
logs = self.logs(100)
self.assertEqual(logs.count(player_called_log), 2)
# Updating the row that the view depends on should trigger re-evaluation
self.spacetime(
"sql",
self.database_identity,
"""
UPDATE player_state SET level = 9 WHERE id = 42;
""",
)
# On fourth call, after updating the dependent row, the view is re-evaluated
logs = self.logs(100)
self.assertEqual(logs.count(player_called_log), 3)
# Updating it back for other tests to work
self.spacetime(
"sql",
self.database_identity,
"""
UPDATE player_state SET level = 7 WHERE id = 42;
""",
)
def test_view_multi_index_materialization(self):
"""This test asserts whether views using multi-column indexes are materialized correctly"""
player_called_log = "player_info called"
# call view, with no data
self.assertSql("SELECT * FROM player_info_multi_index", """\
id | age | level
----+-----+-------
""")
logs = self.logs(100)
self.assertEqual(logs.count(player_called_log), 1)
# Insert data
self.spacetime(
"sql",
self.database_identity,
"""\
INSERT INTO player_info (id, age, level) VALUES (1, 25, 7);
""",
)
# Should invoke view as data is inserted
self.assertSql("SELECT * FROM player_info_multi_index", """\
id | age | level
----+-----+-------
1 | 25 | 7
""")
logs = self.logs(100)
self.assertEqual(logs.count(player_called_log), 2)
# Inserting a row that does not match should not trigger re-evaluation
self.spacetime(
"sql",
self.database_identity,
"""\
INSERT INTO player_info (id, age, level) VALUES (2, 25, 8);
""",
)
logs = self.logs(100)
self.assertEqual(logs.count(player_called_log), 2)
# Updating the row that the view depends on should trigger re-evaluation
self.spacetime(
"sql",
self.database_identity,
"""
UPDATE player_info SET age = 26 WHERE id = 1;
""",
)
logs = self.logs(100)
self.assertEqual(logs.count(player_called_log), 3)
self.assertSql("SELECT * FROM player_info_multi_index", """\
id | age | level
----+-----+-------
""")
def test_query_anonymous_view_reducer(self):
"""Tests that anonymous views are updated for reducers"""
self.call("add_player_level", 0, 1)
self.call("add_player_level", 1, 2)
self.assertSql("SELECT * FROM my_player_and_level", """\
id | level
----+-------
0 | 1
""")
self.assertSql("SELECT * FROM player_and_level", """\
id | level
----+-------
1 | 2
""")
self.call("add_player_level", 2, 2)
self.assertSql("SELECT * FROM player_and_level", """\
id | level
----+-------
1 | 2
2 | 2
""")
self.assertSql("SELECT * FROM player_and_level WHERE id = 2", """\
id | level
----+-------
2 | 2
""")
class AutoMigrateViews(Smoketest):
MODULE_CODE = """
use spacetimedb::ViewContext;
#[derive(Copy, Clone)]
#[spacetimedb::table(accessor = player_state)]
pub struct PlayerState {
#[primary_key]
id: u64,
#[index(btree)]
level: u64,
}
#[spacetimedb::view(accessor = player, public)]
pub fn player(ctx: &ViewContext) -> Option<PlayerState> {
ctx.db.player_state().id().find(1u64)
}
"""
MODULE_CODE_UPDATED = """
use spacetimedb::ViewContext;
#[derive(Copy, Clone)]
#[spacetimedb::table(accessor = player_state)]
pub struct PlayerState {
#[primary_key]
id: u64,
#[index(btree)]
level: u64,
}
#[spacetimedb::view(accessor = player, public)]
pub fn player(ctx: &ViewContext) -> Option<PlayerState> {
ctx.db.player_state().id().find(2u64)
}
"""
def assertSql(self, sql, expected):
self.maxDiff = None
sql_out = self.spacetime("sql", self.database_identity, sql)
sql_out = "\n".join([line.rstrip() for line in sql_out.splitlines()])
expected = "\n".join([line.rstrip() for line in expected.splitlines()])
self.assertMultiLineEqual(sql_out, expected)
def test_views_auto_migration(self):
"""Assert that views are auto-migrated correctly"""
self.spacetime(
"sql",
self.database_identity,
"INSERT INTO player_state (id, level) VALUES (1, 1);",
)
self.spacetime(
"sql",
self.database_identity,
"INSERT INTO player_state (id, level) VALUES (2, 2);",
)
self.assertSql("SELECT * FROM player", """\
id | level
----+-------
1 | 1
""")
self.write_module_code(self.MODULE_CODE_UPDATED)
self.publish_module(self.database_identity, clear=False)
self.assertSql("SELECT * FROM player", """\
id | level
----+-------
2 | 2
""")
class AutoMigrateDropView(Smoketest):
MODULE_CODE = """
use spacetimedb::ViewContext;
#[derive(Copy, Clone)]
#[spacetimedb::table(accessor = player_state)]
pub struct PlayerState {
#[primary_key]
id: u64,
#[index(btree)]
level: u64,
}
#[spacetimedb::view(accessor = player, public)]
pub fn player(ctx: &ViewContext) -> Option<PlayerState> {
ctx.db.player_state().id().find(1u64)
}
"""
MODULE_CODE_DROP_VIEW = """
#[derive(Copy, Clone)]
#[spacetimedb::table(accessor = player_state)]
pub struct PlayerState {
#[primary_key]
id: u64,
#[index(btree)]
level: u64,
}
"""
def test_auto_migration_drop_view(self):
"""Assert that views can be dropped in an auto-migration"""
self.write_module_code(self.MODULE_CODE_DROP_VIEW)
self.publish_module(self.database_identity, clear=False, break_clients=False)
class AutoMigrateAddView(Smoketest):
MODULE_CODE = """
#[derive(Copy, Clone)]
#[spacetimedb::table(accessor = player_state)]
pub struct PlayerState {
#[primary_key]
id: u64,
#[index(btree)]
level: u64,
}
"""
MODULE_CODE_ADD_VIEW = """
use spacetimedb::ViewContext;
#[derive(Copy, Clone)]
#[spacetimedb::table(accessor = player_state)]
pub struct PlayerState {
#[primary_key]
id: u64,
#[index(btree)]
level: u64,
}
#[spacetimedb::view(accessor = player, public)]
pub fn player(ctx: &ViewContext) -> Option<PlayerState> {
ctx.db.player_state().id().find(1u64)
}
"""
def test_auto_migration_drop_view(self):
"""Assert that views can be added in an auto-migration"""
self.write_module_code(self.MODULE_CODE_ADD_VIEW)
self.publish_module(self.database_identity, clear=False)
class AutoMigrateViewsTrapped(Smoketest):
MODULE_CODE = """
use spacetimedb::ViewContext;
#[derive(Copy, Clone)]
#[spacetimedb::table(accessor = player_state)]
pub struct PlayerState {
#[primary_key]
id: u64,
#[index(btree)]
level: u64,
}
#[spacetimedb::view(accessor = player, public)]
pub fn player(ctx: &ViewContext) -> Option<PlayerState> {
ctx.db.player_state().id().find(1u64)
}
"""
TRAPPED_MODULE_CODE_UPDATED = """
use spacetimedb::ViewContext;
#[derive(Copy, Clone)]
#[spacetimedb::table(accessor = player_state)]
pub struct PlayerState {
#[primary_key]
id: u64,
#[index(btree)]
level: u64,
}
#[spacetimedb::view(accessor = player, public)]
pub fn player(_ctx: &ViewContext) -> Option<PlayerState> {
panic!("This view is trapped")
}
"""
MODULE_CODE_RECOVERED = """
use spacetimedb::ViewContext;
#[derive(Copy, Clone)]
#[spacetimedb::table(accessor = player_state)]
pub struct PlayerState {
#[primary_key]
id: u64,
#[index(btree)]
level: u64,
}
#[spacetimedb::view(accessor = player, public)]
pub fn player(ctx: &ViewContext) -> Option<PlayerState> {
ctx.db.player_state().id().find(2u64)
}
"""
def assertSql(self, sql, expected):
self.maxDiff = None
sql_out = self.spacetime("sql", self.database_identity, sql)
sql_out = "\n".join([line.rstrip() for line in sql_out.splitlines()])
expected = "\n".join([line.rstrip() for line in expected.splitlines()])
self.assertMultiLineEqual(sql_out, expected)
def test_recovery_from_trapped_views_auto_migration(self):
"""Assert that view auto-migration recovers correctly after trapped migration"""
self.spacetime(
"sql",
self.database_identity,
"INSERT INTO player_state (id, level) VALUES (1, 1);",
)
# Trigger initial materialization
self.assertSql("SELECT * FROM player", """\
id | level
----+-------
1 | 1
""")
# Attempt to publish trapped module (should fail)
self.write_module_code(self.TRAPPED_MODULE_CODE_UPDATED)
with self.assertRaises(Exception):
self.publish_module(self.database_identity, clear=False)
# Ensure old module still serves queries
self.assertSql("SELECT * FROM player", """\
id | level
----+-------
1 | 1
""")
# Fix the module and publish again
self.write_module_code(self.MODULE_CODE_RECOVERED)
self.publish_module(self.database_identity, clear=False)
self.assertSql("SELECT * FROM player", """\
id | level
----+-------
""")
class SubscribeViews(Smoketest):
MODULE_CODE = """
use spacetimedb::{Identity, ReducerContext, Table, ViewContext};
#[spacetimedb::table(accessor = player_state)]
pub struct PlayerState {
#[primary_key]
identity: Identity,
#[unique]
name: String,
}
#[spacetimedb::view(accessor = my_player, public)]
pub fn my_player(ctx: &ViewContext) -> Option<PlayerState> {
ctx.db.player_state().identity().find(ctx.sender())
}
#[spacetimedb::reducer]
pub fn insert_player(ctx: &ReducerContext, name: String) {
ctx.db.player_state().insert(PlayerState { name, identity: ctx.sender() });
}
"""
def test_subscribing_with_different_identities(self):
"""Tests different clients subscribing to a client-specific view"""
# Insert an identity for Alice
self.call("insert_player", "Alice")
# Generate a new identity for Bob
self.reset_config()
self.new_identity()
# Subscribe to `my_player` as Bob
sub = self.subscribe("select * from my_player", n=1)
self.call("insert_player", "Bob")
events = sub()
# Project out the identity field.
# TODO: Eventually we should be able to do this directly in the sql.
# But for now we implement it in python.
projection = [
{
'my_player': {
'deletes': [
{'name': row['name']}
for row in event['my_player']['deletes']
],
'inserts': [
{'name': row['name']}
for row in event['my_player']['inserts']
],
}
}
for event in events
]
self.assertEqual(
[
{
'my_player': {
'deletes': [],
'inserts': [{'name': 'Bob'}],
}
},
],
projection,
)
class QueryView(Smoketest):
MODULE_CODE = """
use spacetimedb::{Query, ReducerContext, Table, ViewContext};
#[spacetimedb::table(accessor = user, public)]
pub struct User {
#[primary_key]
identity: u8,
name: String,
online: bool,
}
#[spacetimedb::table(accessor = person, public)]
pub struct Person {
#[primary_key]
identity: u8,
name: String,
#[index(btree)]
age: u8,
}
#[spacetimedb::reducer(init)]
fn init(ctx: &ReducerContext) {
ctx.db.user().insert(User {
identity: 1,
name: "Alice".to_string(),
online: true,
});
ctx.db.user().insert(User {
identity: 2,
name: "BOB".to_string(),
online: false,
});
ctx.db.user().insert(User {
identity: 3,
name: "POP".to_string(),
online: false,
});
ctx.db.person().insert(Person {
identity: 1,
name: "Alice".to_string(),
age: 30,
});
ctx.db.person().insert(Person {
identity: 2,
name: "BOB".to_string(),
age: 20,
});
}
#[spacetimedb::view(accessor = online_users, public)]
fn online_users(ctx: &ViewContext) -> impl Query<User> {
ctx.from.user().r#where(|c| c.online.eq(true))
}
#[spacetimedb::view(accessor = online_users_age, public)]
fn online_users_age(ctx: &ViewContext) -> impl Query<Person> {
ctx.from
.user()
.r#where(|u| u.online.eq(true))
.right_semijoin(ctx.from.person(), |u, p| u.identity.eq(p.identity))
}
#[spacetimedb::view(accessor = offline_user_20_years_old, public)]
fn offline_user_in_twienties(ctx: &ViewContext) -> impl Query<User> {
ctx.from
.person()
.filter(|p| p.age.eq(20))
.right_semijoin(ctx.from.user(), |p, u| p.identity.eq(u.identity))
.filter(|u| u.online.eq(false))
}
#[spacetimedb::view(accessor = users_whos_age_is_known, public)]
fn users_whos_age_is_known(ctx: &ViewContext) -> impl Query<User> {
ctx.from
.user()
.left_semijoin(ctx.from.person(), |p, u| p.identity.eq(u.identity))
}
#[spacetimedb::view(accessor = users_who_are_above_20_and_below_30, public)]
fn users_who_are_above_20_and_below_30(ctx: &ViewContext) -> impl Query<Person> {
ctx.from
.person()
.r#where(|p| p.age.gt(20).and(p.age.lt(30)))
}
#[spacetimedb::view(accessor = users_who_are_above_eq_20_and_below_eq_30, public)]
fn users_who_are_above_eq_20_and_below_eq_30(ctx: &ViewContext) -> impl Query<Person> {
ctx.from
.person()
.r#where(|p| p.age.gte(20).and(p.age.lte(30)))
}
"""
def test_query_view(self):
"""Tests that views returning Query types work as expected"""
self.assertSql("SELECT * FROM online_users", """\
identity | name | online
----------+---------+--------
1 | "Alice" | true
""")
def test_query_right_semijoin_view(self):
"""Tests that views returning Query types with right semijoin work as expected"""
self.assertSql("SELECT * FROM online_users_age", """\
identity | name | age
----------+---------+-----
1 | "Alice" | 30
""")
def test_query_left_semijoin_view(self):
"""Tests that views returning Query types with left semijoin work as expected"""
self.assertSql("SELECT * FROM users_whos_age_is_known", """\
identity | name | online
----------+---------+--------
1 | "Alice" | true
2 | "BOB" | false
""")
def test_query_complex_right_semijoin_view(self):
"""Tests that views returning Query types with right semijoin work as expected"""
self.assertSql("SELECT * FROM offline_user_20_years_old", """\
identity | name | online
----------+-------+--------
2 | "BOB" | false
""")
def test_where_expr_view(self):
"""Tests that views with where expressions work as expected"""
self.assertSql("SELECT * FROM users_who_are_above_20_and_below_30", """\
identity | name | age
----------+------+-----
""")
self.assertSql("SELECT * FROM users_who_are_above_eq_20_and_below_eq_30", """\
identity | name | age
----------+---------+-----
1 | "Alice" | 30
2 | "BOB" | 20
""")
-176
View File
@@ -1,176 +0,0 @@
from .. import Smoketest, requires_docker
from ..docker import restart_docker
from urllib.request import urlopen
from .add_remove_index import AddRemoveIndex
@requires_docker
class DockerRestartModule(Smoketest):
# Note: creating indexes on `Person`
# exercises more possible failure cases when replaying after restart
MODULE_CODE = """
use spacetimedb::{log, ReducerContext, Table};
#[spacetimedb::table(accessor = person, index(accessor = name_idx, btree(columns = [name])))]
pub struct Person {
#[primary_key]
#[auto_inc]
id: u32,
name: String,
}
#[spacetimedb::reducer]
pub fn add(ctx: &ReducerContext, name: String) {
ctx.db.person().insert(Person { id: 0, name });
}
#[spacetimedb::reducer]
pub fn say_hello(ctx: &ReducerContext) {
for person in ctx.db.person().iter() {
log::info!("Hello, {}!", person.name);
}
log::info!("Hello, World!");
}
"""
def test_restart_module(self):
"""This tests to see if SpacetimeDB can be queried after a restart"""
self.call("add", "Robert")
restart_docker()
self.call("add", "Julie")
self.call("add", "Samantha")
self.call("say_hello")
logs = self.logs(100)
self.assertIn("Hello, Samantha!", logs)
self.assertIn("Hello, Julie!", logs)
self.assertIn("Hello, Robert!", logs)
self.assertIn("Hello, World!", logs)
@requires_docker
class DockerRestartSql(Smoketest):
# Note: creating indexes on `Person`
# exercises more possible failure cases when replaying after restart
MODULE_CODE = """
use spacetimedb::{log, ReducerContext, Table};
#[spacetimedb::table(accessor = person, index(accessor = name_idx, btree(columns = [name])))]
pub struct Person {
#[primary_key]
#[auto_inc]
id: u32,
name: String,
}
#[spacetimedb::reducer]
pub fn add(ctx: &ReducerContext, name: String) {
ctx.db.person().insert(Person { id: 0, name });
}
#[spacetimedb::reducer]
pub fn say_hello(ctx: &ReducerContext) {
for person in ctx.db.person().iter() {
log::info!("Hello, {}!", person.name);
}
log::info!("Hello, World!");
}
"""
def test_restart_module(self):
"""This tests to see if SpacetimeDB can be queried after a restart"""
self.call("add", "Robert")
self.call("add", "Julie")
self.call("add", "Samantha")
self.call("say_hello")
logs = self.logs(100)
self.assertIn("Hello, Samantha!", logs)
self.assertIn("Hello, Julie!", logs)
self.assertIn("Hello, Robert!", logs)
self.assertIn("Hello, World!", logs)
restart_docker()
sql_out = self.spacetime("sql", self.database_identity, "SELECT name FROM person WHERE id = 3")
self.assertMultiLineEqual(sql_out, """ name \n------------\n "Samantha" \n""")
@requires_docker
class DockerRestartAutoDisconnect(Smoketest):
MODULE_CODE = """
use log::info;
use spacetimedb::{ConnectionId, Identity, ReducerContext, Table};
#[spacetimedb::table(accessor = connected_client)]
pub struct ConnectedClient {
identity: Identity,
connection_id: ConnectionId,
}
#[spacetimedb::reducer(client_connected)]
fn on_connect(ctx: &ReducerContext) {
ctx.db.connected_client().insert(ConnectedClient {
identity: ctx.sender(),
connection_id: ctx.connection_id().expect("sender connection id unset"),
});
}
#[spacetimedb::reducer(client_disconnected)]
fn on_disconnect(ctx: &ReducerContext) {
let sender_identity = &ctx.sender();
let connection_id = ctx.connection_id();
let sender_connection_id = connection_id.as_ref().expect("sender connection id unset");
let match_client = |row: &ConnectedClient| {
&row.identity == sender_identity && &row.connection_id == sender_connection_id
};
if let Some(client) = ctx.db.connected_client().iter().find(match_client) {
ctx.db.connected_client().delete(client);
}
}
#[spacetimedb::reducer]
fn print_num_connected(ctx: &ReducerContext) {
let n = ctx.db.connected_client().count();
info!("CONNECTED CLIENTS: {n}")
}
"""
def test_restart_disconnects(self):
"""Tests if clients are automatically disconnected after a restart"""
# Start two subscribers
self.subscribe("SELECT * FROM connected_client", n=2)
self.subscribe("SELECT * FROM connected_client", n=2)
# Assert that we have two clients + the reducer call
self.call("print_num_connected")
logs = self.logs(10)
self.assertEqual("CONNECTED CLIENTS: 3", logs.pop())
restart_docker()
# After restart, only the current call should be connected
self.call("print_num_connected")
logs = self.logs(10)
self.assertEqual("CONNECTED CLIENTS: 1", logs.pop())
@requires_docker
class AddRemoveIndexAfterRestart(AddRemoveIndex):
"""
`AddRemoveIndex` from `add_remove_index.py`,
but restarts docker between each publish.
This detects a bug we once had, hopefully fixed now,
where the system autoinc sequences were borked after restart,
leading newly-created database objects to re-use IDs.
First publish the module without the indices,
then restart docker, then add the indices and publish.
Then restart docker, and publish again.
There should be no errors from publishing,
and the unindexed versions should reject subscriptions.
"""
def between_publishes(self):
restart_docker()
-376
View File
@@ -1,376 +0,0 @@
# vendored and modified from unittest-parallel by Craig Hobbs
# TODO: upstream some of these changes? to make it usable as a library, maybe?
# Licensed under the MIT License
# https://github.com/craigahobbs/unittest-parallel/blob/main/LICENSE
# full text of license file below:
#
# The MIT License (MIT)
#
# Copyright (c) 2017 Craig A. Hobbs
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
unittest-parallel command-line script main module
"""
import argparse
from contextlib import contextmanager
from io import StringIO
import os
import sys
import tempfile
import time
import unittest
import concurrent.futures
import threading
import coverage
def main(**kwargs):
"""
unittest-parallel command-line script main entry point
"""
# Command line arguments
parser = argparse.ArgumentParser(prog='unittest-parallel')
parser.add_argument('-v', '--verbose', action='store_const', const=2, default=1,
help='Verbose output')
parser.add_argument('-q', '--quiet', dest='verbose', action='store_const', const=0, default=1,
help='Quiet output')
parser.add_argument('-f', '--failfast', action='store_true', default=False,
help='Stop on first fail or error')
parser.add_argument('-b', '--buffer', action='store_true', default=False,
help='Buffer stdout and stderr during tests')
parser.add_argument('-k', dest='testNamePatterns', action='append', type=_convert_select_pattern,
help='Only run tests which match the given substring')
parser.add_argument('-s', '--start-directory', metavar='START', default='.',
help="Directory to start discovery ('.' default)")
parser.add_argument('-p', '--pattern', metavar='PATTERN', default='test*.py',
help="Pattern to match tests ('test*.py' default)")
parser.add_argument('-t', '--top-level-directory', metavar='TOP',
help='Top level directory of project (defaults to start directory)')
group_parallel = parser.add_argument_group('parallelization options')
group_parallel.add_argument('-j', '--jobs', metavar='COUNT', type=int, default=0,
help='The number of test processes (default is 0, all cores)')
group_parallel.add_argument('--level', choices=['module', 'class', 'test'], default='module',
help="Set the test parallelism level (default is 'module')")
group_parallel.add_argument('--disable-process-pooling', action='store_true', default=False,
help='Do not reuse processes used to run test suites')
group_coverage = parser.add_argument_group('coverage options')
group_coverage.add_argument('--coverage', action='store_true',
help='Run tests with coverage')
group_coverage.add_argument('--coverage-branch', action='store_true',
help='Run tests with branch coverage')
group_coverage.add_argument('--coverage-rcfile', metavar='RCFILE',
help='Specify coverage configuration file')
group_coverage.add_argument('--coverage-include', metavar='PAT', action='append',
help='Include only files matching one of these patterns. Accepts shell-style (quoted) wildcards.')
group_coverage.add_argument('--coverage-omit', metavar='PAT', action='append',
help='Omit files matching one of these patterns. Accepts shell-style (quoted) wildcards.')
group_coverage.add_argument('--coverage-source', metavar='SRC', action='append',
help='A list of packages or directories of code to be measured')
group_coverage.add_argument('--coverage-html', metavar='DIR',
help='Generate coverage HTML report')
group_coverage.add_argument('--coverage-xml', metavar='FILE',
help='Generate coverage XML report')
group_coverage.add_argument('--coverage-fail-under', metavar='MIN', type=float,
help='Fail if coverage percentage under min')
args = parser.parse_args(args=[])
args.__dict__.update(kwargs)
if args.coverage_branch:
args.coverage = args.coverage_branch
process_count = max(0, args.jobs)
if process_count == 0:
process_count = os.cpu_count()
# Create the temporary directory (for coverage files)
with tempfile.TemporaryDirectory() as temp_dir:
# Discover tests
# with _coverage(args, temp_dir):
# test_loader = unittest.TestLoader()
# if args.testNamePatterns:
# test_loader.testNamePatterns = args.testNamePatterns
# discover_suite = test_loader.discover(args.start_directory, pattern=args.pattern, top_level_dir=args.top_level_directory)
discover_suite = args.discovered_tests
# Get the parallelizable test suites
if args.level == 'test':
test_suites = list(_iter_test_cases(discover_suite))
elif args.level == 'class':
test_suites = list(_iter_class_suites(discover_suite))
else: # args.level == 'module'
test_suites = list(_iter_module_suites(discover_suite))
# Don't use more processes than test suites
process_count = max(1, min(len(test_suites), process_count))
# Report test suites and processes
print(
f'Running {len(test_suites)} test suites ({discover_suite.countTestCases()} total tests) across {process_count} threads',
file=sys.stderr
)
if args.verbose > 1:
print(file=sys.stderr)
# Run the tests in parallel
start_time = time.perf_counter()
with concurrent.futures.ThreadPoolExecutor(max_workers=process_count) as executor:
test_manager = ParallelTestManager(args, temp_dir)
futures = [executor.submit(test_manager.run_tests, suite) for suite in discover_suite]
# Aggregate parallel test run results
tests_run = 0
errors = []
failures = []
skipped = 0
expected_failures = 0
unexpected_successes = 0
for fut in concurrent.futures.as_completed(futures):
try:
result, stream = fut.result()
except concurrent.futures.CancelledError:
continue
tests_run += result.testsRun
errors.extend(ParallelTestManager._format_error(result, error) for error in result.errors)
failures.extend(ParallelTestManager._format_error(result, failure) for failure in result.failures)
skipped += len(result.skipped)
expected_failures += len(result.expectedFailures)
unexpected_successes += len(result.unexpectedSuccesses)
if result.shouldStop:
for fut in futures:
fut.cancel()
is_success = not(errors or failures or unexpected_successes)
stop_time = time.perf_counter()
test_duration = stop_time - start_time
# Compute test info
infos = []
if failures:
infos.append(f'failures={len(failures)}')
if errors:
infos.append(f'errors={len(errors)}')
if skipped:
infos.append(f'skipped={skipped}')
if expected_failures:
infos.append(f'expected failures={expected_failures}')
if unexpected_successes:
infos.append(f'unexpected successes={unexpected_successes}')
# Report test errors
if errors or failures:
print(file=sys.stderr)
for error in errors:
print(error, file=sys.stderr)
for failure in failures:
print(failure, file=sys.stderr)
elif args.verbose > 0:
print(file=sys.stderr)
# Test report
print(unittest.TextTestResult.separator2, file=sys.stderr)
print(f'Ran {tests_run} {"tests" if tests_run > 1 else "test"} in {test_duration:.3f}s', file=sys.stderr)
print(file=sys.stderr)
print(f'{"OK" if is_success else "FAILED"}{" (" + ", ".join(infos) + ")" if infos else ""}', file=sys.stderr)
# Return an error status on failure
if not is_success:
parser.exit(status=len(errors) + len(failures) + unexpected_successes)
# Coverage?
if args.coverage:
# Combine the coverage files
cov_options = {}
if args.coverage_rcfile is not None:
cov_options['config_file'] = args.coverage_rcfile
cov = coverage.Coverage(**cov_options)
cov.combine(data_paths=[os.path.join(temp_dir, x) for x in os.listdir(temp_dir)])
# Coverage report
print(file=sys.stderr)
percent_covered = cov.report(ignore_errors=True, file=sys.stderr)
print(f'Total coverage is {percent_covered:.2f}%', file=sys.stderr)
# HTML coverage report
if args.coverage_html:
cov.html_report(directory=args.coverage_html, ignore_errors=True)
# XML coverage report
if args.coverage_xml:
cov.xml_report(outfile=args.coverage_xml, ignore_errors=True)
# Fail under
if args.coverage_fail_under and percent_covered < args.coverage_fail_under:
parser.exit(status=2)
def _convert_select_pattern(pattern):
if not '*' in pattern:
return f'*{pattern}*'
return pattern
@contextmanager
def _coverage(args, temp_dir):
# Running tests with coverage?
if args.coverage:
# Generate a random coverage data file name - file is deleted along with containing directory
with tempfile.NamedTemporaryFile(dir=temp_dir, delete=False) as coverage_file:
pass
# Create the coverage object
cov_options = {
'branch': args.coverage_branch,
'data_file': coverage_file.name,
'include': args.coverage_include,
'omit': (args.coverage_omit if args.coverage_omit else []) + [__file__],
'source': args.coverage_source
}
if args.coverage_rcfile is not None:
cov_options['config_file'] = args.coverage_rcfile
cov = coverage.Coverage(**cov_options)
try:
# Start measuring code coverage
cov.start()
# Yield for unit test running
yield cov
finally:
# Stop measuring code coverage
cov.stop()
# Save the collected coverage data to the data file
cov.save()
else:
# Not running tests with coverage - yield for unit test running
yield None
# Iterate module-level test suites - all top-level test suites returned from TestLoader.discover
def _iter_module_suites(test_suite):
for module_suite in test_suite:
if module_suite.countTestCases():
yield module_suite
# Iterate class-level test suites - test suites that contains test cases
def _iter_class_suites(test_suite):
has_cases = any(isinstance(suite, unittest.TestCase) for suite in test_suite)
if has_cases:
yield test_suite
else:
for suite in test_suite:
yield from _iter_class_suites(suite)
# Iterate test cases (methods)
def _iter_test_cases(test_suite):
if isinstance(test_suite, unittest.TestCase):
yield test_suite
else:
for suite in test_suite:
yield from _iter_test_cases(suite)
class ParallelTestManager:
def __init__(self, args, temp_dir):
self.args = args
self.temp_dir = temp_dir
def run_tests(self, test_suite):
# Run unit tests
with _coverage(self.args, self.temp_dir):
stream = StringIO()
runner = unittest.TextTestRunner(
stream=stream,
resultclass=ParallelTextTestResult,
verbosity=self.args.verbose,
failfast=self.args.failfast,
buffer=self.args.buffer
)
result = runner.run(test_suite)
# Return (test_count, errors, failures, skipped_count, expected_failure_count, unexpected_success_count)
return result, stream
@staticmethod
def _format_error(result, error):
return '\n'.join([
unittest.TextTestResult.separator1,
result.getDescription(error[0]),
unittest.TextTestResult.separator2,
error[1]
])
class ParallelTextTestResult(unittest.TextTestResult):
def __init__(self, stream, descriptions, verbosity):
stream = type(stream)(sys.stderr)
super().__init__(stream, descriptions, verbosity)
def startTest(self, test):
if self.showAll:
self.stream.writeln(f'{self.getDescription(test)} ...')
self.stream.flush()
super(unittest.TextTestResult, self).startTest(test)
def _add_helper(self, test, dots_message, show_all_message):
if self.showAll:
self.stream.writeln(f'{self.getDescription(test)} ... {show_all_message}')
elif self.dots:
self.stream.write(dots_message)
self.stream.flush()
def addSuccess(self, test):
super(unittest.TextTestResult, self).addSuccess(test)
self._add_helper(test, '.', 'ok')
def addError(self, test, err):
super(unittest.TextTestResult, self).addError(test, err)
self._add_helper(test, 'E', 'ERROR')
def addFailure(self, test, err):
super(unittest.TextTestResult, self).addFailure(test, err)
self._add_helper(test, 'F', 'FAIL')
def addSkip(self, test, reason):
super(unittest.TextTestResult, self).addSkip(test, reason)
self._add_helper(test, 's', f'skipped {reason!r}')
def addExpectedFailure(self, test, err):
super(unittest.TextTestResult, self).addExpectedFailure(test, err)
self._add_helper(test, 'x', 'expected failure')
def addUnexpectedSuccess(self, test):
super(unittest.TextTestResult, self).addUnexpectedSuccess(test)
self._add_helper(test, 'u', 'unexpected success')
def printErrors(self):
pass