Rewrite to use more modular stream transformers for testability

Also fixes the actual resource hog, which is that the ws_actor never
terminated because all receive errors were ignored.
This commit is contained in:
Kim Altintop
2025-06-30 08:58:10 +02:00
parent 7e6df498fb
commit 4e75fc2eb6
3 changed files with 321 additions and 200 deletions
Generated
+1
View File
@@ -5333,6 +5333,7 @@ name = "spacetimedb-client-api"
version = "1.2.0"
dependencies = [
"anyhow",
"async-stream",
"async-trait",
"axum",
"axum-extra",
+1
View File
@@ -48,6 +48,7 @@ uuid.workspace = true
jsonwebtoken.workspace = true
scopeguard.workspace = true
serde_with.workspace = true
async-stream.workspace = true
[target.'cfg(not(target_env = "msvc"))'.dependencies]
jemalloc_pprof.workspace = true
+319 -200
View File
@@ -1,8 +1,10 @@
use std::future::poll_fn;
use std::mem;
use std::pin::{pin, Pin};
use std::pin::pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use async_stream::stream;
use axum::extract::{Path, Query, State};
use axum::response::IntoResponse;
use axum::Extension;
@@ -10,10 +12,10 @@ use axum_extra::TypedHeader;
use bytes::Bytes;
use bytestring::ByteString;
use derive_more::From;
use futures::future::MaybeDone;
use futures::stream::SplitSink;
use futures::{Future, FutureExt, SinkExt, StreamExt};
use futures::future::FusedFuture as _;
use futures::{pin_mut, FutureExt, Sink, SinkExt, Stream, StreamExt};
use http::{HeaderValue, StatusCode};
use prometheus::IntGauge;
use scopeguard::ScopeGuard;
use serde::Deserialize;
use spacetimedb::client::messages::{
@@ -21,7 +23,7 @@ use spacetimedb::client::messages::{
};
use spacetimedb::client::{
ClientActorId, ClientConfig, ClientConnection, DataMessage, MessageExecutionError, MessageHandleError,
MeteredDeque, MeteredReceiver, Protocol,
MeteredReceiver, Protocol,
};
use spacetimedb::execution_context::WorkloadType;
use spacetimedb::host::module_host::ClientConnectedError;
@@ -33,6 +35,8 @@ use spacetimedb_client_api_messages::websocket::{self as ws_api, Compression};
use spacetimedb_lib::connection_id::{ConnectionId, ConnectionIdForUrl};
use std::time::Instant;
use tokio::sync::mpsc;
use tokio::time::timeout;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_tungstenite::tungstenite::Utf8Bytes;
use crate::auth::SpacetimeAuth;
@@ -189,6 +193,41 @@ where
const LIVELINESS_TIMEOUT: Duration = Duration::from_secs(60);
#[derive(Clone)]
struct ActorState {
pub client_id: ClientActorId,
pub database: Identity,
closed: Arc<AtomicBool>,
got_pong: Arc<AtomicBool>,
}
impl ActorState {
fn new(database: Identity, client_id: ClientActorId) -> Self {
Self {
database,
client_id,
closed: Arc::new(AtomicBool::new(false)),
got_pong: Arc::new(AtomicBool::new(true)),
}
}
fn closed(&self) -> bool {
self.closed.load(Ordering::Relaxed)
}
fn close(&self) -> bool {
self.closed.swap(true, Ordering::Relaxed)
}
fn set_ponged(&self) {
self.got_pong.store(true, Ordering::Relaxed);
}
fn reset_ponged(&self) -> bool {
self.got_pong.swap(false, Ordering::Relaxed)
}
}
async fn ws_client_actor(client: ClientConnection, ws: WebSocketStream, sendrx: MeteredReceiver<SerializableMessage>) {
// ensure that even if this task gets cancelled, we always cleanup the connection
let mut client = scopeguard::guard(client, |client| {
@@ -200,244 +239,318 @@ async fn ws_client_actor(client: ClientConnection, ws: WebSocketStream, sendrx:
ScopeGuard::into_inner(client).disconnect().await;
}
async fn make_progress<Fut: Future>(fut: &mut Pin<&mut MaybeDone<Fut>>) {
if let MaybeDone::Gone = **fut {
// nothing to do
} else {
fut.await
}
}
async fn ws_client_actor_inner(
client: &mut ClientConnection,
ws: WebSocketStream,
sendrx: MeteredReceiver<SerializableMessage>,
) {
let database = client.module.info().database_identity;
let client_closed_metric = WORKER_METRICS.ws_clients_closed_connection.with_label_values(&database);
let incoming_queue_length = WORKER_METRICS.total_incoming_queue_length.with_label_values(&database);
let state = ActorState::new(database, client.id);
let mut liveness_check_interval = tokio::time::interval(LIVELINESS_TIMEOUT);
let mut got_pong = true;
let addr = client.module.info().database_identity;
// Build a queue of incoming messages to handle, to be processed one at a time,
// in the order they're received.
//
// N.B. if you're refactoring this code: you must ensure the handle_queue is dropped before
// client.disconnect() is called. Otherwise, we can be left with a stale future that's never
// awaited, which can lead to bugs like:
// https://rust-lang.github.io/wg-async/vision/submitted_stories/status_quo/aws_engineer/solving_a_deadlock.html
//
// NOTE: never let this go unpolled while you're awaiting something; otherwise, it's possible
// to deadlock or delay for a long time. see usage of `also_poll()` in the branches of the
// `select!` for examples of how to do this.
//
// TODO: do we want this to have a fixed capacity? or should it be unbounded
let mut message_queue = MeteredDeque::<(DataMessage, Instant)>::new(
WORKER_METRICS.total_incoming_queue_length.with_label_values(&addr),
);
// Holds the future processing the current client message,
// the output of that future,
// or nothing at all (`MaybeDone::Gone`).
let mut current_message = pin!(MaybeDone::Gone);
// If true, we sent the client a close frame, and are waiting for a reply.
let mut closed = false;
let (ws_send, mut ws_recv) = ws.split();
// Channel for [`UnorderedWsMessage`]s.
let (unordered_tx, unordered_rx) = mpsc::unbounded_channel();
tokio::spawn(ws_send_loop(addr, client.config, ws_send, sendrx, unordered_rx));
// Channel for submitting work to the [`ws_eval_handler`].
//
// Note that we buffer client messages unboundedly, so that we don't delay
// subscription updates in the `select!` loop while we're waiting for an
// evaluation result.
// Being able to observe the backlog (via `incoming_queue_length`) is useful
// to identify performance issues.
// Yet, we may consider to instead spawn a task for the receive end, and not
// buffer at all, in order to apply backpressure to client.
let (eval_tx, eval_rx) = mpsc::unbounded_channel();
let mut eval_rx = scopeguard::guard(UnboundedReceiverStream::new(eval_rx), |stream| {
incoming_queue_length.sub(stream.into_inner().len() as _);
});
// Split websocket into send and receive halves.
let (ws_send, ws_recv) = ws.split();
// Make a stream that reads from the socket and yields [`ClientMessage`]s.
let recv_loop = pin!(ws_recv_loop(state.clone(), ws_recv));
let recv_handler = ws_client_message_handler(state.clone(), client_closed_metric, recv_loop);
// Stream that consumes from `eval_tx` and evaluates the tasks.
// Yields `Result<(), MessageHandleError>`.
let eval_handler = ws_eval_handler(client.clone(), &mut *eval_rx);
// Sink that sends subscription updates and reducer results from `sendrx`,
// as well as [`UnorderedWsMessage`]s to the socket..
let send_loop = ws_send_loop(state.clone(), client.config, ws_send, sendrx, unordered_rx).fuse();
pin_mut!(recv_handler);
pin_mut!(eval_handler);
pin_mut!(send_loop);
loop {
enum Item {
Message(ClientMessage),
HandleResult(Result<(), MessageHandleError>),
}
if let MaybeDone::Gone = *current_message {
if let Some((message, timer)) = message_queue.pop_front() {
let client = client.clone();
let fut = async move { client.handle_message(message, timer).await };
current_message.set(MaybeDone::Future(fut));
}
}
let message = tokio::select! {
// NOTE: all of the futures for these branches **must** be cancel safe. do not
// change this if you don't know what that means.
// If we have a result from handling a past message to report,
// grab it to handle in the next `match`.
Some(res) = async {
make_progress(&mut current_message).await;
current_message.as_mut().take_output()
} => {
Item::HandleResult(res)
}
// If we've received an incoming message,
// grab it to handle in the next `match`.
message = ws_recv.next() => match message {
// Drop incoming messages if we already sent a close frame.
// We're not supposed to send more data, so there's no point
// processing it.
Some(Ok(_)) if closed => {
continue;
},
Some(Ok(m)) => {
Item::Message(ClientMessage::from_message(m))
},
Some(Err(error)) => {
log::warn!("Websocket receive error: {}", error);
continue;
}
// the client sent us a close frame
None => {
break
},
},
// Update the module host of the client connection if it was
// updated (hotswapped).
//
// If `closed` is true, we already sent a close frame and will exit
// the loop once the client acknowledges (`ws_recv.next()` returns
// `None`).
//
// If the module exited, we'll send a close frame and wait for an
// acknowledgement from the client.
// The branch is disabled if `closed == true` to avoid sending
// another close frame.
res = client.watch_module_host(), if !closed => {
tokio::select! {
// Get the next client message and submit it for evaluation.
res = recv_handler.next() => {
match res {
Ok(()) => {}
// If the module has exited, close the websocket.
Err(NoSuchModule) => {
let close = CloseFrame {
code: CloseCode::Away,
reason: "module exited".into()
};
// If the sender is already gone,
// we won't be sending close, and not receive an ack,
// so exit the loop here.
if unordered_tx.send(close.into()).is_err() {
Some(task) => {
log::trace!("received new task");
if eval_tx.send(task).is_err() {
log::trace!("eval_tx already closed");
break;
};
closed = true;
incoming_queue_length.inc();
},
None => {
log::trace!("recv handler exhausted");
break;
}
}
continue;
}
// If it's time to send a ping...
_ = liveness_check_interval.tick() => {
// If we received a pong at some point, send a fresh ping.
if mem::take(&mut got_pong) {
// If the sender is already gone,
// we'll time out the connection eventually.
let _ = unordered_tx.send(UnorderedWsMessage::Ping(Bytes::new()));
continue;
} else {
// the client never responded to our ping; drop them without trying to send them a Close
log::warn!("client {} timed out", client.id);
break;
}
}
};
// Handle the incoming message we grabbed in the previous `select!`.
// TODO: Data flow appears to not require `enum Item` or this distinct `match`,
// since `Item::HandleResult` comes from exactly one `select!` branch,
// and `Item::Message` comes from exactly one distinct `select!` branch.
// Consider merging this `match` with the previous `select!`.
match message {
Item::Message(ClientMessage::Message(message)) => {
let timer = Instant::now();
message_queue.push_back((message, timer))
}
Item::HandleResult(res) => {
if let Err(e) = res {
},
// Get the next evaluation result and handle errors.
Some(result) = eval_handler.next() => {
log::trace!("received task result");
incoming_queue_length.dec();
if let Err(e) = result {
if let MessageHandleError::Execution(err) = e {
log::error!("{err:#}");
// Ignoring send errors is apparently fine in this case.
let _ = unordered_tx.send(err.into());
continue;
}
log::debug!("Client caused error on text message: {}", e);
log::debug!("Client caused error: {e}");
let close = CloseFrame {
code: CloseCode::Error,
reason: format!("{e:#}").into(),
reason: format!("{e:#}").into()
};
// If the sender is already gone,
// we won't be sending close, and not receive an ack,
// so exit the loop here.
if unordered_tx.send(close.into()).is_err() {
log::trace!("unordered_tx already closed");
break;
}
closed = true;
}
},
// Poll the send loop until it's done.
_ = &mut send_loop, if !send_loop.is_terminated() => {
log::trace!("send loop terminated");
},
// Update the client's module host if it was hotswapped,
// or close the session if the module exited.
//
// Branch is disabled if we already sent a close frame.
res = client.watch_module_host(), if !state.closed() => {
if let Err(NoSuchModule) = res {
let close = CloseFrame {
code: CloseCode::Away,
reason: "module exited".into()
};
// If the sender is already gone,
// we won't be sending close, and not receive an ack,
// so exit the loop here.
if unordered_tx.send(close.into()).is_err() {
log::trace!("unordered_tx already closed");
break;
};
}
},
// Send ping or time out the client.
//
// Branch is disabled if we lready sent a close frame.
_ = liveness_check_interval.tick(), if !state.closed() => {
let was_ponged = state.reset_ponged();
if was_ponged {
// If the sender is already gone,
// we expect to receive an error on the receiver stream,
// but we can just as well exit here.
if unordered_tx.send(UnorderedWsMessage::Ping(Bytes::new())).is_err() {
log::trace!("unordered_tx already closed");
break;
};
} else {
log::warn!("client {} timed out", client.id);
break;
}
}
Item::Message(ClientMessage::Ping(_message)) => {
log::trace!("Received ping from client {}", client.id);
// No need to explicitly respond with a `Pong`, as tungstenite handles this automatically.
// See [https://github.com/snapview/tokio-tungstenite/issues/88].
}
Item::Message(ClientMessage::Pong(_message)) => {
log::trace!("Received heartbeat from client {}", client.id);
got_pong = true;
}
Item::Message(ClientMessage::Close(close_frame)) => {
// This happens in 2 cases:
//
// a) We sent a Close frame and this is the ack.
// b) This is the client telling us they want to close.
//
// In either case, after the remaining messages in the queue
// are drained, `ws_recv.next()` will return `None` and we'll
// exit the loop.
//
// NOTE: No need to send a close frame, it is queued
// automatically by tungstenite.
// We need to continue polling the websocket, however,
// to have the close frame sent.
log::trace!("Close frame {:?}", close_frame);
if !closed {
// This is the client telling us they want to close.
WORKER_METRICS
.ws_clients_closed_connection
.with_label_values(&addr)
.inc();
else => break,
}
}
log::info!("Client connection ended: {}", client.id);
}
/// Stream that consumes a stream of [`WsMessage`]s and yields [`ClientMessage`]s.
///
/// Terminates if:
///
/// - the input stream is exhausted
/// - the input stream yields an error
///
/// If `state.closed`, continues to poll the input stream in order for the
/// websocket close handshake to complete. Any messages received while in this
/// state are dropped.
fn ws_recv_loop(
state: ActorState,
mut ws: impl Stream<Item = Result<WsMessage, WsError>> + Unpin,
) -> impl Stream<Item = ClientMessage> {
stream! {
loop {
let Some(res) = async {
if state.closed() {
log::trace!("await next client message with timeout");
match timeout(Duration::from_millis(150), ws.next()).await {
Err(_) => {
log::warn!("timeout waiting for client close");
None
},
Ok(item) => item
}
} else {
log::trace!("await next client message without timeout");
ws.next().await
}
closed = true;
}.await else {
log::trace!("recv stream exhausted");
break;
};
match res {
Ok(m) => {
if !state.closed() {
yield ClientMessage::from_message(m);
}
// If closed, keep polling until either:
//
// - the client sends a close frame (`ws` returns `None)
// - or `ws` yields an error
log::trace!("message received while already closed");
}
// None of the error cases can be meaningfully recovered from
// (and some can't even occur on the `ws` stream).
// Exit here but spell out an exhaustive match
// in order to bring any future library changes to our attention.
Err(e) => match e {
e @ (WsError::ConnectionClosed
| WsError::AlreadyClosed
| WsError::Io(_)
| WsError::Tls(_)
| WsError::Capacity(_)
| WsError::Protocol(_)
| WsError::WriteBufferFull(_)
| WsError::Utf8
| WsError::AttackAttempt
| WsError::Url(_)
| WsError::Http(_)
| WsError::HttpFormat(_)) => {
log::warn!("Websocket receive error: {e}");
break;
}
},
}
}
}
log::debug!("Client connection ended");
}
/// Stream that consumes [`ClientMessage`]s and yields [`DataMessage`]s for
/// evaluation.
///
/// Calls `state.set_ponged()` if and when the input yields a pong message.
/// Calls `state.close()` if and when the input yields a close frame,
/// i.e. the client initiated a close handshake, which we track using the
/// `client_closed_metric`.
///
/// Terminates when the input stream terminates.
fn ws_client_message_handler(
state: ActorState,
client_closed_metric: IntGauge,
mut messages: impl Stream<Item = ClientMessage> + Unpin,
) -> impl Stream<Item = (DataMessage, Instant)> {
stream! {
while let Some(message) = messages.next().await {
match message {
ClientMessage::Message(message) => {
log::trace!("Received client message");
yield (message, Instant::now());
},
ClientMessage::Ping(_bytes) => {
log::trace!("Received ping from client {}", state.client_id);
},
ClientMessage::Pong(_bytes) => {
log::trace!("Received pong from client {}", state.client_id);
state.set_ponged();
},
ClientMessage::Close(close_frame) => {
log::trace!("Received Close frame from client {}: {:?}", state.client_id, close_frame);
let was_closed = state.close();
// This is the client telling us they want to close.
if !was_closed {
client_closed_metric.inc();
}
}
}
}
log::trace!("client message handler done");
}
}
/// Stream that consumed [`DataMessage`]s, evaluates them, and yields the result.
///
/// Terminates when the input stream terminates.
fn ws_eval_handler(
client: ClientConnection,
mut messages: impl Stream<Item = (DataMessage, Instant)> + Unpin,
) -> impl Stream<Item = Result<(), MessageHandleError>> {
stream! {
while let Some((message, timer)) = messages.next().await {
let result = client.handle_message(message, timer).await;
yield result;
}
}
}
/// Outgoing messages that don't need to be ordered wrt subscription updates.
#[derive(From)]
enum UnorderedWsMessage {
/// Server-initiated close.
Close(CloseFrame),
/// Server-initiated ping.
Ping(Bytes),
/// Error calling a reducer.
///
/// The error indicates that the reducer was **not** called,
/// and can thus be unordered wrt subscription updates.
Error(MessageExecutionError),
}
/// Sink that sends outgoing messages to the `ws` sink.
///
/// Consumes `messages`, which yields subscription updates and reducer call
/// results. Note that [`SerializableMessage`]s require serialization and
/// potentially compression, which can be costly.
/// Also consumes `unordered`, which yields [`UnorderedWsMessage`]s.
///
/// Terminates if:
///
/// - `unordered` is closed
/// - an error occurs sending to the `ws` sink
///
/// If an [`UnorderedWsMessage::Close`] is encountered, a close frame is sent
/// to the `ws` sink, and `state.close()` is called. When this happens,
/// `messages` will no longer be polled (no data can be sent after a close
/// frame anyways), so `messages.close()` will be called.
///
/// Keeps polling `unordered` if `state.closed()`, but discards all data.
/// This is so `ws_client_actor_inner` keeps polling the receive end of the
/// socket until the close handshake completes -- it would otherwise exit early
/// when sending to `unordered` fails.
async fn ws_send_loop(
database_identity: Identity,
state: ActorState,
config: ClientConfig,
mut ws: SplitSink<WebSocketStream, WsMessage>,
mut ws: impl Sink<WsMessage, Error = WsError> + Unpin,
mut messages: MeteredReceiver<SerializableMessage>,
mut unordered: mpsc::UnboundedReceiver<UnorderedWsMessage>,
) {
let mut messages_buf = Vec::with_capacity(32);
let mut serialize_buf = SerializeBuffer::new(config);
// If true, we already sent a close frame.
//
// RFC 6455, Section 5.5.1:
//
// > The application MUST NOT send any more data frames after sending a
// > Close frame.
let mut closed = false;
loop {
tokio::select! {
// `biased` towards the unordered queue,
@@ -448,34 +561,37 @@ async fn ws_send_loop(
// We shall not sent more data after a close frame,
// but keep polling `unordered` so that `ws_client_actor` keeps
// waiting for an acknowledgement from the client,
// event if it spuriously initiates another close itself.
if closed {
// even if it spuriously initiates another close itself.
if state.closed() {
continue;
}
match msg {
UnorderedWsMessage::Close(close_frame) => {
log::trace!("sending close frame");
if let Err(e) = ws.send(WsMessage::Close(Some(close_frame))).await {
log::warn!("error sending close frame: {e:#}");
break;
}
closed = true;
state.close();
// We won't be polling `messages` anymore,
// so let senders know.
messages.close();
},
UnorderedWsMessage::Ping(bytes) => {
log::trace!("sending ping");
let _ = ws
.feed(WsMessage::Ping(bytes))
.await
.inspect_err(|e| log::warn!("error sending ping: {e:#}"));
},
UnorderedWsMessage::Error(err) => {
log::trace!("sending error result");
let (msg_alloc, res) = send_message(
&mut ws,
&database_identity,
&state.database,
config,
serialize_buf,
None,
&mut ws,
err
).await;
serialize_buf = msg_alloc;
@@ -488,20 +604,22 @@ async fn ws_send_loop(
}
},
Some(n) = messages.recv_many(&mut messages_buf, 32).map(|n| (n != 0).then_some(n)), if !closed => {
Some(n) = messages.recv_many(&mut messages_buf, 32).map(|n| (n != 0).then_some(n)), if !state.closed() => {
log::trace!("sending {n} outgoing messages");
for msg in messages_buf.drain(..n) {
let (msg_alloc, res) = send_message(
&mut ws,
&database_identity,
&state.database,
config,
serialize_buf,
msg.workload().zip(msg.num_rows()),
&mut ws,
msg
).await;
serialize_buf = msg_alloc;
if let Err(e) = res {
log::warn!("websocket send error: {e}");
messages.close();
break;
}
}
@@ -519,11 +637,11 @@ async fn ws_send_loop(
/// Serialize and potentially compress `message`, and feed it to the `ws` sink.
async fn send_message(
ws: &mut SplitSink<WebSocketStream, WsMessage>,
database_identity: &Identity,
config: ClientConfig,
serialize_buf: SerializeBuffer,
metrics_metadata: Option<(WorkloadType, usize)>,
ws: &mut (impl Sink<WsMessage, Error = WsError> + Unpin),
message: impl ToProtocol<Encoded = SwitchedServerMessage> + Send + 'static,
) -> (SerializeBuffer, Result<(), WsError>) {
let (workload, num_rows) = metrics_metadata.unzip();
@@ -565,6 +683,7 @@ enum ClientMessage {
Pong(Bytes),
Close(Option<CloseFrame>),
}
impl ClientMessage {
fn from_message(msg: WsMessage) -> Self {
match msg {