mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-15 05:07:16 -04:00
60e7034a74
The "expanding IN" feature, which generates IN expressions at query execution time which are based on the particular parameters associated with the statement execution, is now used for all IN expressions made against lists of literal values. This allows IN expressions to be fully cacheable independently of the list of values being passed, and also includes support for empty lists. For any scenario where the IN expression contains non-literal SQL expressions, the old behavior of pre-rendering for each position in the IN is maintained. The change also completes support for expanding IN with tuples, where previously type-specific bind processors weren't taking effect. As part of this change, a more explicit separation between "literal execute" and "post compile" bound parameters is being made; as the "ansi bind rules" feature is rendering bound parameters inline, as we now support "postcompile" generically, these should be used here, however we have to render literal values at execution time even for "expanding" parameters. new test fixtures etc. are added to assert everything goes to the right place. Fixes: #4645 Change-Id: Iaa2b7bfbfaaf5b80799ee17c9b8507293cba6ed1
962 lines
34 KiB
Python
962 lines
34 KiB
Python
import importlib
|
|
import itertools
|
|
import random
|
|
|
|
from sqlalchemy import and_
|
|
from sqlalchemy import Boolean
|
|
from sqlalchemy import case
|
|
from sqlalchemy import cast
|
|
from sqlalchemy import Column
|
|
from sqlalchemy import column
|
|
from sqlalchemy import dialects
|
|
from sqlalchemy import exists
|
|
from sqlalchemy import extract
|
|
from sqlalchemy import Float
|
|
from sqlalchemy import Integer
|
|
from sqlalchemy import MetaData
|
|
from sqlalchemy import or_
|
|
from sqlalchemy import select
|
|
from sqlalchemy import String
|
|
from sqlalchemy import Table
|
|
from sqlalchemy import table
|
|
from sqlalchemy import text
|
|
from sqlalchemy import tuple_
|
|
from sqlalchemy import union
|
|
from sqlalchemy import union_all
|
|
from sqlalchemy import util
|
|
from sqlalchemy.schema import Sequence
|
|
from sqlalchemy.sql import bindparam
|
|
from sqlalchemy.sql import ColumnElement
|
|
from sqlalchemy.sql import False_
|
|
from sqlalchemy.sql import func
|
|
from sqlalchemy.sql import operators
|
|
from sqlalchemy.sql import True_
|
|
from sqlalchemy.sql import type_coerce
|
|
from sqlalchemy.sql import visitors
|
|
from sqlalchemy.sql.base import HasCacheKey
|
|
from sqlalchemy.sql.elements import _label_reference
|
|
from sqlalchemy.sql.elements import _textual_label_reference
|
|
from sqlalchemy.sql.elements import Annotated
|
|
from sqlalchemy.sql.elements import ClauseElement
|
|
from sqlalchemy.sql.elements import ClauseList
|
|
from sqlalchemy.sql.elements import CollationClause
|
|
from sqlalchemy.sql.elements import Immutable
|
|
from sqlalchemy.sql.elements import Null
|
|
from sqlalchemy.sql.elements import Slice
|
|
from sqlalchemy.sql.elements import UnaryExpression
|
|
from sqlalchemy.sql.functions import FunctionElement
|
|
from sqlalchemy.sql.functions import GenericFunction
|
|
from sqlalchemy.sql.functions import ReturnTypeFromArgs
|
|
from sqlalchemy.sql.selectable import _OffsetLimitParam
|
|
from sqlalchemy.sql.selectable import AliasedReturnsRows
|
|
from sqlalchemy.sql.selectable import FromGrouping
|
|
from sqlalchemy.sql.selectable import Selectable
|
|
from sqlalchemy.sql.selectable import SelectStatementGrouping
|
|
from sqlalchemy.sql.visitors import InternalTraversal
|
|
from sqlalchemy.testing import eq_
|
|
from sqlalchemy.testing import fixtures
|
|
from sqlalchemy.testing import is_false
|
|
from sqlalchemy.testing import is_not_
|
|
from sqlalchemy.testing import is_true
|
|
from sqlalchemy.testing import ne_
|
|
from sqlalchemy.testing.util import random_choices
|
|
from sqlalchemy.util import class_hierarchy
|
|
|
|
|
|
meta = MetaData()
|
|
meta2 = MetaData()
|
|
|
|
table_a = Table("a", meta, Column("a", Integer), Column("b", String))
|
|
table_b_like_a = Table("b2", meta, Column("a", Integer), Column("b", String))
|
|
|
|
table_a_2 = Table("a", meta2, Column("a", Integer), Column("b", String))
|
|
|
|
table_a_2_fs = Table(
|
|
"a", meta2, Column("a", Integer), Column("b", String), schema="fs"
|
|
)
|
|
table_a_2_bs = Table(
|
|
"a", meta2, Column("a", Integer), Column("b", String), schema="bs"
|
|
)
|
|
|
|
table_b = Table("b", meta, Column("a", Integer), Column("b", Integer))
|
|
|
|
table_c = Table("c", meta, Column("x", Integer), Column("y", Integer))
|
|
|
|
table_d = Table("d", meta, Column("y", Integer), Column("z", Integer))
|
|
|
|
|
|
class MyEntity(HasCacheKey):
|
|
def __init__(self, name, element):
|
|
self.name = name
|
|
self.element = element
|
|
|
|
_cache_key_traversal = [
|
|
("name", InternalTraversal.dp_string),
|
|
("element", InternalTraversal.dp_clauseelement),
|
|
]
|
|
|
|
|
|
class CoreFixtures(object):
|
|
# lambdas which return a tuple of ColumnElement objects.
|
|
# must return at least two objects that should compare differently.
|
|
# to test more varieties of "difference" additional objects can be added.
|
|
fixtures = [
|
|
lambda: (
|
|
column("q"),
|
|
column("x"),
|
|
column("q", Integer),
|
|
column("q", String),
|
|
),
|
|
lambda: (~column("q", Boolean), ~column("p", Boolean)),
|
|
lambda: (
|
|
table_a.c.a.label("foo"),
|
|
table_a.c.a.label("bar"),
|
|
table_a.c.b.label("foo"),
|
|
),
|
|
lambda: (
|
|
_label_reference(table_a.c.a.desc()),
|
|
_label_reference(table_a.c.a.asc()),
|
|
),
|
|
lambda: (_textual_label_reference("a"), _textual_label_reference("b")),
|
|
lambda: (
|
|
text("select a, b from table").columns(a=Integer, b=String),
|
|
text("select a, b, c from table").columns(
|
|
a=Integer, b=String, c=Integer
|
|
),
|
|
text("select a, b, c from table where foo=:bar").bindparams(
|
|
bindparam("bar", type_=Integer)
|
|
),
|
|
text("select a, b, c from table where foo=:foo").bindparams(
|
|
bindparam("foo", type_=Integer)
|
|
),
|
|
text("select a, b, c from table where foo=:bar").bindparams(
|
|
bindparam("bar", type_=String)
|
|
),
|
|
),
|
|
lambda: (
|
|
column("q") == column("x"),
|
|
column("q") == column("y"),
|
|
column("z") == column("x"),
|
|
column("z") + column("x"),
|
|
column("z") - column("x"),
|
|
column("x") - column("z"),
|
|
column("z") > column("x"),
|
|
column("x").in_([5, 7]),
|
|
column("x").in_([10, 7, 8]),
|
|
# note these two are mathematically equivalent but for now they
|
|
# are considered to be different
|
|
column("z") >= column("x"),
|
|
column("x") <= column("z"),
|
|
column("q").between(5, 6),
|
|
column("q").between(5, 6, symmetric=True),
|
|
column("q").like("somstr"),
|
|
column("q").like("somstr", escape="\\"),
|
|
column("q").like("somstr", escape="X"),
|
|
),
|
|
lambda: (
|
|
table_a.c.a,
|
|
table_a.c.a._annotate({"orm": True}),
|
|
table_a.c.a._annotate({"orm": True})._annotate({"bar": False}),
|
|
table_a.c.a._annotate(
|
|
{"orm": True, "parententity": MyEntity("a", table_a)}
|
|
),
|
|
table_a.c.a._annotate(
|
|
{"orm": True, "parententity": MyEntity("b", table_a)}
|
|
),
|
|
table_a.c.a._annotate(
|
|
{"orm": True, "parententity": MyEntity("b", select([table_a]))}
|
|
),
|
|
),
|
|
lambda: (
|
|
cast(column("q"), Integer),
|
|
cast(column("q"), Float),
|
|
cast(column("p"), Integer),
|
|
),
|
|
lambda: (
|
|
bindparam("x"),
|
|
bindparam("y"),
|
|
bindparam("x", type_=Integer),
|
|
bindparam("x", type_=String),
|
|
bindparam(None),
|
|
),
|
|
lambda: (_OffsetLimitParam("x"), _OffsetLimitParam("y")),
|
|
lambda: (func.foo(), func.foo(5), func.bar()),
|
|
lambda: (func.current_date(), func.current_time()),
|
|
lambda: (
|
|
func.next_value(Sequence("q")),
|
|
func.next_value(Sequence("p")),
|
|
),
|
|
lambda: (True_(), False_()),
|
|
lambda: (Null(),),
|
|
lambda: (ReturnTypeFromArgs("foo"), ReturnTypeFromArgs(5)),
|
|
lambda: (FunctionElement(5), FunctionElement(5, 6)),
|
|
lambda: (func.count(), func.not_count()),
|
|
lambda: (func.char_length("abc"), func.char_length("def")),
|
|
lambda: (GenericFunction("a", "b"), GenericFunction("a")),
|
|
lambda: (CollationClause("foobar"), CollationClause("batbar")),
|
|
lambda: (
|
|
type_coerce(column("q", Integer), String),
|
|
type_coerce(column("q", Integer), Float),
|
|
type_coerce(column("z", Integer), Float),
|
|
),
|
|
lambda: (table_a.c.a, table_b.c.a),
|
|
lambda: (tuple_(1, 2), tuple_(3, 4)),
|
|
lambda: (func.array_agg([1, 2]), func.array_agg([3, 4])),
|
|
lambda: (
|
|
func.percentile_cont(0.5).within_group(table_a.c.a),
|
|
func.percentile_cont(0.5).within_group(table_a.c.b),
|
|
func.percentile_cont(0.5).within_group(table_a.c.a, table_a.c.b),
|
|
func.percentile_cont(0.5).within_group(
|
|
table_a.c.a, table_a.c.b, column("q")
|
|
),
|
|
),
|
|
lambda: (
|
|
func.is_equal("a", "b").as_comparison(1, 2),
|
|
func.is_equal("a", "c").as_comparison(1, 2),
|
|
func.is_equal("a", "b").as_comparison(2, 1),
|
|
func.is_equal("a", "b", "c").as_comparison(1, 2),
|
|
func.foobar("a", "b").as_comparison(1, 2),
|
|
),
|
|
lambda: (
|
|
func.row_number().over(order_by=table_a.c.a),
|
|
func.row_number().over(order_by=table_a.c.a, range_=(0, 10)),
|
|
func.row_number().over(order_by=table_a.c.a, range_=(None, 10)),
|
|
func.row_number().over(order_by=table_a.c.a, rows=(None, 20)),
|
|
func.row_number().over(order_by=table_a.c.b),
|
|
func.row_number().over(
|
|
order_by=table_a.c.a, partition_by=table_a.c.b
|
|
),
|
|
),
|
|
lambda: (
|
|
func.count(1).filter(table_a.c.a == 5),
|
|
func.count(1).filter(table_a.c.a == 10),
|
|
func.foob(1).filter(table_a.c.a == 10),
|
|
),
|
|
lambda: (
|
|
and_(table_a.c.a == 5, table_a.c.b == table_b.c.a),
|
|
and_(table_a.c.a == 5, table_a.c.a == table_b.c.a),
|
|
or_(table_a.c.a == 5, table_a.c.b == table_b.c.a),
|
|
ClauseList(table_a.c.a == 5, table_a.c.b == table_b.c.a),
|
|
ClauseList(table_a.c.a == 5, table_a.c.b == table_a.c.a),
|
|
),
|
|
lambda: (
|
|
case(whens=[(table_a.c.a == 5, 10), (table_a.c.a == 10, 20)]),
|
|
case(whens=[(table_a.c.a == 18, 10), (table_a.c.a == 10, 20)]),
|
|
case(whens=[(table_a.c.a == 5, 10), (table_a.c.b == 10, 20)]),
|
|
case(
|
|
whens=[
|
|
(table_a.c.a == 5, 10),
|
|
(table_a.c.b == 10, 20),
|
|
(table_a.c.a == 9, 12),
|
|
]
|
|
),
|
|
case(
|
|
whens=[(table_a.c.a == 5, 10), (table_a.c.a == 10, 20)],
|
|
else_=30,
|
|
),
|
|
case({"wendy": "W", "jack": "J"}, value=table_a.c.a, else_="E"),
|
|
case({"wendy": "W", "jack": "J"}, value=table_a.c.b, else_="E"),
|
|
case({"wendy_w": "W", "jack": "J"}, value=table_a.c.a, else_="E"),
|
|
),
|
|
lambda: (
|
|
extract("foo", table_a.c.a),
|
|
extract("foo", table_a.c.b),
|
|
extract("bar", table_a.c.a),
|
|
),
|
|
lambda: (
|
|
Slice(1, 2, 5),
|
|
Slice(1, 5, 5),
|
|
Slice(1, 5, 10),
|
|
Slice(2, 10, 15),
|
|
),
|
|
lambda: (
|
|
select([table_a.c.a]),
|
|
select([table_a.c.a, table_a.c.b]),
|
|
select([table_a.c.b, table_a.c.a]),
|
|
select([table_a.c.a]).where(table_a.c.b == 5),
|
|
select([table_a.c.a])
|
|
.where(table_a.c.b == 5)
|
|
.where(table_a.c.a == 10),
|
|
select([table_a.c.a]).where(table_a.c.b == 5).with_for_update(),
|
|
select([table_a.c.a])
|
|
.where(table_a.c.b == 5)
|
|
.with_for_update(nowait=True),
|
|
select([table_a.c.a]).where(table_a.c.b == 5).correlate(table_b),
|
|
select([table_a.c.a])
|
|
.where(table_a.c.b == 5)
|
|
.correlate_except(table_b),
|
|
),
|
|
lambda: (
|
|
select([table_a.c.a]).cte(),
|
|
select([table_a.c.a]).cte(recursive=True),
|
|
select([table_a.c.a]).cte(name="some_cte", recursive=True),
|
|
select([table_a.c.a]).cte(name="some_cte"),
|
|
select([table_a.c.a]).cte(name="some_cte").alias("other_cte"),
|
|
select([table_a.c.a])
|
|
.cte(name="some_cte")
|
|
.union_all(select([table_a.c.a])),
|
|
select([table_a.c.a])
|
|
.cte(name="some_cte")
|
|
.union_all(select([table_a.c.b])),
|
|
select([table_a.c.a]).lateral(),
|
|
select([table_a.c.a]).lateral(name="bar"),
|
|
table_a.tablesample(func.bernoulli(1)),
|
|
table_a.tablesample(func.bernoulli(1), seed=func.random()),
|
|
table_a.tablesample(func.bernoulli(1), seed=func.other_random()),
|
|
table_a.tablesample(func.hoho(1)),
|
|
table_a.tablesample(func.bernoulli(1), name="bar"),
|
|
table_a.tablesample(
|
|
func.bernoulli(1), name="bar", seed=func.random()
|
|
),
|
|
),
|
|
lambda: (
|
|
select([table_a.c.a]),
|
|
select([table_a.c.a]).prefix_with("foo"),
|
|
select([table_a.c.a]).prefix_with("foo", dialect="mysql"),
|
|
select([table_a.c.a]).prefix_with("foo", dialect="postgresql"),
|
|
select([table_a.c.a]).prefix_with("bar"),
|
|
select([table_a.c.a]).suffix_with("bar"),
|
|
),
|
|
lambda: (
|
|
select([table_a_2.c.a]),
|
|
select([table_a_2_fs.c.a]),
|
|
select([table_a_2_bs.c.a]),
|
|
),
|
|
lambda: (
|
|
select([table_a.c.a]),
|
|
select([table_a.c.a]).with_hint(None, "some hint"),
|
|
select([table_a.c.a]).with_hint(None, "some other hint"),
|
|
select([table_a.c.a]).with_hint(table_a, "some hint"),
|
|
select([table_a.c.a])
|
|
.with_hint(table_a, "some hint")
|
|
.with_hint(None, "some other hint"),
|
|
select([table_a.c.a]).with_hint(table_a, "some other hint"),
|
|
select([table_a.c.a]).with_hint(
|
|
table_a, "some hint", dialect_name="mysql"
|
|
),
|
|
select([table_a.c.a]).with_hint(
|
|
table_a, "some hint", dialect_name="postgresql"
|
|
),
|
|
),
|
|
lambda: (
|
|
table_a.join(table_b, table_a.c.a == table_b.c.a),
|
|
table_a.join(
|
|
table_b, and_(table_a.c.a == table_b.c.a, table_a.c.b == 1)
|
|
),
|
|
table_a.outerjoin(table_b, table_a.c.a == table_b.c.a),
|
|
),
|
|
lambda: (
|
|
table_a.alias("a"),
|
|
table_a.alias("b"),
|
|
table_a.alias(),
|
|
table_b.alias("a"),
|
|
select([table_a.c.a]).alias("a"),
|
|
),
|
|
lambda: (
|
|
FromGrouping(table_a.alias("a")),
|
|
FromGrouping(table_a.alias("b")),
|
|
),
|
|
lambda: (
|
|
SelectStatementGrouping(select([table_a])),
|
|
SelectStatementGrouping(select([table_b])),
|
|
),
|
|
lambda: (
|
|
select([table_a.c.a]).scalar_subquery(),
|
|
select([table_a.c.a]).where(table_a.c.b == 5).scalar_subquery(),
|
|
),
|
|
lambda: (
|
|
exists().where(table_a.c.a == 5),
|
|
exists().where(table_a.c.b == 5),
|
|
),
|
|
lambda: (
|
|
union(select([table_a.c.a]), select([table_a.c.b])),
|
|
union(select([table_a.c.a]), select([table_a.c.b])).order_by("a"),
|
|
union_all(select([table_a.c.a]), select([table_a.c.b])),
|
|
union(select([table_a.c.a])),
|
|
union(
|
|
select([table_a.c.a]),
|
|
select([table_a.c.b]).where(table_a.c.b > 5),
|
|
),
|
|
),
|
|
lambda: (
|
|
table("a", column("x"), column("y")),
|
|
table("a", column("y"), column("x")),
|
|
table("b", column("x"), column("y")),
|
|
table("a", column("x"), column("y"), column("z")),
|
|
table("a", column("x"), column("y", Integer)),
|
|
table("a", column("q"), column("y", Integer)),
|
|
),
|
|
lambda: (table_a, table_b),
|
|
]
|
|
|
|
dont_compare_values_fixtures = [
|
|
lambda: (
|
|
# note the in_(...) all have different column names becuase
|
|
# otherwise all IN expressions would compare as equivalent
|
|
column("x").in_(random_choices(range(10), k=3)),
|
|
column("y").in_(
|
|
bindparam(
|
|
"q",
|
|
random_choices(range(10), k=random.randint(0, 7)),
|
|
expanding=True,
|
|
)
|
|
),
|
|
column("z").in_(random_choices(range(10), k=random.randint(0, 7))),
|
|
column("x") == random.randint(1, 10),
|
|
)
|
|
]
|
|
|
|
def _complex_fixtures():
|
|
def one():
|
|
a1 = table_a.alias()
|
|
a2 = table_b_like_a.alias()
|
|
|
|
stmt = (
|
|
select([table_a.c.a, a1.c.b, a2.c.b])
|
|
.where(table_a.c.b == a1.c.b)
|
|
.where(a1.c.b == a2.c.b)
|
|
.where(a1.c.a == 5)
|
|
)
|
|
|
|
return stmt
|
|
|
|
def one_diff():
|
|
a1 = table_b_like_a.alias()
|
|
a2 = table_a.alias()
|
|
|
|
stmt = (
|
|
select([table_a.c.a, a1.c.b, a2.c.b])
|
|
.where(table_a.c.b == a1.c.b)
|
|
.where(a1.c.b == a2.c.b)
|
|
.where(a1.c.a == 5)
|
|
)
|
|
|
|
return stmt
|
|
|
|
def two():
|
|
inner = one().subquery()
|
|
|
|
stmt = select([table_b.c.a, inner.c.a, inner.c.b]).select_from(
|
|
table_b.join(inner, table_b.c.b == inner.c.b)
|
|
)
|
|
|
|
return stmt
|
|
|
|
def three():
|
|
|
|
a1 = table_a.alias()
|
|
a2 = table_a.alias()
|
|
ex = exists().where(table_b.c.b == a1.c.a)
|
|
|
|
stmt = (
|
|
select([a1.c.a, a2.c.a])
|
|
.select_from(a1.join(a2, a1.c.b == a2.c.b))
|
|
.where(ex)
|
|
)
|
|
return stmt
|
|
|
|
return [one(), one_diff(), two(), three()]
|
|
|
|
fixtures.append(_complex_fixtures)
|
|
|
|
|
|
class CacheKeyFixture(object):
|
|
def _run_cache_key_fixture(self, fixture, compare_values):
|
|
case_a = fixture()
|
|
case_b = fixture()
|
|
|
|
for a, b in itertools.combinations_with_replacement(
|
|
range(len(case_a)), 2
|
|
):
|
|
if a == b:
|
|
a_key = case_a[a]._generate_cache_key()
|
|
b_key = case_b[b]._generate_cache_key()
|
|
is_not_(a_key, None)
|
|
is_not_(b_key, None)
|
|
|
|
eq_(a_key.key, b_key.key)
|
|
eq_(hash(a_key), hash(b_key))
|
|
|
|
for a_param, b_param in zip(
|
|
a_key.bindparams, b_key.bindparams
|
|
):
|
|
assert a_param.compare(
|
|
b_param, compare_values=compare_values
|
|
)
|
|
else:
|
|
a_key = case_a[a]._generate_cache_key()
|
|
b_key = case_b[b]._generate_cache_key()
|
|
|
|
if a_key.key == b_key.key:
|
|
for a_param, b_param in zip(
|
|
a_key.bindparams, b_key.bindparams
|
|
):
|
|
if not a_param.compare(
|
|
b_param, compare_values=compare_values
|
|
):
|
|
break
|
|
else:
|
|
# this fails unconditionally since we could not
|
|
# find bound parameter values that differed.
|
|
# Usually we intended to get two distinct keys here
|
|
# so the failure will be more descriptive using the
|
|
# ne_() assertion.
|
|
ne_(a_key.key, b_key.key)
|
|
else:
|
|
ne_(a_key.key, b_key.key)
|
|
|
|
# ClauseElement-specific test to ensure the cache key
|
|
# collected all the bound parameters
|
|
if isinstance(case_a[a], ClauseElement) and isinstance(
|
|
case_b[b], ClauseElement
|
|
):
|
|
assert_a_params = []
|
|
assert_b_params = []
|
|
visitors.traverse_depthfirst(
|
|
case_a[a], {}, {"bindparam": assert_a_params.append}
|
|
)
|
|
visitors.traverse_depthfirst(
|
|
case_b[b], {}, {"bindparam": assert_b_params.append}
|
|
)
|
|
|
|
# note we're asserting the order of the params as well as
|
|
# if there are dupes or not. ordering has to be
|
|
# deterministic and matches what a traversal would provide.
|
|
# regular traverse_depthfirst does produce dupes in cases
|
|
# like
|
|
# select([some_alias]).
|
|
# select_from(join(some_alias, other_table))
|
|
# where a bound parameter is inside of some_alias. the
|
|
# cache key case is more minimalistic
|
|
eq_(
|
|
sorted(a_key.bindparams, key=lambda b: b.key),
|
|
sorted(
|
|
util.unique_list(assert_a_params), key=lambda b: b.key
|
|
),
|
|
)
|
|
eq_(
|
|
sorted(b_key.bindparams, key=lambda b: b.key),
|
|
sorted(
|
|
util.unique_list(assert_b_params), key=lambda b: b.key
|
|
),
|
|
)
|
|
|
|
|
|
class CacheKeyTest(CacheKeyFixture, CoreFixtures, fixtures.TestBase):
|
|
def test_cache_key(self):
|
|
for fixtures_, compare_values in [
|
|
(self.fixtures, True),
|
|
(self.dont_compare_values_fixtures, False),
|
|
]:
|
|
for fixture in fixtures_:
|
|
self._run_cache_key_fixture(fixture, compare_values)
|
|
|
|
def test_cache_key_unknown_traverse(self):
|
|
class Foobar1(ClauseElement):
|
|
_traverse_internals = [
|
|
("key", InternalTraversal.dp_anon_name),
|
|
("type_", InternalTraversal.dp_unknown_structure),
|
|
]
|
|
|
|
def __init__(self, key, type_):
|
|
self.key = key
|
|
self.type_ = type_
|
|
|
|
f1 = Foobar1("foo", String())
|
|
eq_(f1._generate_cache_key(), None)
|
|
|
|
def test_cache_key_no_method(self):
|
|
class Foobar1(ClauseElement):
|
|
pass
|
|
|
|
class Foobar2(ColumnElement):
|
|
pass
|
|
|
|
# the None for cache key will prevent objects
|
|
# which contain these elements from being cached.
|
|
f1 = Foobar1()
|
|
eq_(f1._generate_cache_key(), None)
|
|
|
|
f2 = Foobar2()
|
|
eq_(f2._generate_cache_key(), None)
|
|
|
|
s1 = select([column("q"), Foobar2()])
|
|
|
|
eq_(s1._generate_cache_key(), None)
|
|
|
|
def test_get_children_no_method(self):
|
|
class Foobar1(ClauseElement):
|
|
pass
|
|
|
|
class Foobar2(ColumnElement):
|
|
pass
|
|
|
|
f1 = Foobar1()
|
|
eq_(f1.get_children(), [])
|
|
|
|
f2 = Foobar2()
|
|
eq_(f2.get_children(), [])
|
|
|
|
def test_copy_internals_no_method(self):
|
|
class Foobar1(ClauseElement):
|
|
pass
|
|
|
|
class Foobar2(ColumnElement):
|
|
pass
|
|
|
|
f1 = Foobar1()
|
|
f2 = Foobar2()
|
|
|
|
f1._copy_internals()
|
|
f2._copy_internals()
|
|
|
|
|
|
class CompareAndCopyTest(CoreFixtures, fixtures.TestBase):
|
|
@classmethod
|
|
def setup_class(cls):
|
|
# TODO: we need to get dialects here somehow, perhaps in test_suite?
|
|
[
|
|
importlib.import_module("sqlalchemy.dialects.%s" % d)
|
|
for d in dialects.__all__
|
|
if not d.startswith("_")
|
|
]
|
|
|
|
def test_all_present(self):
|
|
need = set(
|
|
cls
|
|
for cls in class_hierarchy(ClauseElement)
|
|
if issubclass(cls, (ColumnElement, Selectable))
|
|
and (
|
|
"__init__" in cls.__dict__
|
|
or issubclass(cls, AliasedReturnsRows)
|
|
)
|
|
and not issubclass(cls, (Annotated))
|
|
and "orm" not in cls.__module__
|
|
and "compiler" not in cls.__module__
|
|
and "crud" not in cls.__module__
|
|
and "dialects" not in cls.__module__ # TODO: dialects?
|
|
).difference({ColumnElement, UnaryExpression})
|
|
|
|
for fixture in self.fixtures + self.dont_compare_values_fixtures:
|
|
case_a = fixture()
|
|
for elem in case_a:
|
|
for mro in type(elem).__mro__:
|
|
need.discard(mro)
|
|
|
|
is_false(bool(need), "%d Remaining classes: %r" % (len(need), need))
|
|
|
|
def test_compare_labels(self):
|
|
for fixtures_, compare_values in [
|
|
(self.fixtures, True),
|
|
(self.dont_compare_values_fixtures, False),
|
|
]:
|
|
for fixture in fixtures_:
|
|
case_a = fixture()
|
|
case_b = fixture()
|
|
|
|
for a, b in itertools.combinations_with_replacement(
|
|
range(len(case_a)), 2
|
|
):
|
|
if a == b:
|
|
is_true(
|
|
case_a[a].compare(
|
|
case_b[b],
|
|
compare_annotations=True,
|
|
compare_values=compare_values,
|
|
),
|
|
"%r != %r" % (case_a[a], case_b[b]),
|
|
)
|
|
|
|
else:
|
|
is_false(
|
|
case_a[a].compare(
|
|
case_b[b],
|
|
compare_annotations=True,
|
|
compare_values=compare_values,
|
|
),
|
|
"%r == %r" % (case_a[a], case_b[b]),
|
|
)
|
|
|
|
def test_compare_col_identity(self):
|
|
stmt1 = (
|
|
select([table_a.c.a, table_b.c.b])
|
|
.where(table_a.c.a == table_b.c.b)
|
|
.alias()
|
|
)
|
|
stmt1_c = (
|
|
select([table_a.c.a, table_b.c.b])
|
|
.where(table_a.c.a == table_b.c.b)
|
|
.alias()
|
|
)
|
|
|
|
stmt2 = union(select([table_a]), select([table_b]))
|
|
|
|
equivalents = {table_a.c.a: [table_b.c.a]}
|
|
|
|
is_false(
|
|
stmt1.compare(stmt2, use_proxies=True, equivalents=equivalents)
|
|
)
|
|
|
|
is_true(
|
|
stmt1.compare(stmt1_c, use_proxies=True, equivalents=equivalents)
|
|
)
|
|
is_true(
|
|
(table_a.c.a == table_b.c.b).compare(
|
|
stmt1.c.a == stmt1.c.b,
|
|
use_proxies=True,
|
|
equivalents=equivalents,
|
|
)
|
|
)
|
|
|
|
def test_copy_internals(self):
|
|
for fixtures_, compare_values in [
|
|
(self.fixtures, True),
|
|
(self.dont_compare_values_fixtures, False),
|
|
]:
|
|
for fixture in fixtures_:
|
|
case_a = fixture()
|
|
case_b = fixture()
|
|
|
|
assert case_a[0].compare(
|
|
case_b[0], compare_values=compare_values
|
|
)
|
|
|
|
clone = visitors.replacement_traverse(
|
|
case_a[0], {}, lambda elem: None
|
|
)
|
|
|
|
assert clone.compare(case_b[0], compare_values=compare_values)
|
|
|
|
stack = [clone]
|
|
seen = {clone}
|
|
found_elements = False
|
|
while stack:
|
|
obj = stack.pop(0)
|
|
|
|
items = [
|
|
subelem
|
|
for key, elem in clone.__dict__.items()
|
|
if key != "_is_clone_of" and elem is not None
|
|
for subelem in util.to_list(elem)
|
|
if (
|
|
isinstance(subelem, (ColumnElement, ClauseList))
|
|
and subelem not in seen
|
|
and not isinstance(subelem, Immutable)
|
|
and subelem is not case_a[0]
|
|
)
|
|
]
|
|
stack.extend(items)
|
|
seen.update(items)
|
|
|
|
if obj is not clone:
|
|
found_elements = True
|
|
# ensure the element will not compare as true
|
|
obj.compare = lambda other, **kw: False
|
|
obj.__visit_name__ = "dont_match"
|
|
|
|
if found_elements:
|
|
assert not clone.compare(
|
|
case_b[0], compare_values=compare_values
|
|
)
|
|
assert case_a[0].compare(
|
|
case_b[0], compare_values=compare_values
|
|
)
|
|
|
|
|
|
class CompareClausesTest(fixtures.TestBase):
|
|
def test_compare_metadata_tables(self):
|
|
# metadata Table objects cache on their own identity, not their
|
|
# structure. This is mainly to reduce the size of cache keys
|
|
# as well as reduce computational overhead, as Table objects have
|
|
# very large internal state and they are also generally global
|
|
# objects.
|
|
|
|
t1 = Table("a", MetaData(), Column("q", Integer), Column("p", Integer))
|
|
t2 = Table("a", MetaData(), Column("q", Integer), Column("p", Integer))
|
|
|
|
ne_(t1._generate_cache_key(), t2._generate_cache_key())
|
|
|
|
eq_(t1._generate_cache_key().key, (t1,))
|
|
|
|
def test_compare_adhoc_tables(self):
|
|
# non-metadata tables compare on their structure. these objects are
|
|
# not commonly used.
|
|
|
|
# note this test is a bit redundant as we have a similar test
|
|
# via the fixtures also
|
|
t1 = table("a", Column("q", Integer), Column("p", Integer))
|
|
t2 = table("a", Column("q", Integer), Column("p", Integer))
|
|
t3 = table("b", Column("q", Integer), Column("p", Integer))
|
|
t4 = table("a", Column("q", Integer), Column("x", Integer))
|
|
|
|
eq_(t1._generate_cache_key(), t2._generate_cache_key())
|
|
|
|
ne_(t1._generate_cache_key(), t3._generate_cache_key())
|
|
ne_(t1._generate_cache_key(), t4._generate_cache_key())
|
|
ne_(t3._generate_cache_key(), t4._generate_cache_key())
|
|
|
|
def test_compare_comparison_associative(self):
|
|
|
|
l1 = table_c.c.x == table_d.c.y
|
|
l2 = table_d.c.y == table_c.c.x
|
|
l3 = table_c.c.x == table_d.c.z
|
|
|
|
is_true(l1.compare(l1))
|
|
is_true(l1.compare(l2))
|
|
is_false(l1.compare(l3))
|
|
|
|
def test_compare_comparison_non_commutative_inverses(self):
|
|
l1 = table_c.c.x >= table_d.c.y
|
|
l2 = table_d.c.y < table_c.c.x
|
|
l3 = table_d.c.y <= table_c.c.x
|
|
|
|
# we're not doing this kind of commutativity right now.
|
|
is_false(l1.compare(l2))
|
|
is_false(l1.compare(l3))
|
|
|
|
def test_compare_clauselist_associative(self):
|
|
|
|
l1 = and_(table_c.c.x == table_d.c.y, table_c.c.y == table_d.c.z)
|
|
|
|
l2 = and_(table_c.c.y == table_d.c.z, table_c.c.x == table_d.c.y)
|
|
|
|
l3 = and_(table_c.c.x == table_d.c.z, table_c.c.y == table_d.c.y)
|
|
|
|
is_true(l1.compare(l1))
|
|
is_true(l1.compare(l2))
|
|
is_false(l1.compare(l3))
|
|
|
|
def test_compare_clauselist_not_associative(self):
|
|
|
|
l1 = ClauseList(
|
|
table_c.c.x, table_c.c.y, table_d.c.y, operator=operators.sub
|
|
)
|
|
|
|
l2 = ClauseList(
|
|
table_d.c.y, table_c.c.x, table_c.c.y, operator=operators.sub
|
|
)
|
|
|
|
is_true(l1.compare(l1))
|
|
is_false(l1.compare(l2))
|
|
|
|
def test_compare_clauselist_assoc_different_operator(self):
|
|
|
|
l1 = and_(table_c.c.x == table_d.c.y, table_c.c.y == table_d.c.z)
|
|
|
|
l2 = or_(table_c.c.y == table_d.c.z, table_c.c.x == table_d.c.y)
|
|
|
|
is_false(l1.compare(l2))
|
|
|
|
def test_compare_clauselist_not_assoc_different_operator(self):
|
|
|
|
l1 = ClauseList(
|
|
table_c.c.x, table_c.c.y, table_d.c.y, operator=operators.sub
|
|
)
|
|
|
|
l2 = ClauseList(
|
|
table_c.c.x, table_c.c.y, table_d.c.y, operator=operators.div
|
|
)
|
|
|
|
is_false(l1.compare(l2))
|
|
|
|
def test_compare_labels(self):
|
|
is_true(column("q").label(None).compare(column("q").label(None)))
|
|
|
|
is_false(column("q").label("foo").compare(column("q").label(None)))
|
|
|
|
is_false(column("q").label(None).compare(column("q").label("foo")))
|
|
|
|
is_false(column("q").label("foo").compare(column("q").label("bar")))
|
|
|
|
is_true(column("q").label("foo").compare(column("q").label("foo")))
|
|
|
|
def test_compare_binds(self):
|
|
b1 = bindparam("foo", type_=Integer())
|
|
b2 = bindparam("foo", type_=Integer())
|
|
b3 = bindparam("foo", type_=String())
|
|
|
|
def c1():
|
|
return 5
|
|
|
|
def c2():
|
|
return 6
|
|
|
|
b4 = bindparam("foo", type_=Integer(), callable_=c1)
|
|
b5 = bindparam("foo", type_=Integer(), callable_=c2)
|
|
b6 = bindparam("foo", type_=Integer(), callable_=c1)
|
|
|
|
b7 = bindparam("foo", type_=Integer, value=5)
|
|
b8 = bindparam("foo", type_=Integer, value=6)
|
|
|
|
is_false(b1.compare(b4))
|
|
is_true(b4.compare(b6))
|
|
is_false(b4.compare(b5))
|
|
is_true(b1.compare(b2))
|
|
|
|
# currently not comparing "key", as we often have to compare
|
|
# anonymous names. however we should really check for that
|
|
# is_true(b1.compare(b3))
|
|
|
|
is_false(b1.compare(b3))
|
|
is_false(b1.compare(b7))
|
|
is_false(b7.compare(b8))
|
|
is_true(b7.compare(b7))
|
|
|
|
def test_compare_tables(self):
|
|
is_true(table_a.compare(table_a_2))
|
|
|
|
# the "proxy" version compares schema tables on metadata identity
|
|
is_false(table_a.compare(table_a_2, use_proxies=True))
|
|
|
|
# same for lower case tables since it compares lower case columns
|
|
# using proxies, which makes it very unlikely to have multiple
|
|
# table() objects with columns that compare equally
|
|
is_false(
|
|
table("a", column("x", Integer), column("q", String)).compare(
|
|
table("a", column("x", Integer), column("q", String)),
|
|
use_proxies=True,
|
|
)
|
|
)
|
|
|
|
def test_compare_annotated_clears_mapping(self):
|
|
t = table("t", column("x"), column("y"))
|
|
x_a = t.c.x._annotate({"foo": True})
|
|
x_b = t.c.x._annotate({"foo": True})
|
|
|
|
is_true(x_a.compare(x_b, compare_annotations=True))
|
|
is_false(
|
|
x_a.compare(x_b._annotate({"bar": True}), compare_annotations=True)
|
|
)
|
|
|
|
s1 = select([t.c.x])._annotate({"foo": True})
|
|
s2 = select([t.c.x])._annotate({"foo": True})
|
|
|
|
is_true(s1.compare(s2, compare_annotations=True))
|
|
|
|
is_false(
|
|
s1.compare(s2._annotate({"bar": True}), compare_annotations=True)
|
|
)
|
|
|
|
def test_compare_annotated_wo_annotations(self):
|
|
t = table("t", column("x"), column("y"))
|
|
x_a = t.c.x._annotate({})
|
|
x_b = t.c.x._annotate({"foo": True})
|
|
|
|
is_true(t.c.x.compare(x_a))
|
|
is_true(x_b.compare(x_a))
|
|
|
|
is_true(x_a.compare(t.c.x))
|
|
is_false(x_a.compare(t.c.y))
|
|
is_false(t.c.y.compare(x_a))
|
|
is_true((t.c.x == 5).compare(x_a == 5))
|
|
is_false((t.c.y == 5).compare(x_a == 5))
|
|
|
|
s = select([t]).subquery()
|
|
x_p = s.c.x
|
|
is_false(x_a.compare(x_p))
|
|
is_false(t.c.x.compare(x_p))
|
|
x_p_a = x_p._annotate({})
|
|
is_true(x_p_a.compare(x_p))
|
|
is_true(x_p.compare(x_p_a))
|
|
is_false(x_p_a.compare(x_a))
|