[TS] Allow brotli to be specified for compression and reorganize some websocket stuff (#4561)

# Description of Changes

[Some
runtimes](https://developer.mozilla.org/en-US/docs/Web/API/DecompressionStream#browser_compatibility)
support brotli for `DecompressionStream`, so I figure we may as
well allow it. Also reorganizes some of the websocket code for better
separation of concerns.

# Expected complexity level and risk

1

# Testing

- [ ] <!-- maybe a test you want to do -->
- [ ] <!-- maybe a test you want a reviewer to do, so they can check it
off when they're satisfied. -->
This commit is contained in:
Noa
2026-04-29 21:40:29 -05:00
committed by GitHub
parent afa212b9db
commit f22f80060f
10 changed files with 162 additions and 168 deletions
@@ -63,7 +63,7 @@ export default class BinaryWriter {
return fromByteArray(this.getBuffer());
}
getBuffer(): Uint8Array {
getBuffer(): Uint8Array<ArrayBuffer> {
return new Uint8Array(this.buffer.buffer, 0, this.offset);
}
@@ -8,6 +8,7 @@ import type {
} from '../';
import { ensureMinimumVersionOrThrow } from './version';
import { WebsocketDecompressAdapter } from './websocket_decompress_adapter';
import type { WebSocketFactory } from './ws';
/**
* The database client connection to a SpacetimeDB server.
@@ -23,10 +24,10 @@ export class DbConnectionBuilder<DbConnection extends DbConnectionImpl<any>> {
#identity?: Identity;
#token?: string;
#emitter: EventEmitter<ConnectionEvent> = new EventEmitter();
#compression: 'gzip' | 'none' = 'gzip';
#compression: 'gzip' | 'brotli' | 'none' = 'gzip';
#lightMode: boolean = false;
#confirmedReads?: boolean;
#createWSFn: typeof WebsocketDecompressAdapter.createWebSocketFn;
#createWSFn: WebSocketFactory;
/**
* Creates a new `DbConnectionBuilder` database client and set the initial parameters.
@@ -42,7 +43,7 @@ export class DbConnectionBuilder<DbConnection extends DbConnectionImpl<any>> {
config: DbConnectionConfig<RemoteModuleOf<DbConnection>>
) => DbConnection
) {
this.#createWSFn = WebsocketDecompressAdapter.createWebSocketFn;
this.#createWSFn = WebsocketDecompressAdapter.openWebSocket;
}
/**
@@ -82,9 +83,7 @@ export class DbConnectionBuilder<DbConnection extends DbConnectionImpl<any>> {
return this;
}
withWSFn(
createWSFn: typeof WebsocketDecompressAdapter.createWebSocketFn
): this {
withWSFn(createWSFn: WebSocketFactory): this {
this.#createWSFn = createWSFn;
return this;
}
@@ -94,7 +93,17 @@ export class DbConnectionBuilder<DbConnection extends DbConnectionImpl<any>> {
*
* @param compression The compression algorithm to use for the connection.
*/
withCompression(compression: 'gzip' | 'none'): this {
withCompression(compression: 'gzip' | 'brotli' | 'none'): this {
if (compression === 'brotli') {
try {
new DecompressionStream('brotli' as CompressionFormat);
} catch (e) {
throw new TypeError(
`Brotli compression is not supported by the runtime. Please choose a different compression method.`,
{ cause: e }
);
}
}
this.#compression = compression;
return this;
}
@@ -37,10 +37,6 @@ import {
type PendingCallback,
type TableUpdate as CacheTableUpdate,
} from './table_cache.ts';
import {
WebsocketDecompressAdapter,
type WebsocketAdapter,
} from './websocket_decompress_adapter.ts';
import {
SubscriptionBuilderImpl,
SubscriptionHandleImpl,
@@ -60,6 +56,7 @@ import type { ProceduresView } from './procedures.ts';
import type { Values } from '../lib/type_util.ts';
import type { TransactionUpdate } from './client_api/types.ts';
import { InternalError, SenderError } from '../lib/errors.ts';
import type { WebSocketAdapter, WebSocketFactory } from './ws.ts';
import {
normalizeWsProtocol,
PREFERRED_WS_PROTOCOLS,
@@ -101,8 +98,8 @@ export type DbConnectionConfig<RemoteModule extends UntypedRemoteModule> = {
identity?: Identity;
token?: string;
emitter: EventEmitter<ConnectionEvent>;
createWSFn: typeof WebsocketDecompressAdapter.createWebSocketFn;
compression: 'gzip' | 'none';
createWSFn: WebSocketFactory;
compression: 'gzip' | 'brotli' | 'none';
lightMode: boolean;
confirmedReads?: boolean;
remoteModule: RemoteModule;
@@ -186,7 +183,7 @@ export class DbConnectionImpl<RemoteModule extends UntypedRemoteModule>
#inboundQueue: Uint8Array[] = [];
#inboundQueueOffset = 0;
#isDrainingInboundQueue = false;
#outboundQueue: Uint8Array[] = [];
#outboundQueue: Uint8Array<ArrayBuffer>[] = [];
#isOutboundFlushScheduled = false;
#negotiatedWsProtocol: NegotiatedWsProtocol = V2_WS_PROTOCOL;
#subscriptionManager = new SubscriptionManager<RemoteModule>();
@@ -224,8 +221,8 @@ export class DbConnectionImpl<RemoteModule extends UntypedRemoteModule>
// private fields.
// We use them in testing.
private clientCache: ClientCache<RemoteModule>;
private ws?: WebsocketAdapter;
private wsPromise: Promise<WebsocketAdapter | undefined>;
private ws?: WebSocketAdapter;
private wsPromise: Promise<WebSocketAdapter | undefined>;
constructor({
uri,
@@ -612,7 +609,7 @@ export class DbConnectionImpl<RemoteModule extends UntypedRemoteModule>
return this.#mergeTableUpdates(updates);
}
#flushOutboundQueue(wsResolved: WebsocketAdapter): void {
#flushOutboundQueue(wsResolved: WebSocketAdapter): void {
if (this.#negotiatedWsProtocol === V3_WS_PROTOCOL) {
this.#flushOutboundQueueV3(wsResolved);
return;
@@ -620,14 +617,14 @@ export class DbConnectionImpl<RemoteModule extends UntypedRemoteModule>
this.#flushOutboundQueueV2(wsResolved);
}
#flushOutboundQueueV2(wsResolved: WebsocketAdapter): void {
#flushOutboundQueueV2(wsResolved: WebSocketAdapter): void {
const pending = this.#outboundQueue.splice(0);
for (const message of pending) {
wsResolved.send(message);
}
}
#flushOutboundQueueV3(wsResolved: WebsocketAdapter): void {
#flushOutboundQueueV3(wsResolved: WebSocketAdapter): void {
if (this.#outboundQueue.length === 0) {
return;
}
@@ -692,7 +689,10 @@ export class DbConnectionImpl<RemoteModule extends UntypedRemoteModule>
#reducerArgsEncoder = new BinaryWriter(1024);
#clientMessageEncoder = new BinaryWriter(1024);
#sendEncodedMessage(encoded: Uint8Array, describe: () => string): void {
#sendEncodedMessage(
encoded: Uint8Array<ArrayBuffer>,
describe: () => string
): void {
stdbLogger('trace', describe);
if (this.ws && this.isActive) {
if (this.#negotiatedWsProtocol === V2_WS_PROTOCOL) {
@@ -1,12 +1,11 @@
export async function decompress(
buffer: Uint8Array,
// Leaving it here to expand to brotli when it lands in the browsers and NodeJS
type: 'gzip',
buffer: Uint8Array<ArrayBuffer>,
type: CompressionFormat,
chunkSize: number = 128 * 1024 // 128KB
): Promise<Uint8Array> {
// Create a single ReadableStream to handle chunks
let offset = 0;
const readableStream = new ReadableStream({
const readableStream = new ReadableStream<BufferSource>({
pull(controller) {
if (offset < buffer.length) {
// Slice a chunk of the buffer and enqueue it
@@ -29,24 +28,9 @@ export async function decompress(
const decompressedStream = readableStream.pipeThrough(decompressionStream);
// Collect the decompressed chunks efficiently
const reader = decompressedStream.getReader();
const chunks: Uint8Array[] = [];
let totalLength = 0;
let result: any;
while (!(result = await reader.read()).done) {
chunks.push(result.value);
totalLength += result.value.length;
const chunks = [];
for await (const chunk of decompressedStream) {
chunks.push(chunk);
}
// Allocate a single Uint8Array for the decompressed data
const decompressedArray = new Uint8Array(totalLength);
let chunkOffset = 0;
for (const chunk of chunks) {
decompressedArray.set(chunk, chunkOffset);
chunkOffset += chunk.length;
}
return decompressedArray;
return new Blob(chunks).bytes();
}
@@ -1,22 +1,10 @@
import { decompress } from './decompress';
import { resolveWS } from './ws';
import { openWebSocket, type WebSocketAdapter, type WebSocketArgs } from './ws';
export interface WebsocketAdapter {
readonly protocol: string;
send(msg: Uint8Array): void;
close(): void;
set onclose(handler: (ev: CloseEvent) => void);
set onopen(handler: () => void);
set onmessage(handler: (msg: { data: Uint8Array }) => void);
set onerror(handler: (msg: ErrorEvent) => void);
}
export class WebsocketDecompressAdapter implements WebsocketAdapter {
export class WebsocketDecompressAdapter implements WebSocketAdapter {
get protocol(): string {
return this.#ws.protocol;
}
set onclose(handler: (ev: CloseEvent) => void) {
this.#ws.onclose = handler;
}
@@ -35,16 +23,17 @@ export class WebsocketDecompressAdapter implements WebsocketAdapter {
#ws: WebSocket;
async #decompress(buffer: Uint8Array): Promise<Uint8Array> {
async #decompress(buffer: Uint8Array<ArrayBuffer>): Promise<Uint8Array> {
const tag = buffer[0];
const data = buffer.subarray(1);
switch (tag) {
case 0:
return data;
case 1:
throw new Error(
'Brotli Compression not supported. Please use gzip or none compression in withCompression method on DbConnection.'
);
// Some runtimes support brotli, but it's not yet defined in `lib.dom.d.ts`.
// We assert runtime support in `DbConnectionBuilder.withCompression`, so
// this cast is safe.
return await decompress(data, 'brotli' as CompressionFormat);
case 2:
return await decompress(data, 'gzip');
default:
@@ -54,7 +43,7 @@ export class WebsocketDecompressAdapter implements WebsocketAdapter {
}
}
send(msg: Uint8Array): void {
send(msg: Uint8Array<ArrayBuffer>): void {
this.#ws.send(msg);
}
@@ -63,68 +52,12 @@ export class WebsocketDecompressAdapter implements WebsocketAdapter {
}
constructor(ws: WebSocket) {
ws.binaryType = 'arraybuffer';
this.#ws = ws;
}
static async createWebSocketFn({
url,
nameOrAddress,
wsProtocol,
authToken,
compression,
lightMode,
confirmedReads,
}: {
url: URL;
wsProtocol: string | string[];
nameOrAddress: string;
authToken?: string;
compression: 'gzip' | 'none';
lightMode: boolean;
confirmedReads?: boolean;
}): Promise<WebsocketDecompressAdapter> {
const headers = new Headers();
const WS = await resolveWS();
// We swap our original token to a shorter-lived token
// to avoid sending the original via query params.
let temporaryAuthToken: string | undefined = undefined;
if (authToken) {
headers.set('Authorization', `Bearer ${authToken}`);
const tokenUrl = new URL('v1/identity/websocket-token', url);
tokenUrl.protocol = url.protocol === 'wss:' ? 'https:' : 'http:';
const response = await fetch(tokenUrl, { method: 'POST', headers });
if (response.ok) {
const { token } = await response.json();
temporaryAuthToken = token;
} else {
return Promise.reject(
new Error(`Failed to verify token: ${response.statusText}`)
);
}
}
const databaseUrl = new URL(`v1/database/${nameOrAddress}/subscribe`, url);
if (temporaryAuthToken) {
databaseUrl.searchParams.set('token', temporaryAuthToken);
}
databaseUrl.searchParams.set(
'compression',
compression === 'gzip' ? 'Gzip' : 'None'
);
if (lightMode) {
databaseUrl.searchParams.set('light', 'true');
}
if (confirmedReads !== undefined) {
databaseUrl.searchParams.set('confirmed', confirmedReads.toString());
}
const ws = new WS(databaseUrl.toString(), wsProtocol);
return new WebsocketDecompressAdapter(ws);
static async openWebSocket(
args: WebSocketArgs
): Promise<WebsocketDecompressAdapter> {
return new this(await openWebSocket(args));
}
}
@@ -1,17 +1,17 @@
import BinaryReader from '../lib/binary_reader.ts';
import BinaryWriter from '../lib/binary_writer.ts';
import { ClientMessage, ServerMessage } from './client_api/types';
import type { WebsocketAdapter } from './websocket_decompress_adapter';
import type { WebSocketAdapter, WebSocketFactory } from './ws';
import { PREFERRED_WS_PROTOCOLS, V3_WS_PROTOCOL } from './websocket_protocols';
import {
decodeClientMessagesV3,
encodeServerMessagesV3,
} from './websocket_v3_frames.ts';
class WebsocketTestAdapter implements WebsocketAdapter {
class WebsocketTestAdapter implements WebSocketAdapter {
protocol: string = '';
messageQueue: Uint8Array[];
messageQueue: Uint8Array<ArrayBuffer>[];
outgoingMessages: ClientMessage[];
closed: boolean;
supportedProtocols: string[];
@@ -41,7 +41,7 @@ class WebsocketTestAdapter implements WebsocketAdapter {
set onerror(_handler: (msg: ErrorEvent) => void) {}
send(message: Uint8Array): void {
send(message: Uint8Array<ArrayBuffer>): void {
const rawMessage = message.slice();
const outgoingMessages =
this.protocol === V3_WS_PROTOCOL
@@ -85,28 +85,16 @@ class WebsocketTestAdapter implements WebsocketAdapter {
this.#onmessage({ data: outboundData });
}
async createWebSocketFn(_args: {
url: URL;
wsProtocol: string | string[];
nameOrAddress: string;
authToken?: string;
compression: 'gzip' | 'none';
lightMode: boolean;
confirmedReads?: boolean;
}): Promise<WebsocketTestAdapter> {
const requestedProtocols = Array.isArray(_args.wsProtocol)
? _args.wsProtocol
: [_args.wsProtocol];
const negotiatedProtocol = requestedProtocols.find(protocol =>
openWebSocket: WebSocketFactory = async ({ wsProtocol }) => {
const negotiatedProtocol = wsProtocol.find(protocol =>
this.supportedProtocols.includes(protocol)
);
if (!negotiatedProtocol) {
return Promise.reject(new Error('No compatible websocket protocol'));
throw new Error('No compatible websocket protocol');
}
this.protocol = negotiatedProtocol;
return this;
}
};
}
export type { WebsocketTestAdapter };
export default WebsocketTestAdapter;
@@ -28,9 +28,9 @@ function ensureMessageCount(
function concatenateMessagesV3(
writer: BinaryWriter,
messages: readonly Uint8Array[],
messages: readonly Uint8Array<ArrayBuffer>[],
messageCount: number = messages.length
): Uint8Array {
): Uint8Array<ArrayBuffer> {
ensureMessageCount(messages, messageCount);
writer.clear();
for (let i = 0; i < messageCount; i++) {
@@ -41,15 +41,15 @@ function concatenateMessagesV3(
function splitMessagesV3(
reader: BinaryReader,
data: Uint8Array,
data: Uint8Array<ArrayBuffer>,
deserialize: (reader: BinaryReader) => unknown
): Uint8Array[] {
): Uint8Array<ArrayBuffer>[] {
reader.reset(data);
if (reader.remaining === 0) {
throw new RangeError(EMPTY_V3_PAYLOAD_ERR);
}
const messages: Uint8Array[] = [];
const messages: Uint8Array<ArrayBuffer>[] = [];
while (reader.remaining > 0) {
const startOffset = reader.offset;
deserialize(reader);
@@ -60,7 +60,7 @@ function splitMessagesV3(
}
export function countClientMessagesForV3Frame(
messages: readonly Uint8Array[],
messages: readonly Uint8Array<ArrayBuffer>[],
maxFrameBytes: number
): number {
ensureMessages(messages);
@@ -86,13 +86,15 @@ export function countClientMessagesForV3Frame(
export function encodeClientMessagesV3(
writer: BinaryWriter,
messages: readonly Uint8Array[],
messages: readonly Uint8Array<ArrayBuffer>[],
messageCount: number = messages.length
): Uint8Array {
): Uint8Array<ArrayBuffer> {
return concatenateMessagesV3(writer, messages, messageCount);
}
export function decodeClientMessagesV3(data: Uint8Array): Uint8Array[] {
export function decodeClientMessagesV3(
data: Uint8Array<ArrayBuffer>
): Uint8Array<ArrayBuffer>[] {
return splitMessagesV3(new BinaryReader(data), data, reader =>
ClientMessage.deserialize(reader)
);
@@ -100,8 +102,8 @@ export function decodeClientMessagesV3(data: Uint8Array): Uint8Array[] {
export function encodeServerMessagesV3(
writer: BinaryWriter,
messages: readonly Uint8Array[]
): Uint8Array {
messages: readonly Uint8Array<ArrayBuffer>[]
): Uint8Array<ArrayBuffer> {
return concatenateMessagesV3(writer, messages);
}
+81 -3
View File
@@ -1,9 +1,9 @@
import { stdbLogger } from './logger';
export async function resolveWS(): Promise<typeof WebSocket> {
async function resolveWS(): Promise<typeof WebSocket> {
// Browser or Node >= 22 (or any env that exposes global WebSocket)
if (typeof (globalThis as any).WebSocket !== 'undefined') {
return (globalThis as any).WebSocket as typeof WebSocket;
if (typeof WebSocket !== 'undefined') {
return WebSocket;
}
// Node without a global WebSocket: lazily load undici's polyfill.
@@ -25,3 +25,81 @@ export async function resolveWS(): Promise<typeof WebSocket> {
throw err;
}
}
export interface WebSocketAdapter {
readonly protocol: string;
send(msg: Uint8Array<ArrayBuffer>): void;
close(): void;
set onclose(handler: (ev: CloseEvent) => void);
set onopen(handler: () => void);
set onmessage(handler: (msg: { data: Uint8Array }) => void);
set onerror(handler: (msg: ErrorEvent) => void);
}
export interface WebSocketArgs {
url: URL;
wsProtocol: string[];
nameOrAddress: string;
authToken?: string;
compression: 'gzip' | 'brotli' | 'none';
lightMode: boolean;
confirmedReads?: boolean;
}
export type WebSocketFactory = (
args: WebSocketArgs
) => Promise<WebSocketAdapter>;
/**
* Open a WebSocket to the database specified by the given `WebSocketArgs`.
* @returns a WebSocket with `binaryType` set to `arraybuffer`.
*/
export async function openWebSocket({
url,
nameOrAddress,
wsProtocol,
authToken,
compression,
lightMode,
confirmedReads,
}: WebSocketArgs): Promise<WebSocket> {
const headers = new Headers();
const WS = await resolveWS();
// We swap our original token to a shorter-lived token
// to avoid sending the original via query params.
let temporaryAuthToken: string | undefined;
if (authToken) {
headers.set('Authorization', `Bearer ${authToken}`);
const tokenUrl = new URL('v1/identity/websocket-token', url);
tokenUrl.protocol = url.protocol === 'wss:' ? 'https:' : 'http:';
const response = await fetch(tokenUrl, { method: 'POST', headers });
if (response.ok) {
const { token } = await response.json();
temporaryAuthToken = token;
} else {
throw new Error(`Failed to verify token: ${response.statusText}`);
}
}
const databaseUrl = new URL(`v1/database/${nameOrAddress}/subscribe`, url);
if (temporaryAuthToken) {
databaseUrl.searchParams.set('token', temporaryAuthToken);
}
databaseUrl.searchParams.set(
'compression',
{ gzip: 'Gzip', brotli: 'Brotli', none: 'None' }[compression] ?? 'None'
);
if (lightMode) {
databaseUrl.searchParams.set('light', 'true');
}
if (confirmedReads !== undefined) {
databaseUrl.searchParams.set('confirmed', confirmedReads.toString());
}
const ws = new WS(databaseUrl.toString(), wsProtocol);
ws.binaryType = 'arraybuffer';
return ws;
}
@@ -4,7 +4,7 @@
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo",
"target": "ES2020",
"useDefineForClassFields": true,
"lib": ["ESNext", "DOM", "DOM.Iterable"],
"lib": ["ESNext", "DOM", "DOM.Iterable", "DOM.AsyncIterable"],
"module": "ESNext",
"skipLibCheck": true,
@@ -174,7 +174,7 @@ describe('DbConnection', () => {
const client = DbConnection.builder()
.withUri('ws://127.0.0.1:1234')
.withDatabaseName('db')
.withWSFn(wsAdapter.createWebSocketFn.bind(wsAdapter) as any)
.withWSFn(wsAdapter.openWebSocket)
.onConnect(() => {
called = true;
onConnectPromise.resolve();
@@ -201,7 +201,7 @@ describe('DbConnection', () => {
const client = DbConnection.builder()
.withUri('ws://127.0.0.1:1234')
.withDatabaseName('db')
.withWSFn(wsAdapter.createWebSocketFn.bind(wsAdapter) as any)
.withWSFn(wsAdapter.openWebSocket)
.build();
await client['wsPromise'];
@@ -231,7 +231,7 @@ describe('DbConnection', () => {
const client = DbConnection.builder()
.withUri('ws://127.0.0.1:1234')
.withDatabaseName('db')
.withWSFn(wsAdapter.createWebSocketFn.bind(wsAdapter) as any)
.withWSFn(wsAdapter.openWebSocket)
.build();
await client['wsPromise'];
@@ -259,7 +259,7 @@ describe('DbConnection', () => {
const client = DbConnection.builder()
.withUri('ws://127.0.0.1:1234')
.withDatabaseName('db')
.withWSFn(wsAdapter.createWebSocketFn.bind(wsAdapter) as any)
.withWSFn(wsAdapter.openWebSocket)
.onDisconnect(() => {
onDisconnectPromise.resolve();
})
@@ -285,7 +285,7 @@ describe('DbConnection', () => {
const client = DbConnection.builder()
.withUri('ws://127.0.0.1:1234')
.withDatabaseName('db')
.withWSFn(wsAdapter.createWebSocketFn.bind(wsAdapter) as any)
.withWSFn(wsAdapter.openWebSocket)
.build();
await client['wsPromise'];
@@ -327,7 +327,7 @@ describe('DbConnection', () => {
const client = DbConnection.builder()
.withUri('ws://127.0.0.1:1234')
.withDatabaseName('db')
.withWSFn(wsAdapter.createWebSocketFn.bind(wsAdapter) as any)
.withWSFn(wsAdapter.openWebSocket)
.onConnect(() => {
onConnectPromise.resolve();
})
@@ -393,7 +393,7 @@ describe('DbConnection', () => {
const client = DbConnection.builder()
.withUri('ws://127.0.0.1:1234')
.withDatabaseName('db')
.withWSFn(wsAdapter.createWebSocketFn.bind(wsAdapter) as any)
.withWSFn(wsAdapter.openWebSocket)
.onConnect(() => {
onConnectPromise.resolve();
})
@@ -438,7 +438,7 @@ describe('DbConnection', () => {
const client = DbConnection.builder()
.withUri('ws://127.0.0.1:1234')
.withDatabaseName('db')
.withWSFn(wsAdapter.createWebSocketFn.bind(wsAdapter) as any)
.withWSFn(wsAdapter.openWebSocket)
.onConnect(() => {
onConnectPromise.resolve();
})
@@ -715,7 +715,7 @@ describe('DbConnection', () => {
const client = DbConnection.builder()
.withUri('ws://127.0.0.1:1234')
.withDatabaseName('db')
.withWSFn(wsAdapter.createWebSocketFn.bind(wsAdapter) as any)
.withWSFn(wsAdapter.openWebSocket)
.onConnect(() => {})
.build();
@@ -806,7 +806,7 @@ describe('DbConnection', () => {
const client = DbConnection.builder()
.withUri('ws://127.0.0.1:1234')
.withDatabaseName('db')
.withWSFn(wsAdapter.createWebSocketFn.bind(wsAdapter) as any)
.withWSFn(wsAdapter.openWebSocket)
.build();
await client['wsPromise'];
wsAdapter.acceptConnection();