From 4a20c81d0bc1be7b194498e1a80aa69d1b9585e3 Mon Sep 17 00:00:00 2001 From: Jeff Rooks Date: Mon, 4 May 2026 12:39:01 -0400 Subject: [PATCH] fix: connection lifecycle callbacks (#4935) # Description of Changes This pull request improves the handling of connection lifecycle events in the Rust client SDK for SpacetimeDB, particularly distinguishing between connection failures and disconnections. It introduces a new `ConnectionLifecycle` state machine to track connection progress, ensures that the correct callback (`on_connect_error` or `on_disconnect`) is invoked based on the connection state. **Changes** * `ConnectionLifecycle` enum to track the connection state (`Connecting`, `Connected`, `Ended`) * Refactored error handling so that if a connection fails before establishment, the `on_connect_error` callback is invoked; if the connection fails after establishment, the `on_disconnect` callback is invoked. See `end_connection`. * Updated where disconnections are handled (`advance_one_message_blocking`, `advance_one_message_async`, and message processing) to use `finish_connection` * Improved handling of user-initiated disconnects during the connection process to avoid reporting them as connection errors and to ensure proper cleanup. # API and ABI breaking changes I guess maybe if people relied on the `on_connect_error` to actually fire the `on_disconnect` then this changes that behavior. # Expected complexity level and risk Maybe a 2? Seems pretty low risk but I'm still new to the codebase, please double check. This doesn't fix the websocket issues, that'll be for another day. I noticed websocket.rs has some places it just drops and the error isn't handled properly. We could technically surface that information and run our callbacks with more specific error messages. # Testing I had an agent build and run loads of tests for this but didn't commit those since it would have made the PR massive. I was planning on testing locally though to see if I could trigger a connection failure at some point, maybe via an invalid access token. --- .../00600-clients/00500-rust-reference.md | 6 +- sdks/rust/src/db_connection.rs | 142 ++++++++++++------ 2 files changed, 95 insertions(+), 53 deletions(-) diff --git a/docs/docs/00200-core-concepts/00600-clients/00500-rust-reference.md b/docs/docs/00200-core-concepts/00600-clients/00500-rust-reference.md index 72a291784..ecd0a02c8 100644 --- a/docs/docs/00200-core-concepts/00600-clients/00500-rust-reference.md +++ b/docs/docs/00200-core-concepts/00600-clients/00500-rust-reference.md @@ -147,9 +147,7 @@ impl DbConnectionBuilder { } ``` -Chain a call to `.on_connect_error(callback)` to your builder to register a callback to run when your connection fails. - -A known bug in the SpacetimeDB Rust client SDK currently causes this callback never to be invoked. [`on_disconnect`](#callback-on_disconnect) callbacks are invoked instead. +Chain a call to `.on_connect_error(callback)` to your builder to register a callback to run when a connection attempt fails asynchronously. Errors which prevent `build` from creating the connection are returned by `build` instead. #### Callback `on_disconnect` @@ -162,7 +160,7 @@ impl DbConnectionBuilder { } ``` -Chain a call to `.on_disconnect(callback)` to your builder to register a callback to run when your `DbConnection` disconnects from the remote database, either as a result of a call to [`disconnect`](#method-disconnect) or due to an error. +Chain a call to `.on_disconnect(callback)` to your builder to register a callback to run when your established `DbConnection` disconnects from the remote database, either as a result of a call to [`disconnect`](#method-disconnect) or due to an error. #### Method `with_token` diff --git a/sdks/rust/src/db_connection.rs b/sdks/rust/src/db_connection.rs index 838894522..332aac1b3 100644 --- a/sdks/rust/src/db_connection.rs +++ b/sdks/rust/src/db_connection.rs @@ -137,18 +137,25 @@ impl DbContextImpl { fn process_message(&self, msg: ParsedMessage) -> crate::Result<()> { self.debug_log(|out| writeln!(out, "`process_message`: {msg:?}")); match msg { - // Error: treat this as an erroneous disconnect. - ParsedMessage::Error(e) => { - let disconnect_ctx = self.make_event_ctx(Some(e.clone())); - self.invoke_disconnected(&disconnect_ctx); - Err(e) - } + // Error: route as a connection error if we never finished connecting, + // otherwise treat it as an erroneous disconnect. + ParsedMessage::Error(e) => Err(self.end_connection(Some(e))), // Initial `IdentityToken` message: // confirm that the received identity and connection ID are what we expect, - // store them, - // then invoke the on_connect callback. + // store them, then invoke the on_connect callback. ParsedMessage::IdentityToken(identity, token, conn_id) => { + let on_connect = { + let mut inner = self.inner.lock().unwrap(); + match inner.connection_lifecycle { + ConnectionLifecycle::Connecting => { + inner.connection_lifecycle = ConnectionLifecycle::Connected; + inner.on_connect.take() + } + ConnectionLifecycle::Connected => None, + ConnectionLifecycle::Ended => return Ok(()), + } + }; { // Don't hold the `self.identity` lock while running callbacks. // Callbacks can (will) call [`DbContext::identity`], which acquires that lock, @@ -170,8 +177,7 @@ impl DbContextImpl { } *conn_id_store = Some(conn_id); } - let mut inner = self.inner.lock().unwrap(); - if let Some(on_connect) = inner.on_connect.take() { + if let Some(on_connect) = on_connect { let ctx = ::new(self.clone()); on_connect(&ctx, identity, &token); } @@ -306,23 +312,47 @@ impl DbContextImpl { applied_diff.invoke_row_callbacks(&row_event_ctx, &mut inner.db_callbacks); } - /// Invoke the on-disconnect callback, and mark [`Self::is_active`] false. - fn invoke_disconnected(&self, ctx: &M::ErrorContext) { + /// Mark the connection lifecycle as ended, route the terminal event to the + /// appropriate connection callback, and mark [`Self::is_active`] false. + /// + /// Returns the terminal error that should be returned from `advance_*` methods. + fn end_connection(&self, callback_error: Option) -> crate::Error { let mut inner = self.inner.lock().unwrap(); - // When we disconnect, we first call the on_disconnect method, - // then we call the `on_error` method for all subscriptions. - // We don't change the client cache at all. + let return_error = callback_error.clone().unwrap_or(crate::Error::Disconnected); + + let lifecycle = inner.connection_lifecycle; + if lifecycle == ConnectionLifecycle::Ended { + return return_error; + } + inner.connection_lifecycle = ConnectionLifecycle::Ended; // Set `send_chan` to `None`, since `Self::is_active` checks that. *self.send_chan.lock().unwrap() = None; - // Grap the `on_disconnect` callback and invoke it. - if let Some(disconnect_callback) = inner.on_disconnect.take() { - disconnect_callback(ctx, ctx.event().clone()); - } + match lifecycle { + ConnectionLifecycle::Connecting => { + let callback_error = callback_error.unwrap_or_else(|| crate::Error::FailedToConnect { + source: InternalError::new("Connection closed before receiving the initial connection message"), + }); + let ctx: M::ErrorContext = self.make_event_ctx(Some(callback_error.clone())); + if let Some(connect_error_callback) = inner.on_connect_error.take() { + connect_error_callback(&ctx, callback_error.clone()); + } + callback_error + } + ConnectionLifecycle::Connected => { + let ctx: M::ErrorContext = self.make_event_ctx(callback_error.clone()); + if let Some(disconnect_callback) = inner.on_disconnect.take() { + disconnect_callback(&ctx, callback_error.clone()); + } - // Call the `on_disconnect` method for all subscriptions. - inner.subscriptions.on_disconnect(ctx); + // Call the `on_disconnect` method for all subscriptions. + inner.subscriptions.on_disconnect(&ctx); + + return_error + } + ConnectionLifecycle::Ended => return_error, + } } fn make_event_ctx>(&self, event: E) -> Ctx { @@ -447,10 +477,19 @@ impl DbContextImpl { // Disconnect: close the connection. PendingMutation::Disconnect => { + { + let mut inner = self.inner.lock().unwrap(); + if inner.connection_lifecycle == ConnectionLifecycle::Connecting { + // If the user cancels before the initial connection finishes, + // don't report that as a connection error. + inner.connection_lifecycle = ConnectionLifecycle::Ended; + } + } // Set `send_chan` to `None`, since `Self::is_active` checks that. // This will close the WebSocket loop in websocket.rs, // sending a close frame to the server, - // eventually resulting in disconnect callbacks being called. + // eventually resulting in disconnect callbacks being called + // if the initial connection had completed. *self.send_chan.lock().unwrap() = None; } @@ -540,11 +579,7 @@ impl DbContextImpl { // `Stream::poll_next`. No comment on whether this is a good mental // model or not. let res = match get_lock_sync(&self.recv).try_next() { - Ok(None) => { - let disconnect_ctx = self.make_event_ctx(None); - self.invoke_disconnected(&disconnect_ctx); - Err(crate::Error::Disconnected) - } + Ok(None) => Err(self.end_connection(None)), Err(_) => Ok(false), Ok(Some(msg)) => self.process_message(msg).map(|_| true), }; @@ -599,11 +634,7 @@ impl DbContextImpl { pub fn advance_one_message_blocking(&self) -> crate::Result<()> { match self.runtime.block_on(self.get_message()) { Message::Local(pending) => self.apply_mutation(pending), - Message::Ws(None) => { - let disconnect_ctx = self.make_event_ctx(None); - self.invoke_disconnected(&disconnect_ctx); - Err(crate::Error::Disconnected) - } + Message::Ws(None) => Err(self.end_connection(None)), Message::Ws(Some(msg)) => self.process_message(msg), } } @@ -614,11 +645,7 @@ impl DbContextImpl { pub async fn advance_one_message_async(&self) -> crate::Result<()> { match self.get_message().await { Message::Local(pending) => self.apply_mutation(pending), - Message::Ws(None) => { - let disconnect_ctx = self.make_event_ctx(None); - self.invoke_disconnected(&disconnect_ctx); - Err(crate::Error::Disconnected) - } + Message::Ws(None) => Err(self.end_connection(None)), Message::Ws(Some(msg)) => self.process_message(msg), } } @@ -784,6 +811,16 @@ type OnConnectErrorCallback = Box::ErrorCo type OnDisconnectCallback = Box::ErrorContext, Option) + Send + 'static>; +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum ConnectionLifecycle { + /// Waiting for the server's initial connection message. + Connecting, + /// The server has sent the initial connection message. + Connected, + /// The connection has already reached a terminal lifecycle state. + Ended, +} + /// All the stuff in a [`DbContextImpl`] which can safely be locked while invoking callbacks. pub(crate) struct DbContextImplInner { /// `Some` if not within the context of an outer runtime. The `Runtime` must @@ -796,9 +833,8 @@ pub(crate) struct DbContextImplInner { reducer_callbacks: ReducerCallbacks, pub(crate) subscriptions: SubscriptionManager, + connection_lifecycle: ConnectionLifecycle, on_connect: Option>, - #[allow(unused)] - // TODO: Make use of this to handle `ParsedMessage::Error` before receiving `IdentityToken`. on_connect_error: Option>, on_disconnect: Option>, @@ -1040,9 +1076,10 @@ but you must call one of them, or else the connection will never progress. /// If this method is not invoked, or `None` is supplied, /// the SpacetimeDB host will generate a new anonymous `Identity`. /// - /// If the passed token is invalid or rejected by the host, - /// the connection will fail asynchrnonously. - // FIXME: currently this causes `disconnect` to be called rather than `on_connect_error`. + /// If the token is rejected before a connection context is created, [`Self::build`] + /// returns an error. If the host reports the rejection after the WebSocket is + /// established but before the initial connection message, [`Self::on_connect_error`] + /// is invoked. pub fn with_token(mut self, token: Option>) -> Self { self.token = token.map(|token| token.into()); self @@ -1095,9 +1132,10 @@ but you must call one of them, or else the connection will never progress. self } - /// Register a callback to run when the connection is successfully initiated. + /// Register a callback to run when the connection is successfully established. /// - /// The callback will receive three arguments: + /// The connection is established after the initial connection message is + /// received from the host. The callback will receive three arguments: /// - The `DbConnection` which has successfully connected. /// - The `Identity` of the successful connection. /// - The private access token which can be used to later re-authenticate as the same `Identity`. @@ -1116,9 +1154,11 @@ Instead of registering multiple `on_connect` callbacks, register a single callba self } - /// Register a callback to run when the connection fails asynchronously, - /// e.g. due to invalid credentials. - // FIXME: currently never called; `on_disconnect` is called instead. + /// Register a callback to run when a connection attempt fails asynchronously. + /// + /// This callback is invoked only before the initial connection message is + /// received from the host. Errors which prevent [`Self::build`] from creating + /// a connection are returned by [`Self::build`] instead. pub fn on_connect_error(mut self, callback: impl FnOnce(&M::ErrorContext, crate::Error) + Send + 'static) -> Self { if self.on_connect_error.is_some() { panic!( @@ -1132,8 +1172,11 @@ Instead of registering multiple `on_connect_error` callbacks, register a single self } - /// Register a callback to run when the connection is closed. - // FIXME: currently also called when the connection fails asynchronously, instead of `on_connect_error`. + /// Register a callback to run when an established connection is closed. + /// + /// The connection is established after the initial connection message is + /// received from the host. Connection failures before that point invoke + /// [`Self::on_connect_error`] instead. pub fn on_disconnect( mut self, callback: impl FnOnce(&M::ErrorContext, Option) + Send + 'static, @@ -1166,6 +1209,7 @@ fn build_db_ctx_inner( reducer_callbacks: ReducerCallbacks::default(), subscriptions: SubscriptionManager::default(), + connection_lifecycle: ConnectionLifecycle::Connecting, on_connect: on_connect_cb, on_connect_error: on_connect_error_cb, on_disconnect: on_disconnect_cb,