From 2f5761ba06ae599147013049f004c5ed88f0258c Mon Sep 17 00:00:00 2001 From: Jason Larabie Date: Mon, 13 Apr 2026 15:19:22 -0700 Subject: [PATCH] Sync up with changes from the Unreal Query Builder work --- .../include/spacetimedb/query_builder.h | 32 +- .../include/spacetimedb/query_builder/expr.h | 45 ++- .../include/spacetimedb/query_builder/join.h | 322 +++++++++++------- .../include/spacetimedb/query_builder/table.h | 177 +++++++--- .../fail_event_lookup.cpp | 12 +- .../fail_non_index_join.cpp | 12 +- .../pass_query_integration.cpp | 10 +- .../run_query_builder_compile_tests.sh | 2 +- .../query_builder_sql_tests.cpp | 54 ++- 9 files changed, 441 insertions(+), 225 deletions(-) diff --git a/crates/bindings-cpp/include/spacetimedb/query_builder.h b/crates/bindings-cpp/include/spacetimedb/query_builder.h index 884b943fd..22cbba987 100644 --- a/crates/bindings-cpp/include/spacetimedb/query_builder.h +++ b/crates/bindings-cpp/include/spacetimedb/query_builder.h @@ -1,11 +1,12 @@ #pragma once #include "spacetimedb/query_builder/expr.h" -#include "spacetimedb/query_builder/table.h" #include "spacetimedb/query_builder/join.h" +#include "spacetimedb/query_builder/table.h" #include #include +#include #include namespace SpacetimeDB { @@ -68,29 +69,28 @@ constexpr const char* GetQuerySourceName(const TSourceTag& tag) { class QueryBuilder { public: - template - [[nodiscard]] constexpr query_builder::Table table(const char* table_name) const { - return query_builder::Table(table_name); + template + [[nodiscard]] constexpr query_builder::Table table(const char* table_name, TCols cols, TIxCols ix_cols) const { + return query_builder::Table(table_name, std::move(cols), std::move(ix_cols)); } template [[nodiscard]] constexpr auto table(TTableTag tag) const - -> query_builder::Table::type> { - using Tag = std::remove_cvref_t; - using TRow = typename Tag::type; - // Tags may refer to either base tables or published view relations. - return query_builder::Table(detail::GetQuerySourceName(tag)); - } - - template - [[nodiscard]] constexpr auto operator()(TTableTag tag) const - -> query_builder::Table::type> { - return table(tag); + -> query_builder::Table< + typename std::remove_cvref_t::type, + decltype(query_builder::HasCols::type>::get(std::declval())), + decltype(query_builder::HasIxCols::type>::get(std::declval()))> { + using TRow = typename std::remove_cvref_t::type; + const char* table_name = detail::GetQuerySourceName(tag); + return table( + table_name, + query_builder::HasCols::get(table_name), + query_builder::HasIxCols::get(table_name)); } template [[nodiscard]] constexpr auto operator[](TTableTag tag) const - -> query_builder::Table::type> { + -> decltype(table(tag)) { return table(tag); } }; diff --git a/crates/bindings-cpp/include/spacetimedb/query_builder/expr.h b/crates/bindings-cpp/include/spacetimedb/query_builder/expr.h index 91eba6ce4..e6ff2e0c3 100644 --- a/crates/bindings-cpp/include/spacetimedb/query_builder/expr.h +++ b/crates/bindings-cpp/include/spacetimedb/query_builder/expr.h @@ -2,7 +2,9 @@ #include "spacetimedb/bsatn/timestamp.h" #include "spacetimedb/bsatn/types.h" +#include #include +#include #include #include #include @@ -58,21 +60,23 @@ inline std::string quote_string(std::string_view value) { inline std::string trim_timestamp_fraction(std::string value) { // Keep this in sync with the current Timestamp::to_string() UTC form. - // If that representation changes away from a +00:00 suffix, revisit this trimming logic. + // If that representation changes away from a +00:00 / Z suffix, revisit this trimming logic. const std::size_t plus = value.rfind("+00:00"); + const std::size_t z = value.rfind('Z'); const std::size_t dot = value.find('.'); - if (plus == std::string::npos || dot == std::string::npos || dot > plus) { + const std::size_t suffix = plus != std::string::npos ? plus : z; + if (suffix == std::string::npos || dot == std::string::npos || dot > suffix) { return value; } - std::size_t trim = plus; + std::size_t trim = suffix; while (trim > dot + 1 && value[trim - 1] == '0') { --trim; } if (trim == dot + 1) { - value.erase(dot, plus - dot); + value.erase(dot, suffix - dot); } else { - value.erase(trim, plus - trim); + value.erase(trim, suffix - trim); } return value; } @@ -84,6 +88,7 @@ inline std::string literal_sql(bool value) { return value ? "TRUE" : "FALSE"; } inline std::string literal_sql(const Identity& value) { return "0x" + value.to_hex_string(); } inline std::string literal_sql(const ConnectionId& value) { return "0x" + value.to_string(); } inline std::string literal_sql(const Timestamp& value) { return quote_string(trim_timestamp_fraction(value.to_string())); } +inline std::string literal_sql(const TimeDuration&) = delete; inline std::string literal_sql(const std::vector& value) { std::ostringstream out; out << "0x" << std::hex << std::setfill('0'); @@ -97,15 +102,23 @@ inline std::string literal_sql(const i128& value) { return value.to_string(); } inline std::string literal_sql(const u256& value) { return value.to_string(); } inline std::string literal_sql(const i256& value) { return value.to_string(); } -inline std::string format_floating_point(double value) { +template +inline std::string format_floating_point(TFloat value) { + char buffer[64]; + const auto result = std::to_chars(buffer, buffer + sizeof(buffer), value, std::chars_format::general); + if (result.ec == std::errc{}) { + return std::string(buffer, result.ptr); + } + std::ostringstream out; out.imbue(std::locale::classic()); + out << std::setprecision(std::numeric_limits::max_digits10); out << value; return out.str(); } inline std::string literal_sql(float value) { - return format_floating_point(static_cast(value)); + return format_floating_point(value); } inline std::string literal_sql(double value) { @@ -178,20 +191,20 @@ public: return format_node(root_); } - // Trailing underscore to avoid conflicting with C++ keyword [[nodiscard]] BoolExpr and_(const BoolExpr& other) const { return BoolExpr(std::make_shared(Kind::And, root_, other.root_)); } + [[nodiscard]] BoolExpr And(const BoolExpr& other) const { return and_(other); } - // Trailing underscore to avoid conflicting with C++ keyword [[nodiscard]] BoolExpr or_(const BoolExpr& other) const { return BoolExpr(std::make_shared(Kind::Or, root_, other.root_)); } + [[nodiscard]] BoolExpr Or(const BoolExpr& other) const { return or_(other); } - // Trailing underscore to avoid conflicting with C++ keyword [[nodiscard]] BoolExpr not_() const { return BoolExpr(std::make_shared(Kind::Not, root_, nullptr)); } + [[nodiscard]] BoolExpr Not() const { return not_(); } private: struct Node; @@ -303,15 +316,27 @@ public: template [[nodiscard]] BoolExpr eq(const TRhs& rhs) const { return compare(BoolExpr::Kind::Eq, rhs); } template + [[nodiscard]] BoolExpr Eq(const TRhs& rhs) const { return eq(rhs); } + template [[nodiscard]] BoolExpr ne(const TRhs& rhs) const { return compare(BoolExpr::Kind::Ne, rhs); } template + [[nodiscard]] BoolExpr Neq(const TRhs& rhs) const { return ne(rhs); } + template [[nodiscard]] BoolExpr gt(const TRhs& rhs) const { return compare(BoolExpr::Kind::Gt, rhs); } template + [[nodiscard]] BoolExpr Gt(const TRhs& rhs) const { return gt(rhs); } + template [[nodiscard]] BoolExpr lt(const TRhs& rhs) const { return compare(BoolExpr::Kind::Lt, rhs); } template + [[nodiscard]] BoolExpr Lt(const TRhs& rhs) const { return lt(rhs); } + template [[nodiscard]] BoolExpr gte(const TRhs& rhs) const { return compare(BoolExpr::Kind::Gte, rhs); } template + [[nodiscard]] BoolExpr Gte(const TRhs& rhs) const { return gte(rhs); } + template [[nodiscard]] BoolExpr lte(const TRhs& rhs) const { return compare(BoolExpr::Kind::Lte, rhs); } + template + [[nodiscard]] BoolExpr Lte(const TRhs& rhs) const { return lte(rhs); } [[nodiscard]] constexpr const ColumnRef& column_ref() const { return column_; } diff --git a/crates/bindings-cpp/include/spacetimedb/query_builder/join.h b/crates/bindings-cpp/include/spacetimedb/query_builder/join.h index 243e6354a..1f7178e7b 100644 --- a/crates/bindings-cpp/include/spacetimedb/query_builder/join.h +++ b/crates/bindings-cpp/include/spacetimedb/query_builder/join.h @@ -31,7 +31,67 @@ public: return IxJoinEq{column_, rhs.column_}; } + template + [[nodiscard]] auto Eq(const IxCol& rhs) const { + return eq(rhs); + } + + template + [[nodiscard]] BoolExpr eq(const TRhs& rhs) const { + return compare(BoolExpr::Kind::Eq, rhs); + } + + template + [[nodiscard]] BoolExpr Eq(const TRhs& rhs) const { return eq(rhs); } + + template + [[nodiscard]] BoolExpr ne(const TRhs& rhs) const { + return compare(BoolExpr::Kind::Ne, rhs); + } + + template + [[nodiscard]] BoolExpr Neq(const TRhs& rhs) const { return ne(rhs); } + + template + [[nodiscard]] BoolExpr gt(const TRhs& rhs) const { + return compare(BoolExpr::Kind::Gt, rhs); + } + + template + [[nodiscard]] BoolExpr Gt(const TRhs& rhs) const { return gt(rhs); } + + template + [[nodiscard]] BoolExpr lt(const TRhs& rhs) const { + return compare(BoolExpr::Kind::Lt, rhs); + } + + template + [[nodiscard]] BoolExpr Lt(const TRhs& rhs) const { return lt(rhs); } + + template + [[nodiscard]] BoolExpr gte(const TRhs& rhs) const { + return compare(BoolExpr::Kind::Gte, rhs); + } + + template + [[nodiscard]] BoolExpr Gte(const TRhs& rhs) const { return gte(rhs); } + + template + [[nodiscard]] BoolExpr lte(const TRhs& rhs) const { + return compare(BoolExpr::Kind::Lte, rhs); + } + + template + [[nodiscard]] BoolExpr Lte(const TRhs& rhs) const { return lte(rhs); } + + [[nodiscard]] constexpr const ColumnRef& column_ref() const { return column_; } + private: + template + [[nodiscard]] BoolExpr compare(typename BoolExpr::Kind kind, const TRhs& rhs) const { + return BoolExpr::compare(kind, detail::Operand::column(column_), detail::to_operand(rhs)); + } + ColumnRef column_; template @@ -43,33 +103,6 @@ namespace detail { template inline constexpr bool delayed_is_indexed_member_v = is_indexed_member_v; -template -class MaybeIxCol { -public: - constexpr MaybeIxCol() = default; - constexpr MaybeIxCol(const char* table_name, const char* column_name) - : column_(table_name, column_name) {} - - template - [[nodiscard]] auto eq(const MaybeIxCol& rhs) const { - static_assert( - is_indexed_member_v && is_indexed_member_v, - "Semijoin predicates may only use single-column indexed fields."); - return IxJoinEq{column_, rhs.column_}; - } - -private: - ColumnRef column_{}; - - template - friend class MaybeIxCol; -}; - -template -using ix_col_member_t = MaybeIxCol; -// HasIxCols currently exposes all fields through MaybeIxCol so table/view macros can stay uniform. -// Non-indexed fields are rejected when .eq() is used in a semijoin predicate. - } // namespace detail template @@ -78,18 +111,42 @@ struct IxJoinEq { ColumnRef rhs; }; -template +template class LeftSemiJoin { public: - using row_type = TRow; + using row_type = TLeftRow; - LeftSemiJoin(ColumnRef lhs, const char* right_table, const char* right_column, std::optional> where_expr = std::nullopt) - : lhs_(lhs), right_table_(right_table), right_column_(right_column), where_expr_(std::move(where_expr)) {} + LeftSemiJoin( + Table left, + Table right, + ColumnRef left_join_ref, + ColumnRef right_join_ref, + std::optional> where_expr = std::nullopt) + : left_(std::move(left)) + , right_(std::move(right)) + , left_join_ref_(left_join_ref) + , right_join_ref_(right_join_ref) + , where_expr_(std::move(where_expr)) {} template [[nodiscard]] LeftSemiJoin where(TFn&& predicate) const { - auto extra = detail::make_bool_expr(std::forward(predicate)(HasCols::get(lhs_.table_name()))); - return LeftSemiJoin(lhs_, right_table_, right_column_, where_expr_ ? where_expr_->and_(extra) : std::optional>(std::move(extra))); + auto extra = detail::make_bool_expr(std::forward(predicate)(left_.cols())); + return LeftSemiJoin(left_, right_, left_join_ref_, right_join_ref_, where_expr_ ? where_expr_->and_(extra) : std::optional>(std::move(extra))); + } + + template + [[nodiscard]] LeftSemiJoin where_ix(TFn&& predicate) const { + auto extra = detail::make_bool_expr(std::forward(predicate)(left_.cols(), left_.ix_cols())); + return LeftSemiJoin(left_, right_, left_join_ref_, right_join_ref_, where_expr_ ? where_expr_->and_(extra) : std::optional>(std::move(extra))); + } + + template + [[nodiscard]] LeftSemiJoin Where(TFn&& predicate) const { + if constexpr (std::is_invocable_v) { + return where_ix(std::forward(predicate)); + } else { + return where(std::forward(predicate)); + } } template @@ -97,60 +154,83 @@ public: return where(std::forward(predicate)); } - [[nodiscard]] RawQuery build() const { + template + [[nodiscard]] LeftSemiJoin Filter(TFn&& predicate) const { + return Where(std::forward(predicate)); + } + + [[nodiscard]] RawQuery build() const { std::string sql; sql.reserve( 48 + - (std::char_traits::length(lhs_.table_name()) * 3) + - std::char_traits::length(right_table_) * 2 + - std::char_traits::length(lhs_.column_name()) + - std::char_traits::length(right_column_)); + (std::char_traits::length(left_.name()) * 3) + + std::char_traits::length(right_.name()) * 2 + + std::char_traits::length(left_join_ref_.column_name()) + + std::char_traits::length(right_join_ref_.column_name())); sql += "SELECT \""; - sql += lhs_.table_name(); + sql += left_.name(); sql += "\".* FROM \""; - sql += lhs_.table_name(); + sql += left_.name(); sql += "\" JOIN \""; - sql += right_table_; - sql += "\" ON \""; - sql += lhs_.table_name(); - sql += "\".\""; - sql += lhs_.column_name(); - sql += "\" = \""; - sql += right_table_; - sql += "\".\""; - sql += right_column_; - sql += "\""; + sql += right_.name(); + sql += "\" ON "; + sql += left_join_ref_.format(); + sql += " = "; + sql += right_join_ref_.format(); if (where_expr_) { sql += " WHERE " + where_expr_->format(); } - return RawQuery(std::move(sql)); + return RawQuery(std::move(sql)); } [[nodiscard]] std::string into_sql() const { return build().into_sql(); } private: - ColumnRef lhs_; - const char* right_table_; - const char* right_column_; - std::optional> where_expr_; + Table left_; + Table right_; + ColumnRef left_join_ref_; + ColumnRef right_join_ref_; + std::optional> where_expr_; }; -template +template class RightSemiJoin { public: using row_type = TRightRow; RightSemiJoin( - ColumnRef lhs, - ColumnRef rhs, + Table left, + Table right, + ColumnRef left_join_ref, + ColumnRef right_join_ref, std::optional> left_where_expr = std::nullopt, std::optional> right_where_expr = std::nullopt) - : lhs_(lhs), rhs_(rhs), left_where_expr_(std::move(left_where_expr)), right_where_expr_(std::move(right_where_expr)) {} + : left_(std::move(left)) + , right_(std::move(right)) + , left_join_ref_(left_join_ref) + , right_join_ref_(right_join_ref) + , left_where_expr_(std::move(left_where_expr)) + , right_where_expr_(std::move(right_where_expr)) {} template [[nodiscard]] RightSemiJoin where(TFn&& predicate) const { - auto extra = detail::make_bool_expr(std::forward(predicate)(HasCols::get(rhs_.table_name()))); - return RightSemiJoin(lhs_, rhs_, left_where_expr_, right_where_expr_ ? right_where_expr_->and_(extra) : std::optional>(std::move(extra))); + auto extra = detail::make_bool_expr(std::forward(predicate)(right_.cols())); + return RightSemiJoin(left_, right_, left_join_ref_, right_join_ref_, left_where_expr_, right_where_expr_ ? right_where_expr_->and_(extra) : std::optional>(std::move(extra))); + } + + template + [[nodiscard]] RightSemiJoin where_ix(TFn&& predicate) const { + auto extra = detail::make_bool_expr(std::forward(predicate)(right_.cols(), right_.ix_cols())); + return RightSemiJoin(left_, right_, left_join_ref_, right_join_ref_, left_where_expr_, right_where_expr_ ? right_where_expr_->and_(extra) : std::optional>(std::move(extra))); + } + + template + [[nodiscard]] RightSemiJoin Where(TFn&& predicate) const { + if constexpr (std::is_invocable_v) { + return where_ix(std::forward(predicate)); + } else { + return where(std::forward(predicate)); + } } template @@ -158,29 +238,29 @@ public: return where(std::forward(predicate)); } + template + [[nodiscard]] RightSemiJoin Filter(TFn&& predicate) const { + return Where(std::forward(predicate)); + } + [[nodiscard]] RawQuery build() const { std::string sql; sql.reserve( 48 + - (std::char_traits::length(lhs_.table_name()) * 2) + - (std::char_traits::length(rhs_.table_name()) * 3) + - std::char_traits::length(lhs_.column_name()) + - std::char_traits::length(rhs_.column_name())); + (std::char_traits::length(left_.name()) * 2) + + (std::char_traits::length(right_.name()) * 3) + + std::char_traits::length(left_join_ref_.column_name()) + + std::char_traits::length(right_join_ref_.column_name())); sql += "SELECT \""; - sql += rhs_.table_name(); + sql += right_.name(); sql += "\".* FROM \""; - sql += lhs_.table_name(); + sql += left_.name(); sql += "\" JOIN \""; - sql += rhs_.table_name(); - sql += "\" ON \""; - sql += lhs_.table_name(); - sql += "\".\""; - sql += lhs_.column_name(); - sql += "\" = \""; - sql += rhs_.table_name(); - sql += "\".\""; - sql += rhs_.column_name(); - sql += "\""; + sql += right_.name(); + sql += "\" ON "; + sql += left_join_ref_.format(); + sql += " = "; + sql += right_join_ref_.format(); if (left_where_expr_ && right_where_expr_) { sql += " WHERE "; @@ -201,83 +281,77 @@ public: [[nodiscard]] std::string into_sql() const { return build().into_sql(); } private: - ColumnRef lhs_; - ColumnRef rhs_; + Table left_; + Table right_; + ColumnRef left_join_ref_; + ColumnRef right_join_ref_; std::optional> left_where_expr_; std::optional> right_where_expr_; }; namespace detail { -template -[[nodiscard]] LeftSemiJoin left_semijoin_impl(const Table& left, const Table& right, TFn&& predicate) { - static_assert(can_be_lookup_table_v, "Lookup side of a semijoin must opt in via CanBeLookupTable."); - static_assert(requires { HasIxCols::get(left.name()); }, "Left side of a semijoin must provide HasIxCols."); - static_assert(requires { HasIxCols::get(right.name()); }, "Lookup side of a semijoin must provide HasIxCols."); - const auto join = std::forward(predicate)(HasIxCols::get(left.name()), HasIxCols::get(right.name())); - return LeftSemiJoin(join.lhs, right.name(), join.rhs.column_name()); +template +[[nodiscard]] auto left_semijoin_impl(const Table& left, const Table& right, TFn&& predicate) { + static_assert(can_be_lookup_table_v>, "Lookup side of a semijoin must opt in via CanBeLookupTable."); + const auto join = std::forward(predicate)(left.ix_cols(), right.ix_cols()); + return LeftSemiJoin(left, right, join.lhs, join.rhs); } -template -[[nodiscard]] LeftSemiJoin left_semijoin_impl(const FromWhere& left, const Table& right, TFn&& predicate) { - static_assert(can_be_lookup_table_v, "Lookup side of a semijoin must opt in via CanBeLookupTable."); - static_assert(requires { HasIxCols::get(left.table_name()); }, "Left side of a semijoin must provide HasIxCols."); - static_assert(requires { HasIxCols::get(right.name()); }, "Lookup side of a semijoin must provide HasIxCols."); - const auto join = std::forward(predicate)(HasIxCols::get(left.table_name()), HasIxCols::get(right.name())); - return LeftSemiJoin(join.lhs, right.name(), join.rhs.column_name(), left.expr()); +template +[[nodiscard]] auto left_semijoin_impl(const FromWhere& left, const Table& right, TFn&& predicate) { + static_assert(can_be_lookup_table_v>, "Lookup side of a semijoin must opt in via CanBeLookupTable."); + const auto join = std::forward(predicate)(left.table().ix_cols(), right.ix_cols()); + return LeftSemiJoin(left.table(), right, join.lhs, join.rhs, left.expr()); } -template -[[nodiscard]] RightSemiJoin right_semijoin_impl(const Table& left, const Table& right, TFn&& predicate) { - static_assert(can_be_lookup_table_v, "Lookup side of a semijoin must opt in via CanBeLookupTable."); - static_assert(requires { HasIxCols::get(left.name()); }, "Left side of a semijoin must provide HasIxCols."); - static_assert(requires { HasIxCols::get(right.name()); }, "Lookup side of a semijoin must provide HasIxCols."); - const auto join = std::forward(predicate)(HasIxCols::get(left.name()), HasIxCols::get(right.name())); - return RightSemiJoin(join.lhs, join.rhs); +template +[[nodiscard]] auto right_semijoin_impl(const Table& left, const Table& right, TFn&& predicate) { + static_assert(can_be_lookup_table_v>, "Lookup side of a semijoin must opt in via CanBeLookupTable."); + const auto join = std::forward(predicate)(left.ix_cols(), right.ix_cols()); + return RightSemiJoin(left, right, join.lhs, join.rhs); } -template -[[nodiscard]] RightSemiJoin right_semijoin_impl(const FromWhere& left, const Table& right, TFn&& predicate) { - static_assert(can_be_lookup_table_v, "Lookup side of a semijoin must opt in via CanBeLookupTable."); - static_assert(requires { HasIxCols::get(left.table_name()); }, "Left side of a semijoin must provide HasIxCols."); - static_assert(requires { HasIxCols::get(right.name()); }, "Lookup side of a semijoin must provide HasIxCols."); - const auto join = std::forward(predicate)(HasIxCols::get(left.table_name()), HasIxCols::get(right.name())); - return RightSemiJoin(join.lhs, join.rhs, left.expr()); +template +[[nodiscard]] auto right_semijoin_impl(const FromWhere& left, const Table& right, TFn&& predicate) { + static_assert(can_be_lookup_table_v>, "Lookup side of a semijoin must opt in via CanBeLookupTable."); + const auto join = std::forward(predicate)(left.table().ix_cols(), right.ix_cols()); + return RightSemiJoin(left.table(), right, join.lhs, join.rhs, left.expr()); } } // namespace detail -template -template -[[nodiscard]] LeftSemiJoin Table::left_semijoin(const Table& right, TFn&& predicate) const { +template +template +[[nodiscard]] auto Table::left_semijoin(const Table& right, TFn&& predicate) const { return detail::left_semijoin_impl(*this, right, std::forward(predicate)); } -template -template -[[nodiscard]] RightSemiJoin Table::right_semijoin(const Table& right, TFn&& predicate) const { +template +template +[[nodiscard]] auto Table::right_semijoin(const Table& right, TFn&& predicate) const { return detail::right_semijoin_impl(*this, right, std::forward(predicate)); } -template -template -[[nodiscard]] LeftSemiJoin FromWhere::left_semijoin(const Table& right, TFn&& predicate) const { +template +template +[[nodiscard]] auto FromWhere::left_semijoin(const Table& right, TFn&& predicate) const { return detail::left_semijoin_impl(*this, right, std::forward(predicate)); } -template -template -[[nodiscard]] RightSemiJoin FromWhere::right_semijoin(const Table& right, TFn&& predicate) const { +template +template +[[nodiscard]] auto FromWhere::right_semijoin(const Table& right, TFn&& predicate) const { return detail::right_semijoin_impl(*this, right, std::forward(predicate)); } -template -struct query_row_type> { - using type = TRow; +template +struct query_row_type> { + using type = TLeftRow; }; -template -struct query_row_type> { +template +struct query_row_type> { using type = TRightRow; }; diff --git a/crates/bindings-cpp/include/spacetimedb/query_builder/table.h b/crates/bindings-cpp/include/spacetimedb/query_builder/table.h index 0fd0192b6..38f01c41e 100644 --- a/crates/bindings-cpp/include/spacetimedb/query_builder/table.h +++ b/crates/bindings-cpp/include/spacetimedb/query_builder/table.h @@ -8,22 +8,29 @@ #include #include +#ifndef SPACETIMEDB_QUERY_BUILDER_ENABLE_BSATN +#define SPACETIMEDB_QUERY_BUILDER_ENABLE_BSATN 1 +#endif + namespace SpacetimeDB::query_builder { template struct query_row_type; -template +template class Table; -template +template class FromWhere; -template -class LeftSemiJoin; +template +struct HasCols; -template -class RightSemiJoin; +template +struct HasIxCols; + +template +struct CanBeLookupTable : std::false_type {}; template using query_row_type_t = typename query_row_type>::type; @@ -51,8 +58,8 @@ private: std::string sql_; }; -template -concept QueryLike = requires(const TRow& query) { +template +concept QueryLike = requires(const T& query) { { query.into_sql() } -> std::convertible_to; }; @@ -61,15 +68,6 @@ concept QueryBuilderReturn = requires { typename query_row_type_t; } && QueryLike>; -template -struct HasCols; - -template -struct HasIxCols; - -template -struct CanBeLookupTable : std::false_type {}; - namespace detail { template @@ -86,18 +84,28 @@ std::false_type adl_lookup_table_allowed(...); } // namespace detail template -inline constexpr bool can_be_lookup_table_v = +inline constexpr bool can_be_lookup_row_v = CanBeLookupTable::value || decltype(detail::adl_lookup_table_allowed(0))::value; -template +template +inline constexpr bool can_be_lookup_table_v = CanBeLookupTable>::value; + +template +struct CanBeLookupTable> : std::bool_constant> {}; + +template class Table { public: using row_type = TRow; + using cols_type = TCols; + using ix_cols_type = TIxCols; - explicit constexpr Table(const char* table_name) - : table_name_(table_name) {} + constexpr Table(const char* table_name, TCols cols, TIxCols ix_cols) + : table_name_(table_name), cols_(std::move(cols)), ix_cols_(std::move(ix_cols)) {} [[nodiscard]] constexpr const char* name() const { return table_name_; } + [[nodiscard]] constexpr const TCols& cols() const { return cols_; } + [[nodiscard]] constexpr const TIxCols& ix_cols() const { return ix_cols_; } [[nodiscard]] RawQuery build() const { std::string sql; @@ -108,58 +116,107 @@ public: return RawQuery(std::move(sql)); } - [[nodiscard]] std::string into_sql() const { - return build().into_sql(); + [[nodiscard]] std::string into_sql() const { return build().into_sql(); } + + template + [[nodiscard]] auto where(TFn&& predicate) const { + auto expr = detail::make_bool_expr(std::forward(predicate)(cols_)); + return FromWhere(*this, std::move(expr)); } template - [[nodiscard]] auto where(TFn&& predicate) const; + [[nodiscard]] auto where_ix(TFn&& predicate) const { + auto expr = detail::make_bool_expr(std::forward(predicate)(cols_, ix_cols_)); + return FromWhere(*this, std::move(expr)); + } + + template + [[nodiscard]] auto Where(TFn&& predicate) const { + if constexpr (std::is_invocable_v) { + return where_ix(std::forward(predicate)); + } else { + return where(std::forward(predicate)); + } + } template [[nodiscard]] auto filter(TFn&& predicate) const { return where(std::forward(predicate)); } - template - [[nodiscard]] LeftSemiJoin left_semijoin(const Table& right, TFn&& predicate) const; + template + [[nodiscard]] auto Filter(TFn&& predicate) const { + return Where(std::forward(predicate)); + } - template - [[nodiscard]] RightSemiJoin right_semijoin(const Table& right, TFn&& predicate) const; + template + [[nodiscard]] auto left_semijoin(const Table& right, TFn&& predicate) const; + + template + [[nodiscard]] auto LeftSemijoin(const Table& right, TFn&& predicate) const { + return left_semijoin(right, std::forward(predicate)); + } + + template + [[nodiscard]] auto right_semijoin(const Table& right, TFn&& predicate) const; + + template + [[nodiscard]] auto RightSemijoin(const Table& right, TFn&& predicate) const { + return right_semijoin(right, std::forward(predicate)); + } private: const char* table_name_; + TCols cols_; + TIxCols ix_cols_; }; -template +template class FromWhere { public: using row_type = TRow; + using cols_type = TCols; + using ix_cols_type = TIxCols; - constexpr FromWhere(const char* table_name, BoolExpr expr) - : table_name_(table_name), expr_(std::move(expr)) {} + constexpr FromWhere(Table table, BoolExpr expr) + : table_(std::move(table)), expr_(std::move(expr)) {} - [[nodiscard]] constexpr const char* table_name() const { return table_name_; } + [[nodiscard]] constexpr const char* table_name() const { return table_.name(); } + [[nodiscard]] constexpr const Table& table() const { return table_; } [[nodiscard]] const BoolExpr& expr() const { return expr_; } [[nodiscard]] RawQuery build() const { std::string predicate = expr_.format(); std::string sql; - sql.reserve(24 + std::char_traits::length(table_name_) + predicate.size()); + sql.reserve(24 + std::char_traits::length(table_.name()) + predicate.size()); sql += "SELECT * FROM \""; - sql += table_name_; + sql += table_.name(); sql += "\" WHERE "; sql += predicate; return RawQuery(std::move(sql)); } - [[nodiscard]] std::string into_sql() const { - return build().into_sql(); - } + [[nodiscard]] std::string into_sql() const { return build().into_sql(); } template [[nodiscard]] FromWhere where(TFn&& predicate) const { - auto extra = detail::make_bool_expr(std::forward(predicate)(HasCols::get(table_name_))); - return FromWhere(table_name_, expr_.and_(extra)); + auto extra = detail::make_bool_expr(std::forward(predicate)(table_.cols())); + return FromWhere(table_, expr_.and_(extra)); + } + + template + [[nodiscard]] FromWhere where_ix(TFn&& predicate) const { + auto extra = detail::make_bool_expr(std::forward(predicate)(table_.cols(), table_.ix_cols())); + return FromWhere(table_, expr_.and_(extra)); + } + + template + [[nodiscard]] FromWhere Where(TFn&& predicate) const { + if constexpr (std::is_invocable_v) { + return where_ix(std::forward(predicate)); + } else { + return where(std::forward(predicate)); + } } template @@ -167,41 +224,50 @@ public: return where(std::forward(predicate)); } - template - [[nodiscard]] LeftSemiJoin left_semijoin(const Table& right, TFn&& predicate) const; + template + [[nodiscard]] FromWhere Filter(TFn&& predicate) const { + return Where(std::forward(predicate)); + } - template - [[nodiscard]] RightSemiJoin right_semijoin(const Table& right, TFn&& predicate) const; + template + [[nodiscard]] auto left_semijoin(const Table& right, TFn&& predicate) const; + + template + [[nodiscard]] auto LeftSemijoin(const Table& right, TFn&& predicate) const { + return left_semijoin(right, std::forward(predicate)); + } + + template + [[nodiscard]] auto right_semijoin(const Table& right, TFn&& predicate) const; + + template + [[nodiscard]] auto RightSemijoin(const Table& right, TFn&& predicate) const { + return right_semijoin(right, std::forward(predicate)); + } private: - const char* table_name_; + Table table_; BoolExpr expr_; }; -template -template -[[nodiscard]] auto Table::where(TFn&& predicate) const { - auto expr = detail::make_bool_expr(std::forward(predicate)(HasCols::get(table_name_))); - return FromWhere(table_name_, std::move(expr)); -} - template struct query_row_type> { using type = TRow; }; -template -struct query_row_type> { +template +struct query_row_type> { using type = TRow; }; -template -struct query_row_type> { +template +struct query_row_type> { using type = TRow; }; } // namespace SpacetimeDB::query_builder +#if SPACETIMEDB_QUERY_BUILDER_ENABLE_BSATN namespace SpacetimeDB::bsatn { template @@ -235,3 +301,4 @@ struct bsatn_traits<::SpacetimeDB::query_builder::RawQuery> { }; } // namespace SpacetimeDB::bsatn +#endif diff --git a/crates/bindings-cpp/tests/query-builder-compile/fail_event_lookup.cpp b/crates/bindings-cpp/tests/query-builder-compile/fail_event_lookup.cpp index e76bcf4af..d2772a4c7 100644 --- a/crates/bindings-cpp/tests/query-builder-compile/fail_event_lookup.cpp +++ b/crates/bindings-cpp/tests/query-builder-compile/fail_event_lookup.cpp @@ -2,6 +2,14 @@ using namespace SpacetimeDB; +template +auto TableFor(const char* table_name) { + return QueryBuilder{}.table( + table_name, + query_builder::HasCols::get(table_name), + query_builder::HasIxCols::get(table_name)); +} + struct User { Identity identity; }; @@ -16,8 +24,8 @@ SPACETIMEDB_STRUCT(AuditEvent, identity) SPACETIMEDB_TABLE(AuditEvent, audit_event, Public, true) FIELD_PrimaryKey(audit_event, identity) -auto invalid_event_lookup = QueryBuilder{}[user].right_semijoin( - QueryBuilder{}[audit_event], +auto invalid_event_lookup = TableFor("user").right_semijoin( + TableFor("audit_event"), [](const auto& users, const auto& events) { return users.identity.eq(events.identity); }); diff --git a/crates/bindings-cpp/tests/query-builder-compile/fail_non_index_join.cpp b/crates/bindings-cpp/tests/query-builder-compile/fail_non_index_join.cpp index 27b2ed894..c38287b9b 100644 --- a/crates/bindings-cpp/tests/query-builder-compile/fail_non_index_join.cpp +++ b/crates/bindings-cpp/tests/query-builder-compile/fail_non_index_join.cpp @@ -2,6 +2,14 @@ using namespace SpacetimeDB; +template +auto TableFor(const char* table_name) { + return QueryBuilder{}.table( + table_name, + query_builder::HasCols::get(table_name), + query_builder::HasIxCols::get(table_name)); +} + struct User { Identity identity; uint64_t tenant_id; @@ -18,8 +26,8 @@ SPACETIMEDB_STRUCT(Membership, identity, tenant_id) SPACETIMEDB_TABLE(Membership, membership, Public) FIELD_PrimaryKey(membership, identity) -auto invalid_join = QueryBuilder{}[user].right_semijoin( - QueryBuilder{}[membership], +auto invalid_join = TableFor("user").right_semijoin( + TableFor("membership"), [](const auto& users, const auto& memberships) { return users.tenant_id.eq(memberships.tenant_id); }); diff --git a/crates/bindings-cpp/tests/query-builder-compile/pass_query_integration.cpp b/crates/bindings-cpp/tests/query-builder-compile/pass_query_integration.cpp index 3728b257e..602181538 100644 --- a/crates/bindings-cpp/tests/query-builder-compile/pass_query_integration.cpp +++ b/crates/bindings-cpp/tests/query-builder-compile/pass_query_integration.cpp @@ -2,6 +2,14 @@ using namespace SpacetimeDB; +template +auto TableFor(const char* table_name) { + return QueryBuilder{}.table( + table_name, + query_builder::HasCols::get(table_name), + query_builder::HasIxCols::get(table_name)); +} + struct User { Identity identity; bool online; @@ -38,7 +46,7 @@ FIELD_Index(auto_inc_membership, auto_inc_user_id) SPACETIMEDB_CLIENT_VISIBILITY_FILTER( online_users_filter, - QueryBuilder{}[user].where([](const auto& users) { + TableFor("user").where([](const auto& users) { return users.online; })) diff --git a/crates/bindings-cpp/tests/query-builder-compile/run_query_builder_compile_tests.sh b/crates/bindings-cpp/tests/query-builder-compile/run_query_builder_compile_tests.sh index d6cc40c6d..d64df72d7 100644 --- a/crates/bindings-cpp/tests/query-builder-compile/run_query_builder_compile_tests.sh +++ b/crates/bindings-cpp/tests/query-builder-compile/run_query_builder_compile_tests.sh @@ -95,7 +95,7 @@ compile_should_fail() { } compile_should_pass "$SCRIPT_DIR/pass_query_integration.cpp" -compile_should_fail "$SCRIPT_DIR/fail_non_index_join.cpp" "Semijoin predicates may only use single-column indexed fields." +compile_should_fail "$SCRIPT_DIR/fail_non_index_join.cpp" "no member named 'tenant_id'" compile_should_fail "$SCRIPT_DIR/fail_event_lookup.cpp" "Lookup side of a semijoin must opt in via CanBeLookupTable." echo "All query-builder compile tests passed" diff --git a/crates/bindings-cpp/tests/query-builder-sql/query_builder_sql_tests.cpp b/crates/bindings-cpp/tests/query-builder-sql/query_builder_sql_tests.cpp index dbfb75410..c75cbae9a 100644 --- a/crates/bindings-cpp/tests/query-builder-sql/query_builder_sql_tests.cpp +++ b/crates/bindings-cpp/tests/query-builder-sql/query_builder_sql_tests.cpp @@ -93,6 +93,10 @@ struct ConnectionRowCols { : connection_id(table_name, "connection_id") {} }; +struct ConnectionRowIxCols { + explicit ConnectionRowIxCols(const char*) {} +}; + struct LiteralRowCols { qb::Col score; qb::Col name; @@ -114,6 +118,10 @@ struct LiteralRowCols { bytes(table_name, "bytes") {} }; +struct LiteralRowIxCols { + explicit LiteralRowIxCols(const char*) {} +}; + } // namespace test_query_builder namespace SpacetimeDB::query_builder { @@ -149,11 +157,21 @@ struct HasCols { static test_query_builder::ConnectionRowCols get(const char* table_name) { return test_query_builder::ConnectionRowCols(table_name); } }; +template<> +struct HasIxCols { + static test_query_builder::ConnectionRowIxCols get(const char* table_name) { return test_query_builder::ConnectionRowIxCols(table_name); } +}; + template<> struct HasCols { static test_query_builder::LiteralRowCols get(const char* table_name) { return test_query_builder::LiteralRowCols(table_name); } }; +template<> +struct HasIxCols { + static test_query_builder::LiteralRowIxCols get(const char* table_name) { return test_query_builder::LiteralRowIxCols(table_name); } +}; + } // namespace SpacetimeDB::query_builder namespace SpacetimeDB::bsatn { @@ -176,6 +194,14 @@ struct algebraic_type_of { namespace test_query_builder { +template +auto TableFor(const char* table_name) { + return SpacetimeDB::QueryBuilder().table( + table_name, + qb::HasCols::get(table_name), + qb::HasIxCols::get(table_name)); +} + void ExpectEq(const std::string& actual, const std::string& expected, const std::string& label) { if (actual != expected) { std::ostringstream out; @@ -185,18 +211,18 @@ void ExpectEq(const std::string& actual, const std::string& expected, const std: } void TestSimpleSelect() { - qb::Table users("users"); + auto users = TableFor("users"); ExpectEq(users.build().sql(), "SELECT * FROM \"users\"", "simple select"); } void TestWhereLiteral() { - qb::Table users("users"); + auto users = TableFor("users"); const auto query = users.where([](const auto& user) { return user.id.eq(10); }).build(); ExpectEq(query.sql(), "SELECT * FROM \"users\" WHERE (\"users\".\"id\" = 10)", "where literal"); } void TestWhereMultiplePredicates() { - qb::Table users("users"); + auto users = TableFor("users"); const auto query = users .where([](const auto& user) { return user.id.eq(10); }) .where([](const auto& user) { return user.id.gt(3); }) @@ -208,7 +234,7 @@ void TestWhereMultiplePredicates() { } void TestWhereAndFilter() { - qb::Table users("users"); + auto users = TableFor("users"); const auto query = users .where([](const auto& user) { return user.online; }) .filter([](const auto& user) { return user.id.gt(10); }) @@ -220,7 +246,7 @@ void TestWhereAndFilter() { } void TestColumnComparisons() { - qb::Table users("users"); + auto users = TableFor("users"); ExpectEq( users.where([](const auto& user) { return user.id.eq(user.id); }).build().sql(), @@ -234,7 +260,7 @@ void TestColumnComparisons() { } void TestComparisonOperators() { - qb::Table users("users"); + auto users = TableFor("users"); ExpectEq( users.where([](const auto& user) { return user.name.ne("Shub"); }).build().sql(), @@ -251,7 +277,7 @@ void TestComparisonOperators() { } void TestLogicalComposition() { - qb::Table users("users"); + auto users = TableFor("users"); const auto query = users .where([](const auto& user) { return user.name.eq("Alice").not_().and_(user.online.eq(true).or_(user.id.gte(7))); @@ -264,7 +290,7 @@ void TestLogicalComposition() { } void TestNotAndOr() { - qb::Table users("users"); + auto users = TableFor("users"); ExpectEq( users.where([](const auto& user) { return user.name.eq("Alice").not_(); }).build().sql(), @@ -287,7 +313,7 @@ void TestNotAndOr() { } void TestFilterAlias() { - qb::Table users("users"); + auto users = TableFor("users"); const auto query = users .filter([](const auto& user) { return user.id.eq(5); }) .filter([](const auto& user) { return user.id.lt(30); }) @@ -299,7 +325,7 @@ void TestFilterAlias() { } void TestLiteralFormatting() { - qb::Table users("users"); + auto users = TableFor("users"); std::array identity_bytes{}; identity_bytes.front() = 1; @@ -337,7 +363,7 @@ void TestLiteralFormatting() { "SELECT * FROM \"users\" WHERE (\"users\".\"online\" = TRUE)", "bool literal formatting"); - qb::Table connections("player"); + auto connections = TableFor("player"); ExpectEq( connections.where([&](const auto& row) { return row.connection_id.eq(connection_id); }).build().sql(), "SELECT * FROM \"player\" WHERE (\"player\".\"connection_id\" = 0x00000000000000000000000000000000)", @@ -345,7 +371,7 @@ void TestLiteralFormatting() { } void TestLiteralMatrix() { - qb::Table table("player"); + auto table = TableFor("player"); ExpectEq( table.where([](const auto& row) { return row.score.eq(100); }).build().sql(), @@ -423,8 +449,8 @@ void TestQueryReturnWrapperShape() { } void TestSemiJoins() { - qb::Table users("users"); - qb::Table levels("player_level"); + auto users = TableFor("users"); + auto levels = TableFor("player_level"); const auto left = users.left_semijoin(levels, [](const auto& user, const auto& level) { return user.id.eq(level.entity_id);