mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-07 01:10:52 -04:00
b9e3cacb0e
Added support for Python 3.14+ template strings (t-strings) via the new :func:`_sql.tstring` construct, as defined in :pep:`750`. This feature allows for ergonomic SQL statement construction by automatically interpolating Python values and SQLAlchemy expressions within template strings. Part of the challenge here is the syntax only works on py314, so we have to exclude the test file at many levels when py314 is not used. not sure yet how i want to adjust pep8 tests and rules for this. Fixes: #12548 Change-Id: Ia060d1387ff452fe4f5d924f683529a22a8e1f72
203 lines
6.5 KiB
Python
203 lines
6.5 KiB
Python
import random
|
|
|
|
from sqlalchemy import testing
|
|
from sqlalchemy.schema import Column
|
|
from sqlalchemy.sql import bindparam
|
|
from sqlalchemy.sql import column
|
|
from sqlalchemy.sql import dml
|
|
from sqlalchemy.sql import func
|
|
from sqlalchemy.sql import select
|
|
from sqlalchemy.sql import text
|
|
from sqlalchemy.sql import tstring
|
|
from sqlalchemy.sql.base import ExecutableStatement
|
|
from sqlalchemy.sql.elements import literal
|
|
from sqlalchemy.testing import eq_
|
|
from sqlalchemy.testing import expect_deprecated
|
|
from sqlalchemy.testing import fixtures
|
|
from sqlalchemy.testing import is_not
|
|
from sqlalchemy.testing import ne_
|
|
from sqlalchemy.testing.schema import Table
|
|
from sqlalchemy.types import Integer
|
|
from sqlalchemy.types import Text
|
|
from sqlalchemy.util.compat import Template
|
|
from sqlalchemy.util.langhelpers import class_hierarchy
|
|
|
|
|
|
class BasicTests(fixtures.TestBase):
|
|
def _all_subclasses(self, cls_):
|
|
return dict.fromkeys(
|
|
s
|
|
for s in class_hierarchy(cls_)
|
|
# class_hierarchy may return values that
|
|
# aren't subclasses of cls
|
|
if issubclass(s, cls_)
|
|
)
|
|
|
|
@staticmethod
|
|
def _relevant_impls():
|
|
return (
|
|
text("select 1 + 2"),
|
|
tstring(Template("select 1 + 2")),
|
|
text("select 42 as q").columns(column("q", Integer)),
|
|
func.max(42),
|
|
select(1, 2).union(select(3, 4)),
|
|
select(1, 2),
|
|
)
|
|
|
|
def test_params_impl(self):
|
|
exclude = (dml.UpdateBase,)
|
|
visit_names = set()
|
|
for cls_ in self._all_subclasses(ExecutableStatement):
|
|
if not issubclass(cls_, exclude):
|
|
if "__visit_name__" in cls_.__dict__:
|
|
visit_names.add(cls_.__visit_name__)
|
|
eq_(cls_.params, ExecutableStatement.params, cls_)
|
|
else:
|
|
ne_(cls_.params, ExecutableStatement.params, cls_)
|
|
for other in exclude:
|
|
if issubclass(cls_, other):
|
|
eq_(cls_.params, other.params, cls_)
|
|
break
|
|
else:
|
|
assert False
|
|
|
|
extra = {"orm_from_statement"}
|
|
eq_(
|
|
visit_names - extra,
|
|
{i.__visit_name__ for i in self._relevant_impls()},
|
|
)
|
|
|
|
@testing.combinations(*_relevant_impls())
|
|
def test_compile_params(self, impl):
|
|
new = impl.params(foo=5, bar=10)
|
|
is_not(new, impl)
|
|
eq_(impl.compile()._collected_params, {})
|
|
eq_(new.compile()._collected_params, {"foo": 5, "bar": 10})
|
|
eq_(new._generate_cache_key()[2], {"foo": 5, "bar": 10})
|
|
|
|
|
|
class CacheTests(fixtures.TablesTest):
|
|
__sparse_driver_backend__ = True
|
|
|
|
@classmethod
|
|
def define_tables(cls, metadata):
|
|
Table("a", metadata, Column("data", Integer))
|
|
Table("b", metadata, Column("data", Text))
|
|
|
|
@classmethod
|
|
def insert_data(cls, connection):
|
|
connection.execute(
|
|
cls.tables.a.insert(),
|
|
[{"data": i} for i in range(1, 11)],
|
|
)
|
|
connection.execute(
|
|
cls.tables.b.insert(),
|
|
[{"data": "row %d" % i} for i in range(1, 11)],
|
|
)
|
|
|
|
def test_plain_select(self, connection):
|
|
a = self.tables.a
|
|
|
|
cs = connection.scalars
|
|
|
|
for _ in range(3):
|
|
x1 = random.randint(1, 10)
|
|
|
|
eq_(cs(select(a).where(a.c.data == x1)).all(), [x1])
|
|
stmt = select(a).where(a.c.data == bindparam("x", x1))
|
|
eq_(cs(stmt).all(), [x1])
|
|
|
|
x1 = random.randint(1, 10)
|
|
eq_(cs(stmt.params({"x": x1})).all(), [x1])
|
|
|
|
x1 = random.randint(1, 10)
|
|
eq_(cs(stmt, {"x": x1}).all(), [x1])
|
|
|
|
x1 = random.randint(1, 10)
|
|
x2 = random.randint(1, 10)
|
|
eq_(cs(stmt.params({"x": x1}), {"x": x2}).all(), [x2])
|
|
|
|
stmt2 = stmt.params(x=6).subquery().select()
|
|
eq_(cs(stmt2).all(), [6])
|
|
eq_(cs(stmt2.params({"x": 2})).all(), [2])
|
|
|
|
with expect_deprecated(
|
|
r"The params\(\) and unique_params\(\) "
|
|
"methods on non-statement"
|
|
):
|
|
# NOTE: can't mix and match the two params styles here
|
|
stmt3 = stmt.params(x=6).subquery().params(x=8).select()
|
|
eq_(cs(stmt3).all(), [6])
|
|
eq_(cs(stmt3.params({"x": 9})).all(), [9])
|
|
|
|
def test_union(self, connection):
|
|
a = self.tables.a
|
|
|
|
cs = connection.scalars
|
|
for _ in range(3):
|
|
x1 = random.randint(1, 10)
|
|
x2 = random.randint(1, 10)
|
|
|
|
eq_(
|
|
cs(
|
|
select(a)
|
|
.where(a.c.data == x1)
|
|
.union_all(select(a).where(a.c.data == x2))
|
|
.order_by(a.c.data)
|
|
).all(),
|
|
sorted([x1, x2]),
|
|
)
|
|
|
|
x1 = random.randint(1, 10)
|
|
x2 = random.randint(1, 10)
|
|
stmt = (
|
|
select(a, literal(1).label("ord"))
|
|
.where(a.c.data == bindparam("x", x1))
|
|
.union_all(
|
|
select(a, literal(2)).where(a.c.data == bindparam("y", x2))
|
|
)
|
|
.order_by("ord")
|
|
)
|
|
eq_(cs(stmt).all(), [x1, x2])
|
|
|
|
x1a = random.randint(1, 10)
|
|
eq_(cs(stmt.params({"x": x1a})).all(), [x1a, x2])
|
|
|
|
x2 = random.randint(1, 10)
|
|
eq_(cs(stmt, {"y": x2}).all(), [x1, x2])
|
|
|
|
x1 = random.randint(1, 10)
|
|
x2 = random.randint(1, 10)
|
|
eq_(cs(stmt.params({"x": x1}), {"y": x2}).all(), [x1, x2])
|
|
|
|
x1 = random.randint(1, 10)
|
|
x2 = random.randint(1, 10)
|
|
stmt2 = (
|
|
stmt.params(x=x1)
|
|
.subquery()
|
|
.select()
|
|
.params(y=x2)
|
|
.order_by("ord")
|
|
)
|
|
eq_(cs(stmt2).all(), [x1, x2])
|
|
eq_(cs(stmt2.params({"x": x1}).params({"y": x2})).all(), [x1, x2])
|
|
|
|
def test_text(self, connection):
|
|
a = self.tables.a
|
|
|
|
cs = connection.scalars
|
|
|
|
for _ in range(3):
|
|
x0 = random.randint(1, 10)
|
|
stmt = text("select data from a where data = :x").params(x=x0)
|
|
eq_(cs(stmt).all(), [x0])
|
|
|
|
x1 = random.randint(1, 10)
|
|
eq_(cs(stmt.params({"x": x1})).all(), [x1])
|
|
|
|
x2 = random.randint(1, 10)
|
|
stmt2 = stmt.columns(a.c.data).params(x=x2)
|
|
eq_(cs(stmt2).all(), [x2])
|
|
eq_(cs(stmt2, {"x": 1}).all(), [1])
|
|
eq_(cs(stmt2.params(x=1)).all(), [1])
|