mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-15 13:17:24 -04:00
Fix adaption in AnnotatedLabel; repair needless expense in coercion
Fixed regression involving clause adaption of labeled ORM compound elements, such as single-table inheritance discriminator expressions with conditionals or CASE expressions, which could cause aliased expressions such as those used in ORM join / joinedload operations to not be adapted correctly, such as referring to the wrong table in the ON clause in a join. This change also improves a performance bump that was located within the process of invoking :meth:`_sql.Select.join` given an ORM attribute as a target. Fixes: #6550 Change-Id: I98906476f0cce6f41ea00b77c789baa818e9d167
This commit is contained in:
+13
@@ -0,0 +1,13 @@
|
||||
.. change::
|
||||
:tags: bug, orm, regression
|
||||
:tickets: 6550
|
||||
|
||||
Fixed regression involving clause adaption of labeled ORM compound
|
||||
elements, such as single-table inheritance discriminator expressions with
|
||||
conditionals or CASE expressions, which could cause aliased expressions
|
||||
such as those used in ORM join / joinedload operations to not be adapted
|
||||
correctly, such as referring to the wrong table in the ON clause in a join.
|
||||
|
||||
This change also improves a performance bump that was located within the
|
||||
process of invoking :meth:`_sql.Select.join` given an ORM attribute
|
||||
as a target.
|
||||
@@ -151,12 +151,25 @@ def expect(
|
||||
|
||||
is_clause_element = False
|
||||
|
||||
while hasattr(element, "__clause_element__"):
|
||||
# this is a special performance optimization for ORM
|
||||
# joins used by JoinTargetImpl that we don't go through the
|
||||
# work of creating __clause_element__() when we only need the
|
||||
# original QueryableAttribute, as the former will do clause
|
||||
# adaption and all that which is just thrown away here.
|
||||
if (
|
||||
impl._skip_clauseelement_for_target_match
|
||||
and isinstance(element, role)
|
||||
and hasattr(element, "__clause_element__")
|
||||
):
|
||||
is_clause_element = True
|
||||
if not getattr(element, "is_clause_element", False):
|
||||
element = element.__clause_element__()
|
||||
else:
|
||||
break
|
||||
else:
|
||||
while hasattr(element, "__clause_element__"):
|
||||
is_clause_element = True
|
||||
|
||||
if not getattr(element, "is_clause_element", False):
|
||||
element = element.__clause_element__()
|
||||
else:
|
||||
break
|
||||
|
||||
if not is_clause_element:
|
||||
if impl._use_inspection:
|
||||
@@ -230,6 +243,7 @@ class RoleImpl(object):
|
||||
|
||||
_post_coercion = None
|
||||
_resolve_literal_only = False
|
||||
_skip_clauseelement_for_target_match = False
|
||||
|
||||
def __init__(self, role_class):
|
||||
self._role_class = role_class
|
||||
@@ -860,6 +874,8 @@ class HasCTEImpl(ReturnsRowsImpl):
|
||||
class JoinTargetImpl(RoleImpl):
|
||||
__slots__ = ()
|
||||
|
||||
_skip_clauseelement_for_target_match = True
|
||||
|
||||
def _literal_coercion(self, element, legacy=False, **kw):
|
||||
if isinstance(element, str):
|
||||
return element
|
||||
|
||||
@@ -4395,6 +4395,7 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
|
||||
return self.element.foreign_keys
|
||||
|
||||
def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw):
|
||||
self._reset_memoizations()
|
||||
self._element = clone(self._element, **kw)
|
||||
if anonymize_labels:
|
||||
self.name = self._resolve_label = _anonymous_label.safe_construct(
|
||||
|
||||
@@ -737,7 +737,6 @@ class HasCopyInternals(object):
|
||||
continue
|
||||
|
||||
if obj is not None:
|
||||
|
||||
result = meth(attrname, self, obj, **kw)
|
||||
if result is not None:
|
||||
setattr(self, attrname, result)
|
||||
|
||||
@@ -2,6 +2,7 @@ from sqlalchemy import and_
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import join
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import testing
|
||||
from sqlalchemy.orm import aliased
|
||||
@@ -966,6 +967,21 @@ class JoinConditionTest(NoCache, fixtures.DeclarativeMappedTest):
|
||||
|
||||
go()
|
||||
|
||||
def test_a_to_b_aliased_select_join(self):
|
||||
A, B = self.classes("A", "B")
|
||||
|
||||
b1 = aliased(B)
|
||||
|
||||
stmt = select(A)
|
||||
|
||||
@profiling.function_call_count(times=50, warmup=1)
|
||||
def go():
|
||||
# should not do any adaption or aliasing, this is just getting
|
||||
# the args. See #6550 where we also fixed this.
|
||||
stmt.join(A.b.of_type(b1))
|
||||
|
||||
go()
|
||||
|
||||
def test_a_to_d(self):
|
||||
A, D = self.classes("A", "D")
|
||||
|
||||
|
||||
@@ -273,6 +273,10 @@ test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased x86_64_linux_c
|
||||
test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased x86_64_linux_cpython_3.9_sqlite_pysqlite_dbapiunicode_cextensions 10304
|
||||
test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased x86_64_linux_cpython_3.9_sqlite_pysqlite_dbapiunicode_nocextensions 10454
|
||||
|
||||
# TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased_select_join
|
||||
|
||||
test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased_select_join x86_64_linux_cpython_3.9_sqlite_pysqlite_dbapiunicode_cextensions 1104
|
||||
|
||||
# TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_plain
|
||||
|
||||
test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_plain x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_cextensions 4053
|
||||
|
||||
@@ -458,6 +458,34 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
|
||||
select(f), "SELECT t1_1.col1 * :col1_1 AS anon_1 FROM t1 AS t1_1"
|
||||
)
|
||||
|
||||
@testing.combinations(
|
||||
(lambda t1: t1.c.col1, "t1_1.col1"),
|
||||
(lambda t1: t1.c.col1 == "foo", "t1_1.col1 = :col1_1"),
|
||||
(
|
||||
lambda t1: case((t1.c.col1 == "foo", "bar"), else_=t1.c.col1),
|
||||
"CASE WHEN (t1_1.col1 = :col1_1) THEN :param_1 ELSE t1_1.col1 END",
|
||||
),
|
||||
argnames="case, expected",
|
||||
)
|
||||
@testing.combinations(False, True, argnames="label_")
|
||||
@testing.combinations(False, True, argnames="annotate")
|
||||
def test_annotated_label_cases(self, case, expected, label_, annotate):
|
||||
"""test #6550"""
|
||||
|
||||
t1 = table("t1", column("col1"))
|
||||
a1 = t1.alias()
|
||||
|
||||
expr = case(t1=t1)
|
||||
|
||||
if label_:
|
||||
expr = expr.label(None)
|
||||
if annotate:
|
||||
expr = expr._annotate({"foo": "bar"})
|
||||
|
||||
adapted = sql_util.ClauseAdapter(a1).traverse(expr)
|
||||
|
||||
self.assert_compile(adapted, expected)
|
||||
|
||||
@testing.combinations((null(),), (true(),))
|
||||
def test_dont_adapt_singleton_elements(self, elem):
|
||||
"""test :ticket:`6259`"""
|
||||
|
||||
Reference in New Issue
Block a user