diff --git a/crates/bench/benches/subscription.rs b/crates/bench/benches/subscription.rs index 065916f78b..9f291f0fb9 100644 --- a/crates/bench/benches/subscription.rs +++ b/crates/bench/benches/subscription.rs @@ -126,10 +126,10 @@ fn eval(c: &mut Criterion) { // A benchmark runner for the new query engine let bench_query = |c: &mut Criterion, name, sql| { c.bench_function(name, |b| { - let tx = raw.db.begin_tx(Workload::Subscribe); + let mut tx = raw.db.begin_tx(Workload::Subscribe); let auth = AuthCtx::for_testing(); - let schema_viewer = &SchemaViewer::new(&tx, &auth); - let (plans, table_id, table_name, _) = compile_subscription(sql, schema_viewer, &auth).unwrap(); + let mut schema_viewer = SchemaViewer::new(&mut tx, &auth); + let (plans, table_id, table_name, _) = compile_subscription(sql, &mut schema_viewer, &auth).unwrap(); let plans = plans .into_iter() .map(|plan| plan.optimize(&auth).unwrap()) @@ -155,8 +155,8 @@ fn eval(c: &mut Criterion) { let bench_eval = |c: &mut Criterion, name, sql| { c.bench_function(name, |b| { - let tx = raw.db.begin_tx(Workload::Update); - let query = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), &tx, sql).unwrap(); + let mut tx = raw.db.begin_tx(Workload::Update); + let query = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), &mut tx, sql).unwrap(); let query: ExecutionSet = query.into(); b.iter(|| { @@ -207,11 +207,11 @@ fn eval(c: &mut Criterion) { // A passthru executed independently of the database. let select_lhs = "select * from footprint"; let select_rhs = "select * from location"; - let tx = &raw.db.begin_tx(Workload::Update); - let query_lhs = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), tx, select_lhs).unwrap(); - let query_rhs = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), tx, select_rhs).unwrap(); + let mut tx = raw.db.begin_tx(Workload::Update); + let query_lhs = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), &mut tx, select_lhs).unwrap(); + let query_rhs = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), &mut tx, select_rhs).unwrap(); let query = ExecutionSet::from_iter(query_lhs.into_iter().chain(query_rhs)); - let tx = &tx.into(); + let tx = &(&mut tx).into(); b.iter(|| drop(black_box(query.eval_incr_for_test(&raw.db, tx, &update, None)))) }); @@ -226,10 +226,10 @@ fn eval(c: &mut Criterion) { from footprint join location on footprint.entity_id = location.entity_id \ where location.chunk_index = {chunk_index}" ); - let tx = &raw.db.begin_tx(Workload::Update); - let query = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), tx, &join).unwrap(); + let mut tx = raw.db.begin_tx(Workload::Update); + let query = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), &mut tx, &join).unwrap(); let query: ExecutionSet = query.into(); - let tx = &tx.into(); + let tx = &(&mut tx).into(); b.iter(|| drop(black_box(query.eval_incr_for_test(&raw.db, tx, &update, None)))); }); diff --git a/crates/core/src/estimation.rs b/crates/core/src/estimation.rs index 24130a145f..5e046b06c3 100644 --- a/crates/core/src/estimation.rs +++ b/crates/core/src/estimation.rs @@ -181,8 +181,8 @@ mod tests { } fn num_rows_for(db: &RelationalDB, sql: &str) -> u64 { - let tx = begin_tx(db); - match &*compile_sql(db, &AuthCtx::for_testing(), &tx, sql).expect("Failed to compile sql") { + let mut tx = begin_tx(db); + match &*compile_sql(db, &AuthCtx::for_testing(), &mut tx, sql).expect("Failed to compile sql") { [CrudExpr::Query(expr)] => num_rows(&tx, expr), exprs => panic!("unexpected result from compilation: {exprs:#?}"), } @@ -191,10 +191,10 @@ mod tests { /// Using the new query plan fn new_row_estimate(db: &RelationalDB, sql: &str) -> u64 { let auth = AuthCtx::for_testing(); - let tx = begin_tx(db); - let tx = SchemaViewer::new(&tx, &auth); + let mut tx = begin_tx(db); + let mut tx = SchemaViewer::new(&mut tx, &auth); - compile_subscription(sql, &tx, &auth) + compile_subscription(sql, &mut tx, &auth) .map(|(plans, ..)| plans) .expect("failed to compile sql query") .into_iter() diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index aaafeb337f..3f77ba4e6f 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -1869,7 +1869,7 @@ impl ModuleHost { let metrics = self .on_module_thread("one_off_query", move || { let (tx_offset_sender, tx_offset_receiver) = oneshot::channel(); - let tx = scopeguard::guard(db.begin_tx(Workload::Sql), |tx| { + let mut tx = scopeguard::guard(db.begin_tx(Workload::Sql), |tx| { let (tx_offset, tx_metrics, reducer) = db.release_tx(tx); let _ = tx_offset_sender.send(tx_offset); db.report_read_tx_metrics(reducer, tx_metrics); @@ -1878,7 +1878,7 @@ impl ModuleHost { // We wrap the actual query in a closure so we can use ? to handle errors without making // the entire transaction abort with an error. let result: Result<(OneOffTable, ExecutionMetrics), anyhow::Error> = (|| { - let tx = SchemaViewer::new(&*tx, &auth); + let mut tx = SchemaViewer::new(&mut *tx, &auth); let ( // A query may compile down to several plans. @@ -1888,7 +1888,7 @@ impl ModuleHost { _, table_name, _, - ) = compile_subscription(&query, &tx, &auth)?; + ) = compile_subscription(&query, &mut tx, &auth)?; // Optimize each fragment let optimized = plans diff --git a/crates/core/src/host/wasm_common/module_host_actor.rs b/crates/core/src/host/wasm_common/module_host_actor.rs index 13cdd04fd1..543c07c5a6 100644 --- a/crates/core/src/host/wasm_common/module_host_actor.rs +++ b/crates/core/src/host/wasm_common/module_host_actor.rs @@ -1126,10 +1126,10 @@ impl InstanceCommon { // Views bypass RLS, since views should enforce their own access control procedurally. let auth = AuthCtx::for_current(self.info.database_identity); - let schema_view = SchemaViewer::new(&*tx, &auth); + let mut schema_view = SchemaViewer::new(&mut *tx, &auth); // Compile to subscription plans. - let (plans, has_params) = SubscriptionPlan::compile(the_query, &schema_view, &auth)?; + let (plans, has_params) = SubscriptionPlan::compile(the_query, &mut schema_view, &auth)?; ensure!( !has_params, "parameterized SQL is not supported for view materialization yet" diff --git a/crates/core/src/sql/ast.rs b/crates/core/src/sql/ast.rs index a11775331c..51f73c8f4f 100644 --- a/crates/core/src/sql/ast.rs +++ b/crates/core/src/sql/ast.rs @@ -5,6 +5,7 @@ use spacetimedb_data_structures::map::{HashCollectionExt as _, IntMap}; use spacetimedb_datastore::locking_tx_datastore::state_view::StateView; use spacetimedb_datastore::system_tables::{StRowLevelSecurityFields, ST_ROW_LEVEL_SECURITY_ID}; use spacetimedb_expr::check::{SchemaView, TypingResult}; +use spacetimedb_expr::errors::TypingError; use spacetimedb_expr::statement::compile_sql_stmt; use spacetimedb_lib::identity::AuthCtx; use spacetimedb_primitives::{ArgId, ColId, TableId}; @@ -22,7 +23,7 @@ use sqlparser::ast::{ }; use sqlparser::dialect::PostgreSqlDialect; use sqlparser::parser::Parser; -use std::ops::Deref; +use std::ops::{Deref, DerefMut}; use std::sync::Arc; /// Simplify to detect features of the syntax we don't support yet @@ -477,7 +478,7 @@ fn compile_where(table: &From, filter: Option) -> Result { - pub(crate) tx: &'a T, + tx: &'a mut T, auth: &'a AuthCtx, } @@ -489,6 +490,12 @@ impl Deref for SchemaViewer<'_, T> { } } +impl DerefMut for SchemaViewer<'_, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.tx + } +} + impl SchemaView for SchemaViewer<'_, T> { fn table_id(&self, name: &str) -> Option { // Get the schema from the in-memory state instead of fetching from the database for speed @@ -536,10 +543,15 @@ impl SchemaView for SchemaViewer<'_, T> { }) .collect::>() } + + fn get_or_create_params(&mut self, _params: ProductValue) -> TypingResult { + // Caller should have used `SchemaViewerMut` on crate `core` + Err(TypingError::ParamsReadOnly) + } } impl<'a, T> SchemaViewer<'a, T> { - pub fn new(tx: &'a T, auth: &'a AuthCtx) -> Self { + pub fn new(tx: &'a mut T, auth: &'a AuthCtx) -> Self { Self { tx, auth } } } @@ -1000,13 +1012,12 @@ fn compile_statement( pub(crate) fn compile_to_ast( db: &RelationalDB, auth: &AuthCtx, - tx: &T, + tx: &mut T, sql_text: &str, ) -> Result, DBError> { // NOTE: The following ensures compliance with the 1.0 sql api. // Come 1.0, it will have replaced the current compilation stack. - compile_sql_stmt(sql_text, &SchemaViewer::new(tx, auth), auth)?; - + compile_sql_stmt(sql_text, &mut SchemaViewer::new(tx, auth), auth)?; let dialect = PostgreSqlDialect {}; let ast = Parser::parse_sql(&dialect, sql_text).map_err(|error| DBError::SqlParser { sql: sql_text.to_string(), diff --git a/crates/core/src/sql/compiler.rs b/crates/core/src/sql/compiler.rs index eb17de1e8a..97ba4a507e 100644 --- a/crates/core/src/sql/compiler.rs +++ b/crates/core/src/sql/compiler.rs @@ -23,7 +23,7 @@ const MAX_SQL_LENGTH: usize = 50_000; pub fn compile_sql( db: &RelationalDB, auth: &AuthCtx, - tx: &T, + tx: &mut T, sql_text: &str, ) -> Result, DBError> { if sql_text.len() > MAX_SQL_LENGTH { @@ -266,7 +266,7 @@ mod tests { fn compile_sql( db: &RelationalDB, - tx: &T, + tx: &mut T, sql: &str, ) -> Result, DBError> { super::compile_sql(db, &AuthCtx::for_testing(), tx, sql) @@ -281,10 +281,10 @@ mod tests { let indexes = &[]; db.create_table_for_test("test", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Compile query let sql = "select * from test where a = 1"; - let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else { + let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else { panic!("Expected QueryExpr"); }; assert_eq!(1, query.len()); @@ -303,10 +303,10 @@ mod tests { &[1.into(), 0.into()], )?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Should work with any qualified field. let sql = "select * from test where a = 1 and b <> 3"; - let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else { + let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else { panic!("Expected QueryExpr"); }; assert_eq!(2, query.len()); @@ -324,10 +324,10 @@ mod tests { let indexes = &[0.into()]; db.create_table_for_test("test", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); //Compile query let sql = "select * from test where a = 1"; - let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else { + let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else { panic!("Expected QueryExpr"); }; assert_eq!(1, query.len()); @@ -377,11 +377,11 @@ mod tests { let rows = run_for_testing(&db, sql)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); let CrudExpr::Query(QueryExpr { source: _, query: mut ops, - }) = compile_sql(&db, &tx, sql)?.remove(0) + }) = compile_sql(&db, &mut tx, sql)?.remove(0) else { panic!("Expected QueryExpr"); }; @@ -407,11 +407,11 @@ mod tests { let indexes = &[1.into()]; db.create_table_for_test("test", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Note, order does not matter. // The sargable predicate occurs last, but we can still generate an index scan. let sql = "select * from test where a = 1 and b = 2"; - let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else { + let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else { panic!("Expected QueryExpr"); }; assert_eq!(2, query.len()); @@ -429,11 +429,11 @@ mod tests { let indexes = &[1.into()]; db.create_table_for_test("test", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Note, order does not matter. // The sargable predicate occurs first and we can generate an index scan. let sql = "select * from test where b = 2 and a = 1"; - let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else { + let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else { panic!("Expected QueryExpr"); }; assert_eq!(2, query.len()); @@ -455,9 +455,9 @@ mod tests { ]; db.create_table_for_test_multi_column("test", schema, col_list![0, 1])?; - let tx = begin_mut_tx(&db); + let mut tx = begin_mut_tx(&db); let sql = "select * from test where b = 2 and a = 1"; - let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else { + let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else { panic!("Expected QueryExpr"); }; assert_eq!(1, query.len()); @@ -474,10 +474,10 @@ mod tests { let indexes = &[0.into(), 1.into()]; db.create_table_for_test("test", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Compile query let sql = "select * from test where a = 1 or b = 2"; - let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else { + let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else { panic!("Expected QueryExpr"); }; assert_eq!(1, query.len()); @@ -495,10 +495,10 @@ mod tests { let indexes = &[1.into()]; db.create_table_for_test("test", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Compile query let sql = "select * from test where b > 2"; - let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else { + let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else { panic!("Expected QueryExpr"); }; assert_eq!(1, query.len()); @@ -516,10 +516,10 @@ mod tests { let indexes = &[1.into()]; db.create_table_for_test("test", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Compile query let sql = "select * from test where b > 2 and b < 5"; - let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else { + let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else { panic!("Expected QueryExpr"); }; assert_eq!(1, query.len()); @@ -542,11 +542,11 @@ mod tests { let indexes = &[0.into(), 1.into()]; db.create_table_for_test("test", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Note, order matters - the equality condition occurs first which // means an index scan will be generated rather than the range condition. let sql = "select * from test where a = 3 and b > 2 and b < 5"; - let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else { + let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else { panic!("Expected QueryExpr"); }; assert_eq!(2, query.len()); @@ -569,10 +569,10 @@ mod tests { let indexes = &[]; let rhs_id = db.create_table_for_test("rhs", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Should push sargable equality condition below join let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where lhs.a = 3"; - let exp = compile_sql(&db, &tx, sql)?.remove(0); + let exp = compile_sql(&db, &mut tx, sql)?.remove(0); let CrudExpr::Query(QueryExpr { source: source_lhs, @@ -621,10 +621,10 @@ mod tests { let schema = &[("b", AlgebraicType::U64), ("c", AlgebraicType::U64)]; let rhs_id = db.create_table_for_test("rhs", schema, &[])?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Should push equality condition below join let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where lhs.a = 3"; - let exp = compile_sql(&db, &tx, sql)?.remove(0); + let exp = compile_sql(&db, &mut tx, sql)?.remove(0); let CrudExpr::Query(QueryExpr { source: source_lhs, @@ -678,10 +678,10 @@ mod tests { let schema = &[("b", AlgebraicType::U64), ("c", AlgebraicType::U64)]; let rhs_id = db.create_table_for_test("rhs", schema, &[])?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Should push equality condition below join let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c = 3"; - let exp = compile_sql(&db, &tx, sql)?.remove(0); + let exp = compile_sql(&db, &mut tx, sql)?.remove(0); let CrudExpr::Query(QueryExpr { source: source_lhs, @@ -736,11 +736,11 @@ mod tests { let indexes = &[1.into()]; let rhs_id = db.create_table_for_test("rhs", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Should push the sargable equality condition into the join's left arg. // Should push the sargable range condition into the join's right arg. let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where lhs.a = 3 and rhs.c < 4"; - let exp = compile_sql(&db, &tx, sql)?.remove(0); + let exp = compile_sql(&db, &mut tx, sql)?.remove(0); let CrudExpr::Query(QueryExpr { source: source_lhs, @@ -807,11 +807,11 @@ mod tests { let indexes = &[0.into(), 1.into()]; let rhs_id = db.create_table_for_test("rhs", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Should generate an index join since there is an index on `lhs.b`. // Should push the sargable range condition into the index join's probe side. let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c > 2 and rhs.c < 4 and rhs.d = 3"; - let exp = compile_sql(&db, &tx, sql)?.remove(0); + let exp = compile_sql(&db, &mut tx, sql)?.remove(0); let CrudExpr::Query(QueryExpr { source: SourceExpr::DbTable(DbTable { table_id, .. }), @@ -889,11 +889,11 @@ mod tests { let indexes = col_list![0, 1]; let rhs_id = db.create_table_for_test_multi_column("rhs", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Should generate an index join since there is an index on `lhs.b`. // Should push the sargable range condition into the index join's probe side. let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c = 2 and rhs.b = 4 and rhs.d = 3"; - let exp = compile_sql(&db, &tx, sql)?.remove(0); + let exp = compile_sql(&db, &mut tx, sql)?.remove(0); let CrudExpr::Query(QueryExpr { source: SourceExpr::DbTable(DbTable { table_id, .. }), @@ -953,7 +953,7 @@ mod tests { let db = TestDB::durable()?; db.create_table_for_test("A", &[("x", AlgebraicType::U64)], &[])?; db.create_table_for_test("B", &[("y", AlgebraicType::U64)], &[])?; - assert!(compile_sql(&db, &begin_tx(&db), "select B.* from B join A on B.y = A.x").is_ok()); + assert!(compile_sql(&db, &mut begin_tx(&db), "select B.* from B join A on B.y = A.x").is_ok()); Ok(()) } @@ -970,27 +970,27 @@ mod tests { // TODO: Type check other operations deferred for the new query engine. assert!( - compile_sql(&db, &begin_tx(&db), sql).is_err(), + compile_sql(&db, &mut begin_tx(&db), sql).is_err(), // Err("SqlError: Type Mismatch: `PlayerState.entity_id: U64` != `String(\"161853\"): String`, executing: `SELECT * FROM PlayerState WHERE entity_id = '161853'`".into()) ); // Check we can still compile the query if we remove the type mismatch and have multiple logical operations. let sql = "SELECT * FROM PlayerState WHERE entity_id = 1 AND entity_id = 2 AND entity_id = 3 OR entity_id = 4 OR entity_id = 5"; - assert!(compile_sql(&db, &begin_tx(&db), sql).is_ok()); + assert!(compile_sql(&db, &mut begin_tx(&db), sql).is_ok()); // Now verify when we have a type mismatch in the middle of the logical operations. let sql = "SELECT * FROM PlayerState WHERE entity_id = 1 AND entity_id"; assert!( - compile_sql(&db, &begin_tx(&db), sql).is_err(), + compile_sql(&db, &mut begin_tx(&db), sql).is_err(), // Err("SqlError: Type Mismatch: `PlayerState.entity_id: U64 == U64(1): U64` and `PlayerState.entity_id: U64`, both sides must be an `Bool` expression, executing: `SELECT * FROM PlayerState WHERE entity_id = 1 AND entity_id`".into()) ); // Verify that all operands of `AND` must be `Bool`. let sql = "SELECT * FROM PlayerState WHERE entity_id AND entity_id"; assert!( - compile_sql(&db, &begin_tx(&db), sql).is_err(), + compile_sql(&db, &mut begin_tx(&db), sql).is_err(), // Err("SqlError: Type Mismatch: `PlayerState.entity_id: U64` and `PlayerState.entity_id: U64`, both sides must be an `Bool` expression, executing: `SELECT * FROM PlayerState WHERE entity_id AND entity_id`".into()) ); Ok(()) diff --git a/crates/core/src/sql/execute.rs b/crates/core/src/sql/execute.rs index 2e2e233750..6f30872e85 100644 --- a/crates/core/src/sql/execute.rs +++ b/crates/core/src/sql/execute.rs @@ -20,9 +20,8 @@ use anyhow::anyhow; use spacetimedb_datastore::execution_context::Workload; use spacetimedb_datastore::locking_tx_datastore::state_view::StateView; use spacetimedb_datastore::traits::IsolationLevel; -use spacetimedb_expr::check::SchemaView; +use spacetimedb_expr::check::{SchemaView, TypingResult}; use spacetimedb_expr::errors::TypingError; -use spacetimedb_expr::expr::CallParams; use spacetimedb_expr::statement::Statement; use spacetimedb_lib::identity::AuthCtx; use spacetimedb_lib::metrics::ExecutionMetrics; @@ -191,15 +190,27 @@ pub struct SqlResult { pub metrics: ExecutionMetrics, } -struct DbParams<'a> { +struct SchemaViewerMut<'a> { db: &'a RelationalDB, - tx: &'a mut MutTx, + schema: SchemaViewer<'a, MutTx>, } -impl CallParams for DbParams<'_> { - fn create_or_get_param(&mut self, param: &ProductValue) -> Result { +impl SchemaView for SchemaViewerMut<'_> { + fn table_id(&self, name: &str) -> Option { + self.schema.table_id(name) + } + + fn schema_for_table(&self, table_id: TableId) -> Option> { + self.schema.schema_for_table(table_id) + } + + fn rls_rules_for_table(&self, table_id: TableId) -> anyhow::Result>> { + self.schema.rls_rules_for_table(table_id) + } + + fn get_or_create_params(&mut self, params: ProductValue) -> TypingResult { self.db - .create_or_get_params(self.tx, ¶m) + .create_or_get_params(&mut self.schema, ¶ms) .map_err(|err| TypingError::Other(err.into())) } } @@ -215,20 +226,16 @@ pub async fn run( ) -> Result { // We parse the sql statement in a mutable transaction. // If it turns out to be a query, we downgrade the tx. - let (tx, stmt) = - db.with_auto_rollback( - db.begin_mut_tx(IsolationLevel::Serializable, Workload::Sql), - |tx| match compile_sql_stmt(sql_text, &SchemaViewer::new(tx, &auth), &auth) { - Ok(Statement::Select(mut stmt)) => { - stmt.for_each_fun_call(&mut |param| { - db.create_or_get_params(tx, ¶m) - .map_err(|err| TypingError::Other(err.into())) - })?; - Ok(Statement::Select(stmt)) - } - result => result, + let (tx, stmt) = db.with_auto_rollback(db.begin_mut_tx(IsolationLevel::Serializable, Workload::Sql), |tx| { + compile_sql_stmt( + sql_text, + &mut SchemaViewerMut { + db, + schema: SchemaViewer::new(tx, &auth), }, - )?; + &auth, + ) + })?; let mut metrics = ExecutionMetrics::default(); @@ -1619,6 +1626,10 @@ pub(crate) mod tests { true, )?; let arg_id = ST_RESERVED_SEQUENCE_RANGE as u64; + assert_eq!( + run_for_testing(&db, "select view_id, param_pos, param_name FROM st_view_param")?, + vec![product![arg_id as u32, 0u16, "x"]] + ); with_auto_commit(&db, |tx| -> Result<_, DBError> { tests_utils::insert_into_view(&db, tx, table_id, None, product![arg_id + 1, 0u8, 1i64])?; @@ -1636,6 +1647,12 @@ pub(crate) mod tests { vec![product![0u8, 1i64]] ); + // We have created the internal rows for view args + assert_eq!( + run_for_testing(&db, "select id FROM st_view_arg")?, + vec![product![arg_id], product![arg_id + 1]] + ); + Ok(()) } } diff --git a/crates/core/src/sql/parser.rs b/crates/core/src/sql/parser.rs index 66d216a559..3dde1a4926 100644 --- a/crates/core/src/sql/parser.rs +++ b/crates/core/src/sql/parser.rs @@ -19,7 +19,7 @@ impl RowLevelExpr { auth_ctx: &AuthCtx, rls: &RawRowLevelSecurityDefV9, ) -> anyhow::Result { - let (sql, _) = parse_and_type_sub(&rls.sql, &SchemaViewer::new(tx, auth_ctx), auth_ctx)?; + let (sql, _) = parse_and_type_sub(&rls.sql, &mut SchemaViewer::new(tx, auth_ctx), auth_ctx)?; let table_id = sql.return_table_id().unwrap(); let schema = tx.schema_for_table(table_id)?; diff --git a/crates/core/src/subscription/module_subscription_actor.rs b/crates/core/src/subscription/module_subscription_actor.rs index 0396a12f10..842996c830 100644 --- a/crates/core/src/subscription/module_subscription_actor.rs +++ b/crates/core/src/subscription/module_subscription_actor.rs @@ -472,7 +472,7 @@ impl ModuleSubscriptions { let hash = QueryHash::from_string(&sql, auth.caller(), false); let hash_with_param = QueryHash::from_string(&sql, auth.caller(), true); - let (mut_tx, _) = self.begin_mut_tx(Workload::Subscribe); + let (mut mut_tx, _) = self.begin_mut_tx(Workload::Subscribe); let existing_query = { let guard = self.subscriptions.read(); @@ -482,7 +482,7 @@ impl ModuleSubscriptions { let query = return_on_err_with_sql!( existing_query.map(Ok).unwrap_or_else(|| compile_query_with_hashes( &auth, - &*mut_tx, + &mut *mut_tx, &sql, hash, hash_with_param @@ -736,7 +736,7 @@ impl ModuleSubscriptions { } // We always get the db lock before the subscription lock to avoid deadlocks. - let (mut_tx, _tx_offset) = self.begin_mut_tx(Workload::Subscribe); + let (mut mut_tx, _tx_offset) = self.begin_mut_tx(Workload::Subscribe); let compile_timer = metrics.compilation_time.start_timer(); @@ -752,7 +752,7 @@ impl ModuleSubscriptions { super::subscription::get_all( |relational_db, tx| relational_db.get_all_tables_mut(tx).map(|schemas| schemas.into_iter()), &self.relational_db, - &*mut_tx, + &mut *mut_tx, &auth, )? .into_iter() @@ -769,7 +769,7 @@ impl ModuleSubscriptions { plans.push(unit); } else { plans.push(Arc::new( - compile_query_with_hashes(&auth, &*mut_tx, sql, hash, hash_with_param).map_err(|err| { + compile_query_with_hashes(&auth, &mut *mut_tx, sql, hash, hash_with_param).map_err(|err| { DBError::WithSql { error: Box::new(DBError::Other(err.into())), sql: sql.into(), @@ -1807,8 +1807,8 @@ mod tests { let auth = AuthCtx::for_testing(); let sql = "select * from t where id = 1"; - let tx = begin_tx(&db); - let plan = compile_read_only_query(&auth, &tx, sql)?; + let mut tx = begin_tx(&db); + let plan = compile_read_only_query(&auth, &mut tx, sql)?; let plan = Arc::new(plan); let (_, metrics) = subs.evaluate_queries(sender, &[plan], &tx, &auth, TableUpdateType::Subscribe)?; diff --git a/crates/core/src/subscription/module_subscription_manager.rs b/crates/core/src/subscription/module_subscription_manager.rs index 6b93e996ce..dccb0b1dfc 100644 --- a/crates/core/src/subscription/module_subscription_manager.rs +++ b/crates/core/src/subscription/module_subscription_manager.rs @@ -1681,8 +1681,8 @@ mod tests { fn compile_plan(db: &RelationalDB, sql: &str) -> ResultTest> { with_read_only(db, |tx| { let auth = AuthCtx::for_testing(); - let tx = SchemaViewer::new(&*tx, &auth); - let (plans, has_param) = SubscriptionPlan::compile(sql, &tx, &auth).unwrap(); + let mut tx = SchemaViewer::new(tx, &auth); + let (plans, has_param) = SubscriptionPlan::compile(sql, &mut tx, &auth).unwrap(); let hash = QueryHash::from_string(sql, auth.caller(), has_param); Ok(Arc::new(Plan::new(plans, hash, sql.into()))) }) diff --git a/crates/core/src/subscription/query.rs b/crates/core/src/subscription/query.rs index 968a092e5d..31b81b3dfc 100644 --- a/crates/core/src/subscription/query.rs +++ b/crates/core/src/subscription/query.rs @@ -42,7 +42,7 @@ pub fn is_subscribe_to_all_tables(sql: &str) -> bool { pub fn compile_read_only_queryset( relational_db: &RelationalDB, auth: &AuthCtx, - tx: &Tx, + tx: &mut Tx, input: &str, ) -> Result, DBError> { let input = input.trim(); @@ -82,13 +82,13 @@ pub fn compile_read_only_queryset( /// Compile a string into a single read-only query. /// This returns an error if the string has multiple queries or mutations. -pub fn compile_read_only_query(auth: &AuthCtx, tx: &Tx, input: &str) -> Result { +pub fn compile_read_only_query(auth: &AuthCtx, tx: &mut Tx, input: &str) -> Result { if is_whitespace_or_empty(input) { return Err(SubscriptionError::Empty.into()); } - let tx = SchemaViewer::new(tx, auth); - let (plans, has_param) = SubscriptionPlan::compile(input, &tx, auth)?; + let mut tx = SchemaViewer::new(tx, auth); + let (plans, has_param) = SubscriptionPlan::compile(input, &mut tx, auth)?; let hash = QueryHash::from_string(input, auth.caller(), has_param); Ok(Plan::new(plans, hash, input.to_owned())) } @@ -97,7 +97,7 @@ pub fn compile_read_only_query(auth: &AuthCtx, tx: &Tx, input: &str) -> Result

( auth: &AuthCtx, - tx: &Tx, + tx: &mut Tx, input: &str, hash: QueryHash, hash_with_param: QueryHash, @@ -106,8 +106,8 @@ pub fn compile_query_with_hashes( return Err(SubscriptionError::Empty.into()); } - let tx = SchemaViewer::new(tx, auth); - let (plans, has_param) = SubscriptionPlan::compile(input, &tx, auth)?; + let mut tx = SchemaViewer::new(tx, auth); + let (plans, has_param) = SubscriptionPlan::compile(input, &mut tx, auth)?; if auth.bypass_rls() || has_param { // Note that when generating hashes for queries from owners, @@ -151,7 +151,7 @@ mod tests { use crate::db::relational_db::tests_utils::{ begin_mut_tx, begin_tx, insert, with_auto_commit, with_read_only, TestDB, }; - use crate::db::relational_db::MutTx; + use crate::db::relational_db::{tests_utils, MutTx}; use crate::host::module_host::{DatabaseTableUpdate, DatabaseUpdate, UpdatesRelValue}; use crate::sql::execute::collect_result; use crate::sql::execute::tests::run_for_testing; @@ -164,6 +164,7 @@ mod tests { use itertools::Itertools; use spacetimedb_client_api_messages::websocket::{BsatnFormat, CompressableQueryUpdate, Compression}; use spacetimedb_datastore::execution_context::Workload; + use spacetimedb_datastore::system_tables::ST_RESERVED_SEQUENCE_RANGE; use spacetimedb_lib::bsatn; use spacetimedb_lib::db::auth::{StAccess, StTableType}; use spacetimedb_lib::error::ResultTest; @@ -421,9 +422,9 @@ mod tests { db.create_table_for_test("a", schema, indexes)?; db.create_table_for_test("b", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); let sql = "SELECT b.* FROM b JOIN a ON b.n = a.n WHERE b.data > 200"; - let result = compile_read_only_query(&AuthCtx::for_testing(), &tx, sql); + let result = compile_read_only_query(&AuthCtx::for_testing(), &mut tx, sql); assert!(result.is_ok()); Ok(()) } @@ -454,10 +455,10 @@ mod tests { }; db.commit_tx(tx)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); let sql = "select * from test where b = 3"; - let mut exp = compile_sql(&db, &AuthCtx::for_testing(), &tx, sql)?; + let mut exp = compile_sql(&db, &AuthCtx::for_testing(), &mut tx, sql)?; let Some(CrudExpr::Query(query)) = exp.pop() else { panic!("unexpected query {:#?}", exp[0]); @@ -609,8 +610,8 @@ mod tests { AND MobileEntityState.location_z > 96000 \ AND MobileEntityState.location_z < 192000"; - let tx = begin_tx(&db); - let qset = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &tx, sql_query)?; + let mut tx = begin_tx(&db); + let qset = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &mut tx, sql_query)?; for q in qset { let result = run_query( @@ -684,7 +685,7 @@ mod tests { let indexes = &[ColId(0), ColId(1)]; db.create_table_for_test("rhs", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // All single table queries are supported let scans = [ @@ -696,7 +697,7 @@ mod tests { "SELECT * FROM lhs WHERE id > 5", ]; for scan in scans { - let expr = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &tx, scan)? + let expr = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &mut tx, scan)? .pop() .unwrap(); assert_eq!(expr.kind(), Supported::Select, "{scan}\n{expr:#?}"); @@ -705,7 +706,7 @@ mod tests { // Only index semijoins are supported let joins = ["SELECT lhs.* FROM lhs JOIN rhs ON lhs.id = rhs.id WHERE rhs.y < 10"]; for join in joins { - let expr = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &tx, join)? + let expr = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &mut tx, join)? .pop() .unwrap(); assert_eq!(expr.kind(), Supported::Semijoin, "{join}\n{expr:#?}"); @@ -718,7 +719,7 @@ mod tests { "SELECT * FROM lhs JOIN rhs ON lhs.id = rhs.id WHERE lhs.x < 10", ]; for join in joins { - match compile_read_only_queryset(&db, &AuthCtx::for_testing(), &tx, join) { + match compile_read_only_queryset(&db, &AuthCtx::for_testing(), &mut tx, join) { Err(DBError::Subscription(SubscriptionError::Unsupported(_)) | DBError::TypeError(_)) => (), x => panic!("Unexpected: {x:?}"), } @@ -756,10 +757,10 @@ mod tests { fn compile_query(db: &RelationalDB) -> ResultTest { with_read_only(db, |tx| { let auth = AuthCtx::for_testing(); - let tx = SchemaViewer::new(tx, &auth); + let mut tx = SchemaViewer::new(tx, &auth); // Should be answered using an index semijion let sql = "select lhs.* from lhs join rhs on lhs.id = rhs.id where rhs.y >= 2 and rhs.y <= 4"; - Ok(SubscriptionPlan::compile(sql, &tx, &auth) + Ok(SubscriptionPlan::compile(sql, &mut tx, &auth) .map(|(mut plans, _)| { assert_eq!(plans.len(), 1); plans.pop().unwrap() @@ -781,10 +782,10 @@ mod tests { fn compile_query(db: &RelationalDB) -> ResultTest { with_read_only(db, |tx| { let auth = AuthCtx::for_testing(); - let tx = SchemaViewer::new(tx, &auth); + let mut tx = SchemaViewer::new(tx, &auth); // Should be answered using an index semijion let sql = "select lhs.* from lhs join rhs on lhs.id = rhs.id where lhs.x >= 5 and lhs.x <= 7"; - Ok(SubscriptionPlan::compile(sql, &tx, &auth) + Ok(SubscriptionPlan::compile(sql, &mut tx, &auth) .map(|(mut plans, _)| { assert_eq!(plans.len(), 1); plans.pop().unwrap() @@ -1447,4 +1448,36 @@ mod tests { assert_eq!(metrics.index_seeks, 8); Ok(()) } + + // Verify calling views with params + // TODO: All testing use the old query compiler, so we can't test this yet. + #[test] + fn test_view_params() -> ResultTest<()> { + let db = TestDB::durable()?; + let schema = [("a", AlgebraicType::U8), ("b", AlgebraicType::I64)]; + let (_view_id, table_id) = tests_utils::create_view_for_test( + &db, + "my_view", + &schema, + ProductType::from([("x", AlgebraicType::U8)]), + true, + )?; + let arg_id = ST_RESERVED_SEQUENCE_RANGE as u64; + + with_auto_commit(&db, |tx| -> Result<_, DBError> { + tests_utils::insert_into_view(&db, tx, table_id, None, product![arg_id + 1, 0u8, 1i64])?; + tests_utils::insert_into_view(&db, tx, table_id, None, product![arg_id, 1u8, 2i64])?; + Ok(()) + })?; + + let mut tx = begin_tx(&db); + + let err = + compile_read_only_queryset(&db, &AuthCtx::for_testing(), &mut tx, "SELECT * FROM my_view(1)").unwrap_err(); + assert_eq!( + err.to_string(), + "InternalError: Read-only queries cannot create parameters".to_string() + ); + Ok(()) + } } diff --git a/crates/core/src/subscription/subscription.rs b/crates/core/src/subscription/subscription.rs index 1fb13e8ee5..5b14700c2f 100644 --- a/crates/core/src/subscription/subscription.rs +++ b/crates/core/src/subscription/subscription.rs @@ -615,7 +615,7 @@ impl AuthAccess for ExecutionSet { pub(crate) fn get_all( get_all_tables: F, relational_db: &RelationalDB, - tx: &T, + tx: &mut T, auth: &AuthCtx, ) -> Result, DBError> where @@ -627,8 +627,8 @@ where .filter(|t| t.table_type == StTableType::User && auth.has_read_access(t.table_access)) .map(|schema| { let sql = format!("SELECT * FROM {}", schema.table_name); - let tx = SchemaViewer::new(tx, auth); - SubscriptionPlan::compile(&sql, &tx, auth).map(|(plans, has_param)| { + let mut tx = SchemaViewer::new(tx, auth); + SubscriptionPlan::compile(&sql, &mut tx, auth).map(|(plans, has_param)| { Plan::new( plans, QueryHash::from_string( @@ -701,11 +701,11 @@ mod tests { let indexes = &[0.into(), 1.into()]; let rhs_id = db.create_table_for_test("rhs", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Should generate an index join since there is an index on `lhs.b`. // Should push the sargable range condition into the index join's probe side. let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c > 2 and rhs.c < 4 and rhs.d = 3"; - let exp = compile_sql(&db, &AuthCtx::for_testing(), &tx, sql)?.remove(0); + let exp = compile_sql(&db, &AuthCtx::for_testing(), &mut tx, sql)?.remove(0); let CrudExpr::Query(mut expr) = exp else { panic!("unexpected result from compilation: {exp:#?}"); @@ -781,11 +781,11 @@ mod tests { let indexes = &[0.into(), 1.into()]; let _ = db.create_table_for_test("rhs", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Should generate an index join since there is an index on `lhs.b`. // Should push the sargable range condition into the index join's probe side. let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c > 2 and rhs.c < 4 and rhs.d = 3"; - let exp = compile_sql(&db, &AuthCtx::for_testing(), &tx, sql)?.remove(0); + let exp = compile_sql(&db, &AuthCtx::for_testing(), &mut tx, sql)?.remove(0); let CrudExpr::Query(mut expr) = exp else { panic!("unexpected result from compilation: {exp:#?}"); @@ -865,12 +865,12 @@ mod tests { .create_table_for_test("rhs", schema, indexes) .expect("Failed to create_table_for_test rhs"); - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Should generate an index join since there is an index on `lhs.b`. // Should push the sargable range condition into the index join's probe side. let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c > 2 and rhs.c < 4 and rhs.d = 3"; - let exp = compile_sql(&db, &AuthCtx::for_testing(), &tx, sql) + let exp = compile_sql(&db, &AuthCtx::for_testing(), &mut tx, sql) .expect("Failed to compile_sql") .remove(0); diff --git a/crates/core/src/util/slow.rs b/crates/core/src/util/slow.rs index ea6701658c..17522e62b7 100644 --- a/crates/core/src/util/slow.rs +++ b/crates/core/src/util/slow.rs @@ -65,14 +65,14 @@ mod tests { use spacetimedb_vm::relation::MemTable; fn run_query(db: &Arc, sql: String) -> ResultTest { - let tx = begin_tx(db); - let q = compile_sql(db, &AuthCtx::for_testing(), &tx, &sql)?; + let mut tx = begin_tx(db); + let q = compile_sql(db, &AuthCtx::for_testing(), &mut tx, &sql)?; Ok(execute_for_testing(db, &sql, q)?.pop().unwrap()) } fn run_query_write(db: &Arc, sql: String) -> ResultTest<()> { - let tx = begin_tx(db); - let q = compile_sql(db, &AuthCtx::for_testing(), &tx, &sql)?; + let mut tx = begin_tx(db); + let q = compile_sql(db, &AuthCtx::for_testing(), &mut tx, &sql)?; drop(tx); execute_for_testing(db, &sql, q)?; @@ -92,10 +92,10 @@ mod tests { } Ok(()) })?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); let sql = "select * from test where x > 0"; - let q = compile_sql(&db, &AuthCtx::for_testing(), &tx, sql)?; + let q = compile_sql(&db, &AuthCtx::for_testing(), &mut tx, sql)?; let slow = SlowQueryLogger::new(sql, Some(Duration::from_millis(1)), tx.ctx.workload()); diff --git a/crates/expr/src/check.rs b/crates/expr/src/check.rs index 14e4c59070..c6190e2b10 100644 --- a/crates/expr/src/check.rs +++ b/crates/expr/src/check.rs @@ -9,12 +9,12 @@ use super::{ type_expr, type_proj, type_select, }; use crate::errors::{TableFunc, UnexpectedFunctionType}; -use crate::expr::{Expr, LeftDeepJoin, ProjectList, ProjectName, Relvar}; +use crate::expr::{Expr, FieldProject, LeftDeepJoin, ProjectList, ProjectName, Relvar}; use spacetimedb_lib::identity::AuthCtx; use spacetimedb_lib::AlgebraicType; use spacetimedb_primitives::{ArgId, TableId}; use spacetimedb_sats::algebraic_type::fmt::fmt_algebraic_type; -use spacetimedb_sats::ProductValue; +use spacetimedb_sats::{AlgebraicValue, ProductValue}; use spacetimedb_schema::schema::TableOrViewSchema; use spacetimedb_sql_parser::ast::{BinOp, SqlExpr, SqlLiteral}; use spacetimedb_sql_parser::{ @@ -35,9 +35,7 @@ pub trait SchemaView { self.table_id(name).and_then(|table_id| self.schema_for_table(table_id)) } - fn get_or_create_params(&self, params: &ProductValue) -> TypingResult { - Ok(ArgId::SENTINEL) - } + fn get_or_create_params(&mut self, params: ProductValue) -> TypingResult; } #[derive(Default)] @@ -60,9 +58,9 @@ pub trait TypeChecker { type Ast; type Set; - fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult; + fn type_ast(ast: Self::Ast, tx: &mut impl SchemaView) -> TypingResult; - fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult; + fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &mut impl SchemaView) -> TypingResult; fn type_view_params( schema: &TableOrViewSchema, @@ -152,25 +150,35 @@ pub trait TypeChecker { } fn type_params( + tx: &mut impl SchemaView, from: RelExpr, schema: Arc, alias: Box, params: Option, - ) -> RelExpr { + ) -> TypingResult { match params { - None => from, - Some(args) => RelExpr::FunCall( - Relvar { - schema, - alias, - delta: None, - }, - args, - ), + None => Ok(from), + Some(args) => { + let new_arg_id = tx.get_or_create_params(args)?; + let arg_id_col = schema.inner().get_column_by_name("arg_id").unwrap().col_pos; + + Ok(RelExpr::Select( + Box::new(from), + Expr::BinOp( + BinOp::Eq, + Box::new(Expr::Field(FieldProject { + table: alias, + field: arg_id_col.idx(), + ty: AlgebraicType::U64, + })), + Box::new(Expr::Value(AlgebraicValue::U64(new_arg_id.0), AlgebraicType::U64)), + ), + )) + } } } - fn type_from(from: SqlFrom, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult { + fn type_from(from: SqlFrom, vars: &mut Relvars, tx: &mut impl SchemaView) -> TypingResult { match from { SqlFrom::Expr(SqlIdent(name), SqlIdent(alias)) => { let schema = Self::type_relvar(tx, &name)?; @@ -202,7 +210,7 @@ pub trait TypeChecker { } let schema = Self::type_relvar(tx, &name)?; let arg = Self::type_view_params(&schema, vars, params)?; - let lhs = Box::new(Self::type_params(join, schema.clone(), alias.clone(), arg)); + let lhs = Box::new(Self::type_params(tx, join, schema.clone(), alias.clone(), arg)?); let rhs = Relvar { schema, @@ -237,7 +245,7 @@ pub trait TypeChecker { delta: None, }); - Ok(Self::type_params(from, schema, alias, arg)) + Self::type_params(tx, from, schema, alias, arg) } } } @@ -256,11 +264,11 @@ impl TypeChecker for SubChecker { type Ast = SqlSelect; type Set = SqlSelect; - fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult { + fn type_ast(ast: Self::Ast, tx: &mut impl SchemaView) -> TypingResult { Self::type_set(ast, &mut Relvars::default(), tx) } - fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult { + fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &mut impl SchemaView) -> TypingResult { match ast { SqlSelect { project, @@ -283,7 +291,7 @@ impl TypeChecker for SubChecker { } /// Parse and type check a subscription query -pub fn parse_and_type_sub(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> TypingResult<(ProjectName, bool)> { +pub fn parse_and_type_sub(sql: &str, tx: &mut impl SchemaView, auth: &AuthCtx) -> TypingResult<(ProjectName, bool)> { let ast = parse_subscription(sql)?; let has_param = ast.has_parameter(); let ast = ast.resolve_sender(auth.caller()); @@ -303,15 +311,16 @@ fn expect_table_type(expr: ProjectList) -> TypingResult { pub mod test_utils { use spacetimedb_lib::{db::raw_def::v9::RawModuleDefV9Builder, ProductType}; - use spacetimedb_primitives::TableId; - use spacetimedb_sats::AlgebraicType; + use spacetimedb_primitives::{ArgId, TableId}; + use spacetimedb_sats::{AlgebraicType, ProductValue}; use spacetimedb_schema::{ def::ModuleDef, schema::{Schema, TableOrViewSchema, TableSchema}, }; + use std::collections::HashMap; use std::sync::Arc; - use super::SchemaView; + use super::{SchemaView, TypingResult}; pub struct ViewInfo<'a> { pub(crate) name: &'a str, pub(crate) columns: &'a [(&'a str, AlgebraicType)], @@ -333,7 +342,38 @@ pub mod test_utils { builder.finish().try_into().expect("failed to generate module def") } - pub struct SchemaViewer(pub ModuleDef); + pub struct MockCallParams { + counter: u64, + params: HashMap, + } + + impl Default for MockCallParams { + fn default() -> Self { + Self::new() + } + } + + impl MockCallParams { + pub fn new() -> Self { + Self { + counter: 0, + params: HashMap::new(), + } + } + + pub fn get_or_insert(&mut self, value: ProductValue) -> ArgId { + if let Some(existing) = self.params.get(&value) { + *existing + } else { + self.counter += 1; + let arg_id = ArgId(self.counter - 1); + self.params.insert(value, arg_id); + arg_id + } + } + } + + pub struct SchemaViewer(pub ModuleDef, pub MockCallParams); impl SchemaView for SchemaViewer { fn table_id(&self, name: &str) -> Option { @@ -370,6 +410,10 @@ pub mod test_utils { fn rls_rules_for_table(&self, _: TableId) -> anyhow::Result>> { Ok(vec![]) } + + fn get_or_create_params(&mut self, params: ProductValue) -> TypingResult { + Ok(self.1.get_or_insert(params)) + } } } @@ -423,13 +467,13 @@ mod tests { } /// A wrapper around [super::parse_and_type_sub] that takes a dummy [AuthCtx] - fn parse_and_type_sub(sql: &str, tx: &impl SchemaView) -> TypingResult { + fn parse_and_type_sub(sql: &str, tx: &mut impl SchemaView) -> TypingResult { super::parse_and_type_sub(sql, tx, &AuthCtx::for_testing()).map(|(plan, _)| plan) } #[test] fn valid_literals() { - let tx = SchemaViewer(module_def()); + let mut tx = SchemaViewer(module_def(), Default::default()); struct TestCase { sql: &'static str, @@ -498,27 +542,27 @@ mod tests { msg: "timestamp ms with timezone", }, ] { - let result = parse_and_type_sub(sql, &tx); + let result = parse_and_type_sub(sql, &mut tx); assert!(result.is_ok(), "name: {}, error: {}", msg, result.unwrap_err()); } } #[test] fn valid_literals_for_type() { - let tx = SchemaViewer(module_def()); + let mut tx = SchemaViewer(module_def(), Default::default()); for ty in [ "i8", "u8", "i16", "u16", "i32", "u32", "i64", "u64", "f32", "f64", "i128", "u128", "i256", "u256", ] { let sql = format!("select * from t where {ty} = 127"); - let result = parse_and_type_sub(&sql, &tx); + let result = parse_and_type_sub(&sql, &mut tx); assert!(result.is_ok(), "Failed to parse {ty}: {}", result.unwrap_err()); } } #[test] fn invalid_literals() { - let tx = SchemaViewer(module_def()); + let mut tx = SchemaViewer(module_def(), Default::default()); struct TestCase { sql: &'static str, @@ -547,14 +591,14 @@ mod tests { msg: "Float as integer", }, ] { - let result = parse_and_type_sub(sql, &tx); + let result = parse_and_type_sub(sql, &mut tx); assert!(result.is_err(), "{msg}"); } } #[test] fn valid() { - let tx = SchemaViewer(module_def()); + let mut tx = SchemaViewer(module_def(), Default::default()); struct TestCase { sql: &'static str, @@ -611,14 +655,14 @@ mod tests { msg: "Type inner join + projection", }, ] { - let result = parse_and_type_sub(sql, &tx); + let result = parse_and_type_sub(sql, &mut tx); assert!(result.is_ok(), "{msg}"); } } #[test] fn invalid() { - let tx = SchemaViewer(module_def()); + let mut tx = SchemaViewer(module_def(), Default::default()); struct TestCase { sql: &'static str, @@ -683,7 +727,7 @@ mod tests { msg: "Columns must be qualified in join expressions", }, ] { - let result = parse_and_type_sub(sql, &tx); + let result = parse_and_type_sub(sql, &mut tx); assert!(result.is_err(), "{msg}"); } } diff --git a/crates/expr/src/errors.rs b/crates/expr/src/errors.rs index 4a91980fde..3129ab476f 100644 --- a/crates/expr/src/errors.rs +++ b/crates/expr/src/errors.rs @@ -172,6 +172,8 @@ pub enum TypingError { FilterReturnType(#[from] FilterReturnType), #[error(transparent)] TableFunc(#[from] TableFunc), + #[error("InternalError: Read-only queries cannot create parameters")] + ParamsReadOnly, #[error(transparent)] Other(#[from] anyhow::Error), } diff --git a/crates/expr/src/expr.rs b/crates/expr/src/expr.rs index f6556f3456..5ae8484c67 100644 --- a/crates/expr/src/expr.rs +++ b/crates/expr/src/expr.rs @@ -238,35 +238,6 @@ impl ProjectList { Self::Agg(_, _, name, ty) => f(name, ty), } } - - /// Iterate over the function calls in this projection list - pub fn for_each_fun_call( - &mut self, - f: &mut impl FnMut(ProductValue) -> Result, - ) -> Result<(), TypingError> { - match self { - ProjectList::Name(input) => { - for proj in input { - match proj { - ProjectName::None(expr) | ProjectName::Some(expr, _) => { - expr.for_each_fun_call(f)?; - } - } - } - } - ProjectList::List(input, _) => { - for expr in input { - expr.for_each_fun_call(f)?; - } - } - ProjectList::Limit(input, _) => { - input.for_each_fun_call(f)?; - } - ProjectList::Agg(_, _, _, _) => {} - } - - Ok(()) - } } /// A logical relational expression @@ -398,31 +369,6 @@ impl RelExpr { _ => None, } } - - fn for_each_fun_call( - &mut self, - f: &mut impl FnMut(ProductValue) -> Result, - ) -> Result<(), TypingError> { - // For function calls, we need to filter by the argument id - if let RelExpr::FunCall(relvar, param) = self { - let new_arg_id = f(param.clone())?; - let arg_id_col = relvar.schema.inner().get_column_by_name("arg_id").unwrap().col_pos; - - *self = RelExpr::Select( - Box::new(RelExpr::RelVar(relvar.clone())), - Expr::BinOp( - BinOp::Eq, - Box::new(Expr::Field(FieldProject { - table: relvar.alias.clone(), - field: arg_id_col.idx(), - ty: AlgebraicType::U64, - })), - Box::new(Expr::Value(AlgebraicValue::U64(new_arg_id.0), AlgebraicType::U64)), - ), - ); - } - Ok(()) - } } /// A left deep binary cross product diff --git a/crates/expr/src/rls.rs b/crates/expr/src/rls.rs index 89cd9e5a2a..5a1cc312dc 100644 --- a/crates/expr/src/rls.rs +++ b/crates/expr/src/rls.rs @@ -12,7 +12,7 @@ use crate::{ /// The main driver of RLS resolution for subscription queries. /// Mainly a wrapper around [resolve_views_for_expr]. pub fn resolve_views_for_sub( - tx: &impl SchemaView, + tx: &mut impl SchemaView, expr: ProjectName, auth: &AuthCtx, has_param: &mut bool, @@ -54,7 +54,11 @@ pub fn resolve_views_for_sub( /// The main driver of RLS resolution for sql queries. /// Mainly a wrapper around [resolve_views_for_expr]. -pub fn resolve_views_for_sql(tx: &impl SchemaView, expr: ProjectList, auth: &AuthCtx) -> anyhow::Result { +pub fn resolve_views_for_sql( + tx: &mut impl SchemaView, + expr: ProjectList, + auth: &AuthCtx, +) -> anyhow::Result { // RLS does not apply to the database owner if auth.bypass_rls() { return Ok(expr); @@ -62,44 +66,41 @@ pub fn resolve_views_for_sql(tx: &impl SchemaView, expr: ProjectList, auth: &Aut // The subscription language is a subset of the sql language. // Use the subscription helper if this is a compliant expression. // Use the generic resolver otherwise. - let resolve_for_sub = |expr| resolve_views_for_sub(tx, expr, auth, &mut false); - let resolve_for_sql = |expr| { - resolve_views_for_expr( - // Use all default values - tx, - expr, - None, - Rc::new(ResolveList::None), - &mut false, - &mut 0, - auth, - ) - }; match expr { - ProjectList::Limit(expr, n) => Ok(ProjectList::Limit(Box::new(resolve_views_for_sql(tx, *expr, auth)?), n)), + ProjectList::Limit(expr, n) => { + let expr = resolve_views_for_sql(tx, *expr, auth)?; + Ok(ProjectList::Limit(Box::new(expr), n)) + } + ProjectList::Name(exprs) => Ok(ProjectList::Name( exprs .into_iter() - .map(resolve_for_sub) + .map(|expr| resolve_views_for_sub(tx, expr, auth, &mut false)) .collect::, _>>()? .into_iter() .flatten() .collect(), )), + ProjectList::List(exprs, fields) => Ok(ProjectList::List( exprs .into_iter() - .map(resolve_for_sql) + .map(|expr| { + resolve_views_for_expr(tx, expr, None, Rc::new(ResolveList::None), &mut false, &mut 0, auth) + }) .collect::, _>>()? .into_iter() .flatten() .collect(), fields, )), + ProjectList::Agg(exprs, AggType::Count, name, ty) => Ok(ProjectList::Agg( exprs .into_iter() - .map(resolve_for_sql) + .map(|expr| { + resolve_views_for_expr(tx, expr, None, Rc::new(ResolveList::None), &mut false, &mut 0, auth) + }) .collect::, _>>()? .into_iter() .flatten() @@ -203,7 +204,7 @@ impl ResolveList { /// i.e. the subtree rooted at `a` in the above example, /// must be pushed below the leftmost leaf node of the view expansion. fn resolve_views_for_expr( - tx: &impl SchemaView, + tx: &mut impl SchemaView, view: RelExpr, return_table_id: Option, resolving: Rc, @@ -473,21 +474,23 @@ mod tests { use pretty_assertions as pretty; use spacetimedb_lib::{identity::AuthCtx, AlgebraicType, AlgebraicValue, Identity, ProductType}; - use spacetimedb_primitives::TableId; + use spacetimedb_primitives::{ArgId, TableId}; + use spacetimedb_sats::ProductValue; use spacetimedb_schema::{ def::ModuleDef, schema::{Schema, TableOrViewSchema, TableSchema}, }; use spacetimedb_sql_parser::ast::BinOp; + use super::resolve_views_for_sub; + use crate::check::test_utils::MockCallParams; + use crate::check::TypingResult; use crate::{ check::{parse_and_type_sub, test_utils::build_module_def, SchemaView}, expr::{Expr, FieldProject, LeftDeepJoin, ProjectName, RelExpr, Relvar}, }; - use super::resolve_views_for_sub; - - pub struct SchemaViewer(pub ModuleDef); + pub struct SchemaViewer(pub ModuleDef, pub MockCallParams); impl SchemaView for SchemaViewer { fn table_id(&self, name: &str) -> Option { @@ -526,6 +529,10 @@ mod tests { _ => Ok(vec![]), } } + + fn get_or_create_params(&mut self, params: ProductValue) -> TypingResult { + Ok(self.1.get_or_insert(params)) + } } fn module_def() -> ModuleDef { @@ -549,17 +556,17 @@ mod tests { } /// Parse, type check, and resolve RLS rules - fn resolve(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> anyhow::Result> { + fn resolve(sql: &str, tx: &mut impl SchemaView, auth: &AuthCtx) -> anyhow::Result> { let (expr, _) = parse_and_type_sub(sql, tx, auth)?; resolve_views_for_sub(tx, expr, auth, &mut false) } #[test] fn test_rls_for_owner() -> anyhow::Result<()> { - let tx = SchemaViewer(module_def()); + let mut tx = SchemaViewer(module_def(), Default::default()); let auth = AuthCtx::new(Identity::ONE, Identity::ONE); let sql = "select * from users"; - let resolved = resolve(sql, &tx, &auth)?; + let resolved = resolve(sql, &mut tx, &auth)?; let users_schema = tx.schema("users").unwrap(); @@ -577,10 +584,10 @@ mod tests { #[test] fn test_rls_for_non_owner() -> anyhow::Result<()> { - let tx = SchemaViewer(module_def()); + let mut tx = SchemaViewer(module_def(), Default::default()); let auth = AuthCtx::new(Identity::ZERO, Identity::ONE); let sql = "select * from users"; - let resolved = resolve(sql, &tx, &auth)?; + let resolved = resolve(sql, &mut tx, &auth)?; let users_schema = tx.schema("users").unwrap(); @@ -612,10 +619,10 @@ mod tests { #[test] fn test_multiple_rls_rules_for_table() -> anyhow::Result<()> { - let tx = SchemaViewer(module_def()); + let mut tx = SchemaViewer(module_def(), Default::default()); let auth = AuthCtx::new(Identity::ZERO, Identity::ONE); let sql = "select * from player where level_num = 5"; - let resolved = resolve(sql, &tx, &auth)?; + let resolved = resolve(sql, &mut tx, &auth)?; let users_schema = tx.schema("users").unwrap(); let admins_schema = tx.schema("admins").unwrap(); diff --git a/crates/expr/src/statement.rs b/crates/expr/src/statement.rs index 1fc882cf5c..317cd72a7a 100644 --- a/crates/expr/src/statement.rs +++ b/crates/expr/src/statement.rs @@ -394,11 +394,11 @@ impl TypeChecker for SqlChecker { type Ast = SqlSelect; type Set = SqlSelect; - fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult { + fn type_ast(ast: Self::Ast, tx: &mut impl SchemaView) -> TypingResult { Self::type_set(ast, &mut Relvars::default(), tx) } - fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult { + fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &mut impl SchemaView) -> TypingResult { match ast { SqlSelect { project, @@ -439,7 +439,7 @@ impl TypeChecker for SqlChecker { } } -pub fn parse_and_type_sql(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> TypingResult { +pub fn parse_and_type_sql(sql: &str, tx: &mut impl SchemaView, auth: &AuthCtx) -> TypingResult { match parse_sql(sql)?.resolve_sender(auth.caller()) { SqlAst::Select(ast) => Ok(Statement::Select(SqlChecker::type_ast(ast, tx)?)), SqlAst::Insert(insert) => Ok(Statement::DML(DML::Insert(type_insert(insert, tx)?))), @@ -451,7 +451,7 @@ pub fn parse_and_type_sql(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> Ty } /// Parse and type check a *general* query into a [StatementCtx]. -pub fn compile_sql_stmt<'a>(sql: &'a str, tx: &impl SchemaView, auth: &AuthCtx) -> TypingResult> { +pub fn compile_sql_stmt<'a>(sql: &'a str, tx: &mut impl SchemaView, auth: &AuthCtx) -> TypingResult> { let statement = parse_and_type_sql(sql, tx, auth)?; Ok(StatementCtx { statement, @@ -520,13 +520,13 @@ mod tests { } /// A wrapper around [super::parse_and_type_sql] that takes a dummy [AuthCtx] - fn parse_and_type_sql(sql: &str, tx: &impl SchemaView) -> TypingResult { + fn parse_and_type_sql(sql: &str, tx: &mut impl SchemaView) -> TypingResult { super::parse_and_type_sql(sql, tx, &AuthCtx::for_testing()) } #[test] fn valid() { - let tx = SchemaViewer(module_def()); + let mut tx = SchemaViewer(module_def(), Default::default()); for sql in [ "select str from t", @@ -534,14 +534,14 @@ mod tests { "select t.str, arr from t", "select * from t limit 5", ] { - let result = parse_and_type_sql(sql, &tx); + let result = parse_and_type_sql(sql, &mut tx); assert!(result.is_ok()); } } #[test] fn invalid() { - let tx = SchemaViewer(module_def()); + let mut tx = SchemaViewer(module_def(), Default::default()); for sql in [ // Unqualified columns in a join @@ -551,7 +551,7 @@ mod tests { // Unqualified name in join expression "select t.* from t join s on t.u32 = s.u32 where bytes = 0xABCD", ] { - let result = parse_and_type_sql(sql, &tx); + let result = parse_and_type_sql(sql, &mut tx); assert!(result.is_err()); } } @@ -581,7 +581,7 @@ mod tests { #[test] fn views() { - let tx = SchemaViewer(module_def()); + let mut tx = SchemaViewer(module_def(), Default::default()); struct TestCase { sql: &'static str, @@ -606,7 +606,7 @@ mod tests { msg: "Function call returning view with parameters", }, ] { - let result = parse_and_type_sql(sql, &tx).inspect_err(|e| { + let result = parse_and_type_sql(sql, &mut tx).inspect_err(|e| { panic!("Expected OK for `{sql}` but got error: {e}"); }); assert!(result.is_ok(), "{msg}: {sql}"); @@ -630,14 +630,14 @@ mod tests { msg: "`v` does not take parameters", }, ] { - let result = parse_and_type_sql(sql, &tx); + let result = parse_and_type_sql(sql, &mut tx); assert!(result.is_err(), "{msg}"); } } #[test] fn params_validation() { - let tx = SchemaViewer(module_def()); + let mut tx = SchemaViewer(module_def(), Default::default()); struct TestCase { sql: &'static str, @@ -670,7 +670,7 @@ mod tests { msg: "Unexpected function type. Expected: (U32, String) != Inferred: (U32, String, Num?)", }, ] { - let result = parse_and_type_sql(sql, &tx); + let result = parse_and_type_sql(sql, &mut tx); if msg == "Correct parameters" { assert!(result.is_ok(), "{msg}: {sql}"); } else if let Err(err) = &result { diff --git a/crates/physical-plan/src/plan.rs b/crates/physical-plan/src/plan.rs index 597ab6dd6e..20b17d1e64 100644 --- a/crates/physical-plan/src/plan.rs +++ b/crates/physical-plan/src/plan.rs @@ -1402,6 +1402,7 @@ mod tests { use std::sync::Arc; use pretty_assertions::assert_eq; + use spacetimedb_expr::check::test_utils::MockCallParams; use spacetimedb_expr::{ check::{SchemaView, TypingResult}, expr::ProjectName, @@ -1410,9 +1411,9 @@ mod tests { use spacetimedb_lib::{ db::auth::{StAccess, StTableType}, identity::AuthCtx, - AlgebraicType, AlgebraicValue, + AlgebraicType, AlgebraicValue, ProductValue, }; - use spacetimedb_primitives::{ColId, ColList, ColSet, TableId, ViewId}; + use spacetimedb_primitives::{ArgId, ColId, ColList, ColSet, TableId, ViewId}; use spacetimedb_schema::def::ViewParamDefSimple; use spacetimedb_schema::identifier::Identifier; use spacetimedb_schema::schema::ViewDefInfo; @@ -1431,6 +1432,7 @@ mod tests { struct SchemaViewer { schemas: Vec>, + params: MockCallParams, } impl SchemaView for SchemaViewer { @@ -1448,6 +1450,10 @@ mod tests { fn rls_rules_for_table(&self, _: TableId) -> anyhow::Result>> { Ok(vec![]) } + + fn get_or_create_params(&mut self, params: ProductValue) -> TypingResult { + Ok(self.params.get_or_insert(params)) + } } fn schema_with_params( @@ -1534,7 +1540,7 @@ mod tests { } /// A wrapper around [spacetimedb_expr::check::parse_and_type_sub] that takes a dummy [AuthCtx] - fn parse_and_type_sub(sql: &str, tx: &impl SchemaView) -> TypingResult { + fn parse_and_type_sub(sql: &str, tx: &mut impl SchemaView) -> TypingResult { spacetimedb_expr::check::parse_and_type_sub(sql, tx, &AuthCtx::for_testing()).map(|(plan, _)| plan) } @@ -1552,14 +1558,15 @@ mod tests { Some(0), )); - let db = SchemaViewer { + let mut db = SchemaViewer { schemas: vec![t.clone()], + params: Default::default(), }; let sql = "select * from t"; let auth = AuthCtx::for_testing(); - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); match pp { @@ -1584,14 +1591,15 @@ mod tests { Some(0), )); - let db = SchemaViewer { + let mut db = SchemaViewer { schemas: vec![t.clone()], + params: Default::default(), }; let sql = "select * from t where x = 5"; let auth = AuthCtx::for_testing(); - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); match pp { @@ -1678,8 +1686,9 @@ mod tests { Some(0), )); - let db = SchemaViewer { + let mut db = SchemaViewer { schemas: vec![u.clone(), l.clone(), b.clone()], + params: Default::default(), }; let sql = " @@ -1691,7 +1700,7 @@ mod tests { where u.identity = 5 "; let auth = AuthCtx::for_testing(); - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); // Plan: @@ -1870,8 +1879,9 @@ mod tests { Some(0), )); - let db = SchemaViewer { + let mut db = SchemaViewer { schemas: vec![m.clone(), w.clone(), p.clone()], + params: Default::default(), }; let sql = " @@ -1884,7 +1894,7 @@ mod tests { where 5 = m.employee and 5 = v.employee "; let auth = AuthCtx::for_testing(); - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); // Plan: @@ -2052,13 +2062,14 @@ mod tests { None, )); - let db = SchemaViewer { + let mut db = SchemaViewer { schemas: vec![t.clone()], + params: Default::default(), }; let sql = "select * from t where x = 3 and y = 4 and z = 5"; let auth = AuthCtx::for_testing(); - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); // Select index on (x, y, z) @@ -2081,7 +2092,7 @@ mod tests { // Test permutations of the same query let sql = "select * from t where z = 5 and y = 4 and x = 3"; - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); match pp { @@ -2102,7 +2113,7 @@ mod tests { }; let sql = "select * from t where x = 3 and y = 4"; - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); // Select index on x @@ -2130,7 +2141,7 @@ mod tests { }; let sql = "select * from t where w = 5 and x = 4"; - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); // Select index on x @@ -2158,7 +2169,7 @@ mod tests { }; let sql = "select * from t where y = 1"; - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); // Do not select index on (y, z) @@ -2173,7 +2184,7 @@ mod tests { // Select index on [y, z] let sql = "select * from t where y = 1 and z = 2"; - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); match pp { @@ -2192,7 +2203,7 @@ mod tests { // Check permutations of the same query let sql = "select * from t where z = 2 and y = 1"; - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); match pp { @@ -2211,7 +2222,7 @@ mod tests { // Select index on (y, z) and filter on (w) let sql = "select * from t where w = 1 and y = 2 and z = 3"; - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); let plan = match pp { @@ -2251,12 +2262,13 @@ mod tests { None, )); - let db = SchemaViewer { + let mut db = SchemaViewer { schemas: vec![t.clone()], + params: Default::default(), }; - let compile = |sql| { - let stmt = parse_and_type_sql(sql, &db, &AuthCtx::for_testing()).unwrap(); + let mut compile = |sql| { + let stmt = parse_and_type_sql(sql, &mut db, &AuthCtx::for_testing()).unwrap(); let Statement::Select(select) = stmt else { unreachable!() }; @@ -2322,20 +2334,20 @@ mod tests { Some(&[("param_id", AlgebraicType::U64)]), )); - let db = SchemaViewer { + let mut db = SchemaViewer { schemas: vec![t.clone(), v.clone()], + params: Default::default(), }; let sql = "select * from v(0)"; let auth = AuthCtx::for_testing(); - let lp = parse_and_type_sub(sql, &db).unwrap(); - dbg!(&lp); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); - dbg!(&pp); match pp { ProjectPlan::None(PhysicalPlan::Filter(input, PhysicalExpr::BinOp(BinOp::Eq, field, value))) => { - assert!(matches!(*field, PhysicalExpr::Field(TupleField { field_pos: 1, .. }))); + // This is the internal parameter filter + assert!(matches!(*field, PhysicalExpr::Field(TupleField { field_pos: 0, .. }))); assert!(matches!(*value, PhysicalExpr::Value(AlgebraicValue::U64(0)))); match *input { @@ -2349,12 +2361,12 @@ mod tests { }; let sql = "select * from v(0) as x JOIN t ON x.id = t.id"; - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); match pp { ProjectPlan::None(PhysicalPlan::Filter(_, PhysicalExpr::BinOp(BinOp::Eq, field, value))) => { - assert!(matches!(*field, PhysicalExpr::Field(TupleField { field_pos: 1, .. }))); + assert!(matches!(*field, PhysicalExpr::Field(TupleField { field_pos: 0, .. }))); assert!(matches!(*value, PhysicalExpr::Value(AlgebraicValue::U64(0)))); } proj => panic!("unexpected project: {proj:#?}"), diff --git a/crates/query/src/lib.rs b/crates/query/src/lib.rs index 96efff60dd..7647efc744 100644 --- a/crates/query/src/lib.rs +++ b/crates/query/src/lib.rs @@ -4,8 +4,6 @@ use spacetimedb_execution::{ pipelined::ProjectListExecutor, Datastore, DeltaStore, }; -use spacetimedb_expr::errors::TypingError; -use spacetimedb_expr::expr::CallParams; use spacetimedb_expr::{ check::{parse_and_type_sub, SchemaView}, expr::ProjectList, @@ -17,30 +15,15 @@ use spacetimedb_physical_plan::{ compile::{compile_dml_plan, compile_select, compile_select_list}, plan::{ProjectListPlan, ProjectPlan}, }; -use spacetimedb_primitives::{ArgId, TableId}; -use std::collections::HashMap; +use spacetimedb_primitives::TableId; /// DIRTY HACK ALERT: Maximum allowed length, in UTF-8 bytes, of SQL queries. /// Any query longer than this will be rejected. /// This prevents a stack overflow when compiling queries with deeply-nested `AND` and `OR` conditions. const MAX_SQL_LENGTH: usize = 50_000; - -pub trait CallParamsExt { - fn get_arg(&self, params: &ProductValue) -> Result; -} - -pub struct MockCallParams { - params: HashMap, -} -impl CallParamsExt for MockCallParams { - fn get_arg(&self, params: &ProductValue) -> Result { - todo!() - } -} - pub fn compile_subscription( sql: &str, - tx: &impl SchemaView, + tx: &mut impl SchemaView, auth: &AuthCtx, ) -> Result<(Vec, TableId, Box, bool)> { if sql.len() > MAX_SQL_LENGTH { @@ -72,7 +55,7 @@ pub fn compile_subscription( } /// A utility for parsing and type checking a sql statement -pub fn compile_sql_stmt(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> Result { +pub fn compile_sql_stmt(sql: &str, tx: &mut impl SchemaView, auth: &AuthCtx) -> Result { if sql.len() > MAX_SQL_LENGTH { bail!("SQL query exceeds maximum allowed length: \"{sql:.120}...\"") } diff --git a/crates/subscription/src/lib.rs b/crates/subscription/src/lib.rs index d0f4668459..4b53147046 100644 --- a/crates/subscription/src/lib.rs +++ b/crates/subscription/src/lib.rs @@ -508,7 +508,7 @@ impl SubscriptionPlan { } /// Generate a plan for incrementally maintaining a subscription - pub fn compile(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> Result<(Vec, bool)> { + pub fn compile(sql: &str, tx: &mut impl SchemaView, auth: &AuthCtx) -> Result<(Vec, bool)> { let (plans, return_id, return_name, has_param) = compile_subscription(sql, tx, auth)?; /// Does this plan have any non-index joins?