Files
sqlalchemy/test/sql/test_statement_params.py
Mike Bayer b9e3cacb0e add TString support
Added support for Python 3.14+ template strings (t-strings) via the new
:func:`_sql.tstring` construct, as defined in :pep:`750`. This feature
allows for ergonomic SQL statement construction by automatically
interpolating Python values and SQLAlchemy expressions within template
strings.

Part of the challenge here is the syntax only works on py314, so we have
to exclude the test file at many levels when py314 is not used.  not
sure yet how i want to adjust pep8 tests and rules for this.

Fixes: #12548
Change-Id: Ia060d1387ff452fe4f5d924f683529a22a8e1f72
2025-11-30 14:38:13 -05:00

203 lines
6.5 KiB
Python

import random
from sqlalchemy import testing
from sqlalchemy.schema import Column
from sqlalchemy.sql import bindparam
from sqlalchemy.sql import column
from sqlalchemy.sql import dml
from sqlalchemy.sql import func
from sqlalchemy.sql import select
from sqlalchemy.sql import text
from sqlalchemy.sql import tstring
from sqlalchemy.sql.base import ExecutableStatement
from sqlalchemy.sql.elements import literal
from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_deprecated
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_not
from sqlalchemy.testing import ne_
from sqlalchemy.testing.schema import Table
from sqlalchemy.types import Integer
from sqlalchemy.types import Text
from sqlalchemy.util.compat import Template
from sqlalchemy.util.langhelpers import class_hierarchy
class BasicTests(fixtures.TestBase):
def _all_subclasses(self, cls_):
return dict.fromkeys(
s
for s in class_hierarchy(cls_)
# class_hierarchy may return values that
# aren't subclasses of cls
if issubclass(s, cls_)
)
@staticmethod
def _relevant_impls():
return (
text("select 1 + 2"),
tstring(Template("select 1 + 2")),
text("select 42 as q").columns(column("q", Integer)),
func.max(42),
select(1, 2).union(select(3, 4)),
select(1, 2),
)
def test_params_impl(self):
exclude = (dml.UpdateBase,)
visit_names = set()
for cls_ in self._all_subclasses(ExecutableStatement):
if not issubclass(cls_, exclude):
if "__visit_name__" in cls_.__dict__:
visit_names.add(cls_.__visit_name__)
eq_(cls_.params, ExecutableStatement.params, cls_)
else:
ne_(cls_.params, ExecutableStatement.params, cls_)
for other in exclude:
if issubclass(cls_, other):
eq_(cls_.params, other.params, cls_)
break
else:
assert False
extra = {"orm_from_statement"}
eq_(
visit_names - extra,
{i.__visit_name__ for i in self._relevant_impls()},
)
@testing.combinations(*_relevant_impls())
def test_compile_params(self, impl):
new = impl.params(foo=5, bar=10)
is_not(new, impl)
eq_(impl.compile()._collected_params, {})
eq_(new.compile()._collected_params, {"foo": 5, "bar": 10})
eq_(new._generate_cache_key()[2], {"foo": 5, "bar": 10})
class CacheTests(fixtures.TablesTest):
__sparse_driver_backend__ = True
@classmethod
def define_tables(cls, metadata):
Table("a", metadata, Column("data", Integer))
Table("b", metadata, Column("data", Text))
@classmethod
def insert_data(cls, connection):
connection.execute(
cls.tables.a.insert(),
[{"data": i} for i in range(1, 11)],
)
connection.execute(
cls.tables.b.insert(),
[{"data": "row %d" % i} for i in range(1, 11)],
)
def test_plain_select(self, connection):
a = self.tables.a
cs = connection.scalars
for _ in range(3):
x1 = random.randint(1, 10)
eq_(cs(select(a).where(a.c.data == x1)).all(), [x1])
stmt = select(a).where(a.c.data == bindparam("x", x1))
eq_(cs(stmt).all(), [x1])
x1 = random.randint(1, 10)
eq_(cs(stmt.params({"x": x1})).all(), [x1])
x1 = random.randint(1, 10)
eq_(cs(stmt, {"x": x1}).all(), [x1])
x1 = random.randint(1, 10)
x2 = random.randint(1, 10)
eq_(cs(stmt.params({"x": x1}), {"x": x2}).all(), [x2])
stmt2 = stmt.params(x=6).subquery().select()
eq_(cs(stmt2).all(), [6])
eq_(cs(stmt2.params({"x": 2})).all(), [2])
with expect_deprecated(
r"The params\(\) and unique_params\(\) "
"methods on non-statement"
):
# NOTE: can't mix and match the two params styles here
stmt3 = stmt.params(x=6).subquery().params(x=8).select()
eq_(cs(stmt3).all(), [6])
eq_(cs(stmt3.params({"x": 9})).all(), [9])
def test_union(self, connection):
a = self.tables.a
cs = connection.scalars
for _ in range(3):
x1 = random.randint(1, 10)
x2 = random.randint(1, 10)
eq_(
cs(
select(a)
.where(a.c.data == x1)
.union_all(select(a).where(a.c.data == x2))
.order_by(a.c.data)
).all(),
sorted([x1, x2]),
)
x1 = random.randint(1, 10)
x2 = random.randint(1, 10)
stmt = (
select(a, literal(1).label("ord"))
.where(a.c.data == bindparam("x", x1))
.union_all(
select(a, literal(2)).where(a.c.data == bindparam("y", x2))
)
.order_by("ord")
)
eq_(cs(stmt).all(), [x1, x2])
x1a = random.randint(1, 10)
eq_(cs(stmt.params({"x": x1a})).all(), [x1a, x2])
x2 = random.randint(1, 10)
eq_(cs(stmt, {"y": x2}).all(), [x1, x2])
x1 = random.randint(1, 10)
x2 = random.randint(1, 10)
eq_(cs(stmt.params({"x": x1}), {"y": x2}).all(), [x1, x2])
x1 = random.randint(1, 10)
x2 = random.randint(1, 10)
stmt2 = (
stmt.params(x=x1)
.subquery()
.select()
.params(y=x2)
.order_by("ord")
)
eq_(cs(stmt2).all(), [x1, x2])
eq_(cs(stmt2.params({"x": x1}).params({"y": x2})).all(), [x1, x2])
def test_text(self, connection):
a = self.tables.a
cs = connection.scalars
for _ in range(3):
x0 = random.randint(1, 10)
stmt = text("select data from a where data = :x").params(x=x0)
eq_(cs(stmt).all(), [x0])
x1 = random.randint(1, 10)
eq_(cs(stmt.params({"x": x1})).all(), [x1])
x2 = random.randint(1, 10)
stmt2 = stmt.columns(a.c.data).params(x=x2)
eq_(cs(stmt2).all(), [x2])
eq_(cs(stmt2, {"x": 1}).all(), [1])
eq_(cs(stmt2.params(x=1)).all(), [1])