mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-06 17:01:07 -04:00
a9e82a6abe
these were failing on Oracle which has no json type Change-Id: I3c35c4f38c581378c74ee9dc3fb7a71655c5ee9c
442 lines
14 KiB
Python
442 lines
14 KiB
Python
"""Test the TString construct for Python 3.14+ template strings."""
|
|
|
|
from itertools import zip_longest
|
|
|
|
from sqlalchemy import column
|
|
from sqlalchemy import exc
|
|
from sqlalchemy import Integer
|
|
from sqlalchemy import JSON
|
|
from sqlalchemy import literal
|
|
from sqlalchemy import select
|
|
from sqlalchemy import String
|
|
from sqlalchemy import testing
|
|
from sqlalchemy import tstring
|
|
from sqlalchemy.engine.interfaces import CacheStats
|
|
from sqlalchemy.sql import table
|
|
from sqlalchemy.sql.elements import ColumnClause
|
|
from sqlalchemy.sql.sqltypes import TypeEngine
|
|
from sqlalchemy.testing import AssertsCompiledSQL
|
|
from sqlalchemy.testing import eq_
|
|
from sqlalchemy.testing import fixtures
|
|
from sqlalchemy.testing.assertions import expect_raises_message
|
|
|
|
table1 = table(
|
|
"mytable",
|
|
column("myid", Integer),
|
|
column("name", String),
|
|
column("description", String),
|
|
)
|
|
|
|
table2 = table(
|
|
"myothertable", column("otherid", Integer), column("othername", String)
|
|
)
|
|
|
|
|
|
class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
|
|
__dialect__ = "default"
|
|
|
|
def test_basic_literal_interpolation(self):
|
|
a = 5
|
|
b = 10
|
|
stmt = tstring(t"select {a}, {b}")
|
|
self.assert_compile(
|
|
stmt,
|
|
"select :param_1, :param_2",
|
|
checkparams={"param_1": 5, "param_2": 10},
|
|
)
|
|
|
|
def test_no_strings(self):
|
|
with expect_raises_message(
|
|
exc.ArgumentError, r"pep-750 Tstring \(e.g. t'...'\) expected"
|
|
):
|
|
tstring("select * from table") # type: ignore
|
|
|
|
def test_tstring_literal_passthrough(self):
|
|
stmt = tstring(t"select * from foo where lala = bar")
|
|
self.assert_compile(stmt, "select * from foo where lala = bar")
|
|
|
|
def test_sqlalchemy_expression_interpolation(self):
|
|
subq = select(literal(1)).scalar_subquery()
|
|
stmt = tstring(t"SELECT {subq}")
|
|
self.assert_compile(
|
|
stmt,
|
|
"SELECT (SELECT :param_1 AS anon_1)",
|
|
checkparams={"param_1": 1},
|
|
)
|
|
|
|
def test_column_interpolation(self):
|
|
stmt = tstring(t"SELECT {table1.c.myid}, {table1.c.name} FROM mytable")
|
|
self.assert_compile(
|
|
stmt, "SELECT mytable.myid, mytable.name FROM mytable"
|
|
)
|
|
|
|
def test_column_interpolation_labeled(self):
|
|
# Labels are not supported inside tstring as they're ambiguous
|
|
# (should they render with AS in all contexts?)
|
|
label1 = table1.c.myid.label("label1")
|
|
label2 = table1.c.name.label("label2")
|
|
|
|
with expect_raises_message(
|
|
exc.CompileError,
|
|
"Using label\\(\\) directly inside tstring is not supported",
|
|
):
|
|
tstring(t"SELECT {label1}, {label2} FROM mytable").compile()
|
|
|
|
def test_arithmetic_expression(self):
|
|
# Python arithmetic is evaluated before being passed to tstring
|
|
a = 1
|
|
stmt = tstring(t"SELECT {a + 7}")
|
|
self.assert_compile(
|
|
stmt, "SELECT :param_1", checkparams={"param_1": 8}
|
|
)
|
|
|
|
def test_embed_tstring_as_select_criteria(self):
|
|
user_id = 123
|
|
stmt = select(table1).where(tstring(t"{table1.c.myid} = {user_id}"))
|
|
self.assert_compile(
|
|
stmt,
|
|
"SELECT mytable.myid, mytable.name, mytable.description FROM "
|
|
"mytable WHERE mytable.myid = :param_1",
|
|
checkparams={"param_1": 123},
|
|
)
|
|
|
|
def test_embed_tstring_as_fromclause(self):
|
|
status = "some status"
|
|
stmt = select(column("x")).select_from(
|
|
tstring(
|
|
t"foobar left outer join lala on foobar.foo = lala.foo "
|
|
t"AND foobar.status = {status}"
|
|
)
|
|
)
|
|
self.assert_compile(
|
|
stmt,
|
|
"SELECT x FROM foobar left outer join "
|
|
"lala on foobar.foo = lala.foo AND foobar.status = :param_1",
|
|
checkparams={"param_1": "some status"},
|
|
)
|
|
|
|
def test_and_operator(self):
|
|
stmt = tstring(t"1 = 1") & tstring(t"2 = 2")
|
|
self.assert_compile(stmt, "1 = 1 AND 2 = 2")
|
|
|
|
def test_multiple_literals(self):
|
|
a, b, c, d = 1, 2, 3, 4
|
|
stmt = tstring(t"SELECT {a}, {b}, {c}, {d}")
|
|
self.assert_compile(
|
|
stmt,
|
|
"SELECT :param_1, :param_2, :param_3, :param_4",
|
|
checkparams={
|
|
"param_1": 1,
|
|
"param_2": 2,
|
|
"param_3": 3,
|
|
"param_4": 4,
|
|
},
|
|
)
|
|
|
|
def test_nested_tstring_execution(self):
|
|
inner = tstring(t"(SELECT {'some value'} AS anon_1)")
|
|
self.assert_compile(
|
|
tstring(t"select {inner}"),
|
|
"select (SELECT :param_1 AS anon_1)",
|
|
checkparams={"param_1": "some value"},
|
|
)
|
|
|
|
def test_nested_scalar_subquery_execution(self):
|
|
inner = select(literal("some value")).scalar_subquery()
|
|
self.assert_compile(
|
|
tstring(t"select {inner}"),
|
|
"select (SELECT :param_1 AS anon_1)",
|
|
checkparams={"param_1": "some value"},
|
|
)
|
|
|
|
def test_nested_subquery_execution(self):
|
|
inner = select(literal("some value")).subquery()
|
|
self.assert_compile(
|
|
tstring(t"select * from {inner}"),
|
|
"select * from (SELECT :param_1 AS anon_2) AS anon_1",
|
|
checkparams={"param_1": "some value"},
|
|
)
|
|
|
|
|
|
class ColumnsTest(fixtures.TestBase, AssertsCompiledSQL):
|
|
__dialect__ = "default"
|
|
|
|
def _assert_columns(self, stmt, columns):
|
|
"""Assert that stmt.selected_columns matches the given columns.
|
|
|
|
Also verifies that the result map structure matches what we'd get
|
|
from a regular select() statement with the same columns.
|
|
"""
|
|
# Check that selected_columns matches
|
|
eq_(
|
|
[c.name for c in stmt.selected_columns],
|
|
[c.name for c in columns],
|
|
)
|
|
for stmt_col, expected_col in zip(stmt.selected_columns, columns):
|
|
eq_(stmt_col.type._type_affinity, expected_col.type._type_affinity)
|
|
|
|
# Verify result map structure matches what select() would produce
|
|
stmt_compiled = stmt.compile()
|
|
select_compiled = select(*columns).compile()
|
|
stmt_map = stmt_compiled._create_result_map()
|
|
select_map = select_compiled._create_result_map()
|
|
|
|
# Compare result map structure using recursive comparison
|
|
eq_(list(stmt_map.keys()), list(select_map.keys()))
|
|
for key in stmt_map:
|
|
stmt_entry = stmt_map[key]
|
|
select_entry = select_map[key]
|
|
# Use recursive comparison for the entire entry tuple
|
|
assert self._compare_recursive(
|
|
stmt_entry, select_entry
|
|
), f"Result map entries differ:\n {stmt_entry}\n {select_entry}"
|
|
|
|
def _compare_recursive(self, left, right):
|
|
if isinstance(left, ColumnClause) and isinstance(right, ColumnClause):
|
|
return (
|
|
left.name == right.name
|
|
and left.type._type_affinity == right.type._type_affinity
|
|
)
|
|
elif isinstance(left, TypeEngine) and isinstance(right, TypeEngine):
|
|
return left._type_affinity == right._type_affinity
|
|
elif isinstance(left, (tuple, list)) and isinstance(
|
|
right, (tuple, list)
|
|
):
|
|
return all(
|
|
self._compare_recursive(l, r)
|
|
for l, r in zip_longest(left, right)
|
|
)
|
|
else:
|
|
return left == right
|
|
|
|
def test_columns_positional(self):
|
|
cols = [column("id", Integer), column("name", String)]
|
|
stmt = tstring(t"SELECT id, name FROM users").columns(*cols)
|
|
self.assert_compile(stmt, "SELECT id, name FROM users")
|
|
self._assert_columns(stmt, cols)
|
|
|
|
def test_columns_keyword(self):
|
|
stmt = tstring(t"SELECT id, name FROM users").columns(
|
|
id=Integer, name=String
|
|
)
|
|
self.assert_compile(stmt, "SELECT id, name FROM users")
|
|
cols = [column("id", Integer), column("name", String)]
|
|
self._assert_columns(stmt, cols)
|
|
|
|
def test_columns_mixed(self):
|
|
cols = [
|
|
column("id", Integer),
|
|
column("name", String),
|
|
column("age", Integer),
|
|
]
|
|
stmt = tstring(t"SELECT id, name, age FROM users").columns(
|
|
cols[0], name=String, age=Integer
|
|
)
|
|
self.assert_compile(stmt, "SELECT id, name, age FROM users")
|
|
self._assert_columns(stmt, cols)
|
|
|
|
def test_columns_subquery(self):
|
|
stmt = (
|
|
tstring(t"SELECT id, name FROM users")
|
|
.columns(column("id", Integer), column("name", String))
|
|
.subquery("st")
|
|
)
|
|
outer = select(table1).select_from(
|
|
table1.join(stmt, table1.c.name == stmt.c.name)
|
|
)
|
|
self.assert_compile(
|
|
outer,
|
|
"SELECT mytable.myid, mytable.name, mytable.description FROM "
|
|
"mytable JOIN (SELECT id, name FROM users) AS st ON "
|
|
"mytable.name = st.name",
|
|
)
|
|
|
|
|
|
class ExecutionTest(fixtures.TestBase):
|
|
__backend__ = True
|
|
|
|
def test_basic_execution(self, connection):
|
|
a = 1
|
|
b = 2
|
|
result = connection.execute(tstring(t"select {a + 7}, {b}"))
|
|
eq_(result.all(), [(8, 2)])
|
|
|
|
@testing.requires.json_type
|
|
def test_json_literal_execution(self, connection):
|
|
some_json = {"foo": "bar"}
|
|
stmt = tstring(t"select {literal(some_json, JSON)}").columns(
|
|
column("jj", JSON)
|
|
)
|
|
result = connection.execute(stmt)
|
|
row = result.scalar()
|
|
eq_(row, {"foo": "bar"})
|
|
|
|
@testing.requires.json_type
|
|
def test_statement_caching(self, connection):
|
|
"""Test that tstring statements are properly cached."""
|
|
some_json = {"foo": "bar"}
|
|
stmt1 = tstring(t"select {literal(some_json, JSON)}").columns(
|
|
column("jj", JSON)
|
|
)
|
|
result1 = connection.execute(stmt1)
|
|
eq_(result1.scalar(), {"foo": "bar"})
|
|
|
|
# Execute same structure with different value
|
|
some_json = {"foo": "newbar", "bat": "hoho"}
|
|
stmt2 = tstring(t"select {literal(some_json, JSON)}").columns(
|
|
column("jj", JSON)
|
|
)
|
|
result2 = connection.execute(stmt2)
|
|
|
|
# Should hit cache
|
|
if hasattr(result2.context, "cache_hit"):
|
|
eq_(result2.context.cache_hit, CacheStats.CACHE_HIT)
|
|
|
|
eq_(result2.scalar(), {"foo": "newbar", "bat": "hoho"})
|
|
|
|
def test_nested_scalar_subquery_execution(self, connection):
|
|
inner = select(literal("some value")).scalar_subquery()
|
|
result = connection.execute(tstring(t"select {inner}"))
|
|
eq_(result.all(), [("some value",)])
|
|
|
|
def test_nested_subquery_execution(self, connection):
|
|
inner = select(literal("some value")).subquery()
|
|
result = connection.execute(tstring(t"select * from {inner}"))
|
|
eq_(result.all(), [("some value",)])
|
|
|
|
def test_multiple_values(self, connection):
|
|
values = [1, 2, 3, 4, 5]
|
|
result = connection.execute(
|
|
tstring(
|
|
t"select {values[0]}, {values[1]}, {values[2]}, "
|
|
t"{values[3]}, {values[4]}"
|
|
)
|
|
)
|
|
eq_(result.all(), [(1, 2, 3, 4, 5)])
|
|
|
|
|
|
class IntegrationTest(fixtures.TablesTest):
|
|
__backend__ = True
|
|
|
|
@classmethod
|
|
def define_tables(cls, metadata):
|
|
from sqlalchemy import Column
|
|
from sqlalchemy import Table
|
|
|
|
Table(
|
|
"users",
|
|
metadata,
|
|
Column("id", Integer, primary_key=True),
|
|
Column("name", String(50)),
|
|
)
|
|
|
|
@classmethod
|
|
def insert_data(cls, connection):
|
|
connection.execute(
|
|
cls.tables.users.insert(),
|
|
[
|
|
{"id": 1, "name": "alice"},
|
|
{"id": 2, "name": "bob"},
|
|
{"id": 3, "name": "charlie"},
|
|
],
|
|
)
|
|
|
|
def test_select_from_real_table(self, connection):
|
|
user_id = 2
|
|
stmt = tstring(t"SELECT * FROM users WHERE id = {user_id}").columns(
|
|
column("id", Integer), column("name", String)
|
|
)
|
|
result = connection.execute(stmt)
|
|
row = result.one()
|
|
eq_(row.id, 2)
|
|
eq_(row.name, "bob")
|
|
|
|
def test_where_clause_with_real_table(self, connection):
|
|
users = self.tables.users
|
|
name_filter = "alice"
|
|
stmt = select(users).where(
|
|
tstring(t"{users.c.name} = {literal(name_filter)}")
|
|
)
|
|
result = connection.execute(stmt)
|
|
row = result.one()
|
|
eq_(row.id, 1)
|
|
eq_(row.name, "alice")
|
|
|
|
def test_complex_query(self, connection):
|
|
min_id = 1
|
|
max_id = 2
|
|
stmt = tstring(
|
|
t"SELECT id, name FROM users WHERE id >= {min_id} "
|
|
t"AND id <= {max_id}"
|
|
).columns(column("id", Integer), column("name", String))
|
|
result = connection.execute(stmt)
|
|
rows = result.all()
|
|
eq_(len(rows), 2)
|
|
eq_(rows[0].name, "alice")
|
|
eq_(rows[1].name, "bob")
|
|
|
|
|
|
class CacheKeyTest(fixtures.CacheKeyFixture, fixtures.TestBase):
|
|
"""Test cache key generation for tstring constructs."""
|
|
|
|
@fixtures.CacheKeySuite.run_suite_tests
|
|
def test_tstring_cache_key(self):
|
|
|
|
def stmt1():
|
|
# Basic tstring with literal
|
|
a = 5
|
|
return tstring(t"SELECT {a}")
|
|
|
|
def stmt2():
|
|
# Different structure - two literals
|
|
a = 5
|
|
b = 10
|
|
return tstring(t"SELECT {a}, {b}")
|
|
|
|
def stmt3():
|
|
# With column reference
|
|
return tstring(t"SELECT {table1.c.myid}")
|
|
|
|
def stmt4():
|
|
# Different column - different cache key
|
|
return tstring(t"SELECT {table1.c.name}")
|
|
|
|
def stmt5():
|
|
# With .columns()
|
|
a = 5
|
|
return tstring(t"SELECT {a}").columns(column("val", Integer))
|
|
|
|
def stmt6():
|
|
# String literal passthrough
|
|
return tstring(t"SELECT * FROM users")
|
|
|
|
def stmt7():
|
|
# Different string literal
|
|
return tstring(t"SELECT id FROM users")
|
|
|
|
def stmt8():
|
|
# With SQLAlchemy scalar subquery
|
|
return tstring(t"SELECT {select(literal(1)).scalar_subquery()}")
|
|
|
|
def stmt9():
|
|
# Mixed: text and literal
|
|
user_id = 42
|
|
return tstring(t"SELECT * FROM users WHERE id = {user_id}")
|
|
|
|
def stmt10():
|
|
# Mixed: text and column
|
|
return tstring(t"SELECT * FROM users WHERE id = {table1.c.myid}")
|
|
|
|
return lambda: [
|
|
stmt1(),
|
|
stmt2(),
|
|
stmt3(),
|
|
stmt4(),
|
|
stmt5(),
|
|
stmt6(),
|
|
stmt7(),
|
|
stmt8(),
|
|
stmt9(),
|
|
stmt10(),
|
|
]
|