mirror of
https://github.com/clockworklabs/SpacetimeDB.git
synced 2026-05-14 19:58:24 -04:00
Making tx mut so we can auto-create params on compilation
This commit is contained in:
@@ -126,10 +126,10 @@ fn eval(c: &mut Criterion) {
|
||||
// A benchmark runner for the new query engine
|
||||
let bench_query = |c: &mut Criterion, name, sql| {
|
||||
c.bench_function(name, |b| {
|
||||
let tx = raw.db.begin_tx(Workload::Subscribe);
|
||||
let mut tx = raw.db.begin_tx(Workload::Subscribe);
|
||||
let auth = AuthCtx::for_testing();
|
||||
let schema_viewer = &SchemaViewer::new(&tx, &auth);
|
||||
let (plans, table_id, table_name, _) = compile_subscription(sql, schema_viewer, &auth).unwrap();
|
||||
let mut schema_viewer = SchemaViewer::new(&mut tx, &auth);
|
||||
let (plans, table_id, table_name, _) = compile_subscription(sql, &mut schema_viewer, &auth).unwrap();
|
||||
let plans = plans
|
||||
.into_iter()
|
||||
.map(|plan| plan.optimize(&auth).unwrap())
|
||||
@@ -155,8 +155,8 @@ fn eval(c: &mut Criterion) {
|
||||
|
||||
let bench_eval = |c: &mut Criterion, name, sql| {
|
||||
c.bench_function(name, |b| {
|
||||
let tx = raw.db.begin_tx(Workload::Update);
|
||||
let query = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), &tx, sql).unwrap();
|
||||
let mut tx = raw.db.begin_tx(Workload::Update);
|
||||
let query = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), &mut tx, sql).unwrap();
|
||||
let query: ExecutionSet = query.into();
|
||||
|
||||
b.iter(|| {
|
||||
@@ -207,11 +207,11 @@ fn eval(c: &mut Criterion) {
|
||||
// A passthru executed independently of the database.
|
||||
let select_lhs = "select * from footprint";
|
||||
let select_rhs = "select * from location";
|
||||
let tx = &raw.db.begin_tx(Workload::Update);
|
||||
let query_lhs = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), tx, select_lhs).unwrap();
|
||||
let query_rhs = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), tx, select_rhs).unwrap();
|
||||
let mut tx = raw.db.begin_tx(Workload::Update);
|
||||
let query_lhs = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), &mut tx, select_lhs).unwrap();
|
||||
let query_rhs = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), &mut tx, select_rhs).unwrap();
|
||||
let query = ExecutionSet::from_iter(query_lhs.into_iter().chain(query_rhs));
|
||||
let tx = &tx.into();
|
||||
let tx = &(&mut tx).into();
|
||||
|
||||
b.iter(|| drop(black_box(query.eval_incr_for_test(&raw.db, tx, &update, None))))
|
||||
});
|
||||
@@ -226,10 +226,10 @@ fn eval(c: &mut Criterion) {
|
||||
from footprint join location on footprint.entity_id = location.entity_id \
|
||||
where location.chunk_index = {chunk_index}"
|
||||
);
|
||||
let tx = &raw.db.begin_tx(Workload::Update);
|
||||
let query = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), tx, &join).unwrap();
|
||||
let mut tx = raw.db.begin_tx(Workload::Update);
|
||||
let query = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), &mut tx, &join).unwrap();
|
||||
let query: ExecutionSet = query.into();
|
||||
let tx = &tx.into();
|
||||
let tx = &(&mut tx).into();
|
||||
|
||||
b.iter(|| drop(black_box(query.eval_incr_for_test(&raw.db, tx, &update, None))));
|
||||
});
|
||||
|
||||
@@ -181,8 +181,8 @@ mod tests {
|
||||
}
|
||||
|
||||
fn num_rows_for(db: &RelationalDB, sql: &str) -> u64 {
|
||||
let tx = begin_tx(db);
|
||||
match &*compile_sql(db, &AuthCtx::for_testing(), &tx, sql).expect("Failed to compile sql") {
|
||||
let mut tx = begin_tx(db);
|
||||
match &*compile_sql(db, &AuthCtx::for_testing(), &mut tx, sql).expect("Failed to compile sql") {
|
||||
[CrudExpr::Query(expr)] => num_rows(&tx, expr),
|
||||
exprs => panic!("unexpected result from compilation: {exprs:#?}"),
|
||||
}
|
||||
@@ -191,10 +191,10 @@ mod tests {
|
||||
/// Using the new query plan
|
||||
fn new_row_estimate(db: &RelationalDB, sql: &str) -> u64 {
|
||||
let auth = AuthCtx::for_testing();
|
||||
let tx = begin_tx(db);
|
||||
let tx = SchemaViewer::new(&tx, &auth);
|
||||
let mut tx = begin_tx(db);
|
||||
let mut tx = SchemaViewer::new(&mut tx, &auth);
|
||||
|
||||
compile_subscription(sql, &tx, &auth)
|
||||
compile_subscription(sql, &mut tx, &auth)
|
||||
.map(|(plans, ..)| plans)
|
||||
.expect("failed to compile sql query")
|
||||
.into_iter()
|
||||
|
||||
@@ -1869,7 +1869,7 @@ impl ModuleHost {
|
||||
let metrics = self
|
||||
.on_module_thread("one_off_query", move || {
|
||||
let (tx_offset_sender, tx_offset_receiver) = oneshot::channel();
|
||||
let tx = scopeguard::guard(db.begin_tx(Workload::Sql), |tx| {
|
||||
let mut tx = scopeguard::guard(db.begin_tx(Workload::Sql), |tx| {
|
||||
let (tx_offset, tx_metrics, reducer) = db.release_tx(tx);
|
||||
let _ = tx_offset_sender.send(tx_offset);
|
||||
db.report_read_tx_metrics(reducer, tx_metrics);
|
||||
@@ -1878,7 +1878,7 @@ impl ModuleHost {
|
||||
// We wrap the actual query in a closure so we can use ? to handle errors without making
|
||||
// the entire transaction abort with an error.
|
||||
let result: Result<(OneOffTable<F>, ExecutionMetrics), anyhow::Error> = (|| {
|
||||
let tx = SchemaViewer::new(&*tx, &auth);
|
||||
let mut tx = SchemaViewer::new(&mut *tx, &auth);
|
||||
|
||||
let (
|
||||
// A query may compile down to several plans.
|
||||
@@ -1888,7 +1888,7 @@ impl ModuleHost {
|
||||
_,
|
||||
table_name,
|
||||
_,
|
||||
) = compile_subscription(&query, &tx, &auth)?;
|
||||
) = compile_subscription(&query, &mut tx, &auth)?;
|
||||
|
||||
// Optimize each fragment
|
||||
let optimized = plans
|
||||
|
||||
@@ -1126,10 +1126,10 @@ impl InstanceCommon {
|
||||
|
||||
// Views bypass RLS, since views should enforce their own access control procedurally.
|
||||
let auth = AuthCtx::for_current(self.info.database_identity);
|
||||
let schema_view = SchemaViewer::new(&*tx, &auth);
|
||||
let mut schema_view = SchemaViewer::new(&mut *tx, &auth);
|
||||
|
||||
// Compile to subscription plans.
|
||||
let (plans, has_params) = SubscriptionPlan::compile(the_query, &schema_view, &auth)?;
|
||||
let (plans, has_params) = SubscriptionPlan::compile(the_query, &mut schema_view, &auth)?;
|
||||
ensure!(
|
||||
!has_params,
|
||||
"parameterized SQL is not supported for view materialization yet"
|
||||
|
||||
@@ -5,6 +5,7 @@ use spacetimedb_data_structures::map::{HashCollectionExt as _, IntMap};
|
||||
use spacetimedb_datastore::locking_tx_datastore::state_view::StateView;
|
||||
use spacetimedb_datastore::system_tables::{StRowLevelSecurityFields, ST_ROW_LEVEL_SECURITY_ID};
|
||||
use spacetimedb_expr::check::{SchemaView, TypingResult};
|
||||
use spacetimedb_expr::errors::TypingError;
|
||||
use spacetimedb_expr::statement::compile_sql_stmt;
|
||||
use spacetimedb_lib::identity::AuthCtx;
|
||||
use spacetimedb_primitives::{ArgId, ColId, TableId};
|
||||
@@ -22,7 +23,7 @@ use sqlparser::ast::{
|
||||
};
|
||||
use sqlparser::dialect::PostgreSqlDialect;
|
||||
use sqlparser::parser::Parser;
|
||||
use std::ops::Deref;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Simplify to detect features of the syntax we don't support yet
|
||||
@@ -477,7 +478,7 @@ fn compile_where(table: &From, filter: Option<SqlExpr>) -> Result<Option<Selecti
|
||||
}
|
||||
|
||||
pub struct SchemaViewer<'a, T> {
|
||||
pub(crate) tx: &'a T,
|
||||
tx: &'a mut T,
|
||||
auth: &'a AuthCtx,
|
||||
}
|
||||
|
||||
@@ -489,6 +490,12 @@ impl<T> Deref for SchemaViewer<'_, T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> DerefMut for SchemaViewer<'_, T> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
self.tx
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: StateView> SchemaView for SchemaViewer<'_, T> {
|
||||
fn table_id(&self, name: &str) -> Option<TableId> {
|
||||
// Get the schema from the in-memory state instead of fetching from the database for speed
|
||||
@@ -536,10 +543,15 @@ impl<T: StateView> SchemaView for SchemaViewer<'_, T> {
|
||||
})
|
||||
.collect::<anyhow::Result<_>>()
|
||||
}
|
||||
|
||||
fn get_or_create_params(&mut self, _params: ProductValue) -> TypingResult<ArgId> {
|
||||
// Caller should have used `SchemaViewerMut` on crate `core`
|
||||
Err(TypingError::ParamsReadOnly)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> SchemaViewer<'a, T> {
|
||||
pub fn new(tx: &'a T, auth: &'a AuthCtx) -> Self {
|
||||
pub fn new(tx: &'a mut T, auth: &'a AuthCtx) -> Self {
|
||||
Self { tx, auth }
|
||||
}
|
||||
}
|
||||
@@ -1000,13 +1012,12 @@ fn compile_statement<T: TableSchemaView + StateView>(
|
||||
pub(crate) fn compile_to_ast<T: TableSchemaView + StateView>(
|
||||
db: &RelationalDB,
|
||||
auth: &AuthCtx,
|
||||
tx: &T,
|
||||
tx: &mut T,
|
||||
sql_text: &str,
|
||||
) -> Result<Vec<SqlAst>, DBError> {
|
||||
// NOTE: The following ensures compliance with the 1.0 sql api.
|
||||
// Come 1.0, it will have replaced the current compilation stack.
|
||||
compile_sql_stmt(sql_text, &SchemaViewer::new(tx, auth), auth)?;
|
||||
|
||||
compile_sql_stmt(sql_text, &mut SchemaViewer::new(tx, auth), auth)?;
|
||||
let dialect = PostgreSqlDialect {};
|
||||
let ast = Parser::parse_sql(&dialect, sql_text).map_err(|error| DBError::SqlParser {
|
||||
sql: sql_text.to_string(),
|
||||
|
||||
@@ -23,7 +23,7 @@ const MAX_SQL_LENGTH: usize = 50_000;
|
||||
pub fn compile_sql<T: TableSchemaView + StateView>(
|
||||
db: &RelationalDB,
|
||||
auth: &AuthCtx,
|
||||
tx: &T,
|
||||
tx: &mut T,
|
||||
sql_text: &str,
|
||||
) -> Result<Vec<CrudExpr>, DBError> {
|
||||
if sql_text.len() > MAX_SQL_LENGTH {
|
||||
@@ -266,7 +266,7 @@ mod tests {
|
||||
|
||||
fn compile_sql<T: TableSchemaView + StateView>(
|
||||
db: &RelationalDB,
|
||||
tx: &T,
|
||||
tx: &mut T,
|
||||
sql: &str,
|
||||
) -> Result<Vec<CrudExpr>, DBError> {
|
||||
super::compile_sql(db, &AuthCtx::for_testing(), tx, sql)
|
||||
@@ -281,10 +281,10 @@ mod tests {
|
||||
let indexes = &[];
|
||||
db.create_table_for_test("test", schema, indexes)?;
|
||||
|
||||
let tx = begin_tx(&db);
|
||||
let mut tx = begin_tx(&db);
|
||||
// Compile query
|
||||
let sql = "select * from test where a = 1";
|
||||
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
|
||||
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else {
|
||||
panic!("Expected QueryExpr");
|
||||
};
|
||||
assert_eq!(1, query.len());
|
||||
@@ -303,10 +303,10 @@ mod tests {
|
||||
&[1.into(), 0.into()],
|
||||
)?;
|
||||
|
||||
let tx = begin_tx(&db);
|
||||
let mut tx = begin_tx(&db);
|
||||
// Should work with any qualified field.
|
||||
let sql = "select * from test where a = 1 and b <> 3";
|
||||
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
|
||||
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else {
|
||||
panic!("Expected QueryExpr");
|
||||
};
|
||||
assert_eq!(2, query.len());
|
||||
@@ -324,10 +324,10 @@ mod tests {
|
||||
let indexes = &[0.into()];
|
||||
db.create_table_for_test("test", schema, indexes)?;
|
||||
|
||||
let tx = begin_tx(&db);
|
||||
let mut tx = begin_tx(&db);
|
||||
//Compile query
|
||||
let sql = "select * from test where a = 1";
|
||||
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
|
||||
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else {
|
||||
panic!("Expected QueryExpr");
|
||||
};
|
||||
assert_eq!(1, query.len());
|
||||
@@ -377,11 +377,11 @@ mod tests {
|
||||
|
||||
let rows = run_for_testing(&db, sql)?;
|
||||
|
||||
let tx = begin_tx(&db);
|
||||
let mut tx = begin_tx(&db);
|
||||
let CrudExpr::Query(QueryExpr {
|
||||
source: _,
|
||||
query: mut ops,
|
||||
}) = compile_sql(&db, &tx, sql)?.remove(0)
|
||||
}) = compile_sql(&db, &mut tx, sql)?.remove(0)
|
||||
else {
|
||||
panic!("Expected QueryExpr");
|
||||
};
|
||||
@@ -407,11 +407,11 @@ mod tests {
|
||||
let indexes = &[1.into()];
|
||||
db.create_table_for_test("test", schema, indexes)?;
|
||||
|
||||
let tx = begin_tx(&db);
|
||||
let mut tx = begin_tx(&db);
|
||||
// Note, order does not matter.
|
||||
// The sargable predicate occurs last, but we can still generate an index scan.
|
||||
let sql = "select * from test where a = 1 and b = 2";
|
||||
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
|
||||
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else {
|
||||
panic!("Expected QueryExpr");
|
||||
};
|
||||
assert_eq!(2, query.len());
|
||||
@@ -429,11 +429,11 @@ mod tests {
|
||||
let indexes = &[1.into()];
|
||||
db.create_table_for_test("test", schema, indexes)?;
|
||||
|
||||
let tx = begin_tx(&db);
|
||||
let mut tx = begin_tx(&db);
|
||||
// Note, order does not matter.
|
||||
// The sargable predicate occurs first and we can generate an index scan.
|
||||
let sql = "select * from test where b = 2 and a = 1";
|
||||
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
|
||||
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else {
|
||||
panic!("Expected QueryExpr");
|
||||
};
|
||||
assert_eq!(2, query.len());
|
||||
@@ -455,9 +455,9 @@ mod tests {
|
||||
];
|
||||
db.create_table_for_test_multi_column("test", schema, col_list![0, 1])?;
|
||||
|
||||
let tx = begin_mut_tx(&db);
|
||||
let mut tx = begin_mut_tx(&db);
|
||||
let sql = "select * from test where b = 2 and a = 1";
|
||||
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
|
||||
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else {
|
||||
panic!("Expected QueryExpr");
|
||||
};
|
||||
assert_eq!(1, query.len());
|
||||
@@ -474,10 +474,10 @@ mod tests {
|
||||
let indexes = &[0.into(), 1.into()];
|
||||
db.create_table_for_test("test", schema, indexes)?;
|
||||
|
||||
let tx = begin_tx(&db);
|
||||
let mut tx = begin_tx(&db);
|
||||
// Compile query
|
||||
let sql = "select * from test where a = 1 or b = 2";
|
||||
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
|
||||
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else {
|
||||
panic!("Expected QueryExpr");
|
||||
};
|
||||
assert_eq!(1, query.len());
|
||||
@@ -495,10 +495,10 @@ mod tests {
|
||||
let indexes = &[1.into()];
|
||||
db.create_table_for_test("test", schema, indexes)?;
|
||||
|
||||
let tx = begin_tx(&db);
|
||||
let mut tx = begin_tx(&db);
|
||||
// Compile query
|
||||
let sql = "select * from test where b > 2";
|
||||
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
|
||||
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else {
|
||||
panic!("Expected QueryExpr");
|
||||
};
|
||||
assert_eq!(1, query.len());
|
||||
@@ -516,10 +516,10 @@ mod tests {
|
||||
let indexes = &[1.into()];
|
||||
db.create_table_for_test("test", schema, indexes)?;
|
||||
|
||||
let tx = begin_tx(&db);
|
||||
let mut tx = begin_tx(&db);
|
||||
// Compile query
|
||||
let sql = "select * from test where b > 2 and b < 5";
|
||||
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
|
||||
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else {
|
||||
panic!("Expected QueryExpr");
|
||||
};
|
||||
assert_eq!(1, query.len());
|
||||
@@ -542,11 +542,11 @@ mod tests {
|
||||
let indexes = &[0.into(), 1.into()];
|
||||
db.create_table_for_test("test", schema, indexes)?;
|
||||
|
||||
let tx = begin_tx(&db);
|
||||
let mut tx = begin_tx(&db);
|
||||
// Note, order matters - the equality condition occurs first which
|
||||
// means an index scan will be generated rather than the range condition.
|
||||
let sql = "select * from test where a = 3 and b > 2 and b < 5";
|
||||
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
|
||||
let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else {
|
||||
panic!("Expected QueryExpr");
|
||||
};
|
||||
assert_eq!(2, query.len());
|
||||
@@ -569,10 +569,10 @@ mod tests {
|
||||
let indexes = &[];
|
||||
let rhs_id = db.create_table_for_test("rhs", schema, indexes)?;
|
||||
|
||||
let tx = begin_tx(&db);
|
||||
let mut tx = begin_tx(&db);
|
||||
// Should push sargable equality condition below join
|
||||
let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where lhs.a = 3";
|
||||
let exp = compile_sql(&db, &tx, sql)?.remove(0);
|
||||
let exp = compile_sql(&db, &mut tx, sql)?.remove(0);
|
||||
|
||||
let CrudExpr::Query(QueryExpr {
|
||||
source: source_lhs,
|
||||
@@ -621,10 +621,10 @@ mod tests {
|
||||
let schema = &[("b", AlgebraicType::U64), ("c", AlgebraicType::U64)];
|
||||
let rhs_id = db.create_table_for_test("rhs", schema, &[])?;
|
||||
|
||||
let tx = begin_tx(&db);
|
||||
let mut tx = begin_tx(&db);
|
||||
// Should push equality condition below join
|
||||
let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where lhs.a = 3";
|
||||
let exp = compile_sql(&db, &tx, sql)?.remove(0);
|
||||
let exp = compile_sql(&db, &mut tx, sql)?.remove(0);
|
||||
|
||||
let CrudExpr::Query(QueryExpr {
|
||||
source: source_lhs,
|
||||
@@ -678,10 +678,10 @@ mod tests {
|
||||
let schema = &[("b", AlgebraicType::U64), ("c", AlgebraicType::U64)];
|
||||
let rhs_id = db.create_table_for_test("rhs", schema, &[])?;
|
||||
|
||||
let tx = begin_tx(&db);
|
||||
let mut tx = begin_tx(&db);
|
||||
// Should push equality condition below join
|
||||
let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c = 3";
|
||||
let exp = compile_sql(&db, &tx, sql)?.remove(0);
|
||||
let exp = compile_sql(&db, &mut tx, sql)?.remove(0);
|
||||
|
||||
let CrudExpr::Query(QueryExpr {
|
||||
source: source_lhs,
|
||||
@@ -736,11 +736,11 @@ mod tests {
|
||||
let indexes = &[1.into()];
|
||||
let rhs_id = db.create_table_for_test("rhs", schema, indexes)?;
|
||||
|
||||
let tx = begin_tx(&db);
|
||||
let mut tx = begin_tx(&db);
|
||||
// Should push the sargable equality condition into the join's left arg.
|
||||
// Should push the sargable range condition into the join's right arg.
|
||||
let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where lhs.a = 3 and rhs.c < 4";
|
||||
let exp = compile_sql(&db, &tx, sql)?.remove(0);
|
||||
let exp = compile_sql(&db, &mut tx, sql)?.remove(0);
|
||||
|
||||
let CrudExpr::Query(QueryExpr {
|
||||
source: source_lhs,
|
||||
@@ -807,11 +807,11 @@ mod tests {
|
||||
let indexes = &[0.into(), 1.into()];
|
||||
let rhs_id = db.create_table_for_test("rhs", schema, indexes)?;
|
||||
|
||||
let tx = begin_tx(&db);
|
||||
let mut tx = begin_tx(&db);
|
||||
// Should generate an index join since there is an index on `lhs.b`.
|
||||
// Should push the sargable range condition into the index join's probe side.
|
||||
let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c > 2 and rhs.c < 4 and rhs.d = 3";
|
||||
let exp = compile_sql(&db, &tx, sql)?.remove(0);
|
||||
let exp = compile_sql(&db, &mut tx, sql)?.remove(0);
|
||||
|
||||
let CrudExpr::Query(QueryExpr {
|
||||
source: SourceExpr::DbTable(DbTable { table_id, .. }),
|
||||
@@ -889,11 +889,11 @@ mod tests {
|
||||
let indexes = col_list![0, 1];
|
||||
let rhs_id = db.create_table_for_test_multi_column("rhs", schema, indexes)?;
|
||||
|
||||
let tx = begin_tx(&db);
|
||||
let mut tx = begin_tx(&db);
|
||||
// Should generate an index join since there is an index on `lhs.b`.
|
||||
// Should push the sargable range condition into the index join's probe side.
|
||||
let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c = 2 and rhs.b = 4 and rhs.d = 3";
|
||||
let exp = compile_sql(&db, &tx, sql)?.remove(0);
|
||||
let exp = compile_sql(&db, &mut tx, sql)?.remove(0);
|
||||
|
||||
let CrudExpr::Query(QueryExpr {
|
||||
source: SourceExpr::DbTable(DbTable { table_id, .. }),
|
||||
@@ -953,7 +953,7 @@ mod tests {
|
||||
let db = TestDB::durable()?;
|
||||
db.create_table_for_test("A", &[("x", AlgebraicType::U64)], &[])?;
|
||||
db.create_table_for_test("B", &[("y", AlgebraicType::U64)], &[])?;
|
||||
assert!(compile_sql(&db, &begin_tx(&db), "select B.* from B join A on B.y = A.x").is_ok());
|
||||
assert!(compile_sql(&db, &mut begin_tx(&db), "select B.* from B join A on B.y = A.x").is_ok());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -970,27 +970,27 @@ mod tests {
|
||||
// TODO: Type check other operations deferred for the new query engine.
|
||||
|
||||
assert!(
|
||||
compile_sql(&db, &begin_tx(&db), sql).is_err(),
|
||||
compile_sql(&db, &mut begin_tx(&db), sql).is_err(),
|
||||
// Err("SqlError: Type Mismatch: `PlayerState.entity_id: U64` != `String(\"161853\"): String`, executing: `SELECT * FROM PlayerState WHERE entity_id = '161853'`".into())
|
||||
);
|
||||
|
||||
// Check we can still compile the query if we remove the type mismatch and have multiple logical operations.
|
||||
let sql = "SELECT * FROM PlayerState WHERE entity_id = 1 AND entity_id = 2 AND entity_id = 3 OR entity_id = 4 OR entity_id = 5";
|
||||
|
||||
assert!(compile_sql(&db, &begin_tx(&db), sql).is_ok());
|
||||
assert!(compile_sql(&db, &mut begin_tx(&db), sql).is_ok());
|
||||
|
||||
// Now verify when we have a type mismatch in the middle of the logical operations.
|
||||
let sql = "SELECT * FROM PlayerState WHERE entity_id = 1 AND entity_id";
|
||||
|
||||
assert!(
|
||||
compile_sql(&db, &begin_tx(&db), sql).is_err(),
|
||||
compile_sql(&db, &mut begin_tx(&db), sql).is_err(),
|
||||
// Err("SqlError: Type Mismatch: `PlayerState.entity_id: U64 == U64(1): U64` and `PlayerState.entity_id: U64`, both sides must be an `Bool` expression, executing: `SELECT * FROM PlayerState WHERE entity_id = 1 AND entity_id`".into())
|
||||
);
|
||||
// Verify that all operands of `AND` must be `Bool`.
|
||||
let sql = "SELECT * FROM PlayerState WHERE entity_id AND entity_id";
|
||||
|
||||
assert!(
|
||||
compile_sql(&db, &begin_tx(&db), sql).is_err(),
|
||||
compile_sql(&db, &mut begin_tx(&db), sql).is_err(),
|
||||
// Err("SqlError: Type Mismatch: `PlayerState.entity_id: U64` and `PlayerState.entity_id: U64`, both sides must be an `Bool` expression, executing: `SELECT * FROM PlayerState WHERE entity_id AND entity_id`".into())
|
||||
);
|
||||
Ok(())
|
||||
|
||||
@@ -20,9 +20,8 @@ use anyhow::anyhow;
|
||||
use spacetimedb_datastore::execution_context::Workload;
|
||||
use spacetimedb_datastore::locking_tx_datastore::state_view::StateView;
|
||||
use spacetimedb_datastore::traits::IsolationLevel;
|
||||
use spacetimedb_expr::check::SchemaView;
|
||||
use spacetimedb_expr::check::{SchemaView, TypingResult};
|
||||
use spacetimedb_expr::errors::TypingError;
|
||||
use spacetimedb_expr::expr::CallParams;
|
||||
use spacetimedb_expr::statement::Statement;
|
||||
use spacetimedb_lib::identity::AuthCtx;
|
||||
use spacetimedb_lib::metrics::ExecutionMetrics;
|
||||
@@ -191,15 +190,27 @@ pub struct SqlResult {
|
||||
pub metrics: ExecutionMetrics,
|
||||
}
|
||||
|
||||
struct DbParams<'a> {
|
||||
struct SchemaViewerMut<'a> {
|
||||
db: &'a RelationalDB,
|
||||
tx: &'a mut MutTx,
|
||||
schema: SchemaViewer<'a, MutTx>,
|
||||
}
|
||||
|
||||
impl CallParams for DbParams<'_> {
|
||||
fn create_or_get_param(&mut self, param: &ProductValue) -> Result<ArgId, TypingError> {
|
||||
impl SchemaView for SchemaViewerMut<'_> {
|
||||
fn table_id(&self, name: &str) -> Option<TableId> {
|
||||
self.schema.table_id(name)
|
||||
}
|
||||
|
||||
fn schema_for_table(&self, table_id: TableId) -> Option<Arc<TableOrViewSchema>> {
|
||||
self.schema.schema_for_table(table_id)
|
||||
}
|
||||
|
||||
fn rls_rules_for_table(&self, table_id: TableId) -> anyhow::Result<Vec<Box<str>>> {
|
||||
self.schema.rls_rules_for_table(table_id)
|
||||
}
|
||||
|
||||
fn get_or_create_params(&mut self, params: ProductValue) -> TypingResult<ArgId> {
|
||||
self.db
|
||||
.create_or_get_params(self.tx, ¶m)
|
||||
.create_or_get_params(&mut self.schema, ¶ms)
|
||||
.map_err(|err| TypingError::Other(err.into()))
|
||||
}
|
||||
}
|
||||
@@ -215,20 +226,16 @@ pub async fn run(
|
||||
) -> Result<SqlResult, DBError> {
|
||||
// We parse the sql statement in a mutable transaction.
|
||||
// If it turns out to be a query, we downgrade the tx.
|
||||
let (tx, stmt) =
|
||||
db.with_auto_rollback(
|
||||
db.begin_mut_tx(IsolationLevel::Serializable, Workload::Sql),
|
||||
|tx| match compile_sql_stmt(sql_text, &SchemaViewer::new(tx, &auth), &auth) {
|
||||
Ok(Statement::Select(mut stmt)) => {
|
||||
stmt.for_each_fun_call(&mut |param| {
|
||||
db.create_or_get_params(tx, ¶m)
|
||||
.map_err(|err| TypingError::Other(err.into()))
|
||||
})?;
|
||||
Ok(Statement::Select(stmt))
|
||||
}
|
||||
result => result,
|
||||
let (tx, stmt) = db.with_auto_rollback(db.begin_mut_tx(IsolationLevel::Serializable, Workload::Sql), |tx| {
|
||||
compile_sql_stmt(
|
||||
sql_text,
|
||||
&mut SchemaViewerMut {
|
||||
db,
|
||||
schema: SchemaViewer::new(tx, &auth),
|
||||
},
|
||||
)?;
|
||||
&auth,
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut metrics = ExecutionMetrics::default();
|
||||
|
||||
@@ -1619,6 +1626,10 @@ pub(crate) mod tests {
|
||||
true,
|
||||
)?;
|
||||
let arg_id = ST_RESERVED_SEQUENCE_RANGE as u64;
|
||||
assert_eq!(
|
||||
run_for_testing(&db, "select view_id, param_pos, param_name FROM st_view_param")?,
|
||||
vec![product![arg_id as u32, 0u16, "x"]]
|
||||
);
|
||||
|
||||
with_auto_commit(&db, |tx| -> Result<_, DBError> {
|
||||
tests_utils::insert_into_view(&db, tx, table_id, None, product![arg_id + 1, 0u8, 1i64])?;
|
||||
@@ -1636,6 +1647,12 @@ pub(crate) mod tests {
|
||||
vec![product![0u8, 1i64]]
|
||||
);
|
||||
|
||||
// We have created the internal rows for view args
|
||||
assert_eq!(
|
||||
run_for_testing(&db, "select id FROM st_view_arg")?,
|
||||
vec![product![arg_id], product![arg_id + 1]]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ impl RowLevelExpr {
|
||||
auth_ctx: &AuthCtx,
|
||||
rls: &RawRowLevelSecurityDefV9,
|
||||
) -> anyhow::Result<Self> {
|
||||
let (sql, _) = parse_and_type_sub(&rls.sql, &SchemaViewer::new(tx, auth_ctx), auth_ctx)?;
|
||||
let (sql, _) = parse_and_type_sub(&rls.sql, &mut SchemaViewer::new(tx, auth_ctx), auth_ctx)?;
|
||||
let table_id = sql.return_table_id().unwrap();
|
||||
let schema = tx.schema_for_table(table_id)?;
|
||||
|
||||
|
||||
@@ -472,7 +472,7 @@ impl ModuleSubscriptions {
|
||||
let hash = QueryHash::from_string(&sql, auth.caller(), false);
|
||||
let hash_with_param = QueryHash::from_string(&sql, auth.caller(), true);
|
||||
|
||||
let (mut_tx, _) = self.begin_mut_tx(Workload::Subscribe);
|
||||
let (mut mut_tx, _) = self.begin_mut_tx(Workload::Subscribe);
|
||||
|
||||
let existing_query = {
|
||||
let guard = self.subscriptions.read();
|
||||
@@ -482,7 +482,7 @@ impl ModuleSubscriptions {
|
||||
let query = return_on_err_with_sql!(
|
||||
existing_query.map(Ok).unwrap_or_else(|| compile_query_with_hashes(
|
||||
&auth,
|
||||
&*mut_tx,
|
||||
&mut *mut_tx,
|
||||
&sql,
|
||||
hash,
|
||||
hash_with_param
|
||||
@@ -736,7 +736,7 @@ impl ModuleSubscriptions {
|
||||
}
|
||||
|
||||
// We always get the db lock before the subscription lock to avoid deadlocks.
|
||||
let (mut_tx, _tx_offset) = self.begin_mut_tx(Workload::Subscribe);
|
||||
let (mut mut_tx, _tx_offset) = self.begin_mut_tx(Workload::Subscribe);
|
||||
|
||||
let compile_timer = metrics.compilation_time.start_timer();
|
||||
|
||||
@@ -752,7 +752,7 @@ impl ModuleSubscriptions {
|
||||
super::subscription::get_all(
|
||||
|relational_db, tx| relational_db.get_all_tables_mut(tx).map(|schemas| schemas.into_iter()),
|
||||
&self.relational_db,
|
||||
&*mut_tx,
|
||||
&mut *mut_tx,
|
||||
&auth,
|
||||
)?
|
||||
.into_iter()
|
||||
@@ -769,7 +769,7 @@ impl ModuleSubscriptions {
|
||||
plans.push(unit);
|
||||
} else {
|
||||
plans.push(Arc::new(
|
||||
compile_query_with_hashes(&auth, &*mut_tx, sql, hash, hash_with_param).map_err(|err| {
|
||||
compile_query_with_hashes(&auth, &mut *mut_tx, sql, hash, hash_with_param).map_err(|err| {
|
||||
DBError::WithSql {
|
||||
error: Box::new(DBError::Other(err.into())),
|
||||
sql: sql.into(),
|
||||
@@ -1807,8 +1807,8 @@ mod tests {
|
||||
|
||||
let auth = AuthCtx::for_testing();
|
||||
let sql = "select * from t where id = 1";
|
||||
let tx = begin_tx(&db);
|
||||
let plan = compile_read_only_query(&auth, &tx, sql)?;
|
||||
let mut tx = begin_tx(&db);
|
||||
let plan = compile_read_only_query(&auth, &mut tx, sql)?;
|
||||
let plan = Arc::new(plan);
|
||||
|
||||
let (_, metrics) = subs.evaluate_queries(sender, &[plan], &tx, &auth, TableUpdateType::Subscribe)?;
|
||||
|
||||
@@ -1681,8 +1681,8 @@ mod tests {
|
||||
fn compile_plan(db: &RelationalDB, sql: &str) -> ResultTest<Arc<Plan>> {
|
||||
with_read_only(db, |tx| {
|
||||
let auth = AuthCtx::for_testing();
|
||||
let tx = SchemaViewer::new(&*tx, &auth);
|
||||
let (plans, has_param) = SubscriptionPlan::compile(sql, &tx, &auth).unwrap();
|
||||
let mut tx = SchemaViewer::new(tx, &auth);
|
||||
let (plans, has_param) = SubscriptionPlan::compile(sql, &mut tx, &auth).unwrap();
|
||||
let hash = QueryHash::from_string(sql, auth.caller(), has_param);
|
||||
Ok(Arc::new(Plan::new(plans, hash, sql.into())))
|
||||
})
|
||||
|
||||
@@ -42,7 +42,7 @@ pub fn is_subscribe_to_all_tables(sql: &str) -> bool {
|
||||
pub fn compile_read_only_queryset(
|
||||
relational_db: &RelationalDB,
|
||||
auth: &AuthCtx,
|
||||
tx: &Tx,
|
||||
tx: &mut Tx,
|
||||
input: &str,
|
||||
) -> Result<Vec<SupportedQuery>, DBError> {
|
||||
let input = input.trim();
|
||||
@@ -82,13 +82,13 @@ pub fn compile_read_only_queryset(
|
||||
|
||||
/// Compile a string into a single read-only query.
|
||||
/// This returns an error if the string has multiple queries or mutations.
|
||||
pub fn compile_read_only_query(auth: &AuthCtx, tx: &Tx, input: &str) -> Result<Plan, DBError> {
|
||||
pub fn compile_read_only_query(auth: &AuthCtx, tx: &mut Tx, input: &str) -> Result<Plan, DBError> {
|
||||
if is_whitespace_or_empty(input) {
|
||||
return Err(SubscriptionError::Empty.into());
|
||||
}
|
||||
|
||||
let tx = SchemaViewer::new(tx, auth);
|
||||
let (plans, has_param) = SubscriptionPlan::compile(input, &tx, auth)?;
|
||||
let mut tx = SchemaViewer::new(tx, auth);
|
||||
let (plans, has_param) = SubscriptionPlan::compile(input, &mut tx, auth)?;
|
||||
let hash = QueryHash::from_string(input, auth.caller(), has_param);
|
||||
Ok(Plan::new(plans, hash, input.to_owned()))
|
||||
}
|
||||
@@ -97,7 +97,7 @@ pub fn compile_read_only_query(auth: &AuthCtx, tx: &Tx, input: &str) -> Result<P
|
||||
/// This returns an error if the string has multiple queries or mutations.
|
||||
pub fn compile_query_with_hashes<Tx: Datastore + StateView>(
|
||||
auth: &AuthCtx,
|
||||
tx: &Tx,
|
||||
tx: &mut Tx,
|
||||
input: &str,
|
||||
hash: QueryHash,
|
||||
hash_with_param: QueryHash,
|
||||
@@ -106,8 +106,8 @@ pub fn compile_query_with_hashes<Tx: Datastore + StateView>(
|
||||
return Err(SubscriptionError::Empty.into());
|
||||
}
|
||||
|
||||
let tx = SchemaViewer::new(tx, auth);
|
||||
let (plans, has_param) = SubscriptionPlan::compile(input, &tx, auth)?;
|
||||
let mut tx = SchemaViewer::new(tx, auth);
|
||||
let (plans, has_param) = SubscriptionPlan::compile(input, &mut tx, auth)?;
|
||||
|
||||
if auth.bypass_rls() || has_param {
|
||||
// Note that when generating hashes for queries from owners,
|
||||
@@ -151,7 +151,7 @@ mod tests {
|
||||
use crate::db::relational_db::tests_utils::{
|
||||
begin_mut_tx, begin_tx, insert, with_auto_commit, with_read_only, TestDB,
|
||||
};
|
||||
use crate::db::relational_db::MutTx;
|
||||
use crate::db::relational_db::{tests_utils, MutTx};
|
||||
use crate::host::module_host::{DatabaseTableUpdate, DatabaseUpdate, UpdatesRelValue};
|
||||
use crate::sql::execute::collect_result;
|
||||
use crate::sql::execute::tests::run_for_testing;
|
||||
@@ -164,6 +164,7 @@ mod tests {
|
||||
use itertools::Itertools;
|
||||
use spacetimedb_client_api_messages::websocket::{BsatnFormat, CompressableQueryUpdate, Compression};
|
||||
use spacetimedb_datastore::execution_context::Workload;
|
||||
use spacetimedb_datastore::system_tables::ST_RESERVED_SEQUENCE_RANGE;
|
||||
use spacetimedb_lib::bsatn;
|
||||
use spacetimedb_lib::db::auth::{StAccess, StTableType};
|
||||
use spacetimedb_lib::error::ResultTest;
|
||||
@@ -421,9 +422,9 @@ mod tests {
|
||||
db.create_table_for_test("a", schema, indexes)?;
|
||||
db.create_table_for_test("b", schema, indexes)?;
|
||||
|
||||
let tx = begin_tx(&db);
|
||||
let mut tx = begin_tx(&db);
|
||||
let sql = "SELECT b.* FROM b JOIN a ON b.n = a.n WHERE b.data > 200";
|
||||
let result = compile_read_only_query(&AuthCtx::for_testing(), &tx, sql);
|
||||
let result = compile_read_only_query(&AuthCtx::for_testing(), &mut tx, sql);
|
||||
assert!(result.is_ok());
|
||||
Ok(())
|
||||
}
|
||||
@@ -454,10 +455,10 @@ mod tests {
|
||||
};
|
||||
|
||||
db.commit_tx(tx)?;
|
||||
let tx = begin_tx(&db);
|
||||
let mut tx = begin_tx(&db);
|
||||
|
||||
let sql = "select * from test where b = 3";
|
||||
let mut exp = compile_sql(&db, &AuthCtx::for_testing(), &tx, sql)?;
|
||||
let mut exp = compile_sql(&db, &AuthCtx::for_testing(), &mut tx, sql)?;
|
||||
|
||||
let Some(CrudExpr::Query(query)) = exp.pop() else {
|
||||
panic!("unexpected query {:#?}", exp[0]);
|
||||
@@ -609,8 +610,8 @@ mod tests {
|
||||
AND MobileEntityState.location_z > 96000 \
|
||||
AND MobileEntityState.location_z < 192000";
|
||||
|
||||
let tx = begin_tx(&db);
|
||||
let qset = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &tx, sql_query)?;
|
||||
let mut tx = begin_tx(&db);
|
||||
let qset = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &mut tx, sql_query)?;
|
||||
|
||||
for q in qset {
|
||||
let result = run_query(
|
||||
@@ -684,7 +685,7 @@ mod tests {
|
||||
let indexes = &[ColId(0), ColId(1)];
|
||||
db.create_table_for_test("rhs", schema, indexes)?;
|
||||
|
||||
let tx = begin_tx(&db);
|
||||
let mut tx = begin_tx(&db);
|
||||
|
||||
// All single table queries are supported
|
||||
let scans = [
|
||||
@@ -696,7 +697,7 @@ mod tests {
|
||||
"SELECT * FROM lhs WHERE id > 5",
|
||||
];
|
||||
for scan in scans {
|
||||
let expr = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &tx, scan)?
|
||||
let expr = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &mut tx, scan)?
|
||||
.pop()
|
||||
.unwrap();
|
||||
assert_eq!(expr.kind(), Supported::Select, "{scan}\n{expr:#?}");
|
||||
@@ -705,7 +706,7 @@ mod tests {
|
||||
// Only index semijoins are supported
|
||||
let joins = ["SELECT lhs.* FROM lhs JOIN rhs ON lhs.id = rhs.id WHERE rhs.y < 10"];
|
||||
for join in joins {
|
||||
let expr = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &tx, join)?
|
||||
let expr = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &mut tx, join)?
|
||||
.pop()
|
||||
.unwrap();
|
||||
assert_eq!(expr.kind(), Supported::Semijoin, "{join}\n{expr:#?}");
|
||||
@@ -718,7 +719,7 @@ mod tests {
|
||||
"SELECT * FROM lhs JOIN rhs ON lhs.id = rhs.id WHERE lhs.x < 10",
|
||||
];
|
||||
for join in joins {
|
||||
match compile_read_only_queryset(&db, &AuthCtx::for_testing(), &tx, join) {
|
||||
match compile_read_only_queryset(&db, &AuthCtx::for_testing(), &mut tx, join) {
|
||||
Err(DBError::Subscription(SubscriptionError::Unsupported(_)) | DBError::TypeError(_)) => (),
|
||||
x => panic!("Unexpected: {x:?}"),
|
||||
}
|
||||
@@ -756,10 +757,10 @@ mod tests {
|
||||
fn compile_query(db: &RelationalDB) -> ResultTest<SubscriptionPlan> {
|
||||
with_read_only(db, |tx| {
|
||||
let auth = AuthCtx::for_testing();
|
||||
let tx = SchemaViewer::new(tx, &auth);
|
||||
let mut tx = SchemaViewer::new(tx, &auth);
|
||||
// Should be answered using an index semijion
|
||||
let sql = "select lhs.* from lhs join rhs on lhs.id = rhs.id where rhs.y >= 2 and rhs.y <= 4";
|
||||
Ok(SubscriptionPlan::compile(sql, &tx, &auth)
|
||||
Ok(SubscriptionPlan::compile(sql, &mut tx, &auth)
|
||||
.map(|(mut plans, _)| {
|
||||
assert_eq!(plans.len(), 1);
|
||||
plans.pop().unwrap()
|
||||
@@ -781,10 +782,10 @@ mod tests {
|
||||
fn compile_query(db: &RelationalDB) -> ResultTest<SubscriptionPlan> {
|
||||
with_read_only(db, |tx| {
|
||||
let auth = AuthCtx::for_testing();
|
||||
let tx = SchemaViewer::new(tx, &auth);
|
||||
let mut tx = SchemaViewer::new(tx, &auth);
|
||||
// Should be answered using an index semijion
|
||||
let sql = "select lhs.* from lhs join rhs on lhs.id = rhs.id where lhs.x >= 5 and lhs.x <= 7";
|
||||
Ok(SubscriptionPlan::compile(sql, &tx, &auth)
|
||||
Ok(SubscriptionPlan::compile(sql, &mut tx, &auth)
|
||||
.map(|(mut plans, _)| {
|
||||
assert_eq!(plans.len(), 1);
|
||||
plans.pop().unwrap()
|
||||
@@ -1447,4 +1448,36 @@ mod tests {
|
||||
assert_eq!(metrics.index_seeks, 8);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Verify calling views with params
|
||||
// TODO: All testing use the old query compiler, so we can't test this yet.
|
||||
#[test]
|
||||
fn test_view_params() -> ResultTest<()> {
|
||||
let db = TestDB::durable()?;
|
||||
let schema = [("a", AlgebraicType::U8), ("b", AlgebraicType::I64)];
|
||||
let (_view_id, table_id) = tests_utils::create_view_for_test(
|
||||
&db,
|
||||
"my_view",
|
||||
&schema,
|
||||
ProductType::from([("x", AlgebraicType::U8)]),
|
||||
true,
|
||||
)?;
|
||||
let arg_id = ST_RESERVED_SEQUENCE_RANGE as u64;
|
||||
|
||||
with_auto_commit(&db, |tx| -> Result<_, DBError> {
|
||||
tests_utils::insert_into_view(&db, tx, table_id, None, product![arg_id + 1, 0u8, 1i64])?;
|
||||
tests_utils::insert_into_view(&db, tx, table_id, None, product![arg_id, 1u8, 2i64])?;
|
||||
Ok(())
|
||||
})?;
|
||||
|
||||
let mut tx = begin_tx(&db);
|
||||
|
||||
let err =
|
||||
compile_read_only_queryset(&db, &AuthCtx::for_testing(), &mut tx, "SELECT * FROM my_view(1)").unwrap_err();
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"InternalError: Read-only queries cannot create parameters".to_string()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -615,7 +615,7 @@ impl AuthAccess for ExecutionSet {
|
||||
pub(crate) fn get_all<T, F, I>(
|
||||
get_all_tables: F,
|
||||
relational_db: &RelationalDB,
|
||||
tx: &T,
|
||||
tx: &mut T,
|
||||
auth: &AuthCtx,
|
||||
) -> Result<Vec<Plan>, DBError>
|
||||
where
|
||||
@@ -627,8 +627,8 @@ where
|
||||
.filter(|t| t.table_type == StTableType::User && auth.has_read_access(t.table_access))
|
||||
.map(|schema| {
|
||||
let sql = format!("SELECT * FROM {}", schema.table_name);
|
||||
let tx = SchemaViewer::new(tx, auth);
|
||||
SubscriptionPlan::compile(&sql, &tx, auth).map(|(plans, has_param)| {
|
||||
let mut tx = SchemaViewer::new(tx, auth);
|
||||
SubscriptionPlan::compile(&sql, &mut tx, auth).map(|(plans, has_param)| {
|
||||
Plan::new(
|
||||
plans,
|
||||
QueryHash::from_string(
|
||||
@@ -701,11 +701,11 @@ mod tests {
|
||||
let indexes = &[0.into(), 1.into()];
|
||||
let rhs_id = db.create_table_for_test("rhs", schema, indexes)?;
|
||||
|
||||
let tx = begin_tx(&db);
|
||||
let mut tx = begin_tx(&db);
|
||||
// Should generate an index join since there is an index on `lhs.b`.
|
||||
// Should push the sargable range condition into the index join's probe side.
|
||||
let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c > 2 and rhs.c < 4 and rhs.d = 3";
|
||||
let exp = compile_sql(&db, &AuthCtx::for_testing(), &tx, sql)?.remove(0);
|
||||
let exp = compile_sql(&db, &AuthCtx::for_testing(), &mut tx, sql)?.remove(0);
|
||||
|
||||
let CrudExpr::Query(mut expr) = exp else {
|
||||
panic!("unexpected result from compilation: {exp:#?}");
|
||||
@@ -781,11 +781,11 @@ mod tests {
|
||||
let indexes = &[0.into(), 1.into()];
|
||||
let _ = db.create_table_for_test("rhs", schema, indexes)?;
|
||||
|
||||
let tx = begin_tx(&db);
|
||||
let mut tx = begin_tx(&db);
|
||||
// Should generate an index join since there is an index on `lhs.b`.
|
||||
// Should push the sargable range condition into the index join's probe side.
|
||||
let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c > 2 and rhs.c < 4 and rhs.d = 3";
|
||||
let exp = compile_sql(&db, &AuthCtx::for_testing(), &tx, sql)?.remove(0);
|
||||
let exp = compile_sql(&db, &AuthCtx::for_testing(), &mut tx, sql)?.remove(0);
|
||||
|
||||
let CrudExpr::Query(mut expr) = exp else {
|
||||
panic!("unexpected result from compilation: {exp:#?}");
|
||||
@@ -865,12 +865,12 @@ mod tests {
|
||||
.create_table_for_test("rhs", schema, indexes)
|
||||
.expect("Failed to create_table_for_test rhs");
|
||||
|
||||
let tx = begin_tx(&db);
|
||||
let mut tx = begin_tx(&db);
|
||||
|
||||
// Should generate an index join since there is an index on `lhs.b`.
|
||||
// Should push the sargable range condition into the index join's probe side.
|
||||
let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c > 2 and rhs.c < 4 and rhs.d = 3";
|
||||
let exp = compile_sql(&db, &AuthCtx::for_testing(), &tx, sql)
|
||||
let exp = compile_sql(&db, &AuthCtx::for_testing(), &mut tx, sql)
|
||||
.expect("Failed to compile_sql")
|
||||
.remove(0);
|
||||
|
||||
|
||||
@@ -65,14 +65,14 @@ mod tests {
|
||||
use spacetimedb_vm::relation::MemTable;
|
||||
|
||||
fn run_query(db: &Arc<RelationalDB>, sql: String) -> ResultTest<MemTable> {
|
||||
let tx = begin_tx(db);
|
||||
let q = compile_sql(db, &AuthCtx::for_testing(), &tx, &sql)?;
|
||||
let mut tx = begin_tx(db);
|
||||
let q = compile_sql(db, &AuthCtx::for_testing(), &mut tx, &sql)?;
|
||||
Ok(execute_for_testing(db, &sql, q)?.pop().unwrap())
|
||||
}
|
||||
|
||||
fn run_query_write(db: &Arc<RelationalDB>, sql: String) -> ResultTest<()> {
|
||||
let tx = begin_tx(db);
|
||||
let q = compile_sql(db, &AuthCtx::for_testing(), &tx, &sql)?;
|
||||
let mut tx = begin_tx(db);
|
||||
let q = compile_sql(db, &AuthCtx::for_testing(), &mut tx, &sql)?;
|
||||
drop(tx);
|
||||
|
||||
execute_for_testing(db, &sql, q)?;
|
||||
@@ -92,10 +92,10 @@ mod tests {
|
||||
}
|
||||
Ok(())
|
||||
})?;
|
||||
let tx = begin_tx(&db);
|
||||
let mut tx = begin_tx(&db);
|
||||
|
||||
let sql = "select * from test where x > 0";
|
||||
let q = compile_sql(&db, &AuthCtx::for_testing(), &tx, sql)?;
|
||||
let q = compile_sql(&db, &AuthCtx::for_testing(), &mut tx, sql)?;
|
||||
|
||||
let slow = SlowQueryLogger::new(sql, Some(Duration::from_millis(1)), tx.ctx.workload());
|
||||
|
||||
|
||||
+82
-38
@@ -9,12 +9,12 @@ use super::{
|
||||
type_expr, type_proj, type_select,
|
||||
};
|
||||
use crate::errors::{TableFunc, UnexpectedFunctionType};
|
||||
use crate::expr::{Expr, LeftDeepJoin, ProjectList, ProjectName, Relvar};
|
||||
use crate::expr::{Expr, FieldProject, LeftDeepJoin, ProjectList, ProjectName, Relvar};
|
||||
use spacetimedb_lib::identity::AuthCtx;
|
||||
use spacetimedb_lib::AlgebraicType;
|
||||
use spacetimedb_primitives::{ArgId, TableId};
|
||||
use spacetimedb_sats::algebraic_type::fmt::fmt_algebraic_type;
|
||||
use spacetimedb_sats::ProductValue;
|
||||
use spacetimedb_sats::{AlgebraicValue, ProductValue};
|
||||
use spacetimedb_schema::schema::TableOrViewSchema;
|
||||
use spacetimedb_sql_parser::ast::{BinOp, SqlExpr, SqlLiteral};
|
||||
use spacetimedb_sql_parser::{
|
||||
@@ -35,9 +35,7 @@ pub trait SchemaView {
|
||||
self.table_id(name).and_then(|table_id| self.schema_for_table(table_id))
|
||||
}
|
||||
|
||||
fn get_or_create_params(&self, params: &ProductValue) -> TypingResult<ArgId> {
|
||||
Ok(ArgId::SENTINEL)
|
||||
}
|
||||
fn get_or_create_params(&mut self, params: ProductValue) -> TypingResult<ArgId>;
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
@@ -60,9 +58,9 @@ pub trait TypeChecker {
|
||||
type Ast;
|
||||
type Set;
|
||||
|
||||
fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<ProjectList>;
|
||||
fn type_ast(ast: Self::Ast, tx: &mut impl SchemaView) -> TypingResult<ProjectList>;
|
||||
|
||||
fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<ProjectList>;
|
||||
fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &mut impl SchemaView) -> TypingResult<ProjectList>;
|
||||
|
||||
fn type_view_params(
|
||||
schema: &TableOrViewSchema,
|
||||
@@ -152,25 +150,35 @@ pub trait TypeChecker {
|
||||
}
|
||||
|
||||
fn type_params(
|
||||
tx: &mut impl SchemaView,
|
||||
from: RelExpr,
|
||||
schema: Arc<TableOrViewSchema>,
|
||||
alias: Box<str>,
|
||||
params: Option<ProductValue>,
|
||||
) -> RelExpr {
|
||||
) -> TypingResult<RelExpr> {
|
||||
match params {
|
||||
None => from,
|
||||
Some(args) => RelExpr::FunCall(
|
||||
Relvar {
|
||||
schema,
|
||||
alias,
|
||||
delta: None,
|
||||
},
|
||||
args,
|
||||
),
|
||||
None => Ok(from),
|
||||
Some(args) => {
|
||||
let new_arg_id = tx.get_or_create_params(args)?;
|
||||
let arg_id_col = schema.inner().get_column_by_name("arg_id").unwrap().col_pos;
|
||||
|
||||
Ok(RelExpr::Select(
|
||||
Box::new(from),
|
||||
Expr::BinOp(
|
||||
BinOp::Eq,
|
||||
Box::new(Expr::Field(FieldProject {
|
||||
table: alias,
|
||||
field: arg_id_col.idx(),
|
||||
ty: AlgebraicType::U64,
|
||||
})),
|
||||
Box::new(Expr::Value(AlgebraicValue::U64(new_arg_id.0), AlgebraicType::U64)),
|
||||
),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn type_from(from: SqlFrom, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<RelExpr> {
|
||||
fn type_from(from: SqlFrom, vars: &mut Relvars, tx: &mut impl SchemaView) -> TypingResult<RelExpr> {
|
||||
match from {
|
||||
SqlFrom::Expr(SqlIdent(name), SqlIdent(alias)) => {
|
||||
let schema = Self::type_relvar(tx, &name)?;
|
||||
@@ -202,7 +210,7 @@ pub trait TypeChecker {
|
||||
}
|
||||
let schema = Self::type_relvar(tx, &name)?;
|
||||
let arg = Self::type_view_params(&schema, vars, params)?;
|
||||
let lhs = Box::new(Self::type_params(join, schema.clone(), alias.clone(), arg));
|
||||
let lhs = Box::new(Self::type_params(tx, join, schema.clone(), alias.clone(), arg)?);
|
||||
|
||||
let rhs = Relvar {
|
||||
schema,
|
||||
@@ -237,7 +245,7 @@ pub trait TypeChecker {
|
||||
delta: None,
|
||||
});
|
||||
|
||||
Ok(Self::type_params(from, schema, alias, arg))
|
||||
Self::type_params(tx, from, schema, alias, arg)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -256,11 +264,11 @@ impl TypeChecker for SubChecker {
|
||||
type Ast = SqlSelect;
|
||||
type Set = SqlSelect;
|
||||
|
||||
fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<ProjectList> {
|
||||
fn type_ast(ast: Self::Ast, tx: &mut impl SchemaView) -> TypingResult<ProjectList> {
|
||||
Self::type_set(ast, &mut Relvars::default(), tx)
|
||||
}
|
||||
|
||||
fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<ProjectList> {
|
||||
fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &mut impl SchemaView) -> TypingResult<ProjectList> {
|
||||
match ast {
|
||||
SqlSelect {
|
||||
project,
|
||||
@@ -283,7 +291,7 @@ impl TypeChecker for SubChecker {
|
||||
}
|
||||
|
||||
/// Parse and type check a subscription query
|
||||
pub fn parse_and_type_sub(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> TypingResult<(ProjectName, bool)> {
|
||||
pub fn parse_and_type_sub(sql: &str, tx: &mut impl SchemaView, auth: &AuthCtx) -> TypingResult<(ProjectName, bool)> {
|
||||
let ast = parse_subscription(sql)?;
|
||||
let has_param = ast.has_parameter();
|
||||
let ast = ast.resolve_sender(auth.caller());
|
||||
@@ -303,15 +311,16 @@ fn expect_table_type(expr: ProjectList) -> TypingResult<ProjectName> {
|
||||
|
||||
pub mod test_utils {
|
||||
use spacetimedb_lib::{db::raw_def::v9::RawModuleDefV9Builder, ProductType};
|
||||
use spacetimedb_primitives::TableId;
|
||||
use spacetimedb_sats::AlgebraicType;
|
||||
use spacetimedb_primitives::{ArgId, TableId};
|
||||
use spacetimedb_sats::{AlgebraicType, ProductValue};
|
||||
use spacetimedb_schema::{
|
||||
def::ModuleDef,
|
||||
schema::{Schema, TableOrViewSchema, TableSchema},
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::SchemaView;
|
||||
use super::{SchemaView, TypingResult};
|
||||
pub struct ViewInfo<'a> {
|
||||
pub(crate) name: &'a str,
|
||||
pub(crate) columns: &'a [(&'a str, AlgebraicType)],
|
||||
@@ -333,7 +342,38 @@ pub mod test_utils {
|
||||
builder.finish().try_into().expect("failed to generate module def")
|
||||
}
|
||||
|
||||
pub struct SchemaViewer(pub ModuleDef);
|
||||
pub struct MockCallParams {
|
||||
counter: u64,
|
||||
params: HashMap<ProductValue, ArgId>,
|
||||
}
|
||||
|
||||
impl Default for MockCallParams {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl MockCallParams {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
counter: 0,
|
||||
params: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_or_insert(&mut self, value: ProductValue) -> ArgId {
|
||||
if let Some(existing) = self.params.get(&value) {
|
||||
*existing
|
||||
} else {
|
||||
self.counter += 1;
|
||||
let arg_id = ArgId(self.counter - 1);
|
||||
self.params.insert(value, arg_id);
|
||||
arg_id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SchemaViewer(pub ModuleDef, pub MockCallParams);
|
||||
|
||||
impl SchemaView for SchemaViewer {
|
||||
fn table_id(&self, name: &str) -> Option<TableId> {
|
||||
@@ -370,6 +410,10 @@ pub mod test_utils {
|
||||
fn rls_rules_for_table(&self, _: TableId) -> anyhow::Result<Vec<Box<str>>> {
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
fn get_or_create_params(&mut self, params: ProductValue) -> TypingResult<ArgId> {
|
||||
Ok(self.1.get_or_insert(params))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -423,13 +467,13 @@ mod tests {
|
||||
}
|
||||
|
||||
/// A wrapper around [super::parse_and_type_sub] that takes a dummy [AuthCtx]
|
||||
fn parse_and_type_sub(sql: &str, tx: &impl SchemaView) -> TypingResult<ProjectName> {
|
||||
fn parse_and_type_sub(sql: &str, tx: &mut impl SchemaView) -> TypingResult<ProjectName> {
|
||||
super::parse_and_type_sub(sql, tx, &AuthCtx::for_testing()).map(|(plan, _)| plan)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn valid_literals() {
|
||||
let tx = SchemaViewer(module_def());
|
||||
let mut tx = SchemaViewer(module_def(), Default::default());
|
||||
|
||||
struct TestCase {
|
||||
sql: &'static str,
|
||||
@@ -498,27 +542,27 @@ mod tests {
|
||||
msg: "timestamp ms with timezone",
|
||||
},
|
||||
] {
|
||||
let result = parse_and_type_sub(sql, &tx);
|
||||
let result = parse_and_type_sub(sql, &mut tx);
|
||||
assert!(result.is_ok(), "name: {}, error: {}", msg, result.unwrap_err());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn valid_literals_for_type() {
|
||||
let tx = SchemaViewer(module_def());
|
||||
let mut tx = SchemaViewer(module_def(), Default::default());
|
||||
|
||||
for ty in [
|
||||
"i8", "u8", "i16", "u16", "i32", "u32", "i64", "u64", "f32", "f64", "i128", "u128", "i256", "u256",
|
||||
] {
|
||||
let sql = format!("select * from t where {ty} = 127");
|
||||
let result = parse_and_type_sub(&sql, &tx);
|
||||
let result = parse_and_type_sub(&sql, &mut tx);
|
||||
assert!(result.is_ok(), "Failed to parse {ty}: {}", result.unwrap_err());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_literals() {
|
||||
let tx = SchemaViewer(module_def());
|
||||
let mut tx = SchemaViewer(module_def(), Default::default());
|
||||
|
||||
struct TestCase {
|
||||
sql: &'static str,
|
||||
@@ -547,14 +591,14 @@ mod tests {
|
||||
msg: "Float as integer",
|
||||
},
|
||||
] {
|
||||
let result = parse_and_type_sub(sql, &tx);
|
||||
let result = parse_and_type_sub(sql, &mut tx);
|
||||
assert!(result.is_err(), "{msg}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn valid() {
|
||||
let tx = SchemaViewer(module_def());
|
||||
let mut tx = SchemaViewer(module_def(), Default::default());
|
||||
|
||||
struct TestCase {
|
||||
sql: &'static str,
|
||||
@@ -611,14 +655,14 @@ mod tests {
|
||||
msg: "Type inner join + projection",
|
||||
},
|
||||
] {
|
||||
let result = parse_and_type_sub(sql, &tx);
|
||||
let result = parse_and_type_sub(sql, &mut tx);
|
||||
assert!(result.is_ok(), "{msg}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid() {
|
||||
let tx = SchemaViewer(module_def());
|
||||
let mut tx = SchemaViewer(module_def(), Default::default());
|
||||
|
||||
struct TestCase {
|
||||
sql: &'static str,
|
||||
@@ -683,7 +727,7 @@ mod tests {
|
||||
msg: "Columns must be qualified in join expressions",
|
||||
},
|
||||
] {
|
||||
let result = parse_and_type_sub(sql, &tx);
|
||||
let result = parse_and_type_sub(sql, &mut tx);
|
||||
assert!(result.is_err(), "{msg}");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,6 +172,8 @@ pub enum TypingError {
|
||||
FilterReturnType(#[from] FilterReturnType),
|
||||
#[error(transparent)]
|
||||
TableFunc(#[from] TableFunc),
|
||||
#[error("InternalError: Read-only queries cannot create parameters")]
|
||||
ParamsReadOnly,
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
@@ -238,35 +238,6 @@ impl ProjectList {
|
||||
Self::Agg(_, _, name, ty) => f(name, ty),
|
||||
}
|
||||
}
|
||||
|
||||
/// Iterate over the function calls in this projection list
|
||||
pub fn for_each_fun_call(
|
||||
&mut self,
|
||||
f: &mut impl FnMut(ProductValue) -> Result<ArgId, TypingError>,
|
||||
) -> Result<(), TypingError> {
|
||||
match self {
|
||||
ProjectList::Name(input) => {
|
||||
for proj in input {
|
||||
match proj {
|
||||
ProjectName::None(expr) | ProjectName::Some(expr, _) => {
|
||||
expr.for_each_fun_call(f)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ProjectList::List(input, _) => {
|
||||
for expr in input {
|
||||
expr.for_each_fun_call(f)?;
|
||||
}
|
||||
}
|
||||
ProjectList::Limit(input, _) => {
|
||||
input.for_each_fun_call(f)?;
|
||||
}
|
||||
ProjectList::Agg(_, _, _, _) => {}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A logical relational expression
|
||||
@@ -398,31 +369,6 @@ impl RelExpr {
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn for_each_fun_call(
|
||||
&mut self,
|
||||
f: &mut impl FnMut(ProductValue) -> Result<ArgId, TypingError>,
|
||||
) -> Result<(), TypingError> {
|
||||
// For function calls, we need to filter by the argument id
|
||||
if let RelExpr::FunCall(relvar, param) = self {
|
||||
let new_arg_id = f(param.clone())?;
|
||||
let arg_id_col = relvar.schema.inner().get_column_by_name("arg_id").unwrap().col_pos;
|
||||
|
||||
*self = RelExpr::Select(
|
||||
Box::new(RelExpr::RelVar(relvar.clone())),
|
||||
Expr::BinOp(
|
||||
BinOp::Eq,
|
||||
Box::new(Expr::Field(FieldProject {
|
||||
table: relvar.alias.clone(),
|
||||
field: arg_id_col.idx(),
|
||||
ty: AlgebraicType::U64,
|
||||
})),
|
||||
Box::new(Expr::Value(AlgebraicValue::U64(new_arg_id.0), AlgebraicType::U64)),
|
||||
),
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A left deep binary cross product
|
||||
|
||||
+38
-31
@@ -12,7 +12,7 @@ use crate::{
|
||||
/// The main driver of RLS resolution for subscription queries.
|
||||
/// Mainly a wrapper around [resolve_views_for_expr].
|
||||
pub fn resolve_views_for_sub(
|
||||
tx: &impl SchemaView,
|
||||
tx: &mut impl SchemaView,
|
||||
expr: ProjectName,
|
||||
auth: &AuthCtx,
|
||||
has_param: &mut bool,
|
||||
@@ -54,7 +54,11 @@ pub fn resolve_views_for_sub(
|
||||
|
||||
/// The main driver of RLS resolution for sql queries.
|
||||
/// Mainly a wrapper around [resolve_views_for_expr].
|
||||
pub fn resolve_views_for_sql(tx: &impl SchemaView, expr: ProjectList, auth: &AuthCtx) -> anyhow::Result<ProjectList> {
|
||||
pub fn resolve_views_for_sql(
|
||||
tx: &mut impl SchemaView,
|
||||
expr: ProjectList,
|
||||
auth: &AuthCtx,
|
||||
) -> anyhow::Result<ProjectList> {
|
||||
// RLS does not apply to the database owner
|
||||
if auth.bypass_rls() {
|
||||
return Ok(expr);
|
||||
@@ -62,44 +66,41 @@ pub fn resolve_views_for_sql(tx: &impl SchemaView, expr: ProjectList, auth: &Aut
|
||||
// The subscription language is a subset of the sql language.
|
||||
// Use the subscription helper if this is a compliant expression.
|
||||
// Use the generic resolver otherwise.
|
||||
let resolve_for_sub = |expr| resolve_views_for_sub(tx, expr, auth, &mut false);
|
||||
let resolve_for_sql = |expr| {
|
||||
resolve_views_for_expr(
|
||||
// Use all default values
|
||||
tx,
|
||||
expr,
|
||||
None,
|
||||
Rc::new(ResolveList::None),
|
||||
&mut false,
|
||||
&mut 0,
|
||||
auth,
|
||||
)
|
||||
};
|
||||
match expr {
|
||||
ProjectList::Limit(expr, n) => Ok(ProjectList::Limit(Box::new(resolve_views_for_sql(tx, *expr, auth)?), n)),
|
||||
ProjectList::Limit(expr, n) => {
|
||||
let expr = resolve_views_for_sql(tx, *expr, auth)?;
|
||||
Ok(ProjectList::Limit(Box::new(expr), n))
|
||||
}
|
||||
|
||||
ProjectList::Name(exprs) => Ok(ProjectList::Name(
|
||||
exprs
|
||||
.into_iter()
|
||||
.map(resolve_for_sub)
|
||||
.map(|expr| resolve_views_for_sub(tx, expr, auth, &mut false))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect(),
|
||||
)),
|
||||
|
||||
ProjectList::List(exprs, fields) => Ok(ProjectList::List(
|
||||
exprs
|
||||
.into_iter()
|
||||
.map(resolve_for_sql)
|
||||
.map(|expr| {
|
||||
resolve_views_for_expr(tx, expr, None, Rc::new(ResolveList::None), &mut false, &mut 0, auth)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect(),
|
||||
fields,
|
||||
)),
|
||||
|
||||
ProjectList::Agg(exprs, AggType::Count, name, ty) => Ok(ProjectList::Agg(
|
||||
exprs
|
||||
.into_iter()
|
||||
.map(resolve_for_sql)
|
||||
.map(|expr| {
|
||||
resolve_views_for_expr(tx, expr, None, Rc::new(ResolveList::None), &mut false, &mut 0, auth)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
@@ -203,7 +204,7 @@ impl ResolveList {
|
||||
/// i.e. the subtree rooted at `a` in the above example,
|
||||
/// must be pushed below the leftmost leaf node of the view expansion.
|
||||
fn resolve_views_for_expr(
|
||||
tx: &impl SchemaView,
|
||||
tx: &mut impl SchemaView,
|
||||
view: RelExpr,
|
||||
return_table_id: Option<TableId>,
|
||||
resolving: Rc<ResolveList>,
|
||||
@@ -473,21 +474,23 @@ mod tests {
|
||||
use pretty_assertions as pretty;
|
||||
|
||||
use spacetimedb_lib::{identity::AuthCtx, AlgebraicType, AlgebraicValue, Identity, ProductType};
|
||||
use spacetimedb_primitives::TableId;
|
||||
use spacetimedb_primitives::{ArgId, TableId};
|
||||
use spacetimedb_sats::ProductValue;
|
||||
use spacetimedb_schema::{
|
||||
def::ModuleDef,
|
||||
schema::{Schema, TableOrViewSchema, TableSchema},
|
||||
};
|
||||
use spacetimedb_sql_parser::ast::BinOp;
|
||||
|
||||
use super::resolve_views_for_sub;
|
||||
use crate::check::test_utils::MockCallParams;
|
||||
use crate::check::TypingResult;
|
||||
use crate::{
|
||||
check::{parse_and_type_sub, test_utils::build_module_def, SchemaView},
|
||||
expr::{Expr, FieldProject, LeftDeepJoin, ProjectName, RelExpr, Relvar},
|
||||
};
|
||||
|
||||
use super::resolve_views_for_sub;
|
||||
|
||||
pub struct SchemaViewer(pub ModuleDef);
|
||||
pub struct SchemaViewer(pub ModuleDef, pub MockCallParams);
|
||||
|
||||
impl SchemaView for SchemaViewer {
|
||||
fn table_id(&self, name: &str) -> Option<TableId> {
|
||||
@@ -526,6 +529,10 @@ mod tests {
|
||||
_ => Ok(vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_or_create_params(&mut self, params: ProductValue) -> TypingResult<ArgId> {
|
||||
Ok(self.1.get_or_insert(params))
|
||||
}
|
||||
}
|
||||
|
||||
fn module_def() -> ModuleDef {
|
||||
@@ -549,17 +556,17 @@ mod tests {
|
||||
}
|
||||
|
||||
/// Parse, type check, and resolve RLS rules
|
||||
fn resolve(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> anyhow::Result<Vec<ProjectName>> {
|
||||
fn resolve(sql: &str, tx: &mut impl SchemaView, auth: &AuthCtx) -> anyhow::Result<Vec<ProjectName>> {
|
||||
let (expr, _) = parse_and_type_sub(sql, tx, auth)?;
|
||||
resolve_views_for_sub(tx, expr, auth, &mut false)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rls_for_owner() -> anyhow::Result<()> {
|
||||
let tx = SchemaViewer(module_def());
|
||||
let mut tx = SchemaViewer(module_def(), Default::default());
|
||||
let auth = AuthCtx::new(Identity::ONE, Identity::ONE);
|
||||
let sql = "select * from users";
|
||||
let resolved = resolve(sql, &tx, &auth)?;
|
||||
let resolved = resolve(sql, &mut tx, &auth)?;
|
||||
|
||||
let users_schema = tx.schema("users").unwrap();
|
||||
|
||||
@@ -577,10 +584,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_rls_for_non_owner() -> anyhow::Result<()> {
|
||||
let tx = SchemaViewer(module_def());
|
||||
let mut tx = SchemaViewer(module_def(), Default::default());
|
||||
let auth = AuthCtx::new(Identity::ZERO, Identity::ONE);
|
||||
let sql = "select * from users";
|
||||
let resolved = resolve(sql, &tx, &auth)?;
|
||||
let resolved = resolve(sql, &mut tx, &auth)?;
|
||||
|
||||
let users_schema = tx.schema("users").unwrap();
|
||||
|
||||
@@ -612,10 +619,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_multiple_rls_rules_for_table() -> anyhow::Result<()> {
|
||||
let tx = SchemaViewer(module_def());
|
||||
let mut tx = SchemaViewer(module_def(), Default::default());
|
||||
let auth = AuthCtx::new(Identity::ZERO, Identity::ONE);
|
||||
let sql = "select * from player where level_num = 5";
|
||||
let resolved = resolve(sql, &tx, &auth)?;
|
||||
let resolved = resolve(sql, &mut tx, &auth)?;
|
||||
|
||||
let users_schema = tx.schema("users").unwrap();
|
||||
let admins_schema = tx.schema("admins").unwrap();
|
||||
|
||||
@@ -394,11 +394,11 @@ impl TypeChecker for SqlChecker {
|
||||
type Ast = SqlSelect;
|
||||
type Set = SqlSelect;
|
||||
|
||||
fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<ProjectList> {
|
||||
fn type_ast(ast: Self::Ast, tx: &mut impl SchemaView) -> TypingResult<ProjectList> {
|
||||
Self::type_set(ast, &mut Relvars::default(), tx)
|
||||
}
|
||||
|
||||
fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<ProjectList> {
|
||||
fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &mut impl SchemaView) -> TypingResult<ProjectList> {
|
||||
match ast {
|
||||
SqlSelect {
|
||||
project,
|
||||
@@ -439,7 +439,7 @@ impl TypeChecker for SqlChecker {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_and_type_sql(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> TypingResult<Statement> {
|
||||
pub fn parse_and_type_sql(sql: &str, tx: &mut impl SchemaView, auth: &AuthCtx) -> TypingResult<Statement> {
|
||||
match parse_sql(sql)?.resolve_sender(auth.caller()) {
|
||||
SqlAst::Select(ast) => Ok(Statement::Select(SqlChecker::type_ast(ast, tx)?)),
|
||||
SqlAst::Insert(insert) => Ok(Statement::DML(DML::Insert(type_insert(insert, tx)?))),
|
||||
@@ -451,7 +451,7 @@ pub fn parse_and_type_sql(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> Ty
|
||||
}
|
||||
|
||||
/// Parse and type check a *general* query into a [StatementCtx].
|
||||
pub fn compile_sql_stmt<'a>(sql: &'a str, tx: &impl SchemaView, auth: &AuthCtx) -> TypingResult<StatementCtx<'a>> {
|
||||
pub fn compile_sql_stmt<'a>(sql: &'a str, tx: &mut impl SchemaView, auth: &AuthCtx) -> TypingResult<StatementCtx<'a>> {
|
||||
let statement = parse_and_type_sql(sql, tx, auth)?;
|
||||
Ok(StatementCtx {
|
||||
statement,
|
||||
@@ -520,13 +520,13 @@ mod tests {
|
||||
}
|
||||
|
||||
/// A wrapper around [super::parse_and_type_sql] that takes a dummy [AuthCtx]
|
||||
fn parse_and_type_sql(sql: &str, tx: &impl SchemaView) -> TypingResult<Statement> {
|
||||
fn parse_and_type_sql(sql: &str, tx: &mut impl SchemaView) -> TypingResult<Statement> {
|
||||
super::parse_and_type_sql(sql, tx, &AuthCtx::for_testing())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn valid() {
|
||||
let tx = SchemaViewer(module_def());
|
||||
let mut tx = SchemaViewer(module_def(), Default::default());
|
||||
|
||||
for sql in [
|
||||
"select str from t",
|
||||
@@ -534,14 +534,14 @@ mod tests {
|
||||
"select t.str, arr from t",
|
||||
"select * from t limit 5",
|
||||
] {
|
||||
let result = parse_and_type_sql(sql, &tx);
|
||||
let result = parse_and_type_sql(sql, &mut tx);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid() {
|
||||
let tx = SchemaViewer(module_def());
|
||||
let mut tx = SchemaViewer(module_def(), Default::default());
|
||||
|
||||
for sql in [
|
||||
// Unqualified columns in a join
|
||||
@@ -551,7 +551,7 @@ mod tests {
|
||||
// Unqualified name in join expression
|
||||
"select t.* from t join s on t.u32 = s.u32 where bytes = 0xABCD",
|
||||
] {
|
||||
let result = parse_and_type_sql(sql, &tx);
|
||||
let result = parse_and_type_sql(sql, &mut tx);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
@@ -581,7 +581,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn views() {
|
||||
let tx = SchemaViewer(module_def());
|
||||
let mut tx = SchemaViewer(module_def(), Default::default());
|
||||
|
||||
struct TestCase {
|
||||
sql: &'static str,
|
||||
@@ -606,7 +606,7 @@ mod tests {
|
||||
msg: "Function call returning view with parameters",
|
||||
},
|
||||
] {
|
||||
let result = parse_and_type_sql(sql, &tx).inspect_err(|e| {
|
||||
let result = parse_and_type_sql(sql, &mut tx).inspect_err(|e| {
|
||||
panic!("Expected OK for `{sql}` but got error: {e}");
|
||||
});
|
||||
assert!(result.is_ok(), "{msg}: {sql}");
|
||||
@@ -630,14 +630,14 @@ mod tests {
|
||||
msg: "`v` does not take parameters",
|
||||
},
|
||||
] {
|
||||
let result = parse_and_type_sql(sql, &tx);
|
||||
let result = parse_and_type_sql(sql, &mut tx);
|
||||
assert!(result.is_err(), "{msg}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn params_validation() {
|
||||
let tx = SchemaViewer(module_def());
|
||||
let mut tx = SchemaViewer(module_def(), Default::default());
|
||||
|
||||
struct TestCase {
|
||||
sql: &'static str,
|
||||
@@ -670,7 +670,7 @@ mod tests {
|
||||
msg: "Unexpected function type. Expected: (U32, String) != Inferred: (U32, String, Num?)",
|
||||
},
|
||||
] {
|
||||
let result = parse_and_type_sql(sql, &tx);
|
||||
let result = parse_and_type_sql(sql, &mut tx);
|
||||
if msg == "Correct parameters" {
|
||||
assert!(result.is_ok(), "{msg}: {sql}");
|
||||
} else if let Err(err) = &result {
|
||||
|
||||
@@ -1402,6 +1402,7 @@ mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use pretty_assertions::assert_eq;
|
||||
use spacetimedb_expr::check::test_utils::MockCallParams;
|
||||
use spacetimedb_expr::{
|
||||
check::{SchemaView, TypingResult},
|
||||
expr::ProjectName,
|
||||
@@ -1410,9 +1411,9 @@ mod tests {
|
||||
use spacetimedb_lib::{
|
||||
db::auth::{StAccess, StTableType},
|
||||
identity::AuthCtx,
|
||||
AlgebraicType, AlgebraicValue,
|
||||
AlgebraicType, AlgebraicValue, ProductValue,
|
||||
};
|
||||
use spacetimedb_primitives::{ColId, ColList, ColSet, TableId, ViewId};
|
||||
use spacetimedb_primitives::{ArgId, ColId, ColList, ColSet, TableId, ViewId};
|
||||
use spacetimedb_schema::def::ViewParamDefSimple;
|
||||
use spacetimedb_schema::identifier::Identifier;
|
||||
use spacetimedb_schema::schema::ViewDefInfo;
|
||||
@@ -1431,6 +1432,7 @@ mod tests {
|
||||
|
||||
struct SchemaViewer {
|
||||
schemas: Vec<Arc<TableOrViewSchema>>,
|
||||
params: MockCallParams,
|
||||
}
|
||||
|
||||
impl SchemaView for SchemaViewer {
|
||||
@@ -1448,6 +1450,10 @@ mod tests {
|
||||
fn rls_rules_for_table(&self, _: TableId) -> anyhow::Result<Vec<Box<str>>> {
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
fn get_or_create_params(&mut self, params: ProductValue) -> TypingResult<ArgId> {
|
||||
Ok(self.params.get_or_insert(params))
|
||||
}
|
||||
}
|
||||
|
||||
fn schema_with_params(
|
||||
@@ -1534,7 +1540,7 @@ mod tests {
|
||||
}
|
||||
|
||||
/// A wrapper around [spacetimedb_expr::check::parse_and_type_sub] that takes a dummy [AuthCtx]
|
||||
fn parse_and_type_sub(sql: &str, tx: &impl SchemaView) -> TypingResult<ProjectName> {
|
||||
fn parse_and_type_sub(sql: &str, tx: &mut impl SchemaView) -> TypingResult<ProjectName> {
|
||||
spacetimedb_expr::check::parse_and_type_sub(sql, tx, &AuthCtx::for_testing()).map(|(plan, _)| plan)
|
||||
}
|
||||
|
||||
@@ -1552,14 +1558,15 @@ mod tests {
|
||||
Some(0),
|
||||
));
|
||||
|
||||
let db = SchemaViewer {
|
||||
let mut db = SchemaViewer {
|
||||
schemas: vec![t.clone()],
|
||||
params: Default::default(),
|
||||
};
|
||||
|
||||
let sql = "select * from t";
|
||||
|
||||
let auth = AuthCtx::for_testing();
|
||||
let lp = parse_and_type_sub(sql, &db).unwrap();
|
||||
let lp = parse_and_type_sub(sql, &mut db).unwrap();
|
||||
let pp = compile_select(lp).optimize(&auth).unwrap();
|
||||
|
||||
match pp {
|
||||
@@ -1584,14 +1591,15 @@ mod tests {
|
||||
Some(0),
|
||||
));
|
||||
|
||||
let db = SchemaViewer {
|
||||
let mut db = SchemaViewer {
|
||||
schemas: vec![t.clone()],
|
||||
params: Default::default(),
|
||||
};
|
||||
|
||||
let sql = "select * from t where x = 5";
|
||||
|
||||
let auth = AuthCtx::for_testing();
|
||||
let lp = parse_and_type_sub(sql, &db).unwrap();
|
||||
let lp = parse_and_type_sub(sql, &mut db).unwrap();
|
||||
let pp = compile_select(lp).optimize(&auth).unwrap();
|
||||
|
||||
match pp {
|
||||
@@ -1678,8 +1686,9 @@ mod tests {
|
||||
Some(0),
|
||||
));
|
||||
|
||||
let db = SchemaViewer {
|
||||
let mut db = SchemaViewer {
|
||||
schemas: vec![u.clone(), l.clone(), b.clone()],
|
||||
params: Default::default(),
|
||||
};
|
||||
|
||||
let sql = "
|
||||
@@ -1691,7 +1700,7 @@ mod tests {
|
||||
where u.identity = 5
|
||||
";
|
||||
let auth = AuthCtx::for_testing();
|
||||
let lp = parse_and_type_sub(sql, &db).unwrap();
|
||||
let lp = parse_and_type_sub(sql, &mut db).unwrap();
|
||||
let pp = compile_select(lp).optimize(&auth).unwrap();
|
||||
|
||||
// Plan:
|
||||
@@ -1870,8 +1879,9 @@ mod tests {
|
||||
Some(0),
|
||||
));
|
||||
|
||||
let db = SchemaViewer {
|
||||
let mut db = SchemaViewer {
|
||||
schemas: vec![m.clone(), w.clone(), p.clone()],
|
||||
params: Default::default(),
|
||||
};
|
||||
|
||||
let sql = "
|
||||
@@ -1884,7 +1894,7 @@ mod tests {
|
||||
where 5 = m.employee and 5 = v.employee
|
||||
";
|
||||
let auth = AuthCtx::for_testing();
|
||||
let lp = parse_and_type_sub(sql, &db).unwrap();
|
||||
let lp = parse_and_type_sub(sql, &mut db).unwrap();
|
||||
let pp = compile_select(lp).optimize(&auth).unwrap();
|
||||
|
||||
// Plan:
|
||||
@@ -2052,13 +2062,14 @@ mod tests {
|
||||
None,
|
||||
));
|
||||
|
||||
let db = SchemaViewer {
|
||||
let mut db = SchemaViewer {
|
||||
schemas: vec![t.clone()],
|
||||
params: Default::default(),
|
||||
};
|
||||
|
||||
let sql = "select * from t where x = 3 and y = 4 and z = 5";
|
||||
let auth = AuthCtx::for_testing();
|
||||
let lp = parse_and_type_sub(sql, &db).unwrap();
|
||||
let lp = parse_and_type_sub(sql, &mut db).unwrap();
|
||||
let pp = compile_select(lp).optimize(&auth).unwrap();
|
||||
|
||||
// Select index on (x, y, z)
|
||||
@@ -2081,7 +2092,7 @@ mod tests {
|
||||
|
||||
// Test permutations of the same query
|
||||
let sql = "select * from t where z = 5 and y = 4 and x = 3";
|
||||
let lp = parse_and_type_sub(sql, &db).unwrap();
|
||||
let lp = parse_and_type_sub(sql, &mut db).unwrap();
|
||||
let pp = compile_select(lp).optimize(&auth).unwrap();
|
||||
|
||||
match pp {
|
||||
@@ -2102,7 +2113,7 @@ mod tests {
|
||||
};
|
||||
|
||||
let sql = "select * from t where x = 3 and y = 4";
|
||||
let lp = parse_and_type_sub(sql, &db).unwrap();
|
||||
let lp = parse_and_type_sub(sql, &mut db).unwrap();
|
||||
let pp = compile_select(lp).optimize(&auth).unwrap();
|
||||
|
||||
// Select index on x
|
||||
@@ -2130,7 +2141,7 @@ mod tests {
|
||||
};
|
||||
|
||||
let sql = "select * from t where w = 5 and x = 4";
|
||||
let lp = parse_and_type_sub(sql, &db).unwrap();
|
||||
let lp = parse_and_type_sub(sql, &mut db).unwrap();
|
||||
let pp = compile_select(lp).optimize(&auth).unwrap();
|
||||
|
||||
// Select index on x
|
||||
@@ -2158,7 +2169,7 @@ mod tests {
|
||||
};
|
||||
|
||||
let sql = "select * from t where y = 1";
|
||||
let lp = parse_and_type_sub(sql, &db).unwrap();
|
||||
let lp = parse_and_type_sub(sql, &mut db).unwrap();
|
||||
let pp = compile_select(lp).optimize(&auth).unwrap();
|
||||
|
||||
// Do not select index on (y, z)
|
||||
@@ -2173,7 +2184,7 @@ mod tests {
|
||||
|
||||
// Select index on [y, z]
|
||||
let sql = "select * from t where y = 1 and z = 2";
|
||||
let lp = parse_and_type_sub(sql, &db).unwrap();
|
||||
let lp = parse_and_type_sub(sql, &mut db).unwrap();
|
||||
let pp = compile_select(lp).optimize(&auth).unwrap();
|
||||
|
||||
match pp {
|
||||
@@ -2192,7 +2203,7 @@ mod tests {
|
||||
|
||||
// Check permutations of the same query
|
||||
let sql = "select * from t where z = 2 and y = 1";
|
||||
let lp = parse_and_type_sub(sql, &db).unwrap();
|
||||
let lp = parse_and_type_sub(sql, &mut db).unwrap();
|
||||
let pp = compile_select(lp).optimize(&auth).unwrap();
|
||||
|
||||
match pp {
|
||||
@@ -2211,7 +2222,7 @@ mod tests {
|
||||
|
||||
// Select index on (y, z) and filter on (w)
|
||||
let sql = "select * from t where w = 1 and y = 2 and z = 3";
|
||||
let lp = parse_and_type_sub(sql, &db).unwrap();
|
||||
let lp = parse_and_type_sub(sql, &mut db).unwrap();
|
||||
let pp = compile_select(lp).optimize(&auth).unwrap();
|
||||
|
||||
let plan = match pp {
|
||||
@@ -2251,12 +2262,13 @@ mod tests {
|
||||
None,
|
||||
));
|
||||
|
||||
let db = SchemaViewer {
|
||||
let mut db = SchemaViewer {
|
||||
schemas: vec![t.clone()],
|
||||
params: Default::default(),
|
||||
};
|
||||
|
||||
let compile = |sql| {
|
||||
let stmt = parse_and_type_sql(sql, &db, &AuthCtx::for_testing()).unwrap();
|
||||
let mut compile = |sql| {
|
||||
let stmt = parse_and_type_sql(sql, &mut db, &AuthCtx::for_testing()).unwrap();
|
||||
let Statement::Select(select) = stmt else {
|
||||
unreachable!()
|
||||
};
|
||||
@@ -2322,20 +2334,20 @@ mod tests {
|
||||
Some(&[("param_id", AlgebraicType::U64)]),
|
||||
));
|
||||
|
||||
let db = SchemaViewer {
|
||||
let mut db = SchemaViewer {
|
||||
schemas: vec![t.clone(), v.clone()],
|
||||
params: Default::default(),
|
||||
};
|
||||
|
||||
let sql = "select * from v(0)";
|
||||
|
||||
let auth = AuthCtx::for_testing();
|
||||
let lp = parse_and_type_sub(sql, &db).unwrap();
|
||||
dbg!(&lp);
|
||||
let lp = parse_and_type_sub(sql, &mut db).unwrap();
|
||||
let pp = compile_select(lp).optimize(&auth).unwrap();
|
||||
dbg!(&pp);
|
||||
match pp {
|
||||
ProjectPlan::None(PhysicalPlan::Filter(input, PhysicalExpr::BinOp(BinOp::Eq, field, value))) => {
|
||||
assert!(matches!(*field, PhysicalExpr::Field(TupleField { field_pos: 1, .. })));
|
||||
// This is the internal parameter filter
|
||||
assert!(matches!(*field, PhysicalExpr::Field(TupleField { field_pos: 0, .. })));
|
||||
assert!(matches!(*value, PhysicalExpr::Value(AlgebraicValue::U64(0))));
|
||||
|
||||
match *input {
|
||||
@@ -2349,12 +2361,12 @@ mod tests {
|
||||
};
|
||||
|
||||
let sql = "select * from v(0) as x JOIN t ON x.id = t.id";
|
||||
let lp = parse_and_type_sub(sql, &db).unwrap();
|
||||
let lp = parse_and_type_sub(sql, &mut db).unwrap();
|
||||
let pp = compile_select(lp).optimize(&auth).unwrap();
|
||||
|
||||
match pp {
|
||||
ProjectPlan::None(PhysicalPlan::Filter(_, PhysicalExpr::BinOp(BinOp::Eq, field, value))) => {
|
||||
assert!(matches!(*field, PhysicalExpr::Field(TupleField { field_pos: 1, .. })));
|
||||
assert!(matches!(*field, PhysicalExpr::Field(TupleField { field_pos: 0, .. })));
|
||||
assert!(matches!(*value, PhysicalExpr::Value(AlgebraicValue::U64(0))));
|
||||
}
|
||||
proj => panic!("unexpected project: {proj:#?}"),
|
||||
|
||||
+3
-20
@@ -4,8 +4,6 @@ use spacetimedb_execution::{
|
||||
pipelined::ProjectListExecutor,
|
||||
Datastore, DeltaStore,
|
||||
};
|
||||
use spacetimedb_expr::errors::TypingError;
|
||||
use spacetimedb_expr::expr::CallParams;
|
||||
use spacetimedb_expr::{
|
||||
check::{parse_and_type_sub, SchemaView},
|
||||
expr::ProjectList,
|
||||
@@ -17,30 +15,15 @@ use spacetimedb_physical_plan::{
|
||||
compile::{compile_dml_plan, compile_select, compile_select_list},
|
||||
plan::{ProjectListPlan, ProjectPlan},
|
||||
};
|
||||
use spacetimedb_primitives::{ArgId, TableId};
|
||||
use std::collections::HashMap;
|
||||
use spacetimedb_primitives::TableId;
|
||||
|
||||
/// DIRTY HACK ALERT: Maximum allowed length, in UTF-8 bytes, of SQL queries.
|
||||
/// Any query longer than this will be rejected.
|
||||
/// This prevents a stack overflow when compiling queries with deeply-nested `AND` and `OR` conditions.
|
||||
const MAX_SQL_LENGTH: usize = 50_000;
|
||||
|
||||
pub trait CallParamsExt {
|
||||
fn get_arg(&self, params: &ProductValue) -> Result<ArgId, TypingError>;
|
||||
}
|
||||
|
||||
pub struct MockCallParams {
|
||||
params: HashMap<ProductValue, ArgId>,
|
||||
}
|
||||
impl CallParamsExt for MockCallParams {
|
||||
fn get_arg(&self, params: &ProductValue) -> Result<ArgId, TypingError> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn compile_subscription(
|
||||
sql: &str,
|
||||
tx: &impl SchemaView,
|
||||
tx: &mut impl SchemaView,
|
||||
auth: &AuthCtx,
|
||||
) -> Result<(Vec<ProjectPlan>, TableId, Box<str>, bool)> {
|
||||
if sql.len() > MAX_SQL_LENGTH {
|
||||
@@ -72,7 +55,7 @@ pub fn compile_subscription(
|
||||
}
|
||||
|
||||
/// A utility for parsing and type checking a sql statement
|
||||
pub fn compile_sql_stmt(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> Result<Statement> {
|
||||
pub fn compile_sql_stmt(sql: &str, tx: &mut impl SchemaView, auth: &AuthCtx) -> Result<Statement> {
|
||||
if sql.len() > MAX_SQL_LENGTH {
|
||||
bail!("SQL query exceeds maximum allowed length: \"{sql:.120}...\"")
|
||||
}
|
||||
|
||||
@@ -508,7 +508,7 @@ impl SubscriptionPlan {
|
||||
}
|
||||
|
||||
/// Generate a plan for incrementally maintaining a subscription
|
||||
pub fn compile(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> Result<(Vec<Self>, bool)> {
|
||||
pub fn compile(sql: &str, tx: &mut impl SchemaView, auth: &AuthCtx) -> Result<(Vec<Self>, bool)> {
|
||||
let (plans, return_id, return_name, has_param) = compile_subscription(sql, tx, auth)?;
|
||||
|
||||
/// Does this plan have any non-index joins?
|
||||
|
||||
Reference in New Issue
Block a user