import random from sqlalchemy import testing from sqlalchemy.schema import Column from sqlalchemy.sql import bindparam from sqlalchemy.sql import column from sqlalchemy.sql import dml from sqlalchemy.sql import func from sqlalchemy.sql import select from sqlalchemy.sql import text from sqlalchemy.sql import tstring from sqlalchemy.sql.base import ExecutableStatement from sqlalchemy.sql.elements import literal from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_deprecated from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_not from sqlalchemy.testing import ne_ from sqlalchemy.testing.schema import Table from sqlalchemy.types import Integer from sqlalchemy.types import Text from sqlalchemy.util.compat import Template from sqlalchemy.util.langhelpers import class_hierarchy class BasicTests(fixtures.TestBase): def _all_subclasses(self, cls_): return dict.fromkeys( s for s in class_hierarchy(cls_) # class_hierarchy may return values that # aren't subclasses of cls if issubclass(s, cls_) ) @staticmethod def _relevant_impls(): return ( text("select 1 + 2"), tstring(Template("select 1 + 2")), text("select 42 as q").columns(column("q", Integer)), func.max(42), select(1, 2).union(select(3, 4)), select(1, 2), ) def test_params_impl(self): exclude = (dml.UpdateBase,) visit_names = set() for cls_ in self._all_subclasses(ExecutableStatement): if not issubclass(cls_, exclude): if "__visit_name__" in cls_.__dict__: visit_names.add(cls_.__visit_name__) eq_(cls_.params, ExecutableStatement.params, cls_) else: ne_(cls_.params, ExecutableStatement.params, cls_) for other in exclude: if issubclass(cls_, other): eq_(cls_.params, other.params, cls_) break else: assert False extra = {"orm_from_statement"} eq_( visit_names - extra, {i.__visit_name__ for i in self._relevant_impls()}, ) @testing.combinations(*_relevant_impls()) def test_compile_params(self, impl): new = impl.params(foo=5, bar=10) is_not(new, impl) eq_(impl.compile()._collected_params, {}) eq_(new.compile()._collected_params, {"foo": 5, "bar": 10}) eq_(new._generate_cache_key()[2], {"foo": 5, "bar": 10}) class CacheTests(fixtures.TablesTest): __sparse_driver_backend__ = True @classmethod def define_tables(cls, metadata): Table("a", metadata, Column("data", Integer)) Table("b", metadata, Column("data", Text)) @classmethod def insert_data(cls, connection): connection.execute( cls.tables.a.insert(), [{"data": i} for i in range(1, 11)], ) connection.execute( cls.tables.b.insert(), [{"data": "row %d" % i} for i in range(1, 11)], ) def test_plain_select(self, connection): a = self.tables.a cs = connection.scalars for _ in range(3): x1 = random.randint(1, 10) eq_(cs(select(a).where(a.c.data == x1)).all(), [x1]) stmt = select(a).where(a.c.data == bindparam("x", x1)) eq_(cs(stmt).all(), [x1]) x1 = random.randint(1, 10) eq_(cs(stmt.params({"x": x1})).all(), [x1]) x1 = random.randint(1, 10) eq_(cs(stmt, {"x": x1}).all(), [x1]) x1 = random.randint(1, 10) x2 = random.randint(1, 10) eq_(cs(stmt.params({"x": x1}), {"x": x2}).all(), [x2]) stmt2 = stmt.params(x=6).subquery().select() eq_(cs(stmt2).all(), [6]) eq_(cs(stmt2.params({"x": 2})).all(), [2]) with expect_deprecated( r"The params\(\) and unique_params\(\) " "methods on non-statement" ): # NOTE: can't mix and match the two params styles here stmt3 = stmt.params(x=6).subquery().params(x=8).select() eq_(cs(stmt3).all(), [6]) eq_(cs(stmt3.params({"x": 9})).all(), [9]) def test_union(self, connection): a = self.tables.a cs = connection.scalars for _ in range(3): x1 = random.randint(1, 10) x2 = random.randint(1, 10) eq_( cs( select(a) .where(a.c.data == x1) .union_all(select(a).where(a.c.data == x2)) .order_by(a.c.data) ).all(), sorted([x1, x2]), ) x1 = random.randint(1, 10) x2 = random.randint(1, 10) stmt = ( select(a, literal(1).label("ord")) .where(a.c.data == bindparam("x", x1)) .union_all( select(a, literal(2)).where(a.c.data == bindparam("y", x2)) ) .order_by("ord") ) eq_(cs(stmt).all(), [x1, x2]) x1a = random.randint(1, 10) eq_(cs(stmt.params({"x": x1a})).all(), [x1a, x2]) x2 = random.randint(1, 10) eq_(cs(stmt, {"y": x2}).all(), [x1, x2]) x1 = random.randint(1, 10) x2 = random.randint(1, 10) eq_(cs(stmt.params({"x": x1}), {"y": x2}).all(), [x1, x2]) x1 = random.randint(1, 10) x2 = random.randint(1, 10) stmt2 = ( stmt.params(x=x1) .subquery() .select() .params(y=x2) .order_by("ord") ) eq_(cs(stmt2).all(), [x1, x2]) eq_(cs(stmt2.params({"x": x1}).params({"y": x2})).all(), [x1, x2]) def test_text(self, connection): a = self.tables.a cs = connection.scalars for _ in range(3): x0 = random.randint(1, 10) stmt = text("select data from a where data = :x").params(x=x0) eq_(cs(stmt).all(), [x0]) x1 = random.randint(1, 10) eq_(cs(stmt.params({"x": x1})).all(), [x1]) x2 = random.randint(1, 10) stmt2 = stmt.columns(a.c.data).params(x=x2) eq_(cs(stmt2).all(), [x2]) eq_(cs(stmt2, {"x": 1}).all(), [1]) eq_(cs(stmt2.params(x=1)).all(), [1])