Making tx mut so we can auto-create params on compilation

This commit is contained in:
Mario Alejandro Montoya Cortes
2025-12-25 16:43:09 -05:00
parent d2dde1fb5b
commit 4289cd8517
21 changed files with 379 additions and 324 deletions
+12 -12
View File
@@ -126,10 +126,10 @@ fn eval(c: &mut Criterion) {
// A benchmark runner for the new query engine
let bench_query = |c: &mut Criterion, name, sql| {
c.bench_function(name, |b| {
let tx = raw.db.begin_tx(Workload::Subscribe);
let mut tx = raw.db.begin_tx(Workload::Subscribe);
let auth = AuthCtx::for_testing();
let schema_viewer = &SchemaViewer::new(&tx, &auth);
let (plans, table_id, table_name, _) = compile_subscription(sql, schema_viewer, &auth).unwrap();
let mut schema_viewer = SchemaViewer::new(&mut tx, &auth);
let (plans, table_id, table_name, _) = compile_subscription(sql, &mut schema_viewer, &auth).unwrap();
let plans = plans
.into_iter()
.map(|plan| plan.optimize(&auth).unwrap())
@@ -155,8 +155,8 @@ fn eval(c: &mut Criterion) {
let bench_eval = |c: &mut Criterion, name, sql| {
c.bench_function(name, |b| {
let tx = raw.db.begin_tx(Workload::Update);
let query = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), &tx, sql).unwrap();
let mut tx = raw.db.begin_tx(Workload::Update);
let query = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), &mut tx, sql).unwrap();
let query: ExecutionSet = query.into();
b.iter(|| {
@@ -207,11 +207,11 @@ fn eval(c: &mut Criterion) {
// A passthru executed independently of the database.
let select_lhs = "select * from footprint";
let select_rhs = "select * from location";
let tx = &raw.db.begin_tx(Workload::Update);
let query_lhs = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), tx, select_lhs).unwrap();
let query_rhs = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), tx, select_rhs).unwrap();
let mut tx = raw.db.begin_tx(Workload::Update);
let query_lhs = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), &mut tx, select_lhs).unwrap();
let query_rhs = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), &mut tx, select_rhs).unwrap();
let query = ExecutionSet::from_iter(query_lhs.into_iter().chain(query_rhs));
let tx = &tx.into();
let tx = &(&mut tx).into();
b.iter(|| drop(black_box(query.eval_incr_for_test(&raw.db, tx, &update, None))))
});
@@ -226,10 +226,10 @@ fn eval(c: &mut Criterion) {
from footprint join location on footprint.entity_id = location.entity_id \
where location.chunk_index = {chunk_index}"
);
let tx = &raw.db.begin_tx(Workload::Update);
let query = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), tx, &join).unwrap();
let mut tx = raw.db.begin_tx(Workload::Update);
let query = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), &mut tx, &join).unwrap();
let query: ExecutionSet = query.into();
let tx = &tx.into();
let tx = &(&mut tx).into();
b.iter(|| drop(black_box(query.eval_incr_for_test(&raw.db, tx, &update, None))));
});
+5 -5
View File
@@ -181,8 +181,8 @@ mod tests {
}
fn num_rows_for(db: &RelationalDB, sql: &str) -> u64 {
let tx = begin_tx(db);
match &*compile_sql(db, &AuthCtx::for_testing(), &tx, sql).expect("Failed to compile sql") {
let mut tx = begin_tx(db);
match &*compile_sql(db, &AuthCtx::for_testing(), &mut tx, sql).expect("Failed to compile sql") {
[CrudExpr::Query(expr)] => num_rows(&tx, expr),
exprs => panic!("unexpected result from compilation: {exprs:#?}"),
}
@@ -191,10 +191,10 @@ mod tests {
/// Using the new query plan
fn new_row_estimate(db: &RelationalDB, sql: &str) -> u64 {
let auth = AuthCtx::for_testing();
let tx = begin_tx(db);
let tx = SchemaViewer::new(&tx, &auth);
let mut tx = begin_tx(db);
let mut tx = SchemaViewer::new(&mut tx, &auth);
compile_subscription(sql, &tx, &auth)
compile_subscription(sql, &mut tx, &auth)
.map(|(plans, ..)| plans)
.expect("failed to compile sql query")
.into_iter()
+3 -3
View File
@@ -1869,7 +1869,7 @@ impl ModuleHost {
let metrics = self
.on_module_thread("one_off_query", move || {
let (tx_offset_sender, tx_offset_receiver) = oneshot::channel();
let tx = scopeguard::guard(db.begin_tx(Workload::Sql), |tx| {
let mut tx = scopeguard::guard(db.begin_tx(Workload::Sql), |tx| {
let (tx_offset, tx_metrics, reducer) = db.release_tx(tx);
let _ = tx_offset_sender.send(tx_offset);
db.report_read_tx_metrics(reducer, tx_metrics);
@@ -1878,7 +1878,7 @@ impl ModuleHost {
// We wrap the actual query in a closure so we can use ? to handle errors without making
// the entire transaction abort with an error.
let result: Result<(OneOffTable<F>, ExecutionMetrics), anyhow::Error> = (|| {
let tx = SchemaViewer::new(&*tx, &auth);
let mut tx = SchemaViewer::new(&mut *tx, &auth);
let (
// A query may compile down to several plans.
@@ -1888,7 +1888,7 @@ impl ModuleHost {
_,
table_name,
_,
) = compile_subscription(&query, &tx, &auth)?;
) = compile_subscription(&query, &mut tx, &auth)?;
// Optimize each fragment
let optimized = plans
@@ -1126,10 +1126,10 @@ impl InstanceCommon {
// Views bypass RLS, since views should enforce their own access control procedurally.
let auth = AuthCtx::for_current(self.info.database_identity);
let schema_view = SchemaViewer::new(&*tx, &auth);
let mut schema_view = SchemaViewer::new(&mut *tx, &auth);
// Compile to subscription plans.
let (plans, has_params) = SubscriptionPlan::compile(the_query, &schema_view, &auth)?;
let (plans, has_params) = SubscriptionPlan::compile(the_query, &mut schema_view, &auth)?;
ensure!(
!has_params,
"parameterized SQL is not supported for view materialization yet"
+17 -6
View File
@@ -5,6 +5,7 @@ use spacetimedb_data_structures::map::{HashCollectionExt as _, IntMap};
use spacetimedb_datastore::locking_tx_datastore::state_view::StateView;
use spacetimedb_datastore::system_tables::{StRowLevelSecurityFields, ST_ROW_LEVEL_SECURITY_ID};
use spacetimedb_expr::check::{SchemaView, TypingResult};
use spacetimedb_expr::errors::TypingError;
use spacetimedb_expr::statement::compile_sql_stmt;
use spacetimedb_lib::identity::AuthCtx;
use spacetimedb_primitives::{ArgId, ColId, TableId};
@@ -22,7 +23,7 @@ use sqlparser::ast::{
};
use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser;
use std::ops::Deref;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
/// Simplify to detect features of the syntax we don't support yet
@@ -477,7 +478,7 @@ fn compile_where(table: &From, filter: Option<SqlExpr>) -> Result<Option<Selecti
}
pub struct SchemaViewer<'a, T> {
pub(crate) tx: &'a T,
tx: &'a mut T,
auth: &'a AuthCtx,
}
@@ -489,6 +490,12 @@ impl<T> Deref for SchemaViewer<'_, T> {
}
}
impl<T> DerefMut for SchemaViewer<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.tx
}
}
impl<T: StateView> SchemaView for SchemaViewer<'_, T> {
fn table_id(&self, name: &str) -> Option<TableId> {
// Get the schema from the in-memory state instead of fetching from the database for speed
@@ -536,10 +543,15 @@ impl<T: StateView> SchemaView for SchemaViewer<'_, T> {
})
.collect::<anyhow::Result<_>>()
}
fn get_or_create_params(&mut self, _params: ProductValue) -> TypingResult<ArgId> {
// Caller should have used `SchemaViewerMut` on crate `core`
Err(TypingError::ParamsReadOnly)
}
}
impl<'a, T> SchemaViewer<'a, T> {
pub fn new(tx: &'a T, auth: &'a AuthCtx) -> Self {
pub fn new(tx: &'a mut T, auth: &'a AuthCtx) -> Self {
Self { tx, auth }
}
}
@@ -1000,13 +1012,12 @@ fn compile_statement<T: TableSchemaView + StateView>(
pub(crate) fn compile_to_ast<T: TableSchemaView + StateView>(
db: &RelationalDB,
auth: &AuthCtx,
tx: &T,
tx: &mut T,
sql_text: &str,
) -> Result<Vec<SqlAst>, DBError> {
// NOTE: The following ensures compliance with the 1.0 sql api.
// Come 1.0, it will have replaced the current compilation stack.
compile_sql_stmt(sql_text, &SchemaViewer::new(tx, auth), auth)?;
compile_sql_stmt(sql_text, &mut SchemaViewer::new(tx, auth), auth)?;
let dialect = PostgreSqlDialect {};
let ast = Parser::parse_sql(&dialect, sql_text).map_err(|error| DBError::SqlParser {
sql: sql_text.to_string(),
+41 -41
View File
@@ -23,7 +23,7 @@ const MAX_SQL_LENGTH: usize = 50_000;
pub fn compile_sql<T: TableSchemaView + StateView>(
db: &RelationalDB,
auth: &AuthCtx,
tx: &T,
tx: &mut T,
sql_text: &str,
) -> Result<Vec<CrudExpr>, DBError> {
if sql_text.len() > MAX_SQL_LENGTH {
@@ -266,7 +266,7 @@ mod tests {
fn compile_sql<T: TableSchemaView + StateView>(
db: &RelationalDB,
tx: &T,
tx: &mut T,
sql: &str,
) -> Result<Vec<CrudExpr>, DBError> {
super::compile_sql(db, &AuthCtx::for_testing(), tx, sql)
@@ -281,10 +281,10 @@ mod tests {
let indexes = &[];
db.create_table_for_test("test", schema, indexes)?;
let tx = begin_tx(&db);
let mut tx = begin_tx(&db);
// Compile query
let sql = "select * from test where a = 1";
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else {
panic!("Expected QueryExpr");
};
assert_eq!(1, query.len());
@@ -303,10 +303,10 @@ mod tests {
&[1.into(), 0.into()],
)?;
let tx = begin_tx(&db);
let mut tx = begin_tx(&db);
// Should work with any qualified field.
let sql = "select * from test where a = 1 and b <> 3";
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else {
panic!("Expected QueryExpr");
};
assert_eq!(2, query.len());
@@ -324,10 +324,10 @@ mod tests {
let indexes = &[0.into()];
db.create_table_for_test("test", schema, indexes)?;
let tx = begin_tx(&db);
let mut tx = begin_tx(&db);
//Compile query
let sql = "select * from test where a = 1";
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else {
panic!("Expected QueryExpr");
};
assert_eq!(1, query.len());
@@ -377,11 +377,11 @@ mod tests {
let rows = run_for_testing(&db, sql)?;
let tx = begin_tx(&db);
let mut tx = begin_tx(&db);
let CrudExpr::Query(QueryExpr {
source: _,
query: mut ops,
}) = compile_sql(&db, &tx, sql)?.remove(0)
}) = compile_sql(&db, &mut tx, sql)?.remove(0)
else {
panic!("Expected QueryExpr");
};
@@ -407,11 +407,11 @@ mod tests {
let indexes = &[1.into()];
db.create_table_for_test("test", schema, indexes)?;
let tx = begin_tx(&db);
let mut tx = begin_tx(&db);
// Note, order does not matter.
// The sargable predicate occurs last, but we can still generate an index scan.
let sql = "select * from test where a = 1 and b = 2";
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else {
panic!("Expected QueryExpr");
};
assert_eq!(2, query.len());
@@ -429,11 +429,11 @@ mod tests {
let indexes = &[1.into()];
db.create_table_for_test("test", schema, indexes)?;
let tx = begin_tx(&db);
let mut tx = begin_tx(&db);
// Note, order does not matter.
// The sargable predicate occurs first and we can generate an index scan.
let sql = "select * from test where b = 2 and a = 1";
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else {
panic!("Expected QueryExpr");
};
assert_eq!(2, query.len());
@@ -455,9 +455,9 @@ mod tests {
];
db.create_table_for_test_multi_column("test", schema, col_list![0, 1])?;
let tx = begin_mut_tx(&db);
let mut tx = begin_mut_tx(&db);
let sql = "select * from test where b = 2 and a = 1";
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else {
panic!("Expected QueryExpr");
};
assert_eq!(1, query.len());
@@ -474,10 +474,10 @@ mod tests {
let indexes = &[0.into(), 1.into()];
db.create_table_for_test("test", schema, indexes)?;
let tx = begin_tx(&db);
let mut tx = begin_tx(&db);
// Compile query
let sql = "select * from test where a = 1 or b = 2";
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else {
panic!("Expected QueryExpr");
};
assert_eq!(1, query.len());
@@ -495,10 +495,10 @@ mod tests {
let indexes = &[1.into()];
db.create_table_for_test("test", schema, indexes)?;
let tx = begin_tx(&db);
let mut tx = begin_tx(&db);
// Compile query
let sql = "select * from test where b > 2";
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else {
panic!("Expected QueryExpr");
};
assert_eq!(1, query.len());
@@ -516,10 +516,10 @@ mod tests {
let indexes = &[1.into()];
db.create_table_for_test("test", schema, indexes)?;
let tx = begin_tx(&db);
let mut tx = begin_tx(&db);
// Compile query
let sql = "select * from test where b > 2 and b < 5";
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else {
panic!("Expected QueryExpr");
};
assert_eq!(1, query.len());
@@ -542,11 +542,11 @@ mod tests {
let indexes = &[0.into(), 1.into()];
db.create_table_for_test("test", schema, indexes)?;
let tx = begin_tx(&db);
let mut tx = begin_tx(&db);
// Note, order matters - the equality condition occurs first which
// means an index scan will be generated rather than the range condition.
let sql = "select * from test where a = 3 and b > 2 and b < 5";
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else {
panic!("Expected QueryExpr");
};
assert_eq!(2, query.len());
@@ -569,10 +569,10 @@ mod tests {
let indexes = &[];
let rhs_id = db.create_table_for_test("rhs", schema, indexes)?;
let tx = begin_tx(&db);
let mut tx = begin_tx(&db);
// Should push sargable equality condition below join
let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where lhs.a = 3";
let exp = compile_sql(&db, &tx, sql)?.remove(0);
let exp = compile_sql(&db, &mut tx, sql)?.remove(0);
let CrudExpr::Query(QueryExpr {
source: source_lhs,
@@ -621,10 +621,10 @@ mod tests {
let schema = &[("b", AlgebraicType::U64), ("c", AlgebraicType::U64)];
let rhs_id = db.create_table_for_test("rhs", schema, &[])?;
let tx = begin_tx(&db);
let mut tx = begin_tx(&db);
// Should push equality condition below join
let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where lhs.a = 3";
let exp = compile_sql(&db, &tx, sql)?.remove(0);
let exp = compile_sql(&db, &mut tx, sql)?.remove(0);
let CrudExpr::Query(QueryExpr {
source: source_lhs,
@@ -678,10 +678,10 @@ mod tests {
let schema = &[("b", AlgebraicType::U64), ("c", AlgebraicType::U64)];
let rhs_id = db.create_table_for_test("rhs", schema, &[])?;
let tx = begin_tx(&db);
let mut tx = begin_tx(&db);
// Should push equality condition below join
let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c = 3";
let exp = compile_sql(&db, &tx, sql)?.remove(0);
let exp = compile_sql(&db, &mut tx, sql)?.remove(0);
let CrudExpr::Query(QueryExpr {
source: source_lhs,
@@ -736,11 +736,11 @@ mod tests {
let indexes = &[1.into()];
let rhs_id = db.create_table_for_test("rhs", schema, indexes)?;
let tx = begin_tx(&db);
let mut tx = begin_tx(&db);
// Should push the sargable equality condition into the join's left arg.
// Should push the sargable range condition into the join's right arg.
let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where lhs.a = 3 and rhs.c < 4";
let exp = compile_sql(&db, &tx, sql)?.remove(0);
let exp = compile_sql(&db, &mut tx, sql)?.remove(0);
let CrudExpr::Query(QueryExpr {
source: source_lhs,
@@ -807,11 +807,11 @@ mod tests {
let indexes = &[0.into(), 1.into()];
let rhs_id = db.create_table_for_test("rhs", schema, indexes)?;
let tx = begin_tx(&db);
let mut tx = begin_tx(&db);
// Should generate an index join since there is an index on `lhs.b`.
// Should push the sargable range condition into the index join's probe side.
let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c > 2 and rhs.c < 4 and rhs.d = 3";
let exp = compile_sql(&db, &tx, sql)?.remove(0);
let exp = compile_sql(&db, &mut tx, sql)?.remove(0);
let CrudExpr::Query(QueryExpr {
source: SourceExpr::DbTable(DbTable { table_id, .. }),
@@ -889,11 +889,11 @@ mod tests {
let indexes = col_list![0, 1];
let rhs_id = db.create_table_for_test_multi_column("rhs", schema, indexes)?;
let tx = begin_tx(&db);
let mut tx = begin_tx(&db);
// Should generate an index join since there is an index on `lhs.b`.
// Should push the sargable range condition into the index join's probe side.
let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c = 2 and rhs.b = 4 and rhs.d = 3";
let exp = compile_sql(&db, &tx, sql)?.remove(0);
let exp = compile_sql(&db, &mut tx, sql)?.remove(0);
let CrudExpr::Query(QueryExpr {
source: SourceExpr::DbTable(DbTable { table_id, .. }),
@@ -953,7 +953,7 @@ mod tests {
let db = TestDB::durable()?;
db.create_table_for_test("A", &[("x", AlgebraicType::U64)], &[])?;
db.create_table_for_test("B", &[("y", AlgebraicType::U64)], &[])?;
assert!(compile_sql(&db, &begin_tx(&db), "select B.* from B join A on B.y = A.x").is_ok());
assert!(compile_sql(&db, &mut begin_tx(&db), "select B.* from B join A on B.y = A.x").is_ok());
Ok(())
}
@@ -970,27 +970,27 @@ mod tests {
// TODO: Type check other operations deferred for the new query engine.
assert!(
compile_sql(&db, &begin_tx(&db), sql).is_err(),
compile_sql(&db, &mut begin_tx(&db), sql).is_err(),
// Err("SqlError: Type Mismatch: `PlayerState.entity_id: U64` != `String(\"161853\"): String`, executing: `SELECT * FROM PlayerState WHERE entity_id = '161853'`".into())
);
// Check we can still compile the query if we remove the type mismatch and have multiple logical operations.
let sql = "SELECT * FROM PlayerState WHERE entity_id = 1 AND entity_id = 2 AND entity_id = 3 OR entity_id = 4 OR entity_id = 5";
assert!(compile_sql(&db, &begin_tx(&db), sql).is_ok());
assert!(compile_sql(&db, &mut begin_tx(&db), sql).is_ok());
// Now verify when we have a type mismatch in the middle of the logical operations.
let sql = "SELECT * FROM PlayerState WHERE entity_id = 1 AND entity_id";
assert!(
compile_sql(&db, &begin_tx(&db), sql).is_err(),
compile_sql(&db, &mut begin_tx(&db), sql).is_err(),
// Err("SqlError: Type Mismatch: `PlayerState.entity_id: U64 == U64(1): U64` and `PlayerState.entity_id: U64`, both sides must be an `Bool` expression, executing: `SELECT * FROM PlayerState WHERE entity_id = 1 AND entity_id`".into())
);
// Verify that all operands of `AND` must be `Bool`.
let sql = "SELECT * FROM PlayerState WHERE entity_id AND entity_id";
assert!(
compile_sql(&db, &begin_tx(&db), sql).is_err(),
compile_sql(&db, &mut begin_tx(&db), sql).is_err(),
// Err("SqlError: Type Mismatch: `PlayerState.entity_id: U64` and `PlayerState.entity_id: U64`, both sides must be an `Bool` expression, executing: `SELECT * FROM PlayerState WHERE entity_id AND entity_id`".into())
);
Ok(())
+37 -20
View File
@@ -20,9 +20,8 @@ use anyhow::anyhow;
use spacetimedb_datastore::execution_context::Workload;
use spacetimedb_datastore::locking_tx_datastore::state_view::StateView;
use spacetimedb_datastore::traits::IsolationLevel;
use spacetimedb_expr::check::SchemaView;
use spacetimedb_expr::check::{SchemaView, TypingResult};
use spacetimedb_expr::errors::TypingError;
use spacetimedb_expr::expr::CallParams;
use spacetimedb_expr::statement::Statement;
use spacetimedb_lib::identity::AuthCtx;
use spacetimedb_lib::metrics::ExecutionMetrics;
@@ -191,15 +190,27 @@ pub struct SqlResult {
pub metrics: ExecutionMetrics,
}
struct DbParams<'a> {
struct SchemaViewerMut<'a> {
db: &'a RelationalDB,
tx: &'a mut MutTx,
schema: SchemaViewer<'a, MutTx>,
}
impl CallParams for DbParams<'_> {
fn create_or_get_param(&mut self, param: &ProductValue) -> Result<ArgId, TypingError> {
impl SchemaView for SchemaViewerMut<'_> {
fn table_id(&self, name: &str) -> Option<TableId> {
self.schema.table_id(name)
}
fn schema_for_table(&self, table_id: TableId) -> Option<Arc<TableOrViewSchema>> {
self.schema.schema_for_table(table_id)
}
fn rls_rules_for_table(&self, table_id: TableId) -> anyhow::Result<Vec<Box<str>>> {
self.schema.rls_rules_for_table(table_id)
}
fn get_or_create_params(&mut self, params: ProductValue) -> TypingResult<ArgId> {
self.db
.create_or_get_params(self.tx, &param)
.create_or_get_params(&mut self.schema, &params)
.map_err(|err| TypingError::Other(err.into()))
}
}
@@ -215,20 +226,16 @@ pub async fn run(
) -> Result<SqlResult, DBError> {
// We parse the sql statement in a mutable transaction.
// If it turns out to be a query, we downgrade the tx.
let (tx, stmt) =
db.with_auto_rollback(
db.begin_mut_tx(IsolationLevel::Serializable, Workload::Sql),
|tx| match compile_sql_stmt(sql_text, &SchemaViewer::new(tx, &auth), &auth) {
Ok(Statement::Select(mut stmt)) => {
stmt.for_each_fun_call(&mut |param| {
db.create_or_get_params(tx, &param)
.map_err(|err| TypingError::Other(err.into()))
})?;
Ok(Statement::Select(stmt))
}
result => result,
let (tx, stmt) = db.with_auto_rollback(db.begin_mut_tx(IsolationLevel::Serializable, Workload::Sql), |tx| {
compile_sql_stmt(
sql_text,
&mut SchemaViewerMut {
db,
schema: SchemaViewer::new(tx, &auth),
},
)?;
&auth,
)
})?;
let mut metrics = ExecutionMetrics::default();
@@ -1619,6 +1626,10 @@ pub(crate) mod tests {
true,
)?;
let arg_id = ST_RESERVED_SEQUENCE_RANGE as u64;
assert_eq!(
run_for_testing(&db, "select view_id, param_pos, param_name FROM st_view_param")?,
vec![product![arg_id as u32, 0u16, "x"]]
);
with_auto_commit(&db, |tx| -> Result<_, DBError> {
tests_utils::insert_into_view(&db, tx, table_id, None, product![arg_id + 1, 0u8, 1i64])?;
@@ -1636,6 +1647,12 @@ pub(crate) mod tests {
vec![product![0u8, 1i64]]
);
// We have created the internal rows for view args
assert_eq!(
run_for_testing(&db, "select id FROM st_view_arg")?,
vec![product![arg_id], product![arg_id + 1]]
);
Ok(())
}
}
+1 -1
View File
@@ -19,7 +19,7 @@ impl RowLevelExpr {
auth_ctx: &AuthCtx,
rls: &RawRowLevelSecurityDefV9,
) -> anyhow::Result<Self> {
let (sql, _) = parse_and_type_sub(&rls.sql, &SchemaViewer::new(tx, auth_ctx), auth_ctx)?;
let (sql, _) = parse_and_type_sub(&rls.sql, &mut SchemaViewer::new(tx, auth_ctx), auth_ctx)?;
let table_id = sql.return_table_id().unwrap();
let schema = tx.schema_for_table(table_id)?;
@@ -472,7 +472,7 @@ impl ModuleSubscriptions {
let hash = QueryHash::from_string(&sql, auth.caller(), false);
let hash_with_param = QueryHash::from_string(&sql, auth.caller(), true);
let (mut_tx, _) = self.begin_mut_tx(Workload::Subscribe);
let (mut mut_tx, _) = self.begin_mut_tx(Workload::Subscribe);
let existing_query = {
let guard = self.subscriptions.read();
@@ -482,7 +482,7 @@ impl ModuleSubscriptions {
let query = return_on_err_with_sql!(
existing_query.map(Ok).unwrap_or_else(|| compile_query_with_hashes(
&auth,
&*mut_tx,
&mut *mut_tx,
&sql,
hash,
hash_with_param
@@ -736,7 +736,7 @@ impl ModuleSubscriptions {
}
// We always get the db lock before the subscription lock to avoid deadlocks.
let (mut_tx, _tx_offset) = self.begin_mut_tx(Workload::Subscribe);
let (mut mut_tx, _tx_offset) = self.begin_mut_tx(Workload::Subscribe);
let compile_timer = metrics.compilation_time.start_timer();
@@ -752,7 +752,7 @@ impl ModuleSubscriptions {
super::subscription::get_all(
|relational_db, tx| relational_db.get_all_tables_mut(tx).map(|schemas| schemas.into_iter()),
&self.relational_db,
&*mut_tx,
&mut *mut_tx,
&auth,
)?
.into_iter()
@@ -769,7 +769,7 @@ impl ModuleSubscriptions {
plans.push(unit);
} else {
plans.push(Arc::new(
compile_query_with_hashes(&auth, &*mut_tx, sql, hash, hash_with_param).map_err(|err| {
compile_query_with_hashes(&auth, &mut *mut_tx, sql, hash, hash_with_param).map_err(|err| {
DBError::WithSql {
error: Box::new(DBError::Other(err.into())),
sql: sql.into(),
@@ -1807,8 +1807,8 @@ mod tests {
let auth = AuthCtx::for_testing();
let sql = "select * from t where id = 1";
let tx = begin_tx(&db);
let plan = compile_read_only_query(&auth, &tx, sql)?;
let mut tx = begin_tx(&db);
let plan = compile_read_only_query(&auth, &mut tx, sql)?;
let plan = Arc::new(plan);
let (_, metrics) = subs.evaluate_queries(sender, &[plan], &tx, &auth, TableUpdateType::Subscribe)?;
@@ -1681,8 +1681,8 @@ mod tests {
fn compile_plan(db: &RelationalDB, sql: &str) -> ResultTest<Arc<Plan>> {
with_read_only(db, |tx| {
let auth = AuthCtx::for_testing();
let tx = SchemaViewer::new(&*tx, &auth);
let (plans, has_param) = SubscriptionPlan::compile(sql, &tx, &auth).unwrap();
let mut tx = SchemaViewer::new(tx, &auth);
let (plans, has_param) = SubscriptionPlan::compile(sql, &mut tx, &auth).unwrap();
let hash = QueryHash::from_string(sql, auth.caller(), has_param);
Ok(Arc::new(Plan::new(plans, hash, sql.into())))
})
+55 -22
View File
@@ -42,7 +42,7 @@ pub fn is_subscribe_to_all_tables(sql: &str) -> bool {
pub fn compile_read_only_queryset(
relational_db: &RelationalDB,
auth: &AuthCtx,
tx: &Tx,
tx: &mut Tx,
input: &str,
) -> Result<Vec<SupportedQuery>, DBError> {
let input = input.trim();
@@ -82,13 +82,13 @@ pub fn compile_read_only_queryset(
/// Compile a string into a single read-only query.
/// This returns an error if the string has multiple queries or mutations.
pub fn compile_read_only_query(auth: &AuthCtx, tx: &Tx, input: &str) -> Result<Plan, DBError> {
pub fn compile_read_only_query(auth: &AuthCtx, tx: &mut Tx, input: &str) -> Result<Plan, DBError> {
if is_whitespace_or_empty(input) {
return Err(SubscriptionError::Empty.into());
}
let tx = SchemaViewer::new(tx, auth);
let (plans, has_param) = SubscriptionPlan::compile(input, &tx, auth)?;
let mut tx = SchemaViewer::new(tx, auth);
let (plans, has_param) = SubscriptionPlan::compile(input, &mut tx, auth)?;
let hash = QueryHash::from_string(input, auth.caller(), has_param);
Ok(Plan::new(plans, hash, input.to_owned()))
}
@@ -97,7 +97,7 @@ pub fn compile_read_only_query(auth: &AuthCtx, tx: &Tx, input: &str) -> Result<P
/// This returns an error if the string has multiple queries or mutations.
pub fn compile_query_with_hashes<Tx: Datastore + StateView>(
auth: &AuthCtx,
tx: &Tx,
tx: &mut Tx,
input: &str,
hash: QueryHash,
hash_with_param: QueryHash,
@@ -106,8 +106,8 @@ pub fn compile_query_with_hashes<Tx: Datastore + StateView>(
return Err(SubscriptionError::Empty.into());
}
let tx = SchemaViewer::new(tx, auth);
let (plans, has_param) = SubscriptionPlan::compile(input, &tx, auth)?;
let mut tx = SchemaViewer::new(tx, auth);
let (plans, has_param) = SubscriptionPlan::compile(input, &mut tx, auth)?;
if auth.bypass_rls() || has_param {
// Note that when generating hashes for queries from owners,
@@ -151,7 +151,7 @@ mod tests {
use crate::db::relational_db::tests_utils::{
begin_mut_tx, begin_tx, insert, with_auto_commit, with_read_only, TestDB,
};
use crate::db::relational_db::MutTx;
use crate::db::relational_db::{tests_utils, MutTx};
use crate::host::module_host::{DatabaseTableUpdate, DatabaseUpdate, UpdatesRelValue};
use crate::sql::execute::collect_result;
use crate::sql::execute::tests::run_for_testing;
@@ -164,6 +164,7 @@ mod tests {
use itertools::Itertools;
use spacetimedb_client_api_messages::websocket::{BsatnFormat, CompressableQueryUpdate, Compression};
use spacetimedb_datastore::execution_context::Workload;
use spacetimedb_datastore::system_tables::ST_RESERVED_SEQUENCE_RANGE;
use spacetimedb_lib::bsatn;
use spacetimedb_lib::db::auth::{StAccess, StTableType};
use spacetimedb_lib::error::ResultTest;
@@ -421,9 +422,9 @@ mod tests {
db.create_table_for_test("a", schema, indexes)?;
db.create_table_for_test("b", schema, indexes)?;
let tx = begin_tx(&db);
let mut tx = begin_tx(&db);
let sql = "SELECT b.* FROM b JOIN a ON b.n = a.n WHERE b.data > 200";
let result = compile_read_only_query(&AuthCtx::for_testing(), &tx, sql);
let result = compile_read_only_query(&AuthCtx::for_testing(), &mut tx, sql);
assert!(result.is_ok());
Ok(())
}
@@ -454,10 +455,10 @@ mod tests {
};
db.commit_tx(tx)?;
let tx = begin_tx(&db);
let mut tx = begin_tx(&db);
let sql = "select * from test where b = 3";
let mut exp = compile_sql(&db, &AuthCtx::for_testing(), &tx, sql)?;
let mut exp = compile_sql(&db, &AuthCtx::for_testing(), &mut tx, sql)?;
let Some(CrudExpr::Query(query)) = exp.pop() else {
panic!("unexpected query {:#?}", exp[0]);
@@ -609,8 +610,8 @@ mod tests {
AND MobileEntityState.location_z > 96000 \
AND MobileEntityState.location_z < 192000";
let tx = begin_tx(&db);
let qset = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &tx, sql_query)?;
let mut tx = begin_tx(&db);
let qset = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &mut tx, sql_query)?;
for q in qset {
let result = run_query(
@@ -684,7 +685,7 @@ mod tests {
let indexes = &[ColId(0), ColId(1)];
db.create_table_for_test("rhs", schema, indexes)?;
let tx = begin_tx(&db);
let mut tx = begin_tx(&db);
// All single table queries are supported
let scans = [
@@ -696,7 +697,7 @@ mod tests {
"SELECT * FROM lhs WHERE id > 5",
];
for scan in scans {
let expr = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &tx, scan)?
let expr = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &mut tx, scan)?
.pop()
.unwrap();
assert_eq!(expr.kind(), Supported::Select, "{scan}\n{expr:#?}");
@@ -705,7 +706,7 @@ mod tests {
// Only index semijoins are supported
let joins = ["SELECT lhs.* FROM lhs JOIN rhs ON lhs.id = rhs.id WHERE rhs.y < 10"];
for join in joins {
let expr = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &tx, join)?
let expr = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &mut tx, join)?
.pop()
.unwrap();
assert_eq!(expr.kind(), Supported::Semijoin, "{join}\n{expr:#?}");
@@ -718,7 +719,7 @@ mod tests {
"SELECT * FROM lhs JOIN rhs ON lhs.id = rhs.id WHERE lhs.x < 10",
];
for join in joins {
match compile_read_only_queryset(&db, &AuthCtx::for_testing(), &tx, join) {
match compile_read_only_queryset(&db, &AuthCtx::for_testing(), &mut tx, join) {
Err(DBError::Subscription(SubscriptionError::Unsupported(_)) | DBError::TypeError(_)) => (),
x => panic!("Unexpected: {x:?}"),
}
@@ -756,10 +757,10 @@ mod tests {
fn compile_query(db: &RelationalDB) -> ResultTest<SubscriptionPlan> {
with_read_only(db, |tx| {
let auth = AuthCtx::for_testing();
let tx = SchemaViewer::new(tx, &auth);
let mut tx = SchemaViewer::new(tx, &auth);
// Should be answered using an index semijion
let sql = "select lhs.* from lhs join rhs on lhs.id = rhs.id where rhs.y >= 2 and rhs.y <= 4";
Ok(SubscriptionPlan::compile(sql, &tx, &auth)
Ok(SubscriptionPlan::compile(sql, &mut tx, &auth)
.map(|(mut plans, _)| {
assert_eq!(plans.len(), 1);
plans.pop().unwrap()
@@ -781,10 +782,10 @@ mod tests {
fn compile_query(db: &RelationalDB) -> ResultTest<SubscriptionPlan> {
with_read_only(db, |tx| {
let auth = AuthCtx::for_testing();
let tx = SchemaViewer::new(tx, &auth);
let mut tx = SchemaViewer::new(tx, &auth);
// Should be answered using an index semijion
let sql = "select lhs.* from lhs join rhs on lhs.id = rhs.id where lhs.x >= 5 and lhs.x <= 7";
Ok(SubscriptionPlan::compile(sql, &tx, &auth)
Ok(SubscriptionPlan::compile(sql, &mut tx, &auth)
.map(|(mut plans, _)| {
assert_eq!(plans.len(), 1);
plans.pop().unwrap()
@@ -1447,4 +1448,36 @@ mod tests {
assert_eq!(metrics.index_seeks, 8);
Ok(())
}
// Verify calling views with params
// TODO: All testing use the old query compiler, so we can't test this yet.
#[test]
fn test_view_params() -> ResultTest<()> {
let db = TestDB::durable()?;
let schema = [("a", AlgebraicType::U8), ("b", AlgebraicType::I64)];
let (_view_id, table_id) = tests_utils::create_view_for_test(
&db,
"my_view",
&schema,
ProductType::from([("x", AlgebraicType::U8)]),
true,
)?;
let arg_id = ST_RESERVED_SEQUENCE_RANGE as u64;
with_auto_commit(&db, |tx| -> Result<_, DBError> {
tests_utils::insert_into_view(&db, tx, table_id, None, product![arg_id + 1, 0u8, 1i64])?;
tests_utils::insert_into_view(&db, tx, table_id, None, product![arg_id, 1u8, 2i64])?;
Ok(())
})?;
let mut tx = begin_tx(&db);
let err =
compile_read_only_queryset(&db, &AuthCtx::for_testing(), &mut tx, "SELECT * FROM my_view(1)").unwrap_err();
assert_eq!(
err.to_string(),
"InternalError: Read-only queries cannot create parameters".to_string()
);
Ok(())
}
}
+9 -9
View File
@@ -615,7 +615,7 @@ impl AuthAccess for ExecutionSet {
pub(crate) fn get_all<T, F, I>(
get_all_tables: F,
relational_db: &RelationalDB,
tx: &T,
tx: &mut T,
auth: &AuthCtx,
) -> Result<Vec<Plan>, DBError>
where
@@ -627,8 +627,8 @@ where
.filter(|t| t.table_type == StTableType::User && auth.has_read_access(t.table_access))
.map(|schema| {
let sql = format!("SELECT * FROM {}", schema.table_name);
let tx = SchemaViewer::new(tx, auth);
SubscriptionPlan::compile(&sql, &tx, auth).map(|(plans, has_param)| {
let mut tx = SchemaViewer::new(tx, auth);
SubscriptionPlan::compile(&sql, &mut tx, auth).map(|(plans, has_param)| {
Plan::new(
plans,
QueryHash::from_string(
@@ -701,11 +701,11 @@ mod tests {
let indexes = &[0.into(), 1.into()];
let rhs_id = db.create_table_for_test("rhs", schema, indexes)?;
let tx = begin_tx(&db);
let mut tx = begin_tx(&db);
// Should generate an index join since there is an index on `lhs.b`.
// Should push the sargable range condition into the index join's probe side.
let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c > 2 and rhs.c < 4 and rhs.d = 3";
let exp = compile_sql(&db, &AuthCtx::for_testing(), &tx, sql)?.remove(0);
let exp = compile_sql(&db, &AuthCtx::for_testing(), &mut tx, sql)?.remove(0);
let CrudExpr::Query(mut expr) = exp else {
panic!("unexpected result from compilation: {exp:#?}");
@@ -781,11 +781,11 @@ mod tests {
let indexes = &[0.into(), 1.into()];
let _ = db.create_table_for_test("rhs", schema, indexes)?;
let tx = begin_tx(&db);
let mut tx = begin_tx(&db);
// Should generate an index join since there is an index on `lhs.b`.
// Should push the sargable range condition into the index join's probe side.
let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c > 2 and rhs.c < 4 and rhs.d = 3";
let exp = compile_sql(&db, &AuthCtx::for_testing(), &tx, sql)?.remove(0);
let exp = compile_sql(&db, &AuthCtx::for_testing(), &mut tx, sql)?.remove(0);
let CrudExpr::Query(mut expr) = exp else {
panic!("unexpected result from compilation: {exp:#?}");
@@ -865,12 +865,12 @@ mod tests {
.create_table_for_test("rhs", schema, indexes)
.expect("Failed to create_table_for_test rhs");
let tx = begin_tx(&db);
let mut tx = begin_tx(&db);
// Should generate an index join since there is an index on `lhs.b`.
// Should push the sargable range condition into the index join's probe side.
let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c > 2 and rhs.c < 4 and rhs.d = 3";
let exp = compile_sql(&db, &AuthCtx::for_testing(), &tx, sql)
let exp = compile_sql(&db, &AuthCtx::for_testing(), &mut tx, sql)
.expect("Failed to compile_sql")
.remove(0);
+6 -6
View File
@@ -65,14 +65,14 @@ mod tests {
use spacetimedb_vm::relation::MemTable;
fn run_query(db: &Arc<RelationalDB>, sql: String) -> ResultTest<MemTable> {
let tx = begin_tx(db);
let q = compile_sql(db, &AuthCtx::for_testing(), &tx, &sql)?;
let mut tx = begin_tx(db);
let q = compile_sql(db, &AuthCtx::for_testing(), &mut tx, &sql)?;
Ok(execute_for_testing(db, &sql, q)?.pop().unwrap())
}
fn run_query_write(db: &Arc<RelationalDB>, sql: String) -> ResultTest<()> {
let tx = begin_tx(db);
let q = compile_sql(db, &AuthCtx::for_testing(), &tx, &sql)?;
let mut tx = begin_tx(db);
let q = compile_sql(db, &AuthCtx::for_testing(), &mut tx, &sql)?;
drop(tx);
execute_for_testing(db, &sql, q)?;
@@ -92,10 +92,10 @@ mod tests {
}
Ok(())
})?;
let tx = begin_tx(&db);
let mut tx = begin_tx(&db);
let sql = "select * from test where x > 0";
let q = compile_sql(&db, &AuthCtx::for_testing(), &tx, sql)?;
let q = compile_sql(&db, &AuthCtx::for_testing(), &mut tx, sql)?;
let slow = SlowQueryLogger::new(sql, Some(Duration::from_millis(1)), tx.ctx.workload());
+82 -38
View File
@@ -9,12 +9,12 @@ use super::{
type_expr, type_proj, type_select,
};
use crate::errors::{TableFunc, UnexpectedFunctionType};
use crate::expr::{Expr, LeftDeepJoin, ProjectList, ProjectName, Relvar};
use crate::expr::{Expr, FieldProject, LeftDeepJoin, ProjectList, ProjectName, Relvar};
use spacetimedb_lib::identity::AuthCtx;
use spacetimedb_lib::AlgebraicType;
use spacetimedb_primitives::{ArgId, TableId};
use spacetimedb_sats::algebraic_type::fmt::fmt_algebraic_type;
use spacetimedb_sats::ProductValue;
use spacetimedb_sats::{AlgebraicValue, ProductValue};
use spacetimedb_schema::schema::TableOrViewSchema;
use spacetimedb_sql_parser::ast::{BinOp, SqlExpr, SqlLiteral};
use spacetimedb_sql_parser::{
@@ -35,9 +35,7 @@ pub trait SchemaView {
self.table_id(name).and_then(|table_id| self.schema_for_table(table_id))
}
fn get_or_create_params(&self, params: &ProductValue) -> TypingResult<ArgId> {
Ok(ArgId::SENTINEL)
}
fn get_or_create_params(&mut self, params: ProductValue) -> TypingResult<ArgId>;
}
#[derive(Default)]
@@ -60,9 +58,9 @@ pub trait TypeChecker {
type Ast;
type Set;
fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<ProjectList>;
fn type_ast(ast: Self::Ast, tx: &mut impl SchemaView) -> TypingResult<ProjectList>;
fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<ProjectList>;
fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &mut impl SchemaView) -> TypingResult<ProjectList>;
fn type_view_params(
schema: &TableOrViewSchema,
@@ -152,25 +150,35 @@ pub trait TypeChecker {
}
fn type_params(
tx: &mut impl SchemaView,
from: RelExpr,
schema: Arc<TableOrViewSchema>,
alias: Box<str>,
params: Option<ProductValue>,
) -> RelExpr {
) -> TypingResult<RelExpr> {
match params {
None => from,
Some(args) => RelExpr::FunCall(
Relvar {
schema,
alias,
delta: None,
},
args,
),
None => Ok(from),
Some(args) => {
let new_arg_id = tx.get_or_create_params(args)?;
let arg_id_col = schema.inner().get_column_by_name("arg_id").unwrap().col_pos;
Ok(RelExpr::Select(
Box::new(from),
Expr::BinOp(
BinOp::Eq,
Box::new(Expr::Field(FieldProject {
table: alias,
field: arg_id_col.idx(),
ty: AlgebraicType::U64,
})),
Box::new(Expr::Value(AlgebraicValue::U64(new_arg_id.0), AlgebraicType::U64)),
),
))
}
}
}
fn type_from(from: SqlFrom, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<RelExpr> {
fn type_from(from: SqlFrom, vars: &mut Relvars, tx: &mut impl SchemaView) -> TypingResult<RelExpr> {
match from {
SqlFrom::Expr(SqlIdent(name), SqlIdent(alias)) => {
let schema = Self::type_relvar(tx, &name)?;
@@ -202,7 +210,7 @@ pub trait TypeChecker {
}
let schema = Self::type_relvar(tx, &name)?;
let arg = Self::type_view_params(&schema, vars, params)?;
let lhs = Box::new(Self::type_params(join, schema.clone(), alias.clone(), arg));
let lhs = Box::new(Self::type_params(tx, join, schema.clone(), alias.clone(), arg)?);
let rhs = Relvar {
schema,
@@ -237,7 +245,7 @@ pub trait TypeChecker {
delta: None,
});
Ok(Self::type_params(from, schema, alias, arg))
Self::type_params(tx, from, schema, alias, arg)
}
}
}
@@ -256,11 +264,11 @@ impl TypeChecker for SubChecker {
type Ast = SqlSelect;
type Set = SqlSelect;
fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<ProjectList> {
fn type_ast(ast: Self::Ast, tx: &mut impl SchemaView) -> TypingResult<ProjectList> {
Self::type_set(ast, &mut Relvars::default(), tx)
}
fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<ProjectList> {
fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &mut impl SchemaView) -> TypingResult<ProjectList> {
match ast {
SqlSelect {
project,
@@ -283,7 +291,7 @@ impl TypeChecker for SubChecker {
}
/// Parse and type check a subscription query
pub fn parse_and_type_sub(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> TypingResult<(ProjectName, bool)> {
pub fn parse_and_type_sub(sql: &str, tx: &mut impl SchemaView, auth: &AuthCtx) -> TypingResult<(ProjectName, bool)> {
let ast = parse_subscription(sql)?;
let has_param = ast.has_parameter();
let ast = ast.resolve_sender(auth.caller());
@@ -303,15 +311,16 @@ fn expect_table_type(expr: ProjectList) -> TypingResult<ProjectName> {
pub mod test_utils {
use spacetimedb_lib::{db::raw_def::v9::RawModuleDefV9Builder, ProductType};
use spacetimedb_primitives::TableId;
use spacetimedb_sats::AlgebraicType;
use spacetimedb_primitives::{ArgId, TableId};
use spacetimedb_sats::{AlgebraicType, ProductValue};
use spacetimedb_schema::{
def::ModuleDef,
schema::{Schema, TableOrViewSchema, TableSchema},
};
use std::collections::HashMap;
use std::sync::Arc;
use super::SchemaView;
use super::{SchemaView, TypingResult};
pub struct ViewInfo<'a> {
pub(crate) name: &'a str,
pub(crate) columns: &'a [(&'a str, AlgebraicType)],
@@ -333,7 +342,38 @@ pub mod test_utils {
builder.finish().try_into().expect("failed to generate module def")
}
pub struct SchemaViewer(pub ModuleDef);
pub struct MockCallParams {
counter: u64,
params: HashMap<ProductValue, ArgId>,
}
impl Default for MockCallParams {
fn default() -> Self {
Self::new()
}
}
impl MockCallParams {
pub fn new() -> Self {
Self {
counter: 0,
params: HashMap::new(),
}
}
pub fn get_or_insert(&mut self, value: ProductValue) -> ArgId {
if let Some(existing) = self.params.get(&value) {
*existing
} else {
self.counter += 1;
let arg_id = ArgId(self.counter - 1);
self.params.insert(value, arg_id);
arg_id
}
}
}
pub struct SchemaViewer(pub ModuleDef, pub MockCallParams);
impl SchemaView for SchemaViewer {
fn table_id(&self, name: &str) -> Option<TableId> {
@@ -370,6 +410,10 @@ pub mod test_utils {
fn rls_rules_for_table(&self, _: TableId) -> anyhow::Result<Vec<Box<str>>> {
Ok(vec![])
}
fn get_or_create_params(&mut self, params: ProductValue) -> TypingResult<ArgId> {
Ok(self.1.get_or_insert(params))
}
}
}
@@ -423,13 +467,13 @@ mod tests {
}
/// A wrapper around [super::parse_and_type_sub] that takes a dummy [AuthCtx]
fn parse_and_type_sub(sql: &str, tx: &impl SchemaView) -> TypingResult<ProjectName> {
fn parse_and_type_sub(sql: &str, tx: &mut impl SchemaView) -> TypingResult<ProjectName> {
super::parse_and_type_sub(sql, tx, &AuthCtx::for_testing()).map(|(plan, _)| plan)
}
#[test]
fn valid_literals() {
let tx = SchemaViewer(module_def());
let mut tx = SchemaViewer(module_def(), Default::default());
struct TestCase {
sql: &'static str,
@@ -498,27 +542,27 @@ mod tests {
msg: "timestamp ms with timezone",
},
] {
let result = parse_and_type_sub(sql, &tx);
let result = parse_and_type_sub(sql, &mut tx);
assert!(result.is_ok(), "name: {}, error: {}", msg, result.unwrap_err());
}
}
#[test]
fn valid_literals_for_type() {
let tx = SchemaViewer(module_def());
let mut tx = SchemaViewer(module_def(), Default::default());
for ty in [
"i8", "u8", "i16", "u16", "i32", "u32", "i64", "u64", "f32", "f64", "i128", "u128", "i256", "u256",
] {
let sql = format!("select * from t where {ty} = 127");
let result = parse_and_type_sub(&sql, &tx);
let result = parse_and_type_sub(&sql, &mut tx);
assert!(result.is_ok(), "Failed to parse {ty}: {}", result.unwrap_err());
}
}
#[test]
fn invalid_literals() {
let tx = SchemaViewer(module_def());
let mut tx = SchemaViewer(module_def(), Default::default());
struct TestCase {
sql: &'static str,
@@ -547,14 +591,14 @@ mod tests {
msg: "Float as integer",
},
] {
let result = parse_and_type_sub(sql, &tx);
let result = parse_and_type_sub(sql, &mut tx);
assert!(result.is_err(), "{msg}");
}
}
#[test]
fn valid() {
let tx = SchemaViewer(module_def());
let mut tx = SchemaViewer(module_def(), Default::default());
struct TestCase {
sql: &'static str,
@@ -611,14 +655,14 @@ mod tests {
msg: "Type inner join + projection",
},
] {
let result = parse_and_type_sub(sql, &tx);
let result = parse_and_type_sub(sql, &mut tx);
assert!(result.is_ok(), "{msg}");
}
}
#[test]
fn invalid() {
let tx = SchemaViewer(module_def());
let mut tx = SchemaViewer(module_def(), Default::default());
struct TestCase {
sql: &'static str,
@@ -683,7 +727,7 @@ mod tests {
msg: "Columns must be qualified in join expressions",
},
] {
let result = parse_and_type_sub(sql, &tx);
let result = parse_and_type_sub(sql, &mut tx);
assert!(result.is_err(), "{msg}");
}
}
+2
View File
@@ -172,6 +172,8 @@ pub enum TypingError {
FilterReturnType(#[from] FilterReturnType),
#[error(transparent)]
TableFunc(#[from] TableFunc),
#[error("InternalError: Read-only queries cannot create parameters")]
ParamsReadOnly,
#[error(transparent)]
Other(#[from] anyhow::Error),
}
-54
View File
@@ -238,35 +238,6 @@ impl ProjectList {
Self::Agg(_, _, name, ty) => f(name, ty),
}
}
/// Iterate over the function calls in this projection list
pub fn for_each_fun_call(
&mut self,
f: &mut impl FnMut(ProductValue) -> Result<ArgId, TypingError>,
) -> Result<(), TypingError> {
match self {
ProjectList::Name(input) => {
for proj in input {
match proj {
ProjectName::None(expr) | ProjectName::Some(expr, _) => {
expr.for_each_fun_call(f)?;
}
}
}
}
ProjectList::List(input, _) => {
for expr in input {
expr.for_each_fun_call(f)?;
}
}
ProjectList::Limit(input, _) => {
input.for_each_fun_call(f)?;
}
ProjectList::Agg(_, _, _, _) => {}
}
Ok(())
}
}
/// A logical relational expression
@@ -398,31 +369,6 @@ impl RelExpr {
_ => None,
}
}
fn for_each_fun_call(
&mut self,
f: &mut impl FnMut(ProductValue) -> Result<ArgId, TypingError>,
) -> Result<(), TypingError> {
// For function calls, we need to filter by the argument id
if let RelExpr::FunCall(relvar, param) = self {
let new_arg_id = f(param.clone())?;
let arg_id_col = relvar.schema.inner().get_column_by_name("arg_id").unwrap().col_pos;
*self = RelExpr::Select(
Box::new(RelExpr::RelVar(relvar.clone())),
Expr::BinOp(
BinOp::Eq,
Box::new(Expr::Field(FieldProject {
table: relvar.alias.clone(),
field: arg_id_col.idx(),
ty: AlgebraicType::U64,
})),
Box::new(Expr::Value(AlgebraicValue::U64(new_arg_id.0), AlgebraicType::U64)),
),
);
}
Ok(())
}
}
/// A left deep binary cross product
+38 -31
View File
@@ -12,7 +12,7 @@ use crate::{
/// The main driver of RLS resolution for subscription queries.
/// Mainly a wrapper around [resolve_views_for_expr].
pub fn resolve_views_for_sub(
tx: &impl SchemaView,
tx: &mut impl SchemaView,
expr: ProjectName,
auth: &AuthCtx,
has_param: &mut bool,
@@ -54,7 +54,11 @@ pub fn resolve_views_for_sub(
/// The main driver of RLS resolution for sql queries.
/// Mainly a wrapper around [resolve_views_for_expr].
pub fn resolve_views_for_sql(tx: &impl SchemaView, expr: ProjectList, auth: &AuthCtx) -> anyhow::Result<ProjectList> {
pub fn resolve_views_for_sql(
tx: &mut impl SchemaView,
expr: ProjectList,
auth: &AuthCtx,
) -> anyhow::Result<ProjectList> {
// RLS does not apply to the database owner
if auth.bypass_rls() {
return Ok(expr);
@@ -62,44 +66,41 @@ pub fn resolve_views_for_sql(tx: &impl SchemaView, expr: ProjectList, auth: &Aut
// The subscription language is a subset of the sql language.
// Use the subscription helper if this is a compliant expression.
// Use the generic resolver otherwise.
let resolve_for_sub = |expr| resolve_views_for_sub(tx, expr, auth, &mut false);
let resolve_for_sql = |expr| {
resolve_views_for_expr(
// Use all default values
tx,
expr,
None,
Rc::new(ResolveList::None),
&mut false,
&mut 0,
auth,
)
};
match expr {
ProjectList::Limit(expr, n) => Ok(ProjectList::Limit(Box::new(resolve_views_for_sql(tx, *expr, auth)?), n)),
ProjectList::Limit(expr, n) => {
let expr = resolve_views_for_sql(tx, *expr, auth)?;
Ok(ProjectList::Limit(Box::new(expr), n))
}
ProjectList::Name(exprs) => Ok(ProjectList::Name(
exprs
.into_iter()
.map(resolve_for_sub)
.map(|expr| resolve_views_for_sub(tx, expr, auth, &mut false))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.collect(),
)),
ProjectList::List(exprs, fields) => Ok(ProjectList::List(
exprs
.into_iter()
.map(resolve_for_sql)
.map(|expr| {
resolve_views_for_expr(tx, expr, None, Rc::new(ResolveList::None), &mut false, &mut 0, auth)
})
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.collect(),
fields,
)),
ProjectList::Agg(exprs, AggType::Count, name, ty) => Ok(ProjectList::Agg(
exprs
.into_iter()
.map(resolve_for_sql)
.map(|expr| {
resolve_views_for_expr(tx, expr, None, Rc::new(ResolveList::None), &mut false, &mut 0, auth)
})
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
@@ -203,7 +204,7 @@ impl ResolveList {
/// i.e. the subtree rooted at `a` in the above example,
/// must be pushed below the leftmost leaf node of the view expansion.
fn resolve_views_for_expr(
tx: &impl SchemaView,
tx: &mut impl SchemaView,
view: RelExpr,
return_table_id: Option<TableId>,
resolving: Rc<ResolveList>,
@@ -473,21 +474,23 @@ mod tests {
use pretty_assertions as pretty;
use spacetimedb_lib::{identity::AuthCtx, AlgebraicType, AlgebraicValue, Identity, ProductType};
use spacetimedb_primitives::TableId;
use spacetimedb_primitives::{ArgId, TableId};
use spacetimedb_sats::ProductValue;
use spacetimedb_schema::{
def::ModuleDef,
schema::{Schema, TableOrViewSchema, TableSchema},
};
use spacetimedb_sql_parser::ast::BinOp;
use super::resolve_views_for_sub;
use crate::check::test_utils::MockCallParams;
use crate::check::TypingResult;
use crate::{
check::{parse_and_type_sub, test_utils::build_module_def, SchemaView},
expr::{Expr, FieldProject, LeftDeepJoin, ProjectName, RelExpr, Relvar},
};
use super::resolve_views_for_sub;
pub struct SchemaViewer(pub ModuleDef);
pub struct SchemaViewer(pub ModuleDef, pub MockCallParams);
impl SchemaView for SchemaViewer {
fn table_id(&self, name: &str) -> Option<TableId> {
@@ -526,6 +529,10 @@ mod tests {
_ => Ok(vec![]),
}
}
fn get_or_create_params(&mut self, params: ProductValue) -> TypingResult<ArgId> {
Ok(self.1.get_or_insert(params))
}
}
fn module_def() -> ModuleDef {
@@ -549,17 +556,17 @@ mod tests {
}
/// Parse, type check, and resolve RLS rules
fn resolve(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> anyhow::Result<Vec<ProjectName>> {
fn resolve(sql: &str, tx: &mut impl SchemaView, auth: &AuthCtx) -> anyhow::Result<Vec<ProjectName>> {
let (expr, _) = parse_and_type_sub(sql, tx, auth)?;
resolve_views_for_sub(tx, expr, auth, &mut false)
}
#[test]
fn test_rls_for_owner() -> anyhow::Result<()> {
let tx = SchemaViewer(module_def());
let mut tx = SchemaViewer(module_def(), Default::default());
let auth = AuthCtx::new(Identity::ONE, Identity::ONE);
let sql = "select * from users";
let resolved = resolve(sql, &tx, &auth)?;
let resolved = resolve(sql, &mut tx, &auth)?;
let users_schema = tx.schema("users").unwrap();
@@ -577,10 +584,10 @@ mod tests {
#[test]
fn test_rls_for_non_owner() -> anyhow::Result<()> {
let tx = SchemaViewer(module_def());
let mut tx = SchemaViewer(module_def(), Default::default());
let auth = AuthCtx::new(Identity::ZERO, Identity::ONE);
let sql = "select * from users";
let resolved = resolve(sql, &tx, &auth)?;
let resolved = resolve(sql, &mut tx, &auth)?;
let users_schema = tx.schema("users").unwrap();
@@ -612,10 +619,10 @@ mod tests {
#[test]
fn test_multiple_rls_rules_for_table() -> anyhow::Result<()> {
let tx = SchemaViewer(module_def());
let mut tx = SchemaViewer(module_def(), Default::default());
let auth = AuthCtx::new(Identity::ZERO, Identity::ONE);
let sql = "select * from player where level_num = 5";
let resolved = resolve(sql, &tx, &auth)?;
let resolved = resolve(sql, &mut tx, &auth)?;
let users_schema = tx.schema("users").unwrap();
let admins_schema = tx.schema("admins").unwrap();
+14 -14
View File
@@ -394,11 +394,11 @@ impl TypeChecker for SqlChecker {
type Ast = SqlSelect;
type Set = SqlSelect;
fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<ProjectList> {
fn type_ast(ast: Self::Ast, tx: &mut impl SchemaView) -> TypingResult<ProjectList> {
Self::type_set(ast, &mut Relvars::default(), tx)
}
fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<ProjectList> {
fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &mut impl SchemaView) -> TypingResult<ProjectList> {
match ast {
SqlSelect {
project,
@@ -439,7 +439,7 @@ impl TypeChecker for SqlChecker {
}
}
pub fn parse_and_type_sql(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> TypingResult<Statement> {
pub fn parse_and_type_sql(sql: &str, tx: &mut impl SchemaView, auth: &AuthCtx) -> TypingResult<Statement> {
match parse_sql(sql)?.resolve_sender(auth.caller()) {
SqlAst::Select(ast) => Ok(Statement::Select(SqlChecker::type_ast(ast, tx)?)),
SqlAst::Insert(insert) => Ok(Statement::DML(DML::Insert(type_insert(insert, tx)?))),
@@ -451,7 +451,7 @@ pub fn parse_and_type_sql(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> Ty
}
/// Parse and type check a *general* query into a [StatementCtx].
pub fn compile_sql_stmt<'a>(sql: &'a str, tx: &impl SchemaView, auth: &AuthCtx) -> TypingResult<StatementCtx<'a>> {
pub fn compile_sql_stmt<'a>(sql: &'a str, tx: &mut impl SchemaView, auth: &AuthCtx) -> TypingResult<StatementCtx<'a>> {
let statement = parse_and_type_sql(sql, tx, auth)?;
Ok(StatementCtx {
statement,
@@ -520,13 +520,13 @@ mod tests {
}
/// A wrapper around [super::parse_and_type_sql] that takes a dummy [AuthCtx]
fn parse_and_type_sql(sql: &str, tx: &impl SchemaView) -> TypingResult<Statement> {
fn parse_and_type_sql(sql: &str, tx: &mut impl SchemaView) -> TypingResult<Statement> {
super::parse_and_type_sql(sql, tx, &AuthCtx::for_testing())
}
#[test]
fn valid() {
let tx = SchemaViewer(module_def());
let mut tx = SchemaViewer(module_def(), Default::default());
for sql in [
"select str from t",
@@ -534,14 +534,14 @@ mod tests {
"select t.str, arr from t",
"select * from t limit 5",
] {
let result = parse_and_type_sql(sql, &tx);
let result = parse_and_type_sql(sql, &mut tx);
assert!(result.is_ok());
}
}
#[test]
fn invalid() {
let tx = SchemaViewer(module_def());
let mut tx = SchemaViewer(module_def(), Default::default());
for sql in [
// Unqualified columns in a join
@@ -551,7 +551,7 @@ mod tests {
// Unqualified name in join expression
"select t.* from t join s on t.u32 = s.u32 where bytes = 0xABCD",
] {
let result = parse_and_type_sql(sql, &tx);
let result = parse_and_type_sql(sql, &mut tx);
assert!(result.is_err());
}
}
@@ -581,7 +581,7 @@ mod tests {
#[test]
fn views() {
let tx = SchemaViewer(module_def());
let mut tx = SchemaViewer(module_def(), Default::default());
struct TestCase {
sql: &'static str,
@@ -606,7 +606,7 @@ mod tests {
msg: "Function call returning view with parameters",
},
] {
let result = parse_and_type_sql(sql, &tx).inspect_err(|e| {
let result = parse_and_type_sql(sql, &mut tx).inspect_err(|e| {
panic!("Expected OK for `{sql}` but got error: {e}");
});
assert!(result.is_ok(), "{msg}: {sql}");
@@ -630,14 +630,14 @@ mod tests {
msg: "`v` does not take parameters",
},
] {
let result = parse_and_type_sql(sql, &tx);
let result = parse_and_type_sql(sql, &mut tx);
assert!(result.is_err(), "{msg}");
}
}
#[test]
fn params_validation() {
let tx = SchemaViewer(module_def());
let mut tx = SchemaViewer(module_def(), Default::default());
struct TestCase {
sql: &'static str,
@@ -670,7 +670,7 @@ mod tests {
msg: "Unexpected function type. Expected: (U32, String) != Inferred: (U32, String, Num?)",
},
] {
let result = parse_and_type_sql(sql, &tx);
let result = parse_and_type_sql(sql, &mut tx);
if msg == "Correct parameters" {
assert!(result.is_ok(), "{msg}: {sql}");
} else if let Err(err) = &result {
+42 -30
View File
@@ -1402,6 +1402,7 @@ mod tests {
use std::sync::Arc;
use pretty_assertions::assert_eq;
use spacetimedb_expr::check::test_utils::MockCallParams;
use spacetimedb_expr::{
check::{SchemaView, TypingResult},
expr::ProjectName,
@@ -1410,9 +1411,9 @@ mod tests {
use spacetimedb_lib::{
db::auth::{StAccess, StTableType},
identity::AuthCtx,
AlgebraicType, AlgebraicValue,
AlgebraicType, AlgebraicValue, ProductValue,
};
use spacetimedb_primitives::{ColId, ColList, ColSet, TableId, ViewId};
use spacetimedb_primitives::{ArgId, ColId, ColList, ColSet, TableId, ViewId};
use spacetimedb_schema::def::ViewParamDefSimple;
use spacetimedb_schema::identifier::Identifier;
use spacetimedb_schema::schema::ViewDefInfo;
@@ -1431,6 +1432,7 @@ mod tests {
struct SchemaViewer {
schemas: Vec<Arc<TableOrViewSchema>>,
params: MockCallParams,
}
impl SchemaView for SchemaViewer {
@@ -1448,6 +1450,10 @@ mod tests {
fn rls_rules_for_table(&self, _: TableId) -> anyhow::Result<Vec<Box<str>>> {
Ok(vec![])
}
fn get_or_create_params(&mut self, params: ProductValue) -> TypingResult<ArgId> {
Ok(self.params.get_or_insert(params))
}
}
fn schema_with_params(
@@ -1534,7 +1540,7 @@ mod tests {
}
/// A wrapper around [spacetimedb_expr::check::parse_and_type_sub] that takes a dummy [AuthCtx]
fn parse_and_type_sub(sql: &str, tx: &impl SchemaView) -> TypingResult<ProjectName> {
fn parse_and_type_sub(sql: &str, tx: &mut impl SchemaView) -> TypingResult<ProjectName> {
spacetimedb_expr::check::parse_and_type_sub(sql, tx, &AuthCtx::for_testing()).map(|(plan, _)| plan)
}
@@ -1552,14 +1558,15 @@ mod tests {
Some(0),
));
let db = SchemaViewer {
let mut db = SchemaViewer {
schemas: vec![t.clone()],
params: Default::default(),
};
let sql = "select * from t";
let auth = AuthCtx::for_testing();
let lp = parse_and_type_sub(sql, &db).unwrap();
let lp = parse_and_type_sub(sql, &mut db).unwrap();
let pp = compile_select(lp).optimize(&auth).unwrap();
match pp {
@@ -1584,14 +1591,15 @@ mod tests {
Some(0),
));
let db = SchemaViewer {
let mut db = SchemaViewer {
schemas: vec![t.clone()],
params: Default::default(),
};
let sql = "select * from t where x = 5";
let auth = AuthCtx::for_testing();
let lp = parse_and_type_sub(sql, &db).unwrap();
let lp = parse_and_type_sub(sql, &mut db).unwrap();
let pp = compile_select(lp).optimize(&auth).unwrap();
match pp {
@@ -1678,8 +1686,9 @@ mod tests {
Some(0),
));
let db = SchemaViewer {
let mut db = SchemaViewer {
schemas: vec![u.clone(), l.clone(), b.clone()],
params: Default::default(),
};
let sql = "
@@ -1691,7 +1700,7 @@ mod tests {
where u.identity = 5
";
let auth = AuthCtx::for_testing();
let lp = parse_and_type_sub(sql, &db).unwrap();
let lp = parse_and_type_sub(sql, &mut db).unwrap();
let pp = compile_select(lp).optimize(&auth).unwrap();
// Plan:
@@ -1870,8 +1879,9 @@ mod tests {
Some(0),
));
let db = SchemaViewer {
let mut db = SchemaViewer {
schemas: vec![m.clone(), w.clone(), p.clone()],
params: Default::default(),
};
let sql = "
@@ -1884,7 +1894,7 @@ mod tests {
where 5 = m.employee and 5 = v.employee
";
let auth = AuthCtx::for_testing();
let lp = parse_and_type_sub(sql, &db).unwrap();
let lp = parse_and_type_sub(sql, &mut db).unwrap();
let pp = compile_select(lp).optimize(&auth).unwrap();
// Plan:
@@ -2052,13 +2062,14 @@ mod tests {
None,
));
let db = SchemaViewer {
let mut db = SchemaViewer {
schemas: vec![t.clone()],
params: Default::default(),
};
let sql = "select * from t where x = 3 and y = 4 and z = 5";
let auth = AuthCtx::for_testing();
let lp = parse_and_type_sub(sql, &db).unwrap();
let lp = parse_and_type_sub(sql, &mut db).unwrap();
let pp = compile_select(lp).optimize(&auth).unwrap();
// Select index on (x, y, z)
@@ -2081,7 +2092,7 @@ mod tests {
// Test permutations of the same query
let sql = "select * from t where z = 5 and y = 4 and x = 3";
let lp = parse_and_type_sub(sql, &db).unwrap();
let lp = parse_and_type_sub(sql, &mut db).unwrap();
let pp = compile_select(lp).optimize(&auth).unwrap();
match pp {
@@ -2102,7 +2113,7 @@ mod tests {
};
let sql = "select * from t where x = 3 and y = 4";
let lp = parse_and_type_sub(sql, &db).unwrap();
let lp = parse_and_type_sub(sql, &mut db).unwrap();
let pp = compile_select(lp).optimize(&auth).unwrap();
// Select index on x
@@ -2130,7 +2141,7 @@ mod tests {
};
let sql = "select * from t where w = 5 and x = 4";
let lp = parse_and_type_sub(sql, &db).unwrap();
let lp = parse_and_type_sub(sql, &mut db).unwrap();
let pp = compile_select(lp).optimize(&auth).unwrap();
// Select index on x
@@ -2158,7 +2169,7 @@ mod tests {
};
let sql = "select * from t where y = 1";
let lp = parse_and_type_sub(sql, &db).unwrap();
let lp = parse_and_type_sub(sql, &mut db).unwrap();
let pp = compile_select(lp).optimize(&auth).unwrap();
// Do not select index on (y, z)
@@ -2173,7 +2184,7 @@ mod tests {
// Select index on [y, z]
let sql = "select * from t where y = 1 and z = 2";
let lp = parse_and_type_sub(sql, &db).unwrap();
let lp = parse_and_type_sub(sql, &mut db).unwrap();
let pp = compile_select(lp).optimize(&auth).unwrap();
match pp {
@@ -2192,7 +2203,7 @@ mod tests {
// Check permutations of the same query
let sql = "select * from t where z = 2 and y = 1";
let lp = parse_and_type_sub(sql, &db).unwrap();
let lp = parse_and_type_sub(sql, &mut db).unwrap();
let pp = compile_select(lp).optimize(&auth).unwrap();
match pp {
@@ -2211,7 +2222,7 @@ mod tests {
// Select index on (y, z) and filter on (w)
let sql = "select * from t where w = 1 and y = 2 and z = 3";
let lp = parse_and_type_sub(sql, &db).unwrap();
let lp = parse_and_type_sub(sql, &mut db).unwrap();
let pp = compile_select(lp).optimize(&auth).unwrap();
let plan = match pp {
@@ -2251,12 +2262,13 @@ mod tests {
None,
));
let db = SchemaViewer {
let mut db = SchemaViewer {
schemas: vec![t.clone()],
params: Default::default(),
};
let compile = |sql| {
let stmt = parse_and_type_sql(sql, &db, &AuthCtx::for_testing()).unwrap();
let mut compile = |sql| {
let stmt = parse_and_type_sql(sql, &mut db, &AuthCtx::for_testing()).unwrap();
let Statement::Select(select) = stmt else {
unreachable!()
};
@@ -2322,20 +2334,20 @@ mod tests {
Some(&[("param_id", AlgebraicType::U64)]),
));
let db = SchemaViewer {
let mut db = SchemaViewer {
schemas: vec![t.clone(), v.clone()],
params: Default::default(),
};
let sql = "select * from v(0)";
let auth = AuthCtx::for_testing();
let lp = parse_and_type_sub(sql, &db).unwrap();
dbg!(&lp);
let lp = parse_and_type_sub(sql, &mut db).unwrap();
let pp = compile_select(lp).optimize(&auth).unwrap();
dbg!(&pp);
match pp {
ProjectPlan::None(PhysicalPlan::Filter(input, PhysicalExpr::BinOp(BinOp::Eq, field, value))) => {
assert!(matches!(*field, PhysicalExpr::Field(TupleField { field_pos: 1, .. })));
// This is the internal parameter filter
assert!(matches!(*field, PhysicalExpr::Field(TupleField { field_pos: 0, .. })));
assert!(matches!(*value, PhysicalExpr::Value(AlgebraicValue::U64(0))));
match *input {
@@ -2349,12 +2361,12 @@ mod tests {
};
let sql = "select * from v(0) as x JOIN t ON x.id = t.id";
let lp = parse_and_type_sub(sql, &db).unwrap();
let lp = parse_and_type_sub(sql, &mut db).unwrap();
let pp = compile_select(lp).optimize(&auth).unwrap();
match pp {
ProjectPlan::None(PhysicalPlan::Filter(_, PhysicalExpr::BinOp(BinOp::Eq, field, value))) => {
assert!(matches!(*field, PhysicalExpr::Field(TupleField { field_pos: 1, .. })));
assert!(matches!(*field, PhysicalExpr::Field(TupleField { field_pos: 0, .. })));
assert!(matches!(*value, PhysicalExpr::Value(AlgebraicValue::U64(0))));
}
proj => panic!("unexpected project: {proj:#?}"),
+3 -20
View File
@@ -4,8 +4,6 @@ use spacetimedb_execution::{
pipelined::ProjectListExecutor,
Datastore, DeltaStore,
};
use spacetimedb_expr::errors::TypingError;
use spacetimedb_expr::expr::CallParams;
use spacetimedb_expr::{
check::{parse_and_type_sub, SchemaView},
expr::ProjectList,
@@ -17,30 +15,15 @@ use spacetimedb_physical_plan::{
compile::{compile_dml_plan, compile_select, compile_select_list},
plan::{ProjectListPlan, ProjectPlan},
};
use spacetimedb_primitives::{ArgId, TableId};
use std::collections::HashMap;
use spacetimedb_primitives::TableId;
/// DIRTY HACK ALERT: Maximum allowed length, in UTF-8 bytes, of SQL queries.
/// Any query longer than this will be rejected.
/// This prevents a stack overflow when compiling queries with deeply-nested `AND` and `OR` conditions.
const MAX_SQL_LENGTH: usize = 50_000;
pub trait CallParamsExt {
fn get_arg(&self, params: &ProductValue) -> Result<ArgId, TypingError>;
}
pub struct MockCallParams {
params: HashMap<ProductValue, ArgId>,
}
impl CallParamsExt for MockCallParams {
fn get_arg(&self, params: &ProductValue) -> Result<ArgId, TypingError> {
todo!()
}
}
pub fn compile_subscription(
sql: &str,
tx: &impl SchemaView,
tx: &mut impl SchemaView,
auth: &AuthCtx,
) -> Result<(Vec<ProjectPlan>, TableId, Box<str>, bool)> {
if sql.len() > MAX_SQL_LENGTH {
@@ -72,7 +55,7 @@ pub fn compile_subscription(
}
/// A utility for parsing and type checking a sql statement
pub fn compile_sql_stmt(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> Result<Statement> {
pub fn compile_sql_stmt(sql: &str, tx: &mut impl SchemaView, auth: &AuthCtx) -> Result<Statement> {
if sql.len() > MAX_SQL_LENGTH {
bail!("SQL query exceeds maximum allowed length: \"{sql:.120}...\"")
}
+1 -1
View File
@@ -508,7 +508,7 @@ impl SubscriptionPlan {
}
/// Generate a plan for incrementally maintaining a subscription
pub fn compile(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> Result<(Vec<Self>, bool)> {
pub fn compile(sql: &str, tx: &mut impl SchemaView, auth: &AuthCtx) -> Result<(Vec<Self>, bool)> {
let (plans, return_id, return_name, has_param) = compile_subscription(sql, tx, auth)?;
/// Does this plan have any non-index joins?