Add anonymizing context to cache keys, comparison; convert traversal

Created new visitor system called "internal traversal" that
applies a data driven approach to the concept of a class that
defines its own traversal steps, in contrast to the existing
style of traversal now known as "external traversal" where
the visitor class defines the traversal, i.e. the SQLCompiler.

The internal traversal system now implements get_children(),
_copy_internals(), compare() and _cache_key() for most Core elements.
Core elements with special needs like Select still implement
some of these methods directly however most of these methods
are no longer explicitly implemented.

The data-driven system is also applied to ORM elements that
take part in SQL expressions so that these objects, like mappers,
aliasedclass, query options, etc. can all participate in the
cache key process.

Still not considered is that this approach to defining traversibility
will be used to create some kind of generic introspection system
that works across Core / ORM.  It's also not clear if
real statement caching using the _cache_key() method is feasible,
if it is shown that running _cache_key() is nearly as expensive as
compiling in any case.    Because it is data driven, it is more
straightforward to optimize using inlined code, as is the case now,
as well as potentially using C code to speed it up.

In addition, the caching sytem now accommodates for anonymous
name labels, which is essential so that constructs which have
anonymous labels can be cacheable, that is, their position
within a statement in relation to other anonymous names causes
them to generate an integer counter relative to that construct
which will be the same every time.   Gathering of bound parameters
from any cache key generation is also now required as there is
no use case for a cache key that does not extract bound parameter
values.

