mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-13 04:07:20 -04:00
Merge "ORM bulk insert via execute" into main
This commit is contained in:
Vendored
+11
@@ -660,6 +660,17 @@ Selecting a Synchronization Strategy
|
||||
With both the 1.x and 2.0 form of ORM-enabled updates and deletes, the following
|
||||
values for ``synchronize_session`` are supported:
|
||||
|
||||
* ``'auto'`` - this is the default. The ``'fetch'`` strategy will be used on
|
||||
backends that support RETURNING, which includes all SQLAlchemy-native drivers
|
||||
except for MySQL. If RETURNING is not supported, the ``'evaluate'``
|
||||
strategy will be used instead.
|
||||
|
||||
.. versionchanged:: 2.0 Added the ``'auto'`` synchronization strategy. As
|
||||
most backends now support RETURNING, selecting ``'fetch'`` for these
|
||||
backends specifically is the more efficient and error-free default for
|
||||
these backends. The MySQL backend as well as third party backends without
|
||||
RETURNING support will continue to use ``'evaluate'`` by default.
|
||||
|
||||
* ``False`` - don't synchronize the session. This option is the most
|
||||
efficient and is reliable once the session is expired, which
|
||||
typically occurs after a commit(), or explicitly using
|
||||
|
||||
@@ -301,7 +301,7 @@ Fast Executemany Mode
|
||||
The SQL Server ``fast_executemany`` parameter may be used at the same time
|
||||
as ``insertmanyvalues`` is enabled; however, the parameter will not be used
|
||||
in as many cases as INSERT statements that are invoked using Core
|
||||
:class:`.Insert` constructs as well as all ORM use no longer use the
|
||||
:class:`_dml.Insert` constructs as well as all ORM use no longer use the
|
||||
``.executemany()`` DBAPI cursor method.
|
||||
|
||||
The PyODBC driver includes support for a "fast executemany" mode of execution
|
||||
|
||||
@@ -134,9 +134,11 @@ def _upsert(cfg, table, returning, set_lambda=None):
|
||||
|
||||
stmt = insert(table)
|
||||
|
||||
table_pk = inspect(table).selectable
|
||||
|
||||
if set_lambda:
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=table.primary_key, set_=set_lambda(stmt.excluded)
|
||||
index_elements=table_pk.primary_key, set_=set_lambda(stmt.excluded)
|
||||
)
|
||||
else:
|
||||
stmt = stmt.on_conflict_do_nothing()
|
||||
|
||||
@@ -1466,11 +1466,6 @@ class SQLiteCompiler(compiler.SQLCompiler):
|
||||
|
||||
return target_text
|
||||
|
||||
def visit_insert(self, insert_stmt, **kw):
|
||||
if insert_stmt._post_values_clause is not None:
|
||||
kw["disable_implicit_returning"] = True
|
||||
return super().visit_insert(insert_stmt, **kw)
|
||||
|
||||
def visit_on_conflict_do_nothing(self, on_conflict, **kw):
|
||||
|
||||
target_text = self._on_conflict_target(on_conflict, **kw)
|
||||
|
||||
+336
-54
@@ -23,12 +23,14 @@ from typing import Iterator
|
||||
from typing import List
|
||||
from typing import NoReturn
|
||||
from typing import Optional
|
||||
from typing import overload
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
from typing import Union
|
||||
|
||||
from .result import IteratorResult
|
||||
from .result import MergedResult
|
||||
from .result import Result
|
||||
from .result import ResultMetaData
|
||||
@@ -62,36 +64,80 @@ if typing.TYPE_CHECKING:
|
||||
from .interfaces import ExecutionContext
|
||||
from .result import _KeyIndexType
|
||||
from .result import _KeyMapRecType
|
||||
from .result import _KeyMapType
|
||||
from .result import _KeyType
|
||||
from .result import _ProcessorsType
|
||||
from .result import _TupleGetterType
|
||||
from ..sql.type_api import _ResultProcessorType
|
||||
|
||||
|
||||
_T = TypeVar("_T", bound=Any)
|
||||
|
||||
|
||||
# metadata entry tuple indexes.
|
||||
# using raw tuple is faster than namedtuple.
|
||||
MD_INDEX: Literal[0] = 0 # integer index in cursor.description
|
||||
MD_RESULT_MAP_INDEX: Literal[
|
||||
1
|
||||
] = 1 # integer index in compiled._result_columns
|
||||
MD_OBJECTS: Literal[
|
||||
2
|
||||
] = 2 # other string keys and ColumnElement obj that can match
|
||||
MD_LOOKUP_KEY: Literal[
|
||||
3
|
||||
] = 3 # string key we usually expect for key-based lookup
|
||||
MD_RENDERED_NAME: Literal[4] = 4 # name that is usually in cursor.description
|
||||
MD_PROCESSOR: Literal[5] = 5 # callable to process a result value into a row
|
||||
MD_UNTRANSLATED: Literal[6] = 6 # raw name from cursor.description
|
||||
# these match up to the positions in
|
||||
# _CursorKeyMapRecType
|
||||
MD_INDEX: Literal[0] = 0
|
||||
"""integer index in cursor.description
|
||||
|
||||
"""
|
||||
|
||||
MD_RESULT_MAP_INDEX: Literal[1] = 1
|
||||
"""integer index in compiled._result_columns"""
|
||||
|
||||
MD_OBJECTS: Literal[2] = 2
|
||||
"""other string keys and ColumnElement obj that can match.
|
||||
|
||||
This comes from compiler.RM_OBJECTS / compiler.ResultColumnsEntry.objects
|
||||
|
||||
"""
|
||||
|
||||
MD_LOOKUP_KEY: Literal[3] = 3
|
||||
"""string key we usually expect for key-based lookup
|
||||
|
||||
this comes from compiler.RM_NAME / compiler.ResultColumnsEntry.name
|
||||
"""
|
||||
|
||||
|
||||
MD_RENDERED_NAME: Literal[4] = 4
|
||||
"""name that is usually in cursor.description
|
||||
|
||||
this comes from compiler.RENDERED_NAME / compiler.ResultColumnsEntry.keyname
|
||||
"""
|
||||
|
||||
|
||||
MD_PROCESSOR: Literal[5] = 5
|
||||
"""callable to process a result value into a row"""
|
||||
|
||||
MD_UNTRANSLATED: Literal[6] = 6
|
||||
"""raw name from cursor.description"""
|
||||
|
||||
|
||||
_CursorKeyMapRecType = Tuple[
|
||||
int, int, List[Any], str, str, Optional["_ResultProcessorType"], str
|
||||
Optional[int], # MD_INDEX, None means the record is ambiguously named
|
||||
int, # MD_RESULT_MAP_INDEX
|
||||
List[Any], # MD_OBJECTS
|
||||
str, # MD_LOOKUP_KEY
|
||||
str, # MD_RENDERED_NAME
|
||||
Optional["_ResultProcessorType"], # MD_PROCESSOR
|
||||
Optional[str], # MD_UNTRANSLATED
|
||||
]
|
||||
|
||||
_CursorKeyMapType = Dict["_KeyType", _CursorKeyMapRecType]
|
||||
|
||||
# same as _CursorKeyMapRecType except the MD_INDEX value is definitely
|
||||
# not None
|
||||
_NonAmbigCursorKeyMapRecType = Tuple[
|
||||
int,
|
||||
int,
|
||||
List[Any],
|
||||
str,
|
||||
str,
|
||||
Optional["_ResultProcessorType"],
|
||||
str,
|
||||
]
|
||||
|
||||
|
||||
class CursorResultMetaData(ResultMetaData):
|
||||
"""Result metadata for DBAPI cursors."""
|
||||
@@ -127,38 +173,112 @@ class CursorResultMetaData(ResultMetaData):
|
||||
extra=[self._keymap[key][MD_OBJECTS] for key in self._keys],
|
||||
)
|
||||
|
||||
def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData:
|
||||
recs = cast(
|
||||
"List[_CursorKeyMapRecType]", list(self._metadata_for_keys(keys))
|
||||
def _make_new_metadata(
|
||||
self,
|
||||
*,
|
||||
unpickled: bool,
|
||||
processors: _ProcessorsType,
|
||||
keys: Sequence[str],
|
||||
keymap: _KeyMapType,
|
||||
tuplefilter: Optional[_TupleGetterType],
|
||||
translated_indexes: Optional[List[int]],
|
||||
safe_for_cache: bool,
|
||||
keymap_by_result_column_idx: Any,
|
||||
) -> CursorResultMetaData:
|
||||
new_obj = self.__class__.__new__(self.__class__)
|
||||
new_obj._unpickled = unpickled
|
||||
new_obj._processors = processors
|
||||
new_obj._keys = keys
|
||||
new_obj._keymap = keymap
|
||||
new_obj._tuplefilter = tuplefilter
|
||||
new_obj._translated_indexes = translated_indexes
|
||||
new_obj._safe_for_cache = safe_for_cache
|
||||
new_obj._keymap_by_result_column_idx = keymap_by_result_column_idx
|
||||
return new_obj
|
||||
|
||||
def _remove_processors(self) -> CursorResultMetaData:
|
||||
assert not self._tuplefilter
|
||||
return self._make_new_metadata(
|
||||
unpickled=self._unpickled,
|
||||
processors=[None] * len(self._processors),
|
||||
tuplefilter=None,
|
||||
translated_indexes=None,
|
||||
keymap={
|
||||
key: value[0:5] + (None,) + value[6:]
|
||||
for key, value in self._keymap.items()
|
||||
},
|
||||
keys=self._keys,
|
||||
safe_for_cache=self._safe_for_cache,
|
||||
keymap_by_result_column_idx=self._keymap_by_result_column_idx,
|
||||
)
|
||||
|
||||
def _splice_horizontally(
|
||||
self, other: CursorResultMetaData
|
||||
) -> CursorResultMetaData:
|
||||
|
||||
assert not self._tuplefilter
|
||||
|
||||
keymap = self._keymap.copy()
|
||||
offset = len(self._keys)
|
||||
keymap.update(
|
||||
{
|
||||
key: (
|
||||
# int index should be None for ambiguous key
|
||||
value[0] + offset
|
||||
if value[0] is not None and key not in keymap
|
||||
else None,
|
||||
value[1] + offset,
|
||||
*value[2:],
|
||||
)
|
||||
for key, value in other._keymap.items()
|
||||
}
|
||||
)
|
||||
|
||||
return self._make_new_metadata(
|
||||
unpickled=self._unpickled,
|
||||
processors=self._processors + other._processors, # type: ignore
|
||||
tuplefilter=None,
|
||||
translated_indexes=None,
|
||||
keys=self._keys + other._keys, # type: ignore
|
||||
keymap=keymap,
|
||||
safe_for_cache=self._safe_for_cache,
|
||||
keymap_by_result_column_idx={
|
||||
metadata_entry[MD_RESULT_MAP_INDEX]: metadata_entry
|
||||
for metadata_entry in keymap.values()
|
||||
},
|
||||
)
|
||||
|
||||
def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData:
|
||||
recs = list(self._metadata_for_keys(keys))
|
||||
|
||||
indexes = [rec[MD_INDEX] for rec in recs]
|
||||
new_keys: List[str] = [rec[MD_LOOKUP_KEY] for rec in recs]
|
||||
|
||||
if self._translated_indexes:
|
||||
indexes = [self._translated_indexes[idx] for idx in indexes]
|
||||
tup = tuplegetter(*indexes)
|
||||
|
||||
new_metadata = self.__class__.__new__(self.__class__)
|
||||
new_metadata._unpickled = self._unpickled
|
||||
new_metadata._processors = self._processors
|
||||
new_metadata._keys = new_keys
|
||||
new_metadata._tuplefilter = tup
|
||||
new_metadata._translated_indexes = indexes
|
||||
|
||||
new_recs = [(index,) + rec[1:] for index, rec in enumerate(recs)]
|
||||
new_metadata._keymap = {rec[MD_LOOKUP_KEY]: rec for rec in new_recs}
|
||||
|
||||
keymap: _KeyMapType = {rec[MD_LOOKUP_KEY]: rec for rec in new_recs}
|
||||
# TODO: need unit test for:
|
||||
# result = connection.execute("raw sql, no columns").scalars()
|
||||
# without the "or ()" it's failing because MD_OBJECTS is None
|
||||
new_metadata._keymap.update(
|
||||
keymap.update(
|
||||
(e, new_rec)
|
||||
for new_rec in new_recs
|
||||
for e in new_rec[MD_OBJECTS] or ()
|
||||
)
|
||||
|
||||
return new_metadata
|
||||
return self._make_new_metadata(
|
||||
unpickled=self._unpickled,
|
||||
processors=self._processors,
|
||||
keys=new_keys,
|
||||
tuplefilter=tup,
|
||||
translated_indexes=indexes,
|
||||
keymap=keymap,
|
||||
safe_for_cache=self._safe_for_cache,
|
||||
keymap_by_result_column_idx=self._keymap_by_result_column_idx,
|
||||
)
|
||||
|
||||
def _adapt_to_context(self, context: ExecutionContext) -> ResultMetaData:
|
||||
"""When using a cached Compiled construct that has a _result_map,
|
||||
@@ -168,6 +288,7 @@ class CursorResultMetaData(ResultMetaData):
|
||||
as matched to those of the cached statement.
|
||||
|
||||
"""
|
||||
|
||||
if not context.compiled or not context.compiled._result_columns:
|
||||
return self
|
||||
|
||||
@@ -189,7 +310,6 @@ class CursorResultMetaData(ResultMetaData):
|
||||
|
||||
# make a copy and add the columns from the invoked statement
|
||||
# to the result map.
|
||||
md = self.__class__.__new__(self.__class__)
|
||||
|
||||
keymap_by_position = self._keymap_by_result_column_idx
|
||||
|
||||
@@ -201,26 +321,26 @@ class CursorResultMetaData(ResultMetaData):
|
||||
for metadata_entry in self._keymap.values()
|
||||
}
|
||||
|
||||
md._keymap = compat.dict_union(
|
||||
self._keymap,
|
||||
{
|
||||
new: keymap_by_position[idx]
|
||||
for idx, new in enumerate(
|
||||
invoked_statement._all_selected_columns
|
||||
)
|
||||
if idx in keymap_by_position
|
||||
},
|
||||
)
|
||||
|
||||
md._unpickled = self._unpickled
|
||||
md._processors = self._processors
|
||||
assert not self._tuplefilter
|
||||
md._tuplefilter = None
|
||||
md._translated_indexes = None
|
||||
md._keys = self._keys
|
||||
md._keymap_by_result_column_idx = self._keymap_by_result_column_idx
|
||||
md._safe_for_cache = self._safe_for_cache
|
||||
return md
|
||||
return self._make_new_metadata(
|
||||
keymap=compat.dict_union(
|
||||
self._keymap,
|
||||
{
|
||||
new: keymap_by_position[idx]
|
||||
for idx, new in enumerate(
|
||||
invoked_statement._all_selected_columns
|
||||
)
|
||||
if idx in keymap_by_position
|
||||
},
|
||||
),
|
||||
unpickled=self._unpickled,
|
||||
processors=self._processors,
|
||||
tuplefilter=None,
|
||||
translated_indexes=None,
|
||||
keys=self._keys,
|
||||
safe_for_cache=self._safe_for_cache,
|
||||
keymap_by_result_column_idx=self._keymap_by_result_column_idx,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -683,7 +803,27 @@ class CursorResultMetaData(ResultMetaData):
|
||||
untranslated,
|
||||
)
|
||||
|
||||
def _key_fallback(self, key, err, raiseerr=True):
|
||||
@overload
|
||||
def _key_fallback(
|
||||
self, key: Any, err: Exception, raiseerr: Literal[True] = ...
|
||||
) -> NoReturn:
|
||||
...
|
||||
|
||||
@overload
|
||||
def _key_fallback(
|
||||
self, key: Any, err: Exception, raiseerr: Literal[False] = ...
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def _key_fallback(
|
||||
self, key: Any, err: Exception, raiseerr: bool = ...
|
||||
) -> Optional[NoReturn]:
|
||||
...
|
||||
|
||||
def _key_fallback(
|
||||
self, key: Any, err: Exception, raiseerr: bool = True
|
||||
) -> Optional[NoReturn]:
|
||||
|
||||
if raiseerr:
|
||||
if self._unpickled and isinstance(key, elements.ColumnElement):
|
||||
@@ -714,9 +854,9 @@ class CursorResultMetaData(ResultMetaData):
|
||||
try:
|
||||
rec = self._keymap[key]
|
||||
except KeyError as ke:
|
||||
rec = self._key_fallback(key, ke, raiseerr)
|
||||
if rec is None:
|
||||
return None
|
||||
x = self._key_fallback(key, ke, raiseerr)
|
||||
assert x is None
|
||||
return None
|
||||
|
||||
index = rec[0]
|
||||
|
||||
@@ -734,7 +874,7 @@ class CursorResultMetaData(ResultMetaData):
|
||||
|
||||
def _metadata_for_keys(
|
||||
self, keys: Sequence[Any]
|
||||
) -> Iterator[_CursorKeyMapRecType]:
|
||||
) -> Iterator[_NonAmbigCursorKeyMapRecType]:
|
||||
for key in keys:
|
||||
if int in key.__class__.__mro__:
|
||||
key = self._keys[key]
|
||||
@@ -750,7 +890,7 @@ class CursorResultMetaData(ResultMetaData):
|
||||
if index is None:
|
||||
self._raise_for_ambiguous_column_name(rec)
|
||||
|
||||
yield rec
|
||||
yield cast(_NonAmbigCursorKeyMapRecType, rec)
|
||||
|
||||
def __getstate__(self):
|
||||
return {
|
||||
@@ -1237,6 +1377,12 @@ _NO_RESULT_METADATA = _NoResultMetaData()
|
||||
SelfCursorResult = TypeVar("SelfCursorResult", bound="CursorResult[Any]")
|
||||
|
||||
|
||||
def null_dml_result() -> IteratorResult[Any]:
|
||||
it: IteratorResult[Any] = IteratorResult(_NoResultMetaData(), iter([]))
|
||||
it._soft_close()
|
||||
return it
|
||||
|
||||
|
||||
class CursorResult(Result[_T]):
|
||||
"""A Result that is representing state from a DBAPI cursor.
|
||||
|
||||
@@ -1586,6 +1732,142 @@ class CursorResult(Result[_T]):
|
||||
"""
|
||||
return self.context.returned_default_rows
|
||||
|
||||
def splice_horizontally(self, other):
|
||||
"""Return a new :class:`.CursorResult` that "horizontally splices"
|
||||
together the rows of this :class:`.CursorResult` with that of another
|
||||
:class:`.CursorResult`.
|
||||
|
||||
.. tip:: This method is for the benefit of the SQLAlchemy ORM and is
|
||||
not intended for general use.
|
||||
|
||||
"horizontally splices" means that for each row in the first and second
|
||||
result sets, a new row that concatenates the two rows together is
|
||||
produced, which then becomes the new row. The incoming
|
||||
:class:`.CursorResult` must have the identical number of rows. It is
|
||||
typically expected that the two result sets come from the same sort
|
||||
order as well, as the result rows are spliced together based on their
|
||||
position in the result.
|
||||
|
||||
The expected use case here is so that multiple INSERT..RETURNING
|
||||
statements against different tables can produce a single result
|
||||
that looks like a JOIN of those two tables.
|
||||
|
||||
E.g.::
|
||||
|
||||
r1 = connection.execute(
|
||||
users.insert().returning(users.c.user_name, users.c.user_id),
|
||||
user_values
|
||||
)
|
||||
|
||||
r2 = connection.execute(
|
||||
addresses.insert().returning(
|
||||
addresses.c.address_id,
|
||||
addresses.c.address,
|
||||
addresses.c.user_id,
|
||||
),
|
||||
address_values
|
||||
)
|
||||
|
||||
rows = r1.splice_horizontally(r2).all()
|
||||
assert (
|
||||
rows ==
|
||||
[
|
||||
("john", 1, 1, "foo@bar.com", 1),
|
||||
("jack", 2, 2, "bar@bat.com", 2),
|
||||
]
|
||||
)
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`.CursorResult.splice_vertically`
|
||||
|
||||
|
||||
"""
|
||||
|
||||
clone = self._generate()
|
||||
total_rows = [
|
||||
tuple(r1) + tuple(r2)
|
||||
for r1, r2 in zip(
|
||||
list(self._raw_row_iterator()),
|
||||
list(other._raw_row_iterator()),
|
||||
)
|
||||
]
|
||||
|
||||
clone._metadata = clone._metadata._splice_horizontally(other._metadata)
|
||||
|
||||
clone.cursor_strategy = FullyBufferedCursorFetchStrategy(
|
||||
None,
|
||||
initial_buffer=total_rows,
|
||||
)
|
||||
clone._reset_memoizations()
|
||||
return clone
|
||||
|
||||
def splice_vertically(self, other):
|
||||
"""Return a new :class:`.CursorResult` that "vertically splices",
|
||||
i.e. "extends", the rows of this :class:`.CursorResult` with that of
|
||||
another :class:`.CursorResult`.
|
||||
|
||||
.. tip:: This method is for the benefit of the SQLAlchemy ORM and is
|
||||
not intended for general use.
|
||||
|
||||
"vertically splices" means the rows of the given result are appended to
|
||||
the rows of this cursor result. The incoming :class:`.CursorResult`
|
||||
must have rows that represent the identical list of columns in the
|
||||
identical order as they are in this :class:`.CursorResult`.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`.CursorResult.splice_horizontally`
|
||||
|
||||
"""
|
||||
clone = self._generate()
|
||||
total_rows = list(self._raw_row_iterator()) + list(
|
||||
other._raw_row_iterator()
|
||||
)
|
||||
|
||||
clone.cursor_strategy = FullyBufferedCursorFetchStrategy(
|
||||
None,
|
||||
initial_buffer=total_rows,
|
||||
)
|
||||
clone._reset_memoizations()
|
||||
return clone
|
||||
|
||||
def _rewind(self, rows):
|
||||
"""rewind this result back to the given rowset.
|
||||
|
||||
this is used internally for the case where an :class:`.Insert`
|
||||
construct combines the use of
|
||||
:meth:`.Insert.return_defaults` along with the
|
||||
"supplemental columns" feature.
|
||||
|
||||
"""
|
||||
|
||||
if self._echo:
|
||||
self.context.connection._log_debug(
|
||||
"CursorResult rewound %d row(s)", len(rows)
|
||||
)
|
||||
|
||||
# the rows given are expected to be Row objects, so we
|
||||
# have to clear out processors which have already run on these
|
||||
# rows
|
||||
self._metadata = cast(
|
||||
CursorResultMetaData, self._metadata
|
||||
)._remove_processors()
|
||||
|
||||
self.cursor_strategy = FullyBufferedCursorFetchStrategy(
|
||||
None,
|
||||
# TODO: if these are Row objects, can we save on not having to
|
||||
# re-make new Row objects out of them a second time? is that
|
||||
# what's actually happening right now? maybe look into this
|
||||
initial_buffer=rows,
|
||||
)
|
||||
self._reset_memoizations()
|
||||
return self
|
||||
|
||||
@property
|
||||
def returned_defaults(self):
|
||||
"""Return the values of default columns that were fetched using
|
||||
|
||||
@@ -1007,6 +1007,7 @@ class DefaultExecutionContext(ExecutionContext):
|
||||
|
||||
_is_implicit_returning = False
|
||||
_is_explicit_returning = False
|
||||
_is_supplemental_returning = False
|
||||
_is_server_side = False
|
||||
|
||||
_soft_closed = False
|
||||
@@ -1125,18 +1126,19 @@ class DefaultExecutionContext(ExecutionContext):
|
||||
self.is_text = compiled.isplaintext
|
||||
|
||||
if ii or iu or id_:
|
||||
dml_statement = compiled.compile_state.statement # type: ignore
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(compiled.statement, UpdateBase)
|
||||
assert isinstance(dml_statement, UpdateBase)
|
||||
self.is_crud = True
|
||||
self._is_explicit_returning = ier = bool(
|
||||
compiled.statement._returning
|
||||
)
|
||||
self._is_implicit_returning = iir = is_implicit_returning = bool(
|
||||
self._is_explicit_returning = ier = bool(dml_statement._returning)
|
||||
self._is_implicit_returning = iir = bool(
|
||||
compiled.implicit_returning
|
||||
)
|
||||
assert not (
|
||||
is_implicit_returning and compiled.statement._returning
|
||||
)
|
||||
if iir and dml_statement._supplemental_returning:
|
||||
self._is_supplemental_returning = True
|
||||
|
||||
# dont mix implicit and explicit returning
|
||||
assert not (iir and ier)
|
||||
|
||||
if (ier or iir) and compiled.for_executemany:
|
||||
if ii and not self.dialect.insert_executemany_returning:
|
||||
@@ -1711,7 +1713,14 @@ class DefaultExecutionContext(ExecutionContext):
|
||||
# are that the result has only one row, until executemany()
|
||||
# support is added here.
|
||||
assert result._metadata.returns_rows
|
||||
result._soft_close()
|
||||
|
||||
# Insert statement has both return_defaults() and
|
||||
# returning(). rewind the result on the list of rows
|
||||
# we just used.
|
||||
if self._is_supplemental_returning:
|
||||
result._rewind(rows)
|
||||
else:
|
||||
result._soft_close()
|
||||
elif not self._is_explicit_returning:
|
||||
result._soft_close()
|
||||
|
||||
@@ -1721,21 +1730,18 @@ class DefaultExecutionContext(ExecutionContext):
|
||||
# function so this is not necessarily true.
|
||||
# assert not result.returns_rows
|
||||
|
||||
elif self.isupdate and self._is_implicit_returning:
|
||||
# get rowcount
|
||||
# (which requires open cursor on some drivers)
|
||||
# we were not doing this in 1.4, however
|
||||
# test_rowcount -> test_update_rowcount_return_defaults
|
||||
# is testing this, and psycopg will no longer return
|
||||
# rowcount after cursor is closed.
|
||||
result.rowcount
|
||||
elif self._is_implicit_returning:
|
||||
rows = result.all()
|
||||
|
||||
if rows:
|
||||
self.returned_default_rows = rows
|
||||
result.rowcount = len(rows)
|
||||
self._has_rowcount = True
|
||||
|
||||
row = result.fetchone()
|
||||
if row is not None:
|
||||
self.returned_default_rows = [row]
|
||||
|
||||
result._soft_close()
|
||||
if self._is_supplemental_returning:
|
||||
result._rewind(rows)
|
||||
else:
|
||||
result._soft_close()
|
||||
|
||||
# test that it has a cursor metadata that is accurate.
|
||||
# the rows have all been fetched however.
|
||||
@@ -1750,7 +1756,6 @@ class DefaultExecutionContext(ExecutionContext):
|
||||
elif self.isupdate or self.isdelete:
|
||||
result.rowcount
|
||||
self._has_rowcount = True
|
||||
|
||||
return result
|
||||
|
||||
@util.memoized_property
|
||||
|
||||
@@ -109,9 +109,27 @@ class ResultMetaData:
|
||||
def _for_freeze(self) -> ResultMetaData:
|
||||
raise NotImplementedError()
|
||||
|
||||
@overload
|
||||
def _key_fallback(
|
||||
self, key: _KeyType, err: Exception, raiseerr: bool = True
|
||||
self, key: Any, err: Exception, raiseerr: Literal[True] = ...
|
||||
) -> NoReturn:
|
||||
...
|
||||
|
||||
@overload
|
||||
def _key_fallback(
|
||||
self, key: Any, err: Exception, raiseerr: Literal[False] = ...
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def _key_fallback(
|
||||
self, key: Any, err: Exception, raiseerr: bool = ...
|
||||
) -> Optional[NoReturn]:
|
||||
...
|
||||
|
||||
def _key_fallback(
|
||||
self, key: Any, err: Exception, raiseerr: bool = True
|
||||
) -> Optional[NoReturn]:
|
||||
assert raiseerr
|
||||
raise KeyError(key) from err
|
||||
|
||||
@@ -2148,6 +2166,7 @@ class IteratorResult(Result[_TP]):
|
||||
"""
|
||||
|
||||
_hard_closed = False
|
||||
_soft_closed = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -2168,6 +2187,7 @@ class IteratorResult(Result[_TP]):
|
||||
self.raw._soft_close(hard=hard, **kw)
|
||||
self.iterator = iter([])
|
||||
self._reset_memoizations()
|
||||
self._soft_closed = True
|
||||
|
||||
def _raise_hard_closed(self) -> NoReturn:
|
||||
raise exc.ResourceClosedError("This result object is closed.")
|
||||
|
||||
+1147
-324
File diff suppressed because it is too large
Load Diff
+141
-32
@@ -73,6 +73,7 @@ if TYPE_CHECKING:
|
||||
from .query import Query
|
||||
from .session import _BindArguments
|
||||
from .session import Session
|
||||
from ..engine import Result
|
||||
from ..engine.interfaces import _CoreSingleExecuteParams
|
||||
from ..engine.interfaces import _ExecuteOptionsParameter
|
||||
from ..sql._typing import _ColumnsClauseArgument
|
||||
@@ -203,15 +204,19 @@ _orm_load_exec_options = util.immutabledict(
|
||||
|
||||
|
||||
class AbstractORMCompileState(CompileState):
|
||||
is_dml_returning = False
|
||||
|
||||
@classmethod
|
||||
def create_for_statement(
|
||||
cls,
|
||||
statement: Union[Select, FromStatement],
|
||||
compiler: Optional[SQLCompiler],
|
||||
**kw: Any,
|
||||
) -> ORMCompileState:
|
||||
) -> AbstractORMCompileState:
|
||||
"""Create a context for a statement given a :class:`.Compiler`.
|
||||
|
||||
This method is always invoked in the context of SQLCompiler.process().
|
||||
|
||||
For a Select object, this would be invoked from
|
||||
SQLCompiler.visit_select(). For the special FromStatement object used
|
||||
by Query to indicate "Query.from_statement()", this is called by
|
||||
@@ -232,6 +237,28 @@ class AbstractORMCompileState(CompileState):
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def orm_execute_statement(
|
||||
cls,
|
||||
session,
|
||||
statement,
|
||||
params,
|
||||
execution_options,
|
||||
bind_arguments,
|
||||
conn,
|
||||
) -> Result:
|
||||
result = conn.execute(
|
||||
statement, params or {}, execution_options=execution_options
|
||||
)
|
||||
return cls.orm_setup_cursor_result(
|
||||
session,
|
||||
statement,
|
||||
params,
|
||||
execution_options,
|
||||
bind_arguments,
|
||||
result,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def orm_setup_cursor_result(
|
||||
cls,
|
||||
@@ -309,6 +336,17 @@ class ORMCompileState(AbstractORMCompileState):
|
||||
def __init__(self, *arg, **kw):
|
||||
raise NotImplementedError()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@classmethod
|
||||
def create_for_statement(
|
||||
cls,
|
||||
statement: Union[Select, FromStatement],
|
||||
compiler: Optional[SQLCompiler],
|
||||
**kw: Any,
|
||||
) -> ORMCompileState:
|
||||
...
|
||||
|
||||
def _append_dedupe_col_collection(self, obj, col_collection):
|
||||
dedupe = self.dedupe_columns
|
||||
if obj not in dedupe:
|
||||
@@ -332,26 +370,6 @@ class ORMCompileState(AbstractORMCompileState):
|
||||
else:
|
||||
return SelectState._column_naming_convention(label_style)
|
||||
|
||||
@classmethod
|
||||
def create_for_statement(
|
||||
cls,
|
||||
statement: Union[Select, FromStatement],
|
||||
compiler: Optional[SQLCompiler],
|
||||
**kw: Any,
|
||||
) -> ORMCompileState:
|
||||
"""Create a context for a statement given a :class:`.Compiler`.
|
||||
|
||||
This method is always invoked in the context of SQLCompiler.process().
|
||||
|
||||
For a Select object, this would be invoked from
|
||||
SQLCompiler.visit_select(). For the special FromStatement object used
|
||||
by Query to indicate "Query.from_statement()", this is called by
|
||||
FromStatement._compiler_dispatch() that would be called by
|
||||
SQLCompiler.process().
|
||||
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def get_column_descriptions(cls, statement):
|
||||
return _column_descriptions(statement)
|
||||
@@ -518,6 +536,49 @@ class ORMCompileState(AbstractORMCompileState):
|
||||
)
|
||||
|
||||
|
||||
class DMLReturningColFilter:
|
||||
"""an adapter used for the DML RETURNING case.
|
||||
|
||||
Has a subset of the interface used by
|
||||
:class:`.ORMAdapter` and is used for :class:`._QueryEntity`
|
||||
instances to set up their columns as used in RETURNING for a
|
||||
DML statement.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ("mapper", "columns", "__weakref__")
|
||||
|
||||
def __init__(self, target_mapper, immediate_dml_mapper):
|
||||
if (
|
||||
immediate_dml_mapper is not None
|
||||
and target_mapper.local_table
|
||||
is not immediate_dml_mapper.local_table
|
||||
):
|
||||
# joined inh, or in theory other kinds of multi-table mappings
|
||||
self.mapper = immediate_dml_mapper
|
||||
else:
|
||||
# single inh, normal mappings, etc.
|
||||
self.mapper = target_mapper
|
||||
self.columns = self.columns = util.WeakPopulateDict(
|
||||
self.adapt_check_present # type: ignore
|
||||
)
|
||||
|
||||
def __call__(self, col, as_filter):
|
||||
for cc in sql_util._find_columns(col):
|
||||
c2 = self.adapt_check_present(cc)
|
||||
if c2 is not None:
|
||||
return col
|
||||
else:
|
||||
return None
|
||||
|
||||
def adapt_check_present(self, col):
|
||||
mapper = self.mapper
|
||||
prop = mapper._columntoproperty.get(col, None)
|
||||
if prop is None:
|
||||
return None
|
||||
return mapper.local_table.c.corresponding_column(col)
|
||||
|
||||
|
||||
@sql.base.CompileState.plugin_for("orm", "orm_from_statement")
|
||||
class ORMFromStatementCompileState(ORMCompileState):
|
||||
_from_obj_alias = None
|
||||
@@ -525,7 +586,7 @@ class ORMFromStatementCompileState(ORMCompileState):
|
||||
|
||||
statement_container: FromStatement
|
||||
requested_statement: Union[SelectBase, TextClause, UpdateBase]
|
||||
dml_table: _DMLTableElement
|
||||
dml_table: Optional[_DMLTableElement] = None
|
||||
|
||||
_has_orm_entities = False
|
||||
multi_row_eager_loaders = False
|
||||
@@ -541,7 +602,7 @@ class ORMFromStatementCompileState(ORMCompileState):
|
||||
statement_container: Union[Select, FromStatement],
|
||||
compiler: Optional[SQLCompiler],
|
||||
**kw: Any,
|
||||
) -> ORMCompileState:
|
||||
) -> ORMFromStatementCompileState:
|
||||
|
||||
if compiler is not None:
|
||||
toplevel = not compiler.stack
|
||||
@@ -565,6 +626,7 @@ class ORMFromStatementCompileState(ORMCompileState):
|
||||
|
||||
if statement.is_dml:
|
||||
self.dml_table = statement.table
|
||||
self.is_dml_returning = True
|
||||
|
||||
self._entities = []
|
||||
self._polymorphic_adapters = {}
|
||||
@@ -674,6 +736,18 @@ class ORMFromStatementCompileState(ORMCompileState):
|
||||
def _get_current_adapter(self):
|
||||
return None
|
||||
|
||||
def setup_dml_returning_compile_state(self, dml_mapper):
|
||||
"""used by BulkORMInsert (and Update / Delete?) to set up a handler
|
||||
for RETURNING to return ORM objects and expressions
|
||||
|
||||
"""
|
||||
target_mapper = self.statement._propagate_attrs.get(
|
||||
"plugin_subject", None
|
||||
)
|
||||
adapter = DMLReturningColFilter(target_mapper, dml_mapper)
|
||||
for entity in self._entities:
|
||||
entity.setup_dml_returning_compile_state(self, adapter)
|
||||
|
||||
|
||||
class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]):
|
||||
"""Core construct that represents a load of ORM objects from various
|
||||
@@ -813,7 +887,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
|
||||
statement: Union[Select, FromStatement],
|
||||
compiler: Optional[SQLCompiler],
|
||||
**kw: Any,
|
||||
) -> ORMCompileState:
|
||||
) -> ORMSelectCompileState:
|
||||
"""compiler hook, we arrive here from compiler.visit_select() only."""
|
||||
|
||||
self = cls.__new__(cls)
|
||||
@@ -2312,6 +2386,13 @@ class _QueryEntity:
|
||||
def setup_compile_state(self, compile_state: ORMCompileState) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def setup_dml_returning_compile_state(
|
||||
self,
|
||||
compile_state: ORMCompileState,
|
||||
adapter: DMLReturningColFilter,
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def row_processor(self, context, result):
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -2509,8 +2590,24 @@ class _MapperEntity(_QueryEntity):
|
||||
|
||||
return _instance, self._label_name, self._extra_entities
|
||||
|
||||
def setup_compile_state(self, compile_state):
|
||||
def setup_dml_returning_compile_state(
|
||||
self,
|
||||
compile_state: ORMCompileState,
|
||||
adapter: DMLReturningColFilter,
|
||||
) -> None:
|
||||
loading._setup_entity_query(
|
||||
compile_state,
|
||||
self.mapper,
|
||||
self,
|
||||
self.path,
|
||||
adapter,
|
||||
compile_state.primary_columns,
|
||||
with_polymorphic=self._with_polymorphic_mappers,
|
||||
only_load_props=compile_state.compile_options._only_load_props,
|
||||
polymorphic_discriminator=self._polymorphic_discriminator,
|
||||
)
|
||||
|
||||
def setup_compile_state(self, compile_state):
|
||||
adapter = self._get_entity_clauses(compile_state)
|
||||
|
||||
single_table_crit = self.mapper._single_table_criterion
|
||||
@@ -2536,7 +2633,6 @@ class _MapperEntity(_QueryEntity):
|
||||
only_load_props=compile_state.compile_options._only_load_props,
|
||||
polymorphic_discriminator=self._polymorphic_discriminator,
|
||||
)
|
||||
|
||||
compile_state._fallback_from_clauses.append(self.selectable)
|
||||
|
||||
|
||||
@@ -2743,9 +2839,7 @@ class _ColumnEntity(_QueryEntity):
|
||||
getter, label_name, extra_entities = self._row_processor
|
||||
if self.translate_raw_column:
|
||||
extra_entities += (
|
||||
result.context.invoked_statement._raw_columns[
|
||||
self.raw_column_index
|
||||
],
|
||||
context.query._raw_columns[self.raw_column_index],
|
||||
)
|
||||
|
||||
return getter, label_name, extra_entities
|
||||
@@ -2781,9 +2875,7 @@ class _ColumnEntity(_QueryEntity):
|
||||
|
||||
if self.translate_raw_column:
|
||||
extra_entities = self._extra_entities + (
|
||||
result.context.invoked_statement._raw_columns[
|
||||
self.raw_column_index
|
||||
],
|
||||
context.query._raw_columns[self.raw_column_index],
|
||||
)
|
||||
return getter, self._label_name, extra_entities
|
||||
else:
|
||||
@@ -2843,6 +2935,8 @@ class _RawColumnEntity(_ColumnEntity):
|
||||
current_adapter = compile_state._get_current_adapter()
|
||||
if current_adapter:
|
||||
column = current_adapter(self.column, False)
|
||||
if column is None:
|
||||
return
|
||||
else:
|
||||
column = self.column
|
||||
|
||||
@@ -2944,10 +3038,25 @@ class _ORMColumnEntity(_ColumnEntity):
|
||||
self.entity_zero
|
||||
) and entity.common_parent(self.entity_zero)
|
||||
|
||||
def setup_dml_returning_compile_state(
|
||||
self,
|
||||
compile_state: ORMCompileState,
|
||||
adapter: DMLReturningColFilter,
|
||||
) -> None:
|
||||
self._fetch_column = self.column
|
||||
column = adapter(self.column, False)
|
||||
if column is not None:
|
||||
compile_state.dedupe_columns.add(column)
|
||||
compile_state.primary_columns.append(column)
|
||||
|
||||
def setup_compile_state(self, compile_state):
|
||||
current_adapter = compile_state._get_current_adapter()
|
||||
if current_adapter:
|
||||
column = current_adapter(self.column, False)
|
||||
if column is None:
|
||||
assert compile_state.is_dml_returning
|
||||
self._fetch_column = self.column
|
||||
return
|
||||
else:
|
||||
column = self.column
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import operator
|
||||
import typing
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import NoReturn
|
||||
from typing import Optional
|
||||
@@ -602,6 +603,31 @@ class Composite(
|
||||
def _attribute_keys(self) -> Sequence[str]:
|
||||
return [prop.key for prop in self.props]
|
||||
|
||||
def _populate_composite_bulk_save_mappings_fn(
|
||||
self,
|
||||
) -> Callable[[Dict[str, Any]], None]:
|
||||
|
||||
if self._generated_composite_accessor:
|
||||
get_values = self._generated_composite_accessor
|
||||
else:
|
||||
|
||||
def get_values(val: Any) -> Tuple[Any]:
|
||||
return val.__composite_values__() # type: ignore
|
||||
|
||||
attrs = [prop.key for prop in self.props]
|
||||
|
||||
def populate(dest_dict: Dict[str, Any]) -> None:
|
||||
dest_dict.update(
|
||||
{
|
||||
key: val
|
||||
for key, val in zip(
|
||||
attrs, get_values(dest_dict.pop(self.key))
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
return populate
|
||||
|
||||
def get_history(
|
||||
self,
|
||||
state: InstanceState[Any],
|
||||
|
||||
@@ -9,8 +9,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
|
||||
from .base import LoaderCallableStatus
|
||||
from .base import PassiveFlag
|
||||
from .. import exc
|
||||
from .. import inspect
|
||||
from .. import util
|
||||
@@ -32,7 +32,16 @@ class _NoObject(operators.ColumnOperators):
|
||||
return None
|
||||
|
||||
|
||||
class _ExpiredObject(operators.ColumnOperators):
|
||||
def operate(self, *arg, **kw):
|
||||
return self
|
||||
|
||||
def reverse_operate(self, *arg, **kw):
|
||||
return self
|
||||
|
||||
|
||||
_NO_OBJECT = _NoObject()
|
||||
_EXPIRED_OBJECT = _ExpiredObject()
|
||||
|
||||
|
||||
class EvaluatorCompiler:
|
||||
@@ -73,6 +82,24 @@ class EvaluatorCompiler:
|
||||
f"alternate class {parentmapper.class_}"
|
||||
)
|
||||
key = parentmapper._columntoproperty[clause].key
|
||||
impl = parentmapper.class_manager[key].impl
|
||||
|
||||
if impl is not None:
|
||||
|
||||
def get_corresponding_attr(obj):
|
||||
if obj is None:
|
||||
return _NO_OBJECT
|
||||
state = inspect(obj)
|
||||
dict_ = state.dict
|
||||
|
||||
value = impl.get(
|
||||
state, dict_, passive=PassiveFlag.PASSIVE_NO_FETCH
|
||||
)
|
||||
if value is LoaderCallableStatus.PASSIVE_NO_RESULT:
|
||||
return _EXPIRED_OBJECT
|
||||
return value
|
||||
|
||||
return get_corresponding_attr
|
||||
else:
|
||||
key = clause.key
|
||||
if (
|
||||
@@ -85,15 +112,16 @@ class EvaluatorCompiler:
|
||||
"make use of the actual mapped columns in ORM-evaluated "
|
||||
"UPDATE / DELETE expressions."
|
||||
)
|
||||
|
||||
else:
|
||||
raise UnevaluatableError(f"Cannot evaluate column: {clause}")
|
||||
|
||||
get_corresponding_attr = operator.attrgetter(key)
|
||||
return (
|
||||
lambda obj: get_corresponding_attr(obj)
|
||||
if obj is not None
|
||||
else _NO_OBJECT
|
||||
)
|
||||
def get_corresponding_attr(obj):
|
||||
if obj is None:
|
||||
return _NO_OBJECT
|
||||
return getattr(obj, key, _EXPIRED_OBJECT)
|
||||
|
||||
return get_corresponding_attr
|
||||
|
||||
def visit_tuple(self, clause):
|
||||
return self.visit_clauselist(clause)
|
||||
@@ -134,7 +162,9 @@ class EvaluatorCompiler:
|
||||
has_null = False
|
||||
for sub_evaluate in evaluators:
|
||||
value = sub_evaluate(obj)
|
||||
if value:
|
||||
if value is _EXPIRED_OBJECT:
|
||||
return _EXPIRED_OBJECT
|
||||
elif value:
|
||||
return True
|
||||
has_null = has_null or value is None
|
||||
if has_null:
|
||||
@@ -147,6 +177,9 @@ class EvaluatorCompiler:
|
||||
def evaluate(obj):
|
||||
for sub_evaluate in evaluators:
|
||||
value = sub_evaluate(obj)
|
||||
if value is _EXPIRED_OBJECT:
|
||||
return _EXPIRED_OBJECT
|
||||
|
||||
if not value:
|
||||
if value is None or value is _NO_OBJECT:
|
||||
return None
|
||||
@@ -160,7 +193,9 @@ class EvaluatorCompiler:
|
||||
values = []
|
||||
for sub_evaluate in evaluators:
|
||||
value = sub_evaluate(obj)
|
||||
if value is None or value is _NO_OBJECT:
|
||||
if value is _EXPIRED_OBJECT:
|
||||
return _EXPIRED_OBJECT
|
||||
elif value is None or value is _NO_OBJECT:
|
||||
return None
|
||||
values.append(value)
|
||||
return tuple(values)
|
||||
@@ -183,13 +218,21 @@ class EvaluatorCompiler:
|
||||
|
||||
def visit_is_binary_op(self, operator, eval_left, eval_right, clause):
|
||||
def evaluate(obj):
|
||||
return eval_left(obj) == eval_right(obj)
|
||||
left_val = eval_left(obj)
|
||||
right_val = eval_right(obj)
|
||||
if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
|
||||
return _EXPIRED_OBJECT
|
||||
return left_val == right_val
|
||||
|
||||
return evaluate
|
||||
|
||||
def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause):
|
||||
def evaluate(obj):
|
||||
return eval_left(obj) != eval_right(obj)
|
||||
left_val = eval_left(obj)
|
||||
right_val = eval_right(obj)
|
||||
if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
|
||||
return _EXPIRED_OBJECT
|
||||
return left_val != right_val
|
||||
|
||||
return evaluate
|
||||
|
||||
@@ -197,8 +240,11 @@ class EvaluatorCompiler:
|
||||
def evaluate(obj):
|
||||
left_val = eval_left(obj)
|
||||
right_val = eval_right(obj)
|
||||
if left_val is None or right_val is None:
|
||||
if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
|
||||
return _EXPIRED_OBJECT
|
||||
elif left_val is None or right_val is None:
|
||||
return None
|
||||
|
||||
return operator(eval_left(obj), eval_right(obj))
|
||||
|
||||
return evaluate
|
||||
@@ -274,7 +320,9 @@ class EvaluatorCompiler:
|
||||
|
||||
def evaluate(obj):
|
||||
value = eval_inner(obj)
|
||||
if value is None:
|
||||
if value is _EXPIRED_OBJECT:
|
||||
return _EXPIRED_OBJECT
|
||||
elif value is None:
|
||||
return None
|
||||
return not value
|
||||
|
||||
|
||||
@@ -68,6 +68,11 @@ class IdentityMap:
|
||||
) -> Optional[_O]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def fast_get_state(
|
||||
self, key: _IdentityKeyType[_O]
|
||||
) -> Optional[InstanceState[_O]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def keys(self) -> Iterable[_IdentityKeyType[Any]]:
|
||||
return self._dict.keys()
|
||||
|
||||
@@ -206,6 +211,11 @@ class WeakInstanceDict(IdentityMap):
|
||||
self._dict[key] = state
|
||||
state._instance_dict = self._wr
|
||||
|
||||
def fast_get_state(
|
||||
self, key: _IdentityKeyType[_O]
|
||||
) -> Optional[InstanceState[_O]]:
|
||||
return self._dict.get(key)
|
||||
|
||||
def get(
|
||||
self, key: _IdentityKeyType[_O], default: Optional[_O] = None
|
||||
) -> Optional[_O]:
|
||||
|
||||
@@ -29,7 +29,6 @@ from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy.orm.context import FromStatement
|
||||
from . import attributes
|
||||
from . import exc as orm_exc
|
||||
from . import path_registry
|
||||
@@ -37,6 +36,7 @@ from .base import _DEFER_FOR_STATE
|
||||
from .base import _RAISE_FOR_STATE
|
||||
from .base import _SET_DEFERRED_EXPIRED
|
||||
from .base import PassiveFlag
|
||||
from .context import FromStatement
|
||||
from .util import _none_set
|
||||
from .util import state_str
|
||||
from .. import exc as sa_exc
|
||||
@@ -50,6 +50,7 @@ from ..sql import util as sql_util
|
||||
from ..sql.selectable import ForUpdateArg
|
||||
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
|
||||
from ..sql.selectable import SelectState
|
||||
from ..util import EMPTY_DICT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._typing import _IdentityKeyType
|
||||
@@ -764,7 +765,7 @@ def _instance_processor(
|
||||
)
|
||||
|
||||
quick_populators = path.get(
|
||||
context.attributes, "memoized_setups", _none_set
|
||||
context.attributes, "memoized_setups", EMPTY_DICT
|
||||
)
|
||||
|
||||
todo = []
|
||||
|
||||
@@ -854,6 +854,7 @@ class Mapper(
|
||||
_memoized_values: Dict[Any, Callable[[], Any]]
|
||||
_inheriting_mappers: util.WeakSequence[Mapper[Any]]
|
||||
_all_tables: Set[Table]
|
||||
_polymorphic_attr_key: Optional[str]
|
||||
|
||||
_pks_by_table: Dict[FromClause, OrderedSet[ColumnClause[Any]]]
|
||||
_cols_by_table: Dict[FromClause, OrderedSet[ColumnElement[Any]]]
|
||||
@@ -1653,6 +1654,7 @@ class Mapper(
|
||||
|
||||
"""
|
||||
setter = False
|
||||
polymorphic_key: Optional[str] = None
|
||||
|
||||
if self.polymorphic_on is not None:
|
||||
setter = True
|
||||
@@ -1772,17 +1774,23 @@ class Mapper(
|
||||
self._set_polymorphic_identity = (
|
||||
mapper._set_polymorphic_identity
|
||||
)
|
||||
self._polymorphic_attr_key = (
|
||||
mapper._polymorphic_attr_key
|
||||
)
|
||||
self._validate_polymorphic_identity = (
|
||||
mapper._validate_polymorphic_identity
|
||||
)
|
||||
else:
|
||||
self._set_polymorphic_identity = None
|
||||
self._polymorphic_attr_key = None
|
||||
return
|
||||
|
||||
if setter:
|
||||
|
||||
def _set_polymorphic_identity(state):
|
||||
dict_ = state.dict
|
||||
# TODO: what happens if polymorphic_on column attribute name
|
||||
# does not match .key?
|
||||
state.get_impl(polymorphic_key).set(
|
||||
state,
|
||||
dict_,
|
||||
@@ -1790,6 +1798,8 @@ class Mapper(
|
||||
None,
|
||||
)
|
||||
|
||||
self._polymorphic_attr_key = polymorphic_key
|
||||
|
||||
def _validate_polymorphic_identity(mapper, state, dict_):
|
||||
if (
|
||||
polymorphic_key in dict_
|
||||
@@ -1808,6 +1818,7 @@ class Mapper(
|
||||
_validate_polymorphic_identity
|
||||
)
|
||||
else:
|
||||
self._polymorphic_attr_key = None
|
||||
self._set_polymorphic_identity = None
|
||||
|
||||
_validate_polymorphic_identity = None
|
||||
@@ -3561,6 +3572,10 @@ class Mapper(
|
||||
def _compiled_cache(self):
|
||||
return util.LRUCache(self._compiled_cache_size)
|
||||
|
||||
@HasMemoized.memoized_attribute
|
||||
def _multiple_persistence_tables(self):
|
||||
return len(self.tables) > 1
|
||||
|
||||
@HasMemoized.memoized_attribute
|
||||
def _sorted_tables(self):
|
||||
table_to_mapper: Dict[Table, Mapper[Any]] = {}
|
||||
|
||||
@@ -31,6 +31,7 @@ from .. import exc as sa_exc
|
||||
from .. import future
|
||||
from .. import sql
|
||||
from .. import util
|
||||
from ..engine import cursor as _cursor
|
||||
from ..sql import operators
|
||||
from ..sql.elements import BooleanClauseList
|
||||
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
|
||||
@@ -398,6 +399,11 @@ def _collect_insert_commands(
|
||||
None
|
||||
)
|
||||
|
||||
if bulk and mapper._set_polymorphic_identity:
|
||||
params.setdefault(
|
||||
mapper._polymorphic_attr_key, mapper.polymorphic_identity
|
||||
)
|
||||
|
||||
yield (
|
||||
state,
|
||||
state_dict,
|
||||
@@ -411,7 +417,11 @@ def _collect_insert_commands(
|
||||
|
||||
|
||||
def _collect_update_commands(
|
||||
uowtransaction, table, states_to_update, bulk=False
|
||||
uowtransaction,
|
||||
table,
|
||||
states_to_update,
|
||||
bulk=False,
|
||||
use_orm_update_stmt=None,
|
||||
):
|
||||
"""Identify sets of values to use in UPDATE statements for a
|
||||
list of states.
|
||||
@@ -437,7 +447,11 @@ def _collect_update_commands(
|
||||
|
||||
pks = mapper._pks_by_table[table]
|
||||
|
||||
value_params = {}
|
||||
if use_orm_update_stmt is not None:
|
||||
# TODO: ordered values, etc
|
||||
value_params = use_orm_update_stmt._values
|
||||
else:
|
||||
value_params = {}
|
||||
|
||||
propkey_to_col = mapper._propkey_to_col[table]
|
||||
|
||||
@@ -697,6 +711,7 @@ def _emit_update_statements(
|
||||
table,
|
||||
update,
|
||||
bookkeeping=True,
|
||||
use_orm_update_stmt=None,
|
||||
):
|
||||
"""Emit UPDATE statements corresponding to value lists collected
|
||||
by _collect_update_commands()."""
|
||||
@@ -708,7 +723,7 @@ def _emit_update_statements(
|
||||
|
||||
execution_options = {"compiled_cache": base_mapper._compiled_cache}
|
||||
|
||||
def update_stmt():
|
||||
def update_stmt(existing_stmt=None):
|
||||
clauses = BooleanClauseList._construct_raw(operators.and_)
|
||||
|
||||
for col in mapper._pks_by_table[table]:
|
||||
@@ -725,10 +740,17 @@ def _emit_update_statements(
|
||||
)
|
||||
)
|
||||
|
||||
stmt = table.update().where(clauses)
|
||||
if existing_stmt is not None:
|
||||
stmt = existing_stmt.where(clauses)
|
||||
else:
|
||||
stmt = table.update().where(clauses)
|
||||
return stmt
|
||||
|
||||
cached_stmt = base_mapper._memo(("update", table), update_stmt)
|
||||
if use_orm_update_stmt is not None:
|
||||
cached_stmt = update_stmt(use_orm_update_stmt)
|
||||
|
||||
else:
|
||||
cached_stmt = base_mapper._memo(("update", table), update_stmt)
|
||||
|
||||
for (
|
||||
(connection, paramkeys, hasvalue, has_all_defaults, has_all_pks),
|
||||
@@ -747,6 +769,15 @@ def _emit_update_statements(
|
||||
records = list(records)
|
||||
|
||||
statement = cached_stmt
|
||||
|
||||
if use_orm_update_stmt is not None:
|
||||
statement = statement._annotate(
|
||||
{
|
||||
"_emit_update_table": table,
|
||||
"_emit_update_mapper": mapper,
|
||||
}
|
||||
)
|
||||
|
||||
return_defaults = False
|
||||
|
||||
if not has_all_pks:
|
||||
@@ -904,16 +935,35 @@ def _emit_insert_statements(
|
||||
table,
|
||||
insert,
|
||||
bookkeeping=True,
|
||||
use_orm_insert_stmt=None,
|
||||
execution_options=None,
|
||||
):
|
||||
"""Emit INSERT statements corresponding to value lists collected
|
||||
by _collect_insert_commands()."""
|
||||
|
||||
cached_stmt = base_mapper._memo(("insert", table), table.insert)
|
||||
if use_orm_insert_stmt is not None:
|
||||
cached_stmt = use_orm_insert_stmt
|
||||
exec_opt = util.EMPTY_DICT
|
||||
|
||||
execution_options = {"compiled_cache": base_mapper._compiled_cache}
|
||||
# if a user query with RETURNING was passed, we definitely need
|
||||
# to use RETURNING.
|
||||
returning_is_required_anyway = bool(use_orm_insert_stmt._returning)
|
||||
else:
|
||||
returning_is_required_anyway = False
|
||||
cached_stmt = base_mapper._memo(("insert", table), table.insert)
|
||||
exec_opt = {"compiled_cache": base_mapper._compiled_cache}
|
||||
|
||||
if execution_options:
|
||||
execution_options = util.EMPTY_DICT.merge_with(
|
||||
exec_opt, execution_options
|
||||
)
|
||||
else:
|
||||
execution_options = exec_opt
|
||||
|
||||
return_result = None
|
||||
|
||||
for (
|
||||
(connection, pkeys, hasvalue, has_all_pks, has_all_defaults),
|
||||
(connection, _, hasvalue, has_all_pks, has_all_defaults),
|
||||
records,
|
||||
) in groupby(
|
||||
insert,
|
||||
@@ -928,17 +978,29 @@ def _emit_insert_statements(
|
||||
|
||||
statement = cached_stmt
|
||||
|
||||
if (
|
||||
not bookkeeping
|
||||
or (
|
||||
has_all_defaults
|
||||
or not base_mapper.eager_defaults
|
||||
or not base_mapper.local_table.implicit_returning
|
||||
or not connection.dialect.insert_returning
|
||||
if use_orm_insert_stmt is not None:
|
||||
statement = statement._annotate(
|
||||
{
|
||||
"_emit_insert_table": table,
|
||||
"_emit_insert_mapper": mapper,
|
||||
}
|
||||
)
|
||||
|
||||
if (
|
||||
(
|
||||
not bookkeeping
|
||||
or (
|
||||
has_all_defaults
|
||||
or not base_mapper.eager_defaults
|
||||
or not base_mapper.local_table.implicit_returning
|
||||
or not connection.dialect.insert_returning
|
||||
)
|
||||
)
|
||||
and not returning_is_required_anyway
|
||||
and has_all_pks
|
||||
and not hasvalue
|
||||
):
|
||||
|
||||
# the "we don't need newly generated values back" section.
|
||||
# here we have all the PKs, all the defaults or we don't want
|
||||
# to fetch them, or the dialect doesn't support RETURNING at all
|
||||
@@ -946,7 +1008,7 @@ def _emit_insert_statements(
|
||||
records = list(records)
|
||||
multiparams = [rec[2] for rec in records]
|
||||
|
||||
c = connection.execute(
|
||||
result = connection.execute(
|
||||
statement, multiparams, execution_options=execution_options
|
||||
)
|
||||
if bookkeeping:
|
||||
@@ -962,7 +1024,7 @@ def _emit_insert_statements(
|
||||
has_all_defaults,
|
||||
),
|
||||
last_inserted_params,
|
||||
) in zip(records, c.context.compiled_parameters):
|
||||
) in zip(records, result.context.compiled_parameters):
|
||||
if state:
|
||||
_postfetch(
|
||||
mapper_rec,
|
||||
@@ -970,19 +1032,20 @@ def _emit_insert_statements(
|
||||
table,
|
||||
state,
|
||||
state_dict,
|
||||
c,
|
||||
result,
|
||||
last_inserted_params,
|
||||
value_params,
|
||||
False,
|
||||
c.returned_defaults
|
||||
if not c.context.executemany
|
||||
result.returned_defaults
|
||||
if not result.context.executemany
|
||||
else None,
|
||||
)
|
||||
else:
|
||||
_postfetch_bulk_save(mapper_rec, state_dict, table)
|
||||
|
||||
else:
|
||||
# here, we need defaults and/or pk values back.
|
||||
# here, we need defaults and/or pk values back or we otherwise
|
||||
# know that we are using RETURNING in any case
|
||||
|
||||
records = list(records)
|
||||
if (
|
||||
@@ -991,6 +1054,16 @@ def _emit_insert_statements(
|
||||
and len(records) > 1
|
||||
):
|
||||
do_executemany = True
|
||||
elif returning_is_required_anyway:
|
||||
if connection.dialect.insert_executemany_returning:
|
||||
do_executemany = True
|
||||
else:
|
||||
raise sa_exc.InvalidRequestError(
|
||||
f"Can't use explicit RETURNING for bulk INSERT "
|
||||
f"operation with "
|
||||
f"{connection.dialect.dialect_description} backend; "
|
||||
f"executemany is not supported with RETURNING"
|
||||
)
|
||||
else:
|
||||
do_executemany = False
|
||||
|
||||
@@ -998,6 +1071,7 @@ def _emit_insert_statements(
|
||||
statement = statement.return_defaults(
|
||||
*mapper._server_default_cols[table]
|
||||
)
|
||||
|
||||
if mapper.version_id_col is not None:
|
||||
statement = statement.return_defaults(mapper.version_id_col)
|
||||
elif do_executemany:
|
||||
@@ -1006,10 +1080,16 @@ def _emit_insert_statements(
|
||||
if do_executemany:
|
||||
multiparams = [rec[2] for rec in records]
|
||||
|
||||
c = connection.execute(
|
||||
result = connection.execute(
|
||||
statement, multiparams, execution_options=execution_options
|
||||
)
|
||||
|
||||
if use_orm_insert_stmt is not None:
|
||||
if return_result is None:
|
||||
return_result = result
|
||||
else:
|
||||
return_result = return_result.splice_vertically(result)
|
||||
|
||||
if bookkeeping:
|
||||
for (
|
||||
(
|
||||
@@ -1027,9 +1107,9 @@ def _emit_insert_statements(
|
||||
returned_defaults,
|
||||
) in zip_longest(
|
||||
records,
|
||||
c.context.compiled_parameters,
|
||||
c.inserted_primary_key_rows,
|
||||
c.returned_defaults_rows or (),
|
||||
result.context.compiled_parameters,
|
||||
result.inserted_primary_key_rows,
|
||||
result.returned_defaults_rows or (),
|
||||
):
|
||||
if inserted_primary_key is None:
|
||||
# this is a real problem and means that we didn't
|
||||
@@ -1062,7 +1142,7 @@ def _emit_insert_statements(
|
||||
table,
|
||||
state,
|
||||
state_dict,
|
||||
c,
|
||||
result,
|
||||
last_inserted_params,
|
||||
value_params,
|
||||
False,
|
||||
@@ -1071,6 +1151,8 @@ def _emit_insert_statements(
|
||||
else:
|
||||
_postfetch_bulk_save(mapper_rec, state_dict, table)
|
||||
else:
|
||||
assert not returning_is_required_anyway
|
||||
|
||||
for (
|
||||
state,
|
||||
state_dict,
|
||||
@@ -1132,6 +1214,12 @@ def _emit_insert_statements(
|
||||
else:
|
||||
_postfetch_bulk_save(mapper_rec, state_dict, table)
|
||||
|
||||
if use_orm_insert_stmt is not None:
|
||||
if return_result is None:
|
||||
return _cursor.null_dml_result()
|
||||
else:
|
||||
return return_result
|
||||
|
||||
|
||||
def _emit_post_update_statements(
|
||||
base_mapper, uowtransaction, mapper, table, update
|
||||
|
||||
@@ -2978,7 +2978,7 @@ class Query(
|
||||
)
|
||||
|
||||
def delete(
|
||||
self, synchronize_session: _SynchronizeSessionArgument = "evaluate"
|
||||
self, synchronize_session: _SynchronizeSessionArgument = "auto"
|
||||
) -> int:
|
||||
r"""Perform a DELETE with an arbitrary WHERE clause.
|
||||
|
||||
@@ -3042,7 +3042,7 @@ class Query(
|
||||
def update(
|
||||
self,
|
||||
values: Dict[_DMLColumnArgument, Any],
|
||||
synchronize_session: _SynchronizeSessionArgument = "evaluate",
|
||||
synchronize_session: _SynchronizeSessionArgument = "auto",
|
||||
update_args: Optional[Dict[Any, Any]] = None,
|
||||
) -> int:
|
||||
r"""Perform an UPDATE with an arbitrary WHERE clause.
|
||||
|
||||
@@ -1828,12 +1828,13 @@ class Session(_SessionClassMethods, EventTarget):
|
||||
statement._propagate_attrs.get("compile_state_plugin", None)
|
||||
== "orm"
|
||||
):
|
||||
# note that even without "future" mode, we need
|
||||
compile_state_cls = CompileState._get_plugin_class_for_plugin(
|
||||
statement, "orm"
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(compile_state_cls, ORMCompileState)
|
||||
assert isinstance(
|
||||
compile_state_cls, context.AbstractORMCompileState
|
||||
)
|
||||
else:
|
||||
compile_state_cls = None
|
||||
|
||||
@@ -1897,18 +1898,18 @@ class Session(_SessionClassMethods, EventTarget):
|
||||
statement, params or {}, execution_options=execution_options
|
||||
)
|
||||
|
||||
result: Result[Any] = conn.execute(
|
||||
statement, params or {}, execution_options=execution_options
|
||||
)
|
||||
|
||||
if compile_state_cls:
|
||||
result = compile_state_cls.orm_setup_cursor_result(
|
||||
result: Result[Any] = compile_state_cls.orm_execute_statement(
|
||||
self,
|
||||
statement,
|
||||
params,
|
||||
params or {},
|
||||
execution_options,
|
||||
bind_arguments,
|
||||
result,
|
||||
conn,
|
||||
)
|
||||
else:
|
||||
result = conn.execute(
|
||||
statement, params or {}, execution_options=execution_options
|
||||
)
|
||||
|
||||
if _scalar_result:
|
||||
@@ -2066,7 +2067,7 @@ class Session(_SessionClassMethods, EventTarget):
|
||||
def scalars(
|
||||
self,
|
||||
statement: TypedReturnsRows[Tuple[_T]],
|
||||
params: Optional[_CoreSingleExecuteParams] = None,
|
||||
params: Optional[_CoreAnyExecuteParams] = None,
|
||||
*,
|
||||
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
|
||||
bind_arguments: Optional[_BindArguments] = None,
|
||||
@@ -2078,7 +2079,7 @@ class Session(_SessionClassMethods, EventTarget):
|
||||
def scalars(
|
||||
self,
|
||||
statement: Executable,
|
||||
params: Optional[_CoreSingleExecuteParams] = None,
|
||||
params: Optional[_CoreAnyExecuteParams] = None,
|
||||
*,
|
||||
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
|
||||
bind_arguments: Optional[_BindArguments] = None,
|
||||
@@ -2089,7 +2090,7 @@ class Session(_SessionClassMethods, EventTarget):
|
||||
def scalars(
|
||||
self,
|
||||
statement: Executable,
|
||||
params: Optional[_CoreSingleExecuteParams] = None,
|
||||
params: Optional[_CoreAnyExecuteParams] = None,
|
||||
*,
|
||||
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
|
||||
bind_arguments: Optional[_BindArguments] = None,
|
||||
|
||||
@@ -227,6 +227,11 @@ class ColumnLoader(LoaderStrategy):
|
||||
fetch = self.columns[0]
|
||||
if adapter:
|
||||
fetch = adapter.columns[fetch]
|
||||
if fetch is None:
|
||||
# None happens here only for dml bulk_persistence cases
|
||||
# when context.DMLReturningColFilter is used
|
||||
return
|
||||
|
||||
memoized_populators[self.parent_property] = fetch
|
||||
|
||||
def init_class_attribute(self, mapper):
|
||||
@@ -318,6 +323,12 @@ class ExpressionColumnLoader(ColumnLoader):
|
||||
fetch = columns[0]
|
||||
if adapter:
|
||||
fetch = adapter.columns[fetch]
|
||||
if fetch is None:
|
||||
# None is not expected to be the result of any
|
||||
# adapter implementation here, however there may be theoretical
|
||||
# usages of returning() with context.DMLReturningColFilter
|
||||
return
|
||||
|
||||
memoized_populators[self.parent_property] = fetch
|
||||
|
||||
def create_row_processor(
|
||||
|
||||
@@ -552,6 +552,8 @@ def _new_annotation_type(
|
||||
# e.g. BindParameter, add it if present.
|
||||
if cls.__dict__.get("inherit_cache", False):
|
||||
anno_cls.inherit_cache = True # type: ignore
|
||||
elif "inherit_cache" in cls.__dict__:
|
||||
anno_cls.inherit_cache = cls.__dict__["inherit_cache"] # type: ignore
|
||||
|
||||
anno_cls._is_column_operators = issubclass(cls, operators.ColumnOperators)
|
||||
|
||||
|
||||
@@ -5166,6 +5166,8 @@ class SQLCompiler(Compiled):
|
||||
delete_stmt, delete_stmt.table, extra_froms
|
||||
)
|
||||
|
||||
crud._get_crud_params(self, delete_stmt, compile_state, toplevel, **kw)
|
||||
|
||||
if delete_stmt._hints:
|
||||
dialect_hints, table_text = self._setup_crud_hints(
|
||||
delete_stmt, table_text
|
||||
@@ -5178,13 +5180,14 @@ class SQLCompiler(Compiled):
|
||||
|
||||
text += table_text
|
||||
|
||||
if delete_stmt._returning:
|
||||
if self.returning_precedes_values:
|
||||
text += " " + self.returning_clause(
|
||||
delete_stmt,
|
||||
delete_stmt._returning,
|
||||
populate_result_map=toplevel,
|
||||
)
|
||||
if (
|
||||
self.implicit_returning or delete_stmt._returning
|
||||
) and self.returning_precedes_values:
|
||||
text += " " + self.returning_clause(
|
||||
delete_stmt,
|
||||
self.implicit_returning or delete_stmt._returning,
|
||||
populate_result_map=toplevel,
|
||||
)
|
||||
|
||||
if extra_froms:
|
||||
extra_from_text = self.delete_extra_from_clause(
|
||||
@@ -5204,10 +5207,12 @@ class SQLCompiler(Compiled):
|
||||
if t:
|
||||
text += " WHERE " + t
|
||||
|
||||
if delete_stmt._returning and not self.returning_precedes_values:
|
||||
if (
|
||||
self.implicit_returning or delete_stmt._returning
|
||||
) and not self.returning_precedes_values:
|
||||
text += " " + self.returning_clause(
|
||||
delete_stmt,
|
||||
delete_stmt._returning,
|
||||
self.implicit_returning or delete_stmt._returning,
|
||||
populate_result_map=toplevel,
|
||||
)
|
||||
|
||||
@@ -5297,7 +5302,6 @@ class StrSQLCompiler(SQLCompiler):
|
||||
self._label_select_column(None, c, True, False, {})
|
||||
for c in base._select_iterables(returning_cols)
|
||||
]
|
||||
|
||||
return "RETURNING " + ", ".join(columns)
|
||||
|
||||
def update_from_clause(
|
||||
|
||||
+101
-18
@@ -150,6 +150,22 @@ def _get_crud_params(
|
||||
"return_defaults() simultaneously"
|
||||
)
|
||||
|
||||
if compile_state.isdelete:
|
||||
_setup_delete_return_defaults(
|
||||
compiler,
|
||||
stmt,
|
||||
compile_state,
|
||||
(),
|
||||
_getattr_col_key,
|
||||
_column_as_key,
|
||||
_col_bind_name,
|
||||
(),
|
||||
(),
|
||||
toplevel,
|
||||
kw,
|
||||
)
|
||||
return _CrudParams([], [])
|
||||
|
||||
# no parameters in the statement, no parameters in the
|
||||
# compiled params - return binds for all columns
|
||||
if compiler.column_keys is None and compile_state._no_parameters:
|
||||
@@ -466,13 +482,6 @@ def _scan_insert_from_select_cols(
|
||||
kw,
|
||||
):
|
||||
|
||||
(
|
||||
need_pks,
|
||||
implicit_returning,
|
||||
implicit_return_defaults,
|
||||
postfetch_lastrowid,
|
||||
) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel)
|
||||
|
||||
cols = [stmt.table.c[_column_as_key(name)] for name in stmt._select_names]
|
||||
|
||||
assert compiler.stack[-1]["selectable"] is stmt
|
||||
@@ -537,6 +546,8 @@ def _scan_cols(
|
||||
postfetch_lastrowid,
|
||||
) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel)
|
||||
|
||||
assert compile_state.isupdate or compile_state.isinsert
|
||||
|
||||
if compile_state._parameter_ordering:
|
||||
parameter_ordering = [
|
||||
_column_as_key(key) for key in compile_state._parameter_ordering
|
||||
@@ -563,6 +574,13 @@ def _scan_cols(
|
||||
else:
|
||||
autoincrement_col = insert_null_pk_still_autoincrements = None
|
||||
|
||||
if stmt._supplemental_returning:
|
||||
supplemental_returning = set(stmt._supplemental_returning)
|
||||
else:
|
||||
supplemental_returning = set()
|
||||
|
||||
compiler_implicit_returning = compiler.implicit_returning
|
||||
|
||||
for c in cols:
|
||||
# scan through every column in the target table
|
||||
|
||||
@@ -627,11 +645,13 @@ def _scan_cols(
|
||||
# column has a DDL-level default, and is either not a pk
|
||||
# column or we don't need the pk.
|
||||
if implicit_return_defaults and c in implicit_return_defaults:
|
||||
compiler.implicit_returning.append(c)
|
||||
compiler_implicit_returning.append(c)
|
||||
elif not c.primary_key:
|
||||
compiler.postfetch.append(c)
|
||||
|
||||
elif implicit_return_defaults and c in implicit_return_defaults:
|
||||
compiler.implicit_returning.append(c)
|
||||
compiler_implicit_returning.append(c)
|
||||
|
||||
elif (
|
||||
c.primary_key
|
||||
and c is not stmt.table._autoincrement_column
|
||||
@@ -652,6 +672,59 @@ def _scan_cols(
|
||||
kw,
|
||||
)
|
||||
|
||||
# adding supplemental cols to implicit_returning in table
|
||||
# order so that order is maintained between multiple INSERT
|
||||
# statements which may have different parameters included, but all
|
||||
# have the same RETURNING clause
|
||||
if (
|
||||
c in supplemental_returning
|
||||
and c not in compiler_implicit_returning
|
||||
):
|
||||
compiler_implicit_returning.append(c)
|
||||
|
||||
if supplemental_returning:
|
||||
# we should have gotten every col into implicit_returning,
|
||||
# however supplemental returning can also have SQL functions etc.
|
||||
# in it
|
||||
remaining_supplemental = supplemental_returning.difference(
|
||||
compiler_implicit_returning
|
||||
)
|
||||
compiler_implicit_returning.extend(
|
||||
c
|
||||
for c in stmt._supplemental_returning
|
||||
if c in remaining_supplemental
|
||||
)
|
||||
|
||||
|
||||
def _setup_delete_return_defaults(
|
||||
compiler,
|
||||
stmt,
|
||||
compile_state,
|
||||
parameters,
|
||||
_getattr_col_key,
|
||||
_column_as_key,
|
||||
_col_bind_name,
|
||||
check_columns,
|
||||
values,
|
||||
toplevel,
|
||||
kw,
|
||||
):
|
||||
(_, _, implicit_return_defaults, _) = _get_returning_modifiers(
|
||||
compiler, stmt, compile_state, toplevel
|
||||
)
|
||||
|
||||
if not implicit_return_defaults:
|
||||
return
|
||||
|
||||
if stmt._return_defaults_columns:
|
||||
compiler.implicit_returning.extend(implicit_return_defaults)
|
||||
|
||||
if stmt._supplemental_returning:
|
||||
ir_set = set(compiler.implicit_returning)
|
||||
compiler.implicit_returning.extend(
|
||||
c for c in stmt._supplemental_returning if c not in ir_set
|
||||
)
|
||||
|
||||
|
||||
def _append_param_parameter(
|
||||
compiler,
|
||||
@@ -743,7 +816,7 @@ def _append_param_parameter(
|
||||
elif compiler.dialect.postfetch_lastrowid:
|
||||
compiler.postfetch_lastrowid = True
|
||||
|
||||
elif implicit_return_defaults and c in implicit_return_defaults:
|
||||
elif implicit_return_defaults and (c in implicit_return_defaults):
|
||||
compiler.implicit_returning.append(c)
|
||||
|
||||
else:
|
||||
@@ -1303,6 +1376,7 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel):
|
||||
INSERT or UPDATE statement after it's invoked.
|
||||
|
||||
"""
|
||||
|
||||
need_pks = (
|
||||
toplevel
|
||||
and _compile_state_isinsert(compile_state)
|
||||
@@ -1315,6 +1389,7 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel):
|
||||
)
|
||||
)
|
||||
and not stmt._returning
|
||||
# and (not stmt._returning or stmt._return_defaults)
|
||||
and not compile_state._has_multi_parameters
|
||||
)
|
||||
|
||||
@@ -1357,33 +1432,41 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel):
|
||||
or stmt._return_defaults
|
||||
)
|
||||
)
|
||||
|
||||
if implicit_returning:
|
||||
postfetch_lastrowid = False
|
||||
|
||||
if _compile_state_isinsert(compile_state):
|
||||
implicit_return_defaults = implicit_returning and stmt._return_defaults
|
||||
should_implicit_return_defaults = (
|
||||
implicit_returning and stmt._return_defaults
|
||||
)
|
||||
elif compile_state.isupdate:
|
||||
implicit_return_defaults = (
|
||||
should_implicit_return_defaults = (
|
||||
stmt._return_defaults
|
||||
and compile_state._primary_table.implicit_returning
|
||||
and compile_state._supports_implicit_returning
|
||||
and compiler.dialect.update_returning
|
||||
)
|
||||
elif compile_state.isdelete:
|
||||
should_implicit_return_defaults = (
|
||||
stmt._return_defaults
|
||||
and compile_state._primary_table.implicit_returning
|
||||
and compile_state._supports_implicit_returning
|
||||
and compiler.dialect.delete_returning
|
||||
)
|
||||
else:
|
||||
# this line is unused, currently we are always
|
||||
# isinsert or isupdate
|
||||
implicit_return_defaults = False # pragma: no cover
|
||||
should_implicit_return_defaults = False # pragma: no cover
|
||||
|
||||
if implicit_return_defaults:
|
||||
if should_implicit_return_defaults:
|
||||
if not stmt._return_defaults_columns:
|
||||
implicit_return_defaults = set(stmt.table.c)
|
||||
else:
|
||||
implicit_return_defaults = set(stmt._return_defaults_columns)
|
||||
else:
|
||||
implicit_return_defaults = None
|
||||
|
||||
return (
|
||||
need_pks,
|
||||
implicit_returning,
|
||||
implicit_returning or should_implicit_return_defaults,
|
||||
implicit_return_defaults,
|
||||
postfetch_lastrowid,
|
||||
)
|
||||
|
||||
+241
-168
@@ -164,16 +164,33 @@ class DMLState(CompileState):
|
||||
def get_plugin_class(cls, statement: Executable) -> Type[DMLState]:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def _get_multi_crud_kv_pairs(
|
||||
cls,
|
||||
statement: UpdateBase,
|
||||
multi_kv_iterator: Iterable[Dict[_DMLColumnArgument, Any]],
|
||||
) -> List[Dict[_DMLColumnElement, Any]]:
|
||||
return [
|
||||
{
|
||||
coercions.expect(roles.DMLColumnRole, k): v
|
||||
for k, v in mapping.items()
|
||||
}
|
||||
for mapping in multi_kv_iterator
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _get_crud_kv_pairs(
|
||||
cls,
|
||||
statement: UpdateBase,
|
||||
kv_iterator: Iterable[Tuple[_DMLColumnArgument, Any]],
|
||||
needs_to_be_cacheable: bool,
|
||||
) -> List[Tuple[_DMLColumnElement, Any]]:
|
||||
return [
|
||||
(
|
||||
coercions.expect(roles.DMLColumnRole, k),
|
||||
coercions.expect(
|
||||
v
|
||||
if not needs_to_be_cacheable
|
||||
else coercions.expect(
|
||||
roles.ExpressionElementRole,
|
||||
v,
|
||||
type_=NullType(),
|
||||
@@ -269,7 +286,7 @@ class InsertDMLState(DMLState):
|
||||
def _insert_col_keys(self) -> List[str]:
|
||||
# this is also done in crud.py -> _key_getters_for_crud_column
|
||||
return [
|
||||
coercions.expect_as_key(roles.DMLColumnRole, col)
|
||||
coercions.expect(roles.DMLColumnRole, col, as_key=True)
|
||||
for col in self._dict_parameters or ()
|
||||
]
|
||||
|
||||
@@ -326,7 +343,6 @@ class UpdateDMLState(DMLState):
|
||||
self._extra_froms = ef
|
||||
|
||||
self.is_multitable = mt = ef
|
||||
|
||||
self.include_table_with_column_exprs = bool(
|
||||
mt and compiler.render_table_with_column_in_update_from
|
||||
)
|
||||
@@ -389,6 +405,7 @@ class UpdateBase(
|
||||
_return_defaults_columns: Optional[
|
||||
Tuple[_ColumnsClauseElement, ...]
|
||||
] = None
|
||||
_supplemental_returning: Optional[Tuple[_ColumnsClauseElement, ...]] = None
|
||||
_returning: Tuple[_ColumnsClauseElement, ...] = ()
|
||||
|
||||
is_dml = True
|
||||
@@ -434,6 +451,215 @@ class UpdateBase(
|
||||
self._validate_dialect_kwargs(opt)
|
||||
return self
|
||||
|
||||
@_generative
|
||||
def return_defaults(
|
||||
self: SelfUpdateBase,
|
||||
*cols: _DMLColumnArgument,
|
||||
supplemental_cols: Optional[Iterable[_DMLColumnArgument]] = None,
|
||||
) -> SelfUpdateBase:
|
||||
"""Make use of a :term:`RETURNING` clause for the purpose
|
||||
of fetching server-side expressions and defaults, for supporting
|
||||
backends only.
|
||||
|
||||
.. deepalchemy::
|
||||
|
||||
The :meth:`.UpdateBase.return_defaults` method is used by the ORM
|
||||
for its internal work in fetching newly generated primary key
|
||||
and server default values, in particular to provide the underyling
|
||||
implementation of the :paramref:`_orm.Mapper.eager_defaults`
|
||||
ORM feature as well as to allow RETURNING support with bulk
|
||||
ORM inserts. Its behavior is fairly idiosyncratic
|
||||
and is not really intended for general use. End users should
|
||||
stick with using :meth:`.UpdateBase.returning` in order to
|
||||
add RETURNING clauses to their INSERT, UPDATE and DELETE
|
||||
statements.
|
||||
|
||||
Normally, a single row INSERT statement will automatically populate the
|
||||
:attr:`.CursorResult.inserted_primary_key` attribute when executed,
|
||||
which stores the primary key of the row that was just inserted in the
|
||||
form of a :class:`.Row` object with column names as named tuple keys
|
||||
(and the :attr:`.Row._mapping` view fully populated as well). The
|
||||
dialect in use chooses the strategy to use in order to populate this
|
||||
data; if it was generated using server-side defaults and / or SQL
|
||||
expressions, dialect-specific approaches such as ``cursor.lastrowid``
|
||||
or ``RETURNING`` are typically used to acquire the new primary key
|
||||
value.
|
||||
|
||||
However, when the statement is modified by calling
|
||||
:meth:`.UpdateBase.return_defaults` before executing the statement,
|
||||
additional behaviors take place **only** for backends that support
|
||||
RETURNING and for :class:`.Table` objects that maintain the
|
||||
:paramref:`.Table.implicit_returning` parameter at its default value of
|
||||
``True``. In these cases, when the :class:`.CursorResult` is returned
|
||||
from the statement's execution, not only will
|
||||
:attr:`.CursorResult.inserted_primary_key` be populated as always, the
|
||||
:attr:`.CursorResult.returned_defaults` attribute will also be
|
||||
populated with a :class:`.Row` named-tuple representing the full range
|
||||
of server generated
|
||||
values from that single row, including values for any columns that
|
||||
specify :paramref:`_schema.Column.server_default` or which make use of
|
||||
:paramref:`_schema.Column.default` using a SQL expression.
|
||||
|
||||
When invoking INSERT statements with multiple rows using
|
||||
:ref:`insertmanyvalues <engine_insertmanyvalues>`, the
|
||||
:meth:`.UpdateBase.return_defaults` modifier will have the effect of
|
||||
the :attr:`_engine.CursorResult.inserted_primary_key_rows` and
|
||||
:attr:`_engine.CursorResult.returned_defaults_rows` attributes being
|
||||
fully populated with lists of :class:`.Row` objects representing newly
|
||||
inserted primary key values as well as newly inserted server generated
|
||||
values for each row inserted. The
|
||||
:attr:`.CursorResult.inserted_primary_key` and
|
||||
:attr:`.CursorResult.returned_defaults` attributes will also continue
|
||||
to be populated with the first row of these two collections.
|
||||
|
||||
If the backend does not support RETURNING or the :class:`.Table` in use
|
||||
has disabled :paramref:`.Table.implicit_returning`, then no RETURNING
|
||||
clause is added and no additional data is fetched, however the
|
||||
INSERT, UPDATE or DELETE statement proceeds normally.
|
||||
|
||||
E.g.::
|
||||
|
||||
stmt = table.insert().values(data='newdata').return_defaults()
|
||||
|
||||
result = connection.execute(stmt)
|
||||
|
||||
server_created_at = result.returned_defaults['created_at']
|
||||
|
||||
When used against an UPDATE statement
|
||||
:meth:`.UpdateBase.return_defaults` instead looks for columns that
|
||||
include :paramref:`_schema.Column.onupdate` or
|
||||
:paramref:`_schema.Column.server_onupdate` parameters assigned, when
|
||||
constructing the columns that will be included in the RETURNING clause
|
||||
by default if explicit columns were not specified. When used against a
|
||||
DELETE statement, no columns are included in RETURNING by default, they
|
||||
instead must be specified explicitly as there are no columns that
|
||||
normally change values when a DELETE statement proceeds.
|
||||
|
||||
.. versionadded:: 2.0 :meth:`.UpdateBase.return_defaults` is supported
|
||||
for DELETE statements also and has been moved from
|
||||
:class:`.ValuesBase` to :class:`.UpdateBase`.
|
||||
|
||||
The :meth:`.UpdateBase.return_defaults` method is mutually exclusive
|
||||
against the :meth:`.UpdateBase.returning` method and errors will be
|
||||
raised during the SQL compilation process if both are used at the same
|
||||
time on one statement. The RETURNING clause of the INSERT, UPDATE or
|
||||
DELETE statement is therefore controlled by only one of these methods
|
||||
at a time.
|
||||
|
||||
The :meth:`.UpdateBase.return_defaults` method differs from
|
||||
:meth:`.UpdateBase.returning` in these ways:
|
||||
|
||||
1. :meth:`.UpdateBase.return_defaults` method causes the
|
||||
:attr:`.CursorResult.returned_defaults` collection to be populated
|
||||
with the first row from the RETURNING result. This attribute is not
|
||||
populated when using :meth:`.UpdateBase.returning`.
|
||||
|
||||
2. :meth:`.UpdateBase.return_defaults` is compatible with existing
|
||||
logic used to fetch auto-generated primary key values that are then
|
||||
populated into the :attr:`.CursorResult.inserted_primary_key`
|
||||
attribute. By contrast, using :meth:`.UpdateBase.returning` will
|
||||
have the effect of the :attr:`.CursorResult.inserted_primary_key`
|
||||
attribute being left unpopulated.
|
||||
|
||||
3. :meth:`.UpdateBase.return_defaults` can be called against any
|
||||
backend. Backends that don't support RETURNING will skip the usage
|
||||
of the feature, rather than raising an exception. The return value
|
||||
of :attr:`_engine.CursorResult.returned_defaults` will be ``None``
|
||||
for backends that don't support RETURNING or for which the target
|
||||
:class:`.Table` sets :paramref:`.Table.implicit_returning` to
|
||||
``False``.
|
||||
|
||||
4. An INSERT statement invoked with executemany() is supported if the
|
||||
backend database driver supports the
|
||||
:ref:`insertmanyvalues <engine_insertmanyvalues>`
|
||||
feature which is now supported by most SQLAlchemy-included backends.
|
||||
When executemany is used, the
|
||||
:attr:`_engine.CursorResult.returned_defaults_rows` and
|
||||
:attr:`_engine.CursorResult.inserted_primary_key_rows` accessors
|
||||
will return the inserted defaults and primary keys.
|
||||
|
||||
.. versionadded:: 1.4 Added
|
||||
:attr:`_engine.CursorResult.returned_defaults_rows` and
|
||||
:attr:`_engine.CursorResult.inserted_primary_key_rows` accessors.
|
||||
In version 2.0, the underlying implementation which fetches and
|
||||
populates the data for these attributes was generalized to be
|
||||
supported by most backends, whereas in 1.4 they were only
|
||||
supported by the ``psycopg2`` driver.
|
||||
|
||||
|
||||
:param cols: optional list of column key names or
|
||||
:class:`_schema.Column` that acts as a filter for those columns that
|
||||
will be fetched.
|
||||
:param supplemental_cols: optional list of RETURNING expressions,
|
||||
in the same form as one would pass to the
|
||||
:meth:`.UpdateBase.returning` method. When present, the additional
|
||||
columns will be included in the RETURNING clause, and the
|
||||
:class:`.CursorResult` object will be "rewound" when returned, so
|
||||
that methods like :meth:`.CursorResult.all` will return new rows
|
||||
mostly as though the statement used :meth:`.UpdateBase.returning`
|
||||
directly. However, unlike when using :meth:`.UpdateBase.returning`
|
||||
directly, the **order of the columns is undefined**, so can only be
|
||||
targeted using names or :attr:`.Row._mapping` keys; they cannot
|
||||
reliably be targeted positionally.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`.UpdateBase.returning`
|
||||
|
||||
:attr:`_engine.CursorResult.returned_defaults`
|
||||
|
||||
:attr:`_engine.CursorResult.returned_defaults_rows`
|
||||
|
||||
:attr:`_engine.CursorResult.inserted_primary_key`
|
||||
|
||||
:attr:`_engine.CursorResult.inserted_primary_key_rows`
|
||||
|
||||
"""
|
||||
|
||||
if self._return_defaults:
|
||||
# note _return_defaults_columns = () means return all columns,
|
||||
# so if we have been here before, only update collection if there
|
||||
# are columns in the collection
|
||||
if self._return_defaults_columns and cols:
|
||||
self._return_defaults_columns = tuple(
|
||||
util.OrderedSet(self._return_defaults_columns).union(
|
||||
coercions.expect(roles.ColumnsClauseRole, c)
|
||||
for c in cols
|
||||
)
|
||||
)
|
||||
else:
|
||||
# set for all columns
|
||||
self._return_defaults_columns = ()
|
||||
else:
|
||||
self._return_defaults_columns = tuple(
|
||||
coercions.expect(roles.ColumnsClauseRole, c) for c in cols
|
||||
)
|
||||
self._return_defaults = True
|
||||
|
||||
if supplemental_cols:
|
||||
# uniquifying while also maintaining order (the maintain of order
|
||||
# is for test suites but also for vertical splicing
|
||||
supplemental_col_tup = (
|
||||
coercions.expect(roles.ColumnsClauseRole, c)
|
||||
for c in supplemental_cols
|
||||
)
|
||||
|
||||
if self._supplemental_returning is None:
|
||||
self._supplemental_returning = tuple(
|
||||
util.unique_list(supplemental_col_tup)
|
||||
)
|
||||
else:
|
||||
self._supplemental_returning = tuple(
|
||||
util.unique_list(
|
||||
self._supplemental_returning
|
||||
+ tuple(supplemental_col_tup)
|
||||
)
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@_generative
|
||||
def returning(
|
||||
self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
|
||||
@@ -500,7 +726,7 @@ class UpdateBase(
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`.ValuesBase.return_defaults` - an alternative method tailored
|
||||
:meth:`.UpdateBase.return_defaults` - an alternative method tailored
|
||||
towards efficient fetching of server-side defaults and triggers
|
||||
for single-row INSERTs or UPDATEs.
|
||||
|
||||
@@ -703,7 +929,6 @@ class ValuesBase(UpdateBase):
|
||||
|
||||
_select_names: Optional[List[str]] = None
|
||||
_inline: bool = False
|
||||
_returning: Tuple[_ColumnsClauseElement, ...] = ()
|
||||
|
||||
def __init__(self, table: _DMLTableArgument):
|
||||
self.table = coercions.expect(
|
||||
@@ -859,7 +1084,15 @@ class ValuesBase(UpdateBase):
|
||||
)
|
||||
|
||||
elif isinstance(arg, collections_abc.Sequence):
|
||||
if arg and isinstance(arg[0], (list, dict, tuple)):
|
||||
|
||||
if arg and isinstance(arg[0], dict):
|
||||
multi_kv_generator = DMLState.get_plugin_class(
|
||||
self
|
||||
)._get_multi_crud_kv_pairs
|
||||
self._multi_values += (multi_kv_generator(self, arg),)
|
||||
return self
|
||||
|
||||
if arg and isinstance(arg[0], (list, tuple)):
|
||||
self._multi_values += (arg,)
|
||||
return self
|
||||
|
||||
@@ -888,173 +1121,13 @@ class ValuesBase(UpdateBase):
|
||||
# and ensures they get the "crud"-style name when rendered.
|
||||
|
||||
kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs
|
||||
coerced_arg = {k: v for k, v in kv_generator(self, arg.items())}
|
||||
coerced_arg = dict(kv_generator(self, arg.items(), True))
|
||||
if self._values:
|
||||
self._values = self._values.union(coerced_arg)
|
||||
else:
|
||||
self._values = util.immutabledict(coerced_arg)
|
||||
return self
|
||||
|
||||
@_generative
|
||||
def return_defaults(
|
||||
self: SelfValuesBase, *cols: _DMLColumnArgument
|
||||
) -> SelfValuesBase:
|
||||
"""Make use of a :term:`RETURNING` clause for the purpose
|
||||
of fetching server-side expressions and defaults, for supporting
|
||||
backends only.
|
||||
|
||||
.. tip::
|
||||
|
||||
The :meth:`.ValuesBase.return_defaults` method is used by the ORM
|
||||
for its internal work in fetching newly generated primary key
|
||||
and server default values, in particular to provide the underyling
|
||||
implementation of the :paramref:`_orm.Mapper.eager_defaults`
|
||||
ORM feature. Its behavior is fairly idiosyncratic
|
||||
and is not really intended for general use. End users should
|
||||
stick with using :meth:`.UpdateBase.returning` in order to
|
||||
add RETURNING clauses to their INSERT, UPDATE and DELETE
|
||||
statements.
|
||||
|
||||
Normally, a single row INSERT statement will automatically populate the
|
||||
:attr:`.CursorResult.inserted_primary_key` attribute when executed,
|
||||
which stores the primary key of the row that was just inserted in the
|
||||
form of a :class:`.Row` object with column names as named tuple keys
|
||||
(and the :attr:`.Row._mapping` view fully populated as well). The
|
||||
dialect in use chooses the strategy to use in order to populate this
|
||||
data; if it was generated using server-side defaults and / or SQL
|
||||
expressions, dialect-specific approaches such as ``cursor.lastrowid``
|
||||
or ``RETURNING`` are typically used to acquire the new primary key
|
||||
value.
|
||||
|
||||
However, when the statement is modified by calling
|
||||
:meth:`.ValuesBase.return_defaults` before executing the statement,
|
||||
additional behaviors take place **only** for backends that support
|
||||
RETURNING and for :class:`.Table` objects that maintain the
|
||||
:paramref:`.Table.implicit_returning` parameter at its default value of
|
||||
``True``. In these cases, when the :class:`.CursorResult` is returned
|
||||
from the statement's execution, not only will
|
||||
:attr:`.CursorResult.inserted_primary_key` be populated as always, the
|
||||
:attr:`.CursorResult.returned_defaults` attribute will also be
|
||||
populated with a :class:`.Row` named-tuple representing the full range
|
||||
of server generated
|
||||
values from that single row, including values for any columns that
|
||||
specify :paramref:`_schema.Column.server_default` or which make use of
|
||||
:paramref:`_schema.Column.default` using a SQL expression.
|
||||
|
||||
When invoking INSERT statements with multiple rows using
|
||||
:ref:`insertmanyvalues <engine_insertmanyvalues>`, the
|
||||
:meth:`.ValuesBase.return_defaults` modifier will have the effect of
|
||||
the :attr:`_engine.CursorResult.inserted_primary_key_rows` and
|
||||
:attr:`_engine.CursorResult.returned_defaults_rows` attributes being
|
||||
fully populated with lists of :class:`.Row` objects representing newly
|
||||
inserted primary key values as well as newly inserted server generated
|
||||
values for each row inserted. The
|
||||
:attr:`.CursorResult.inserted_primary_key` and
|
||||
:attr:`.CursorResult.returned_defaults` attributes will also continue
|
||||
to be populated with the first row of these two collections.
|
||||
|
||||
If the backend does not support RETURNING or the :class:`.Table` in use
|
||||
has disabled :paramref:`.Table.implicit_returning`, then no RETURNING
|
||||
clause is added and no additional data is fetched, however the
|
||||
INSERT or UPDATE statement proceeds normally.
|
||||
|
||||
|
||||
E.g.::
|
||||
|
||||
stmt = table.insert().values(data='newdata').return_defaults()
|
||||
|
||||
result = connection.execute(stmt)
|
||||
|
||||
server_created_at = result.returned_defaults['created_at']
|
||||
|
||||
|
||||
The :meth:`.ValuesBase.return_defaults` method is mutually exclusive
|
||||
against the :meth:`.UpdateBase.returning` method and errors will be
|
||||
raised during the SQL compilation process if both are used at the same
|
||||
time on one statement. The RETURNING clause of the INSERT or UPDATE
|
||||
statement is therefore controlled by only one of these methods at a
|
||||
time.
|
||||
|
||||
The :meth:`.ValuesBase.return_defaults` method differs from
|
||||
:meth:`.UpdateBase.returning` in these ways:
|
||||
|
||||
1. :meth:`.ValuesBase.return_defaults` method causes the
|
||||
:attr:`.CursorResult.returned_defaults` collection to be populated
|
||||
with the first row from the RETURNING result. This attribute is not
|
||||
populated when using :meth:`.UpdateBase.returning`.
|
||||
|
||||
2. :meth:`.ValuesBase.return_defaults` is compatible with existing
|
||||
logic used to fetch auto-generated primary key values that are then
|
||||
populated into the :attr:`.CursorResult.inserted_primary_key`
|
||||
attribute. By contrast, using :meth:`.UpdateBase.returning` will
|
||||
have the effect of the :attr:`.CursorResult.inserted_primary_key`
|
||||
attribute being left unpopulated.
|
||||
|
||||
3. :meth:`.ValuesBase.return_defaults` can be called against any
|
||||
backend. Backends that don't support RETURNING will skip the usage
|
||||
of the feature, rather than raising an exception. The return value
|
||||
of :attr:`_engine.CursorResult.returned_defaults` will be ``None``
|
||||
for backends that don't support RETURNING or for which the target
|
||||
:class:`.Table` sets :paramref:`.Table.implicit_returning` to
|
||||
``False``.
|
||||
|
||||
4. An INSERT statement invoked with executemany() is supported if the
|
||||
backend database driver supports the
|
||||
:ref:`insertmanyvalues <engine_insertmanyvalues>`
|
||||
feature which is now supported by most SQLAlchemy-included backends.
|
||||
When executemany is used, the
|
||||
:attr:`_engine.CursorResult.returned_defaults_rows` and
|
||||
:attr:`_engine.CursorResult.inserted_primary_key_rows` accessors
|
||||
will return the inserted defaults and primary keys.
|
||||
|
||||
.. versionadded:: 1.4 Added
|
||||
:attr:`_engine.CursorResult.returned_defaults_rows` and
|
||||
:attr:`_engine.CursorResult.inserted_primary_key_rows` accessors.
|
||||
In version 2.0, the underlying implementation which fetches and
|
||||
populates the data for these attributes was generalized to be
|
||||
supported by most backends, whereas in 1.4 they were only
|
||||
supported by the ``psycopg2`` driver.
|
||||
|
||||
|
||||
:param cols: optional list of column key names or
|
||||
:class:`_schema.Column` that acts as a filter for those columns that
|
||||
will be fetched.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`.UpdateBase.returning`
|
||||
|
||||
:attr:`_engine.CursorResult.returned_defaults`
|
||||
|
||||
:attr:`_engine.CursorResult.returned_defaults_rows`
|
||||
|
||||
:attr:`_engine.CursorResult.inserted_primary_key`
|
||||
|
||||
:attr:`_engine.CursorResult.inserted_primary_key_rows`
|
||||
|
||||
"""
|
||||
|
||||
if self._return_defaults:
|
||||
# note _return_defaults_columns = () means return all columns,
|
||||
# so if we have been here before, only update collection if there
|
||||
# are columns in the collection
|
||||
if self._return_defaults_columns and cols:
|
||||
self._return_defaults_columns = tuple(
|
||||
set(self._return_defaults_columns).union(
|
||||
coercions.expect(roles.ColumnsClauseRole, c)
|
||||
for c in cols
|
||||
)
|
||||
)
|
||||
else:
|
||||
# set for all columns
|
||||
self._return_defaults_columns = ()
|
||||
else:
|
||||
self._return_defaults_columns = tuple(
|
||||
coercions.expect(roles.ColumnsClauseRole, c) for c in cols
|
||||
)
|
||||
self._return_defaults = True
|
||||
return self
|
||||
|
||||
|
||||
SelfInsert = typing.TypeVar("SelfInsert", bound="Insert")
|
||||
|
||||
@@ -1459,7 +1532,7 @@ class Update(DMLWhereBase, ValuesBase):
|
||||
)
|
||||
|
||||
kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs
|
||||
self._ordered_values = kv_generator(self, args)
|
||||
self._ordered_values = kv_generator(self, args, True)
|
||||
return self
|
||||
|
||||
@_generative
|
||||
|
||||
@@ -68,7 +68,7 @@ class CursorSQL(SQLMatchRule):
|
||||
|
||||
class CompiledSQL(SQLMatchRule):
|
||||
def __init__(
|
||||
self, statement, params=None, dialect="default", enable_returning=False
|
||||
self, statement, params=None, dialect="default", enable_returning=True
|
||||
):
|
||||
self.statement = statement
|
||||
self.params = params
|
||||
@@ -90,6 +90,17 @@ class CompiledSQL(SQLMatchRule):
|
||||
dialect.insert_returning = (
|
||||
dialect.update_returning
|
||||
) = dialect.delete_returning = True
|
||||
dialect.use_insertmanyvalues = True
|
||||
dialect.supports_multivalues_insert = True
|
||||
dialect.update_returning_multifrom = True
|
||||
dialect.delete_returning_multifrom = True
|
||||
# dialect.favor_returning_over_lastrowid = True
|
||||
# dialect.insert_null_pk_still_autoincrements = True
|
||||
|
||||
# this is calculated but we need it to be True for this
|
||||
# to look like all the current RETURNING dialects
|
||||
assert dialect.insert_executemany_returning
|
||||
|
||||
return dialect
|
||||
else:
|
||||
return url.URL.create(self.dialect).get_dialect()()
|
||||
|
||||
@@ -23,7 +23,6 @@ from .util import adict
|
||||
from .util import drop_all_tables_from_metadata
|
||||
from .. import event
|
||||
from .. import util
|
||||
from ..orm import declarative_base
|
||||
from ..orm import DeclarativeBase
|
||||
from ..orm import MappedAsDataclass
|
||||
from ..orm import registry
|
||||
@@ -117,7 +116,7 @@ class TestBase:
|
||||
metadata=metadata,
|
||||
type_annotation_map={
|
||||
str: sa.String().with_variant(
|
||||
sa.String(50), "mysql", "mariadb"
|
||||
sa.String(50), "mysql", "mariadb", "oracle"
|
||||
)
|
||||
},
|
||||
)
|
||||
@@ -132,7 +131,7 @@ class TestBase:
|
||||
metadata = _md
|
||||
type_annotation_map = {
|
||||
str: sa.String().with_variant(
|
||||
sa.String(50), "mysql", "mariadb"
|
||||
sa.String(50), "mysql", "mariadb", "oracle"
|
||||
)
|
||||
}
|
||||
|
||||
@@ -780,18 +779,19 @@ class DeclarativeMappedTest(MappedTest):
|
||||
def _with_register_classes(cls, fn):
|
||||
cls_registry = cls.classes
|
||||
|
||||
class DeclarativeBasic:
|
||||
class _DeclBase(DeclarativeBase):
|
||||
__table_cls__ = schema.Table
|
||||
metadata = cls._tables_metadata
|
||||
type_annotation_map = {
|
||||
str: sa.String().with_variant(
|
||||
sa.String(50), "mysql", "mariadb", "oracle"
|
||||
)
|
||||
}
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
def __init_subclass__(cls, **kw) -> None:
|
||||
assert cls_registry is not None
|
||||
cls_registry[cls.__name__] = cls
|
||||
super().__init_subclass__()
|
||||
|
||||
_DeclBase = declarative_base(
|
||||
metadata=cls._tables_metadata,
|
||||
cls=DeclarativeBasic,
|
||||
)
|
||||
super().__init_subclass__(**kw)
|
||||
|
||||
cls.DeclarativeBasic = _DeclBase
|
||||
|
||||
|
||||
@@ -89,8 +89,13 @@ class RowCountTest(fixtures.TablesTest):
|
||||
eq_(r.rowcount, 3)
|
||||
|
||||
@testing.requires.update_returning
|
||||
@testing.requires.sane_rowcount_w_returning
|
||||
def test_update_rowcount_return_defaults(self, connection):
|
||||
"""note this test should succeed for all RETURNING backends
|
||||
as of 2.0. In
|
||||
Idf28379f8705e403a3c6a937f6a798a042ef2540 we changed rowcount to use
|
||||
len(rows) when we have implicit returning
|
||||
|
||||
"""
|
||||
employees_table = self.tables.employees
|
||||
|
||||
department = employees_table.c.department
|
||||
|
||||
@@ -11,6 +11,7 @@ from __future__ import annotations
|
||||
from itertools import filterfalse
|
||||
from typing import AbstractSet
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import cast
|
||||
from typing import Collection
|
||||
from typing import Dict
|
||||
@@ -481,7 +482,9 @@ class IdentitySet:
|
||||
return "%s(%r)" % (type(self).__name__, list(self._members.values()))
|
||||
|
||||
|
||||
def unique_list(seq, hashfunc=None):
|
||||
def unique_list(
|
||||
seq: Iterable[_T], hashfunc: Optional[Callable[[_T], int]] = None
|
||||
) -> List[_T]:
|
||||
seen: Set[Any] = set()
|
||||
seen_add = seen.add
|
||||
if not hashfunc:
|
||||
|
||||
@@ -465,7 +465,11 @@ class ShardTest:
|
||||
t = get_tokyo(sess2)
|
||||
eq_(t.city, tokyo.city)
|
||||
|
||||
def test_bulk_update_synchronize_evaluate(self):
|
||||
@testing.combinations(
|
||||
"fetch", "evaluate", "auto", argnames="synchronize_session"
|
||||
)
|
||||
@testing.combinations(True, False, argnames="legacy")
|
||||
def test_orm_update_synchronize(self, synchronize_session, legacy):
|
||||
sess = self._fixture_data()
|
||||
|
||||
eq_(
|
||||
@@ -476,199 +480,67 @@ class ShardTest:
|
||||
temps = sess.query(Report).all()
|
||||
eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
|
||||
|
||||
sess.query(Report).filter(Report.temperature >= 80).update(
|
||||
{"temperature": Report.temperature + 6},
|
||||
synchronize_session="evaluate",
|
||||
)
|
||||
|
||||
eq_(
|
||||
set(row.temperature for row in sess.query(Report.temperature)),
|
||||
{86.0, 75.0, 91.0},
|
||||
)
|
||||
|
||||
# test synchronize session as well
|
||||
eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
|
||||
|
||||
def test_bulk_update_synchronize_fetch(self):
|
||||
sess = self._fixture_data()
|
||||
|
||||
eq_(
|
||||
set(row.temperature for row in sess.query(Report.temperature)),
|
||||
{80.0, 75.0, 85.0},
|
||||
)
|
||||
|
||||
temps = sess.query(Report).all()
|
||||
eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
|
||||
|
||||
sess.query(Report).filter(Report.temperature >= 80).update(
|
||||
{"temperature": Report.temperature + 6},
|
||||
synchronize_session="fetch",
|
||||
)
|
||||
|
||||
eq_(
|
||||
set(row.temperature for row in sess.query(Report.temperature)),
|
||||
{86.0, 75.0, 91.0},
|
||||
)
|
||||
|
||||
# test synchronize session as well
|
||||
eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
|
||||
|
||||
def test_bulk_delete_synchronize_evaluate(self):
|
||||
sess = self._fixture_data()
|
||||
|
||||
temps = sess.query(Report).all()
|
||||
eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
|
||||
|
||||
sess.query(Report).filter(Report.temperature >= 80).delete(
|
||||
synchronize_session="evaluate"
|
||||
)
|
||||
|
||||
eq_(
|
||||
set(row.temperature for row in sess.query(Report.temperature)),
|
||||
{75.0},
|
||||
)
|
||||
|
||||
# test synchronize session as well
|
||||
for t in temps:
|
||||
assert inspect(t).deleted is (t.temperature >= 80)
|
||||
|
||||
def test_bulk_delete_synchronize_fetch(self):
|
||||
sess = self._fixture_data()
|
||||
|
||||
temps = sess.query(Report).all()
|
||||
eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
|
||||
|
||||
sess.query(Report).filter(Report.temperature >= 80).delete(
|
||||
synchronize_session="fetch"
|
||||
)
|
||||
|
||||
eq_(
|
||||
set(row.temperature for row in sess.query(Report.temperature)),
|
||||
{75.0},
|
||||
)
|
||||
|
||||
# test synchronize session as well
|
||||
for t in temps:
|
||||
assert inspect(t).deleted is (t.temperature >= 80)
|
||||
|
||||
def test_bulk_update_future_synchronize_evaluate(self):
|
||||
sess = self._fixture_data()
|
||||
|
||||
eq_(
|
||||
set(
|
||||
row.temperature
|
||||
for row in sess.execute(select(Report.temperature))
|
||||
),
|
||||
{80.0, 75.0, 85.0},
|
||||
)
|
||||
|
||||
temps = sess.execute(select(Report)).scalars().all()
|
||||
eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
|
||||
|
||||
sess.execute(
|
||||
update(Report)
|
||||
.filter(Report.temperature >= 80)
|
||||
.values(
|
||||
if legacy:
|
||||
sess.query(Report).filter(Report.temperature >= 80).update(
|
||||
{"temperature": Report.temperature + 6},
|
||||
synchronize_session=synchronize_session,
|
||||
)
|
||||
.execution_options(synchronize_session="evaluate")
|
||||
else:
|
||||
sess.execute(
|
||||
update(Report)
|
||||
.filter(Report.temperature >= 80)
|
||||
.values(temperature=Report.temperature + 6)
|
||||
.execution_options(synchronize_session=synchronize_session)
|
||||
)
|
||||
|
||||
# test synchronize session
|
||||
def go():
|
||||
eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
|
||||
|
||||
self.assert_sql_count(
|
||||
sess._ShardedSession__binds["north_america"], go, 0
|
||||
)
|
||||
|
||||
eq_(
|
||||
set(
|
||||
row.temperature
|
||||
for row in sess.execute(select(Report.temperature))
|
||||
),
|
||||
set(row.temperature for row in sess.query(Report.temperature)),
|
||||
{86.0, 75.0, 91.0},
|
||||
)
|
||||
|
||||
# test synchronize session as well
|
||||
eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
|
||||
|
||||
def test_bulk_update_future_synchronize_fetch(self):
|
||||
@testing.combinations(
|
||||
"fetch", "evaluate", "auto", argnames="synchronize_session"
|
||||
)
|
||||
@testing.combinations(True, False, argnames="legacy")
|
||||
def test_orm_delete_synchronize(self, synchronize_session, legacy):
|
||||
sess = self._fixture_data()
|
||||
|
||||
eq_(
|
||||
set(
|
||||
row.temperature
|
||||
for row in sess.execute(select(Report.temperature))
|
||||
),
|
||||
{80.0, 75.0, 85.0},
|
||||
)
|
||||
|
||||
temps = sess.execute(select(Report)).scalars().all()
|
||||
temps = sess.query(Report).all()
|
||||
eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
|
||||
|
||||
# MARKMARK
|
||||
# omitting the criteria so that the UPDATE affects three out of
|
||||
# four shards
|
||||
sess.execute(
|
||||
update(Report)
|
||||
.values(
|
||||
{"temperature": Report.temperature + 6},
|
||||
if legacy:
|
||||
sess.query(Report).filter(Report.temperature >= 80).delete(
|
||||
synchronize_session=synchronize_session
|
||||
)
|
||||
.execution_options(synchronize_session="fetch")
|
||||
else:
|
||||
sess.execute(
|
||||
delete(Report)
|
||||
.filter(Report.temperature >= 80)
|
||||
.execution_options(synchronize_session=synchronize_session)
|
||||
)
|
||||
|
||||
def go():
|
||||
# test synchronize session
|
||||
for t in temps:
|
||||
assert inspect(t).deleted is (t.temperature >= 80)
|
||||
|
||||
self.assert_sql_count(
|
||||
sess._ShardedSession__binds["north_america"], go, 0
|
||||
)
|
||||
|
||||
eq_(
|
||||
set(
|
||||
row.temperature
|
||||
for row in sess.execute(select(Report.temperature))
|
||||
),
|
||||
{86.0, 81.0, 91.0},
|
||||
)
|
||||
|
||||
# test synchronize session as well
|
||||
eq_(set(t.temperature for t in temps), {86.0, 81.0, 91.0})
|
||||
|
||||
def test_bulk_delete_future_synchronize_evaluate(self):
|
||||
sess = self._fixture_data()
|
||||
|
||||
temps = sess.execute(select(Report)).scalars().all()
|
||||
eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
|
||||
|
||||
sess.execute(
|
||||
delete(Report)
|
||||
.filter(Report.temperature >= 80)
|
||||
.execution_options(synchronize_session="evaluate")
|
||||
)
|
||||
|
||||
eq_(
|
||||
set(
|
||||
row.temperature
|
||||
for row in sess.execute(select(Report.temperature))
|
||||
),
|
||||
set(row.temperature for row in sess.query(Report.temperature)),
|
||||
{75.0},
|
||||
)
|
||||
|
||||
# test synchronize session as well
|
||||
for t in temps:
|
||||
assert inspect(t).deleted is (t.temperature >= 80)
|
||||
|
||||
def test_bulk_delete_future_synchronize_fetch(self):
|
||||
sess = self._fixture_data()
|
||||
|
||||
temps = sess.execute(select(Report)).scalars().all()
|
||||
eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
|
||||
|
||||
sess.execute(
|
||||
delete(Report)
|
||||
.filter(Report.temperature >= 80)
|
||||
.execution_options(synchronize_session="fetch")
|
||||
)
|
||||
|
||||
eq_(
|
||||
set(
|
||||
row.temperature
|
||||
for row in sess.execute(select(Report.temperature))
|
||||
),
|
||||
{75.0},
|
||||
)
|
||||
|
||||
# test synchronize session as well
|
||||
for t in temps:
|
||||
assert inspect(t).deleted is (t.temperature >= 80)
|
||||
|
||||
|
||||
class DistinctEngineShardTest(ShardTest, fixtures.MappedTest):
|
||||
def _init_dbs(self):
|
||||
|
||||
+32
-3
@@ -3,6 +3,7 @@ from decimal import Decimal
|
||||
from sqlalchemy import exc
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL
|
||||
@@ -1017,15 +1018,43 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL):
|
||||
params={"first_name": "Dr."},
|
||||
)
|
||||
|
||||
def test_update_expr(self):
|
||||
@testing.combinations("attr", "str", "kwarg", argnames="keytype")
|
||||
def test_update_expr(self, keytype):
|
||||
Person = self.classes.Person
|
||||
|
||||
statement = update(Person).values({Person.name: "Dr. No"})
|
||||
if keytype == "attr":
|
||||
statement = update(Person).values({Person.name: "Dr. No"})
|
||||
elif keytype == "str":
|
||||
statement = update(Person).values({"name": "Dr. No"})
|
||||
elif keytype == "kwarg":
|
||||
statement = update(Person).values(name="Dr. No")
|
||||
else:
|
||||
assert False
|
||||
|
||||
self.assert_compile(
|
||||
statement,
|
||||
"UPDATE person SET first_name=:first_name, last_name=:last_name",
|
||||
params={"first_name": "Dr.", "last_name": "No"},
|
||||
checkparams={"first_name": "Dr.", "last_name": "No"},
|
||||
)
|
||||
|
||||
@testing.combinations("attr", "str", "kwarg", argnames="keytype")
|
||||
def test_insert_expr(self, keytype):
|
||||
Person = self.classes.Person
|
||||
|
||||
if keytype == "attr":
|
||||
statement = insert(Person).values({Person.name: "Dr. No"})
|
||||
elif keytype == "str":
|
||||
statement = insert(Person).values({"name": "Dr. No"})
|
||||
elif keytype == "kwarg":
|
||||
statement = insert(Person).values(name="Dr. No")
|
||||
else:
|
||||
assert False
|
||||
|
||||
self.assert_compile(
|
||||
statement,
|
||||
"INSERT INTO person (first_name, last_name) VALUES "
|
||||
"(:first_name, :last_name)",
|
||||
checkparams={"first_name": "Dr.", "last_name": "No"},
|
||||
)
|
||||
|
||||
# these tests all run two UPDATES to assert that caching is not
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from sqlalchemy import FetchedValue
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import Identity
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import testing
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.testing import eq_
|
||||
from sqlalchemy.testing import fixtures
|
||||
from sqlalchemy.testing import mock
|
||||
@@ -20,6 +23,8 @@ class BulkTest(testing.AssertsExecutionResults):
|
||||
|
||||
|
||||
class BulkInsertUpdateVersionId(BulkTest, fixtures.MappedTest):
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
@@ -73,6 +78,8 @@ class BulkInsertUpdateVersionId(BulkTest, fixtures.MappedTest):
|
||||
|
||||
|
||||
class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def setup_mappers(cls):
|
||||
User, Address, Order = cls.classes("User", "Address", "Order")
|
||||
@@ -82,22 +89,42 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
|
||||
cls.mapper_registry.map_imperatively(Address, a)
|
||||
cls.mapper_registry.map_imperatively(Order, o)
|
||||
|
||||
def test_bulk_save_return_defaults(self):
|
||||
@testing.combinations("save_objects", "insert_mappings", "insert_stmt")
|
||||
def test_bulk_save_return_defaults(self, statement_type):
|
||||
(User,) = self.classes("User")
|
||||
|
||||
s = fixture_session()
|
||||
objects = [User(name="u1"), User(name="u2"), User(name="u3")]
|
||||
assert "id" not in objects[0].__dict__
|
||||
|
||||
with self.sql_execution_asserter() as asserter:
|
||||
s.bulk_save_objects(objects, return_defaults=True)
|
||||
if statement_type == "save_objects":
|
||||
objects = [User(name="u1"), User(name="u2"), User(name="u3")]
|
||||
assert "id" not in objects[0].__dict__
|
||||
|
||||
returning_users_id = " RETURNING users.id"
|
||||
with self.sql_execution_asserter() as asserter:
|
||||
s.bulk_save_objects(objects, return_defaults=True)
|
||||
elif statement_type == "insert_mappings":
|
||||
data = [dict(name="u1"), dict(name="u2"), dict(name="u3")]
|
||||
returning_users_id = " RETURNING users.id"
|
||||
with self.sql_execution_asserter() as asserter:
|
||||
s.bulk_insert_mappings(User, data, return_defaults=True)
|
||||
elif statement_type == "insert_stmt":
|
||||
data = [dict(name="u1"), dict(name="u2"), dict(name="u3")]
|
||||
|
||||
# for statement, "return_defaults" is heuristic on if we are
|
||||
# a joined inh mapping if we don't otherwise include
|
||||
# .returning() on the statement itself
|
||||
returning_users_id = ""
|
||||
with self.sql_execution_asserter() as asserter:
|
||||
s.execute(insert(User), data)
|
||||
|
||||
asserter.assert_(
|
||||
Conditional(
|
||||
testing.db.dialect.insert_executemany_returning,
|
||||
testing.db.dialect.insert_executemany_returning
|
||||
or statement_type == "insert_stmt",
|
||||
[
|
||||
CompiledSQL(
|
||||
"INSERT INTO users (name) VALUES (:name)",
|
||||
"INSERT INTO users (name) "
|
||||
f"VALUES (:name){returning_users_id}",
|
||||
[{"name": "u1"}, {"name": "u2"}, {"name": "u3"}],
|
||||
),
|
||||
],
|
||||
@@ -117,7 +144,8 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
|
||||
],
|
||||
)
|
||||
)
|
||||
eq_(objects[0].__dict__["id"], 1)
|
||||
if statement_type == "save_objects":
|
||||
eq_(objects[0].__dict__["id"], 1)
|
||||
|
||||
def test_bulk_save_mappings_preserve_order(self):
|
||||
(User,) = self.classes("User")
|
||||
@@ -219,8 +247,9 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
|
||||
)
|
||||
)
|
||||
|
||||
def test_bulk_update(self):
|
||||
(User,) = self.classes("User")
|
||||
@testing.combinations("update_mappings", "update_stmt")
|
||||
def test_bulk_update(self, statement_type):
|
||||
User = self.classes.User
|
||||
|
||||
s = fixture_session(expire_on_commit=False)
|
||||
objects = [User(name="u1"), User(name="u2"), User(name="u3")]
|
||||
@@ -228,15 +257,18 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
|
||||
s.commit()
|
||||
|
||||
s = fixture_session()
|
||||
with self.sql_execution_asserter() as asserter:
|
||||
s.bulk_update_mappings(
|
||||
User,
|
||||
[
|
||||
{"id": 1, "name": "u1new"},
|
||||
{"id": 2, "name": "u2"},
|
||||
{"id": 3, "name": "u3new"},
|
||||
],
|
||||
)
|
||||
data = [
|
||||
{"id": 1, "name": "u1new"},
|
||||
{"id": 2, "name": "u2"},
|
||||
{"id": 3, "name": "u3new"},
|
||||
]
|
||||
|
||||
if statement_type == "update_mappings":
|
||||
with self.sql_execution_asserter() as asserter:
|
||||
s.bulk_update_mappings(User, data)
|
||||
elif statement_type == "update_stmt":
|
||||
with self.sql_execution_asserter() as asserter:
|
||||
s.execute(update(User), data)
|
||||
|
||||
asserter.assert_(
|
||||
CompiledSQL(
|
||||
@@ -303,6 +335,8 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
|
||||
|
||||
|
||||
class BulkUDPostfetchTest(BulkTest, fixtures.MappedTest):
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
@@ -360,6 +394,8 @@ class BulkUDPostfetchTest(BulkTest, fixtures.MappedTest):
|
||||
|
||||
|
||||
class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest):
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
@@ -547,6 +583,8 @@ class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest):
|
||||
|
||||
|
||||
class BulkInheritanceTest(BulkTest, fixtures.MappedTest):
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
@@ -643,6 +681,7 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest):
|
||||
)
|
||||
|
||||
s = fixture_session()
|
||||
|
||||
objects = [
|
||||
Manager(name="m1", status="s1", manager_name="mn1"),
|
||||
Engineer(name="e1", status="s2", primary_language="l1"),
|
||||
@@ -669,7 +708,7 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest):
|
||||
[
|
||||
CompiledSQL(
|
||||
"INSERT INTO people (name, type) "
|
||||
"VALUES (:name, :type)",
|
||||
"VALUES (:name, :type) RETURNING people.person_id",
|
||||
[
|
||||
{"type": "engineer", "name": "e1"},
|
||||
{"type": "engineer", "name": "e2"},
|
||||
@@ -798,59 +837,74 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest):
|
||||
),
|
||||
)
|
||||
|
||||
def test_bulk_insert_joined_inh_return_defaults(self):
|
||||
@testing.combinations("insert_mappings", "insert_stmt")
|
||||
def test_bulk_insert_joined_inh_return_defaults(self, statement_type):
|
||||
Person, Engineer, Manager, Boss = self.classes(
|
||||
"Person", "Engineer", "Manager", "Boss"
|
||||
)
|
||||
|
||||
s = fixture_session()
|
||||
with self.sql_execution_asserter() as asserter:
|
||||
s.bulk_insert_mappings(
|
||||
Boss,
|
||||
[
|
||||
dict(
|
||||
name="b1",
|
||||
status="s1",
|
||||
manager_name="mn1",
|
||||
golf_swing="g1",
|
||||
),
|
||||
dict(
|
||||
name="b2",
|
||||
status="s2",
|
||||
manager_name="mn2",
|
||||
golf_swing="g2",
|
||||
),
|
||||
dict(
|
||||
name="b3",
|
||||
status="s3",
|
||||
manager_name="mn3",
|
||||
golf_swing="g3",
|
||||
),
|
||||
],
|
||||
return_defaults=True,
|
||||
)
|
||||
data = [
|
||||
dict(
|
||||
name="b1",
|
||||
status="s1",
|
||||
manager_name="mn1",
|
||||
golf_swing="g1",
|
||||
),
|
||||
dict(
|
||||
name="b2",
|
||||
status="s2",
|
||||
manager_name="mn2",
|
||||
golf_swing="g2",
|
||||
),
|
||||
dict(
|
||||
name="b3",
|
||||
status="s3",
|
||||
manager_name="mn3",
|
||||
golf_swing="g3",
|
||||
),
|
||||
]
|
||||
|
||||
if statement_type == "insert_mappings":
|
||||
with self.sql_execution_asserter() as asserter:
|
||||
s.bulk_insert_mappings(
|
||||
Boss,
|
||||
data,
|
||||
return_defaults=True,
|
||||
)
|
||||
elif statement_type == "insert_stmt":
|
||||
with self.sql_execution_asserter() as asserter:
|
||||
s.execute(insert(Boss), data)
|
||||
|
||||
asserter.assert_(
|
||||
Conditional(
|
||||
testing.db.dialect.insert_executemany_returning,
|
||||
[
|
||||
CompiledSQL(
|
||||
"INSERT INTO people (name) VALUES (:name)",
|
||||
[{"name": "b1"}, {"name": "b2"}, {"name": "b3"}],
|
||||
"INSERT INTO people (name, type) "
|
||||
"VALUES (:name, :type) RETURNING people.person_id",
|
||||
[
|
||||
{"name": "b1", "type": "boss"},
|
||||
{"name": "b2", "type": "boss"},
|
||||
{"name": "b3", "type": "boss"},
|
||||
],
|
||||
),
|
||||
],
|
||||
[
|
||||
CompiledSQL(
|
||||
"INSERT INTO people (name) VALUES (:name)",
|
||||
[{"name": "b1"}],
|
||||
"INSERT INTO people (name, type) "
|
||||
"VALUES (:name, :type)",
|
||||
[{"name": "b1", "type": "boss"}],
|
||||
),
|
||||
CompiledSQL(
|
||||
"INSERT INTO people (name) VALUES (:name)",
|
||||
[{"name": "b2"}],
|
||||
"INSERT INTO people (name, type) "
|
||||
"VALUES (:name, :type)",
|
||||
[{"name": "b2", "type": "boss"}],
|
||||
),
|
||||
CompiledSQL(
|
||||
"INSERT INTO people (name) VALUES (:name)",
|
||||
[{"name": "b3"}],
|
||||
"INSERT INTO people (name, type) "
|
||||
"VALUES (:name, :type)",
|
||||
[{"name": "b3", "type": "boss"}],
|
||||
),
|
||||
],
|
||||
),
|
||||
@@ -874,15 +928,79 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest):
|
||||
),
|
||||
)
|
||||
|
||||
@testing.combinations("update_mappings", "update_stmt")
|
||||
def test_bulk_update(self, statement_type):
|
||||
Person, Engineer, Manager, Boss = self.classes(
|
||||
"Person", "Engineer", "Manager", "Boss"
|
||||
)
|
||||
|
||||
s = fixture_session()
|
||||
|
||||
b1, b2, b3 = (
|
||||
Boss(name="b1", status="s1", manager_name="mn1", golf_swing="g1"),
|
||||
Boss(name="b2", status="s2", manager_name="mn2", golf_swing="g2"),
|
||||
Boss(name="b3", status="s3", manager_name="mn3", golf_swing="g3"),
|
||||
)
|
||||
s.add_all([b1, b2, b3])
|
||||
s.commit()
|
||||
|
||||
# slight non-convenient thing. we have to fill in boss_id here
|
||||
# for update, this is not sent along automatically. this is not a
|
||||
# new behavior in bulk
|
||||
new_data = [
|
||||
{
|
||||
"person_id": b1.person_id,
|
||||
"boss_id": b1.boss_id,
|
||||
"name": "b1_updated",
|
||||
"manager_name": "mn1_updated",
|
||||
},
|
||||
{
|
||||
"person_id": b3.person_id,
|
||||
"boss_id": b3.boss_id,
|
||||
"manager_name": "mn2_updated",
|
||||
"golf_swing": "g1_updated",
|
||||
},
|
||||
]
|
||||
|
||||
if statement_type == "update_mappings":
|
||||
with self.sql_execution_asserter() as asserter:
|
||||
s.bulk_update_mappings(Boss, new_data)
|
||||
elif statement_type == "update_stmt":
|
||||
with self.sql_execution_asserter() as asserter:
|
||||
s.execute(update(Boss), new_data)
|
||||
|
||||
asserter.assert_(
|
||||
CompiledSQL(
|
||||
"UPDATE people SET name=:name WHERE "
|
||||
"people.person_id = :people_person_id",
|
||||
[{"name": "b1_updated", "people_person_id": 1}],
|
||||
),
|
||||
CompiledSQL(
|
||||
"UPDATE managers SET manager_name=:manager_name WHERE "
|
||||
"managers.person_id = :managers_person_id",
|
||||
[
|
||||
{"manager_name": "mn1_updated", "managers_person_id": 1},
|
||||
{"manager_name": "mn2_updated", "managers_person_id": 3},
|
||||
],
|
||||
),
|
||||
CompiledSQL(
|
||||
"UPDATE boss SET golf_swing=:golf_swing WHERE "
|
||||
"boss.boss_id = :boss_boss_id",
|
||||
[{"golf_swing": "g1_updated", "boss_boss_id": 3}],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class BulkIssue6793Test(BulkTest, fixtures.DeclarativeMappedTest):
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def setup_classes(cls):
|
||||
Base = cls.DeclarativeBasic
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
id = Column(Integer, primary_key=True)
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
name = Column(String(255), nullable=False)
|
||||
|
||||
def test_issue_6793(self):
|
||||
@@ -907,7 +1025,8 @@ class BulkIssue6793Test(BulkTest, fixtures.DeclarativeMappedTest):
|
||||
[{"name": "A"}, {"name": "B"}],
|
||||
),
|
||||
CompiledSQL(
|
||||
"INSERT INTO users (name) VALUES (:name)",
|
||||
"INSERT INTO users (name) VALUES (:name) "
|
||||
"RETURNING users.id",
|
||||
[{"name": "C"}, {"name": "D"}],
|
||||
),
|
||||
],
|
||||
File diff suppressed because it is too large
Load Diff
@@ -324,7 +324,6 @@ class EvaluateTest(fixtures.MappedTest):
|
||||
"""test #3162"""
|
||||
|
||||
User = self.classes.User
|
||||
|
||||
with expect_raises_message(
|
||||
evaluator.UnevaluatableError,
|
||||
r"Custom operator '\^\^' can't be evaluated in "
|
||||
@@ -1,3 +1,4 @@
|
||||
from sqlalchemy import bindparam
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy import case
|
||||
from sqlalchemy import column
|
||||
@@ -7,6 +8,7 @@ from sqlalchemy import exc
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import lambda_stmt
|
||||
from sqlalchemy import MetaData
|
||||
@@ -17,6 +19,7 @@ from sqlalchemy import testing
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import backref
|
||||
from sqlalchemy.orm import exc as orm_exc
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -26,6 +29,7 @@ from sqlalchemy.orm import with_loader_criteria
|
||||
from sqlalchemy.testing import assert_raises
|
||||
from sqlalchemy.testing import assert_raises_message
|
||||
from sqlalchemy.testing import eq_
|
||||
from sqlalchemy.testing import expect_raises
|
||||
from sqlalchemy.testing import fixtures
|
||||
from sqlalchemy.testing import in_
|
||||
from sqlalchemy.testing import not_in
|
||||
@@ -123,6 +127,25 @@ class UpdateDeleteTest(fixtures.MappedTest):
|
||||
},
|
||||
)
|
||||
|
||||
def test_update_dont_use_col_key(self):
|
||||
User = self.classes.User
|
||||
|
||||
s = fixture_session()
|
||||
|
||||
# make sure objects are present to synchronize
|
||||
_ = s.query(User).all()
|
||||
|
||||
with expect_raises_message(
|
||||
exc.InvalidRequestError,
|
||||
"Attribute name not found, can't be synchronized back "
|
||||
"to objects: 'age_int'",
|
||||
):
|
||||
s.execute(update(User).values(age_int=5))
|
||||
|
||||
stmt = update(User).values(age=5)
|
||||
s.execute(stmt)
|
||||
eq_(s.scalars(select(User.age)).all(), [5, 5, 5, 5])
|
||||
|
||||
@testing.combinations("table", "mapper", "both", argnames="bind_type")
|
||||
@testing.combinations(
|
||||
"update", "insert", "delete", argnames="statement_type"
|
||||
@@ -162,7 +185,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
|
||||
assert_raises_message(
|
||||
exc.ArgumentError,
|
||||
"Valid strategies for session synchronization "
|
||||
"are 'evaluate', 'fetch', False",
|
||||
"are 'auto', 'evaluate', 'fetch', False",
|
||||
s.query(User).update,
|
||||
{},
|
||||
synchronize_session="fake",
|
||||
@@ -351,6 +374,12 @@ class UpdateDeleteTest(fixtures.MappedTest):
|
||||
def test_evaluate_dont_refresh_expired_objects(
|
||||
self, expire_jane_age, add_filter_criteria
|
||||
):
|
||||
"""test #5664.
|
||||
|
||||
approach is revised in SQLAlchemy 2.0 to not pre-emptively
|
||||
unexpire the involved attributes
|
||||
|
||||
"""
|
||||
User = self.classes.User
|
||||
|
||||
sess = fixture_session()
|
||||
@@ -379,15 +408,10 @@ class UpdateDeleteTest(fixtures.MappedTest):
|
||||
if add_filter_criteria:
|
||||
if expire_jane_age:
|
||||
asserter.assert_(
|
||||
# it has to unexpire jane.name, because jane is not fully
|
||||
# expired and the criteria needs to look at this particular
|
||||
# key
|
||||
CompiledSQL(
|
||||
"SELECT users.age_int AS users_age_int, "
|
||||
"users.name AS users_name FROM users "
|
||||
"WHERE users.id = :pk_1",
|
||||
[{"pk_1": 4}],
|
||||
),
|
||||
# previously, this would unexpire the attribute and
|
||||
# cause an additional SELECT. The
|
||||
# 2.0 approach is that if the object has expired attrs
|
||||
# we just expire the whole thing, avoiding SQL up front
|
||||
CompiledSQL(
|
||||
"UPDATE users "
|
||||
"SET age_int=(users.age_int + :age_int_1) "
|
||||
@@ -397,14 +421,10 @@ class UpdateDeleteTest(fixtures.MappedTest):
|
||||
)
|
||||
else:
|
||||
asserter.assert_(
|
||||
# it has to unexpire jane.name, because jane is not fully
|
||||
# expired and the criteria needs to look at this particular
|
||||
# key
|
||||
CompiledSQL(
|
||||
"SELECT users.name AS users_name FROM users "
|
||||
"WHERE users.id = :pk_1",
|
||||
[{"pk_1": 4}],
|
||||
),
|
||||
# previously, this would unexpire the attribute and
|
||||
# cause an additional SELECT. The
|
||||
# 2.0 approach is that if the object has expired attrs
|
||||
# we just expire the whole thing, avoiding SQL up front
|
||||
CompiledSQL(
|
||||
"UPDATE users SET "
|
||||
"age_int=(users.age_int + :age_int_1) "
|
||||
@@ -443,9 +463,9 @@ class UpdateDeleteTest(fixtures.MappedTest):
|
||||
),
|
||||
]
|
||||
|
||||
if expire_jane_age and not add_filter_criteria:
|
||||
if expire_jane_age:
|
||||
to_assert.append(
|
||||
# refresh jane
|
||||
# refresh jane for partial attributes
|
||||
CompiledSQL(
|
||||
"SELECT users.age_int AS users_age_int, "
|
||||
"users.name AS users_name FROM users "
|
||||
@@ -455,6 +475,75 @@ class UpdateDeleteTest(fixtures.MappedTest):
|
||||
)
|
||||
asserter.assert_(*to_assert)
|
||||
|
||||
@testing.combinations(True, False, argnames="is_evaluable")
|
||||
def test_auto_synchronize(self, is_evaluable):
|
||||
User = self.classes.User
|
||||
|
||||
sess = fixture_session()
|
||||
|
||||
john, jack, jill, jane = sess.query(User).order_by(User.id).all()
|
||||
|
||||
if is_evaluable:
|
||||
crit = or_(User.name == "jack", User.name == "jane")
|
||||
else:
|
||||
crit = case((User.name.in_(["jack", "jane"]), True), else_=False)
|
||||
|
||||
with self.sql_execution_asserter() as asserter:
|
||||
sess.execute(update(User).where(crit).values(age=User.age + 10))
|
||||
|
||||
if is_evaluable:
|
||||
asserter.assert_(
|
||||
CompiledSQL(
|
||||
"UPDATE users SET age_int=(users.age_int + :age_int_1) "
|
||||
"WHERE users.name = :name_1 OR users.name = :name_2",
|
||||
[{"age_int_1": 10, "name_1": "jack", "name_2": "jane"}],
|
||||
),
|
||||
)
|
||||
elif testing.db.dialect.update_returning:
|
||||
asserter.assert_(
|
||||
CompiledSQL(
|
||||
"UPDATE users SET age_int=(users.age_int + :age_int_1) "
|
||||
"WHERE CASE WHEN (users.name IN (__[POSTCOMPILE_name_1])) "
|
||||
"THEN :param_1 ELSE :param_2 END = 1 RETURNING users.id",
|
||||
[
|
||||
{
|
||||
"age_int_1": 10,
|
||||
"name_1": ["jack", "jane"],
|
||||
"param_1": True,
|
||||
"param_2": False,
|
||||
}
|
||||
],
|
||||
),
|
||||
)
|
||||
else:
|
||||
asserter.assert_(
|
||||
CompiledSQL(
|
||||
"SELECT users.id FROM users WHERE CASE WHEN "
|
||||
"(users.name IN (__[POSTCOMPILE_name_1])) "
|
||||
"THEN :param_1 ELSE :param_2 END = 1",
|
||||
[
|
||||
{
|
||||
"name_1": ["jack", "jane"],
|
||||
"param_1": True,
|
||||
"param_2": False,
|
||||
}
|
||||
],
|
||||
),
|
||||
CompiledSQL(
|
||||
"UPDATE users SET age_int=(users.age_int + :age_int_1) "
|
||||
"WHERE CASE WHEN (users.name IN (__[POSTCOMPILE_name_1])) "
|
||||
"THEN :param_1 ELSE :param_2 END = 1",
|
||||
[
|
||||
{
|
||||
"age_int_1": 10,
|
||||
"name_1": ["jack", "jane"],
|
||||
"param_1": True,
|
||||
"param_2": False,
|
||||
}
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
def test_fetch_dont_refresh_expired_objects(self):
|
||||
User = self.classes.User
|
||||
|
||||
@@ -518,17 +607,25 @@ class UpdateDeleteTest(fixtures.MappedTest):
|
||||
),
|
||||
)
|
||||
|
||||
def test_delete(self):
|
||||
@testing.combinations(False, None, "auto", "evaluate", "fetch")
|
||||
def test_delete(self, synchronize_session):
|
||||
User = self.classes.User
|
||||
|
||||
sess = fixture_session()
|
||||
|
||||
john, jack, jill, jane = sess.query(User).order_by(User.id).all()
|
||||
sess.query(User).filter(
|
||||
or_(User.name == "john", User.name == "jill")
|
||||
).delete()
|
||||
|
||||
assert john not in sess and jill not in sess
|
||||
stmt = delete(User).filter(
|
||||
or_(User.name == "john", User.name == "jill")
|
||||
)
|
||||
if synchronize_session is not None:
|
||||
stmt = stmt.execution_options(
|
||||
synchronize_session=synchronize_session
|
||||
)
|
||||
sess.execute(stmt)
|
||||
|
||||
if synchronize_session not in (False, None):
|
||||
assert john not in sess and jill not in sess
|
||||
|
||||
eq_(sess.query(User).order_by(User.id).all(), [jack, jane])
|
||||
|
||||
@@ -629,6 +726,33 @@ class UpdateDeleteTest(fixtures.MappedTest):
|
||||
|
||||
eq_(sess.query(User).order_by(User.id).all(), [jack, jill, jane])
|
||||
|
||||
def test_update_multirow_not_supported(self):
|
||||
User = self.classes.User
|
||||
|
||||
sess = fixture_session()
|
||||
|
||||
with expect_raises_message(
|
||||
exc.InvalidRequestError,
|
||||
"WHERE clause with bulk ORM UPDATE not supported " "right now.",
|
||||
):
|
||||
sess.execute(
|
||||
update(User).where(User.id == bindparam("id")),
|
||||
[{"id": 1, "age": 27}, {"id": 2, "age": 37}],
|
||||
)
|
||||
|
||||
def test_delete_bulk_not_supported(self):
|
||||
User = self.classes.User
|
||||
|
||||
sess = fixture_session()
|
||||
|
||||
with expect_raises_message(
|
||||
exc.InvalidRequestError, "Bulk ORM DELETE not supported right now."
|
||||
):
|
||||
sess.execute(
|
||||
delete(User),
|
||||
[{"id": 1}, {"id": 2}],
|
||||
)
|
||||
|
||||
def test_update(self):
|
||||
User, users = self.classes.User, self.tables.users
|
||||
|
||||
@@ -640,6 +764,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
|
||||
)
|
||||
|
||||
eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27])
|
||||
|
||||
eq_(
|
||||
sess.query(User.age).order_by(User.id).all(),
|
||||
list(zip([25, 37, 29, 27])),
|
||||
@@ -974,7 +1099,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
|
||||
)
|
||||
|
||||
@testing.requires.update_returning
|
||||
def test_update_explicit_returning(self):
|
||||
def test_update_evaluate_w_explicit_returning(self):
|
||||
User = self.classes.User
|
||||
|
||||
sess = fixture_session()
|
||||
@@ -987,6 +1112,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
|
||||
.filter(User.age > 29)
|
||||
.values({"age": User.age - 10})
|
||||
.returning(User.id)
|
||||
.execution_options(synchronize_session="evaluate")
|
||||
)
|
||||
|
||||
rows = sess.execute(stmt).all()
|
||||
@@ -1006,24 +1132,41 @@ class UpdateDeleteTest(fixtures.MappedTest):
|
||||
)
|
||||
|
||||
@testing.requires.update_returning
|
||||
def test_no_fetch_w_explicit_returning(self):
|
||||
@testing.combinations("update", "delete", argnames="crud_type")
|
||||
def test_fetch_w_explicit_returning(self, crud_type):
|
||||
User = self.classes.User
|
||||
|
||||
sess = fixture_session()
|
||||
|
||||
stmt = (
|
||||
update(User)
|
||||
.filter(User.age > 29)
|
||||
.values({"age": User.age - 10})
|
||||
.execution_options(synchronize_session="fetch")
|
||||
.returning(User.id)
|
||||
)
|
||||
with expect_raises_message(
|
||||
exc.InvalidRequestError,
|
||||
r"Can't use synchronize_session='fetch' "
|
||||
r"with explicit returning\(\)",
|
||||
):
|
||||
sess.execute(stmt)
|
||||
if crud_type == "update":
|
||||
stmt = (
|
||||
update(User)
|
||||
.filter(User.age > 29)
|
||||
.values({"age": User.age - 10})
|
||||
.execution_options(synchronize_session="fetch")
|
||||
.returning(User, User.name)
|
||||
)
|
||||
expected = [
|
||||
(User(age=37), "jack"),
|
||||
(User(age=27), "jane"),
|
||||
]
|
||||
elif crud_type == "delete":
|
||||
stmt = (
|
||||
delete(User)
|
||||
.filter(User.age > 29)
|
||||
.execution_options(synchronize_session="fetch")
|
||||
.returning(User, User.name)
|
||||
)
|
||||
expected = [
|
||||
(User(age=47), "jack"),
|
||||
(User(age=37), "jane"),
|
||||
]
|
||||
else:
|
||||
assert False
|
||||
|
||||
result = sess.execute(stmt)
|
||||
|
||||
eq_(result.all(), expected)
|
||||
|
||||
@testing.combinations(True, False, argnames="implicit_returning")
|
||||
def test_delete_fetch_returning(self, implicit_returning):
|
||||
@@ -1142,7 +1285,8 @@ class UpdateDeleteTest(fixtures.MappedTest):
|
||||
list(zip([25, 47, 44, 37])),
|
||||
)
|
||||
|
||||
def test_update_changes_resets_dirty(self):
|
||||
@testing.combinations("orm", "bulk")
|
||||
def test_update_changes_resets_dirty(self, update_type):
|
||||
User = self.classes.User
|
||||
|
||||
sess = fixture_session(autoflush=False)
|
||||
@@ -1155,9 +1299,30 @@ class UpdateDeleteTest(fixtures.MappedTest):
|
||||
# autoflush is false. therefore our '50' and '37' are getting
|
||||
# blown away by this operation.
|
||||
|
||||
sess.query(User).filter(User.age > 29).update(
|
||||
{"age": User.age - 10}, synchronize_session="evaluate"
|
||||
)
|
||||
if update_type == "orm":
|
||||
sess.execute(
|
||||
update(User)
|
||||
.filter(User.age > 29)
|
||||
.values({"age": User.age - 10}),
|
||||
execution_options=dict(synchronize_session="evaluate"),
|
||||
)
|
||||
elif update_type == "bulk":
|
||||
|
||||
data = [
|
||||
{"id": john.id, "age": 25},
|
||||
{"id": jack.id, "age": 37},
|
||||
{"id": jill.id, "age": 29},
|
||||
{"id": jane.id, "age": 27},
|
||||
]
|
||||
|
||||
sess.execute(
|
||||
update(User),
|
||||
data,
|
||||
execution_options=dict(synchronize_session="evaluate"),
|
||||
)
|
||||
|
||||
else:
|
||||
assert False
|
||||
|
||||
for x in (john, jack, jill, jane):
|
||||
assert not sess.is_modified(x)
|
||||
@@ -1171,6 +1336,93 @@ class UpdateDeleteTest(fixtures.MappedTest):
|
||||
assert not sess.is_modified(john)
|
||||
assert not sess.is_modified(jack)
|
||||
|
||||
@testing.combinations(
|
||||
None, False, "evaluate", "fetch", argnames="synchronize_session"
|
||||
)
|
||||
@testing.combinations(True, False, argnames="homogeneous_keys")
|
||||
def test_bulk_update_synchronize_session(
|
||||
self, synchronize_session, homogeneous_keys
|
||||
):
|
||||
User = self.classes.User
|
||||
|
||||
sess = fixture_session(expire_on_commit=False)
|
||||
|
||||
john, jack, jill, jane = sess.query(User).order_by(User.id).all()
|
||||
|
||||
if homogeneous_keys:
|
||||
data = [
|
||||
{"id": john.id, "age": 35},
|
||||
{"id": jack.id, "age": 27},
|
||||
{"id": jill.id, "age": 30},
|
||||
]
|
||||
else:
|
||||
data = [
|
||||
{"id": john.id, "age": 35},
|
||||
{"id": jack.id, "name": "new jack"},
|
||||
{"id": jill.id, "age": 30, "name": "new jill"},
|
||||
]
|
||||
|
||||
with self.sql_execution_asserter() as asserter:
|
||||
if synchronize_session is not None:
|
||||
opts = {"synchronize_session": synchronize_session}
|
||||
else:
|
||||
opts = {}
|
||||
|
||||
if synchronize_session == "fetch":
|
||||
with expect_raises_message(
|
||||
exc.InvalidRequestError,
|
||||
"The 'fetch' synchronization strategy is not available "
|
||||
"for 'bulk' ORM updates",
|
||||
):
|
||||
sess.execute(update(User), data, execution_options=opts)
|
||||
return
|
||||
else:
|
||||
sess.execute(update(User), data, execution_options=opts)
|
||||
|
||||
if homogeneous_keys:
|
||||
asserter.assert_(
|
||||
CompiledSQL(
|
||||
"UPDATE users SET age_int=:age_int "
|
||||
"WHERE users.id = :users_id",
|
||||
[
|
||||
{"age_int": 35, "users_id": 1},
|
||||
{"age_int": 27, "users_id": 2},
|
||||
{"age_int": 30, "users_id": 3},
|
||||
],
|
||||
)
|
||||
)
|
||||
else:
|
||||
asserter.assert_(
|
||||
CompiledSQL(
|
||||
"UPDATE users SET age_int=:age_int "
|
||||
"WHERE users.id = :users_id",
|
||||
[{"age_int": 35, "users_id": 1}],
|
||||
),
|
||||
CompiledSQL(
|
||||
"UPDATE users SET name=:name WHERE users.id = :users_id",
|
||||
[{"name": "new jack", "users_id": 2}],
|
||||
),
|
||||
CompiledSQL(
|
||||
"UPDATE users SET name=:name, age_int=:age_int "
|
||||
"WHERE users.id = :users_id",
|
||||
[{"name": "new jill", "age_int": 30, "users_id": 3}],
|
||||
),
|
||||
)
|
||||
|
||||
if synchronize_session is False:
|
||||
eq_(jill.name, "jill")
|
||||
eq_(jack.name, "jack")
|
||||
eq_(jill.age, 29)
|
||||
eq_(jack.age, 47)
|
||||
else:
|
||||
if not homogeneous_keys:
|
||||
eq_(jill.name, "new jill")
|
||||
eq_(jack.name, "new jack")
|
||||
eq_(jack.age, 47)
|
||||
else:
|
||||
eq_(jack.age, 27)
|
||||
eq_(jill.age, 30)
|
||||
|
||||
def test_update_changes_with_autoflush(self):
|
||||
User = self.classes.User
|
||||
|
||||
@@ -1214,7 +1466,8 @@ class UpdateDeleteTest(fixtures.MappedTest):
|
||||
)
|
||||
|
||||
@testing.fails_if(lambda: not testing.db.dialect.supports_sane_rowcount)
|
||||
def test_update_returns_rowcount(self):
|
||||
@testing.combinations("auto", "fetch", "evaluate")
|
||||
def test_update_returns_rowcount(self, synchronize_session):
|
||||
User = self.classes.User
|
||||
|
||||
sess = fixture_session()
|
||||
@@ -1222,20 +1475,25 @@ class UpdateDeleteTest(fixtures.MappedTest):
|
||||
rowcount = (
|
||||
sess.query(User)
|
||||
.filter(User.age > 29)
|
||||
.update({"age": User.age + 0})
|
||||
.update(
|
||||
{"age": User.age + 0}, synchronize_session=synchronize_session
|
||||
)
|
||||
)
|
||||
eq_(rowcount, 2)
|
||||
|
||||
rowcount = (
|
||||
sess.query(User)
|
||||
.filter(User.age > 29)
|
||||
.update({"age": User.age - 10})
|
||||
.update(
|
||||
{"age": User.age - 10}, synchronize_session=synchronize_session
|
||||
)
|
||||
)
|
||||
eq_(rowcount, 2)
|
||||
|
||||
# test future
|
||||
result = sess.execute(
|
||||
update(User).where(User.age > 19).values({"age": User.age - 10})
|
||||
update(User).where(User.age > 19).values({"age": User.age - 10}),
|
||||
execution_options={"synchronize_session": synchronize_session},
|
||||
)
|
||||
eq_(result.rowcount, 4)
|
||||
|
||||
@@ -1327,12 +1585,17 @@ class UpdateDeleteTest(fixtures.MappedTest):
|
||||
)
|
||||
assert john not in sess
|
||||
|
||||
def test_evaluate_before_update(self):
|
||||
@testing.combinations(True, False)
|
||||
def test_evaluate_before_update(self, full_expiration):
|
||||
User = self.classes.User
|
||||
|
||||
sess = fixture_session()
|
||||
john = sess.query(User).filter_by(name="john").one()
|
||||
sess.expire(john, ["age"])
|
||||
|
||||
if full_expiration:
|
||||
sess.expire(john)
|
||||
else:
|
||||
sess.expire(john, ["age"])
|
||||
|
||||
# eval must be before the update. otherwise
|
||||
# we eval john, age has been expired and doesn't
|
||||
@@ -1356,17 +1619,47 @@ class UpdateDeleteTest(fixtures.MappedTest):
|
||||
eq_(john.name, "j2")
|
||||
eq_(john.age, 40)
|
||||
|
||||
def test_evaluate_before_delete(self):
|
||||
@testing.combinations(True, False)
|
||||
def test_evaluate_before_delete(self, full_expiration):
|
||||
User = self.classes.User
|
||||
|
||||
sess = fixture_session()
|
||||
john = sess.query(User).filter_by(name="john").one()
|
||||
sess.expire(john, ["age"])
|
||||
jill = sess.query(User).filter_by(name="jill").one()
|
||||
jane = sess.query(User).filter_by(name="jane").one()
|
||||
|
||||
sess.query(User).filter_by(name="john").filter_by(age=25).delete(
|
||||
if full_expiration:
|
||||
sess.expire(jill)
|
||||
sess.expire(john)
|
||||
else:
|
||||
sess.expire(jill, ["age"])
|
||||
sess.expire(john, ["age"])
|
||||
|
||||
sess.query(User).filter(or_(User.age == 25, User.age == 37)).delete(
|
||||
synchronize_session="evaluate"
|
||||
)
|
||||
assert john not in sess
|
||||
|
||||
# was fully deleted
|
||||
assert jane not in sess
|
||||
|
||||
# deleted object was expired, but not otherwise affected
|
||||
assert jill in sess
|
||||
|
||||
# deleted object was expired, but not otherwise affected
|
||||
assert john in sess
|
||||
|
||||
# partially expired row fully expired
|
||||
assert inspect(jill).expired
|
||||
|
||||
# non-deleted row still present
|
||||
eq_(jill.age, 29)
|
||||
|
||||
# partially expired row fully expired
|
||||
assert inspect(john).expired
|
||||
|
||||
# is deleted
|
||||
with expect_raises(orm_exc.ObjectDeletedError):
|
||||
john.name
|
||||
|
||||
def test_fetch_before_delete(self):
|
||||
User = self.classes.User
|
||||
@@ -1378,6 +1671,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
|
||||
sess.query(User).filter_by(name="john").filter_by(age=25).delete(
|
||||
synchronize_session="fetch"
|
||||
)
|
||||
|
||||
assert john not in sess
|
||||
|
||||
def test_update_unordered_dict(self):
|
||||
@@ -1495,6 +1789,60 @@ class UpdateDeleteTest(fixtures.MappedTest):
|
||||
]
|
||||
eq_(["name", "age_int"], cols)
|
||||
|
||||
@testing.requires.sqlite
|
||||
def test_sharding_extension_returning_mismatch(self, testing_engine):
|
||||
"""test one horizontal shard case where the given binds don't match
|
||||
for RETURNING support; we dont support this.
|
||||
|
||||
See test/ext/test_horizontal_shard.py for complete round trip
|
||||
test cases for ORM update/delete
|
||||
|
||||
"""
|
||||
e1 = testing_engine("sqlite://")
|
||||
e2 = testing_engine("sqlite://")
|
||||
e1.connect().close()
|
||||
e2.connect().close()
|
||||
|
||||
e1.dialect.update_returning = True
|
||||
e2.dialect.update_returning = False
|
||||
|
||||
engines = [e1, e2]
|
||||
|
||||
# a simulated version of the horizontal sharding extension
|
||||
def execute_and_instances(orm_context):
|
||||
execution_options = dict(orm_context.local_execution_options)
|
||||
partial = []
|
||||
for engine in engines:
|
||||
bind_arguments = dict(orm_context.bind_arguments)
|
||||
bind_arguments["bind"] = engine
|
||||
result_ = orm_context.invoke_statement(
|
||||
bind_arguments=bind_arguments,
|
||||
execution_options=execution_options,
|
||||
)
|
||||
|
||||
partial.append(result_)
|
||||
return partial[0].merge(*partial[1:])
|
||||
|
||||
User = self.classes.User
|
||||
session = Session()
|
||||
|
||||
event.listen(
|
||||
session, "do_orm_execute", execute_and_instances, retval=True
|
||||
)
|
||||
|
||||
stmt = (
|
||||
update(User)
|
||||
.filter(User.id == 15)
|
||||
.values(age=123)
|
||||
.execution_options(synchronize_session="fetch")
|
||||
)
|
||||
with expect_raises_message(
|
||||
exc.InvalidRequestError,
|
||||
"For synchronize_session='fetch', can't mix multiple backends "
|
||||
"where some support RETURNING and others don't",
|
||||
):
|
||||
session.execute(stmt)
|
||||
|
||||
|
||||
class UpdateDeleteIgnoresLoadersTest(fixtures.MappedTest):
|
||||
@classmethod
|
||||
@@ -1748,6 +2096,7 @@ class UpdateDeleteFromTest(fixtures.MappedTest):
|
||||
"Could not evaluate current criteria in Python.",
|
||||
q.update,
|
||||
{"samename": "ed"},
|
||||
synchronize_session="evaluate",
|
||||
)
|
||||
|
||||
@testing.requires.multi_table_update
|
||||
@@ -1901,7 +2250,7 @@ class ExpressionUpdateTest(fixtures.MappedTest):
|
||||
sess.commit()
|
||||
eq_(d1.cnt, 0)
|
||||
|
||||
sess.query(Data).update({Data.cnt: Data.cnt + 1})
|
||||
sess.query(Data).update({Data.cnt: Data.cnt + 1}, "evaluate")
|
||||
sess.flush()
|
||||
|
||||
eq_(d1.cnt, 1)
|
||||
@@ -2443,7 +2792,8 @@ class LoadFromReturningTest(fixtures.MappedTest):
|
||||
)
|
||||
|
||||
@testing.requires.update_returning
|
||||
def test_load_from_update(self, connection):
|
||||
@testing.combinations(True, False, argnames="use_from_statement")
|
||||
def test_load_from_update(self, connection, use_from_statement):
|
||||
User = self.classes.User
|
||||
|
||||
stmt = (
|
||||
@@ -2453,7 +2803,16 @@ class LoadFromReturningTest(fixtures.MappedTest):
|
||||
.returning(User)
|
||||
)
|
||||
|
||||
stmt = select(User).from_statement(stmt)
|
||||
if use_from_statement:
|
||||
# this is now a legacy-ish case, because as of 2.0 you can just
|
||||
# use returning() directly to get the objects back.
|
||||
#
|
||||
# when from_statement is used, the UPDATE statement is no
|
||||
# longer interpreted by
|
||||
# BulkUDCompileState.orm_pre_session_exec or
|
||||
# BulkUDCompileState.orm_setup_cursor_result. The compilation
|
||||
# level routines still take place though
|
||||
stmt = select(User).from_statement(stmt)
|
||||
|
||||
with Session(connection) as sess:
|
||||
rows = sess.execute(stmt).scalars().all()
|
||||
@@ -2468,7 +2827,8 @@ class LoadFromReturningTest(fixtures.MappedTest):
|
||||
("multiple", testing.requires.multivalues_inserts),
|
||||
argnames="params",
|
||||
)
|
||||
def test_load_from_insert(self, connection, params):
|
||||
@testing.combinations(True, False, argnames="use_from_statement")
|
||||
def test_load_from_insert(self, connection, params, use_from_statement):
|
||||
User = self.classes.User
|
||||
|
||||
if params == "multiple":
|
||||
@@ -2484,7 +2844,8 @@ class LoadFromReturningTest(fixtures.MappedTest):
|
||||
|
||||
stmt = insert(User).values(values).returning(User)
|
||||
|
||||
stmt = select(User).from_statement(stmt)
|
||||
if use_from_statement:
|
||||
stmt = select(User).from_statement(stmt)
|
||||
|
||||
with Session(connection) as sess:
|
||||
rows = sess.execute(stmt).scalars().all()
|
||||
@@ -2505,3 +2866,25 @@ class LoadFromReturningTest(fixtures.MappedTest):
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
|
||||
@testing.requires.delete_returning
|
||||
@testing.combinations(True, False, argnames="use_from_statement")
|
||||
def test_load_from_delete(self, connection, use_from_statement):
|
||||
User = self.classes.User
|
||||
|
||||
stmt = (
|
||||
delete(User).where(User.name.in_(["jack", "jill"])).returning(User)
|
||||
)
|
||||
|
||||
if use_from_statement:
|
||||
stmt = select(User).from_statement(stmt)
|
||||
|
||||
with Session(connection) as sess:
|
||||
rows = sess.execute(stmt).scalars().all()
|
||||
|
||||
eq_(
|
||||
rows,
|
||||
[User(name="jack", age=47), User(name="jill", age=29)],
|
||||
)
|
||||
|
||||
# TODO: state of above objects should be "deleted"
|
||||
@@ -2012,7 +2012,8 @@ class JoinedNoFKSortingTest(fixtures.MappedTest):
|
||||
and testing.db.dialect.supports_default_metavalue,
|
||||
[
|
||||
CompiledSQL(
|
||||
"INSERT INTO a (id) VALUES (DEFAULT)", [{}, {}, {}, {}]
|
||||
"INSERT INTO a (id) VALUES (DEFAULT) RETURNING a.id",
|
||||
[{}, {}, {}, {}],
|
||||
),
|
||||
],
|
||||
[
|
||||
|
||||
+10
-1
@@ -326,6 +326,7 @@ class BindIntegrationTest(_fixtures.FixtureTest):
|
||||
),
|
||||
(
|
||||
lambda User: update(User)
|
||||
.execution_options(synchronize_session=False)
|
||||
.values(name="not ed")
|
||||
.where(User.name == "ed"),
|
||||
lambda User: {"clause": mock.ANY, "mapper": inspect(User)},
|
||||
@@ -392,7 +393,15 @@ class BindIntegrationTest(_fixtures.FixtureTest):
|
||||
engine = {"e1": e1, "e2": e2, "e3": e3}[expected_engine_name]
|
||||
|
||||
with mock.patch(
|
||||
"sqlalchemy.orm.context.ORMCompileState.orm_setup_cursor_result"
|
||||
"sqlalchemy.orm.context." "ORMCompileState.orm_setup_cursor_result"
|
||||
), mock.patch(
|
||||
"sqlalchemy.orm.context.ORMCompileState.orm_execute_statement"
|
||||
), mock.patch(
|
||||
"sqlalchemy.orm.bulk_persistence."
|
||||
"BulkORMInsert.orm_execute_statement"
|
||||
), mock.patch(
|
||||
"sqlalchemy.orm.bulk_persistence."
|
||||
"BulkUDCompileState.orm_setup_cursor_result"
|
||||
):
|
||||
sess.execute(statement)
|
||||
|
||||
|
||||
+204
-3
@@ -1,8 +1,10 @@
|
||||
import dataclasses
|
||||
import operator
|
||||
import random
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import String
|
||||
@@ -233,7 +235,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
|
||||
is g.edges[1]
|
||||
)
|
||||
|
||||
def test_bulk_update_sql(self):
|
||||
def test_update_crit_sql(self):
|
||||
Edge, Point = (self.classes.Edge, self.classes.Point)
|
||||
|
||||
sess = self._fixture()
|
||||
@@ -258,7 +260,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
|
||||
dialect="default",
|
||||
)
|
||||
|
||||
def test_bulk_update_evaluate(self):
|
||||
def test_update_crit_evaluate(self):
|
||||
Edge, Point = (self.classes.Edge, self.classes.Point)
|
||||
|
||||
sess = self._fixture()
|
||||
@@ -287,7 +289,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
|
||||
|
||||
eq_(e1.end, Point(17, 8))
|
||||
|
||||
def test_bulk_update_fetch(self):
|
||||
def test_update_crit_fetch(self):
|
||||
Edge, Point = (self.classes.Edge, self.classes.Point)
|
||||
|
||||
sess = self._fixture()
|
||||
@@ -305,6 +307,205 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
|
||||
|
||||
eq_(e1.end, Point(17, 8))
|
||||
|
||||
@testing.combinations(
|
||||
"legacy", "statement", "values", "stmt_returning", "values_returning"
|
||||
)
|
||||
def test_bulk_insert(self, type_):
|
||||
Edge, Point = (self.classes.Edge, self.classes.Point)
|
||||
Graph = self.classes.Graph
|
||||
|
||||
sess = self._fixture()
|
||||
|
||||
graph = Graph(id=2)
|
||||
sess.add(graph)
|
||||
sess.flush()
|
||||
graph_id = 2
|
||||
|
||||
data = [
|
||||
{
|
||||
"start": Point(random.randint(1, 50), random.randint(1, 50)),
|
||||
"end": Point(random.randint(1, 50), random.randint(1, 50)),
|
||||
"graph_id": graph_id,
|
||||
}
|
||||
for i in range(25)
|
||||
]
|
||||
returning = False
|
||||
if type_ == "statement":
|
||||
sess.execute(insert(Edge), data)
|
||||
elif type_ == "stmt_returning":
|
||||
result = sess.scalars(insert(Edge).returning(Edge), data)
|
||||
returning = True
|
||||
elif type_ == "values":
|
||||
sess.execute(insert(Edge).values(data))
|
||||
elif type_ == "values_returning":
|
||||
result = sess.scalars(insert(Edge).values(data).returning(Edge))
|
||||
returning = True
|
||||
elif type_ == "legacy":
|
||||
sess.bulk_insert_mappings(Edge, data)
|
||||
else:
|
||||
assert False
|
||||
|
||||
if returning:
|
||||
eq_(result.all(), [Edge(rec["start"], rec["end"]) for rec in data])
|
||||
|
||||
edges = self.tables.edges
|
||||
eq_(
|
||||
sess.execute(
|
||||
select(edges.c["x1", "y1", "x2", "y2"])
|
||||
.where(edges.c.graph_id == graph_id)
|
||||
.order_by(edges.c.id)
|
||||
).all(),
|
||||
[
|
||||
(e["start"].x, e["start"].y, e["end"].x, e["end"].y)
|
||||
for e in data
|
||||
],
|
||||
)
|
||||
|
||||
@testing.combinations("legacy", "statement")
|
||||
def test_bulk_insert_heterogeneous(self, type_):
|
||||
Edge, Point = (self.classes.Edge, self.classes.Point)
|
||||
Graph = self.classes.Graph
|
||||
|
||||
sess = self._fixture()
|
||||
|
||||
graph = Graph(id=2)
|
||||
sess.add(graph)
|
||||
sess.flush()
|
||||
graph_id = 2
|
||||
|
||||
d1 = [
|
||||
{
|
||||
"start": Point(random.randint(1, 50), random.randint(1, 50)),
|
||||
"end": Point(random.randint(1, 50), random.randint(1, 50)),
|
||||
"graph_id": graph_id,
|
||||
}
|
||||
for i in range(3)
|
||||
]
|
||||
d2 = [
|
||||
{
|
||||
"start": Point(random.randint(1, 50), random.randint(1, 50)),
|
||||
"graph_id": graph_id,
|
||||
}
|
||||
for i in range(2)
|
||||
]
|
||||
d3 = [
|
||||
{
|
||||
"x2": random.randint(1, 50),
|
||||
"y2": random.randint(1, 50),
|
||||
"graph_id": graph_id,
|
||||
}
|
||||
for i in range(2)
|
||||
]
|
||||
data = d1 + d2 + d3
|
||||
random.shuffle(data)
|
||||
|
||||
assert_data = [
|
||||
{
|
||||
"start": d["start"] if "start" in d else None,
|
||||
"end": d["end"]
|
||||
if "end" in d
|
||||
else Point(d["x2"], d["y2"])
|
||||
if "x2" in d
|
||||
else None,
|
||||
"graph_id": d["graph_id"],
|
||||
}
|
||||
for d in data
|
||||
]
|
||||
|
||||
if type_ == "statement":
|
||||
sess.execute(insert(Edge), data)
|
||||
elif type_ == "legacy":
|
||||
sess.bulk_insert_mappings(Edge, data)
|
||||
else:
|
||||
assert False
|
||||
|
||||
edges = self.tables.edges
|
||||
eq_(
|
||||
sess.execute(
|
||||
select(edges.c["x1", "y1", "x2", "y2"])
|
||||
.where(edges.c.graph_id == graph_id)
|
||||
.order_by(edges.c.id)
|
||||
).all(),
|
||||
[
|
||||
(
|
||||
e["start"].x if e["start"] else None,
|
||||
e["start"].y if e["start"] else None,
|
||||
e["end"].x if e["end"] else None,
|
||||
e["end"].y if e["end"] else None,
|
||||
)
|
||||
for e in assert_data
|
||||
],
|
||||
)
|
||||
|
||||
@testing.combinations("legacy", "statement")
|
||||
def test_bulk_update(self, type_):
|
||||
Edge, Point = (self.classes.Edge, self.classes.Point)
|
||||
Graph = self.classes.Graph
|
||||
|
||||
sess = self._fixture()
|
||||
|
||||
graph = Graph(id=2)
|
||||
sess.add(graph)
|
||||
sess.flush()
|
||||
graph_id = 2
|
||||
|
||||
data = [
|
||||
{
|
||||
"start": Point(random.randint(1, 50), random.randint(1, 50)),
|
||||
"end": Point(random.randint(1, 50), random.randint(1, 50)),
|
||||
"graph_id": graph_id,
|
||||
}
|
||||
for i in range(25)
|
||||
]
|
||||
sess.execute(insert(Edge), data)
|
||||
|
||||
inserted_data = [
|
||||
dict(row._mapping)
|
||||
for row in sess.execute(
|
||||
select(Edge.id, Edge.start, Edge.end, Edge.graph_id)
|
||||
.where(Edge.graph_id == graph_id)
|
||||
.order_by(Edge.id)
|
||||
)
|
||||
]
|
||||
|
||||
to_update = []
|
||||
updated_pks = {}
|
||||
for rec in random.choices(inserted_data, k=7):
|
||||
rec_copy = dict(rec)
|
||||
updated_pks[rec_copy["id"]] = rec_copy
|
||||
rec_copy["start"] = Point(
|
||||
random.randint(1, 50), random.randint(1, 50)
|
||||
)
|
||||
rec_copy["end"] = Point(
|
||||
random.randint(1, 50), random.randint(1, 50)
|
||||
)
|
||||
to_update.append(rec_copy)
|
||||
|
||||
expected_dataset = [
|
||||
updated_pks[row["id"]] if row["id"] in updated_pks else row
|
||||
for row in inserted_data
|
||||
]
|
||||
|
||||
if type_ == "statement":
|
||||
sess.execute(update(Edge), to_update)
|
||||
elif type_ == "legacy":
|
||||
sess.bulk_update_mappings(Edge, to_update)
|
||||
else:
|
||||
assert False
|
||||
|
||||
edges = self.tables.edges
|
||||
eq_(
|
||||
sess.execute(
|
||||
select(edges.c["x1", "y1", "x2", "y2"])
|
||||
.where(edges.c.graph_id == graph_id)
|
||||
.order_by(edges.c.id)
|
||||
).all(),
|
||||
[
|
||||
(e["start"].x, e["start"].y, e["end"].x, e["end"].y)
|
||||
for e in expected_dataset
|
||||
],
|
||||
)
|
||||
|
||||
def test_get_history(self):
|
||||
Edge = self.classes.Edge
|
||||
Point = self.classes.Point
|
||||
|
||||
@@ -1122,7 +1122,7 @@ class OneToManyManyToOneTest(fixtures.MappedTest):
|
||||
[
|
||||
CompiledSQL(
|
||||
"INSERT INTO ball (person_id, data) "
|
||||
"VALUES (:person_id, :data)",
|
||||
"VALUES (:person_id, :data) RETURNING ball.id",
|
||||
[
|
||||
{"person_id": None, "data": "some data"},
|
||||
{"person_id": None, "data": "some data"},
|
||||
|
||||
@@ -383,20 +383,24 @@ class ComputedDefaultsOnUpdateTest(fixtures.MappedTest):
|
||||
CompiledSQL(
|
||||
"UPDATE test SET foo=:foo WHERE test.id = :test_id",
|
||||
[{"foo": 5, "test_id": 1}],
|
||||
enable_returning=False,
|
||||
),
|
||||
CompiledSQL(
|
||||
"UPDATE test SET foo=:foo WHERE test.id = :test_id",
|
||||
[{"foo": 6, "test_id": 2}],
|
||||
enable_returning=False,
|
||||
),
|
||||
CompiledSQL(
|
||||
"SELECT test.bar AS test_bar FROM test "
|
||||
"WHERE test.id = :pk_1",
|
||||
[{"pk_1": 1}],
|
||||
enable_returning=False,
|
||||
),
|
||||
CompiledSQL(
|
||||
"SELECT test.bar AS test_bar FROM test "
|
||||
"WHERE test.id = :pk_1",
|
||||
[{"pk_1": 2}],
|
||||
enable_returning=False,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
||||
+11
-2
@@ -661,8 +661,17 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest):
|
||||
|
||||
canary = self._flag_fixture(sess)
|
||||
|
||||
sess.execute(delete(User).filter_by(id=18))
|
||||
sess.execute(update(User).filter_by(id=18).values(name="eighteen"))
|
||||
sess.execute(
|
||||
delete(User)
|
||||
.filter_by(id=18)
|
||||
.execution_options(synchronize_session="evaluate")
|
||||
)
|
||||
sess.execute(
|
||||
update(User)
|
||||
.filter_by(id=18)
|
||||
.values(name="eighteen")
|
||||
.execution_options(synchronize_session="evaluate")
|
||||
)
|
||||
|
||||
eq_(
|
||||
canary.mock_calls,
|
||||
|
||||
@@ -2868,12 +2868,14 @@ class SaveTest2(_fixtures.FixtureTest):
|
||||
testing.db.dialect.insert_executemany_returning,
|
||||
[
|
||||
CompiledSQL(
|
||||
"INSERT INTO users (name) VALUES (:name)",
|
||||
"INSERT INTO users (name) VALUES (:name) "
|
||||
"RETURNING users.id",
|
||||
[{"name": "u1"}, {"name": "u2"}],
|
||||
),
|
||||
CompiledSQL(
|
||||
"INSERT INTO addresses (user_id, email_address) "
|
||||
"VALUES (:user_id, :email_address)",
|
||||
"VALUES (:user_id, :email_address) "
|
||||
"RETURNING addresses.id",
|
||||
[
|
||||
{"user_id": 1, "email_address": "a1"},
|
||||
{"user_id": 2, "email_address": "a2"},
|
||||
|
||||
@@ -98,7 +98,8 @@ class RudimentaryFlushTest(UOWTest):
|
||||
[
|
||||
CompiledSQL(
|
||||
"INSERT INTO addresses (user_id, email_address) "
|
||||
"VALUES (:user_id, :email_address)",
|
||||
"VALUES (:user_id, :email_address) "
|
||||
"RETURNING addresses.id",
|
||||
lambda ctx: [
|
||||
{"email_address": "a1", "user_id": u1.id},
|
||||
{"email_address": "a2", "user_id": u1.id},
|
||||
@@ -220,7 +221,8 @@ class RudimentaryFlushTest(UOWTest):
|
||||
[
|
||||
CompiledSQL(
|
||||
"INSERT INTO addresses (user_id, email_address) "
|
||||
"VALUES (:user_id, :email_address)",
|
||||
"VALUES (:user_id, :email_address) "
|
||||
"RETURNING addresses.id",
|
||||
lambda ctx: [
|
||||
{"email_address": "a1", "user_id": u1.id},
|
||||
{"email_address": "a2", "user_id": u1.id},
|
||||
@@ -889,7 +891,7 @@ class SingleCycleTest(UOWTest):
|
||||
[
|
||||
CompiledSQL(
|
||||
"INSERT INTO nodes (parent_id, data) VALUES "
|
||||
"(:parent_id, :data)",
|
||||
"(:parent_id, :data) RETURNING nodes.id",
|
||||
lambda ctx: [
|
||||
{"parent_id": n1.id, "data": "n2"},
|
||||
{"parent_id": n1.id, "data": "n3"},
|
||||
@@ -1003,7 +1005,7 @@ class SingleCycleTest(UOWTest):
|
||||
[
|
||||
CompiledSQL(
|
||||
"INSERT INTO nodes (parent_id, data) VALUES "
|
||||
"(:parent_id, :data)",
|
||||
"(:parent_id, :data) RETURNING nodes.id",
|
||||
lambda ctx: [
|
||||
{"parent_id": n1.id, "data": "n2"},
|
||||
{"parent_id": n1.id, "data": "n3"},
|
||||
@@ -1165,7 +1167,7 @@ class SingleCycleTest(UOWTest):
|
||||
[
|
||||
CompiledSQL(
|
||||
"INSERT INTO nodes (parent_id, data) VALUES "
|
||||
"(:parent_id, :data)",
|
||||
"(:parent_id, :data) RETURNING nodes.id",
|
||||
lambda ctx: [
|
||||
{"parent_id": n1.id, "data": "n11"},
|
||||
{"parent_id": n1.id, "data": "n12"},
|
||||
@@ -1196,7 +1198,7 @@ class SingleCycleTest(UOWTest):
|
||||
[
|
||||
CompiledSQL(
|
||||
"INSERT INTO nodes (parent_id, data) VALUES "
|
||||
"(:parent_id, :data)",
|
||||
"(:parent_id, :data) RETURNING nodes.id",
|
||||
lambda ctx: [
|
||||
{"parent_id": n12.id, "data": "n121"},
|
||||
{"parent_id": n12.id, "data": "n122"},
|
||||
@@ -2099,7 +2101,7 @@ class BatchInsertsTest(fixtures.MappedTest, testing.AssertsExecutionResults):
|
||||
testing.db.dialect.insert_executemany_returning,
|
||||
[
|
||||
CompiledSQL(
|
||||
"INSERT INTO t (data) VALUES (:data)",
|
||||
"INSERT INTO t (data) VALUES (:data) RETURNING t.id",
|
||||
[{"data": "t1"}, {"data": "t2"}],
|
||||
),
|
||||
],
|
||||
@@ -2472,20 +2474,24 @@ class EagerDefaultsTest(fixtures.MappedTest):
|
||||
CompiledSQL(
|
||||
"INSERT INTO test (id, foo) VALUES (:id, 2 + 5)",
|
||||
[{"id": 1}],
|
||||
enable_returning=False,
|
||||
),
|
||||
CompiledSQL(
|
||||
"INSERT INTO test (id, foo) VALUES (:id, 5 + 5)",
|
||||
[{"id": 2}],
|
||||
enable_returning=False,
|
||||
),
|
||||
CompiledSQL(
|
||||
"SELECT test.foo AS test_foo FROM test "
|
||||
"WHERE test.id = :pk_1",
|
||||
[{"pk_1": 1}],
|
||||
enable_returning=False,
|
||||
),
|
||||
CompiledSQL(
|
||||
"SELECT test.foo AS test_foo FROM test "
|
||||
"WHERE test.id = :pk_1",
|
||||
[{"pk_1": 2}],
|
||||
enable_returning=False,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -2678,20 +2684,24 @@ class EagerDefaultsTest(fixtures.MappedTest):
|
||||
CompiledSQL(
|
||||
"UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
|
||||
[{"foo": 5, "test2_id": 1}],
|
||||
enable_returning=False,
|
||||
),
|
||||
CompiledSQL(
|
||||
"UPDATE test2 SET foo=:foo, bar=:bar "
|
||||
"WHERE test2.id = :test2_id",
|
||||
[{"foo": 6, "bar": 10, "test2_id": 2}],
|
||||
enable_returning=False,
|
||||
),
|
||||
CompiledSQL(
|
||||
"UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
|
||||
[{"foo": 7, "test2_id": 3}],
|
||||
enable_returning=False,
|
||||
),
|
||||
CompiledSQL(
|
||||
"UPDATE test2 SET foo=:foo, bar=:bar "
|
||||
"WHERE test2.id = :test2_id",
|
||||
[{"foo": 8, "bar": 12, "test2_id": 4}],
|
||||
enable_returning=False,
|
||||
),
|
||||
CompiledSQL(
|
||||
"SELECT test2.bar AS test2_bar FROM test2 "
|
||||
@@ -2772,31 +2782,37 @@ class EagerDefaultsTest(fixtures.MappedTest):
|
||||
"UPDATE test4 SET foo=:foo, bar=5 + 3 "
|
||||
"WHERE test4.id = :test4_id",
|
||||
[{"foo": 5, "test4_id": 1}],
|
||||
enable_returning=False,
|
||||
),
|
||||
CompiledSQL(
|
||||
"UPDATE test4 SET foo=:foo, bar=:bar "
|
||||
"WHERE test4.id = :test4_id",
|
||||
[{"foo": 6, "bar": 10, "test4_id": 2}],
|
||||
enable_returning=False,
|
||||
),
|
||||
CompiledSQL(
|
||||
"UPDATE test4 SET foo=:foo, bar=5 + 3 "
|
||||
"WHERE test4.id = :test4_id",
|
||||
[{"foo": 7, "test4_id": 3}],
|
||||
enable_returning=False,
|
||||
),
|
||||
CompiledSQL(
|
||||
"UPDATE test4 SET foo=:foo, bar=:bar "
|
||||
"WHERE test4.id = :test4_id",
|
||||
[{"foo": 8, "bar": 12, "test4_id": 4}],
|
||||
enable_returning=False,
|
||||
),
|
||||
CompiledSQL(
|
||||
"SELECT test4.bar AS test4_bar FROM test4 "
|
||||
"WHERE test4.id = :pk_1",
|
||||
[{"pk_1": 1}],
|
||||
enable_returning=False,
|
||||
),
|
||||
CompiledSQL(
|
||||
"SELECT test4.bar AS test4_bar FROM test4 "
|
||||
"WHERE test4.id = :pk_1",
|
||||
[{"pk_1": 3}],
|
||||
enable_returning=False,
|
||||
),
|
||||
],
|
||||
),
|
||||
@@ -2871,20 +2887,24 @@ class EagerDefaultsTest(fixtures.MappedTest):
|
||||
"UPDATE test2 SET foo=:foo, bar=1 + 1 "
|
||||
"WHERE test2.id = :test2_id",
|
||||
[{"foo": 5, "test2_id": 1}],
|
||||
enable_returning=False,
|
||||
),
|
||||
CompiledSQL(
|
||||
"UPDATE test2 SET foo=:foo, bar=:bar "
|
||||
"WHERE test2.id = :test2_id",
|
||||
[{"foo": 6, "bar": 10, "test2_id": 2}],
|
||||
enable_returning=False,
|
||||
),
|
||||
CompiledSQL(
|
||||
"UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
|
||||
[{"foo": 7, "test2_id": 3}],
|
||||
enable_returning=False,
|
||||
),
|
||||
CompiledSQL(
|
||||
"UPDATE test2 SET foo=:foo, bar=5 + 7 "
|
||||
"WHERE test2.id = :test2_id",
|
||||
[{"foo": 8, "test2_id": 4}],
|
||||
enable_returning=False,
|
||||
),
|
||||
CompiledSQL(
|
||||
"SELECT test2.bar AS test2_bar FROM test2 "
|
||||
|
||||
@@ -1424,12 +1424,10 @@ class ServerVersioningTest(fixtures.MappedTest):
|
||||
sess.add(f1)
|
||||
|
||||
statements = [
|
||||
# note that the assertsql tests the rule against
|
||||
# "default" - on a "returning" backend, the statement
|
||||
# includes "RETURNING"
|
||||
CompiledSQL(
|
||||
"INSERT INTO version_table (version_id, value) "
|
||||
"VALUES (1, :value)",
|
||||
"VALUES (1, :value) "
|
||||
"RETURNING version_table.id, version_table.version_id",
|
||||
lambda ctx: [{"value": "f1"}],
|
||||
)
|
||||
]
|
||||
@@ -1493,6 +1491,7 @@ class ServerVersioningTest(fixtures.MappedTest):
|
||||
"value": "f2",
|
||||
}
|
||||
],
|
||||
enable_returning=False,
|
||||
),
|
||||
CompiledSQL(
|
||||
"SELECT version_table.version_id "
|
||||
@@ -1618,6 +1617,7 @@ class ServerVersioningTest(fixtures.MappedTest):
|
||||
"value": "f1a",
|
||||
}
|
||||
],
|
||||
enable_returning=False,
|
||||
),
|
||||
CompiledSQL(
|
||||
"UPDATE version_table SET version_id=2, value=:value "
|
||||
@@ -1630,6 +1630,7 @@ class ServerVersioningTest(fixtures.MappedTest):
|
||||
"value": "f2a",
|
||||
}
|
||||
],
|
||||
enable_returning=False,
|
||||
),
|
||||
CompiledSQL(
|
||||
"UPDATE version_table SET version_id=2, value=:value "
|
||||
@@ -1642,6 +1643,7 @@ class ServerVersioningTest(fixtures.MappedTest):
|
||||
"value": "f3a",
|
||||
}
|
||||
],
|
||||
enable_returning=False,
|
||||
),
|
||||
CompiledSQL(
|
||||
"SELECT version_table.version_id "
|
||||
|
||||
@@ -100,10 +100,55 @@ class CursorResultTest(fixtures.TablesTest):
|
||||
Table(
|
||||
"test",
|
||||
metadata,
|
||||
Column("x", Integer, primary_key=True),
|
||||
Column(
|
||||
"x", Integer, primary_key=True, test_needs_autoincrement=False
|
||||
),
|
||||
Column("y", String(50)),
|
||||
)
|
||||
|
||||
@testing.requires.insert_returning
|
||||
def test_splice_horizontally(self, connection):
|
||||
users = self.tables.users
|
||||
addresses = self.tables.addresses
|
||||
|
||||
r1 = connection.execute(
|
||||
users.insert().returning(users.c.user_name, users.c.user_id),
|
||||
[
|
||||
dict(user_id=1, user_name="john"),
|
||||
dict(user_id=2, user_name="jack"),
|
||||
],
|
||||
)
|
||||
|
||||
r2 = connection.execute(
|
||||
addresses.insert().returning(
|
||||
addresses.c.address_id,
|
||||
addresses.c.address,
|
||||
addresses.c.user_id,
|
||||
),
|
||||
[
|
||||
dict(address_id=1, user_id=1, address="foo@bar.com"),
|
||||
dict(address_id=2, user_id=2, address="bar@bat.com"),
|
||||
],
|
||||
)
|
||||
|
||||
rows = r1.splice_horizontally(r2).all()
|
||||
eq_(
|
||||
rows,
|
||||
[
|
||||
("john", 1, 1, "foo@bar.com", 1),
|
||||
("jack", 2, 2, "bar@bat.com", 2),
|
||||
],
|
||||
)
|
||||
|
||||
eq_(rows[0]._mapping[users.c.user_id], 1)
|
||||
eq_(rows[0]._mapping[addresses.c.user_id], 1)
|
||||
eq_(rows[1].address, "bar@bat.com")
|
||||
|
||||
with expect_raises_message(
|
||||
exc.InvalidRequestError, "Ambiguous column name 'user_id'"
|
||||
):
|
||||
rows[0].user_id
|
||||
|
||||
def test_keys_no_rows(self, connection):
|
||||
|
||||
for i in range(2):
|
||||
|
||||
@@ -23,6 +23,7 @@ from sqlalchemy.testing import config
|
||||
from sqlalchemy.testing import eq_
|
||||
from sqlalchemy.testing import expect_raises_message
|
||||
from sqlalchemy.testing import fixtures
|
||||
from sqlalchemy.testing import is_
|
||||
from sqlalchemy.testing import mock
|
||||
from sqlalchemy.testing import provision
|
||||
from sqlalchemy.testing.schema import Column
|
||||
@@ -76,6 +77,7 @@ class ReturnCombinationTests(fixtures.TestBase, AssertsCompiledSQL):
|
||||
stmt = stmt.returning(t.c.x)
|
||||
|
||||
stmt = stmt.return_defaults()
|
||||
|
||||
assert_raises_message(
|
||||
sa_exc.CompileError,
|
||||
r"Can't compile statement that includes returning\(\) "
|
||||
@@ -330,6 +332,7 @@ class InsertReturningTest(fixtures.TablesTest, AssertsExecutionResults):
|
||||
table = self.tables.returning_tbl
|
||||
|
||||
exprs = testing.resolve_lambda(testcase, table=table)
|
||||
|
||||
result = connection.execute(
|
||||
table.insert().returning(*exprs),
|
||||
{"persons": 5, "full": False, "strval": "str1"},
|
||||
@@ -679,6 +682,30 @@ class InsertReturnDefaultsTest(fixtures.TablesTest):
|
||||
Column("upddef", Integer, onupdate=IncDefault()),
|
||||
)
|
||||
|
||||
Table(
|
||||
"table_no_addtl_defaults",
|
||||
metadata,
|
||||
Column(
|
||||
"id", Integer, primary_key=True, test_needs_autoincrement=True
|
||||
),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
class MyType(TypeDecorator):
|
||||
impl = String(50)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return f"PROCESSED! {value}"
|
||||
|
||||
Table(
|
||||
"table_datatype_has_result_proc",
|
||||
metadata,
|
||||
Column(
|
||||
"id", Integer, primary_key=True, test_needs_autoincrement=True
|
||||
),
|
||||
Column("data", MyType()),
|
||||
)
|
||||
|
||||
def test_chained_insert_pk(self, connection):
|
||||
t1 = self.tables.t1
|
||||
result = connection.execute(
|
||||
@@ -758,6 +785,38 @@ class InsertReturnDefaultsTest(fixtures.TablesTest):
|
||||
)
|
||||
eq_(result.inserted_primary_key, (1,))
|
||||
|
||||
def test_insert_w_defaults_supplemental_cols(self, connection):
|
||||
t1 = self.tables.t1
|
||||
result = connection.execute(
|
||||
t1.insert().return_defaults(supplemental_cols=[t1.c.id]),
|
||||
{"data": "d1"},
|
||||
)
|
||||
eq_(result.all(), [(1, 0, None)])
|
||||
|
||||
def test_insert_w_no_defaults_supplemental_cols(self, connection):
|
||||
t1 = self.tables.table_no_addtl_defaults
|
||||
result = connection.execute(
|
||||
t1.insert().return_defaults(supplemental_cols=[t1.c.id]),
|
||||
{"data": "d1"},
|
||||
)
|
||||
eq_(result.all(), [(1,)])
|
||||
|
||||
def test_insert_w_defaults_supplemental_processor_cols(self, connection):
|
||||
"""test that the cursor._rewind() used by supplemental RETURNING
|
||||
clears out result-row processors as we will have already processed
|
||||
the rows.
|
||||
|
||||
"""
|
||||
|
||||
t1 = self.tables.table_datatype_has_result_proc
|
||||
result = connection.execute(
|
||||
t1.insert().return_defaults(
|
||||
supplemental_cols=[t1.c.id, t1.c.data]
|
||||
),
|
||||
{"data": "d1"},
|
||||
)
|
||||
eq_(result.all(), [(1, "PROCESSED! d1")])
|
||||
|
||||
|
||||
class UpdatedReturnDefaultsTest(fixtures.TablesTest):
|
||||
__requires__ = ("update_returning",)
|
||||
@@ -792,6 +851,7 @@ class UpdatedReturnDefaultsTest(fixtures.TablesTest):
|
||||
|
||||
t1 = self.tables.t1
|
||||
connection.execute(t1.insert().values(upddef=1))
|
||||
|
||||
result = connection.execute(
|
||||
t1.update().values(upddef=2).return_defaults(t1.c.data)
|
||||
)
|
||||
@@ -800,6 +860,72 @@ class UpdatedReturnDefaultsTest(fixtures.TablesTest):
|
||||
[None],
|
||||
)
|
||||
|
||||
def test_update_values_col_is_excluded(self, connection):
|
||||
"""columns that are in values() are not returned"""
|
||||
t1 = self.tables.t1
|
||||
connection.execute(t1.insert().values(upddef=1))
|
||||
|
||||
result = connection.execute(
|
||||
t1.update().values(data="x", upddef=2).return_defaults(t1.c.data)
|
||||
)
|
||||
is_(result.returned_defaults, None)
|
||||
|
||||
result = connection.execute(
|
||||
t1.update()
|
||||
.values(data="x", upddef=2)
|
||||
.return_defaults(t1.c.data, t1.c.id)
|
||||
)
|
||||
eq_(result.returned_defaults, (1,))
|
||||
|
||||
def test_update_supplemental_cols(self, connection):
|
||||
"""with supplemental_cols, we can get back arbitrary cols."""
|
||||
|
||||
t1 = self.tables.t1
|
||||
connection.execute(t1.insert().values(upddef=1))
|
||||
result = connection.execute(
|
||||
t1.update()
|
||||
.values(data="x", insdef=3)
|
||||
.return_defaults(supplemental_cols=[t1.c.data, t1.c.insdef])
|
||||
)
|
||||
|
||||
row = result.returned_defaults
|
||||
|
||||
# row has all the cols in it
|
||||
eq_(row, ("x", 3, 1))
|
||||
eq_(row._mapping[t1.c.upddef], 1)
|
||||
eq_(row._mapping[t1.c.insdef], 3)
|
||||
|
||||
# result is rewound
|
||||
# but has both return_defaults + supplemental_cols
|
||||
eq_(result.all(), [("x", 3, 1)])
|
||||
|
||||
def test_update_expl_return_defaults_plus_supplemental_cols(
|
||||
self, connection
|
||||
):
|
||||
"""with supplemental_cols, we can get back arbitrary cols."""
|
||||
|
||||
t1 = self.tables.t1
|
||||
connection.execute(t1.insert().values(upddef=1))
|
||||
result = connection.execute(
|
||||
t1.update()
|
||||
.values(data="x", insdef=3)
|
||||
.return_defaults(
|
||||
t1.c.id, supplemental_cols=[t1.c.data, t1.c.insdef]
|
||||
)
|
||||
)
|
||||
|
||||
row = result.returned_defaults
|
||||
|
||||
# row has all the cols in it
|
||||
eq_(row, (1, "x", 3))
|
||||
eq_(row._mapping[t1.c.id], 1)
|
||||
eq_(row._mapping[t1.c.insdef], 3)
|
||||
assert t1.c.upddef not in row._mapping
|
||||
|
||||
# result is rewound
|
||||
# but has both return_defaults + supplemental_cols
|
||||
eq_(result.all(), [(1, "x", 3)])
|
||||
|
||||
def test_update_sql_expr(self, connection):
|
||||
from sqlalchemy import literal
|
||||
|
||||
@@ -833,6 +959,75 @@ class UpdatedReturnDefaultsTest(fixtures.TablesTest):
|
||||
eq_(dict(result.returned_defaults._mapping), {"upddef": 1})
|
||||
|
||||
|
||||
class DeleteReturnDefaultsTest(fixtures.TablesTest):
|
||||
__requires__ = ("delete_returning",)
|
||||
run_define_tables = "each"
|
||||
__backend__ = True
|
||||
|
||||
define_tables = InsertReturnDefaultsTest.define_tables
|
||||
|
||||
def test_delete(self, connection):
|
||||
t1 = self.tables.t1
|
||||
connection.execute(t1.insert().values(upddef=1))
|
||||
result = connection.execute(t1.delete().return_defaults(t1.c.upddef))
|
||||
eq_(
|
||||
[result.returned_defaults._mapping[k] for k in (t1.c.upddef,)], [1]
|
||||
)
|
||||
|
||||
def test_delete_empty_return_defaults(self, connection):
|
||||
t1 = self.tables.t1
|
||||
connection.execute(t1.insert().values(upddef=5))
|
||||
result = connection.execute(t1.delete().return_defaults())
|
||||
|
||||
# there's no "delete" default, so we get None. we have to
|
||||
# ask for them in all cases
|
||||
eq_(result.returned_defaults, None)
|
||||
|
||||
def test_delete_non_default(self, connection):
|
||||
"""test that a column not marked at all as a
|
||||
default works with this feature."""
|
||||
|
||||
t1 = self.tables.t1
|
||||
connection.execute(t1.insert().values(upddef=1))
|
||||
result = connection.execute(t1.delete().return_defaults(t1.c.data))
|
||||
eq_(
|
||||
[result.returned_defaults._mapping[k] for k in (t1.c.data,)],
|
||||
[None],
|
||||
)
|
||||
|
||||
def test_delete_non_default_plus_default(self, connection):
|
||||
t1 = self.tables.t1
|
||||
connection.execute(t1.insert().values(upddef=1))
|
||||
result = connection.execute(
|
||||
t1.delete().return_defaults(t1.c.data, t1.c.upddef)
|
||||
)
|
||||
eq_(
|
||||
dict(result.returned_defaults._mapping),
|
||||
{"data": None, "upddef": 1},
|
||||
)
|
||||
|
||||
def test_delete_supplemental_cols(self, connection):
|
||||
"""with supplemental_cols, we can get back arbitrary cols."""
|
||||
|
||||
t1 = self.tables.t1
|
||||
connection.execute(t1.insert().values(upddef=1))
|
||||
result = connection.execute(
|
||||
t1.delete().return_defaults(
|
||||
t1.c.id, supplemental_cols=[t1.c.data, t1.c.insdef]
|
||||
)
|
||||
)
|
||||
|
||||
row = result.returned_defaults
|
||||
|
||||
# row has all the cols in it
|
||||
eq_(row, (1, None, 0))
|
||||
eq_(row._mapping[t1.c.insdef], 0)
|
||||
|
||||
# result is rewound
|
||||
# but has both return_defaults + supplemental_cols
|
||||
eq_(result.all(), [(1, None, 0)])
|
||||
|
||||
|
||||
class InsertManyReturnDefaultsTest(fixtures.TablesTest):
|
||||
__requires__ = ("insert_executemany_returning",)
|
||||
run_define_tables = "each"
|
||||
|
||||
@@ -44,6 +44,7 @@ from sqlalchemy.sql import operators
|
||||
from sqlalchemy.sql import table
|
||||
from sqlalchemy.sql import util as sql_util
|
||||
from sqlalchemy.sql import visitors
|
||||
from sqlalchemy.sql.dml import Insert
|
||||
from sqlalchemy.sql.selectable import LABEL_STYLE_NONE
|
||||
from sqlalchemy.testing import assert_raises
|
||||
from sqlalchemy.testing import assert_raises_message
|
||||
@@ -3029,6 +3030,26 @@ class AnnotationsTest(fixtures.TestBase):
|
||||
eq_(whereclause.left._annotations, {"foo": "bar"})
|
||||
eq_(whereclause.right._annotations, {"foo": "bar"})
|
||||
|
||||
@testing.combinations(True, False, None)
|
||||
def test_setup_inherit_cache(self, inherit_cache_value):
|
||||
if inherit_cache_value is None:
|
||||
|
||||
class MyInsertThing(Insert):
|
||||
pass
|
||||
|
||||
else:
|
||||
|
||||
class MyInsertThing(Insert):
|
||||
inherit_cache = inherit_cache_value
|
||||
|
||||
t = table("t", column("x"))
|
||||
anno = MyInsertThing(t)._annotate({"foo": "bar"})
|
||||
|
||||
if inherit_cache_value is not None:
|
||||
is_(type(anno).__dict__["inherit_cache"], inherit_cache_value)
|
||||
else:
|
||||
assert "inherit_cache" not in type(anno).__dict__
|
||||
|
||||
def test_proxy_set_iteration_includes_annotated(self):
|
||||
from sqlalchemy.schema import Column
|
||||
|
||||
|
||||
Reference in New Issue
Block a user