diff --git a/crates/client-api/src/routes/subscribe.rs b/crates/client-api/src/routes/subscribe.rs index 3add2056a9..71009a9bca 100644 --- a/crates/client-api/src/routes/subscribe.rs +++ b/crates/client-api/src/routes/subscribe.rs @@ -855,9 +855,14 @@ fn bytestring_to_utf8bytes(s: ByteString) -> Utf8Bytes { #[cfg(test)] mod tests { - use std::{future::Future, task::Poll}; + use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, + }; - use future::FutureExt as _; + use anyhow::anyhow; + use future::{Either, FutureExt as _}; use futures::{sink, stream}; use pretty_assertions::assert_matches; use spacetimedb::client::ClientName; @@ -1090,7 +1095,204 @@ mod tests { assert!(messages_tx.is_closed()); } + #[tokio::test] + async fn send_loop_updates_idle_channel() { + let state = Arc::new(dummy_actor_state()); + let idle_deadline = Instant::now() + state.config.idle_timeout; + let (idle_tx, mut idle_rx) = watch::channel(idle_deadline); + let (messages_tx, messages_rx) = mpsc::channel(64); + let messages = MeteredReceiver::new(messages_rx); + let (unordered_tx, unordered_rx) = mpsc::unbounded_channel(); + + let send_loop = ws_send_loop( + state.clone(), + idle_tx, + ClientConfig::for_test(), + sink::drain(), + messages, + unordered_rx, + ); + pin_mut!(send_loop); + + let input = [ + Either::Left(UnorderedWsMessage::Ping(Bytes::new())), + Either::Left(UnorderedWsMessage::Error(MessageExecutionError { + reducer: None, + reducer_id: None, + caller_identity: Identity::ZERO, + caller_connection_id: None, + err: anyhow!("it did not work"), + })), + // TODO: This is the easiest to construct, + // but maybe we want other variants, too. + Either::Right(SerializableMessage::Identity(IdentityTokenMessage { + identity: Identity::ZERO, + token: "macaron".into(), + connection_id: ConnectionId::ZERO, + })), + Either::Left(UnorderedWsMessage::Close(CloseFrame { + code: CloseCode::Away, + reason: "bah!".into(), + })), + ]; + + let mut new_idle_deadline = idle_deadline; + for msg in input { + match msg { + Either::Left(unordered) => unordered_tx.send(unordered).unwrap(), + Either::Right(msg) => messages_tx.send(msg).await.unwrap(), + } + assert!(is_pending(&mut send_loop).await); + assert!(idle_rx.has_changed().unwrap()); + new_idle_deadline = *idle_rx.borrow_and_update(); + } + + assert!(new_idle_deadline > idle_deadline); + } + + #[tokio::test] + async fn send_loop_terminates_if_sink_cant_be_fed() { + let input = [ + Either::Left(UnorderedWsMessage::Close(CloseFrame { + code: CloseCode::Away, + reason: "bah!".into(), + })), + Either::Left(UnorderedWsMessage::Ping(Bytes::new())), + Either::Left(UnorderedWsMessage::Error(MessageExecutionError { + reducer: None, + reducer_id: None, + caller_identity: Identity::ZERO, + caller_connection_id: None, + err: anyhow!("it did not work"), + })), + // TODO: This is the easiest to construct, + // but maybe we want other variants, too. + Either::Right(SerializableMessage::Identity(IdentityTokenMessage { + identity: Identity::ZERO, + token: "macaron".into(), + connection_id: ConnectionId::ZERO, + })), + ]; + + for msg in input { + let state = Arc::new(dummy_actor_state()); + let (idle_tx, _idle_rx) = watch::channel(Instant::now() + state.config.idle_timeout); + let (messages_tx, messages_rx) = mpsc::channel(64); + let messages = MeteredReceiver::new(messages_rx); + let (unordered_tx, unordered_rx) = mpsc::unbounded_channel(); + + let send_loop = ws_send_loop( + state.clone(), + idle_tx, + ClientConfig::for_test(), + UnfeedableSink, + messages, + unordered_rx, + ); + pin_mut!(send_loop); + + match msg { + Either::Left(unordered) => unordered_tx.send(unordered).unwrap(), + Either::Right(msg) => messages_tx.send(msg).await.unwrap(), + } + send_loop.await; + } + } + + #[tokio::test] + async fn send_loop_terminates_if_sink_cant_be_flushed() { + let input = [ + Either::Left(UnorderedWsMessage::Close(CloseFrame { + code: CloseCode::Away, + reason: "bah!".into(), + })), + Either::Left(UnorderedWsMessage::Ping(Bytes::new())), + Either::Left(UnorderedWsMessage::Error(MessageExecutionError { + reducer: None, + reducer_id: None, + caller_identity: Identity::ZERO, + caller_connection_id: None, + err: anyhow!("it did not work"), + })), + // TODO: This is the easiest to construct, + // but maybe we want other variants, too. + Either::Right(SerializableMessage::Identity(IdentityTokenMessage { + identity: Identity::ZERO, + token: "macaron".into(), + connection_id: ConnectionId::ZERO, + })), + ]; + + for msg in input { + let state = Arc::new(dummy_actor_state()); + let (idle_tx, _idle_rx) = watch::channel(Instant::now() + state.config.idle_timeout); + let (messages_tx, messages_rx) = mpsc::channel(64); + let messages = MeteredReceiver::new(messages_rx); + let (unordered_tx, unordered_rx) = mpsc::unbounded_channel(); + + let send_loop = ws_send_loop( + state.clone(), + idle_tx, + ClientConfig::for_test(), + UnflushableSink, + messages, + unordered_rx, + ); + pin_mut!(send_loop); + + match msg { + Either::Left(unordered) => unordered_tx.send(unordered).unwrap(), + Either::Right(msg) => messages_tx.send(msg).await.unwrap(), + } + send_loop.await; + } + } + async fn is_pending(fut: &mut (impl Future + Unpin)) -> bool { poll_fn(|cx| Poll::Ready(fut.poll_unpin(cx).is_pending())).await } + + struct UnfeedableSink; + + impl Sink for UnfeedableSink { + type Error = &'static str; + + fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, _: T) -> Result<(), Self::Error> { + Err("don't feed the sink") + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + } + + struct UnflushableSink; + + impl Sink for UnflushableSink { + type Error = &'static str; + + fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, _: T) -> Result<(), Self::Error> { + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Err("don't flush the sink")) + } + + fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + } }