From 4e75fc2eb675913da200fd73565e4eb6d3ee4f12 Mon Sep 17 00:00:00 2001 From: Kim Altintop Date: Mon, 30 Jun 2025 08:58:10 +0200 Subject: [PATCH] Rewrite to use more modular stream transformers for testability Also fixes the actual resource hog, which is that the ws_actor never terminated because all receive errors were ignored. --- Cargo.lock | 1 + crates/client-api/Cargo.toml | 1 + crates/client-api/src/routes/subscribe.rs | 519 +++++++++++++--------- 3 files changed, 321 insertions(+), 200 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c9bd4e0598..a773d1310b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5333,6 +5333,7 @@ name = "spacetimedb-client-api" version = "1.2.0" dependencies = [ "anyhow", + "async-stream", "async-trait", "axum", "axum-extra", diff --git a/crates/client-api/Cargo.toml b/crates/client-api/Cargo.toml index bd30f31da8..2005226eb6 100644 --- a/crates/client-api/Cargo.toml +++ b/crates/client-api/Cargo.toml @@ -48,6 +48,7 @@ uuid.workspace = true jsonwebtoken.workspace = true scopeguard.workspace = true serde_with.workspace = true +async-stream.workspace = true [target.'cfg(not(target_env = "msvc"))'.dependencies] jemalloc_pprof.workspace = true diff --git a/crates/client-api/src/routes/subscribe.rs b/crates/client-api/src/routes/subscribe.rs index 432110f2ef..9521804387 100644 --- a/crates/client-api/src/routes/subscribe.rs +++ b/crates/client-api/src/routes/subscribe.rs @@ -1,8 +1,10 @@ use std::future::poll_fn; -use std::mem; -use std::pin::{pin, Pin}; +use std::pin::pin; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; use std::time::Duration; +use async_stream::stream; use axum::extract::{Path, Query, State}; use axum::response::IntoResponse; use axum::Extension; @@ -10,10 +12,10 @@ use axum_extra::TypedHeader; use bytes::Bytes; use bytestring::ByteString; use derive_more::From; -use futures::future::MaybeDone; -use futures::stream::SplitSink; -use futures::{Future, FutureExt, SinkExt, StreamExt}; +use futures::future::FusedFuture as _; +use futures::{pin_mut, FutureExt, Sink, SinkExt, Stream, StreamExt}; use http::{HeaderValue, StatusCode}; +use prometheus::IntGauge; use scopeguard::ScopeGuard; use serde::Deserialize; use spacetimedb::client::messages::{ @@ -21,7 +23,7 @@ use spacetimedb::client::messages::{ }; use spacetimedb::client::{ ClientActorId, ClientConfig, ClientConnection, DataMessage, MessageExecutionError, MessageHandleError, - MeteredDeque, MeteredReceiver, Protocol, + MeteredReceiver, Protocol, }; use spacetimedb::execution_context::WorkloadType; use spacetimedb::host::module_host::ClientConnectedError; @@ -33,6 +35,8 @@ use spacetimedb_client_api_messages::websocket::{self as ws_api, Compression}; use spacetimedb_lib::connection_id::{ConnectionId, ConnectionIdForUrl}; use std::time::Instant; use tokio::sync::mpsc; +use tokio::time::timeout; +use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_tungstenite::tungstenite::Utf8Bytes; use crate::auth::SpacetimeAuth; @@ -189,6 +193,41 @@ where const LIVELINESS_TIMEOUT: Duration = Duration::from_secs(60); +#[derive(Clone)] +struct ActorState { + pub client_id: ClientActorId, + pub database: Identity, + closed: Arc, + got_pong: Arc, +} + +impl ActorState { + fn new(database: Identity, client_id: ClientActorId) -> Self { + Self { + database, + client_id, + closed: Arc::new(AtomicBool::new(false)), + got_pong: Arc::new(AtomicBool::new(true)), + } + } + + fn closed(&self) -> bool { + self.closed.load(Ordering::Relaxed) + } + + fn close(&self) -> bool { + self.closed.swap(true, Ordering::Relaxed) + } + + fn set_ponged(&self) { + self.got_pong.store(true, Ordering::Relaxed); + } + + fn reset_ponged(&self) -> bool { + self.got_pong.swap(false, Ordering::Relaxed) + } +} + async fn ws_client_actor(client: ClientConnection, ws: WebSocketStream, sendrx: MeteredReceiver) { // ensure that even if this task gets cancelled, we always cleanup the connection let mut client = scopeguard::guard(client, |client| { @@ -200,244 +239,318 @@ async fn ws_client_actor(client: ClientConnection, ws: WebSocketStream, sendrx: ScopeGuard::into_inner(client).disconnect().await; } -async fn make_progress(fut: &mut Pin<&mut MaybeDone>) { - if let MaybeDone::Gone = **fut { - // nothing to do - } else { - fut.await - } -} - async fn ws_client_actor_inner( client: &mut ClientConnection, ws: WebSocketStream, sendrx: MeteredReceiver, ) { + let database = client.module.info().database_identity; + + let client_closed_metric = WORKER_METRICS.ws_clients_closed_connection.with_label_values(&database); + let incoming_queue_length = WORKER_METRICS.total_incoming_queue_length.with_label_values(&database); + + let state = ActorState::new(database, client.id); + let mut liveness_check_interval = tokio::time::interval(LIVELINESS_TIMEOUT); - let mut got_pong = true; - let addr = client.module.info().database_identity; - - // Build a queue of incoming messages to handle, to be processed one at a time, - // in the order they're received. - // - // N.B. if you're refactoring this code: you must ensure the handle_queue is dropped before - // client.disconnect() is called. Otherwise, we can be left with a stale future that's never - // awaited, which can lead to bugs like: - // https://rust-lang.github.io/wg-async/vision/submitted_stories/status_quo/aws_engineer/solving_a_deadlock.html - // - // NOTE: never let this go unpolled while you're awaiting something; otherwise, it's possible - // to deadlock or delay for a long time. see usage of `also_poll()` in the branches of the - // `select!` for examples of how to do this. - // - // TODO: do we want this to have a fixed capacity? or should it be unbounded - let mut message_queue = MeteredDeque::<(DataMessage, Instant)>::new( - WORKER_METRICS.total_incoming_queue_length.with_label_values(&addr), - ); - // Holds the future processing the current client message, - // the output of that future, - // or nothing at all (`MaybeDone::Gone`). - let mut current_message = pin!(MaybeDone::Gone); - - // If true, we sent the client a close frame, and are waiting for a reply. - let mut closed = false; - - let (ws_send, mut ws_recv) = ws.split(); + // Channel for [`UnorderedWsMessage`]s. let (unordered_tx, unordered_rx) = mpsc::unbounded_channel(); - tokio::spawn(ws_send_loop(addr, client.config, ws_send, sendrx, unordered_rx)); + // Channel for submitting work to the [`ws_eval_handler`]. + // + // Note that we buffer client messages unboundedly, so that we don't delay + // subscription updates in the `select!` loop while we're waiting for an + // evaluation result. + // Being able to observe the backlog (via `incoming_queue_length`) is useful + // to identify performance issues. + // Yet, we may consider to instead spawn a task for the receive end, and not + // buffer at all, in order to apply backpressure to client. + let (eval_tx, eval_rx) = mpsc::unbounded_channel(); + let mut eval_rx = scopeguard::guard(UnboundedReceiverStream::new(eval_rx), |stream| { + incoming_queue_length.sub(stream.into_inner().len() as _); + }); + + // Split websocket into send and receive halves. + let (ws_send, ws_recv) = ws.split(); + // Make a stream that reads from the socket and yields [`ClientMessage`]s. + let recv_loop = pin!(ws_recv_loop(state.clone(), ws_recv)); + let recv_handler = ws_client_message_handler(state.clone(), client_closed_metric, recv_loop); + // Stream that consumes from `eval_tx` and evaluates the tasks. + // Yields `Result<(), MessageHandleError>`. + let eval_handler = ws_eval_handler(client.clone(), &mut *eval_rx); + // Sink that sends subscription updates and reducer results from `sendrx`, + // as well as [`UnorderedWsMessage`]s to the socket.. + let send_loop = ws_send_loop(state.clone(), client.config, ws_send, sendrx, unordered_rx).fuse(); + + pin_mut!(recv_handler); + pin_mut!(eval_handler); + pin_mut!(send_loop); loop { - enum Item { - Message(ClientMessage), - HandleResult(Result<(), MessageHandleError>), - } - if let MaybeDone::Gone = *current_message { - if let Some((message, timer)) = message_queue.pop_front() { - let client = client.clone(); - let fut = async move { client.handle_message(message, timer).await }; - current_message.set(MaybeDone::Future(fut)); - } - } - let message = tokio::select! { - // NOTE: all of the futures for these branches **must** be cancel safe. do not - // change this if you don't know what that means. - - // If we have a result from handling a past message to report, - // grab it to handle in the next `match`. - Some(res) = async { - make_progress(&mut current_message).await; - current_message.as_mut().take_output() - } => { - Item::HandleResult(res) - } - - // If we've received an incoming message, - // grab it to handle in the next `match`. - message = ws_recv.next() => match message { - // Drop incoming messages if we already sent a close frame. - // We're not supposed to send more data, so there's no point - // processing it. - Some(Ok(_)) if closed => { - continue; - }, - Some(Ok(m)) => { - Item::Message(ClientMessage::from_message(m)) - }, - Some(Err(error)) => { - log::warn!("Websocket receive error: {}", error); - continue; - } - // the client sent us a close frame - None => { - break - }, - }, - - // Update the module host of the client connection if it was - // updated (hotswapped). - // - // If `closed` is true, we already sent a close frame and will exit - // the loop once the client acknowledges (`ws_recv.next()` returns - // `None`). - // - // If the module exited, we'll send a close frame and wait for an - // acknowledgement from the client. - // The branch is disabled if `closed == true` to avoid sending - // another close frame. - res = client.watch_module_host(), if !closed => { + tokio::select! { + // Get the next client message and submit it for evaluation. + res = recv_handler.next() => { match res { - Ok(()) => {} - // If the module has exited, close the websocket. - Err(NoSuchModule) => { - let close = CloseFrame { - code: CloseCode::Away, - reason: "module exited".into() - }; - // If the sender is already gone, - // we won't be sending close, and not receive an ack, - // so exit the loop here. - if unordered_tx.send(close.into()).is_err() { + Some(task) => { + log::trace!("received new task"); + if eval_tx.send(task).is_err() { + log::trace!("eval_tx already closed"); break; }; - closed = true; + incoming_queue_length.inc(); + }, + None => { + log::trace!("recv handler exhausted"); + break; } } - continue; - } - - // If it's time to send a ping... - _ = liveness_check_interval.tick() => { - // If we received a pong at some point, send a fresh ping. - if mem::take(&mut got_pong) { - // If the sender is already gone, - // we'll time out the connection eventually. - let _ = unordered_tx.send(UnorderedWsMessage::Ping(Bytes::new())); - continue; - } else { - // the client never responded to our ping; drop them without trying to send them a Close - log::warn!("client {} timed out", client.id); - break; - } - } - }; - - // Handle the incoming message we grabbed in the previous `select!`. - - // TODO: Data flow appears to not require `enum Item` or this distinct `match`, - // since `Item::HandleResult` comes from exactly one `select!` branch, - // and `Item::Message` comes from exactly one distinct `select!` branch. - // Consider merging this `match` with the previous `select!`. - match message { - Item::Message(ClientMessage::Message(message)) => { - let timer = Instant::now(); - message_queue.push_back((message, timer)) - } - Item::HandleResult(res) => { - if let Err(e) = res { + }, + // Get the next evaluation result and handle errors. + Some(result) = eval_handler.next() => { + log::trace!("received task result"); + incoming_queue_length.dec(); + if let Err(e) = result { if let MessageHandleError::Execution(err) = e { log::error!("{err:#}"); - // Ignoring send errors is apparently fine in this case. let _ = unordered_tx.send(err.into()); continue; } - log::debug!("Client caused error on text message: {}", e); + log::debug!("Client caused error: {e}"); let close = CloseFrame { code: CloseCode::Error, - reason: format!("{e:#}").into(), + reason: format!("{e:#}").into() }; // If the sender is already gone, // we won't be sending close, and not receive an ack, // so exit the loop here. if unordered_tx.send(close.into()).is_err() { + log::trace!("unordered_tx already closed"); break; } - closed = true; + } + }, + // Poll the send loop until it's done. + _ = &mut send_loop, if !send_loop.is_terminated() => { + log::trace!("send loop terminated"); + }, + + // Update the client's module host if it was hotswapped, + // or close the session if the module exited. + // + // Branch is disabled if we already sent a close frame. + res = client.watch_module_host(), if !state.closed() => { + if let Err(NoSuchModule) = res { + let close = CloseFrame { + code: CloseCode::Away, + reason: "module exited".into() + }; + // If the sender is already gone, + // we won't be sending close, and not receive an ack, + // so exit the loop here. + if unordered_tx.send(close.into()).is_err() { + log::trace!("unordered_tx already closed"); + break; + }; + } + }, + + // Send ping or time out the client. + // + // Branch is disabled if we lready sent a close frame. + _ = liveness_check_interval.tick(), if !state.closed() => { + let was_ponged = state.reset_ponged(); + if was_ponged { + // If the sender is already gone, + // we expect to receive an error on the receiver stream, + // but we can just as well exit here. + if unordered_tx.send(UnorderedWsMessage::Ping(Bytes::new())).is_err() { + log::trace!("unordered_tx already closed"); + break; + }; + } else { + log::warn!("client {} timed out", client.id); + break; } } - Item::Message(ClientMessage::Ping(_message)) => { - log::trace!("Received ping from client {}", client.id); - // No need to explicitly respond with a `Pong`, as tungstenite handles this automatically. - // See [https://github.com/snapview/tokio-tungstenite/issues/88]. - } - Item::Message(ClientMessage::Pong(_message)) => { - log::trace!("Received heartbeat from client {}", client.id); - got_pong = true; - } - Item::Message(ClientMessage::Close(close_frame)) => { - // This happens in 2 cases: - // - // a) We sent a Close frame and this is the ack. - // b) This is the client telling us they want to close. - // - // In either case, after the remaining messages in the queue - // are drained, `ws_recv.next()` will return `None` and we'll - // exit the loop. - // - // NOTE: No need to send a close frame, it is queued - // automatically by tungstenite. - // We need to continue polling the websocket, however, - // to have the close frame sent. - log::trace!("Close frame {:?}", close_frame); - if !closed { - // This is the client telling us they want to close. - WORKER_METRICS - .ws_clients_closed_connection - .with_label_values(&addr) - .inc(); + + else => break, + } + } + log::info!("Client connection ended: {}", client.id); +} + +/// Stream that consumes a stream of [`WsMessage`]s and yields [`ClientMessage`]s. +/// +/// Terminates if: +/// +/// - the input stream is exhausted +/// - the input stream yields an error +/// +/// If `state.closed`, continues to poll the input stream in order for the +/// websocket close handshake to complete. Any messages received while in this +/// state are dropped. +fn ws_recv_loop( + state: ActorState, + mut ws: impl Stream> + Unpin, +) -> impl Stream { + stream! { + loop { + let Some(res) = async { + if state.closed() { + log::trace!("await next client message with timeout"); + match timeout(Duration::from_millis(150), ws.next()).await { + Err(_) => { + log::warn!("timeout waiting for client close"); + None + }, + Ok(item) => item + } + } else { + log::trace!("await next client message without timeout"); + ws.next().await } - closed = true; + }.await else { + log::trace!("recv stream exhausted"); + break; + }; + match res { + Ok(m) => { + if !state.closed() { + yield ClientMessage::from_message(m); + } + // If closed, keep polling until either: + // + // - the client sends a close frame (`ws` returns `None) + // - or `ws` yields an error + log::trace!("message received while already closed"); + } + // None of the error cases can be meaningfully recovered from + // (and some can't even occur on the `ws` stream). + // Exit here but spell out an exhaustive match + // in order to bring any future library changes to our attention. + Err(e) => match e { + e @ (WsError::ConnectionClosed + | WsError::AlreadyClosed + | WsError::Io(_) + | WsError::Tls(_) + | WsError::Capacity(_) + | WsError::Protocol(_) + | WsError::WriteBufferFull(_) + | WsError::Utf8 + | WsError::AttackAttempt + | WsError::Url(_) + | WsError::Http(_) + | WsError::HttpFormat(_)) => { + log::warn!("Websocket receive error: {e}"); + break; + } + }, } } } - log::debug!("Client connection ended"); +} + +/// Stream that consumes [`ClientMessage`]s and yields [`DataMessage`]s for +/// evaluation. +/// +/// Calls `state.set_ponged()` if and when the input yields a pong message. +/// Calls `state.close()` if and when the input yields a close frame, +/// i.e. the client initiated a close handshake, which we track using the +/// `client_closed_metric`. +/// +/// Terminates when the input stream terminates. +fn ws_client_message_handler( + state: ActorState, + client_closed_metric: IntGauge, + mut messages: impl Stream + Unpin, +) -> impl Stream { + stream! { + while let Some(message) = messages.next().await { + match message { + ClientMessage::Message(message) => { + log::trace!("Received client message"); + yield (message, Instant::now()); + }, + ClientMessage::Ping(_bytes) => { + log::trace!("Received ping from client {}", state.client_id); + }, + ClientMessage::Pong(_bytes) => { + log::trace!("Received pong from client {}", state.client_id); + state.set_ponged(); + }, + ClientMessage::Close(close_frame) => { + log::trace!("Received Close frame from client {}: {:?}", state.client_id, close_frame); + let was_closed = state.close(); + // This is the client telling us they want to close. + if !was_closed { + client_closed_metric.inc(); + } + } + } + } + log::trace!("client message handler done"); + } +} + +/// Stream that consumed [`DataMessage`]s, evaluates them, and yields the result. +/// +/// Terminates when the input stream terminates. +fn ws_eval_handler( + client: ClientConnection, + mut messages: impl Stream + Unpin, +) -> impl Stream> { + stream! { + while let Some((message, timer)) = messages.next().await { + let result = client.handle_message(message, timer).await; + yield result; + } + } } /// Outgoing messages that don't need to be ordered wrt subscription updates. #[derive(From)] enum UnorderedWsMessage { + /// Server-initiated close. Close(CloseFrame), + /// Server-initiated ping. Ping(Bytes), + /// Error calling a reducer. + /// + /// The error indicates that the reducer was **not** called, + /// and can thus be unordered wrt subscription updates. Error(MessageExecutionError), } +/// Sink that sends outgoing messages to the `ws` sink. +/// +/// Consumes `messages`, which yields subscription updates and reducer call +/// results. Note that [`SerializableMessage`]s require serialization and +/// potentially compression, which can be costly. +/// Also consumes `unordered`, which yields [`UnorderedWsMessage`]s. +/// +/// Terminates if: +/// +/// - `unordered` is closed +/// - an error occurs sending to the `ws` sink +/// +/// If an [`UnorderedWsMessage::Close`] is encountered, a close frame is sent +/// to the `ws` sink, and `state.close()` is called. When this happens, +/// `messages` will no longer be polled (no data can be sent after a close +/// frame anyways), so `messages.close()` will be called. +/// +/// Keeps polling `unordered` if `state.closed()`, but discards all data. +/// This is so `ws_client_actor_inner` keeps polling the receive end of the +/// socket until the close handshake completes -- it would otherwise exit early +/// when sending to `unordered` fails. async fn ws_send_loop( - database_identity: Identity, + state: ActorState, config: ClientConfig, - mut ws: SplitSink, + mut ws: impl Sink + Unpin, mut messages: MeteredReceiver, mut unordered: mpsc::UnboundedReceiver, ) { let mut messages_buf = Vec::with_capacity(32); let mut serialize_buf = SerializeBuffer::new(config); - // If true, we already sent a close frame. - // - // RFC 6455, Section 5.5.1: - // - // > The application MUST NOT send any more data frames after sending a - // > Close frame. - let mut closed = false; - loop { tokio::select! { // `biased` towards the unordered queue, @@ -448,34 +561,37 @@ async fn ws_send_loop( // We shall not sent more data after a close frame, // but keep polling `unordered` so that `ws_client_actor` keeps // waiting for an acknowledgement from the client, - // event if it spuriously initiates another close itself. - if closed { + // even if it spuriously initiates another close itself. + if state.closed() { continue; } match msg { UnorderedWsMessage::Close(close_frame) => { + log::trace!("sending close frame"); if let Err(e) = ws.send(WsMessage::Close(Some(close_frame))).await { log::warn!("error sending close frame: {e:#}"); break; } - closed = true; + state.close(); // We won't be polling `messages` anymore, // so let senders know. messages.close(); }, UnorderedWsMessage::Ping(bytes) => { + log::trace!("sending ping"); let _ = ws .feed(WsMessage::Ping(bytes)) .await .inspect_err(|e| log::warn!("error sending ping: {e:#}")); }, UnorderedWsMessage::Error(err) => { + log::trace!("sending error result"); let (msg_alloc, res) = send_message( - &mut ws, - &database_identity, + &state.database, config, serialize_buf, None, + &mut ws, err ).await; serialize_buf = msg_alloc; @@ -488,20 +604,22 @@ async fn ws_send_loop( } }, - Some(n) = messages.recv_many(&mut messages_buf, 32).map(|n| (n != 0).then_some(n)), if !closed => { + Some(n) = messages.recv_many(&mut messages_buf, 32).map(|n| (n != 0).then_some(n)), if !state.closed() => { + log::trace!("sending {n} outgoing messages"); for msg in messages_buf.drain(..n) { let (msg_alloc, res) = send_message( - &mut ws, - &database_identity, + &state.database, config, serialize_buf, msg.workload().zip(msg.num_rows()), + &mut ws, msg ).await; serialize_buf = msg_alloc; if let Err(e) = res { log::warn!("websocket send error: {e}"); + messages.close(); break; } } @@ -519,11 +637,11 @@ async fn ws_send_loop( /// Serialize and potentially compress `message`, and feed it to the `ws` sink. async fn send_message( - ws: &mut SplitSink, database_identity: &Identity, config: ClientConfig, serialize_buf: SerializeBuffer, metrics_metadata: Option<(WorkloadType, usize)>, + ws: &mut (impl Sink + Unpin), message: impl ToProtocol + Send + 'static, ) -> (SerializeBuffer, Result<(), WsError>) { let (workload, num_rows) = metrics_metadata.unzip(); @@ -565,6 +683,7 @@ enum ClientMessage { Pong(Bytes), Close(Option), } + impl ClientMessage { fn from_message(msg: WsMessage) -> Self { match msg {