Applies-to: #4639
Change-Id: I0660584def8627cad566719ee98d3be045db4b8d
This commit is contained in:
Mike Bayer
2019-08-29 14:45:23 -04:00
parent db47859dca
commit 29330ec159
37 changed files with 2341 additions and 1295 deletions
+4
View File
@@ -82,6 +82,9 @@ the FROM clause of a SELECT statement.
.. autoclass:: BindParameter
:members:
.. autoclass:: CacheKey
:members:
.. autoclass:: Case
:members:
@@ -90,6 +93,7 @@ the FROM clause of a SELECT statement.
.. autoclass:: ClauseElement
:members:
:inherited-members:
.. autoclass:: ClauseList
+1
View File
@@ -23,3 +23,4 @@ as well as when building out custom SQL expressions using the
.. automodule:: sqlalchemy.sql.visitors
:members:
:private-members:
-1
View File
@@ -103,7 +103,6 @@ class Insert(StandardInsert):
inserted_alias = getattr(self, "inserted_alias", None)
self._post_values_clause = OnDuplicateClause(inserted_alias, values)
return self
insert = public_factory(Insert, ".dialects.mysql.insert")
+7 -10
View File
@@ -1658,23 +1658,20 @@ class PGCompiler(compiler.SQLCompiler):
return "ONLY " + sqltext
def get_select_precolumns(self, select, **kw):
if select._distinct is not False:
if select._distinct is True:
return "DISTINCT "
elif isinstance(select._distinct, (list, tuple)):
if select._distinct or select._distinct_on:
if select._distinct_on:
return (
"DISTINCT ON ("
+ ", ".join(
[self.process(col, **kw) for col in select._distinct]
[
self.process(col, **kw)
for col in select._distinct_on
]
)
+ ") "
)
else:
return (
"DISTINCT ON ("
+ self.process(select._distinct, **kw)
+ ") "
)
return "DISTINCT "
else:
return ""
@@ -103,7 +103,6 @@ class Insert(StandardInsert):
self._post_values_clause = OnConflictDoUpdate(
constraint, index_elements, index_where, set_, where
)
return self
@_generative
def on_conflict_do_nothing(
@@ -138,7 +137,6 @@ class Insert(StandardInsert):
self._post_values_clause = OnConflictDoNothing(
constraint, index_elements, index_where
)
return self
insert = public_factory(Insert, ".dialects.postgresql.insert")
+1 -1
View File
@@ -198,7 +198,7 @@ class BakedQuery(object):
self.spoil()
else:
for opt in options:
cache_key = opt._generate_cache_key(cache_path)
cache_key = opt._generate_path_cache_key(cache_path)
if cache_key is False:
self.spoil()
elif cache_key is not None:
+1 -1
View File
@@ -455,7 +455,7 @@ def deregister(class_):
if hasattr(class_, "_compiler_dispatcher"):
# regenerate default _compiler_dispatch
visitors._generate_dispatch(class_)
visitors._generate_compiler_dispatch(class_)
# remove custom directive
del class_._compiler_dispatcher
+10
View File
@@ -47,6 +47,8 @@ from .base import state_str
from .. import event
from .. import inspection
from .. import util
from ..sql import base as sql_base
from ..sql import visitors
@inspection._self_inspects
@@ -54,6 +56,7 @@ class QueryableAttribute(
interfaces._MappedAttribute,
interfaces.InspectionAttr,
interfaces.PropComparator,
sql_base.HasCacheKey,
):
"""Base class for :term:`descriptor` objects that intercept
attribute events on behalf of a :class:`.MapperProperty`
@@ -102,6 +105,13 @@ class QueryableAttribute(
if base[key].dispatch._active_history:
self.dispatch._active_history = True
_cache_key_traversal = [
# ("class_", visitors.ExtendedInternalTraversal.dp_plain_obj),
("key", visitors.ExtendedInternalTraversal.dp_string),
("_parententity", visitors.ExtendedInternalTraversal.dp_multi),
("_of_type", visitors.ExtendedInternalTraversal.dp_multi),
]
@util.memoized_property
def _supports_population(self):
return self.impl.supports_population
-1
View File
@@ -216,7 +216,6 @@ def _assertions(*assertions):
for assertion in assertions:
assertion(self, fn.__name__)
fn(self, *args[1:], **kw)
return self
return generate
+11 -2
View File
@@ -36,6 +36,8 @@ from .. import inspect
from .. import inspection
from .. import util
from ..sql import operators
from ..sql import visitors
from ..sql.traversals import HasCacheKey
__all__ = (
@@ -54,7 +56,9 @@ __all__ = (
)
class MapperProperty(_MappedAttribute, InspectionAttr, util.MemoizedSlots):
class MapperProperty(
HasCacheKey, _MappedAttribute, InspectionAttr, util.MemoizedSlots
):
"""Represent a particular class attribute mapped by :class:`.Mapper`.
The most common occurrences of :class:`.MapperProperty` are the
@@ -74,6 +78,11 @@ class MapperProperty(_MappedAttribute, InspectionAttr, util.MemoizedSlots):
"info",
)
_cache_key_traversal = [
("parent", visitors.ExtendedInternalTraversal.dp_has_cache_key),
("key", visitors.ExtendedInternalTraversal.dp_string),
]
cascade = frozenset()
"""The set of 'cascade' attribute names.
@@ -647,7 +656,7 @@ class MapperOption(object):
self.process_query(query)
def _generate_cache_key(self, path):
def _generate_path_cache_key(self, path):
"""Used by the "baked lazy loader" to see if this option can be cached.
The "baked lazy loader" refers to the :class:`.Query` that is
+5 -1
View File
@@ -71,7 +71,7 @@ _CONFIGURE_MUTEX = util.threading.RLock()
@inspection._self_inspects
@log.class_logger
class Mapper(InspectionAttr):
class Mapper(sql_base.HasCacheKey, InspectionAttr):
"""Define the correlation of class attributes to database table
columns.
@@ -729,6 +729,10 @@ class Mapper(InspectionAttr):
"""
return self
_cache_key_traversal = [
("class_", visitors.ExtendedInternalTraversal.dp_plain_obj)
]
@property
def entity(self):
r"""Part of the inspection API.
+10 -2
View File
@@ -15,7 +15,8 @@ from .base import class_mapper
from .. import exc
from .. import inspection
from .. import util
from ..sql import visitors
from ..sql.traversals import HasCacheKey
log = logging.getLogger(__name__)
@@ -28,7 +29,7 @@ _WILDCARD_TOKEN = "*"
_DEFAULT_TOKEN = "_sa_default"
class PathRegistry(object):
class PathRegistry(HasCacheKey):
"""Represent query load paths and registry functions.
Basically represents structures like:
@@ -57,6 +58,10 @@ class PathRegistry(object):
is_token = False
is_root = False
_cache_key_traversal = [
("path", visitors.ExtendedInternalTraversal.dp_has_cache_key_list)
]
def __eq__(self, other):
return other is not None and self.path == other.path
@@ -78,6 +83,9 @@ class PathRegistry(object):
def __len__(self):
return len(self.path)
def __hash__(self):
return id(self)
@property
def length(self):
return len(self.path)
+33 -5
View File
@@ -26,11 +26,13 @@ from .. import inspect
from .. import util
from ..sql import coercions
from ..sql import roles
from ..sql import visitors
from ..sql.base import _generative
from ..sql.base import Generative
from ..sql.traversals import HasCacheKey
class Load(Generative, MapperOption):
class Load(HasCacheKey, Generative, MapperOption):
"""Represents loader options which modify the state of a
:class:`.Query` in order to affect how various mapped attributes are
loaded.
@@ -70,6 +72,17 @@ class Load(Generative, MapperOption):
"""
_cache_key_traversal = [
("path", visitors.ExtendedInternalTraversal.dp_has_cache_key),
("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj),
("_of_type", visitors.ExtendedInternalTraversal.dp_multi),
(
"_context_cache_key",
visitors.ExtendedInternalTraversal.dp_has_cache_key_tuples,
),
("local_opts", visitors.ExtendedInternalTraversal.dp_plain_dict),
]
def __init__(self, entity):
insp = inspect(entity)
self.path = insp._path_registry
@@ -89,7 +102,16 @@ class Load(Generative, MapperOption):
load._of_type = None
return load
def _generate_cache_key(self, path):
@property
def _context_cache_key(self):
serialized = []
for (key, loader_path), obj in self.context.items():
if key != "loader":
continue
serialized.append(loader_path + (obj,))
return serialized
def _generate_path_cache_key(self, path):
if path.path[0].is_aliased_class:
return False
@@ -522,9 +544,16 @@ class _UnboundLoad(Load):
self._to_bind = []
self.local_opts = {}
_cache_key_traversal = [
("path", visitors.ExtendedInternalTraversal.dp_multi_list),
("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj),
("_to_bind", visitors.ExtendedInternalTraversal.dp_has_cache_key_list),
("local_opts", visitors.ExtendedInternalTraversal.dp_plain_dict),
]
_is_chain_link = False
def _generate_cache_key(self, path):
def _generate_path_cache_key(self, path):
serialized = ()
for val in self._to_bind:
for local_elem, val_elem in zip(self.path, val.path):
@@ -533,7 +562,7 @@ class _UnboundLoad(Load):
else:
opt = val._bind_loader([path.path[0]], None, None, False)
if opt:
c_key = opt._generate_cache_key(path)
c_key = opt._generate_path_cache_key(path)
if c_key is False:
return False
elif c_key:
@@ -660,7 +689,6 @@ class _UnboundLoad(Load):
opt = meth(opt, all_tokens[-1], **kw)
opt._is_chain_link = False
return opt
def _chop_path(self, to_chop, path):
+9 -1
View File
@@ -30,10 +30,12 @@ from .. import exc as sa_exc
from .. import inspection
from .. import sql
from .. import util
from ..sql import base as sql_base
from ..sql import coercions
from ..sql import expression
from ..sql import roles
from ..sql import util as sql_util
from ..sql import visitors
all_cascades = frozenset(
@@ -530,7 +532,7 @@ class AliasedClass(object):
return str(self._aliased_insp)
class AliasedInsp(InspectionAttr):
class AliasedInsp(sql_base.HasCacheKey, InspectionAttr):
"""Provide an inspection interface for an
:class:`.AliasedClass` object.
@@ -627,6 +629,12 @@ class AliasedInsp(InspectionAttr):
def __clause_element__(self):
return self.selectable
_cache_key_traversal = [
("name", visitors.ExtendedInternalTraversal.dp_string),
("_adapt_on_names", visitors.ExtendedInternalTraversal.dp_boolean),
("selectable", visitors.ExtendedInternalTraversal.dp_clauseelement),
]
@property
def class_(self):
"""Return the mapped class ultimately represented by this
+54 -15
View File
@@ -12,12 +12,32 @@ associations.
"""
from . import operators
from .base import HasCacheKey
from .visitors import InternalTraversal
from .. import util
class SupportsCloneAnnotations(object):
class SupportsAnnotations(object):
@util.memoized_property
def _annotation_traversals(self):
return [
(
key,
InternalTraversal.dp_has_cache_key
if isinstance(value, HasCacheKey)
else InternalTraversal.dp_plain_obj,
)
for key, value in self._annotations.items()
]
class SupportsCloneAnnotations(SupportsAnnotations):
_annotations = util.immutabledict()
_traverse_internals = [
("_annotations", InternalTraversal.dp_annotations_state)
]
def _annotate(self, values):
"""return a copy of this ClauseElement with annotations
updated by the given dictionary.
@@ -25,6 +45,7 @@ class SupportsCloneAnnotations(object):
"""
new = self._clone()
new._annotations = new._annotations.union(values)
new.__dict__.pop("_annotation_traversals", None)
return new
def _with_annotations(self, values):
@@ -34,6 +55,7 @@ class SupportsCloneAnnotations(object):
"""
new = self._clone()
new._annotations = util.immutabledict(values)
new.__dict__.pop("_annotation_traversals", None)
return new
def _deannotate(self, values=None, clone=False):
@@ -49,12 +71,13 @@ class SupportsCloneAnnotations(object):
# the expression for a deep deannotation
new = self._clone()
new._annotations = {}
new.__dict__.pop("_annotation_traversals", None)
return new
else:
return self
class SupportsWrappingAnnotations(object):
class SupportsWrappingAnnotations(SupportsAnnotations):
def _annotate(self, values):
"""return a copy of this ClauseElement with annotations
updated by the given dictionary.
@@ -123,6 +146,7 @@ class Annotated(object):
def __init__(self, element, values):
self.__dict__ = element.__dict__.copy()
self.__dict__.pop("_annotation_traversals", None)
self.__element = element
self._annotations = values
self._hash = hash(element)
@@ -135,6 +159,7 @@ class Annotated(object):
def _with_annotations(self, values):
clone = self.__class__.__new__(self.__class__)
clone.__dict__ = self.__dict__.copy()
clone.__dict__.pop("_annotation_traversals", None)
clone._annotations = values
return clone
@@ -192,7 +217,17 @@ def _deep_annotate(element, annotations, exclude=None):
"""
def clone(elem):
# annotated objects hack the __hash__() method so if we want to
# uniquely process them we have to use id()
cloned_ids = {}
def clone(elem, **kw):
id_ = id(elem)
if id_ in cloned_ids:
return cloned_ids[id_]
if (
exclude
and hasattr(elem, "proxy_set")
@@ -204,6 +239,7 @@ def _deep_annotate(element, annotations, exclude=None):
else:
newelem = elem
newelem._copy_internals(clone=clone)
cloned_ids[id_] = newelem
return newelem
if element is not None:
@@ -214,23 +250,21 @@ def _deep_annotate(element, annotations, exclude=None):
def _deep_deannotate(element, values=None):
"""Deep copy the given element, removing annotations."""
cloned = util.column_dict()
cloned = {}
def clone(elem):
# if a values dict is given,
# the elem must be cloned each time it appears,
# as there may be different annotations in source
# elements that are remaining. if totally
# removing all annotations, can assume the same
# slate...
if values or elem not in cloned:
def clone(elem, **kw):
if values:
key = id(elem)
else:
key = elem
if key not in cloned:
newelem = elem._deannotate(values=values, clone=True)
newelem._copy_internals(clone=clone)
if not values:
cloned[elem] = newelem
cloned[key] = newelem
return newelem
else:
return cloned[elem]
return cloned[key]
if element is not None:
element = clone(element)
@@ -268,6 +302,11 @@ def _new_annotation_type(cls, base_cls):
"Annotated%s" % cls.__name__, (base_cls, cls), {}
)
globals()["Annotated%s" % cls.__name__] = anno_cls
if "_traverse_internals" in cls.__dict__:
anno_cls._traverse_internals = list(cls._traverse_internals) + [
("_annotations", InternalTraversal.dp_annotations_state)
]
return anno_cls
+29 -7
View File
@@ -14,6 +14,7 @@ import itertools
import operator
import re
from .traversals import HasCacheKey # noqa
from .visitors import ClauseVisitor
from .. import exc
from .. import util
@@ -38,18 +39,41 @@ class Immutable(object):
def _clone(self):
return self
def _copy_internals(self, **kw):
pass
class HasMemoized(object):
def _reset_memoizations(self):
self._memoized_property.expire_instance(self)
def _reset_exported(self):
self._memoized_property.expire_instance(self)
def _copy_internals(self, **kw):
super(HasMemoized, self)._copy_internals(**kw)
self._reset_memoizations()
def _from_objects(*elements):
return itertools.chain(*[element._from_objects for element in elements])
def _generative(fn):
"""non-caching _generative() decorator.
This is basically the legacy decorator that copies the object and
runs a method on the new copy.
"""
@util.decorator
def _generative(fn, *args, **kw):
def _generative(fn, self, *args, **kw):
"""Mark a method as generative."""
self = args[0]._generate()
fn(self, *args[1:], **kw)
self = self._generate()
x = fn(self, *args, **kw)
assert x is None, "generative methods must have no return value"
return self
decorated = _generative(fn)
@@ -357,10 +381,8 @@ class DialectKWArgs(object):
class Generative(object):
"""Allow a ClauseElement to generate itself via the
@_generative decorator.
"""
"""Provide a method-chaining pattern in conjunction with the
@_generative decorator."""
def _generate(self):
s = self.__class__.__new__(self.__class__)
-334
View File
@@ -1,334 +0,0 @@
from collections import deque
from . import operators
from .. import util
SKIP_TRAVERSE = util.symbol("skip_traverse")
def compare(obj1, obj2, **kw):
if kw.get("use_proxies", False):
strategy = ColIdentityComparatorStrategy()
else:
strategy = StructureComparatorStrategy()
return strategy.compare(obj1, obj2, **kw)
class StructureComparatorStrategy(object):
__slots__ = "compare_stack", "cache"
def __init__(self):
self.compare_stack = deque()
self.cache = set()
def compare(self, obj1, obj2, **kw):
stack = self.compare_stack
cache = self.cache
stack.append((obj1, obj2))
while stack:
left, right = stack.popleft()
if left is right:
continue
elif left is None or right is None:
# we know they are different so no match
return False
elif (left, right) in cache:
continue
cache.add((left, right))
visit_name = left.__visit_name__
# we're not exactly looking for identical types, because
# there are things like Column and AnnotatedColumn. So the
# visit_name has to at least match up
if visit_name != right.__visit_name__:
return False
meth = getattr(self, "compare_%s" % visit_name, None)
if meth:
comparison = meth(left, right, **kw)
if comparison is False:
return False
elif comparison is SKIP_TRAVERSE:
continue
for c1, c2 in util.zip_longest(
left.get_children(column_collections=False),
right.get_children(column_collections=False),
fillvalue=None,
):
if c1 is None or c2 is None:
# collections are different sizes, comparison fails
return False
stack.append((c1, c2))
return True
def compare_inner(self, obj1, obj2, **kw):
stack = self.compare_stack
try:
self.compare_stack = deque()
return self.compare(obj1, obj2, **kw)
finally:
self.compare_stack = stack
def _compare_unordered_sequences(self, seq1, seq2, **kw):
if seq1 is None:
return seq2 is None
completed = set()
for clause in seq1:
for other_clause in set(seq2).difference(completed):
if self.compare_inner(clause, other_clause, **kw):
completed.add(other_clause)
break
return len(completed) == len(seq1) == len(seq2)
def compare_bindparam(self, left, right, **kw):
# note the ".key" is often generated from id(self) so can't
# be compared, as far as determining structure.
return (
left.type._compare_type_affinity(right.type)
and left.value == right.value
and left.callable == right.callable
and left._orig_key == right._orig_key
)
def compare_clauselist(self, left, right, **kw):
if left.operator is right.operator:
if operators.is_associative(left.operator):
if self._compare_unordered_sequences(
left.clauses, right.clauses
):
return SKIP_TRAVERSE
else:
return False
else:
# normal ordered traversal
return True
else:
return False
def compare_unary(self, left, right, **kw):
if left.operator:
disp = self._get_operator_dispatch(
left.operator, "unary", "operator"
)
if disp is not None:
result = disp(left, right, left.operator, **kw)
if result is not True:
return result
elif left.modifier:
disp = self._get_operator_dispatch(
left.modifier, "unary", "modifier"
)
if disp is not None:
result = disp(left, right, left.operator, **kw)
if result is not True:
return result
return (
left.operator == right.operator and left.modifier == right.modifier
)
def compare_binary(self, left, right, **kw):
disp = self._get_operator_dispatch(left.operator, "binary", None)
if disp:
result = disp(left, right, left.operator, **kw)
if result is not True:
return result
if left.operator == right.operator:
if operators.is_commutative(left.operator):
if (
compare(left.left, right.left, **kw)
and compare(left.right, right.right, **kw)
) or (
compare(left.left, right.right, **kw)
and compare(left.right, right.left, **kw)
):
return SKIP_TRAVERSE
else:
return False
else:
return True
else:
return False
def _get_operator_dispatch(self, operator_, qualifier1, qualifier2):
# used by compare_binary, compare_unary
attrname = "visit_%s_%s%s" % (
operator_.__name__,
qualifier1,
"_" + qualifier2 if qualifier2 else "",
)
return getattr(self, attrname, None)
def visit_function_as_comparison_op_binary(
self, left, right, operator, **kw
):
return (
left.left_index == right.left_index
and left.right_index == right.right_index
)
def compare_function(self, left, right, **kw):
return left.name == right.name
def compare_column(self, left, right, **kw):
if left.table is not None:
self.compare_stack.appendleft((left.table, right.table))
return (
left.key == right.key
and left.name == right.name
and (
left.type._compare_type_affinity(right.type)
if left.type is not None
else right.type is None
)
and left.is_literal == right.is_literal
)
def compare_collation(self, left, right, **kw):
return left.collation == right.collation
def compare_type_coerce(self, left, right, **kw):
return left.type._compare_type_affinity(right.type)
@util.dependencies("sqlalchemy.sql.elements")
def compare_alias(self, elements, left, right, **kw):
return (
left.name == right.name
if not isinstance(left.name, elements._anonymous_label)
else isinstance(right.name, elements._anonymous_label)
)
def compare_cte(self, elements, left, right, **kw):
raise NotImplementedError("TODO")
def compare_extract(self, left, right, **kw):
return left.field == right.field
def compare_textual_label_reference(self, left, right, **kw):
return left.element == right.element
def compare_slice(self, left, right, **kw):
return (
left.start == right.start
and left.stop == right.stop
and left.step == right.step
)
def compare_over(self, left, right, **kw):
return left.range_ == right.range_ and left.rows == right.rows
@util.dependencies("sqlalchemy.sql.elements")
def compare_label(self, elements, left, right, **kw):
return left._type._compare_type_affinity(right._type) and (
left.name == right.name
if not isinstance(left.name, elements._anonymous_label)
else isinstance(right.name, elements._anonymous_label)
)
def compare_typeclause(self, left, right, **kw):
return left.type._compare_type_affinity(right.type)
def compare_join(self, left, right, **kw):
return left.isouter == right.isouter and left.full == right.full
def compare_table(self, left, right, **kw):
if left.name != right.name:
return False
self.compare_stack.extendleft(
util.zip_longest(left.columns, right.columns)
)
def compare_compound_select(self, left, right, **kw):
if not self._compare_unordered_sequences(
left.selects, right.selects, **kw
):
return False
if left.keyword != right.keyword:
return False
if left._for_update_arg != right._for_update_arg:
return False
if not self.compare_inner(
left._order_by_clause, right._order_by_clause, **kw
):
return False
if not self.compare_inner(
left._group_by_clause, right._group_by_clause, **kw
):
return False
return SKIP_TRAVERSE
def compare_select(self, left, right, **kw):
if not self._compare_unordered_sequences(
left._correlate, right._correlate
):
return False
if not self._compare_unordered_sequences(
left._correlate_except, right._correlate_except
):
return False
if not self._compare_unordered_sequences(
left._from_obj, right._from_obj
):
return False
if left._for_update_arg != right._for_update_arg:
return False
return True
def compare_textual_select(self, left, right, **kw):
self.compare_stack.extendleft(
util.zip_longest(left.column_args, right.column_args)
)
return left.positional == right.positional
class ColIdentityComparatorStrategy(StructureComparatorStrategy):
def compare_column_element(
self, left, right, use_proxies=True, equivalents=(), **kw
):
"""Compare ColumnElements using proxies and equivalent collections.
This is a comparison strategy specific to the ORM.
"""
to_compare = (right,)
if equivalents and right in equivalents:
to_compare = equivalents[right].union(to_compare)
for oth in to_compare:
if use_proxies and left.shares_lineage(oth):
return True
elif hash(left) == hash(right):
return True
else:
return False
def compare_column(self, left, right, **kw):
return self.compare_column_element(left, right, **kw)
def compare_label(self, left, right, **kw):
return self.compare_column_element(left, right, **kw)
def compare_table(self, left, right, **kw):
# tables compare on identity, since it's not really feasible to
# compare them column by column with the above rules
return left is right
+22 -7
View File
@@ -434,6 +434,27 @@ class _CompileLabel(elements.ColumnElement):
return self
class prefix_anon_map(dict):
"""A map that creates new keys for missing key access.
Considers keys of the form "<ident> <name>" to produce
new symbols "<name>_<index>", where "index" is an incrementing integer
corresponding to <name>.
Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which
is otherwise usually used for this type of operation.
"""
def __missing__(self, key):
(ident, derived) = key.split(" ", 1)
anonymous_counter = self.get(derived, 1)
self[derived] = anonymous_counter + 1
value = derived + "_" + str(anonymous_counter)
self[key] = value
return value
class SQLCompiler(Compiled):
"""Default implementation of :class:`.Compiled`.
@@ -574,7 +595,7 @@ class SQLCompiler(Compiled):
# a map which tracks "anonymous" identifiers that are created on
# the fly here
self.anon_map = util.PopulateDict(self._process_anon)
self.anon_map = prefix_anon_map()
# a map which tracks "truncated" names based on
# dialect.label_length or dialect.max_identifier_length
@@ -1712,12 +1733,6 @@ class SQLCompiler(Compiled):
def _anonymize(self, name):
return name % self.anon_map
def _process_anon(self, key):
(ident, derived) = key.split(" ", 1)
anonymous_counter = self.anon_map.get(derived, 1)
self.anon_map[derived] = anonymous_counter + 1
return derived + "_" + str(anonymous_counter)
def bindparam_string(
self,
name,
+3
View File
@@ -178,6 +178,9 @@ def _unsupported_impl(expr, op, *arg, **kw):
def _inv_impl(expr, op, **kw):
"""See :meth:`.ColumnOperators.__inv__`."""
# undocumented element currently used by the ORM for
# relationship.contains()
if hasattr(expr, "negation_clause"):
return expr.negation_clause
else:
+186 -329
View File
@@ -16,23 +16,29 @@ import itertools
import operator
import re
from . import clause_compare
from . import coercions
from . import operators
from . import roles
from . import traversals
from . import type_api
from .annotation import Annotated
from .annotation import SupportsWrappingAnnotations
from .base import _clone
from .base import _generative
from .base import Executable
from .base import HasCacheKey
from .base import HasMemoized
from .base import Immutable
from .base import NO_ARG
from .base import PARSE_AUTOCOMMIT
from .coercions import _document_text_coercion
from .traversals import _copy_internals
from .traversals import _get_children
from .traversals import NO_CACHE
from .visitors import cloned_traverse
from .visitors import InternalTraversal
from .visitors import traverse
from .visitors import Visitable
from .visitors import Traversible
from .. import exc
from .. import inspection
from .. import util
@@ -162,7 +168,9 @@ def not_(clause):
@inspection._self_inspects
class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
class ClauseElement(
roles.SQLRole, SupportsWrappingAnnotations, HasCacheKey, Traversible
):
"""Base class for elements of a programmatically constructed SQL
expression.
@@ -190,6 +198,13 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
_order_by_label_element = None
@property
def _cache_key_traversal(self):
try:
return self._traverse_internals
except AttributeError:
return NO_CACHE
def _clone(self):
"""Create a shallow copy of this ClauseElement.
@@ -221,28 +236,6 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
"""
return self
def _cache_key(self, **kw):
"""return an optional cache key.
The cache key is a tuple which can contain any series of
objects that are hashable and also identifies
this object uniquely within the presence of a larger SQL expression
or statement, for the purposes of caching the resulting query.
The cache key should be based on the SQL compiled structure that would
ultimately be produced. That is, two structures that are composed in
exactly the same way should produce the same cache key; any difference
in the strucures that would affect the SQL string or the type handlers
should result in a different cache key.
If a structure cannot produce a useful cache key, it should raise
NotImplementedError, which will result in the entire structure
for which it's part of not being useful as a cache key.
"""
raise NotImplementedError()
@property
def _constructor(self):
"""return the 'constructor' for this ClauseElement.
@@ -336,9 +329,9 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
(see :class:`.ColumnElement`)
"""
return clause_compare.compare(self, other, **kw)
return traversals.compare(self, other, **kw)
def _copy_internals(self, clone=_clone, **kw):
def _copy_internals(self, **kw):
"""Reassign internal elements to be clones of themselves.
Called during a copy-and-traverse operation on newly
@@ -349,21 +342,46 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
traversal, cloned traversal, annotations).
"""
pass
def get_children(self, **kwargs):
r"""Return immediate child elements of this :class:`.ClauseElement`.
try:
traverse_internals = self._traverse_internals
except AttributeError:
return
for attrname, obj, meth in _copy_internals.run_generated_dispatch(
self, traverse_internals, "_generated_copy_internals_traversal"
):
if obj is not None:
result = meth(self, obj, **kw)
if result is not None:
setattr(self, attrname, result)
def get_children(self, omit_attrs=None, **kw):
r"""Return immediate child :class:`.Traversible` elements of this
:class:`.Traversible`.
This is used for visit traversal.
\**kwargs may contain flags that change the collection that is
\**kw may contain flags that change the collection that is
returned, for example to return a subset of items in order to
cut down on larger traversals, or to return child items from a
different context (such as schema-level collections instead of
clause-level).
"""
return []
result = []
try:
traverse_internals = self._traverse_internals
except AttributeError:
return result
for attrname, obj, meth in _get_children.run_generated_dispatch(
self, traverse_internals, "_generated_get_children_traversal"
):
if obj is None or omit_attrs and attrname in omit_attrs:
continue
result.extend(meth(obj, **kw))
return result
def self_group(self, against=None):
# type: (Optional[Any]) -> ClauseElement
@@ -501,6 +519,8 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
return or_(self, other)
def __invert__(self):
# undocumented element currently used by the ORM for
# relationship.contains()
if hasattr(self, "negation_clause"):
return self.negation_clause
else:
@@ -508,9 +528,7 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
def _negate(self):
return UnaryExpression(
self.self_group(against=operators.inv),
operator=operators.inv,
negate=None,
self.self_group(against=operators.inv), operator=operators.inv
)
def __bool__(self):
@@ -731,9 +749,6 @@ class ColumnElement(
else:
return comparator_factory(self)
def _cache_key(self, **kw):
raise NotImplementedError(self.__class__)
def __getattr__(self, key):
try:
return getattr(self.comparator, key)
@@ -969,6 +984,13 @@ class BindParameter(roles.InElementRole, ColumnElement):
__visit_name__ = "bindparam"
_traverse_internals = [
("key", InternalTraversal.dp_anon_name),
("type", InternalTraversal.dp_type),
("callable", InternalTraversal.dp_plain_dict),
("value", InternalTraversal.dp_plain_obj),
]
_is_crud = False
_expanding_in_types = ()
@@ -1321,26 +1343,19 @@ class BindParameter(roles.InElementRole, ColumnElement):
)
return c
def _cache_key(self, bindparams=None, **kw):
if bindparams is None:
# even though _cache_key is a private method, we would like to
# be super paranoid about this point. You can't include the
# "value" or "callable" in the cache key, because the value is
# not part of the structure of a statement and is likely to
# change every time. However you cannot *throw it away* either,
# because you can't invoke the statement without the parameter
# values that were explicitly placed. So require that they
# are collected here to make sure this happens.
if self._value_required_for_cache:
raise NotImplementedError(
"bindparams collection argument required for _cache_key "
"implementation. Bound parameter cache keys are not safe "
"to use without accommodating for the value or callable "
"within the parameter itself."
)
else:
bindparams.append(self)
return (BindParameter, self.type._cache_key, self._orig_key)
def _gen_cache_key(self, anon_map, bindparams):
if self in anon_map:
return (anon_map[self], self.__class__)
id_ = anon_map[self]
bindparams.append(self)
return (
id_,
self.__class__,
self.type._gen_cache_key,
traversals._resolve_name_for_compare(self, self.key, anon_map),
)
def _convert_to_unique(self):
if not self.unique:
@@ -1377,12 +1392,11 @@ class TypeClause(ClauseElement):
__visit_name__ = "typeclause"
_traverse_internals = [("type", InternalTraversal.dp_type)]
def __init__(self, type_):
self.type = type_
def _cache_key(self, **kw):
return (TypeClause, self.type._cache_key)
class TextClause(
roles.DDLConstraintColumnRole,
@@ -1419,6 +1433,11 @@ class TextClause(
__visit_name__ = "textclause"
_traverse_internals = [
("_bindparams", InternalTraversal.dp_string_clauseelement_dict),
("text", InternalTraversal.dp_string),
]
_is_text_clause = True
_is_textual = True
@@ -1861,19 +1880,6 @@ class TextClause(
else:
return self
def _copy_internals(self, clone=_clone, **kw):
self._bindparams = dict(
(b.key, clone(b, **kw)) for b in self._bindparams.values()
)
def get_children(self, **kwargs):
return list(self._bindparams.values())
def _cache_key(self, **kw):
return (self.text,) + tuple(
bind._cache_key for bind in self._bindparams.values()
)
class Null(roles.ConstExprRole, ColumnElement):
"""Represent the NULL keyword in a SQL statement.
@@ -1885,6 +1891,8 @@ class Null(roles.ConstExprRole, ColumnElement):
__visit_name__ = "null"
_traverse_internals = []
@util.memoized_property
def type(self):
return type_api.NULLTYPE
@@ -1895,9 +1903,6 @@ class Null(roles.ConstExprRole, ColumnElement):
return Null()
def _cache_key(self, **kw):
return (Null,)
class False_(roles.ConstExprRole, ColumnElement):
"""Represent the ``false`` keyword, or equivalent, in a SQL statement.
@@ -1908,6 +1913,7 @@ class False_(roles.ConstExprRole, ColumnElement):
"""
__visit_name__ = "false"
_traverse_internals = []
@util.memoized_property
def type(self):
@@ -1954,9 +1960,6 @@ class False_(roles.ConstExprRole, ColumnElement):
return False_()
def _cache_key(self, **kw):
return (False_,)
class True_(roles.ConstExprRole, ColumnElement):
"""Represent the ``true`` keyword, or equivalent, in a SQL statement.
@@ -1968,6 +1971,8 @@ class True_(roles.ConstExprRole, ColumnElement):
__visit_name__ = "true"
_traverse_internals = []
@util.memoized_property
def type(self):
return type_api.BOOLEANTYPE
@@ -2020,9 +2025,6 @@ class True_(roles.ConstExprRole, ColumnElement):
return True_()
def _cache_key(self, **kw):
return (True_,)
class ClauseList(
roles.InElementRole,
@@ -2038,6 +2040,11 @@ class ClauseList(
__visit_name__ = "clauselist"
_traverse_internals = [
("clauses", InternalTraversal.dp_clauseelement_list),
("operator", InternalTraversal.dp_operator),
]
def __init__(self, *clauses, **kwargs):
self.operator = kwargs.pop("operator", operators.comma_op)
self.group = kwargs.pop("group", True)
@@ -2082,17 +2089,6 @@ class ClauseList(
coercions.expect(self._text_converter_role, clause)
)
def _copy_internals(self, clone=_clone, **kw):
self.clauses = [clone(clause, **kw) for clause in self.clauses]
def get_children(self, **kwargs):
return self.clauses
def _cache_key(self, **kw):
return (ClauseList, self.operator) + tuple(
clause._cache_key(**kw) for clause in self.clauses
)
@property
def _from_objects(self):
return list(itertools.chain(*[c._from_objects for c in self.clauses]))
@@ -2115,11 +2111,6 @@ class BooleanClauseList(ClauseList, ColumnElement):
"BooleanClauseList has a private constructor"
)
def _cache_key(self, **kw):
return (BooleanClauseList, self.operator) + tuple(
clause._cache_key(**kw) for clause in self.clauses
)
@classmethod
def _construct(cls, operator, continue_on, skip_on, *clauses, **kw):
convert_clauses = []
@@ -2250,6 +2241,8 @@ or_ = BooleanClauseList.or_
class Tuple(ClauseList, ColumnElement):
"""Represent a SQL tuple."""
_traverse_internals = ClauseList._traverse_internals + []
def __init__(self, *clauses, **kw):
"""Return a :class:`.Tuple`.
@@ -2289,11 +2282,6 @@ class Tuple(ClauseList, ColumnElement):
def _select_iterable(self):
return (self,)
def _cache_key(self, **kw):
return (Tuple,) + tuple(
clause._cache_key(**kw) for clause in self.clauses
)
def _bind_param(self, operator, obj, type_=None):
return Tuple(
*[
@@ -2339,6 +2327,12 @@ class Case(ColumnElement):
__visit_name__ = "case"
_traverse_internals = [
("value", InternalTraversal.dp_clauseelement),
("whens", InternalTraversal.dp_clauseelement_tuples),
("else_", InternalTraversal.dp_clauseelement),
]
def __init__(self, whens, value=None, else_=None):
r"""Produce a ``CASE`` expression.
@@ -2501,40 +2495,6 @@ class Case(ColumnElement):
else:
self.else_ = None
def _copy_internals(self, clone=_clone, **kw):
if self.value is not None:
self.value = clone(self.value, **kw)
self.whens = [(clone(x, **kw), clone(y, **kw)) for x, y in self.whens]
if self.else_ is not None:
self.else_ = clone(self.else_, **kw)
def get_children(self, **kwargs):
if self.value is not None:
yield self.value
for x, y in self.whens:
yield x
yield y
if self.else_ is not None:
yield self.else_
def _cache_key(self, **kw):
return (
(
Case,
self.value._cache_key(**kw)
if self.value is not None
else None,
)
+ tuple(
(x._cache_key(**kw), y._cache_key(**kw)) for x, y in self.whens
)
+ (
self.else_._cache_key(**kw)
if self.else_ is not None
else None,
)
)
@property
def _from_objects(self):
return list(
@@ -2603,6 +2563,11 @@ class Cast(WrapsColumnExpression, ColumnElement):
__visit_name__ = "cast"
_traverse_internals = [
("clause", InternalTraversal.dp_clauseelement),
("typeclause", InternalTraversal.dp_clauseelement),
]
def __init__(self, expression, type_):
r"""Produce a ``CAST`` expression.
@@ -2662,20 +2627,6 @@ class Cast(WrapsColumnExpression, ColumnElement):
)
self.typeclause = TypeClause(self.type)
def _copy_internals(self, clone=_clone, **kw):
self.clause = clone(self.clause, **kw)
self.typeclause = clone(self.typeclause, **kw)
def get_children(self, **kwargs):
return self.clause, self.typeclause
def _cache_key(self, **kw):
return (
Cast,
self.clause._cache_key(**kw),
self.typeclause._cache_key(**kw),
)
@property
def _from_objects(self):
return self.clause._from_objects
@@ -2685,7 +2636,7 @@ class Cast(WrapsColumnExpression, ColumnElement):
return self.clause
class TypeCoerce(WrapsColumnExpression, ColumnElement):
class TypeCoerce(HasMemoized, WrapsColumnExpression, ColumnElement):
"""Represent a Python-side type-coercion wrapper.
:class:`.TypeCoerce` supplies the :func:`.expression.type_coerce`
@@ -2705,6 +2656,13 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement):
__visit_name__ = "type_coerce"
_traverse_internals = [
("clause", InternalTraversal.dp_clauseelement),
("type", InternalTraversal.dp_type),
]
_memoized_property = util.group_expirable_memoized_property()
def __init__(self, expression, type_):
r"""Associate a SQL expression with a particular type, without rendering
``CAST``.
@@ -2773,21 +2731,11 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement):
roles.ExpressionElementRole, expression, type_=self.type
)
def _copy_internals(self, clone=_clone, **kw):
self.clause = clone(self.clause, **kw)
self.__dict__.pop("typed_expression", None)
def get_children(self, **kwargs):
return (self.clause,)
def _cache_key(self, **kw):
return (TypeCoerce, self.type._cache_key, self.clause._cache_key(**kw))
@property
def _from_objects(self):
return self.clause._from_objects
@util.memoized_property
@_memoized_property
def typed_expression(self):
if isinstance(self.clause, BindParameter):
bp = self.clause._clone()
@@ -2806,6 +2754,11 @@ class Extract(ColumnElement):
__visit_name__ = "extract"
_traverse_internals = [
("expr", InternalTraversal.dp_clauseelement),
("field", InternalTraversal.dp_string),
]
def __init__(self, field, expr, **kwargs):
"""Return a :class:`.Extract` construct.
@@ -2818,15 +2771,6 @@ class Extract(ColumnElement):
self.field = field
self.expr = coercions.expect(roles.ExpressionElementRole, expr)
def _copy_internals(self, clone=_clone, **kw):
self.expr = clone(self.expr, **kw)
def get_children(self, **kwargs):
return (self.expr,)
def _cache_key(self, **kw):
return (Extract, self.field, self.expr._cache_key(**kw))
@property
def _from_objects(self):
return self.expr._from_objects
@@ -2847,18 +2791,11 @@ class _label_reference(ColumnElement):
__visit_name__ = "label_reference"
_traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
def __init__(self, element):
self.element = element
def _copy_internals(self, clone=_clone, **kw):
self.element = clone(self.element, **kw)
def _cache_key(self, **kw):
return (_label_reference, self.element._cache_key(**kw))
def get_children(self, **kwargs):
return [self.element]
@property
def _from_objects(self):
return ()
@@ -2867,6 +2804,8 @@ class _label_reference(ColumnElement):
class _textual_label_reference(ColumnElement):
__visit_name__ = "textual_label_reference"
_traverse_internals = [("element", InternalTraversal.dp_string)]
def __init__(self, element):
self.element = element
@@ -2874,9 +2813,6 @@ class _textual_label_reference(ColumnElement):
def _text_clause(self):
return TextClause._create_text(self.element)
def _cache_key(self, **kw):
return (_textual_label_reference, self.element)
class UnaryExpression(ColumnElement):
"""Define a 'unary' expression.
@@ -2894,13 +2830,18 @@ class UnaryExpression(ColumnElement):
__visit_name__ = "unary"
_traverse_internals = [
("element", InternalTraversal.dp_clauseelement),
("operator", InternalTraversal.dp_operator),
("modifier", InternalTraversal.dp_operator),
]
def __init__(
self,
element,
operator=None,
modifier=None,
type_=None,
negate=None,
wraps_column_expression=False,
):
self.operator = operator
@@ -2909,7 +2850,6 @@ class UnaryExpression(ColumnElement):
against=self.operator or self.modifier
)
self.type = type_api.to_instance(type_)
self.negate = negate
self.wraps_column_expression = wraps_column_expression
@classmethod
@@ -3135,37 +3075,13 @@ class UnaryExpression(ColumnElement):
def _from_objects(self):
return self.element._from_objects
def _copy_internals(self, clone=_clone, **kw):
self.element = clone(self.element, **kw)
def _cache_key(self, **kw):
return (
UnaryExpression,
self.element._cache_key(**kw),
self.operator,
self.modifier,
)
def get_children(self, **kwargs):
return (self.element,)
def _negate(self):
if self.negate is not None:
return UnaryExpression(
self.element,
operator=self.negate,
negate=self.operator,
modifier=self.modifier,
type_=self.type,
wraps_column_expression=self.wraps_column_expression,
)
elif self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity:
if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity:
return UnaryExpression(
self.self_group(against=operators.inv),
operator=operators.inv,
type_=type_api.BOOLEANTYPE,
wraps_column_expression=self.wraps_column_expression,
negate=None,
)
else:
return ClauseElement._negate(self)
@@ -3286,15 +3202,6 @@ class AsBoolean(WrapsColumnExpression, UnaryExpression):
# type: (Optional[Any]) -> ClauseElement
return self
def _cache_key(self, **kw):
return (
self.element._cache_key(**kw),
self.type._cache_key,
self.operator,
self.negate,
self.modifier,
)
def _negate(self):
if isinstance(self.element, (True_, False_)):
return self.element._negate()
@@ -3318,6 +3225,14 @@ class BinaryExpression(ColumnElement):
__visit_name__ = "binary"
_traverse_internals = [
("left", InternalTraversal.dp_clauseelement),
("right", InternalTraversal.dp_clauseelement),
("operator", InternalTraversal.dp_operator),
("negate", InternalTraversal.dp_operator),
("modifiers", InternalTraversal.dp_plain_dict),
]
_is_implicitly_boolean = True
"""Indicates that any database will know this is a boolean expression
even if the database does not have an explicit boolean datatype.
@@ -3360,20 +3275,6 @@ class BinaryExpression(ColumnElement):
def _from_objects(self):
return self.left._from_objects + self.right._from_objects
def _copy_internals(self, clone=_clone, **kw):
self.left = clone(self.left, **kw)
self.right = clone(self.right, **kw)
def get_children(self, **kwargs):
return self.left, self.right
def _cache_key(self, **kw):
return (
BinaryExpression,
self.left._cache_key(**kw),
self.right._cache_key(**kw),
)
def self_group(self, against=None):
# type: (Optional[Any]) -> ClauseElement
@@ -3406,6 +3307,12 @@ class Slice(ColumnElement):
__visit_name__ = "slice"
_traverse_internals = [
("start", InternalTraversal.dp_plain_obj),
("stop", InternalTraversal.dp_plain_obj),
("step", InternalTraversal.dp_plain_obj),
]
def __init__(self, start, stop, step):
self.start = start
self.stop = stop
@@ -3417,9 +3324,6 @@ class Slice(ColumnElement):
assert against is operator.getitem
return self
def _cache_key(self, **kw):
return (Slice, self.start, self.stop, self.step)
class IndexExpression(BinaryExpression):
"""Represent the class of expressions that are like an "index" operation.
@@ -3444,6 +3348,11 @@ class GroupedElement(ClauseElement):
class Grouping(GroupedElement, ColumnElement):
"""Represent a grouping within a column expression"""
_traverse_internals = [
("element", InternalTraversal.dp_clauseelement),
("type", InternalTraversal.dp_type),
]
def __init__(self, element):
self.element = element
self.type = getattr(element, "type", type_api.NULLTYPE)
@@ -3460,15 +3369,6 @@ class Grouping(GroupedElement, ColumnElement):
def _label(self):
return getattr(self.element, "_label", None) or self.anon_label
def _copy_internals(self, clone=_clone, **kw):
self.element = clone(self.element, **kw)
def get_children(self, **kwargs):
return (self.element,)
def _cache_key(self, **kw):
return (Grouping, self.element._cache_key(**kw))
@property
def _from_objects(self):
return self.element._from_objects
@@ -3501,6 +3401,14 @@ class Over(ColumnElement):
__visit_name__ = "over"
_traverse_internals = [
("element", InternalTraversal.dp_clauseelement),
("order_by", InternalTraversal.dp_clauseelement),
("partition_by", InternalTraversal.dp_clauseelement),
("range_", InternalTraversal.dp_plain_obj),
("rows", InternalTraversal.dp_plain_obj),
]
order_by = None
partition_by = None
@@ -3667,30 +3575,6 @@ class Over(ColumnElement):
def type(self):
return self.element.type
def get_children(self, **kwargs):
return [
c
for c in (self.element, self.partition_by, self.order_by)
if c is not None
]
def _cache_key(self, **kw):
return (
(Over,)
+ tuple(
e._cache_key(**kw) if e is not None else None
for e in (self.element, self.partition_by, self.order_by)
)
+ (self.range_, self.rows)
)
def _copy_internals(self, clone=_clone, **kw):
self.element = clone(self.element, **kw)
if self.partition_by is not None:
self.partition_by = clone(self.partition_by, **kw)
if self.order_by is not None:
self.order_by = clone(self.order_by, **kw)
@property
def _from_objects(self):
return list(
@@ -3723,6 +3607,11 @@ class WithinGroup(ColumnElement):
__visit_name__ = "withingroup"
_traverse_internals = [
("element", InternalTraversal.dp_clauseelement),
("order_by", InternalTraversal.dp_clauseelement),
]
order_by = None
def __init__(self, element, *order_by):
@@ -3791,25 +3680,6 @@ class WithinGroup(ColumnElement):
else:
return self.element.type
def get_children(self, **kwargs):
return [c for c in (self.element, self.order_by) if c is not None]
def _cache_key(self, **kw):
return (
WithinGroup,
self.element._cache_key(**kw)
if self.element is not None
else None,
self.order_by._cache_key(**kw)
if self.order_by is not None
else None,
)
def _copy_internals(self, clone=_clone, **kw):
self.element = clone(self.element, **kw)
if self.order_by is not None:
self.order_by = clone(self.order_by, **kw)
@property
def _from_objects(self):
return list(
@@ -3845,6 +3715,11 @@ class FunctionFilter(ColumnElement):
__visit_name__ = "funcfilter"
_traverse_internals = [
("func", InternalTraversal.dp_clauseelement),
("criterion", InternalTraversal.dp_clauseelement),
]
criterion = None
def __init__(self, func, *criterion):
@@ -3932,23 +3807,6 @@ class FunctionFilter(ColumnElement):
def type(self):
return self.func.type
def get_children(self, **kwargs):
return [c for c in (self.func, self.criterion) if c is not None]
def _copy_internals(self, clone=_clone, **kw):
self.func = clone(self.func, **kw)
if self.criterion is not None:
self.criterion = clone(self.criterion, **kw)
def _cache_key(self, **kw):
return (
FunctionFilter,
self.func._cache_key(**kw),
self.criterion._cache_key(**kw)
if self.criterion is not None
else None,
)
@property
def _from_objects(self):
return list(
@@ -3962,7 +3820,7 @@ class FunctionFilter(ColumnElement):
)
class Label(roles.LabeledColumnExprRole, ColumnElement):
class Label(HasMemoized, roles.LabeledColumnExprRole, ColumnElement):
"""Represents a column label (AS).
Represent a label, as typically applied to any column-level
@@ -3972,6 +3830,14 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
__visit_name__ = "label"
_traverse_internals = [
("name", InternalTraversal.dp_anon_name),
("_type", InternalTraversal.dp_type),
("_element", InternalTraversal.dp_clauseelement),
]
_memoized_property = util.group_expirable_memoized_property()
def __init__(self, name, element, type_=None):
"""Return a :class:`Label` object for the
given :class:`.ColumnElement`.
@@ -4010,14 +3876,11 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
def __reduce__(self):
return self.__class__, (self.name, self._element, self._type)
def _cache_key(self, **kw):
return (Label, self.element._cache_key(**kw), self._resolve_label)
@util.memoized_property
def _is_implicitly_boolean(self):
return self.element._is_implicitly_boolean
@util.memoized_property
@_memoized_property
def _allow_label_resolve(self):
return self.element._allow_label_resolve
@@ -4031,7 +3894,7 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
self._type or getattr(self._element, "type", None)
)
@util.memoized_property
@_memoized_property
def element(self):
return self._element.self_group(against=operators.as_)
@@ -4057,13 +3920,9 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
def foreign_keys(self):
return self.element.foreign_keys
def get_children(self, **kwargs):
return (self.element,)
def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw):
self._reset_memoizations()
self._element = clone(self._element, **kw)
self.__dict__.pop("element", None)
self.__dict__.pop("_allow_label_resolve", None)
if anonymize_labels:
self.name = self._resolve_label = _anonymous_label(
"%%(%d %s)s"
@@ -4124,6 +3983,13 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement):
__visit_name__ = "column"
_traverse_internals = [
("name", InternalTraversal.dp_string),
("type", InternalTraversal.dp_type),
("table", InternalTraversal.dp_clauseelement),
("is_literal", InternalTraversal.dp_boolean),
]
onupdate = default = server_default = server_onupdate = None
_is_multiparam_column = False
@@ -4254,14 +4120,6 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement):
table = property(_get_table, _set_table)
def _cache_key(self, **kw):
return (
self.name,
self.table.name if self.table is not None else None,
self.is_literal,
self.type._cache_key,
)
@_memoized_property
def _from_objects(self):
t = self.table
@@ -4395,12 +4253,11 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement):
class CollationClause(ColumnElement):
__visit_name__ = "collation"
_traverse_internals = [("collation", InternalTraversal.dp_string)]
def __init__(self, collation):
self.collation = collation
def _cache_key(self, **kw):
return (CollationClause, self.collation)
class _IdentifiedClause(Executable, ClauseElement):
-2
View File
@@ -86,7 +86,6 @@ __all__ = [
from .base import _from_objects # noqa
from .base import ColumnCollection # noqa
from .base import Executable # noqa
from .base import Generative # noqa
from .base import PARSE_AUTOCOMMIT # noqa
from .dml import Delete # noqa
from .dml import Insert # noqa
@@ -242,7 +241,6 @@ _UnaryExpression = UnaryExpression
_Case = Case
_Tuple = Tuple
_Over = Over
_Generative = Generative
_TypeClause = TypeClause
_Extract = Extract
_Exists = Exists
+25 -45
View File
@@ -17,7 +17,6 @@ from . import sqltypes
from . import util as sqlutil
from .base import ColumnCollection
from .base import Executable
from .elements import _clone
from .elements import _type_from_args
from .elements import BinaryExpression
from .elements import BindParameter
@@ -33,7 +32,8 @@ from .elements import WithinGroup
from .selectable import Alias
from .selectable import FromClause
from .selectable import Select
from .visitors import VisitableType
from .visitors import InternalTraversal
from .visitors import TraversibleType
from .. import util
@@ -78,10 +78,14 @@ class FunctionElement(Executable, ColumnElement, FromClause):
"""
_traverse_internals = [("clause_expr", InternalTraversal.dp_clauseelement)]
packagenames = ()
_has_args = False
_memoized_property = FromClause._memoized_property
def __init__(self, *clauses, **kwargs):
r"""Construct a :class:`.FunctionElement`.
@@ -136,7 +140,7 @@ class FunctionElement(Executable, ColumnElement, FromClause):
col = self.label(None)
return ColumnCollection(columns=[(col.key, col)])
@util.memoized_property
@_memoized_property
def clauses(self):
"""Return the underlying :class:`.ClauseList` which contains
the arguments for this :class:`.FunctionElement`.
@@ -283,17 +287,6 @@ class FunctionElement(Executable, ColumnElement, FromClause):
def _from_objects(self):
return self.clauses._from_objects
def get_children(self, **kwargs):
return (self.clause_expr,)
def _cache_key(self, **kw):
return (FunctionElement, self.clause_expr._cache_key(**kw))
def _copy_internals(self, clone=_clone, **kw):
self.clause_expr = clone(self.clause_expr, **kw)
self._reset_exported()
FunctionElement.clauses._reset(self)
def within_group_type(self, within_group):
"""For types that define their return type as based on the criteria
within a WITHIN GROUP (ORDER BY) expression, called by the
@@ -404,6 +397,13 @@ class FunctionElement(Executable, ColumnElement, FromClause):
class FunctionAsBinary(BinaryExpression):
_traverse_internals = [
("sql_function", InternalTraversal.dp_clauseelement),
("left_index", InternalTraversal.dp_plain_obj),
("right_index", InternalTraversal.dp_plain_obj),
("modifiers", InternalTraversal.dp_plain_dict),
]
def __init__(self, fn, left_index, right_index):
self.sql_function = fn
self.left_index = left_index
@@ -431,20 +431,6 @@ class FunctionAsBinary(BinaryExpression):
def right(self, value):
self.sql_function.clauses.clauses[self.right_index - 1] = value
def _copy_internals(self, clone=_clone, **kw):
self.sql_function = clone(self.sql_function, **kw)
def get_children(self, **kw):
yield self.sql_function
def _cache_key(self, **kw):
return (
FunctionAsBinary,
self.sql_function._cache_key(**kw),
self.left_index,
self.right_index,
)
class _FunctionGenerator(object):
"""Generate SQL function expressions.
@@ -606,6 +592,12 @@ class Function(FunctionElement):
__visit_name__ = "function"
_traverse_internals = FunctionElement._traverse_internals + [
("packagenames", InternalTraversal.dp_plain_obj),
("name", InternalTraversal.dp_string),
("type", InternalTraversal.dp_type),
]
def __init__(self, name, *clauses, **kw):
"""Construct a :class:`.Function`.
@@ -630,15 +622,8 @@ class Function(FunctionElement):
unique=True,
)
def _cache_key(self, **kw):
return (
(Function,) + tuple(self.packagenames)
if self.packagenames
else () + (self.name, self.clause_expr._cache_key(**kw))
)
class _GenericMeta(VisitableType):
class _GenericMeta(TraversibleType):
def __init__(cls, clsname, bases, clsdict):
if annotation.Annotated not in cls.__mro__:
cls.name = name = clsdict.get("name", clsname)
@@ -764,6 +749,10 @@ class next_value(GenericFunction):
type = sqltypes.Integer()
name = "next_value"
_traverse_internals = [
("sequence", InternalTraversal.dp_named_ddl_element)
]
def __init__(self, seq, **kw):
assert isinstance(
seq, schema.Sequence
@@ -771,21 +760,12 @@ class next_value(GenericFunction):
self._bind = kw.get("bind", None)
self.sequence = seq
def _cache_key(self, **kw):
return (next_value, self.sequence.name)
def compare(self, other, **kw):
return (
isinstance(other, next_value)
and self.sequence.name == other.sequence.name
)
def get_children(self, **kwargs):
return []
def _copy_internals(self, **kw):
pass
@property
def _from_objects(self):
return []
+18
View File
@@ -50,6 +50,7 @@ from .elements import ColumnElement
from .elements import quoted_name
from .elements import TextClause
from .selectable import TableClause
from .visitors import InternalTraversal
from .. import event
from .. import exc
from .. import inspection
@@ -425,6 +426,21 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
__visit_name__ = "table"
_traverse_internals = TableClause._traverse_internals + [
("schema", InternalTraversal.dp_string)
]
def _gen_cache_key(self, anon_map, bindparams):
return (self,)
@util.deprecated_params(
useexisting=(
"0.7",
"The :paramref:`.Table.useexisting` parameter is deprecated and "
"will be removed in a future release. Please use "
":paramref:`.Table.extend_existing`.",
)
)
def __new__(cls, *args, **kw):
if not args:
# python3k pickle seems to call this
@@ -763,6 +779,8 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
def get_children(
self, column_collections=True, schema_visitor=False, **kw
):
# TODO: consider that we probably don't need column_collections=True
# at all, it does not seem to impact anything
if not schema_visitor:
return TableClause.get_children(
self, column_collections=column_collections, **kw
+149 -249
View File
@@ -31,6 +31,7 @@ from .base import ColumnSet
from .base import DedupeColumnCollection
from .base import Executable
from .base import Generative
from .base import HasMemoized
from .base import Immutable
from .coercions import _document_text_coercion
from .elements import _anonymous_label
@@ -39,11 +40,13 @@ from .elements import and_
from .elements import BindParameter
from .elements import ClauseElement
from .elements import ClauseList
from .elements import ColumnClause
from .elements import GroupedElement
from .elements import Grouping
from .elements import literal_column
from .elements import True_
from .elements import UnaryExpression
from .visitors import InternalTraversal
from .. import exc
from .. import util
@@ -201,6 +204,8 @@ class Selectable(ReturnsRows):
class HasPrefixes(object):
_prefixes = ()
_traverse_internals = [("_prefixes", InternalTraversal.dp_prefix_sequence)]
@_generative
@_document_text_coercion(
"expr",
@@ -252,6 +257,8 @@ class HasPrefixes(object):
class HasSuffixes(object):
_suffixes = ()
_traverse_internals = [("_suffixes", InternalTraversal.dp_prefix_sequence)]
@_generative
@_document_text_coercion(
"expr",
@@ -295,7 +302,7 @@ class HasSuffixes(object):
)
class FromClause(roles.AnonymizedFromClauseRole, Selectable):
class FromClause(HasMemoized, roles.AnonymizedFromClauseRole, Selectable):
"""Represent an element that can be used within the ``FROM``
clause of a ``SELECT`` statement.
@@ -529,11 +536,6 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
"""
return getattr(self, "name", self.__class__.__name__ + " object")
def _reset_exported(self):
"""delete memoized collections when a FromClause is cloned."""
self._memoized_property.expire_instance(self)
def _generate_fromclause_column_proxies(self, fromclause):
fromclause._columns._populate_separate_keys(
col._make_proxy(fromclause) for col in self.c
@@ -668,6 +670,14 @@ class Join(FromClause):
__visit_name__ = "join"
_traverse_internals = [
("left", InternalTraversal.dp_clauseelement),
("right", InternalTraversal.dp_clauseelement),
("onclause", InternalTraversal.dp_clauseelement),
("isouter", InternalTraversal.dp_boolean),
("full", InternalTraversal.dp_boolean),
]
_is_join = True
def __init__(self, left, right, onclause=None, isouter=False, full=False):
@@ -805,25 +815,6 @@ class Join(FromClause):
self.left._refresh_for_new_column(column)
self.right._refresh_for_new_column(column)
def _copy_internals(self, clone=_clone, **kw):
self._reset_exported()
self.left = clone(self.left, **kw)
self.right = clone(self.right, **kw)
self.onclause = clone(self.onclause, **kw)
def get_children(self, **kwargs):
return self.left, self.right, self.onclause
def _cache_key(self, **kw):
return (
Join,
self.isouter,
self.full,
self.left._cache_key(**kw),
self.right._cache_key(**kw),
self.onclause._cache_key(**kw),
)
def _match_primaries(self, left, right):
if isinstance(left, Join):
left_right = left.right
@@ -1175,6 +1166,11 @@ class AliasedReturnsRows(FromClause):
_is_from_container = True
named_with_column = True
_traverse_internals = [
("element", InternalTraversal.dp_clauseelement),
("name", InternalTraversal.dp_anon_name),
]
def __init__(self, *arg, **kw):
raise NotImplementedError(
"The %s class is not intended to be constructed "
@@ -1243,18 +1239,13 @@ class AliasedReturnsRows(FromClause):
def _copy_internals(self, clone=_clone, **kw):
element = clone(self.element, **kw)
# the element clone is usually against a Table that returns the
# same object. don't reset exported .c. collections and other
# memoized details if nothing changed
if element is not self.element:
self._reset_exported()
self.element = element
def get_children(self, column_collections=True, **kw):
if column_collections:
for c in self.c:
yield c
yield self.element
def _cache_key(self, **kw):
return (self.__class__, self.element._cache_key(**kw), self._orig_name)
self.element = element
@property
def _from_objects(self):
@@ -1396,6 +1387,11 @@ class TableSample(AliasedReturnsRows):
__visit_name__ = "tablesample"
_traverse_internals = AliasedReturnsRows._traverse_internals + [
("sampling", InternalTraversal.dp_clauseelement),
("seed", InternalTraversal.dp_clauseelement),
]
@classmethod
def _factory(cls, selectable, sampling, name=None, seed=None):
"""Return a :class:`.TableSample` object.
@@ -1466,6 +1462,16 @@ class CTE(Generative, HasSuffixes, AliasedReturnsRows):
__visit_name__ = "cte"
_traverse_internals = (
AliasedReturnsRows._traverse_internals
+ [
("_cte_alias", InternalTraversal.dp_clauseelement),
("_restates", InternalTraversal.dp_clauseelement_unordered_set),
("recursive", InternalTraversal.dp_boolean),
]
+ HasSuffixes._traverse_internals
)
@classmethod
def _factory(cls, selectable, name=None, recursive=False):
r"""Return a new :class:`.CTE`, or Common Table Expression instance.
@@ -1495,15 +1501,13 @@ class CTE(Generative, HasSuffixes, AliasedReturnsRows):
def _copy_internals(self, clone=_clone, **kw):
super(CTE, self)._copy_internals(clone, **kw)
# TODO: I don't like that we can't use the traversal data here
if self._cte_alias is not None:
self._cte_alias = clone(self._cte_alias, **kw)
self._restates = frozenset(
[clone(elem, **kw) for elem in self._restates]
)
def _cache_key(self, *arg, **kw):
raise NotImplementedError("TODO")
def alias(self, name=None, flat=False):
"""Return an :class:`.Alias` of this :class:`.CTE`.
@@ -1764,6 +1768,8 @@ class Subquery(AliasedReturnsRows):
class FromGrouping(GroupedElement, FromClause):
"""Represent a grouping of a FROM clause"""
_traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
def __init__(self, element):
self.element = coercions.expect(roles.FromClauseRole, element)
@@ -1792,15 +1798,6 @@ class FromGrouping(GroupedElement, FromClause):
def _hide_froms(self):
return self.element._hide_froms
def get_children(self, **kwargs):
return (self.element,)
def _copy_internals(self, clone=_clone, **kw):
self.element = clone(self.element, **kw)
def _cache_key(self, **kw):
return (FromGrouping, self.element._cache_key(**kw))
@property
def _from_objects(self):
return self.element._from_objects
@@ -1843,6 +1840,14 @@ class TableClause(Immutable, FromClause):
__visit_name__ = "table"
_traverse_internals = [
(
"columns",
InternalTraversal.dp_fromclause_canonical_column_collection,
),
("name", InternalTraversal.dp_string),
]
named_with_column = True
implicit_returning = False
@@ -1895,17 +1900,6 @@ class TableClause(Immutable, FromClause):
self._columns.add(c)
c.table = self
def get_children(self, column_collections=True, **kwargs):
if column_collections:
return [c for c in self.c]
else:
return []
def _cache_key(self, **kw):
return (TableClause, self.name) + tuple(
col._cache_key(**kw) for col in self._columns
)
@util.dependencies("sqlalchemy.sql.dml")
def insert(self, dml, values=None, inline=False, **kwargs):
"""Generate an :func:`.insert` construct against this
@@ -1965,6 +1959,13 @@ class TableClause(Immutable, FromClause):
class ForUpdateArg(ClauseElement):
_traverse_internals = [
("of", InternalTraversal.dp_clauseelement_list),
("nowait", InternalTraversal.dp_boolean),
("read", InternalTraversal.dp_boolean),
("skip_locked", InternalTraversal.dp_boolean),
]
@classmethod
def parse_legacy_select(self, arg):
"""Parse the for_update argument of :func:`.select`.
@@ -2029,19 +2030,6 @@ class ForUpdateArg(ClauseElement):
def __hash__(self):
return id(self)
def _copy_internals(self, clone=_clone, **kw):
if self.of is not None:
self.of = [clone(col, **kw) for col in self.of]
def _cache_key(self, **kw):
return (
ForUpdateArg,
self.nowait,
self.read,
self.skip_locked,
self.of._cache_key(**kw) if self.of is not None else None,
)
def __init__(
self,
nowait=False,
@@ -2074,6 +2062,7 @@ class SelectBase(
roles.DMLSelectRole,
roles.CompoundElementRole,
roles.InElementRole,
HasMemoized,
HasCTE,
Executable,
SupportsCloneAnnotations,
@@ -2092,9 +2081,6 @@ class SelectBase(
_memoized_property = util.group_expirable_memoized_property()
def _reset_memoizations(self):
self._memoized_property.expire_instance(self)
def _generate_fromclause_column_proxies(self, fromclause):
# type: (FromClause)
raise NotImplementedError()
@@ -2339,6 +2325,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
"""
__visit_name__ = "grouping"
_traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
_is_select_container = True
@@ -2350,9 +2337,6 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
def select_statement(self):
return self.element
def get_children(self, **kwargs):
return (self.element,)
def self_group(self, against=None):
# type: (Optional[Any]) -> FromClause
return self
@@ -2377,12 +2361,6 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
"""
return self.element.selected_columns
def _copy_internals(self, clone=_clone, **kw):
self.element = clone(self.element, **kw)
def _cache_key(self, **kw):
return (SelectStatementGrouping, self.element._cache_key(**kw))
@property
def _from_objects(self):
return self.element._from_objects
@@ -2758,9 +2736,6 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
def _label_resolve_dict(self):
raise NotImplementedError()
def _copy_internals(self, clone=_clone, **kw):
raise NotImplementedError()
class CompoundSelect(GenerativeSelect):
"""Forms the basis of ``UNION``, ``UNION ALL``, and other
@@ -2785,6 +2760,16 @@ class CompoundSelect(GenerativeSelect):
__visit_name__ = "compound_select"
_traverse_internals = [
("selects", InternalTraversal.dp_clauseelement_list),
("_limit_clause", InternalTraversal.dp_clauseelement),
("_offset_clause", InternalTraversal.dp_clauseelement),
("_order_by_clause", InternalTraversal.dp_clauseelement),
("_group_by_clause", InternalTraversal.dp_clauseelement),
("_for_update_arg", InternalTraversal.dp_clauseelement),
("keyword", InternalTraversal.dp_string),
] + SupportsCloneAnnotations._traverse_internals
UNION = util.symbol("UNION")
UNION_ALL = util.symbol("UNION ALL")
EXCEPT = util.symbol("EXCEPT")
@@ -3004,47 +2989,6 @@ class CompoundSelect(GenerativeSelect):
"""
return self.selects[0].selected_columns
def _copy_internals(self, clone=_clone, **kw):
self._reset_memoizations()
self.selects = [clone(s, **kw) for s in self.selects]
if hasattr(self, "_col_map"):
del self._col_map
for attr in (
"_limit_clause",
"_offset_clause",
"_order_by_clause",
"_group_by_clause",
"_for_update_arg",
):
if getattr(self, attr) is not None:
setattr(self, attr, clone(getattr(self, attr), **kw))
def get_children(self, **kwargs):
return [self._order_by_clause, self._group_by_clause] + list(
self.selects
)
def _cache_key(self, **kw):
return (
(CompoundSelect, self.keyword)
+ tuple(stmt._cache_key(**kw) for stmt in self.selects)
+ (
self._order_by_clause._cache_key(**kw)
if self._order_by_clause is not None
else None,
)
+ (
self._group_by_clause._cache_key(**kw)
if self._group_by_clause is not None
else None,
)
+ (
self._for_update_arg._cache_key(**kw)
if self._for_update_arg is not None
else None,
)
)
def bind(self):
if self._bind:
return self._bind
@@ -3193,11 +3137,35 @@ class Select(
_hints = util.immutabledict()
_statement_hints = ()
_distinct = False
_from_cloned = None
_distinct_on = ()
_correlate = ()
_correlate_except = None
_memoized_property = SelectBase._memoized_property
_traverse_internals = (
[
("_from_obj", InternalTraversal.dp_fromclause_ordered_set),
("_raw_columns", InternalTraversal.dp_clauseelement_list),
("_whereclause", InternalTraversal.dp_clauseelement),
("_having", InternalTraversal.dp_clauseelement),
("_order_by_clause", InternalTraversal.dp_clauseelement_list),
("_group_by_clause", InternalTraversal.dp_clauseelement_list),
("_correlate", InternalTraversal.dp_clauseelement_unordered_set),
(
"_correlate_except",
InternalTraversal.dp_clauseelement_unordered_set,
),
("_for_update_arg", InternalTraversal.dp_clauseelement),
("_statement_hints", InternalTraversal.dp_statement_hint_list),
("_hints", InternalTraversal.dp_table_hint_list),
("_distinct", InternalTraversal.dp_boolean),
("_distinct_on", InternalTraversal.dp_clauseelement_list),
]
+ HasPrefixes._traverse_internals
+ HasSuffixes._traverse_internals
+ SupportsCloneAnnotations._traverse_internals
)
@util.deprecated_params(
autocommit=(
"0.6",
@@ -3416,13 +3384,14 @@ class Select(
"""
self._auto_correlate = correlate
if distinct is not False:
if distinct is True:
self._distinct = True
else:
self._distinct = [
coercions.expect(roles.WhereHavingRole, e)
for e in util.to_list(distinct)
]
self._distinct = True
if not isinstance(distinct, bool):
self._distinct_on = tuple(
[
coercions.expect(roles.WhereHavingRole, e)
for e in util.to_list(distinct)
]
)
if from_obj is not None:
self._from_obj = util.OrderedSet(
@@ -3472,15 +3441,17 @@ class Select(
GenerativeSelect.__init__(self, **kwargs)
# @_memoized_property
@property
def _froms(self):
# would love to cache this,
# but there's just enough edge cases, particularly now that
# declarative encourages construction of SQL expressions
# without tables present, to just regen this each time.
# current roadblock to caching is two tests that test that the
# SELECT can be compiled to a string, then a Table is created against
# columns, then it can be compiled again and works. this is somewhat
# valid as people make select() against declarative class where
# columns don't have their Table yet and perhaps some operations
# call upon _froms and cache it too soon.
froms = []
seen = set()
translate = self._from_cloned
for item in itertools.chain(
_from_objects(*self._raw_columns),
@@ -3493,8 +3464,6 @@ class Select(
raise exc.InvalidRequestError(
"select() construct refers to itself as a FROM"
)
if translate and item in translate:
item = translate[item]
if not seen.intersection(item._cloned_set):
froms.append(item)
seen.update(item._cloned_set)
@@ -3518,15 +3487,6 @@ class Select(
itertools.chain(*[_expand_cloned(f._hide_froms) for f in froms])
)
if toremove:
# if we're maintaining clones of froms,
# add the copies out to the toremove list. only include
# clones that are lexical equivalents.
if self._from_cloned:
toremove.update(
self._from_cloned[f]
for f in toremove.intersection(self._from_cloned)
if self._from_cloned[f]._is_lexical_equivalent(f)
)
# filter out to FROM clauses not in the list,
# using a list to maintain ordering
froms = [f for f in froms if f not in toremove]
@@ -3707,7 +3667,6 @@ class Select(
return False
def _copy_internals(self, clone=_clone, **kw):
# Select() object has been cloned and probably adapted by the
# given clone function. Apply the cloning function to internal
# objects
@@ -3719,37 +3678,42 @@ class Select(
# as of 0.7.4 we also put the current version of _froms, which
# gets cleared on each generation. previously we were "baking"
# _froms into self._from_obj.
self._from_cloned = from_cloned = dict(
(f, clone(f, **kw)) for f in self._from_obj.union(self._froms)
)
# 3. update persistent _from_obj with the cloned versions.
all_the_froms = list(
itertools.chain(
_from_objects(*self._raw_columns),
_from_objects(self._whereclause)
if self._whereclause is not None
else (),
)
)
new_froms = {f: clone(f, **kw) for f in all_the_froms}
# copy FROM collections
self._from_obj = util.OrderedSet(
from_cloned[f] for f in self._from_obj
)
clone(f, **kw) for f in self._from_obj
).union(f for f in new_froms.values() if isinstance(f, Join))
# the _correlate collection is done separately, what can happen
# here is the same item is _correlate as in _from_obj but the
# _correlate version has an annotation on it - (specifically
# RelationshipProperty.Comparator._criterion_exists() does
# this). Also keep _correlate liberally open with its previous
# contents, as this set is used for matching, not rendering.
self._correlate = set(clone(f) for f in self._correlate).union(
self._correlate
)
# do something similar for _correlate_except - this is a more
# unusual case but same idea applies
self._correlate = set(clone(f) for f in self._correlate)
if self._correlate_except:
self._correlate_except = set(
clone(f) for f in self._correlate_except
).union(self._correlate_except)
)
# 4. clone other things. The difficulty here is that Column
# objects are not actually cloned, and refer to their original
# .table, resulting in the wrong "from" parent after a clone
# operation. Hence _from_cloned and _from_obj supersede what is
# present here.
# objects are usually not altered by a straight clone because they
# are dependent on the FROM cloning we just did above in order to
# be targeted correctly, or a new FROM we have might be a JOIN
# object which doesn't have its own columns. so give the cloner a
# hint.
def replace(obj, **kw):
if isinstance(obj, ColumnClause) and obj.table in new_froms:
newelem = new_froms[obj.table].corresponding_column(obj)
return newelem
kw["replace"] = replace
# TODO: I'd still like to try to leverage the traversal data
self._raw_columns = [clone(c, **kw) for c in self._raw_columns]
for attr in (
"_limit_clause",
@@ -3763,67 +3727,12 @@ class Select(
if getattr(self, attr) is not None:
setattr(self, attr, clone(getattr(self, attr), **kw))
# erase _froms collection,
# etc.
self._reset_memoizations()
def get_children(self, **kwargs):
"""return child elements as per the ClauseElement specification."""
return (
self._raw_columns
+ list(self._froms)
+ [
x
for x in (
self._whereclause,
self._having,
self._order_by_clause,
self._group_by_clause,
)
if x is not None
]
)
def _cache_key(self, **kw):
return (
(Select,)
+ ("raw_columns",)
+ tuple(elem._cache_key(**kw) for elem in self._raw_columns)
+ ("elements",)
+ tuple(
elem._cache_key(**kw) if elem is not None else None
for elem in (
self._whereclause,
self._having,
self._order_by_clause,
self._group_by_clause,
)
)
+ ("from_obj",)
+ tuple(elem._cache_key(**kw) for elem in self._from_obj)
+ ("correlate",)
+ tuple(
elem._cache_key(**kw)
for elem in (
self._correlate if self._correlate is not None else ()
)
)
+ ("correlate_except",)
+ tuple(
elem._cache_key(**kw)
for elem in (
self._correlate_except
if self._correlate_except is not None
else ()
)
)
+ ("for_update",),
(
self._for_update_arg._cache_key(**kw)
if self._for_update_arg is not None
else None,
),
# TODO: define "get_children" traversal items separately?
return self._froms + super(Select, self).get_children(
omit_attrs=["_from_obj", "_correlate", "_correlate_except"]
)
@_generative
@@ -3987,10 +3896,8 @@ class Select(
"""
if expr:
expr = [coercions.expect(roles.ByOfRole, e) for e in expr]
if isinstance(self._distinct, list):
self._distinct = self._distinct + expr
else:
self._distinct = expr
self._distinct = True
self._distinct_on = self._distinct_on + tuple(expr)
else:
self._distinct = True
@@ -4489,6 +4396,11 @@ class TextualSelect(SelectBase):
__visit_name__ = "textual_select"
_traverse_internals = [
("element", InternalTraversal.dp_clauseelement),
("column_args", InternalTraversal.dp_clauseelement_list),
] + SupportsCloneAnnotations._traverse_internals
_is_textual = True
def __init__(self, text, columns, positional=False):
@@ -4534,18 +4446,6 @@ class TextualSelect(SelectBase):
c._make_proxy(fromclause) for c in self.column_args
)
def _copy_internals(self, clone=_clone, **kw):
self._reset_memoizations()
self.element = clone(self.element, **kw)
def get_children(self, **kw):
return [self.element]
def _cache_key(self, **kw):
return (TextualSelect, self.element._cache_key(**kw)) + tuple(
col._cache_key(**kw) for col in self.column_args
)
def _scalar_type(self):
return self.column_args[0].type
+768
View File
@@ -0,0 +1,768 @@
from collections import deque
from collections import namedtuple
from . import operators
from .visitors import ExtendedInternalTraversal
from .visitors import InternalTraversal
from .. import inspect
from .. import util
SKIP_TRAVERSE = util.symbol("skip_traverse")
COMPARE_FAILED = False
COMPARE_SUCCEEDED = True
NO_CACHE = util.symbol("no_cache")
def compare(obj1, obj2, **kw):
if kw.get("use_proxies", False):
strategy = ColIdentityComparatorStrategy()
else:
strategy = TraversalComparatorStrategy()
return strategy.compare(obj1, obj2, **kw)
class HasCacheKey(object):
_cache_key_traversal = NO_CACHE
def _gen_cache_key(self, anon_map, bindparams):
"""return an optional cache key.
The cache key is a tuple which can contain any series of
objects that are hashable and also identifies
this object uniquely within the presence of a larger SQL expression
or statement, for the purposes of caching the resulting query.
The cache key should be based on the SQL compiled structure that would
ultimately be produced. That is, two structures that are composed in
exactly the same way should produce the same cache key; any difference
in the strucures that would affect the SQL string or the type handlers
should result in a different cache key.
If a structure cannot produce a useful cache key, it should raise
NotImplementedError, which will result in the entire structure
for which it's part of not being useful as a cache key.
"""
if self in anon_map:
return (anon_map[self], self.__class__)
id_ = anon_map[self]
if self._cache_key_traversal is NO_CACHE:
anon_map[NO_CACHE] = True
return None
result = (id_, self.__class__)
for attrname, obj, meth in _cache_key_traversal.run_generated_dispatch(
self, self._cache_key_traversal, "_generated_cache_key_traversal"
):
if obj is not None:
result += meth(attrname, obj, self, anon_map, bindparams)
return result
def _generate_cache_key(self):
"""return a cache key.
The cache key is a tuple which can contain any series of
objects that are hashable and also identifies
this object uniquely within the presence of a larger SQL expression
or statement, for the purposes of caching the resulting query.
The cache key should be based on the SQL compiled structure that would
ultimately be produced. That is, two structures that are composed in
exactly the same way should produce the same cache key; any difference
in the strucures that would affect the SQL string or the type handlers
should result in a different cache key.
The cache key returned by this method is an instance of
:class:`.CacheKey`, which consists of a tuple representing the
cache key, as well as a list of :class:`.BindParameter` objects
which are extracted from the expression. While two expressions
that produce identical cache key tuples will themselves generate
identical SQL strings, the list of :class:`.BindParameter` objects
indicates the bound values which may have different values in
each one; these bound parameters must be consulted in order to
execute the statement with the correct parameters.
a :class:`.ClauseElement` structure that does not implement
a :meth:`._gen_cache_key` method and does not implement a
:attr:`.traverse_internals` attribute will not be cacheable; when
such an element is embedded into a larger structure, this method
will return None, indicating no cache key is available.
"""
bindparams = []
_anon_map = anon_map()
key = self._gen_cache_key(_anon_map, bindparams)
if NO_CACHE in _anon_map:
return None
else:
return CacheKey(key, bindparams)
class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])):
def __hash__(self):
return hash(self.key)
def __eq__(self, other):
return self.key == other.key
def _clone(element, **kw):
return element._clone()
class _CacheKey(ExtendedInternalTraversal):
def visit_has_cache_key(self, attrname, obj, parent, anon_map, bindparams):
return (attrname, obj._gen_cache_key(anon_map, bindparams))
def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
return self.visit_has_cache_key(
attrname, inspect(obj), parent, anon_map, bindparams
)
def visit_clauseelement(self, attrname, obj, parent, anon_map, bindparams):
return (attrname, obj._gen_cache_key(anon_map, bindparams))
def visit_multi(self, attrname, obj, parent, anon_map, bindparams):
return (
attrname,
obj._gen_cache_key(anon_map, bindparams)
if isinstance(obj, HasCacheKey)
else obj,
)
def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams):
return (
attrname,
tuple(
elem._gen_cache_key(anon_map, bindparams)
if isinstance(elem, HasCacheKey)
else elem
for elem in obj
),
)
def visit_has_cache_key_tuples(
self, attrname, obj, parent, anon_map, bindparams
):
return (
attrname,
tuple(
tuple(
elem._gen_cache_key(anon_map, bindparams)
for elem in tup_elem
)
for tup_elem in obj
),
)
def visit_has_cache_key_list(
self, attrname, obj, parent, anon_map, bindparams
):
return (
attrname,
tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
)
def visit_inspectable_list(
self, attrname, obj, parent, anon_map, bindparams
):
return self.visit_has_cache_key_list(
attrname, [inspect(o) for o in obj], parent, anon_map, bindparams
)
def visit_clauseelement_list(
self, attrname, obj, parent, anon_map, bindparams
):
return (
attrname,
tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
)
def visit_clauseelement_tuples(
self, attrname, obj, parent, anon_map, bindparams
):
return self.visit_has_cache_key_tuples(
attrname, obj, parent, anon_map, bindparams
)
def visit_anon_name(self, attrname, obj, parent, anon_map, bindparams):
from . import elements
name = obj
if isinstance(name, elements._anonymous_label):
name = name.apply_map(anon_map)
return (attrname, name)
def visit_fromclause_ordered_set(
self, attrname, obj, parent, anon_map, bindparams
):
return (
attrname,
tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
)
def visit_clauseelement_unordered_set(
self, attrname, obj, parent, anon_map, bindparams
):
cache_keys = [
elem._gen_cache_key(anon_map, bindparams) for elem in obj
]
return (
attrname,
tuple(
sorted(cache_keys)
), # cache keys all start with (id_, class)
)
def visit_named_ddl_element(
self, attrname, obj, parent, anon_map, bindparams
):
return (attrname, obj.name)
def visit_prefix_sequence(
self, attrname, obj, parent, anon_map, bindparams
):
return (
attrname,
tuple(
(clause._gen_cache_key(anon_map, bindparams), strval)
for clause, strval in obj
),
)
def visit_statement_hint_list(
self, attrname, obj, parent, anon_map, bindparams
):
return (attrname, obj)
def visit_table_hint_list(
self, attrname, obj, parent, anon_map, bindparams
):
return (
attrname,
tuple(
(
clause._gen_cache_key(anon_map, bindparams),
dialect_name,
text,
)
for (clause, dialect_name), text in obj.items()
),
)
def visit_type(self, attrname, obj, parent, anon_map, bindparams):
return (attrname, obj._gen_cache_key)
def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams):
return (attrname, tuple((key, obj[key]) for key in sorted(obj)))
def visit_string_clauseelement_dict(
self, attrname, obj, parent, anon_map, bindparams
):
return (
attrname,
tuple(
(key, obj[key]._gen_cache_key(anon_map, bindparams))
for key in sorted(obj)
),
)
def visit_string_multi_dict(
self, attrname, obj, parent, anon_map, bindparams
):
return (
attrname,
tuple(
(
key,
value._gen_cache_key(anon_map, bindparams)
if isinstance(value, HasCacheKey)
else value,
)
for key, value in [(key, obj[key]) for key in sorted(obj)]
),
)
def visit_string(self, attrname, obj, parent, anon_map, bindparams):
return (attrname, obj)
def visit_boolean(self, attrname, obj, parent, anon_map, bindparams):
return (attrname, obj)
def visit_operator(self, attrname, obj, parent, anon_map, bindparams):
return (attrname, obj)
def visit_plain_obj(self, attrname, obj, parent, anon_map, bindparams):
return (attrname, obj)
def visit_fromclause_canonical_column_collection(
self, attrname, obj, parent, anon_map, bindparams
):
return (
attrname,
tuple(col._gen_cache_key(anon_map, bindparams) for col in obj),
)
def visit_annotations_state(
self, attrname, obj, parent, anon_map, bindparams
):
return (
attrname,
tuple(
(
key,
self.dispatch(sym)(
key, obj[key], obj, anon_map, bindparams
),
)
for key, sym in parent._annotation_traversals
),
)
def visit_unknown_structure(
self, attrname, obj, parent, anon_map, bindparams
):
anon_map[NO_CACHE] = True
return ()
_cache_key_traversal = _CacheKey()
class _CopyInternals(InternalTraversal):
"""Generate a _copy_internals internal traversal dispatch for classes
with a _traverse_internals collection."""
def visit_clauseelement(self, parent, element, clone=_clone, **kw):
return clone(element, **kw)
def visit_clauseelement_list(self, parent, element, clone=_clone, **kw):
return [clone(clause, **kw) for clause in element]
def visit_clauseelement_tuples(self, parent, element, clone=_clone, **kw):
return [
tuple(clone(tup_elem, **kw) for tup_elem in elem)
for elem in element
]
def visit_string_clauseelement_dict(
self, parent, element, clone=_clone, **kw
):
return dict(
(key, clone(value, **kw)) for key, value in element.items()
)
_copy_internals = _CopyInternals()
class _GetChildren(InternalTraversal):
"""Generate a _children_traversal internal traversal dispatch for classes
with a _traverse_internals collection."""
def visit_has_cache_key(self, element, **kw):
return (element,)
def visit_clauseelement(self, element, **kw):
return (element,)
def visit_clauseelement_list(self, element, **kw):
return tuple(element)
def visit_clauseelement_tuples(self, element, **kw):
tup = ()
for elem in element:
tup += elem
return tup
def visit_fromclause_canonical_column_collection(self, element, **kw):
if kw.get("column_collections", False):
return tuple(element)
else:
return ()
def visit_string_clauseelement_dict(self, element, **kw):
return tuple(element.values())
def visit_fromclause_ordered_set(self, element, **kw):
return tuple(element)
def visit_clauseelement_unordered_set(self, element, **kw):
return tuple(element)
_get_children = _GetChildren()
@util.dependencies("sqlalchemy.sql.elements")
def _resolve_name_for_compare(elements, element, name, anon_map, **kw):
if isinstance(name, elements._anonymous_label):
name = name.apply_map(anon_map)
return name
class anon_map(dict):
"""A map that creates new keys for missing key access.
Produces an incrementing sequence given a series of unique keys.
This is similar to the compiler prefix_anon_map class although simpler.
Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which
is otherwise usually used for this type of operation.
"""
def __init__(self):
self.index = 0
def __missing__(self, key):
self[key] = val = str(self.index)
self.index += 1
return val
class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
__slots__ = "stack", "cache", "anon_map"
def __init__(self):
self.stack = deque()
self.cache = set()
def _memoized_attr_anon_map(self):
return (anon_map(), anon_map())
def compare(self, obj1, obj2, **kw):
stack = self.stack
cache = self.cache
compare_annotations = kw.get("compare_annotations", False)
stack.append((obj1, obj2))
while stack:
left, right = stack.popleft()
if left is right:
continue
elif left is None or right is None:
# we know they are different so no match
return False
elif (left, right) in cache:
continue
cache.add((left, right))
visit_name = left.__visit_name__
if visit_name != right.__visit_name__:
return False
meth = getattr(self, "compare_%s" % visit_name, None)
if meth:
attributes_compared = meth(left, right, **kw)
if attributes_compared is COMPARE_FAILED:
return False
elif attributes_compared is SKIP_TRAVERSE:
continue
# attributes_compared is returned as a list of attribute
# names that were "handled" by the comparison method above.
# remaining attribute names in the _traverse_internals
# will be compared.
else:
attributes_compared = ()
for (
(left_attrname, left_visit_sym),
(right_attrname, right_visit_sym),
) in util.zip_longest(
left._traverse_internals,
right._traverse_internals,
fillvalue=(None, None),
):
if (
left_attrname != right_attrname
or left_visit_sym is not right_visit_sym
):
if not compare_annotations and (
(
left_visit_sym
is InternalTraversal.dp_annotations_state,
)
or (
right_visit_sym
is InternalTraversal.dp_annotations_state,
)
):
continue
return False
elif left_attrname in attributes_compared:
continue
dispatch = self.dispatch(left_visit_sym)
left_child = getattr(left, left_attrname)
right_child = getattr(right, right_attrname)
if left_child is None:
if right_child is not None:
return False
else:
continue
comparison = dispatch(
left, left_child, right, right_child, **kw
)
if comparison is COMPARE_FAILED:
return False
return True
def compare_inner(self, obj1, obj2, **kw):
comparator = self.__class__()
return comparator.compare(obj1, obj2, **kw)
def visit_has_cache_key(
self, left_parent, left, right_parent, right, **kw
):
if left._gen_cache_key(self.anon_map[0], []) != right._gen_cache_key(
self.anon_map[1], []
):
return COMPARE_FAILED
def visit_clauseelement(
self, left_parent, left, right_parent, right, **kw
):
self.stack.append((left, right))
def visit_fromclause_canonical_column_collection(
self, left_parent, left, right_parent, right, **kw
):
for lcol, rcol in util.zip_longest(left, right, fillvalue=None):
self.stack.append((lcol, rcol))
def visit_fromclause_derived_column_collection(
self, left_parent, left, right_parent, right, **kw
):
pass
def visit_string_clauseelement_dict(
self, left_parent, left, right_parent, right, **kw
):
for lstr, rstr in util.zip_longest(
sorted(left), sorted(right), fillvalue=None
):
if lstr != rstr:
return COMPARE_FAILED
self.stack.append((left[lstr], right[rstr]))
def visit_annotations_state(
self, left_parent, left, right_parent, right, **kw
):
if not kw.get("compare_annotations", False):
return
for (lstr, lmeth), (rstr, rmeth) in util.zip_longest(
left_parent._annotation_traversals,
right_parent._annotation_traversals,
fillvalue=(None, None),
):
if lstr != rstr or (lmeth is not rmeth):
return COMPARE_FAILED
dispatch = self.dispatch(lmeth)
left_child = left[lstr]
right_child = right[rstr]
if left_child is None:
if right_child is not None:
return False
else:
continue
comparison = dispatch(None, left_child, None, right_child, **kw)
if comparison is COMPARE_FAILED:
return comparison
def visit_clauseelement_tuples(
self, left_parent, left, right_parent, right, **kw
):
for ltup, rtup in util.zip_longest(left, right, fillvalue=None):
if ltup is None or rtup is None:
return COMPARE_FAILED
for l, r in util.zip_longest(ltup, rtup, fillvalue=None):
self.stack.append((l, r))
def visit_clauseelement_list(
self, left_parent, left, right_parent, right, **kw
):
for l, r in util.zip_longest(left, right, fillvalue=None):
self.stack.append((l, r))
def _compare_unordered_sequences(self, seq1, seq2, **kw):
if seq1 is None:
return seq2 is None
completed = set()
for clause in seq1:
for other_clause in set(seq2).difference(completed):
if self.compare_inner(clause, other_clause, **kw):
completed.add(other_clause)
break
return len(completed) == len(seq1) == len(seq2)
def visit_clauseelement_unordered_set(
self, left_parent, left, right_parent, right, **kw
):
return self._compare_unordered_sequences(left, right, **kw)
def visit_fromclause_ordered_set(
self, left_parent, left, right_parent, right, **kw
):
for l, r in util.zip_longest(left, right, fillvalue=None):
self.stack.append((l, r))
def visit_string(self, left_parent, left, right_parent, right, **kw):
return left == right
def visit_anon_name(self, left_parent, left, right_parent, right, **kw):
return _resolve_name_for_compare(
left_parent, left, self.anon_map[0], **kw
) == _resolve_name_for_compare(
right_parent, right, self.anon_map[1], **kw
)
def visit_boolean(self, left_parent, left, right_parent, right, **kw):
return left == right
def visit_operator(self, left_parent, left, right_parent, right, **kw):
return left is right
def visit_type(self, left_parent, left, right_parent, right, **kw):
return left._compare_type_affinity(right)
def visit_plain_dict(self, left_parent, left, right_parent, right, **kw):
return left == right
def visit_plain_obj(self, left_parent, left, right_parent, right, **kw):
return left == right
def visit_named_ddl_element(
self, left_parent, left, right_parent, right, **kw
):
if left is None:
if right is not None:
return COMPARE_FAILED
return left.name == right.name
def visit_prefix_sequence(
self, left_parent, left, right_parent, right, **kw
):
for (l_clause, l_str), (r_clause, r_str) in util.zip_longest(
left, right, fillvalue=(None, None)
):
if l_str != r_str:
return COMPARE_FAILED
else:
self.stack.append((l_clause, r_clause))
def visit_table_hint_list(
self, left_parent, left, right_parent, right, **kw
):
left_keys = sorted(left, key=lambda elem: (elem[0].fullname, elem[1]))
right_keys = sorted(
right, key=lambda elem: (elem[0].fullname, elem[1])
)
for (ltable, ldialect), (rtable, rdialect) in util.zip_longest(
left_keys, right_keys, fillvalue=(None, None)
):
if ldialect != rdialect:
return COMPARE_FAILED
elif left[(ltable, ldialect)] != right[(rtable, rdialect)]:
return COMPARE_FAILED
else:
self.stack.append((ltable, rtable))
def visit_statement_hint_list(
self, left_parent, left, right_parent, right, **kw
):
return left == right
def visit_unknown_structure(
self, left_parent, left, right_parent, right, **kw
):
raise NotImplementedError()
def compare_clauselist(self, left, right, **kw):
if left.operator is right.operator:
if operators.is_associative(left.operator):
if self._compare_unordered_sequences(
left.clauses, right.clauses, **kw
):
return ["operator", "clauses"]
else:
return COMPARE_FAILED
else:
return ["operator"]
else:
return COMPARE_FAILED
def compare_binary(self, left, right, **kw):
if left.operator == right.operator:
if operators.is_commutative(left.operator):
if (
compare(left.left, right.left, **kw)
and compare(left.right, right.right, **kw)
) or (
compare(left.left, right.right, **kw)
and compare(left.right, right.left, **kw)
):
return ["operator", "negate", "left", "right"]
else:
return COMPARE_FAILED
else:
return ["operator", "negate"]
else:
return COMPARE_FAILED
class ColIdentityComparatorStrategy(TraversalComparatorStrategy):
def compare_column_element(
self, left, right, use_proxies=True, equivalents=(), **kw
):
"""Compare ColumnElements using proxies and equivalent collections.
This is a comparison strategy specific to the ORM.
"""
to_compare = (right,)
if equivalents and right in equivalents:
to_compare = equivalents[right].union(to_compare)
for oth in to_compare:
if use_proxies and left.shares_lineage(oth):
return SKIP_TRAVERSE
elif hash(left) == hash(right):
return SKIP_TRAVERSE
else:
return COMPARE_FAILED
def compare_column(self, left, right, **kw):
return self.compare_column_element(left, right, **kw)
def compare_label(self, left, right, **kw):
return self.compare_column_element(left, right, **kw)
def compare_table(self, left, right, **kw):
# tables compare on identity, since it's not really feasible to
# compare them column by column with the above rules
return SKIP_TRAVERSE if left is right else COMPARE_FAILED
+11 -6
View File
@@ -12,8 +12,8 @@
from . import operators
from .base import SchemaEventTarget
from .visitors import Visitable
from .visitors import VisitableType
from .visitors import Traversible
from .visitors import TraversibleType
from .. import exc
from .. import util
@@ -28,7 +28,7 @@ INDEXABLE = None
_resolve_value_to_type = None
class TypeEngine(Visitable):
class TypeEngine(Traversible):
"""The ultimate base class for all SQL datatypes.
Common subclasses of :class:`.TypeEngine` include
@@ -535,8 +535,13 @@ class TypeEngine(Visitable):
return dialect.type_descriptor(self)
@util.memoized_property
def _cache_key(self):
return util.constructor_key(self, self.__class__)
def _gen_cache_key(self):
names = util.get_cls_kwargs(self.__class__)
return (self.__class__,) + tuple(
(k, self.__dict__[k])
for k in names
if k in self.__dict__ and not k.startswith("_")
)
def adapt(self, cls, **kw):
"""Produce an "adapted" form of this type, given an "impl" class
@@ -617,7 +622,7 @@ class TypeEngine(Visitable):
return util.generic_repr(self)
class VisitableCheckKWArg(util.EnsureKWArgType, VisitableType):
class VisitableCheckKWArg(util.EnsureKWArgType, TraversibleType):
pass
+1 -1
View File
@@ -734,7 +734,7 @@ def criterion_as_pairs(
return pairs
class ClauseAdapter(visitors.ReplacingCloningVisitor):
class ClauseAdapter(visitors.ReplacingExternalTraversal):
"""Clones and modifies clauses based on column correspondence.
E.g.::
+385 -70
View File
@@ -28,14 +28,10 @@ import operator
from .. import exc
from .. import util
from ..util import langhelpers
from ..util import symbol
__all__ = [
"VisitableType",
"Visitable",
"ClauseVisitor",
"CloningVisitor",
"ReplacingCloningVisitor",
"iterate",
"iterate_depthfirst",
"traverse_using",
@@ -43,85 +39,382 @@ __all__ = [
"traverse_depthfirst",
"cloned_traverse",
"replacement_traverse",
"Traversible",
"TraversibleType",
"ExternalTraversal",
"InternalTraversal",
]
class VisitableType(type):
"""Metaclass which assigns a ``_compiler_dispatch`` method to classes
having a ``__visit_name__`` attribute.
def _generate_compiler_dispatch(cls):
"""Generate a _compiler_dispatch() external traversal on classes with a
__visit_name__ attribute.
The ``_compiler_dispatch`` attribute becomes an instance method which
looks approximately like the following::
"""
visit_name = cls.__visit_name__
def _compiler_dispatch (self, visitor, **kw):
'''Look for an attribute named "visit_" + self.__visit_name__
on the visitor, and call it with the same kw params.'''
visit_attr = 'visit_%s' % self.__visit_name__
return getattr(visitor, visit_attr)(self, **kw)
if isinstance(visit_name, util.compat.string_types):
# There is an optimization opportunity here because the
# the string name of the class's __visit_name__ is known at
# this early stage (import time) so it can be pre-constructed.
getter = operator.attrgetter("visit_%s" % visit_name)
Classes having no ``__visit_name__`` attribute will remain unaffected.
def _compiler_dispatch(self, visitor, **kw):
try:
meth = getter(visitor)
except AttributeError:
raise exc.UnsupportedCompilationError(visitor, cls)
else:
return meth(self, **kw)
else:
# The optimization opportunity is lost for this case because the
# __visit_name__ is not yet a string. As a result, the visit
# string has to be recalculated with each compilation.
def _compiler_dispatch(self, visitor, **kw):
visit_attr = "visit_%s" % self.__visit_name__
try:
meth = getattr(visitor, visit_attr)
except AttributeError:
raise exc.UnsupportedCompilationError(visitor, cls)
else:
return meth(self, **kw)
_compiler_dispatch.__doc__ = """Look for an attribute named "visit_"
+ self.__visit_name__ on the visitor, and call it with the same
kw params.
"""
cls._compiler_dispatch = _compiler_dispatch
class TraversibleType(type):
"""Metaclass which assigns dispatch attributes to various kinds of
"visitable" classes.
Attributes include:
* The ``_compiler_dispatch`` method, corresponding to ``__visit_name__``.
This is called "external traversal" because the caller of each visit()
method is responsible for sub-traversing the inner elements of each
object. This is appropriate for string compilers and other traversals
that need to call upon the inner elements in a specific pattern.
* internal traversal collections ``_children_traversal``,
``_cache_key_traversal``, ``_copy_internals_traversal``, generated from
an optional ``_traverse_internals`` collection of symbols which comes
from the :class:`.InternalTraversal` list of symbols. This is called
"internal traversal" MARKMARK
"""
def __init__(cls, clsname, bases, clsdict):
if clsname != "Visitable" and hasattr(cls, "__visit_name__"):
_generate_dispatch(cls)
if clsname != "Traversible":
if "__visit_name__" in clsdict:
_generate_compiler_dispatch(cls)
super(VisitableType, cls).__init__(clsname, bases, clsdict)
super(TraversibleType, cls).__init__(clsname, bases, clsdict)
def _generate_dispatch(cls):
"""Return an optimized visit dispatch function for the cls
for use by the compiler.
"""
if "__visit_name__" in cls.__dict__:
visit_name = cls.__visit_name__
if isinstance(visit_name, util.compat.string_types):
# There is an optimization opportunity here because the
# the string name of the class's __visit_name__ is known at
# this early stage (import time) so it can be pre-constructed.
getter = operator.attrgetter("visit_%s" % visit_name)
def _compiler_dispatch(self, visitor, **kw):
try:
meth = getter(visitor)
except AttributeError:
raise exc.UnsupportedCompilationError(visitor, cls)
else:
return meth(self, **kw)
else:
# The optimization opportunity is lost for this case because the
# __visit_name__ is not yet a string. As a result, the visit
# string has to be recalculated with each compilation.
def _compiler_dispatch(self, visitor, **kw):
visit_attr = "visit_%s" % self.__visit_name__
try:
meth = getattr(visitor, visit_attr)
except AttributeError:
raise exc.UnsupportedCompilationError(visitor, cls)
else:
return meth(self, **kw)
_compiler_dispatch.__doc__ = """Look for an attribute named "visit_" + self.__visit_name__
on the visitor, and call it with the same kw params.
"""
cls._compiler_dispatch = _compiler_dispatch
class Visitable(util.with_metaclass(VisitableType, object)):
class Traversible(util.with_metaclass(TraversibleType)):
"""Base class for visitable objects, applies the
:class:`.visitors.VisitableType` metaclass.
The :class:`.Visitable` class is essentially at the base of the
:class:`.ClauseElement` hierarchy.
:class:`.visitors.TraversibleType` metaclass.
"""
class ClauseVisitor(object):
"""Base class for visitor objects which can traverse using
class _InternalTraversalType(type):
def __init__(cls, clsname, bases, clsdict):
if cls.__name__ in ("InternalTraversal", "ExtendedInternalTraversal"):
lookup = {}
for key, sym in clsdict.items():
if key.startswith("dp_"):
visit_key = key.replace("dp_", "visit_")
sym_name = sym.name
assert sym_name not in lookup, sym_name
lookup[sym] = lookup[sym_name] = visit_key
if hasattr(cls, "_dispatch_lookup"):
lookup.update(cls._dispatch_lookup)
cls._dispatch_lookup = lookup
super(_InternalTraversalType, cls).__init__(clsname, bases, clsdict)
def _generate_dispatcher(visitor, internal_dispatch, method_name):
names = []
for attrname, visit_sym in internal_dispatch:
meth = visitor.dispatch(visit_sym)
if meth:
visit_name = ExtendedInternalTraversal._dispatch_lookup[visit_sym]
names.append((attrname, visit_name))
code = (
(" return [\n")
+ (
", \n".join(
" (%r, self.%s, visitor.%s)"
% (attrname, attrname, visit_name)
for attrname, visit_name in names
)
)
+ ("\n ]\n")
)
meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n"
# print(meth_text)
return langhelpers._exec_code_in_env(meth_text, {}, method_name)
class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
r"""Defines visitor symbols used for internal traversal.
The :class:`.InternalTraversal` class is used in two ways. One is that
it can serve as the superclass for an object that implements the
various visit methods of the class. The other is that the symbols
themselves of :class:`.InternalTraversal` are used within
the ``_traverse_internals`` collection. Such as, the :class:`.Case`
object defines ``_travserse_internals`` as ::
_traverse_internals = [
("value", InternalTraversal.dp_clauseelement),
("whens", InternalTraversal.dp_clauseelement_tuples),
("else_", InternalTraversal.dp_clauseelement),
]
Above, the :class:`.Case` class indicates its internal state as the
attribtues named ``value``, ``whens``, and ``else\_``. They each
link to an :class:`.InternalTraversal` method which indicates the type
of datastructure referred towards.
Using the ``_traverse_internals`` structure, objects of type
:class:`.InternalTraversible` will have the following methods automatically
implemented:
* :meth:`.Traversible.get_children`
* :meth:`.Traversible._copy_internals`
* :meth:`.Traversible._gen_cache_key`
Subclasses can also implement these methods directly, particularly for the
:meth:`.Traversible._copy_internals` method, when special steps
are needed.
.. versionadded:: 1.4
"""
def dispatch(self, visit_symbol):
"""Given a method from :class:`.InternalTraversal`, return the
corresponding method on a subclass.
"""
name = self._dispatch_lookup[visit_symbol]
return getattr(self, name, None)
def run_generated_dispatch(
self, target, internal_dispatch, generate_dispatcher_name
):
try:
dispatcher = target.__class__.__dict__[generate_dispatcher_name]
except KeyError:
dispatcher = _generate_dispatcher(
self, internal_dispatch, generate_dispatcher_name
)
setattr(target.__class__, generate_dispatcher_name, dispatcher)
return dispatcher(target, self)
dp_has_cache_key = symbol("HC")
"""Visit a :class:`.HasCacheKey` object."""
dp_clauseelement = symbol("CE")
"""Visit a :class:`.ClauseElement` object."""
dp_fromclause_canonical_column_collection = symbol("FC")
"""Visit a :class:`.FromClause` object in the context of the
``columns`` attribute.
The column collection is "canonical", meaning it is the originally
defined location of the :class:`.ColumnClause` objects. Right now
this means that the object being visited is a :class:`.TableClause`
or :class:`.Table` object only.
"""
dp_clauseelement_tuples = symbol("CT")
"""Visit a list of tuples which contain :class:`.ClauseElement`
objects.
"""
dp_clauseelement_list = symbol("CL")
"""Visit a list of :class:`.ClauseElement` objects.
"""
dp_clauseelement_unordered_set = symbol("CU")
"""Visit an unordered set of :class:`.ClauseElement` objects. """
dp_fromclause_ordered_set = symbol("CO")
"""Visit an ordered set of :class:`.FromClause` objects. """
dp_string = symbol("S")
"""Visit a plain string value.
Examples include table and column names, bound parameter keys, special
keywords such as "UNION", "UNION ALL".
The string value is considered to be significant for cache key
generation.
"""
dp_anon_name = symbol("AN")
"""Visit a potentially "anonymized" string value.
The string value is considered to be significant for cache key
generation.
"""
dp_boolean = symbol("B")
"""Visit a boolean value.
The boolean value is considered to be significant for cache key
generation.
"""
dp_operator = symbol("O")
"""Visit an operator.
The operator is a function from the :mod:`sqlalchemy.sql.operators`
module.
The operator value is considered to be significant for cache key
generation.
"""
dp_type = symbol("T")
"""Visit a :class:`.TypeEngine` object
The type object is considered to be significant for cache key
generation.
"""
dp_plain_dict = symbol("PD")
"""Visit a dictionary with string keys.
The keys of the dictionary should be strings, the values should
be immutable and hashable. The dictionary is considered to be
significant for cache key generation.
"""
dp_string_clauseelement_dict = symbol("CD")
"""Visit a dictionary of string keys to :class:`.ClauseElement`
objects.
"""
dp_string_multi_dict = symbol("MD")
"""Visit a dictionary of string keys to values which may either be
plain immutable/hashable or :class:`.HasCacheKey` objects.
"""
dp_plain_obj = symbol("PO")
"""Visit a plain python object.
The value should be immutable and hashable, such as an integer.
The value is considered to be significant for cache key generation.
"""
dp_annotations_state = symbol("A")
"""Visit the state of the :class:`.Annotatated` version of an object.
"""
dp_named_ddl_element = symbol("DD")
"""Visit a simple named DDL element.
The current object used by this method is the :class:`.Sequence`.
The object is only considered to be important for cache key generation
as far as its name, but not any other aspects of it.
"""
dp_prefix_sequence = symbol("PS")
"""Visit the sequence represented by :class:`.HasPrefixes`
or :class:`.HasSuffixes`.
"""
dp_table_hint_list = symbol("TH")
"""Visit the ``_hints`` collection of a :class:`.Select` object.
"""
dp_statement_hint_list = symbol("SH")
"""Visit the ``_statement_hints`` collection of a :class:`.Select`
object.
"""
dp_unknown_structure = symbol("UK")
"""Visit an unknown structure.
"""
class ExtendedInternalTraversal(InternalTraversal):
"""defines additional symbols that are useful in caching applications.
Traversals for :class:`.ClauseElement` objects only need to use
those symbols present in :class:`.InternalTraversal`. However, for
additional caching use cases within the ORM, symbols dealing with the
:class:`.HasCacheKey` class are added here.
"""
dp_ignore = symbol("IG")
"""Specify an object that should be ignored entirely.
This currently applies function call argument caching where some
arguments should not be considered to be part of a cache key.
"""
dp_inspectable = symbol("IS")
"""Visit an inspectable object where the return value is a HasCacheKey`
object."""
dp_multi = symbol("M")
"""Visit an object that may be a :class:`.HasCacheKey` or may be a
plain hashable object."""
dp_multi_list = symbol("MT")
"""Visit a tuple containing elements that may be :class:`.HasCacheKey` or
may be a plain hashable object."""
dp_has_cache_key_tuples = symbol("HT")
"""Visit a list of tuples which contain :class:`.HasCacheKey`
objects.
"""
dp_has_cache_key_list = symbol("HL")
"""Visit a list of :class:`.HasCacheKey` objects."""
dp_inspectable_list = symbol("IL")
"""Visit a list of inspectable objects which upon inspection are
HasCacheKey objects."""
class ExternalTraversal(object):
"""Base class for visitor objects which can traverse externally using
the :func:`.visitors.traverse` function.
Direct usage of the :func:`.visitors.traverse` function is usually
@@ -178,7 +471,7 @@ class ClauseVisitor(object):
return self
class CloningVisitor(ClauseVisitor):
class CloningExternalTraversal(ExternalTraversal):
"""Base class for visitor objects which can traverse using
the :func:`.visitors.cloned_traverse` function.
@@ -203,7 +496,7 @@ class CloningVisitor(ClauseVisitor):
)
class ReplacingCloningVisitor(CloningVisitor):
class ReplacingExternalTraversal(CloningExternalTraversal):
"""Base class for visitor objects which can traverse using
the :func:`.visitors.replacement_traverse` function.
@@ -233,6 +526,14 @@ class ReplacingCloningVisitor(CloningVisitor):
return replacement_traverse(obj, self.__traverse_options__, replace)
# backwards compatibility
Visitable = Traversible
VisitableType = TraversibleType
ClauseVisitor = ExternalTraversal
CloningVisitor = CloningExternalTraversal
ReplacingCloningVisitor = ReplacingExternalTraversal
def iterate(obj, opts):
r"""traverse the given expression structure, returning an iterator.
@@ -405,11 +706,18 @@ def cloned_traverse(obj, opts, visitors):
cloned = {}
stop_on = set(opts.get("stop_on", []))
def clone(elem):
def clone(elem, **kw):
if elem in stop_on:
return elem
else:
if id(elem) not in cloned:
if "replace" in kw:
newelem = kw["replace"](elem)
if newelem is not None:
cloned[id(elem)] = newelem
return newelem
cloned[id(elem)] = newelem = elem._clone()
newelem._copy_internals(clone=clone)
meth = visitors.get(newelem.__visit_name__, None)
@@ -461,7 +769,14 @@ def replacement_traverse(obj, opts, replace):
stop_on.add(id(newelem))
return newelem
else:
if elem not in cloned:
if "replace" in kw:
newelem = kw["replace"](elem)
if newelem is not None:
cloned[elem] = newelem
return newelem
cloned[elem] = newelem = elem._clone()
newelem._copy_internals(clone=clone, **kw)
return cloned[elem]
+4 -4
View File
@@ -934,7 +934,7 @@ class BranchedOptionTest(fixtures.MappedTest):
configure_mappers()
def test_generate_cache_key_unbound_branching(self):
def test_generate_path_cache_key_unbound_branching(self):
A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G")
base = joinedload(A.bs)
@@ -950,11 +950,11 @@ class BranchedOptionTest(fixtures.MappedTest):
@profiling.function_call_count()
def go():
for opt in opts:
opt._generate_cache_key(cache_path)
opt._generate_path_cache_key(cache_path)
go()
def test_generate_cache_key_bound_branching(self):
def test_generate_path_cache_key_bound_branching(self):
A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G")
base = Load(A).joinedload(A.bs)
@@ -970,7 +970,7 @@ class BranchedOptionTest(fixtures.MappedTest):
@profiling.function_call_count()
def go():
for opt in opts:
opt._generate_cache_key(cache_path)
opt._generate_path_cache_key(cache_path)
go()
+1 -1
View File
@@ -1533,7 +1533,7 @@ class CustomIntegrationTest(testing.AssertsCompiledSQL, BakedTest):
if query._current_path:
query._cache_key = "user7_addresses"
def _generate_cache_key(self, path):
def _generate_path_cache_key(self, path):
return None
return RelationshipCache()
+120
View File
@@ -0,0 +1,120 @@
from sqlalchemy import inspect
from sqlalchemy.orm import aliased
from sqlalchemy.orm import defaultload
from sqlalchemy.orm import defer
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Load
from sqlalchemy.orm import subqueryload
from sqlalchemy.testing import eq_
from test.orm import _fixtures
from ..sql.test_compare import CacheKeyFixture
class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
run_setup_mappers = "once"
run_inserts = None
run_deletes = None
@classmethod
def setup_mappers(cls):
cls._setup_stock_mapping()
def test_mapper_and_aliased(self):
User, Address, Keyword = self.classes("User", "Address", "Keyword")
self._run_cache_key_fixture(
lambda: (inspect(User), inspect(Address), inspect(aliased(User)))
)
def test_attributes(self):
User, Address, Keyword = self.classes("User", "Address", "Keyword")
self._run_cache_key_fixture(
lambda: (
User.id,
Address.id,
aliased(User).id,
aliased(User, name="foo").id,
aliased(User, name="bar").id,
User.name,
User.addresses,
Address.email_address,
aliased(User).addresses,
)
)
def test_unbound_options(self):
User, Address, Keyword, Order, Item = self.classes(
"User", "Address", "Keyword", "Order", "Item"
)
self._run_cache_key_fixture(
lambda: (
joinedload(User.addresses),
joinedload("addresses"),
joinedload(User.orders).selectinload("items"),
joinedload(User.orders).selectinload(Order.items),
defer(User.id),
defer("id"),
defer(Address.id),
joinedload(User.addresses).defer(Address.id),
joinedload(aliased(User).addresses).defer(Address.id),
joinedload(User.addresses).defer("id"),
joinedload(User.orders).joinedload(Order.items),
joinedload(User.orders).subqueryload(Order.items),
subqueryload(User.orders).subqueryload(Order.items),
subqueryload(User.orders)
.subqueryload(Order.items)
.defer(Item.description),
defaultload(User.orders).defaultload(Order.items),
defaultload(User.orders),
)
)
def test_bound_options(self):
User, Address, Keyword, Order, Item = self.classes(
"User", "Address", "Keyword", "Order", "Item"
)
self._run_cache_key_fixture(
lambda: (
Load(User).joinedload(User.addresses),
Load(User).joinedload(User.orders),
Load(User).defer(User.id),
Load(User).subqueryload("addresses"),
Load(Address).defer("id"),
Load(aliased(Address)).defer("id"),
Load(User).joinedload(User.addresses).defer(Address.id),
Load(User).joinedload(User.orders).joinedload(Order.items),
Load(User).joinedload(User.orders).subqueryload(Order.items),
Load(User).subqueryload(User.orders).subqueryload(Order.items),
Load(User)
.subqueryload(User.orders)
.subqueryload(Order.items)
.defer(Item.description),
Load(User).defaultload(User.orders).defaultload(Order.items),
Load(User).defaultload(User.orders),
)
)
def test_bound_options_equiv_on_strname(self):
"""Bound loader options resolve on string name so test that the cache
key for the string version matches the resolved version.
"""
User, Address, Keyword, Order, Item = self.classes(
"User", "Address", "Keyword", "Order", "Item"
)
for left, right in [
(Load(User).defer(User.id), Load(User).defer("id")),
(
Load(User).joinedload(User.addresses),
Load(User).joinedload("addresses"),
),
(
Load(User).joinedload(User.orders).joinedload(Order.items),
Load(User).joinedload("orders").joinedload("items"),
),
]:
eq_(left._generate_cache_key(), right._generate_cache_key())
+46 -46
View File
@@ -1790,7 +1790,7 @@ class SubOptionsTest(PathTest, QueryTest):
)
class CacheKeyTest(PathTest, QueryTest):
class PathedCacheKeyTest(PathTest, QueryTest):
run_create_tables = False
run_inserts = None
@@ -1805,7 +1805,7 @@ class CacheKeyTest(PathTest, QueryTest):
opt = joinedload(User.orders).joinedload(Order.items)
eq_(
opt._generate_cache_key(query_path),
opt._generate_path_cache_key(query_path),
(((Order, "items", Item, ("lazy", "joined")),)),
)
@@ -1821,12 +1821,12 @@ class CacheKeyTest(PathTest, QueryTest):
opt2 = base.joinedload(Order.address)
eq_(
opt1._generate_cache_key(query_path),
opt1._generate_path_cache_key(query_path),
(((Order, "items", Item, ("lazy", "joined")),)),
)
eq_(
opt2._generate_cache_key(query_path),
opt2._generate_path_cache_key(query_path),
(((Order, "address", Address, ("lazy", "joined")),)),
)
@@ -1842,12 +1842,12 @@ class CacheKeyTest(PathTest, QueryTest):
opt2 = base.joinedload(Order.address)
eq_(
opt1._generate_cache_key(query_path),
opt1._generate_path_cache_key(query_path),
(((Order, "items", Item, ("lazy", "joined")),)),
)
eq_(
opt2._generate_cache_key(query_path),
opt2._generate_path_cache_key(query_path),
(((Order, "address", Address, ("lazy", "joined")),)),
)
@@ -1860,7 +1860,7 @@ class CacheKeyTest(PathTest, QueryTest):
opt = Load(User).joinedload(User.orders).joinedload(Order.items)
eq_(
opt._generate_cache_key(query_path),
opt._generate_path_cache_key(query_path),
(((Order, "items", Item, ("lazy", "joined")),)),
)
@@ -1872,7 +1872,7 @@ class CacheKeyTest(PathTest, QueryTest):
query_path = self._make_path_registry([User, "addresses"])
opt = joinedload(User.orders).joinedload(Order.items)
eq_(opt._generate_cache_key(query_path), None)
eq_(opt._generate_path_cache_key(query_path), None)
def test_bound_cache_key_excluded_on_other(self):
User, Address, Order, Item, SubItem = self.classes(
@@ -1882,7 +1882,7 @@ class CacheKeyTest(PathTest, QueryTest):
query_path = self._make_path_registry([User, "addresses"])
opt = Load(User).joinedload(User.orders).joinedload(Order.items)
eq_(opt._generate_cache_key(query_path), None)
eq_(opt._generate_path_cache_key(query_path), None)
def test_unbound_cache_key_excluded_on_aliased(self):
User, Address, Order, Item, SubItem = self.classes(
@@ -1901,7 +1901,7 @@ class CacheKeyTest(PathTest, QueryTest):
query_path = self._make_path_registry([User, "orders"])
opt = joinedload(aliased(User).orders).joinedload(Order.items)
eq_(opt._generate_cache_key(query_path), None)
eq_(opt._generate_path_cache_key(query_path), None)
def test_bound_cache_key_wildcard_one(self):
# do not change this test, it is testing
@@ -1911,7 +1911,7 @@ class CacheKeyTest(PathTest, QueryTest):
query_path = self._make_path_registry([User, "addresses"])
opt = Load(User).lazyload("*")
eq_(opt._generate_cache_key(query_path), None)
eq_(opt._generate_path_cache_key(query_path), None)
def test_unbound_cache_key_wildcard_one(self):
User, Address = self.classes("User", "Address")
@@ -1920,7 +1920,7 @@ class CacheKeyTest(PathTest, QueryTest):
opt = lazyload("*")
eq_(
opt._generate_cache_key(query_path),
opt._generate_path_cache_key(query_path),
(("relationship:_sa_default", ("lazy", "select")),),
)
@@ -1933,7 +1933,7 @@ class CacheKeyTest(PathTest, QueryTest):
opt = Load(User).lazyload("orders").lazyload("*")
eq_(
opt._generate_cache_key(query_path),
opt._generate_path_cache_key(query_path),
(
("orders", Order, ("lazy", "select")),
("orders", Order, "relationship:*", ("lazy", "select")),
@@ -1949,7 +1949,7 @@ class CacheKeyTest(PathTest, QueryTest):
opt = lazyload("orders").lazyload("*")
eq_(
opt._generate_cache_key(query_path),
opt._generate_path_cache_key(query_path),
(
("orders", Order, ("lazy", "select")),
("orders", Order, "relationship:*", ("lazy", "select")),
@@ -1968,7 +1968,7 @@ class CacheKeyTest(PathTest, QueryTest):
)
eq_(
opt._generate_cache_key(query_path),
opt._generate_path_cache_key(query_path),
(
(SubItem, ("lazy", "subquery")),
("extra_keywords", Keyword, ("lazy", "subquery")),
@@ -1987,7 +1987,7 @@ class CacheKeyTest(PathTest, QueryTest):
)
eq_(
opt._generate_cache_key(query_path),
opt._generate_path_cache_key(query_path),
(
(SubItem, ("lazy", "subquery")),
("extra_keywords", Keyword, ("lazy", "subquery")),
@@ -2008,7 +2008,7 @@ class CacheKeyTest(PathTest, QueryTest):
)
eq_(
opt._generate_cache_key(query_path),
opt._generate_path_cache_key(query_path),
(
(SubItem, ("lazy", "subquery")),
("extra_keywords", Keyword, ("lazy", "subquery")),
@@ -2029,7 +2029,7 @@ class CacheKeyTest(PathTest, QueryTest):
)
eq_(
opt._generate_cache_key(query_path),
opt._generate_path_cache_key(query_path),
(
(SubItem, ("lazy", "subquery")),
("extra_keywords", Keyword, ("lazy", "subquery")),
@@ -2056,7 +2056,7 @@ class CacheKeyTest(PathTest, QueryTest):
opt = subqueryload(User.orders).subqueryload(
Order.items.of_type(SubItem)
)
eq_(opt._generate_cache_key(query_path), None)
eq_(opt._generate_path_cache_key(query_path), None)
def test_unbound_cache_key_excluded_of_type_unsafe(self):
User, Address, Order, Item, SubItem = self.classes(
@@ -2078,7 +2078,7 @@ class CacheKeyTest(PathTest, QueryTest):
opt = subqueryload(User.orders).subqueryload(
Order.items.of_type(aliased(SubItem))
)
eq_(opt._generate_cache_key(query_path), None)
eq_(opt._generate_path_cache_key(query_path), None)
def test_bound_cache_key_excluded_of_type_safe(self):
User, Address, Order, Item, SubItem = self.classes(
@@ -2102,7 +2102,7 @@ class CacheKeyTest(PathTest, QueryTest):
.subqueryload(User.orders)
.subqueryload(Order.items.of_type(SubItem))
)
eq_(opt._generate_cache_key(query_path), None)
eq_(opt._generate_path_cache_key(query_path), None)
def test_bound_cache_key_excluded_of_type_unsafe(self):
User, Address, Order, Item, SubItem = self.classes(
@@ -2126,7 +2126,7 @@ class CacheKeyTest(PathTest, QueryTest):
.subqueryload(User.orders)
.subqueryload(Order.items.of_type(aliased(SubItem)))
)
eq_(opt._generate_cache_key(query_path), None)
eq_(opt._generate_path_cache_key(query_path), None)
def test_unbound_cache_key_included_of_type_safe(self):
User, Address, Order, Item, SubItem = self.classes(
@@ -2137,7 +2137,7 @@ class CacheKeyTest(PathTest, QueryTest):
opt = joinedload(User.orders).joinedload(Order.items.of_type(SubItem))
eq_(
opt._generate_cache_key(query_path),
opt._generate_path_cache_key(query_path),
((Order, "items", SubItem, ("lazy", "joined")),),
)
@@ -2155,7 +2155,7 @@ class CacheKeyTest(PathTest, QueryTest):
)
eq_(
opt._generate_cache_key(query_path),
opt._generate_path_cache_key(query_path),
((Order, "items", SubItem, ("lazy", "joined")),),
)
@@ -2169,7 +2169,7 @@ class CacheKeyTest(PathTest, QueryTest):
opt = joinedload(User.orders).joinedload(
Order.items.of_type(aliased(SubItem))
)
eq_(opt._generate_cache_key(query_path), False)
eq_(opt._generate_path_cache_key(query_path), False)
def test_unbound_cache_key_included_unsafe_option_two(self):
User, Address, Order, Item, SubItem = self.classes(
@@ -2181,7 +2181,7 @@ class CacheKeyTest(PathTest, QueryTest):
opt = joinedload(User.orders).joinedload(
Order.items.of_type(aliased(SubItem))
)
eq_(opt._generate_cache_key(query_path), False)
eq_(opt._generate_path_cache_key(query_path), False)
def test_unbound_cache_key_included_unsafe_option_three(self):
User, Address, Order, Item, SubItem = self.classes(
@@ -2193,7 +2193,7 @@ class CacheKeyTest(PathTest, QueryTest):
opt = joinedload(User.orders).joinedload(
Order.items.of_type(aliased(SubItem))
)
eq_(opt._generate_cache_key(query_path), False)
eq_(opt._generate_path_cache_key(query_path), False)
def test_unbound_cache_key_included_unsafe_query(self):
User, Address, Order, Item, SubItem = self.classes(
@@ -2204,7 +2204,7 @@ class CacheKeyTest(PathTest, QueryTest):
query_path = self._make_path_registry([inspect(au), "orders"])
opt = joinedload(au.orders).joinedload(Order.items)
eq_(opt._generate_cache_key(query_path), False)
eq_(opt._generate_path_cache_key(query_path), False)
def test_unbound_cache_key_included_safe_w_deferred(self):
User, Address, Order, Item, SubItem = self.classes(
@@ -2219,7 +2219,7 @@ class CacheKeyTest(PathTest, QueryTest):
.defer(Address.user_id)
)
eq_(
opt._generate_cache_key(query_path),
opt._generate_path_cache_key(query_path),
(
(
Address,
@@ -2247,12 +2247,12 @@ class CacheKeyTest(PathTest, QueryTest):
)
eq_(
opt1._generate_cache_key(query_path),
opt1._generate_path_cache_key(query_path),
((Order, "items", Item, ("lazy", "joined")),),
)
eq_(
opt2._generate_cache_key(query_path),
opt2._generate_path_cache_key(query_path),
(
(Order, "address", Address, ("lazy", "joined")),
(
@@ -2288,7 +2288,7 @@ class CacheKeyTest(PathTest, QueryTest):
.defer(Address.user_id)
)
eq_(
opt._generate_cache_key(query_path),
opt._generate_path_cache_key(query_path),
(
(
Address,
@@ -2316,12 +2316,12 @@ class CacheKeyTest(PathTest, QueryTest):
)
eq_(
opt1._generate_cache_key(query_path),
opt1._generate_path_cache_key(query_path),
((Order, "items", Item, ("lazy", "joined")),),
)
eq_(
opt2._generate_cache_key(query_path),
opt2._generate_path_cache_key(query_path),
(
(Order, "address", Address, ("lazy", "joined")),
(
@@ -2356,7 +2356,7 @@ class CacheKeyTest(PathTest, QueryTest):
query_path = self._make_path_registry([User, "orders"])
eq_(
opt._generate_cache_key(query_path),
opt._generate_path_cache_key(query_path),
(
(
Order,
@@ -2385,7 +2385,7 @@ class CacheKeyTest(PathTest, QueryTest):
au = aliased(User)
opt = Load(au).joinedload(au.orders).joinedload(Order.items)
eq_(opt._generate_cache_key(query_path), None)
eq_(opt._generate_path_cache_key(query_path), None)
def test_bound_cache_key_included_unsafe_option_one(self):
User, Address, Order, Item, SubItem = self.classes(
@@ -2399,7 +2399,7 @@ class CacheKeyTest(PathTest, QueryTest):
.joinedload(User.orders)
.joinedload(Order.items.of_type(aliased(SubItem)))
)
eq_(opt._generate_cache_key(query_path), False)
eq_(opt._generate_path_cache_key(query_path), False)
def test_bound_cache_key_included_unsafe_option_two(self):
User, Address, Order, Item, SubItem = self.classes(
@@ -2413,7 +2413,7 @@ class CacheKeyTest(PathTest, QueryTest):
.joinedload(User.orders)
.joinedload(Order.items.of_type(aliased(SubItem)))
)
eq_(opt._generate_cache_key(query_path), False)
eq_(opt._generate_path_cache_key(query_path), False)
def test_bound_cache_key_included_unsafe_option_three(self):
User, Address, Order, Item, SubItem = self.classes(
@@ -2427,7 +2427,7 @@ class CacheKeyTest(PathTest, QueryTest):
.joinedload(User.orders)
.joinedload(Order.items.of_type(aliased(SubItem)))
)
eq_(opt._generate_cache_key(query_path), False)
eq_(opt._generate_path_cache_key(query_path), False)
def test_bound_cache_key_included_unsafe_query(self):
User, Address, Order, Item, SubItem = self.classes(
@@ -2438,7 +2438,7 @@ class CacheKeyTest(PathTest, QueryTest):
query_path = self._make_path_registry([inspect(au), "orders"])
opt = Load(au).joinedload(au.orders).joinedload(Order.items)
eq_(opt._generate_cache_key(query_path), False)
eq_(opt._generate_path_cache_key(query_path), False)
def test_bound_cache_key_included_safe_w_option(self):
User, Address, Order, Item, SubItem = self.classes(
@@ -2454,7 +2454,7 @@ class CacheKeyTest(PathTest, QueryTest):
query_path = self._make_path_registry([User, "orders"])
eq_(
opt._generate_cache_key(query_path),
opt._generate_path_cache_key(query_path),
(
(
Order,
@@ -2483,7 +2483,7 @@ class CacheKeyTest(PathTest, QueryTest):
opt = defaultload(User.addresses).load_only("id", "email_address")
eq_(
opt._generate_cache_key(query_path),
opt._generate_path_cache_key(query_path),
(
(Address, "id", ("deferred", False), ("instrument", True)),
(
@@ -2513,7 +2513,7 @@ class CacheKeyTest(PathTest, QueryTest):
Address.id, Address.email_address
)
eq_(
opt._generate_cache_key(query_path),
opt._generate_path_cache_key(query_path),
(
(Address, "id", ("deferred", False), ("instrument", True)),
(
@@ -2545,7 +2545,7 @@ class CacheKeyTest(PathTest, QueryTest):
.load_only("id", "email_address")
)
eq_(
opt._generate_cache_key(query_path),
opt._generate_path_cache_key(query_path),
(
(Address, "id", ("deferred", False), ("instrument", True)),
(
@@ -2572,7 +2572,7 @@ class CacheKeyTest(PathTest, QueryTest):
opt = defaultload(User.addresses).undefer_group("xyz")
eq_(
opt._generate_cache_key(query_path),
opt._generate_path_cache_key(query_path),
((Address, "column:*", ("undefer_group_xyz", True)),),
)
@@ -2584,6 +2584,6 @@ class CacheKeyTest(PathTest, QueryTest):
opt = Load(User).defaultload(User.addresses).undefer_group("xyz")
eq_(
opt._generate_cache_key(query_path),
opt._generate_path_cache_key(query_path),
((Address, "column:*", ("undefer_group_xyz", True)),),
)
+396 -119
View File
@@ -32,6 +32,7 @@ from sqlalchemy.sql import operators
from sqlalchemy.sql import True_
from sqlalchemy.sql import type_coerce
from sqlalchemy.sql import visitors
from sqlalchemy.sql.base import HasCacheKey
from sqlalchemy.sql.elements import _label_reference
from sqlalchemy.sql.elements import _textual_label_reference
from sqlalchemy.sql.elements import Annotated
@@ -46,13 +47,13 @@ from sqlalchemy.sql.functions import FunctionElement
from sqlalchemy.sql.functions import GenericFunction
from sqlalchemy.sql.functions import ReturnTypeFromArgs
from sqlalchemy.sql.selectable import _OffsetLimitParam
from sqlalchemy.sql.selectable import AliasedReturnsRows
from sqlalchemy.sql.selectable import FromGrouping
from sqlalchemy.sql.selectable import Selectable
from sqlalchemy.sql.selectable import SelectStatementGrouping
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.sql.visitors import InternalTraversal
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
from sqlalchemy.testing import is_false
from sqlalchemy.testing import is_true
from sqlalchemy.testing import ne_
@@ -63,8 +64,17 @@ meta = MetaData()
meta2 = MetaData()
table_a = Table("a", meta, Column("a", Integer), Column("b", String))
table_b_like_a = Table("b2", meta, Column("a", Integer), Column("b", String))
table_a_2 = Table("a", meta2, Column("a", Integer), Column("b", String))
table_a_2_fs = Table(
"a", meta2, Column("a", Integer), Column("b", String), schema="fs"
)
table_a_2_bs = Table(
"a", meta2, Column("a", Integer), Column("b", String), schema="bs"
)
table_b = Table("b", meta, Column("a", Integer), Column("b", Integer))
table_c = Table("c", meta, Column("x", Integer), Column("y", Integer))
@@ -72,8 +82,18 @@ table_c = Table("c", meta, Column("x", Integer), Column("y", Integer))
table_d = Table("d", meta, Column("y", Integer), Column("z", Integer))
class CompareAndCopyTest(fixtures.TestBase):
class MyEntity(HasCacheKey):
def __init__(self, name, element):
self.name = name
self.element = element
_cache_key_traversal = [
("name", InternalTraversal.dp_string),
("element", InternalTraversal.dp_clauseelement),
]
class CoreFixtures(object):
# lambdas which return a tuple of ColumnElement objects.
# must return at least two objects that should compare differently.
# to test more varieties of "difference" additional objects can be added.
@@ -100,11 +120,47 @@ class CompareAndCopyTest(fixtures.TestBase):
text("select a, b, c from table").columns(
a=Integer, b=String, c=Integer
),
text("select a, b, c from table where foo=:bar").bindparams(
bindparam("bar", Integer)
),
text("select a, b, c from table where foo=:foo").bindparams(
bindparam("foo", Integer)
),
text("select a, b, c from table where foo=:bar").bindparams(
bindparam("bar", String)
),
),
lambda: (
column("q") == column("x"),
column("q") == column("y"),
column("z") == column("x"),
column("z") + column("x"),
column("z") - column("x"),
column("x") - column("z"),
column("z") > column("x"),
# note these two are mathematically equivalent but for now they
# are considered to be different
column("z") >= column("x"),
column("x") <= column("z"),
column("q").between(5, 6),
column("q").between(5, 6, symmetric=True),
column("q").like("somstr"),
column("q").like("somstr", escape="\\"),
column("q").like("somstr", escape="X"),
),
lambda: (
table_a.c.a,
table_a.c.a._annotate({"orm": True}),
table_a.c.a._annotate({"orm": True})._annotate({"bar": False}),
table_a.c.a._annotate(
{"orm": True, "parententity": MyEntity("a", table_a)}
),
table_a.c.a._annotate(
{"orm": True, "parententity": MyEntity("b", table_a)}
),
table_a.c.a._annotate(
{"orm": True, "parententity": MyEntity("b", select([table_a]))}
),
),
lambda: (
cast(column("q"), Integer),
@@ -225,6 +281,58 @@ class CompareAndCopyTest(fixtures.TestBase):
.where(table_a.c.b == 5)
.correlate_except(table_b),
),
lambda: (
select([table_a.c.a]).cte(),
select([table_a.c.a]).cte(recursive=True),
select([table_a.c.a]).cte(name="some_cte", recursive=True),
select([table_a.c.a]).cte(name="some_cte"),
select([table_a.c.a]).cte(name="some_cte").alias("other_cte"),
select([table_a.c.a])
.cte(name="some_cte")
.union_all(select([table_a.c.a])),
select([table_a.c.a])
.cte(name="some_cte")
.union_all(select([table_a.c.b])),
select([table_a.c.a]).lateral(),
select([table_a.c.a]).lateral(name="bar"),
table_a.tablesample(func.bernoulli(1)),
table_a.tablesample(func.bernoulli(1), seed=func.random()),
table_a.tablesample(func.bernoulli(1), seed=func.other_random()),
table_a.tablesample(func.hoho(1)),
table_a.tablesample(func.bernoulli(1), name="bar"),
table_a.tablesample(
func.bernoulli(1), name="bar", seed=func.random()
),
),
lambda: (
select([table_a.c.a]),
select([table_a.c.a]).prefix_with("foo"),
select([table_a.c.a]).prefix_with("foo", dialect="mysql"),
select([table_a.c.a]).prefix_with("foo", dialect="postgresql"),
select([table_a.c.a]).prefix_with("bar"),
select([table_a.c.a]).suffix_with("bar"),
),
lambda: (
select([table_a_2.c.a]),
select([table_a_2_fs.c.a]),
select([table_a_2_bs.c.a]),
),
lambda: (
select([table_a.c.a]),
select([table_a.c.a]).with_hint(None, "some hint"),
select([table_a.c.a]).with_hint(None, "some other hint"),
select([table_a.c.a]).with_hint(table_a, "some hint"),
select([table_a.c.a])
.with_hint(table_a, "some hint")
.with_hint(None, "some other hint"),
select([table_a.c.a]).with_hint(table_a, "some other hint"),
select([table_a.c.a]).with_hint(
table_a, "some hint", dialect_name="mysql"
),
select([table_a.c.a]).with_hint(
table_a, "some hint", dialect_name="postgresql"
),
),
lambda: (
table_a.join(table_b, table_a.c.a == table_b.c.a),
table_a.join(
@@ -273,12 +381,202 @@ class CompareAndCopyTest(fixtures.TestBase):
table("a", column("x"), column("y", Integer)),
table("a", column("q"), column("y", Integer)),
),
lambda: (
Table("a", MetaData(), Column("q", Integer), Column("b", String)),
Table("b", MetaData(), Column("q", Integer), Column("b", String)),
),
lambda: (table_a, table_b),
]
def _complex_fixtures():
def one():
a1 = table_a.alias()
a2 = table_b_like_a.alias()
stmt = (
select([table_a.c.a, a1.c.b, a2.c.b])
.where(table_a.c.b == a1.c.b)
.where(a1.c.b == a2.c.b)
.where(a1.c.a == 5)
)
return stmt
def one_diff():
a1 = table_b_like_a.alias()
a2 = table_a.alias()
stmt = (
select([table_a.c.a, a1.c.b, a2.c.b])
.where(table_a.c.b == a1.c.b)
.where(a1.c.b == a2.c.b)
.where(a1.c.a == 5)
)
return stmt
def two():
inner = one().subquery()
stmt = select([table_b.c.a, inner.c.a, inner.c.b]).select_from(
table_b.join(inner, table_b.c.b == inner.c.b)
)
return stmt
def three():
a1 = table_a.alias()
a2 = table_a.alias()
ex = exists().where(table_b.c.b == a1.c.a)
stmt = (
select([a1.c.a, a2.c.a])
.select_from(a1.join(a2, a1.c.b == a2.c.b))
.where(ex)
)
return stmt
return [one(), one_diff(), two(), three()]
fixtures.append(_complex_fixtures)
class CacheKeyFixture(object):
def _run_cache_key_fixture(self, fixture):
case_a = fixture()
case_b = fixture()
for a, b in itertools.combinations_with_replacement(
range(len(case_a)), 2
):
if a == b:
a_key = case_a[a]._generate_cache_key()
b_key = case_b[b]._generate_cache_key()
eq_(a_key.key, b_key.key)
for a_param, b_param in zip(
a_key.bindparams, b_key.bindparams
):
assert a_param.compare(b_param, compare_values=False)
else:
a_key = case_a[a]._generate_cache_key()
b_key = case_b[b]._generate_cache_key()
if a_key.key == b_key.key:
for a_param, b_param in zip(
a_key.bindparams, b_key.bindparams
):
if not a_param.compare(b_param, compare_values=True):
break
else:
# this fails unconditionally since we could not
# find bound parameter values that differed.
# Usually we intended to get two distinct keys here
# so the failure will be more descriptive using the
# ne_() assertion.
ne_(a_key.key, b_key.key)
else:
ne_(a_key.key, b_key.key)
# ClauseElement-specific test to ensure the cache key
# collected all the bound parameters
if isinstance(case_a[a], ClauseElement) and isinstance(
case_b[b], ClauseElement
):
assert_a_params = []
assert_b_params = []
visitors.traverse_depthfirst(
case_a[a], {}, {"bindparam": assert_a_params.append}
)
visitors.traverse_depthfirst(
case_b[b], {}, {"bindparam": assert_b_params.append}
)
# note we're asserting the order of the params as well as
# if there are dupes or not. ordering has to be deterministic
# and matches what a traversal would provide.
# regular traverse_depthfirst does produce dupes in cases like
# select([some_alias]).
# select_from(join(some_alias, other_table))
# where a bound parameter is inside of some_alias. the
# cache key case is more minimalistic
eq_(
sorted(a_key.bindparams, key=lambda b: b.key),
sorted(
util.unique_list(assert_a_params), key=lambda b: b.key
),
)
eq_(
sorted(b_key.bindparams, key=lambda b: b.key),
sorted(
util.unique_list(assert_b_params), key=lambda b: b.key
),
)
class CacheKeyTest(CacheKeyFixture, CoreFixtures, fixtures.TestBase):
def test_cache_key(self):
for fixture in self.fixtures:
self._run_cache_key_fixture(fixture)
def test_cache_key_unknown_traverse(self):
class Foobar1(ClauseElement):
_traverse_internals = [
("key", InternalTraversal.dp_anon_name),
("type_", InternalTraversal.dp_unknown_structure),
]
def __init__(self, key, type_):
self.key = key
self.type_ = type_
f1 = Foobar1("foo", String())
eq_(f1._generate_cache_key(), None)
def test_cache_key_no_method(self):
class Foobar1(ClauseElement):
pass
class Foobar2(ColumnElement):
pass
# the None for cache key will prevent objects
# which contain these elements from being cached.
f1 = Foobar1()
eq_(f1._generate_cache_key(), None)
f2 = Foobar2()
eq_(f2._generate_cache_key(), None)
s1 = select([column("q"), Foobar2()])
eq_(s1._generate_cache_key(), None)
def test_get_children_no_method(self):
class Foobar1(ClauseElement):
pass
class Foobar2(ColumnElement):
pass
f1 = Foobar1()
eq_(f1.get_children(), [])
f2 = Foobar2()
eq_(f2.get_children(), [])
def test_copy_internals_no_method(self):
class Foobar1(ClauseElement):
pass
class Foobar2(ColumnElement):
pass
f1 = Foobar1()
f2 = Foobar2()
f1._copy_internals()
f2._copy_internals()
class CompareAndCopyTest(CoreFixtures, fixtures.TestBase):
@classmethod
def setup_class(cls):
# TODO: we need to get dialects here somehow, perhaps in test_suite?
@@ -293,7 +591,10 @@ class CompareAndCopyTest(fixtures.TestBase):
cls
for cls in class_hierarchy(ClauseElement)
if issubclass(cls, (ColumnElement, Selectable))
and "__init__" in cls.__dict__
and (
"__init__" in cls.__dict__
or issubclass(cls, AliasedReturnsRows)
)
and not issubclass(cls, (Annotated))
and "orm" not in cls.__module__
and "compiler" not in cls.__module__
@@ -318,123 +619,16 @@ class CompareAndCopyTest(fixtures.TestBase):
):
if a == b:
is_true(
case_a[a].compare(
case_b[b], arbitrary_expression=True
),
case_a[a].compare(case_b[b], compare_annotations=True),
"%r != %r" % (case_a[a], case_b[b]),
)
else:
is_false(
case_a[a].compare(
case_b[b], arbitrary_expression=True
),
case_a[a].compare(case_b[b], compare_annotations=True),
"%r == %r" % (case_a[a], case_b[b]),
)
def test_cache_key(self):
def assert_params_append(assert_params):
def append(param):
if param._value_required_for_cache:
assert_params.append(param)
else:
is_(param.value, None)
return append
for fixture in self.fixtures:
case_a = fixture()
case_b = fixture()
for a, b in itertools.combinations_with_replacement(
range(len(case_a)), 2
):
assert_a_params = []
assert_b_params = []
visitors.traverse_depthfirst(
case_a[a],
{},
{"bindparam": assert_params_append(assert_a_params)},
)
visitors.traverse_depthfirst(
case_b[b],
{},
{"bindparam": assert_params_append(assert_b_params)},
)
if assert_a_params:
assert_raises_message(
NotImplementedError,
"bindparams collection argument required ",
case_a[a]._cache_key,
)
if assert_b_params:
assert_raises_message(
NotImplementedError,
"bindparams collection argument required ",
case_b[b]._cache_key,
)
if not assert_a_params and not assert_b_params:
if a == b:
eq_(case_a[a]._cache_key(), case_b[b]._cache_key())
else:
ne_(case_a[a]._cache_key(), case_b[b]._cache_key())
def test_cache_key_gather_bindparams(self):
for fixture in self.fixtures:
case_a = fixture()
case_b = fixture()
# in the "bindparams" case, the cache keys for bound parameters
# with only different values will be the same, but the params
# themselves are gathered into a collection.
for a, b in itertools.combinations_with_replacement(
range(len(case_a)), 2
):
a_params = {"bindparams": []}
b_params = {"bindparams": []}
if a == b:
a_key = case_a[a]._cache_key(**a_params)
b_key = case_b[b]._cache_key(**b_params)
eq_(a_key, b_key)
if a_params["bindparams"]:
for a_param, b_param in zip(
a_params["bindparams"], b_params["bindparams"]
):
assert a_param.compare(b_param)
else:
a_key = case_a[a]._cache_key(**a_params)
b_key = case_b[b]._cache_key(**b_params)
if a_key == b_key:
for a_param, b_param in zip(
a_params["bindparams"], b_params["bindparams"]
):
if not a_param.compare(b_param):
break
else:
assert False, "Bound parameters are all the same"
else:
ne_(a_key, b_key)
assert_a_params = []
assert_b_params = []
visitors.traverse_depthfirst(
case_a[a], {}, {"bindparam": assert_a_params.append}
)
visitors.traverse_depthfirst(
case_b[b], {}, {"bindparam": assert_b_params.append}
)
# note we're asserting the order of the params as well as
# if there are dupes or not. ordering has to be deterministic
# and matches what a traversal would provide.
eq_(a_params["bindparams"], assert_a_params)
eq_(b_params["bindparams"], assert_b_params)
def test_compare_col_identity(self):
stmt1 = (
select([table_a.c.a, table_b.c.b])
@@ -473,8 +667,9 @@ class CompareAndCopyTest(fixtures.TestBase):
assert case_a[0].compare(case_b[0])
clone = case_a[0]._clone()
clone._copy_internals()
clone = visitors.replacement_traverse(
case_a[0], {}, lambda elem: None
)
assert clone.compare(case_b[0])
@@ -511,6 +706,37 @@ class CompareAndCopyTest(fixtures.TestBase):
class CompareClausesTest(fixtures.TestBase):
def test_compare_metadata_tables(self):
# metadata Table objects cache on their own identity, not their
# structure. This is mainly to reduce the size of cache keys
# as well as reduce computational overhead, as Table objects have
# very large internal state and they are also generally global
# objects.
t1 = Table("a", MetaData(), Column("q", Integer), Column("p", Integer))
t2 = Table("a", MetaData(), Column("q", Integer), Column("p", Integer))
ne_(t1._generate_cache_key(), t2._generate_cache_key())
eq_(t1._generate_cache_key().key, (t1,))
def test_compare_adhoc_tables(self):
# non-metadata tables compare on their structure. these objects are
# not commonly used.
# note this test is a bit redundant as we have a similar test
# via the fixtures also
t1 = table("a", Column("q", Integer), Column("p", Integer))
t2 = table("a", Column("q", Integer), Column("p", Integer))
t3 = table("b", Column("q", Integer), Column("p", Integer))
t4 = table("a", Column("q", Integer), Column("x", Integer))
eq_(t1._generate_cache_key(), t2._generate_cache_key())
ne_(t1._generate_cache_key(), t3._generate_cache_key())
ne_(t1._generate_cache_key(), t4._generate_cache_key())
ne_(t3._generate_cache_key(), t4._generate_cache_key())
def test_compare_comparison_associative(self):
l1 = table_c.c.x == table_d.c.y
@@ -521,6 +747,15 @@ class CompareClausesTest(fixtures.TestBase):
is_true(l1.compare(l2))
is_false(l1.compare(l3))
def test_compare_comparison_non_commutative_inverses(self):
l1 = table_c.c.x >= table_d.c.y
l2 = table_d.c.y < table_c.c.x
l3 = table_d.c.y <= table_c.c.x
# we're not doing this kind of commutativity right now.
is_false(l1.compare(l2))
is_false(l1.compare(l3))
def test_compare_clauselist_associative(self):
l1 = and_(table_c.c.x == table_d.c.y, table_c.c.y == table_d.c.z)
@@ -624,3 +859,45 @@ class CompareClausesTest(fixtures.TestBase):
use_proxies=True,
)
)
def test_compare_annotated_clears_mapping(self):
t = table("t", column("x"), column("y"))
x_a = t.c.x._annotate({"foo": True})
x_b = t.c.x._annotate({"foo": True})
is_true(x_a.compare(x_b, compare_annotations=True))
is_false(
x_a.compare(x_b._annotate({"bar": True}), compare_annotations=True)
)
s1 = select([t.c.x])._annotate({"foo": True})
s2 = select([t.c.x])._annotate({"foo": True})
is_true(s1.compare(s2, compare_annotations=True))
is_false(
s1.compare(s2._annotate({"bar": True}), compare_annotations=True)
)
def test_compare_annotated_wo_annotations(self):
t = table("t", column("x"), column("y"))
x_a = t.c.x._annotate({})
x_b = t.c.x._annotate({"foo": True})
is_true(t.c.x.compare(x_a))
is_true(x_b.compare(x_a))
is_true(x_a.compare(t.c.x))
is_false(x_a.compare(t.c.y))
is_false(t.c.y.compare(x_a))
is_true((t.c.x == 5).compare(x_a == 5))
is_false((t.c.y == 5).compare(x_a == 5))
s = select([t]).subquery()
x_p = s.c.x
is_false(x_a.compare(x_p))
is_false(t.c.x.compare(x_p))
x_p_a = x_p._annotate({})
is_true(x_p_a.compare(x_p))
is_true(x_p.compare(x_p_a))
is_false(x_p_a.compare(x_a))
@@ -55,6 +55,7 @@ class TraversalTest(fixtures.TestBase, AssertsExecutionResults):
# identity semantics.
class A(ClauseElement):
__visit_name__ = "a"
_traverse_internals = []
def __init__(self, expr):
self.expr = expr
+3
View File
@@ -118,11 +118,14 @@ class DefaultColumnComparatorTest(fixtures.TestBase):
)
)
modifiers = operator(left, right).modifiers
assert operator(left, right).compare(
BinaryExpression(
coercions.expect(roles.WhereHavingRole, left),
coercions.expect(roles.WhereHavingRole, right),
operator,
modifiers=modifiers,
)
)
+26 -32
View File
@@ -1070,7 +1070,7 @@ class SelectableTest(
s4 = s3.with_only_columns([table2.c.b])
self.assert_compile(s4, "SELECT t2.b FROM t2")
def test_from_list_warning_against_existing(self):
def test_from_list_against_existing_one(self):
c1 = Column("c1", Integer)
s = select([c1])
@@ -1081,7 +1081,7 @@ class SelectableTest(
self.assert_compile(s, "SELECT t.c1 FROM t")
def test_from_list_recovers_after_warning(self):
def test_from_list_against_existing_two(self):
c1 = Column("c1", Integer)
c2 = Column("c2", Integer)
@@ -1090,18 +1090,11 @@ class SelectableTest(
# force a compile.
eq_(str(s), "SELECT c1")
@testing.emits_warning()
def go():
return Table("t", MetaData(), c1, c2)
t = go()
t = Table("t", MetaData(), c1, c2)
eq_(c1._from_objects, [t])
eq_(c2._from_objects, [t])
# 's' has been baked. Can't afford
# not caching select._froms.
# hopefully the warning will clue the user
self.assert_compile(s, "SELECT t.c1 FROM t")
self.assert_compile(select([c1]), "SELECT t.c1 FROM t")
self.assert_compile(select([c2]), "SELECT t.c2 FROM t")
@@ -1124,6 +1117,26 @@ class SelectableTest(
"foo",
)
def test_whereclause_adapted(self):
table1 = table("t1", column("a"))
s1 = select([table1]).subquery()
s2 = select([s1]).where(s1.c.a == 5)
assert s2._whereclause.left.table is s1
ta = select([table1]).subquery()
s3 = sql_util.ClauseAdapter(ta).traverse(s2)
assert s1 not in s3._froms
# these are new assumptions with the newer approach that
# actively swaps out whereclause and others
assert s3._whereclause.left.table is not s1
assert s3._whereclause.left.table in s3._froms
class RefreshForNewColTest(fixtures.TestBase):
def test_join_uninit(self):
@@ -2241,25 +2254,6 @@ class AnnotationsTest(fixtures.TestBase):
annot = obj._annotate({})
ne_(set([obj]), set([annot]))
def test_compare(self):
t = table("t", column("x"), column("y"))
x_a = t.c.x._annotate({})
assert t.c.x.compare(x_a)
assert x_a.compare(t.c.x)
assert not x_a.compare(t.c.y)
assert not t.c.y.compare(x_a)
assert (t.c.x == 5).compare(x_a == 5)
assert not (t.c.y == 5).compare(x_a == 5)
s = select([t]).subquery()
x_p = s.c.x
assert not x_a.compare(x_p)
assert not t.c.x.compare(x_p)
x_p_a = x_p._annotate({})
assert x_p_a.compare(x_p)
assert x_p.compare(x_p_a)
assert not x_p_a.compare(x_a)
def test_proxy_set_iteration_includes_annotated(self):
from sqlalchemy.schema import Column
@@ -2542,13 +2536,13 @@ class AnnotationsTest(fixtures.TestBase):
):
# the columns clause isn't changed at all
assert sel._raw_columns[0].table is a1
assert sel._froms[0] is sel._froms[1].left
assert sel._froms[0].element is sel._froms[1].left.element
eq_(str(s), str(sel))
# when we are modifying annotations sets only
# partially, each element is copied unconditionally
# when encountered.
# partially, elements are copied uniquely based on id().
# this is new as of 1.4, previously they'd be copied every time
for sel in (
sql_util._deep_deannotate(s, {"foo": "bar"}),
sql_util._deep_annotate(s, {"foo": "bar"}),
+1 -1
View File
@@ -7,6 +7,6 @@ from sqlalchemy.testing import fixtures
class MiscTest(fixtures.TestBase):
def test_column_element_no_visit(self):
class MyElement(ColumnElement):
pass
_traverse_internals = []
eq_(sql_util.find_tables(MyElement(), check_columns=True), [])