Files
zep/spacetimedb/src/utils.rs
T

442 lines
13 KiB
Rust

use crate::tables::*;
use spacetimedb::{Identity, Local, LocalReadOnly, Table};
use std::collections::{HashMap, HashSet};
pub fn validate_name(name: &str) -> Result<(), String> {
if name.trim().is_empty() {
return Err("Names must not be empty".to_string());
}
Ok(())
}
pub fn validate_message_length(db: &Local, text: &str) -> Result<(), String> {
let max_length_conf = db
.system_configuration()
.key()
.find("max_message_length".to_string());
let max_length = max_length_conf
.and_then(|c| c.value.parse::<usize>().ok())
.unwrap_or(262144);
if text.len() > max_length {
return Err(format!(
"Message exceeds maximum length of {} bytes ({}KB).",
max_length,
max_length / 1024
));
}
Ok(())
}
pub fn get_recent_message_limit(db: &Local) -> u64 {
db.system_configuration()
.key()
.find("recent_message_limit".to_string())
.and_then(|c| c.value.parse::<u64>().ok())
.unwrap_or(50)
}
pub fn get_next_seq_id(db: &Local, channel_id: u64) -> u64 {
let hwm = db.channel_high_water_mark().channel_id().find(channel_id);
let next_seq_id = hwm.as_ref().map(|h| h.last_seq_id + 1).unwrap_or(1);
if let Some(_h) = hwm {
db.channel_high_water_mark()
.channel_id()
.update(ChannelHighWaterMark {
channel_id,
last_seq_id: next_seq_id,
});
} else {
db.channel_high_water_mark().insert(ChannelHighWaterMark {
channel_id,
last_seq_id: next_seq_id,
});
}
next_seq_id
}
pub fn get_visible_message_ids(db: &Local, identity: Identity) -> HashMap<u64, u64> {
let mut result = HashMap::new();
// 1. Scrollback Path
if let Some(sub) = db.channel_subscription().identity().find(identity) {
for cms in db
.channel_message_sequence()
.channel_id()
.filter(sub.channel_id)
{
if cms.seq_id >= sub.earliest_seq_id {
result.insert(cms.message_id, cms.seq_id);
}
}
}
// 2. Fast Path: Recent Messages
let my_server_ids: Vec<u64> = db
.server_member()
.identity()
.filter(identity)
.map(|m| m.server_id)
.collect();
for server_id in my_server_ids {
for rm in db.recent_message().server_id().filter(server_id) {
result.entry(rm.id).or_insert(rm.seq_id);
}
}
// 3. DM Fast Path
let my_dms: Vec<_> = db
.direct_message()
.sender()
.filter(identity)
.filter(|dm| dm.is_open_sender)
.chain(
db.direct_message()
.recipient()
.filter(identity)
.filter(|dm| dm.is_open_recipient),
)
.map(|dm| dm.channel_id)
.collect();
for channel_id in my_dms {
for rm in db.recent_message().channel_id().filter(channel_id) {
result.entry(rm.id).or_insert(rm.seq_id);
}
}
result
}
pub fn get_visible_message_ids_read_only(
db: &LocalReadOnly,
identity: Identity,
) -> HashMap<u64, u64> {
let mut result = HashMap::new();
// 1. Scrollback Path (Selective)
if let Some(sub) = db.channel_subscription().identity().find(identity) {
for cms in db
.channel_message_sequence()
.channel_id()
.filter(sub.channel_id)
{
if cms.seq_id >= sub.earliest_seq_id {
result.insert(cms.message_id, cms.seq_id);
}
}
}
// 2. Fast Path: Recent Messages from my Servers
let my_server_ids: Vec<u64> = db
.server_member()
.identity()
.filter(identity)
.map(|m| m.server_id)
.collect();
for server_id in my_server_ids {
for rm in db.recent_message().server_id().filter(server_id) {
// entry().or_insert is faster than double lookup
result.entry(rm.id).or_insert(rm.seq_id);
}
}
// 3. Fast Path: Recent Messages from my Open DMs
let my_dms: Vec<_> = db
.direct_message()
.sender()
.filter(identity)
.filter(|dm| dm.is_open_sender)
.chain(
db.direct_message()
.recipient()
.filter(identity)
.filter(|dm| dm.is_open_recipient),
)
.map(|dm| dm.channel_id)
.collect();
for channel_id in my_dms {
for rm in db.recent_message().channel_id().filter(channel_id) {
result.entry(rm.id).or_insert(rm.seq_id);
}
}
result
}
pub fn get_visible_image_ids(db: &Local, identity: Identity) -> HashSet<u64> {
let mut ids = HashSet::new();
// 1. My Servers and their Members (Avatars/Banners)
let memberships: Vec<_> = db.server_member().identity().filter(identity).collect();
for member in memberships {
if let Some(s) = db.server().id().find(member.server_id) {
if let Some(avatar_id) = s.avatar_id {
ids.insert(avatar_id);
}
}
for peer in db.server_member().server_id().filter(member.server_id) {
if let Some(u) = db.user().identity().find(peer.identity) {
if let Some(avatar_id) = u.avatar_id {
ids.insert(avatar_id);
}
if let Some(banner_id) = u.banner_id {
ids.insert(banner_id);
}
}
}
}
// 2. Custom Emojis (Global)
for ce in db.custom_emoji().name().filter(""..) {
ids.insert(ce.id);
}
// 3. Active Channel Images (Recent + Scrollback)
if let Some(sub) = db.channel_subscription().identity().find(identity) {
// From Recent Messages cache for this channel
for rm in db.recent_message().channel_id().filter(sub.channel_id) {
for id in &rm.image_ids {
ids.insert(*id);
}
}
// From Scrollback Messages for this channel
for cms in db
.channel_message_sequence()
.channel_id()
.filter(sub.channel_id)
{
if cms.seq_id >= sub.earliest_seq_id {
if let Some(msg) = db.message().id().find(cms.message_id) {
for id in &msg.image_ids {
ids.insert(*id);
}
}
}
}
}
ids
}
pub fn get_visible_image_ids_read_only(db: &LocalReadOnly, identity: Identity) -> HashSet<u64> {
let mut ids = HashSet::new();
// 1. My Servers and their Members (Avatars/Banners)
let memberships: Vec<_> = db.server_member().identity().filter(identity).collect();
for member in memberships {
if let Some(s) = db.server().id().find(member.server_id) {
if let Some(avatar_id) = s.avatar_id {
ids.insert(avatar_id);
}
}
for peer in db.server_member().server_id().filter(member.server_id) {
if let Some(u) = db.user().identity().find(peer.identity) {
if let Some(avatar_id) = u.avatar_id {
ids.insert(avatar_id);
}
if let Some(banner_id) = u.banner_id {
ids.insert(banner_id);
}
}
}
}
// 2. Custom Emojis (Global)
for ce in db.custom_emoji().name().filter(""..) {
ids.insert(ce.id);
}
// 3. Active Channel Images (Recent + Scrollback)
if let Some(sub) = db.channel_subscription().identity().find(identity) {
// From Recent Messages cache for this channel
for rm in db.recent_message().channel_id().filter(sub.channel_id) {
for id in &rm.image_ids {
ids.insert(*id);
}
}
// From Scrollback Messages for this channel
for cms in db
.channel_message_sequence()
.channel_id()
.filter(sub.channel_id)
{
if cms.seq_id >= sub.earliest_seq_id {
if let Some(msg) = db.message().id().find(cms.message_id) {
for id in &msg.image_ids {
ids.insert(*id);
}
}
}
}
}
ids
}
pub fn clear_signaling_for_user(db: &Local, identity: Identity) {
if let Some(va) = db.voice_activity().identity().find(identity) {
db.voice_activity().delete(va);
}
let watchers: Vec<_> = db.watching().watcher().filter(identity).collect();
for row in watchers {
db.watching().delete(row);
}
let watchees: Vec<_> = db.watching().watchee().filter(identity).collect();
for row in watchees {
db.watching().delete(row);
}
for row in db
.webrtc_signal()
.sender()
.filter(identity)
.collect::<Vec<_>>()
{
db.webrtc_signal().delete(row);
}
for row in db
.webrtc_signal()
.receiver()
.filter(identity)
.collect::<Vec<_>>()
{
db.webrtc_signal().delete(row);
}
}
pub fn auto_join_community_server(db: &Local, identity: Identity) {
let community_server = db.server().name().filter(&"Zep".to_string()).next();
if let Some(s) = community_server {
let user = db.user().identity().find(identity);
db.server_member().insert(ServerMember {
id: 0,
identity,
server_id: s.id,
name: user.as_ref().and_then(|u| u.name.clone()),
avatar_id: user.as_ref().and_then(|u| u.avatar_id),
online: user.as_ref().map(|u| u.online).unwrap_or(false),
});
}
}
pub fn internal_open_direct_message(db: &Local, sender: Identity, recipient: Identity) -> u64 {
// Check if a DM already exists
let existing = db
.direct_message()
.sender()
.filter(sender)
.find(|dm| dm.recipient == recipient)
.or_else(|| {
db.direct_message()
.recipient()
.filter(sender)
.find(|dm| dm.sender == recipient)
});
if let Some(mut dm) = existing {
if dm.sender == sender {
dm.is_open_sender = true;
} else {
dm.is_open_recipient = true;
}
db.direct_message().id().update(dm.clone());
dm.channel_id
} else {
// Create a new DM channel
let chan = db.channel().insert(Channel {
id: 0,
server_id: 0,
name: "dm".to_string(),
kind: ChannelKind::Text,
});
db.direct_message().insert(DirectMessage {
id: 0,
channel_id: chan.id,
sender,
recipient,
is_open_sender: true,
is_open_recipient: true,
});
chan.id
}
}
pub fn internal_send_message(db: &Local, sender: Identity, channel_id: u64, text: String, timestamp: spacetimedb::Timestamp) {
let msg = db.message().insert(Message {
id: 0,
sender,
sent: timestamp,
text,
channel_id,
thread_id: None,
reactions: Vec::new(),
image_ids: Vec::new(),
thread_name: None,
thread_reply_count: 0,
edited: false,
is_encrypted: false,
});
let seq_id = get_next_seq_id(db, channel_id);
db.channel_message_sequence()
.insert(ChannelMessageSequence {
message_id: msg.id,
channel_id,
seq_id,
});
db.recent_message().insert(RecentMessage {
id: msg.id,
sender: msg.sender,
sent: msg.sent,
text: msg.text,
channel_id: msg.channel_id,
thread_id: msg.thread_id,
seq_id,
reactions: msg.reactions,
image_ids: msg.image_ids,
thread_name: msg.thread_name,
thread_reply_count: msg.thread_reply_count,
edited: msg.edited,
server_id: 0, // DMs have server_id 0
is_encrypted: false,
});
let limit = get_recent_message_limit(db);
if seq_id > limit {
let old_seq_id = seq_id - limit;
let to_delete: Vec<_> = db
.recent_message()
.channel_id()
.filter(channel_id)
.filter(|m| m.seq_id <= old_seq_id)
.map(|m| m.id)
.collect();
for id in to_delete {
db.recent_message().id().delete(id);
}
}
}
pub fn sync_server_member_info(db: &Local, identity: Identity) {
if let Some(user) = db.user().identity().find(identity) {
let members: Vec<_> = db.server_member().identity().filter(identity).collect();
for mut member in members {
member.name = user.name.clone();
member.avatar_id = user.avatar_id;
member.online = user.online;
db.server_member().id().update(member);
}
}
}