diff --git a/crates/bindings-cpp/include/spacetimedb/query_builder/expr.h b/crates/bindings-cpp/include/spacetimedb/query_builder/expr.h index 381b5fd67..b586b560c 100644 --- a/crates/bindings-cpp/include/spacetimedb/query_builder/expr.h +++ b/crates/bindings-cpp/include/spacetimedb/query_builder/expr.h @@ -21,6 +21,14 @@ namespace SpacetimeDB::query_builder { template class Col; +template +struct is_col : std::false_type {}; + +template +inline constexpr bool is_rhs_for_value_v = + std::is_same_v> || + (std::is_same_v && std::is_convertible_v); + template class ColumnRef { public: @@ -314,30 +322,194 @@ public: : column_(table_name, column_name) {} template + requires(is_rhs_for_value_v) [[nodiscard]] BoolExpr eq(const TRhs& rhs) const { return compare(BoolExpr::Kind::Eq, rhs); } template + requires(is_rhs_for_value_v) [[nodiscard]] BoolExpr Eq(const TRhs& rhs) const { return eq(rhs); } template + requires(is_rhs_for_value_v) [[nodiscard]] BoolExpr ne(const TRhs& rhs) const { return compare(BoolExpr::Kind::Ne, rhs); } template + requires(is_rhs_for_value_v) [[nodiscard]] BoolExpr Ne(const TRhs& rhs) const { return ne(rhs); } template + requires(is_rhs_for_value_v) [[nodiscard]] BoolExpr gt(const TRhs& rhs) const { return compare(BoolExpr::Kind::Gt, rhs); } template + requires(is_rhs_for_value_v) [[nodiscard]] BoolExpr Gt(const TRhs& rhs) const { return gt(rhs); } template + requires(is_rhs_for_value_v) [[nodiscard]] BoolExpr lt(const TRhs& rhs) const { return compare(BoolExpr::Kind::Lt, rhs); } template + requires(is_rhs_for_value_v) [[nodiscard]] BoolExpr Lt(const TRhs& rhs) const { return lt(rhs); } template + requires(is_rhs_for_value_v) [[nodiscard]] BoolExpr gte(const TRhs& rhs) const { return compare(BoolExpr::Kind::Gte, rhs); } template + requires(is_rhs_for_value_v) [[nodiscard]] BoolExpr Gte(const TRhs& rhs) const { return gte(rhs); } template + requires(is_rhs_for_value_v) [[nodiscard]] BoolExpr lte(const TRhs& rhs) const { return compare(BoolExpr::Kind::Lte, rhs); } template + requires(is_rhs_for_value_v) [[nodiscard]] BoolExpr Lte(const TRhs& rhs) const { return lte(rhs); } + // Keep incompatible non-column RHS values on a dedicated overload so they + // fail with the same diagnostic shape as mismatched column comparisons. + template + requires(!is_rhs_for_value_v && !is_col>::value) + [[nodiscard]] auto eq(const TRhs&) const { + static_assert(is_rhs_for_value_v, "Column comparison requires both sides to have the same value type."); + return BoolExpr::always(false); + } + template + requires(!is_rhs_for_value_v && !is_col>::value) + [[nodiscard]] auto Eq(const TRhs& rhs) const { return eq(rhs); } + template + requires(!is_rhs_for_value_v && !is_col>::value) + [[nodiscard]] auto ne(const TRhs&) const { + static_assert(is_rhs_for_value_v, "Column comparison requires both sides to have the same value type."); + return BoolExpr::always(false); + } + template + requires(!is_rhs_for_value_v && !is_col>::value) + [[nodiscard]] auto Ne(const TRhs& rhs) const { return ne(rhs); } + template + requires(!is_rhs_for_value_v && !is_col>::value) + [[nodiscard]] auto gt(const TRhs&) const { + static_assert(is_rhs_for_value_v, "Column comparison requires both sides to have the same value type."); + return BoolExpr::always(false); + } + template + requires(!is_rhs_for_value_v && !is_col>::value) + [[nodiscard]] auto Gt(const TRhs& rhs) const { return gt(rhs); } + template + requires(!is_rhs_for_value_v && !is_col>::value) + [[nodiscard]] auto lt(const TRhs&) const { + static_assert(is_rhs_for_value_v, "Column comparison requires both sides to have the same value type."); + return BoolExpr::always(false); + } + template + requires(!is_rhs_for_value_v && !is_col>::value) + [[nodiscard]] auto Lt(const TRhs& rhs) const { return lt(rhs); } + template + requires(!is_rhs_for_value_v && !is_col>::value) + [[nodiscard]] auto gte(const TRhs&) const { + static_assert(is_rhs_for_value_v, "Column comparison requires both sides to have the same value type."); + return BoolExpr::always(false); + } + template + requires(!is_rhs_for_value_v && !is_col>::value) + [[nodiscard]] auto Gte(const TRhs& rhs) const { return gte(rhs); } + template + requires(!is_rhs_for_value_v && !is_col>::value) + [[nodiscard]] auto lte(const TRhs&) const { + static_assert(is_rhs_for_value_v, "Column comparison requires both sides to have the same value type."); + return BoolExpr::always(false); + } + template + requires(!is_rhs_for_value_v && !is_col>::value) + [[nodiscard]] auto Lte(const TRhs& rhs) const { return lte(rhs); } + + template + requires(std::is_same_v) + [[nodiscard]] BoolExpr eq(const Col& rhs) const { return compare(BoolExpr::Kind::Eq, rhs); } + template + requires(std::is_same_v) + [[nodiscard]] BoolExpr Eq(const Col& rhs) const { return eq(rhs); } + template + requires(std::is_same_v) + [[nodiscard]] BoolExpr ne(const Col& rhs) const { return compare(BoolExpr::Kind::Ne, rhs); } + template + requires(std::is_same_v) + [[nodiscard]] BoolExpr Ne(const Col& rhs) const { return ne(rhs); } + template + requires(std::is_same_v) + [[nodiscard]] BoolExpr gt(const Col& rhs) const { return compare(BoolExpr::Kind::Gt, rhs); } + template + requires(std::is_same_v) + [[nodiscard]] BoolExpr Gt(const Col& rhs) const { return gt(rhs); } + template + requires(std::is_same_v) + [[nodiscard]] BoolExpr lt(const Col& rhs) const { return compare(BoolExpr::Kind::Lt, rhs); } + template + requires(std::is_same_v) + [[nodiscard]] BoolExpr Lt(const Col& rhs) const { return lt(rhs); } + template + requires(std::is_same_v) + [[nodiscard]] BoolExpr gte(const Col& rhs) const { return compare(BoolExpr::Kind::Gte, rhs); } + template + requires(std::is_same_v) + [[nodiscard]] BoolExpr Gte(const Col& rhs) const { return gte(rhs); } + template + requires(std::is_same_v) + [[nodiscard]] BoolExpr lte(const Col& rhs) const { return compare(BoolExpr::Kind::Lte, rhs); } + template + requires(std::is_same_v) + [[nodiscard]] BoolExpr Lte(const Col& rhs) const { return lte(rhs); } + + // Keep mismatched column-to-column comparisons on a dedicated overload so + // they fail here with a clear diagnostic instead of disappearing into the + // generic operand-conversion path. + template + requires(!std::is_same_v) + [[nodiscard]] auto eq(const Col&) const { + static_assert(std::is_same_v, "Column comparison requires both sides to have the same value type."); + return BoolExpr::always(false); + } + template + requires(!std::is_same_v) + [[nodiscard]] auto Eq(const Col& rhs) const { return eq(rhs); } + template + requires(!std::is_same_v) + [[nodiscard]] auto ne(const Col&) const { + static_assert(std::is_same_v, "Column comparison requires both sides to have the same value type."); + return BoolExpr::always(false); + } + template + requires(!std::is_same_v) + [[nodiscard]] auto Ne(const Col& rhs) const { return ne(rhs); } + template + requires(!std::is_same_v) + [[nodiscard]] auto gt(const Col&) const { + static_assert(std::is_same_v, "Column comparison requires both sides to have the same value type."); + return BoolExpr::always(false); + } + template + requires(!std::is_same_v) + [[nodiscard]] auto Gt(const Col& rhs) const { return gt(rhs); } + template + requires(!std::is_same_v) + [[nodiscard]] auto lt(const Col&) const { + static_assert(std::is_same_v, "Column comparison requires both sides to have the same value type."); + return BoolExpr::always(false); + } + template + requires(!std::is_same_v) + [[nodiscard]] auto Lt(const Col& rhs) const { return lt(rhs); } + template + requires(!std::is_same_v) + [[nodiscard]] auto gte(const Col&) const { + static_assert(std::is_same_v, "Column comparison requires both sides to have the same value type."); + return BoolExpr::always(false); + } + template + requires(!std::is_same_v) + [[nodiscard]] auto Gte(const Col& rhs) const { return gte(rhs); } + template + requires(!std::is_same_v) + [[nodiscard]] auto lte(const Col&) const { + static_assert(std::is_same_v, "Column comparison requires both sides to have the same value type."); + return BoolExpr::always(false); + } + template + requires(!std::is_same_v) + [[nodiscard]] auto Lte(const Col& rhs) const { return lte(rhs); } + [[nodiscard]] constexpr const ColumnRef& column_ref() const { return column_; } private: @@ -349,6 +521,9 @@ private: ColumnRef column_; }; +template +struct is_col> : std::true_type {}; + namespace detail { template diff --git a/crates/bindings-cpp/include/spacetimedb/query_builder/join.h b/crates/bindings-cpp/include/spacetimedb/query_builder/join.h index 98ee79ea8..b4c63ddae 100644 --- a/crates/bindings-cpp/include/spacetimedb/query_builder/join.h +++ b/crates/bindings-cpp/include/spacetimedb/query_builder/join.h @@ -13,6 +13,12 @@ struct IxJoinEq; template struct member_tag {}; +template +struct is_ix_col : std::false_type {}; + +template +struct is_ix_join_eq : std::false_type {}; + inline std::false_type indexed_member_lookup(...); template @@ -36,12 +42,27 @@ public: return eq(rhs); } + // Keep mismatched indexed-column comparisons on a dedicated overload so they + // fail here with a clear diagnostic instead of falling through to BoolExpr. + template + [[nodiscard]] auto eq(const IxCol&) const { + static_assert(std::is_same_v, "Semijoin indexed equality requires both sides to have the same value type."); + return IxJoinEq{}; + } + + template + [[nodiscard]] auto Eq(const IxCol& rhs) const { + return eq(rhs); + } + template + requires(!is_ix_col>::value) [[nodiscard]] BoolExpr eq(const TRhs& rhs) const { return compare(BoolExpr::Kind::Eq, rhs); } template + requires(!is_ix_col>::value) [[nodiscard]] BoolExpr Eq(const TRhs& rhs) const { return eq(rhs); } [[nodiscard]] constexpr const ColumnRef& column_ref() const { return column_; } @@ -58,6 +79,9 @@ private: friend class IxCol; }; +template +struct is_ix_col> : std::true_type {}; + namespace detail { template @@ -71,6 +95,9 @@ struct IxJoinEq { ColumnRef rhs; }; +template +struct is_ix_join_eq> : std::true_type {}; + template class LeftSemiJoin { public: @@ -269,28 +296,52 @@ template& left, const Table& right, TFn&& predicate) { static_assert(can_be_lookup_table_v>, "Lookup side of a semijoin must opt in via CanBeLookupTable."); const auto join = std::forward(predicate)(left.ix_cols(), right.ix_cols()); - return LeftSemiJoin(left, right, join.lhs, join.rhs); + using TJoin = std::remove_cvref_t; + if constexpr (is_ix_join_eq::value) { + return LeftSemiJoin(left, right, join.lhs, join.rhs); + } else { + static_assert(is_ix_join_eq::value, "Semijoin predicate must compare two indexed columns with eq()."); + return LeftSemiJoin(left, right, {}, {}); + } } template [[nodiscard]] auto left_semijoin_impl(const FromWhere& left, const Table& right, TFn&& predicate) { static_assert(can_be_lookup_table_v>, "Lookup side of a semijoin must opt in via CanBeLookupTable."); const auto join = std::forward(predicate)(left.table().ix_cols(), right.ix_cols()); - return LeftSemiJoin(left.table(), right, join.lhs, join.rhs, left.expr()); + using TJoin = std::remove_cvref_t; + if constexpr (is_ix_join_eq::value) { + return LeftSemiJoin(left.table(), right, join.lhs, join.rhs, left.expr()); + } else { + static_assert(is_ix_join_eq::value, "Semijoin predicate must compare two indexed columns with eq()."); + return LeftSemiJoin(left.table(), right, {}, {}, left.expr()); + } } template [[nodiscard]] auto right_semijoin_impl(const Table& left, const Table& right, TFn&& predicate) { static_assert(can_be_lookup_table_v>, "Lookup side of a semijoin must opt in via CanBeLookupTable."); const auto join = std::forward(predicate)(left.ix_cols(), right.ix_cols()); - return RightSemiJoin(left, right, join.lhs, join.rhs); + using TJoin = std::remove_cvref_t; + if constexpr (is_ix_join_eq::value) { + return RightSemiJoin(left, right, join.lhs, join.rhs); + } else { + static_assert(is_ix_join_eq::value, "Semijoin predicate must compare two indexed columns with eq()."); + return RightSemiJoin(left, right, {}, {}); + } } template [[nodiscard]] auto right_semijoin_impl(const FromWhere& left, const Table& right, TFn&& predicate) { static_assert(can_be_lookup_table_v>, "Lookup side of a semijoin must opt in via CanBeLookupTable."); const auto join = std::forward(predicate)(left.table().ix_cols(), right.ix_cols()); - return RightSemiJoin(left.table(), right, join.lhs, join.rhs, left.expr()); + using TJoin = std::remove_cvref_t; + if constexpr (is_ix_join_eq::value) { + return RightSemiJoin(left.table(), right, join.lhs, join.rhs, left.expr()); + } else { + static_assert(is_ix_join_eq::value, "Semijoin predicate must compare two indexed columns with eq()."); + return RightSemiJoin(left.table(), right, {}, {}, left.expr()); + } } } // namespace detail diff --git a/crates/bindings-cpp/tests/query-builder-compile/fail_implicit_numeric_where_types.cpp b/crates/bindings-cpp/tests/query-builder-compile/fail_implicit_numeric_where_types.cpp new file mode 100644 index 000000000..ca3b3c890 --- /dev/null +++ b/crates/bindings-cpp/tests/query-builder-compile/fail_implicit_numeric_where_types.cpp @@ -0,0 +1,21 @@ +#include + +using namespace SpacetimeDB; + +template +auto TableFor(const char* table_name) { + return QueryBuilder{}.table( + table_name, + query_builder::HasCols::get(table_name), + query_builder::HasIxCols::get(table_name)); +} + +struct PlayerInfo { + uint8_t age; +}; +SPACETIMEDB_STRUCT(PlayerInfo, age) +SPACETIMEDB_TABLE(PlayerInfo, player_info, Public) + +auto invalid_filter = TableFor("player_info").where([](const auto& players) { + return players.age.eq(4200); +}); diff --git a/crates/bindings-cpp/tests/query-builder-compile/fail_incompatible_join_types.cpp b/crates/bindings-cpp/tests/query-builder-compile/fail_incompatible_join_types.cpp new file mode 100644 index 000000000..05fab0255 --- /dev/null +++ b/crates/bindings-cpp/tests/query-builder-compile/fail_incompatible_join_types.cpp @@ -0,0 +1,33 @@ +#include + +using namespace SpacetimeDB; + +template +auto TableFor(const char* table_name) { + return QueryBuilder{}.table( + table_name, + query_builder::HasCols::get(table_name), + query_builder::HasIxCols::get(table_name)); +} + +struct User { + Identity identity; +}; +SPACETIMEDB_STRUCT(User, identity) +SPACETIMEDB_TABLE(User, user, Public) +FIELD_PrimaryKey(user, identity) + +struct Membership { + Identity membership_identity; + uint64_t tenant_id; +}; +SPACETIMEDB_STRUCT(Membership, membership_identity, tenant_id) +SPACETIMEDB_TABLE(Membership, membership, Public) +FIELD_PrimaryKey(membership, membership_identity) +FIELD_Index(membership, tenant_id) + +auto invalid_join = TableFor("user").right_semijoin( + TableFor("membership"), + [](const auto& users, const auto& memberships) { + return users.identity.eq(memberships.tenant_id); + }); diff --git a/crates/bindings-cpp/tests/query-builder-compile/fail_incompatible_where_types.cpp b/crates/bindings-cpp/tests/query-builder-compile/fail_incompatible_where_types.cpp new file mode 100644 index 000000000..9e50dc919 --- /dev/null +++ b/crates/bindings-cpp/tests/query-builder-compile/fail_incompatible_where_types.cpp @@ -0,0 +1,23 @@ +#include + +using namespace SpacetimeDB; + +template +auto TableFor(const char* table_name) { + return QueryBuilder{}.table( + table_name, + query_builder::HasCols::get(table_name), + query_builder::HasIxCols::get(table_name)); +} + +struct User { + Identity identity; + uint64_t tenant_id; +}; +SPACETIMEDB_STRUCT(User, identity, tenant_id) +SPACETIMEDB_TABLE(User, user, Public) +FIELD_PrimaryKey(user, identity) + +auto invalid_filter = TableFor("user").where([](const auto& users) { + return users.identity.eq(users.tenant_id); +}); diff --git a/crates/bindings-cpp/tests/query-builder-compile/fail_invalid_join_predicate.cpp b/crates/bindings-cpp/tests/query-builder-compile/fail_invalid_join_predicate.cpp new file mode 100644 index 000000000..0afc27aa3 --- /dev/null +++ b/crates/bindings-cpp/tests/query-builder-compile/fail_invalid_join_predicate.cpp @@ -0,0 +1,33 @@ +#include + +using namespace SpacetimeDB; + +template +auto TableFor(const char* table_name) { + return QueryBuilder{}.table( + table_name, + query_builder::HasCols::get(table_name), + query_builder::HasIxCols::get(table_name)); +} + +struct User { + uint64_t id; +}; +SPACETIMEDB_STRUCT(User, id) +SPACETIMEDB_TABLE(User, user, Public) +FIELD_PrimaryKey(user, id) + +struct Membership { + uint64_t id; + uint64_t user_id; +}; +SPACETIMEDB_STRUCT(Membership, id, user_id) +SPACETIMEDB_TABLE(Membership, membership, Public) +FIELD_PrimaryKey(membership, id) +FIELD_Index(membership, user_id) + +auto invalid_join = TableFor("user").right_semijoin( + TableFor("membership"), + [](const auto& users, const auto& memberships) { + return users.id.eq(1ULL); + }); diff --git a/crates/bindings-cpp/tests/query-builder-compile/run_query_builder_compile_tests.sh b/crates/bindings-cpp/tests/query-builder-compile/run_query_builder_compile_tests.sh index d64df72d7..fafb27c6d 100644 --- a/crates/bindings-cpp/tests/query-builder-compile/run_query_builder_compile_tests.sh +++ b/crates/bindings-cpp/tests/query-builder-compile/run_query_builder_compile_tests.sh @@ -95,7 +95,11 @@ compile_should_fail() { } compile_should_pass "$SCRIPT_DIR/pass_query_integration.cpp" +compile_should_fail "$SCRIPT_DIR/fail_invalid_join_predicate.cpp" "Semijoin predicate must compare two indexed columns with eq()." +compile_should_fail "$SCRIPT_DIR/fail_incompatible_where_types.cpp" "Column comparison requires both sides to have the same value type." +compile_should_fail "$SCRIPT_DIR/fail_implicit_numeric_where_types.cpp" "Column comparison requires both sides to have the same value type." compile_should_fail "$SCRIPT_DIR/fail_non_index_join.cpp" "no member named 'tenant_id'" +compile_should_fail "$SCRIPT_DIR/fail_incompatible_join_types.cpp" "Semijoin indexed equality requires both sides to have the same value type." compile_should_fail "$SCRIPT_DIR/fail_event_lookup.cpp" "Lookup side of a semijoin must opt in via CanBeLookupTable." echo "All query-builder compile tests passed"