fix: connection lifecycle callbacks (#4935)

# Description of Changes

This pull request improves the handling of connection lifecycle events
in the Rust client SDK for SpacetimeDB, particularly distinguishing
between connection failures and disconnections. It introduces a new
`ConnectionLifecycle` state machine to track connection progress,
ensures that the correct callback (`on_connect_error` or
`on_disconnect`) is invoked based on the connection state.

**Changes**

* `ConnectionLifecycle` enum to track the connection state
(`Connecting`, `Connected`, `Ended`)
* Refactored error handling so that if a connection fails before
establishment, the `on_connect_error` callback is invoked; if the
connection fails after establishment, the `on_disconnect` callback is
invoked. See `end_connection`.
* Updated where disconnections are handled
(`advance_one_message_blocking`, `advance_one_message_async`, and
message processing) to use `finish_connection`
* Improved handling of user-initiated disconnects during the connection
process to avoid reporting them as connection errors and to ensure
proper cleanup.

# API and ABI breaking changes

I guess maybe if people relied on the `on_connect_error` to actually
fire the `on_disconnect` then this changes that behavior.

# Expected complexity level and risk

Maybe a 2? Seems pretty low risk but I'm still new to the codebase,
please double check.

This doesn't fix the websocket issues, that'll be for another day. I
noticed websocket.rs has some places it just drops and the error isn't
handled properly. We could technically surface that information and run
our callbacks with more specific error messages.

# Testing

I had an agent build and run loads of tests for this but didn't commit
those since it would have made the PR massive. I was planning on testing
locally though to see if I could trigger a connection failure at some
point, maybe via an invalid access token.
This commit is contained in:
Jeff Rooks
2026-05-04 12:39:01 -04:00
committed by GitHub
parent 8cd2936931
commit 4a20c81d0b
2 changed files with 95 additions and 53 deletions
@@ -147,9 +147,7 @@ impl DbConnectionBuilder {
}
```
Chain a call to `.on_connect_error(callback)` to your builder to register a callback to run when your connection fails.
A known bug in the SpacetimeDB Rust client SDK currently causes this callback never to be invoked. [`on_disconnect`](#callback-on_disconnect) callbacks are invoked instead.
Chain a call to `.on_connect_error(callback)` to your builder to register a callback to run when a connection attempt fails asynchronously. Errors which prevent `build` from creating the connection are returned by `build` instead.
#### Callback `on_disconnect`
@@ -162,7 +160,7 @@ impl DbConnectionBuilder {
}
```
Chain a call to `.on_disconnect(callback)` to your builder to register a callback to run when your `DbConnection` disconnects from the remote database, either as a result of a call to [`disconnect`](#method-disconnect) or due to an error.
Chain a call to `.on_disconnect(callback)` to your builder to register a callback to run when your established `DbConnection` disconnects from the remote database, either as a result of a call to [`disconnect`](#method-disconnect) or due to an error.
#### Method `with_token`
+93 -49
View File
@@ -137,18 +137,25 @@ impl<M: SpacetimeModule> DbContextImpl<M> {
fn process_message(&self, msg: ParsedMessage<M>) -> crate::Result<()> {
self.debug_log(|out| writeln!(out, "`process_message`: {msg:?}"));
match msg {
// Error: treat this as an erroneous disconnect.
ParsedMessage::Error(e) => {
let disconnect_ctx = self.make_event_ctx(Some(e.clone()));
self.invoke_disconnected(&disconnect_ctx);
Err(e)
}
// Error: route as a connection error if we never finished connecting,
// otherwise treat it as an erroneous disconnect.
ParsedMessage::Error(e) => Err(self.end_connection(Some(e))),
// Initial `IdentityToken` message:
// confirm that the received identity and connection ID are what we expect,
// store them,
// then invoke the on_connect callback.
// store them, then invoke the on_connect callback.
ParsedMessage::IdentityToken(identity, token, conn_id) => {
let on_connect = {
let mut inner = self.inner.lock().unwrap();
match inner.connection_lifecycle {
ConnectionLifecycle::Connecting => {
inner.connection_lifecycle = ConnectionLifecycle::Connected;
inner.on_connect.take()
}
ConnectionLifecycle::Connected => None,
ConnectionLifecycle::Ended => return Ok(()),
}
};
{
// Don't hold the `self.identity` lock while running callbacks.
// Callbacks can (will) call [`DbContext::identity`], which acquires that lock,
@@ -170,8 +177,7 @@ impl<M: SpacetimeModule> DbContextImpl<M> {
}
*conn_id_store = Some(conn_id);
}
let mut inner = self.inner.lock().unwrap();
if let Some(on_connect) = inner.on_connect.take() {
if let Some(on_connect) = on_connect {
let ctx = <M::DbConnection as DbConnection>::new(self.clone());
on_connect(&ctx, identity, &token);
}
@@ -306,23 +312,47 @@ impl<M: SpacetimeModule> DbContextImpl<M> {
applied_diff.invoke_row_callbacks(&row_event_ctx, &mut inner.db_callbacks);
}
/// Invoke the on-disconnect callback, and mark [`Self::is_active`] false.
fn invoke_disconnected(&self, ctx: &M::ErrorContext) {
/// Mark the connection lifecycle as ended, route the terminal event to the
/// appropriate connection callback, and mark [`Self::is_active`] false.
///
/// Returns the terminal error that should be returned from `advance_*` methods.
fn end_connection(&self, callback_error: Option<crate::Error>) -> crate::Error {
let mut inner = self.inner.lock().unwrap();
// When we disconnect, we first call the on_disconnect method,
// then we call the `on_error` method for all subscriptions.
// We don't change the client cache at all.
let return_error = callback_error.clone().unwrap_or(crate::Error::Disconnected);
let lifecycle = inner.connection_lifecycle;
if lifecycle == ConnectionLifecycle::Ended {
return return_error;
}
inner.connection_lifecycle = ConnectionLifecycle::Ended;
// Set `send_chan` to `None`, since `Self::is_active` checks that.
*self.send_chan.lock().unwrap() = None;
// Grap the `on_disconnect` callback and invoke it.
if let Some(disconnect_callback) = inner.on_disconnect.take() {
disconnect_callback(ctx, ctx.event().clone());
}
match lifecycle {
ConnectionLifecycle::Connecting => {
let callback_error = callback_error.unwrap_or_else(|| crate::Error::FailedToConnect {
source: InternalError::new("Connection closed before receiving the initial connection message"),
});
let ctx: M::ErrorContext = self.make_event_ctx(Some(callback_error.clone()));
if let Some(connect_error_callback) = inner.on_connect_error.take() {
connect_error_callback(&ctx, callback_error.clone());
}
callback_error
}
ConnectionLifecycle::Connected => {
let ctx: M::ErrorContext = self.make_event_ctx(callback_error.clone());
if let Some(disconnect_callback) = inner.on_disconnect.take() {
disconnect_callback(&ctx, callback_error.clone());
}
// Call the `on_disconnect` method for all subscriptions.
inner.subscriptions.on_disconnect(ctx);
// Call the `on_disconnect` method for all subscriptions.
inner.subscriptions.on_disconnect(&ctx);
return_error
}
ConnectionLifecycle::Ended => return_error,
}
}
fn make_event_ctx<E, Ctx: AbstractEventContext<Module = M, Event = E>>(&self, event: E) -> Ctx {
@@ -447,10 +477,19 @@ impl<M: SpacetimeModule> DbContextImpl<M> {
// Disconnect: close the connection.
PendingMutation::Disconnect => {
{
let mut inner = self.inner.lock().unwrap();
if inner.connection_lifecycle == ConnectionLifecycle::Connecting {
// If the user cancels before the initial connection finishes,
// don't report that as a connection error.
inner.connection_lifecycle = ConnectionLifecycle::Ended;
}
}
// Set `send_chan` to `None`, since `Self::is_active` checks that.
// This will close the WebSocket loop in websocket.rs,
// sending a close frame to the server,
// eventually resulting in disconnect callbacks being called.
// eventually resulting in disconnect callbacks being called
// if the initial connection had completed.
*self.send_chan.lock().unwrap() = None;
}
@@ -540,11 +579,7 @@ impl<M: SpacetimeModule> DbContextImpl<M> {
// `Stream::poll_next`. No comment on whether this is a good mental
// model or not.
let res = match get_lock_sync(&self.recv).try_next() {
Ok(None) => {
let disconnect_ctx = self.make_event_ctx(None);
self.invoke_disconnected(&disconnect_ctx);
Err(crate::Error::Disconnected)
}
Ok(None) => Err(self.end_connection(None)),
Err(_) => Ok(false),
Ok(Some(msg)) => self.process_message(msg).map(|_| true),
};
@@ -599,11 +634,7 @@ impl<M: SpacetimeModule> DbContextImpl<M> {
pub fn advance_one_message_blocking(&self) -> crate::Result<()> {
match self.runtime.block_on(self.get_message()) {
Message::Local(pending) => self.apply_mutation(pending),
Message::Ws(None) => {
let disconnect_ctx = self.make_event_ctx(None);
self.invoke_disconnected(&disconnect_ctx);
Err(crate::Error::Disconnected)
}
Message::Ws(None) => Err(self.end_connection(None)),
Message::Ws(Some(msg)) => self.process_message(msg),
}
}
@@ -614,11 +645,7 @@ impl<M: SpacetimeModule> DbContextImpl<M> {
pub async fn advance_one_message_async(&self) -> crate::Result<()> {
match self.get_message().await {
Message::Local(pending) => self.apply_mutation(pending),
Message::Ws(None) => {
let disconnect_ctx = self.make_event_ctx(None);
self.invoke_disconnected(&disconnect_ctx);
Err(crate::Error::Disconnected)
}
Message::Ws(None) => Err(self.end_connection(None)),
Message::Ws(Some(msg)) => self.process_message(msg),
}
}
@@ -784,6 +811,16 @@ type OnConnectErrorCallback<M> = Box<dyn FnOnce(&<M as SpacetimeModule>::ErrorCo
type OnDisconnectCallback<M> =
Box<dyn FnOnce(&<M as SpacetimeModule>::ErrorContext, Option<crate::Error>) + Send + 'static>;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum ConnectionLifecycle {
/// Waiting for the server's initial connection message.
Connecting,
/// The server has sent the initial connection message.
Connected,
/// The connection has already reached a terminal lifecycle state.
Ended,
}
/// All the stuff in a [`DbContextImpl`] which can safely be locked while invoking callbacks.
pub(crate) struct DbContextImplInner<M: SpacetimeModule> {
/// `Some` if not within the context of an outer runtime. The `Runtime` must
@@ -796,9 +833,8 @@ pub(crate) struct DbContextImplInner<M: SpacetimeModule> {
reducer_callbacks: ReducerCallbacks<M>,
pub(crate) subscriptions: SubscriptionManager<M>,
connection_lifecycle: ConnectionLifecycle,
on_connect: Option<OnConnectCallback<M>>,
#[allow(unused)]
// TODO: Make use of this to handle `ParsedMessage::Error` before receiving `IdentityToken`.
on_connect_error: Option<OnConnectErrorCallback<M>>,
on_disconnect: Option<OnDisconnectCallback<M>>,
@@ -1040,9 +1076,10 @@ but you must call one of them, or else the connection will never progress.
/// If this method is not invoked, or `None` is supplied,
/// the SpacetimeDB host will generate a new anonymous `Identity`.
///
/// If the passed token is invalid or rejected by the host,
/// the connection will fail asynchrnonously.
// FIXME: currently this causes `disconnect` to be called rather than `on_connect_error`.
/// If the token is rejected before a connection context is created, [`Self::build`]
/// returns an error. If the host reports the rejection after the WebSocket is
/// established but before the initial connection message, [`Self::on_connect_error`]
/// is invoked.
pub fn with_token(mut self, token: Option<impl Into<String>>) -> Self {
self.token = token.map(|token| token.into());
self
@@ -1095,9 +1132,10 @@ but you must call one of them, or else the connection will never progress.
self
}
/// Register a callback to run when the connection is successfully initiated.
/// Register a callback to run when the connection is successfully established.
///
/// The callback will receive three arguments:
/// The connection is established after the initial connection message is
/// received from the host. The callback will receive three arguments:
/// - The `DbConnection` which has successfully connected.
/// - The `Identity` of the successful connection.
/// - The private access token which can be used to later re-authenticate as the same `Identity`.
@@ -1116,9 +1154,11 @@ Instead of registering multiple `on_connect` callbacks, register a single callba
self
}
/// Register a callback to run when the connection fails asynchronously,
/// e.g. due to invalid credentials.
// FIXME: currently never called; `on_disconnect` is called instead.
/// Register a callback to run when a connection attempt fails asynchronously.
///
/// This callback is invoked only before the initial connection message is
/// received from the host. Errors which prevent [`Self::build`] from creating
/// a connection are returned by [`Self::build`] instead.
pub fn on_connect_error(mut self, callback: impl FnOnce(&M::ErrorContext, crate::Error) + Send + 'static) -> Self {
if self.on_connect_error.is_some() {
panic!(
@@ -1132,8 +1172,11 @@ Instead of registering multiple `on_connect_error` callbacks, register a single
self
}
/// Register a callback to run when the connection is closed.
// FIXME: currently also called when the connection fails asynchronously, instead of `on_connect_error`.
/// Register a callback to run when an established connection is closed.
///
/// The connection is established after the initial connection message is
/// received from the host. Connection failures before that point invoke
/// [`Self::on_connect_error`] instead.
pub fn on_disconnect(
mut self,
callback: impl FnOnce(&M::ErrorContext, Option<crate::Error>) + Send + 'static,
@@ -1166,6 +1209,7 @@ fn build_db_ctx_inner<M: SpacetimeModule>(
reducer_callbacks: ReducerCallbacks::default(),
subscriptions: SubscriptionManager::default(),
connection_lifecycle: ConnectionLifecycle::Connecting,
on_connect: on_connect_cb,
on_connect_error: on_connect_error_cb,
on_disconnect: on_disconnect_cb,