Merge "ORM bulk insert via execute" into main

This commit is contained in:
mike bayer
2022-09-26 01:17:44 +00:00
committed by Gerrit Code Review
45 changed files with 4761 additions and 1023 deletions
+11
View File
@@ -660,6 +660,17 @@ Selecting a Synchronization Strategy
With both the 1.x and 2.0 form of ORM-enabled updates and deletes, the following
values for ``synchronize_session`` are supported:
* ``'auto'`` - this is the default. The ``'fetch'`` strategy will be used on
backends that support RETURNING, which includes all SQLAlchemy-native drivers
except for MySQL. If RETURNING is not supported, the ``'evaluate'``
strategy will be used instead.
.. versionchanged:: 2.0 Added the ``'auto'`` synchronization strategy. As
most backends now support RETURNING, selecting ``'fetch'`` for these
backends specifically is the more efficient and error-free default for
these backends. The MySQL backend as well as third party backends without
RETURNING support will continue to use ``'evaluate'`` by default.
* ``False`` - don't synchronize the session. This option is the most
efficient and is reliable once the session is expired, which
typically occurs after a commit(), or explicitly using
+1 -1
View File
@@ -301,7 +301,7 @@ Fast Executemany Mode
The SQL Server ``fast_executemany`` parameter may be used at the same time
as ``insertmanyvalues`` is enabled; however, the parameter will not be used
in as many cases as INSERT statements that are invoked using Core
:class:`.Insert` constructs as well as all ORM use no longer use the
:class:`_dml.Insert` constructs as well as all ORM use no longer use the
``.executemany()`` DBAPI cursor method.
The PyODBC driver includes support for a "fast executemany" mode of execution
@@ -134,9 +134,11 @@ def _upsert(cfg, table, returning, set_lambda=None):
stmt = insert(table)
table_pk = inspect(table).selectable
if set_lambda:
stmt = stmt.on_conflict_do_update(
index_elements=table.primary_key, set_=set_lambda(stmt.excluded)
index_elements=table_pk.primary_key, set_=set_lambda(stmt.excluded)
)
else:
stmt = stmt.on_conflict_do_nothing()
-5
View File
@@ -1466,11 +1466,6 @@ class SQLiteCompiler(compiler.SQLCompiler):
return target_text
def visit_insert(self, insert_stmt, **kw):
if insert_stmt._post_values_clause is not None:
kw["disable_implicit_returning"] = True
return super().visit_insert(insert_stmt, **kw)
def visit_on_conflict_do_nothing(self, on_conflict, **kw):
target_text = self._on_conflict_target(on_conflict, **kw)
+336 -54
View File
@@ -23,12 +23,14 @@ from typing import Iterator
from typing import List
from typing import NoReturn
from typing import Optional
from typing import overload
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from .result import IteratorResult
from .result import MergedResult
from .result import Result
from .result import ResultMetaData
@@ -62,36 +64,80 @@ if typing.TYPE_CHECKING:
from .interfaces import ExecutionContext
from .result import _KeyIndexType
from .result import _KeyMapRecType
from .result import _KeyMapType
from .result import _KeyType
from .result import _ProcessorsType
from .result import _TupleGetterType
from ..sql.type_api import _ResultProcessorType
_T = TypeVar("_T", bound=Any)
# metadata entry tuple indexes.
# using raw tuple is faster than namedtuple.
MD_INDEX: Literal[0] = 0 # integer index in cursor.description
MD_RESULT_MAP_INDEX: Literal[
1
] = 1 # integer index in compiled._result_columns
MD_OBJECTS: Literal[
2
] = 2 # other string keys and ColumnElement obj that can match
MD_LOOKUP_KEY: Literal[
3
] = 3 # string key we usually expect for key-based lookup
MD_RENDERED_NAME: Literal[4] = 4 # name that is usually in cursor.description
MD_PROCESSOR: Literal[5] = 5 # callable to process a result value into a row
MD_UNTRANSLATED: Literal[6] = 6 # raw name from cursor.description
# these match up to the positions in
# _CursorKeyMapRecType
MD_INDEX: Literal[0] = 0
"""integer index in cursor.description
"""
MD_RESULT_MAP_INDEX: Literal[1] = 1
"""integer index in compiled._result_columns"""
MD_OBJECTS: Literal[2] = 2
"""other string keys and ColumnElement obj that can match.
This comes from compiler.RM_OBJECTS / compiler.ResultColumnsEntry.objects
"""
MD_LOOKUP_KEY: Literal[3] = 3
"""string key we usually expect for key-based lookup
this comes from compiler.RM_NAME / compiler.ResultColumnsEntry.name
"""
MD_RENDERED_NAME: Literal[4] = 4
"""name that is usually in cursor.description
this comes from compiler.RENDERED_NAME / compiler.ResultColumnsEntry.keyname
"""
MD_PROCESSOR: Literal[5] = 5
"""callable to process a result value into a row"""
MD_UNTRANSLATED: Literal[6] = 6
"""raw name from cursor.description"""
_CursorKeyMapRecType = Tuple[
int, int, List[Any], str, str, Optional["_ResultProcessorType"], str
Optional[int], # MD_INDEX, None means the record is ambiguously named
int, # MD_RESULT_MAP_INDEX
List[Any], # MD_OBJECTS
str, # MD_LOOKUP_KEY
str, # MD_RENDERED_NAME
Optional["_ResultProcessorType"], # MD_PROCESSOR
Optional[str], # MD_UNTRANSLATED
]
_CursorKeyMapType = Dict["_KeyType", _CursorKeyMapRecType]
# same as _CursorKeyMapRecType except the MD_INDEX value is definitely
# not None
_NonAmbigCursorKeyMapRecType = Tuple[
int,
int,
List[Any],
str,
str,
Optional["_ResultProcessorType"],
str,
]
class CursorResultMetaData(ResultMetaData):
"""Result metadata for DBAPI cursors."""
@@ -127,38 +173,112 @@ class CursorResultMetaData(ResultMetaData):
extra=[self._keymap[key][MD_OBJECTS] for key in self._keys],
)
def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData:
recs = cast(
"List[_CursorKeyMapRecType]", list(self._metadata_for_keys(keys))
def _make_new_metadata(
self,
*,
unpickled: bool,
processors: _ProcessorsType,
keys: Sequence[str],
keymap: _KeyMapType,
tuplefilter: Optional[_TupleGetterType],
translated_indexes: Optional[List[int]],
safe_for_cache: bool,
keymap_by_result_column_idx: Any,
) -> CursorResultMetaData:
new_obj = self.__class__.__new__(self.__class__)
new_obj._unpickled = unpickled
new_obj._processors = processors
new_obj._keys = keys
new_obj._keymap = keymap
new_obj._tuplefilter = tuplefilter
new_obj._translated_indexes = translated_indexes
new_obj._safe_for_cache = safe_for_cache
new_obj._keymap_by_result_column_idx = keymap_by_result_column_idx
return new_obj
def _remove_processors(self) -> CursorResultMetaData:
assert not self._tuplefilter
return self._make_new_metadata(
unpickled=self._unpickled,
processors=[None] * len(self._processors),
tuplefilter=None,
translated_indexes=None,
keymap={
key: value[0:5] + (None,) + value[6:]
for key, value in self._keymap.items()
},
keys=self._keys,
safe_for_cache=self._safe_for_cache,
keymap_by_result_column_idx=self._keymap_by_result_column_idx,
)
def _splice_horizontally(
self, other: CursorResultMetaData
) -> CursorResultMetaData:
assert not self._tuplefilter
keymap = self._keymap.copy()
offset = len(self._keys)
keymap.update(
{
key: (
# int index should be None for ambiguous key
value[0] + offset
if value[0] is not None and key not in keymap
else None,
value[1] + offset,
*value[2:],
)
for key, value in other._keymap.items()
}
)
return self._make_new_metadata(
unpickled=self._unpickled,
processors=self._processors + other._processors, # type: ignore
tuplefilter=None,
translated_indexes=None,
keys=self._keys + other._keys, # type: ignore
keymap=keymap,
safe_for_cache=self._safe_for_cache,
keymap_by_result_column_idx={
metadata_entry[MD_RESULT_MAP_INDEX]: metadata_entry
for metadata_entry in keymap.values()
},
)
def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData:
recs = list(self._metadata_for_keys(keys))
indexes = [rec[MD_INDEX] for rec in recs]
new_keys: List[str] = [rec[MD_LOOKUP_KEY] for rec in recs]
if self._translated_indexes:
indexes = [self._translated_indexes[idx] for idx in indexes]
tup = tuplegetter(*indexes)
new_metadata = self.__class__.__new__(self.__class__)
new_metadata._unpickled = self._unpickled
new_metadata._processors = self._processors
new_metadata._keys = new_keys
new_metadata._tuplefilter = tup
new_metadata._translated_indexes = indexes
new_recs = [(index,) + rec[1:] for index, rec in enumerate(recs)]
new_metadata._keymap = {rec[MD_LOOKUP_KEY]: rec for rec in new_recs}
keymap: _KeyMapType = {rec[MD_LOOKUP_KEY]: rec for rec in new_recs}
# TODO: need unit test for:
# result = connection.execute("raw sql, no columns").scalars()
# without the "or ()" it's failing because MD_OBJECTS is None
new_metadata._keymap.update(
keymap.update(
(e, new_rec)
for new_rec in new_recs
for e in new_rec[MD_OBJECTS] or ()
)
return new_metadata
return self._make_new_metadata(
unpickled=self._unpickled,
processors=self._processors,
keys=new_keys,
tuplefilter=tup,
translated_indexes=indexes,
keymap=keymap,
safe_for_cache=self._safe_for_cache,
keymap_by_result_column_idx=self._keymap_by_result_column_idx,
)
def _adapt_to_context(self, context: ExecutionContext) -> ResultMetaData:
"""When using a cached Compiled construct that has a _result_map,
@@ -168,6 +288,7 @@ class CursorResultMetaData(ResultMetaData):
as matched to those of the cached statement.
"""
if not context.compiled or not context.compiled._result_columns:
return self
@@ -189,7 +310,6 @@ class CursorResultMetaData(ResultMetaData):
# make a copy and add the columns from the invoked statement
# to the result map.
md = self.__class__.__new__(self.__class__)
keymap_by_position = self._keymap_by_result_column_idx
@@ -201,26 +321,26 @@ class CursorResultMetaData(ResultMetaData):
for metadata_entry in self._keymap.values()
}
md._keymap = compat.dict_union(
self._keymap,
{
new: keymap_by_position[idx]
for idx, new in enumerate(
invoked_statement._all_selected_columns
)
if idx in keymap_by_position
},
)
md._unpickled = self._unpickled
md._processors = self._processors
assert not self._tuplefilter
md._tuplefilter = None
md._translated_indexes = None
md._keys = self._keys
md._keymap_by_result_column_idx = self._keymap_by_result_column_idx
md._safe_for_cache = self._safe_for_cache
return md
return self._make_new_metadata(
keymap=compat.dict_union(
self._keymap,
{
new: keymap_by_position[idx]
for idx, new in enumerate(
invoked_statement._all_selected_columns
)
if idx in keymap_by_position
},
),
unpickled=self._unpickled,
processors=self._processors,
tuplefilter=None,
translated_indexes=None,
keys=self._keys,
safe_for_cache=self._safe_for_cache,
keymap_by_result_column_idx=self._keymap_by_result_column_idx,
)
def __init__(
self,
@@ -683,7 +803,27 @@ class CursorResultMetaData(ResultMetaData):
untranslated,
)
def _key_fallback(self, key, err, raiseerr=True):
@overload
def _key_fallback(
self, key: Any, err: Exception, raiseerr: Literal[True] = ...
) -> NoReturn:
...
@overload
def _key_fallback(
self, key: Any, err: Exception, raiseerr: Literal[False] = ...
) -> None:
...
@overload
def _key_fallback(
self, key: Any, err: Exception, raiseerr: bool = ...
) -> Optional[NoReturn]:
...
def _key_fallback(
self, key: Any, err: Exception, raiseerr: bool = True
) -> Optional[NoReturn]:
if raiseerr:
if self._unpickled and isinstance(key, elements.ColumnElement):
@@ -714,9 +854,9 @@ class CursorResultMetaData(ResultMetaData):
try:
rec = self._keymap[key]
except KeyError as ke:
rec = self._key_fallback(key, ke, raiseerr)
if rec is None:
return None
x = self._key_fallback(key, ke, raiseerr)
assert x is None
return None
index = rec[0]
@@ -734,7 +874,7 @@ class CursorResultMetaData(ResultMetaData):
def _metadata_for_keys(
self, keys: Sequence[Any]
) -> Iterator[_CursorKeyMapRecType]:
) -> Iterator[_NonAmbigCursorKeyMapRecType]:
for key in keys:
if int in key.__class__.__mro__:
key = self._keys[key]
@@ -750,7 +890,7 @@ class CursorResultMetaData(ResultMetaData):
if index is None:
self._raise_for_ambiguous_column_name(rec)
yield rec
yield cast(_NonAmbigCursorKeyMapRecType, rec)
def __getstate__(self):
return {
@@ -1237,6 +1377,12 @@ _NO_RESULT_METADATA = _NoResultMetaData()
SelfCursorResult = TypeVar("SelfCursorResult", bound="CursorResult[Any]")
def null_dml_result() -> IteratorResult[Any]:
it: IteratorResult[Any] = IteratorResult(_NoResultMetaData(), iter([]))
it._soft_close()
return it
class CursorResult(Result[_T]):
"""A Result that is representing state from a DBAPI cursor.
@@ -1586,6 +1732,142 @@ class CursorResult(Result[_T]):
"""
return self.context.returned_default_rows
def splice_horizontally(self, other):
"""Return a new :class:`.CursorResult` that "horizontally splices"
together the rows of this :class:`.CursorResult` with that of another
:class:`.CursorResult`.
.. tip:: This method is for the benefit of the SQLAlchemy ORM and is
not intended for general use.
"horizontally splices" means that for each row in the first and second
result sets, a new row that concatenates the two rows together is
produced, which then becomes the new row. The incoming
:class:`.CursorResult` must have the identical number of rows. It is
typically expected that the two result sets come from the same sort
order as well, as the result rows are spliced together based on their
position in the result.
The expected use case here is so that multiple INSERT..RETURNING
statements against different tables can produce a single result
that looks like a JOIN of those two tables.
E.g.::
r1 = connection.execute(
users.insert().returning(users.c.user_name, users.c.user_id),
user_values
)
r2 = connection.execute(
addresses.insert().returning(
addresses.c.address_id,
addresses.c.address,
addresses.c.user_id,
),
address_values
)
rows = r1.splice_horizontally(r2).all()
assert (
rows ==
[
("john", 1, 1, "foo@bar.com", 1),
("jack", 2, 2, "bar@bat.com", 2),
]
)
.. versionadded:: 2.0
.. seealso::
:meth:`.CursorResult.splice_vertically`
"""
clone = self._generate()
total_rows = [
tuple(r1) + tuple(r2)
for r1, r2 in zip(
list(self._raw_row_iterator()),
list(other._raw_row_iterator()),
)
]
clone._metadata = clone._metadata._splice_horizontally(other._metadata)
clone.cursor_strategy = FullyBufferedCursorFetchStrategy(
None,
initial_buffer=total_rows,
)
clone._reset_memoizations()
return clone
def splice_vertically(self, other):
"""Return a new :class:`.CursorResult` that "vertically splices",
i.e. "extends", the rows of this :class:`.CursorResult` with that of
another :class:`.CursorResult`.
.. tip:: This method is for the benefit of the SQLAlchemy ORM and is
not intended for general use.
"vertically splices" means the rows of the given result are appended to
the rows of this cursor result. The incoming :class:`.CursorResult`
must have rows that represent the identical list of columns in the
identical order as they are in this :class:`.CursorResult`.
.. versionadded:: 2.0
.. seealso::
:ref:`.CursorResult.splice_horizontally`
"""
clone = self._generate()
total_rows = list(self._raw_row_iterator()) + list(
other._raw_row_iterator()
)
clone.cursor_strategy = FullyBufferedCursorFetchStrategy(
None,
initial_buffer=total_rows,
)
clone._reset_memoizations()
return clone
def _rewind(self, rows):
"""rewind this result back to the given rowset.
this is used internally for the case where an :class:`.Insert`
construct combines the use of
:meth:`.Insert.return_defaults` along with the
"supplemental columns" feature.
"""
if self._echo:
self.context.connection._log_debug(
"CursorResult rewound %d row(s)", len(rows)
)
# the rows given are expected to be Row objects, so we
# have to clear out processors which have already run on these
# rows
self._metadata = cast(
CursorResultMetaData, self._metadata
)._remove_processors()
self.cursor_strategy = FullyBufferedCursorFetchStrategy(
None,
# TODO: if these are Row objects, can we save on not having to
# re-make new Row objects out of them a second time? is that
# what's actually happening right now? maybe look into this
initial_buffer=rows,
)
self._reset_memoizations()
return self
@property
def returned_defaults(self):
"""Return the values of default columns that were fetched using
+28 -23
View File
@@ -1007,6 +1007,7 @@ class DefaultExecutionContext(ExecutionContext):
_is_implicit_returning = False
_is_explicit_returning = False
_is_supplemental_returning = False
_is_server_side = False
_soft_closed = False
@@ -1125,18 +1126,19 @@ class DefaultExecutionContext(ExecutionContext):
self.is_text = compiled.isplaintext
if ii or iu or id_:
dml_statement = compiled.compile_state.statement # type: ignore
if TYPE_CHECKING:
assert isinstance(compiled.statement, UpdateBase)
assert isinstance(dml_statement, UpdateBase)
self.is_crud = True
self._is_explicit_returning = ier = bool(
compiled.statement._returning
)
self._is_implicit_returning = iir = is_implicit_returning = bool(
self._is_explicit_returning = ier = bool(dml_statement._returning)
self._is_implicit_returning = iir = bool(
compiled.implicit_returning
)
assert not (
is_implicit_returning and compiled.statement._returning
)
if iir and dml_statement._supplemental_returning:
self._is_supplemental_returning = True
# dont mix implicit and explicit returning
assert not (iir and ier)
if (ier or iir) and compiled.for_executemany:
if ii and not self.dialect.insert_executemany_returning:
@@ -1711,7 +1713,14 @@ class DefaultExecutionContext(ExecutionContext):
# are that the result has only one row, until executemany()
# support is added here.
assert result._metadata.returns_rows
result._soft_close()
# Insert statement has both return_defaults() and
# returning(). rewind the result on the list of rows
# we just used.
if self._is_supplemental_returning:
result._rewind(rows)
else:
result._soft_close()
elif not self._is_explicit_returning:
result._soft_close()
@@ -1721,21 +1730,18 @@ class DefaultExecutionContext(ExecutionContext):
# function so this is not necessarily true.
# assert not result.returns_rows
elif self.isupdate and self._is_implicit_returning:
# get rowcount
# (which requires open cursor on some drivers)
# we were not doing this in 1.4, however
# test_rowcount -> test_update_rowcount_return_defaults
# is testing this, and psycopg will no longer return
# rowcount after cursor is closed.
result.rowcount
elif self._is_implicit_returning:
rows = result.all()
if rows:
self.returned_default_rows = rows
result.rowcount = len(rows)
self._has_rowcount = True
row = result.fetchone()
if row is not None:
self.returned_default_rows = [row]
result._soft_close()
if self._is_supplemental_returning:
result._rewind(rows)
else:
result._soft_close()
# test that it has a cursor metadata that is accurate.
# the rows have all been fetched however.
@@ -1750,7 +1756,6 @@ class DefaultExecutionContext(ExecutionContext):
elif self.isupdate or self.isdelete:
result.rowcount
self._has_rowcount = True
return result
@util.memoized_property
+21 -1
View File
@@ -109,9 +109,27 @@ class ResultMetaData:
def _for_freeze(self) -> ResultMetaData:
raise NotImplementedError()
@overload
def _key_fallback(
self, key: _KeyType, err: Exception, raiseerr: bool = True
self, key: Any, err: Exception, raiseerr: Literal[True] = ...
) -> NoReturn:
...
@overload
def _key_fallback(
self, key: Any, err: Exception, raiseerr: Literal[False] = ...
) -> None:
...
@overload
def _key_fallback(
self, key: Any, err: Exception, raiseerr: bool = ...
) -> Optional[NoReturn]:
...
def _key_fallback(
self, key: Any, err: Exception, raiseerr: bool = True
) -> Optional[NoReturn]:
assert raiseerr
raise KeyError(key) from err
@@ -2148,6 +2166,7 @@ class IteratorResult(Result[_TP]):
"""
_hard_closed = False
_soft_closed = False
def __init__(
self,
@@ -2168,6 +2187,7 @@ class IteratorResult(Result[_TP]):
self.raw._soft_close(hard=hard, **kw)
self.iterator = iter([])
self._reset_memoizations()
self._soft_closed = True
def _raise_hard_closed(self) -> NoReturn:
raise exc.ResourceClosedError("This result object is closed.")
File diff suppressed because it is too large Load Diff
+141 -32
View File
@@ -73,6 +73,7 @@ if TYPE_CHECKING:
from .query import Query
from .session import _BindArguments
from .session import Session
from ..engine import Result
from ..engine.interfaces import _CoreSingleExecuteParams
from ..engine.interfaces import _ExecuteOptionsParameter
from ..sql._typing import _ColumnsClauseArgument
@@ -203,15 +204,19 @@ _orm_load_exec_options = util.immutabledict(
class AbstractORMCompileState(CompileState):
is_dml_returning = False
@classmethod
def create_for_statement(
cls,
statement: Union[Select, FromStatement],
compiler: Optional[SQLCompiler],
**kw: Any,
) -> ORMCompileState:
) -> AbstractORMCompileState:
"""Create a context for a statement given a :class:`.Compiler`.
This method is always invoked in the context of SQLCompiler.process().
For a Select object, this would be invoked from
SQLCompiler.visit_select(). For the special FromStatement object used
by Query to indicate "Query.from_statement()", this is called by
@@ -232,6 +237,28 @@ class AbstractORMCompileState(CompileState):
):
raise NotImplementedError()
@classmethod
def orm_execute_statement(
cls,
session,
statement,
params,
execution_options,
bind_arguments,
conn,
) -> Result:
result = conn.execute(
statement, params or {}, execution_options=execution_options
)
return cls.orm_setup_cursor_result(
session,
statement,
params,
execution_options,
bind_arguments,
result,
)
@classmethod
def orm_setup_cursor_result(
cls,
@@ -309,6 +336,17 @@ class ORMCompileState(AbstractORMCompileState):
def __init__(self, *arg, **kw):
raise NotImplementedError()
if TYPE_CHECKING:
@classmethod
def create_for_statement(
cls,
statement: Union[Select, FromStatement],
compiler: Optional[SQLCompiler],
**kw: Any,
) -> ORMCompileState:
...
def _append_dedupe_col_collection(self, obj, col_collection):
dedupe = self.dedupe_columns
if obj not in dedupe:
@@ -332,26 +370,6 @@ class ORMCompileState(AbstractORMCompileState):
else:
return SelectState._column_naming_convention(label_style)
@classmethod
def create_for_statement(
cls,
statement: Union[Select, FromStatement],
compiler: Optional[SQLCompiler],
**kw: Any,
) -> ORMCompileState:
"""Create a context for a statement given a :class:`.Compiler`.
This method is always invoked in the context of SQLCompiler.process().
For a Select object, this would be invoked from
SQLCompiler.visit_select(). For the special FromStatement object used
by Query to indicate "Query.from_statement()", this is called by
FromStatement._compiler_dispatch() that would be called by
SQLCompiler.process().
"""
raise NotImplementedError()
@classmethod
def get_column_descriptions(cls, statement):
return _column_descriptions(statement)
@@ -518,6 +536,49 @@ class ORMCompileState(AbstractORMCompileState):
)
class DMLReturningColFilter:
"""an adapter used for the DML RETURNING case.
Has a subset of the interface used by
:class:`.ORMAdapter` and is used for :class:`._QueryEntity`
instances to set up their columns as used in RETURNING for a
DML statement.
"""
__slots__ = ("mapper", "columns", "__weakref__")
def __init__(self, target_mapper, immediate_dml_mapper):
if (
immediate_dml_mapper is not None
and target_mapper.local_table
is not immediate_dml_mapper.local_table
):
# joined inh, or in theory other kinds of multi-table mappings
self.mapper = immediate_dml_mapper
else:
# single inh, normal mappings, etc.
self.mapper = target_mapper
self.columns = self.columns = util.WeakPopulateDict(
self.adapt_check_present # type: ignore
)
def __call__(self, col, as_filter):
for cc in sql_util._find_columns(col):
c2 = self.adapt_check_present(cc)
if c2 is not None:
return col
else:
return None
def adapt_check_present(self, col):
mapper = self.mapper
prop = mapper._columntoproperty.get(col, None)
if prop is None:
return None
return mapper.local_table.c.corresponding_column(col)
@sql.base.CompileState.plugin_for("orm", "orm_from_statement")
class ORMFromStatementCompileState(ORMCompileState):
_from_obj_alias = None
@@ -525,7 +586,7 @@ class ORMFromStatementCompileState(ORMCompileState):
statement_container: FromStatement
requested_statement: Union[SelectBase, TextClause, UpdateBase]
dml_table: _DMLTableElement
dml_table: Optional[_DMLTableElement] = None
_has_orm_entities = False
multi_row_eager_loaders = False
@@ -541,7 +602,7 @@ class ORMFromStatementCompileState(ORMCompileState):
statement_container: Union[Select, FromStatement],
compiler: Optional[SQLCompiler],
**kw: Any,
) -> ORMCompileState:
) -> ORMFromStatementCompileState:
if compiler is not None:
toplevel = not compiler.stack
@@ -565,6 +626,7 @@ class ORMFromStatementCompileState(ORMCompileState):
if statement.is_dml:
self.dml_table = statement.table
self.is_dml_returning = True
self._entities = []
self._polymorphic_adapters = {}
@@ -674,6 +736,18 @@ class ORMFromStatementCompileState(ORMCompileState):
def _get_current_adapter(self):
return None
def setup_dml_returning_compile_state(self, dml_mapper):
"""used by BulkORMInsert (and Update / Delete?) to set up a handler
for RETURNING to return ORM objects and expressions
"""
target_mapper = self.statement._propagate_attrs.get(
"plugin_subject", None
)
adapter = DMLReturningColFilter(target_mapper, dml_mapper)
for entity in self._entities:
entity.setup_dml_returning_compile_state(self, adapter)
class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]):
"""Core construct that represents a load of ORM objects from various
@@ -813,7 +887,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
statement: Union[Select, FromStatement],
compiler: Optional[SQLCompiler],
**kw: Any,
) -> ORMCompileState:
) -> ORMSelectCompileState:
"""compiler hook, we arrive here from compiler.visit_select() only."""
self = cls.__new__(cls)
@@ -2312,6 +2386,13 @@ class _QueryEntity:
def setup_compile_state(self, compile_state: ORMCompileState) -> None:
raise NotImplementedError()
def setup_dml_returning_compile_state(
self,
compile_state: ORMCompileState,
adapter: DMLReturningColFilter,
) -> None:
raise NotImplementedError()
def row_processor(self, context, result):
raise NotImplementedError()
@@ -2509,8 +2590,24 @@ class _MapperEntity(_QueryEntity):
return _instance, self._label_name, self._extra_entities
def setup_compile_state(self, compile_state):
def setup_dml_returning_compile_state(
self,
compile_state: ORMCompileState,
adapter: DMLReturningColFilter,
) -> None:
loading._setup_entity_query(
compile_state,
self.mapper,
self,
self.path,
adapter,
compile_state.primary_columns,
with_polymorphic=self._with_polymorphic_mappers,
only_load_props=compile_state.compile_options._only_load_props,
polymorphic_discriminator=self._polymorphic_discriminator,
)
def setup_compile_state(self, compile_state):
adapter = self._get_entity_clauses(compile_state)
single_table_crit = self.mapper._single_table_criterion
@@ -2536,7 +2633,6 @@ class _MapperEntity(_QueryEntity):
only_load_props=compile_state.compile_options._only_load_props,
polymorphic_discriminator=self._polymorphic_discriminator,
)
compile_state._fallback_from_clauses.append(self.selectable)
@@ -2743,9 +2839,7 @@ class _ColumnEntity(_QueryEntity):
getter, label_name, extra_entities = self._row_processor
if self.translate_raw_column:
extra_entities += (
result.context.invoked_statement._raw_columns[
self.raw_column_index
],
context.query._raw_columns[self.raw_column_index],
)
return getter, label_name, extra_entities
@@ -2781,9 +2875,7 @@ class _ColumnEntity(_QueryEntity):
if self.translate_raw_column:
extra_entities = self._extra_entities + (
result.context.invoked_statement._raw_columns[
self.raw_column_index
],
context.query._raw_columns[self.raw_column_index],
)
return getter, self._label_name, extra_entities
else:
@@ -2843,6 +2935,8 @@ class _RawColumnEntity(_ColumnEntity):
current_adapter = compile_state._get_current_adapter()
if current_adapter:
column = current_adapter(self.column, False)
if column is None:
return
else:
column = self.column
@@ -2944,10 +3038,25 @@ class _ORMColumnEntity(_ColumnEntity):
self.entity_zero
) and entity.common_parent(self.entity_zero)
def setup_dml_returning_compile_state(
self,
compile_state: ORMCompileState,
adapter: DMLReturningColFilter,
) -> None:
self._fetch_column = self.column
column = adapter(self.column, False)
if column is not None:
compile_state.dedupe_columns.add(column)
compile_state.primary_columns.append(column)
def setup_compile_state(self, compile_state):
current_adapter = compile_state._get_current_adapter()
if current_adapter:
column = current_adapter(self.column, False)
if column is None:
assert compile_state.is_dml_returning
self._fetch_column = self.column
return
else:
column = self.column
+26
View File
@@ -19,6 +19,7 @@ import operator
import typing
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import NoReturn
from typing import Optional
@@ -602,6 +603,31 @@ class Composite(
def _attribute_keys(self) -> Sequence[str]:
return [prop.key for prop in self.props]
def _populate_composite_bulk_save_mappings_fn(
self,
) -> Callable[[Dict[str, Any]], None]:
if self._generated_composite_accessor:
get_values = self._generated_composite_accessor
else:
def get_values(val: Any) -> Tuple[Any]:
return val.__composite_values__() # type: ignore
attrs = [prop.key for prop in self.props]
def populate(dest_dict: Dict[str, Any]) -> None:
dest_dict.update(
{
key: val
for key, val in zip(
attrs, get_values(dest_dict.pop(self.key))
)
}
)
return populate
def get_history(
self,
state: InstanceState[Any],
+62 -14
View File
@@ -9,8 +9,8 @@
from __future__ import annotations
import operator
from .base import LoaderCallableStatus
from .base import PassiveFlag
from .. import exc
from .. import inspect
from .. import util
@@ -32,7 +32,16 @@ class _NoObject(operators.ColumnOperators):
return None
class _ExpiredObject(operators.ColumnOperators):
def operate(self, *arg, **kw):
return self
def reverse_operate(self, *arg, **kw):
return self
_NO_OBJECT = _NoObject()
_EXPIRED_OBJECT = _ExpiredObject()
class EvaluatorCompiler:
@@ -73,6 +82,24 @@ class EvaluatorCompiler:
f"alternate class {parentmapper.class_}"
)
key = parentmapper._columntoproperty[clause].key
impl = parentmapper.class_manager[key].impl
if impl is not None:
def get_corresponding_attr(obj):
if obj is None:
return _NO_OBJECT
state = inspect(obj)
dict_ = state.dict
value = impl.get(
state, dict_, passive=PassiveFlag.PASSIVE_NO_FETCH
)
if value is LoaderCallableStatus.PASSIVE_NO_RESULT:
return _EXPIRED_OBJECT
return value
return get_corresponding_attr
else:
key = clause.key
if (
@@ -85,15 +112,16 @@ class EvaluatorCompiler:
"make use of the actual mapped columns in ORM-evaluated "
"UPDATE / DELETE expressions."
)
else:
raise UnevaluatableError(f"Cannot evaluate column: {clause}")
get_corresponding_attr = operator.attrgetter(key)
return (
lambda obj: get_corresponding_attr(obj)
if obj is not None
else _NO_OBJECT
)
def get_corresponding_attr(obj):
if obj is None:
return _NO_OBJECT
return getattr(obj, key, _EXPIRED_OBJECT)
return get_corresponding_attr
def visit_tuple(self, clause):
return self.visit_clauselist(clause)
@@ -134,7 +162,9 @@ class EvaluatorCompiler:
has_null = False
for sub_evaluate in evaluators:
value = sub_evaluate(obj)
if value:
if value is _EXPIRED_OBJECT:
return _EXPIRED_OBJECT
elif value:
return True
has_null = has_null or value is None
if has_null:
@@ -147,6 +177,9 @@ class EvaluatorCompiler:
def evaluate(obj):
for sub_evaluate in evaluators:
value = sub_evaluate(obj)
if value is _EXPIRED_OBJECT:
return _EXPIRED_OBJECT
if not value:
if value is None or value is _NO_OBJECT:
return None
@@ -160,7 +193,9 @@ class EvaluatorCompiler:
values = []
for sub_evaluate in evaluators:
value = sub_evaluate(obj)
if value is None or value is _NO_OBJECT:
if value is _EXPIRED_OBJECT:
return _EXPIRED_OBJECT
elif value is None or value is _NO_OBJECT:
return None
values.append(value)
return tuple(values)
@@ -183,13 +218,21 @@ class EvaluatorCompiler:
def visit_is_binary_op(self, operator, eval_left, eval_right, clause):
def evaluate(obj):
return eval_left(obj) == eval_right(obj)
left_val = eval_left(obj)
right_val = eval_right(obj)
if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
return _EXPIRED_OBJECT
return left_val == right_val
return evaluate
def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause):
def evaluate(obj):
return eval_left(obj) != eval_right(obj)
left_val = eval_left(obj)
right_val = eval_right(obj)
if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
return _EXPIRED_OBJECT
return left_val != right_val
return evaluate
@@ -197,8 +240,11 @@ class EvaluatorCompiler:
def evaluate(obj):
left_val = eval_left(obj)
right_val = eval_right(obj)
if left_val is None or right_val is None:
if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
return _EXPIRED_OBJECT
elif left_val is None or right_val is None:
return None
return operator(eval_left(obj), eval_right(obj))
return evaluate
@@ -274,7 +320,9 @@ class EvaluatorCompiler:
def evaluate(obj):
value = eval_inner(obj)
if value is None:
if value is _EXPIRED_OBJECT:
return _EXPIRED_OBJECT
elif value is None:
return None
return not value
+10
View File
@@ -68,6 +68,11 @@ class IdentityMap:
) -> Optional[_O]:
raise NotImplementedError()
def fast_get_state(
self, key: _IdentityKeyType[_O]
) -> Optional[InstanceState[_O]]:
raise NotImplementedError()
def keys(self) -> Iterable[_IdentityKeyType[Any]]:
return self._dict.keys()
@@ -206,6 +211,11 @@ class WeakInstanceDict(IdentityMap):
self._dict[key] = state
state._instance_dict = self._wr
def fast_get_state(
self, key: _IdentityKeyType[_O]
) -> Optional[InstanceState[_O]]:
return self._dict.get(key)
def get(
self, key: _IdentityKeyType[_O], default: Optional[_O] = None
) -> Optional[_O]:
+3 -2
View File
@@ -29,7 +29,6 @@ from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from sqlalchemy.orm.context import FromStatement
from . import attributes
from . import exc as orm_exc
from . import path_registry
@@ -37,6 +36,7 @@ from .base import _DEFER_FOR_STATE
from .base import _RAISE_FOR_STATE
from .base import _SET_DEFERRED_EXPIRED
from .base import PassiveFlag
from .context import FromStatement
from .util import _none_set
from .util import state_str
from .. import exc as sa_exc
@@ -50,6 +50,7 @@ from ..sql import util as sql_util
from ..sql.selectable import ForUpdateArg
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
from ..sql.selectable import SelectState
from ..util import EMPTY_DICT
if TYPE_CHECKING:
from ._typing import _IdentityKeyType
@@ -764,7 +765,7 @@ def _instance_processor(
)
quick_populators = path.get(
context.attributes, "memoized_setups", _none_set
context.attributes, "memoized_setups", EMPTY_DICT
)
todo = []
+15
View File
@@ -854,6 +854,7 @@ class Mapper(
_memoized_values: Dict[Any, Callable[[], Any]]
_inheriting_mappers: util.WeakSequence[Mapper[Any]]
_all_tables: Set[Table]
_polymorphic_attr_key: Optional[str]
_pks_by_table: Dict[FromClause, OrderedSet[ColumnClause[Any]]]
_cols_by_table: Dict[FromClause, OrderedSet[ColumnElement[Any]]]
@@ -1653,6 +1654,7 @@ class Mapper(
"""
setter = False
polymorphic_key: Optional[str] = None
if self.polymorphic_on is not None:
setter = True
@@ -1772,17 +1774,23 @@ class Mapper(
self._set_polymorphic_identity = (
mapper._set_polymorphic_identity
)
self._polymorphic_attr_key = (
mapper._polymorphic_attr_key
)
self._validate_polymorphic_identity = (
mapper._validate_polymorphic_identity
)
else:
self._set_polymorphic_identity = None
self._polymorphic_attr_key = None
return
if setter:
def _set_polymorphic_identity(state):
dict_ = state.dict
# TODO: what happens if polymorphic_on column attribute name
# does not match .key?
state.get_impl(polymorphic_key).set(
state,
dict_,
@@ -1790,6 +1798,8 @@ class Mapper(
None,
)
self._polymorphic_attr_key = polymorphic_key
def _validate_polymorphic_identity(mapper, state, dict_):
if (
polymorphic_key in dict_
@@ -1808,6 +1818,7 @@ class Mapper(
_validate_polymorphic_identity
)
else:
self._polymorphic_attr_key = None
self._set_polymorphic_identity = None
_validate_polymorphic_identity = None
@@ -3561,6 +3572,10 @@ class Mapper(
def _compiled_cache(self):
return util.LRUCache(self._compiled_cache_size)
@HasMemoized.memoized_attribute
def _multiple_persistence_tables(self):
return len(self.tables) > 1
@HasMemoized.memoized_attribute
def _sorted_tables(self):
table_to_mapper: Dict[Table, Mapper[Any]] = {}
+114 -26
View File
@@ -31,6 +31,7 @@ from .. import exc as sa_exc
from .. import future
from .. import sql
from .. import util
from ..engine import cursor as _cursor
from ..sql import operators
from ..sql.elements import BooleanClauseList
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
@@ -398,6 +399,11 @@ def _collect_insert_commands(
None
)
if bulk and mapper._set_polymorphic_identity:
params.setdefault(
mapper._polymorphic_attr_key, mapper.polymorphic_identity
)
yield (
state,
state_dict,
@@ -411,7 +417,11 @@ def _collect_insert_commands(
def _collect_update_commands(
uowtransaction, table, states_to_update, bulk=False
uowtransaction,
table,
states_to_update,
bulk=False,
use_orm_update_stmt=None,
):
"""Identify sets of values to use in UPDATE statements for a
list of states.
@@ -437,7 +447,11 @@ def _collect_update_commands(
pks = mapper._pks_by_table[table]
value_params = {}
if use_orm_update_stmt is not None:
# TODO: ordered values, etc
value_params = use_orm_update_stmt._values
else:
value_params = {}
propkey_to_col = mapper._propkey_to_col[table]
@@ -697,6 +711,7 @@ def _emit_update_statements(
table,
update,
bookkeeping=True,
use_orm_update_stmt=None,
):
"""Emit UPDATE statements corresponding to value lists collected
by _collect_update_commands()."""
@@ -708,7 +723,7 @@ def _emit_update_statements(
execution_options = {"compiled_cache": base_mapper._compiled_cache}
def update_stmt():
def update_stmt(existing_stmt=None):
clauses = BooleanClauseList._construct_raw(operators.and_)
for col in mapper._pks_by_table[table]:
@@ -725,10 +740,17 @@ def _emit_update_statements(
)
)
stmt = table.update().where(clauses)
if existing_stmt is not None:
stmt = existing_stmt.where(clauses)
else:
stmt = table.update().where(clauses)
return stmt
cached_stmt = base_mapper._memo(("update", table), update_stmt)
if use_orm_update_stmt is not None:
cached_stmt = update_stmt(use_orm_update_stmt)
else:
cached_stmt = base_mapper._memo(("update", table), update_stmt)
for (
(connection, paramkeys, hasvalue, has_all_defaults, has_all_pks),
@@ -747,6 +769,15 @@ def _emit_update_statements(
records = list(records)
statement = cached_stmt
if use_orm_update_stmt is not None:
statement = statement._annotate(
{
"_emit_update_table": table,
"_emit_update_mapper": mapper,
}
)
return_defaults = False
if not has_all_pks:
@@ -904,16 +935,35 @@ def _emit_insert_statements(
table,
insert,
bookkeeping=True,
use_orm_insert_stmt=None,
execution_options=None,
):
"""Emit INSERT statements corresponding to value lists collected
by _collect_insert_commands()."""
cached_stmt = base_mapper._memo(("insert", table), table.insert)
if use_orm_insert_stmt is not None:
cached_stmt = use_orm_insert_stmt
exec_opt = util.EMPTY_DICT
execution_options = {"compiled_cache": base_mapper._compiled_cache}
# if a user query with RETURNING was passed, we definitely need
# to use RETURNING.
returning_is_required_anyway = bool(use_orm_insert_stmt._returning)
else:
returning_is_required_anyway = False
cached_stmt = base_mapper._memo(("insert", table), table.insert)
exec_opt = {"compiled_cache": base_mapper._compiled_cache}
if execution_options:
execution_options = util.EMPTY_DICT.merge_with(
exec_opt, execution_options
)
else:
execution_options = exec_opt
return_result = None
for (
(connection, pkeys, hasvalue, has_all_pks, has_all_defaults),
(connection, _, hasvalue, has_all_pks, has_all_defaults),
records,
) in groupby(
insert,
@@ -928,17 +978,29 @@ def _emit_insert_statements(
statement = cached_stmt
if (
not bookkeeping
or (
has_all_defaults
or not base_mapper.eager_defaults
or not base_mapper.local_table.implicit_returning
or not connection.dialect.insert_returning
if use_orm_insert_stmt is not None:
statement = statement._annotate(
{
"_emit_insert_table": table,
"_emit_insert_mapper": mapper,
}
)
if (
(
not bookkeeping
or (
has_all_defaults
or not base_mapper.eager_defaults
or not base_mapper.local_table.implicit_returning
or not connection.dialect.insert_returning
)
)
and not returning_is_required_anyway
and has_all_pks
and not hasvalue
):
# the "we don't need newly generated values back" section.
# here we have all the PKs, all the defaults or we don't want
# to fetch them, or the dialect doesn't support RETURNING at all
@@ -946,7 +1008,7 @@ def _emit_insert_statements(
records = list(records)
multiparams = [rec[2] for rec in records]
c = connection.execute(
result = connection.execute(
statement, multiparams, execution_options=execution_options
)
if bookkeeping:
@@ -962,7 +1024,7 @@ def _emit_insert_statements(
has_all_defaults,
),
last_inserted_params,
) in zip(records, c.context.compiled_parameters):
) in zip(records, result.context.compiled_parameters):
if state:
_postfetch(
mapper_rec,
@@ -970,19 +1032,20 @@ def _emit_insert_statements(
table,
state,
state_dict,
c,
result,
last_inserted_params,
value_params,
False,
c.returned_defaults
if not c.context.executemany
result.returned_defaults
if not result.context.executemany
else None,
)
else:
_postfetch_bulk_save(mapper_rec, state_dict, table)
else:
# here, we need defaults and/or pk values back.
# here, we need defaults and/or pk values back or we otherwise
# know that we are using RETURNING in any case
records = list(records)
if (
@@ -991,6 +1054,16 @@ def _emit_insert_statements(
and len(records) > 1
):
do_executemany = True
elif returning_is_required_anyway:
if connection.dialect.insert_executemany_returning:
do_executemany = True
else:
raise sa_exc.InvalidRequestError(
f"Can't use explicit RETURNING for bulk INSERT "
f"operation with "
f"{connection.dialect.dialect_description} backend; "
f"executemany is not supported with RETURNING"
)
else:
do_executemany = False
@@ -998,6 +1071,7 @@ def _emit_insert_statements(
statement = statement.return_defaults(
*mapper._server_default_cols[table]
)
if mapper.version_id_col is not None:
statement = statement.return_defaults(mapper.version_id_col)
elif do_executemany:
@@ -1006,10 +1080,16 @@ def _emit_insert_statements(
if do_executemany:
multiparams = [rec[2] for rec in records]
c = connection.execute(
result = connection.execute(
statement, multiparams, execution_options=execution_options
)
if use_orm_insert_stmt is not None:
if return_result is None:
return_result = result
else:
return_result = return_result.splice_vertically(result)
if bookkeeping:
for (
(
@@ -1027,9 +1107,9 @@ def _emit_insert_statements(
returned_defaults,
) in zip_longest(
records,
c.context.compiled_parameters,
c.inserted_primary_key_rows,
c.returned_defaults_rows or (),
result.context.compiled_parameters,
result.inserted_primary_key_rows,
result.returned_defaults_rows or (),
):
if inserted_primary_key is None:
# this is a real problem and means that we didn't
@@ -1062,7 +1142,7 @@ def _emit_insert_statements(
table,
state,
state_dict,
c,
result,
last_inserted_params,
value_params,
False,
@@ -1071,6 +1151,8 @@ def _emit_insert_statements(
else:
_postfetch_bulk_save(mapper_rec, state_dict, table)
else:
assert not returning_is_required_anyway
for (
state,
state_dict,
@@ -1132,6 +1214,12 @@ def _emit_insert_statements(
else:
_postfetch_bulk_save(mapper_rec, state_dict, table)
if use_orm_insert_stmt is not None:
if return_result is None:
return _cursor.null_dml_result()
else:
return return_result
def _emit_post_update_statements(
base_mapper, uowtransaction, mapper, table, update
+2 -2
View File
@@ -2978,7 +2978,7 @@ class Query(
)
def delete(
self, synchronize_session: _SynchronizeSessionArgument = "evaluate"
self, synchronize_session: _SynchronizeSessionArgument = "auto"
) -> int:
r"""Perform a DELETE with an arbitrary WHERE clause.
@@ -3042,7 +3042,7 @@ class Query(
def update(
self,
values: Dict[_DMLColumnArgument, Any],
synchronize_session: _SynchronizeSessionArgument = "evaluate",
synchronize_session: _SynchronizeSessionArgument = "auto",
update_args: Optional[Dict[Any, Any]] = None,
) -> int:
r"""Perform an UPDATE with an arbitrary WHERE clause.
+13 -12
View File
@@ -1828,12 +1828,13 @@ class Session(_SessionClassMethods, EventTarget):
statement._propagate_attrs.get("compile_state_plugin", None)
== "orm"
):
# note that even without "future" mode, we need
compile_state_cls = CompileState._get_plugin_class_for_plugin(
statement, "orm"
)
if TYPE_CHECKING:
assert isinstance(compile_state_cls, ORMCompileState)
assert isinstance(
compile_state_cls, context.AbstractORMCompileState
)
else:
compile_state_cls = None
@@ -1897,18 +1898,18 @@ class Session(_SessionClassMethods, EventTarget):
statement, params or {}, execution_options=execution_options
)
result: Result[Any] = conn.execute(
statement, params or {}, execution_options=execution_options
)
if compile_state_cls:
result = compile_state_cls.orm_setup_cursor_result(
result: Result[Any] = compile_state_cls.orm_execute_statement(
self,
statement,
params,
params or {},
execution_options,
bind_arguments,
result,
conn,
)
else:
result = conn.execute(
statement, params or {}, execution_options=execution_options
)
if _scalar_result:
@@ -2066,7 +2067,7 @@ class Session(_SessionClassMethods, EventTarget):
def scalars(
self,
statement: TypedReturnsRows[Tuple[_T]],
params: Optional[_CoreSingleExecuteParams] = None,
params: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
@@ -2078,7 +2079,7 @@ class Session(_SessionClassMethods, EventTarget):
def scalars(
self,
statement: Executable,
params: Optional[_CoreSingleExecuteParams] = None,
params: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
@@ -2089,7 +2090,7 @@ class Session(_SessionClassMethods, EventTarget):
def scalars(
self,
statement: Executable,
params: Optional[_CoreSingleExecuteParams] = None,
params: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
+11
View File
@@ -227,6 +227,11 @@ class ColumnLoader(LoaderStrategy):
fetch = self.columns[0]
if adapter:
fetch = adapter.columns[fetch]
if fetch is None:
# None happens here only for dml bulk_persistence cases
# when context.DMLReturningColFilter is used
return
memoized_populators[self.parent_property] = fetch
def init_class_attribute(self, mapper):
@@ -318,6 +323,12 @@ class ExpressionColumnLoader(ColumnLoader):
fetch = columns[0]
if adapter:
fetch = adapter.columns[fetch]
if fetch is None:
# None is not expected to be the result of any
# adapter implementation here, however there may be theoretical
# usages of returning() with context.DMLReturningColFilter
return
memoized_populators[self.parent_property] = fetch
def create_row_processor(
+2
View File
@@ -552,6 +552,8 @@ def _new_annotation_type(
# e.g. BindParameter, add it if present.
if cls.__dict__.get("inherit_cache", False):
anno_cls.inherit_cache = True # type: ignore
elif "inherit_cache" in cls.__dict__:
anno_cls.inherit_cache = cls.__dict__["inherit_cache"] # type: ignore
anno_cls._is_column_operators = issubclass(cls, operators.ColumnOperators)
+14 -10
View File
@@ -5166,6 +5166,8 @@ class SQLCompiler(Compiled):
delete_stmt, delete_stmt.table, extra_froms
)
crud._get_crud_params(self, delete_stmt, compile_state, toplevel, **kw)
if delete_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
delete_stmt, table_text
@@ -5178,13 +5180,14 @@ class SQLCompiler(Compiled):
text += table_text
if delete_stmt._returning:
if self.returning_precedes_values:
text += " " + self.returning_clause(
delete_stmt,
delete_stmt._returning,
populate_result_map=toplevel,
)
if (
self.implicit_returning or delete_stmt._returning
) and self.returning_precedes_values:
text += " " + self.returning_clause(
delete_stmt,
self.implicit_returning or delete_stmt._returning,
populate_result_map=toplevel,
)
if extra_froms:
extra_from_text = self.delete_extra_from_clause(
@@ -5204,10 +5207,12 @@ class SQLCompiler(Compiled):
if t:
text += " WHERE " + t
if delete_stmt._returning and not self.returning_precedes_values:
if (
self.implicit_returning or delete_stmt._returning
) and not self.returning_precedes_values:
text += " " + self.returning_clause(
delete_stmt,
delete_stmt._returning,
self.implicit_returning or delete_stmt._returning,
populate_result_map=toplevel,
)
@@ -5297,7 +5302,6 @@ class StrSQLCompiler(SQLCompiler):
self._label_select_column(None, c, True, False, {})
for c in base._select_iterables(returning_cols)
]
return "RETURNING " + ", ".join(columns)
def update_from_clause(
+101 -18
View File
@@ -150,6 +150,22 @@ def _get_crud_params(
"return_defaults() simultaneously"
)
if compile_state.isdelete:
_setup_delete_return_defaults(
compiler,
stmt,
compile_state,
(),
_getattr_col_key,
_column_as_key,
_col_bind_name,
(),
(),
toplevel,
kw,
)
return _CrudParams([], [])
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if compiler.column_keys is None and compile_state._no_parameters:
@@ -466,13 +482,6 @@ def _scan_insert_from_select_cols(
kw,
):
(
need_pks,
implicit_returning,
implicit_return_defaults,
postfetch_lastrowid,
) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel)
cols = [stmt.table.c[_column_as_key(name)] for name in stmt._select_names]
assert compiler.stack[-1]["selectable"] is stmt
@@ -537,6 +546,8 @@ def _scan_cols(
postfetch_lastrowid,
) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel)
assert compile_state.isupdate or compile_state.isinsert
if compile_state._parameter_ordering:
parameter_ordering = [
_column_as_key(key) for key in compile_state._parameter_ordering
@@ -563,6 +574,13 @@ def _scan_cols(
else:
autoincrement_col = insert_null_pk_still_autoincrements = None
if stmt._supplemental_returning:
supplemental_returning = set(stmt._supplemental_returning)
else:
supplemental_returning = set()
compiler_implicit_returning = compiler.implicit_returning
for c in cols:
# scan through every column in the target table
@@ -627,11 +645,13 @@ def _scan_cols(
# column has a DDL-level default, and is either not a pk
# column or we don't need the pk.
if implicit_return_defaults and c in implicit_return_defaults:
compiler.implicit_returning.append(c)
compiler_implicit_returning.append(c)
elif not c.primary_key:
compiler.postfetch.append(c)
elif implicit_return_defaults and c in implicit_return_defaults:
compiler.implicit_returning.append(c)
compiler_implicit_returning.append(c)
elif (
c.primary_key
and c is not stmt.table._autoincrement_column
@@ -652,6 +672,59 @@ def _scan_cols(
kw,
)
# adding supplemental cols to implicit_returning in table
# order so that order is maintained between multiple INSERT
# statements which may have different parameters included, but all
# have the same RETURNING clause
if (
c in supplemental_returning
and c not in compiler_implicit_returning
):
compiler_implicit_returning.append(c)
if supplemental_returning:
# we should have gotten every col into implicit_returning,
# however supplemental returning can also have SQL functions etc.
# in it
remaining_supplemental = supplemental_returning.difference(
compiler_implicit_returning
)
compiler_implicit_returning.extend(
c
for c in stmt._supplemental_returning
if c in remaining_supplemental
)
def _setup_delete_return_defaults(
compiler,
stmt,
compile_state,
parameters,
_getattr_col_key,
_column_as_key,
_col_bind_name,
check_columns,
values,
toplevel,
kw,
):
(_, _, implicit_return_defaults, _) = _get_returning_modifiers(
compiler, stmt, compile_state, toplevel
)
if not implicit_return_defaults:
return
if stmt._return_defaults_columns:
compiler.implicit_returning.extend(implicit_return_defaults)
if stmt._supplemental_returning:
ir_set = set(compiler.implicit_returning)
compiler.implicit_returning.extend(
c for c in stmt._supplemental_returning if c not in ir_set
)
def _append_param_parameter(
compiler,
@@ -743,7 +816,7 @@ def _append_param_parameter(
elif compiler.dialect.postfetch_lastrowid:
compiler.postfetch_lastrowid = True
elif implicit_return_defaults and c in implicit_return_defaults:
elif implicit_return_defaults and (c in implicit_return_defaults):
compiler.implicit_returning.append(c)
else:
@@ -1303,6 +1376,7 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel):
INSERT or UPDATE statement after it's invoked.
"""
need_pks = (
toplevel
and _compile_state_isinsert(compile_state)
@@ -1315,6 +1389,7 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel):
)
)
and not stmt._returning
# and (not stmt._returning or stmt._return_defaults)
and not compile_state._has_multi_parameters
)
@@ -1357,33 +1432,41 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel):
or stmt._return_defaults
)
)
if implicit_returning:
postfetch_lastrowid = False
if _compile_state_isinsert(compile_state):
implicit_return_defaults = implicit_returning and stmt._return_defaults
should_implicit_return_defaults = (
implicit_returning and stmt._return_defaults
)
elif compile_state.isupdate:
implicit_return_defaults = (
should_implicit_return_defaults = (
stmt._return_defaults
and compile_state._primary_table.implicit_returning
and compile_state._supports_implicit_returning
and compiler.dialect.update_returning
)
elif compile_state.isdelete:
should_implicit_return_defaults = (
stmt._return_defaults
and compile_state._primary_table.implicit_returning
and compile_state._supports_implicit_returning
and compiler.dialect.delete_returning
)
else:
# this line is unused, currently we are always
# isinsert or isupdate
implicit_return_defaults = False # pragma: no cover
should_implicit_return_defaults = False # pragma: no cover
if implicit_return_defaults:
if should_implicit_return_defaults:
if not stmt._return_defaults_columns:
implicit_return_defaults = set(stmt.table.c)
else:
implicit_return_defaults = set(stmt._return_defaults_columns)
else:
implicit_return_defaults = None
return (
need_pks,
implicit_returning,
implicit_returning or should_implicit_return_defaults,
implicit_return_defaults,
postfetch_lastrowid,
)
+241 -168
View File
@@ -164,16 +164,33 @@ class DMLState(CompileState):
def get_plugin_class(cls, statement: Executable) -> Type[DMLState]:
...
@classmethod
def _get_multi_crud_kv_pairs(
cls,
statement: UpdateBase,
multi_kv_iterator: Iterable[Dict[_DMLColumnArgument, Any]],
) -> List[Dict[_DMLColumnElement, Any]]:
return [
{
coercions.expect(roles.DMLColumnRole, k): v
for k, v in mapping.items()
}
for mapping in multi_kv_iterator
]
@classmethod
def _get_crud_kv_pairs(
cls,
statement: UpdateBase,
kv_iterator: Iterable[Tuple[_DMLColumnArgument, Any]],
needs_to_be_cacheable: bool,
) -> List[Tuple[_DMLColumnElement, Any]]:
return [
(
coercions.expect(roles.DMLColumnRole, k),
coercions.expect(
v
if not needs_to_be_cacheable
else coercions.expect(
roles.ExpressionElementRole,
v,
type_=NullType(),
@@ -269,7 +286,7 @@ class InsertDMLState(DMLState):
def _insert_col_keys(self) -> List[str]:
# this is also done in crud.py -> _key_getters_for_crud_column
return [
coercions.expect_as_key(roles.DMLColumnRole, col)
coercions.expect(roles.DMLColumnRole, col, as_key=True)
for col in self._dict_parameters or ()
]
@@ -326,7 +343,6 @@ class UpdateDMLState(DMLState):
self._extra_froms = ef
self.is_multitable = mt = ef
self.include_table_with_column_exprs = bool(
mt and compiler.render_table_with_column_in_update_from
)
@@ -389,6 +405,7 @@ class UpdateBase(
_return_defaults_columns: Optional[
Tuple[_ColumnsClauseElement, ...]
] = None
_supplemental_returning: Optional[Tuple[_ColumnsClauseElement, ...]] = None
_returning: Tuple[_ColumnsClauseElement, ...] = ()
is_dml = True
@@ -434,6 +451,215 @@ class UpdateBase(
self._validate_dialect_kwargs(opt)
return self
@_generative
def return_defaults(
self: SelfUpdateBase,
*cols: _DMLColumnArgument,
supplemental_cols: Optional[Iterable[_DMLColumnArgument]] = None,
) -> SelfUpdateBase:
"""Make use of a :term:`RETURNING` clause for the purpose
of fetching server-side expressions and defaults, for supporting
backends only.
.. deepalchemy::
The :meth:`.UpdateBase.return_defaults` method is used by the ORM
for its internal work in fetching newly generated primary key
and server default values, in particular to provide the underyling
implementation of the :paramref:`_orm.Mapper.eager_defaults`
ORM feature as well as to allow RETURNING support with bulk
ORM inserts. Its behavior is fairly idiosyncratic
and is not really intended for general use. End users should
stick with using :meth:`.UpdateBase.returning` in order to
add RETURNING clauses to their INSERT, UPDATE and DELETE
statements.
Normally, a single row INSERT statement will automatically populate the
:attr:`.CursorResult.inserted_primary_key` attribute when executed,
which stores the primary key of the row that was just inserted in the
form of a :class:`.Row` object with column names as named tuple keys
(and the :attr:`.Row._mapping` view fully populated as well). The
dialect in use chooses the strategy to use in order to populate this
data; if it was generated using server-side defaults and / or SQL
expressions, dialect-specific approaches such as ``cursor.lastrowid``
or ``RETURNING`` are typically used to acquire the new primary key
value.
However, when the statement is modified by calling
:meth:`.UpdateBase.return_defaults` before executing the statement,
additional behaviors take place **only** for backends that support
RETURNING and for :class:`.Table` objects that maintain the
:paramref:`.Table.implicit_returning` parameter at its default value of
``True``. In these cases, when the :class:`.CursorResult` is returned
from the statement's execution, not only will
:attr:`.CursorResult.inserted_primary_key` be populated as always, the
:attr:`.CursorResult.returned_defaults` attribute will also be
populated with a :class:`.Row` named-tuple representing the full range
of server generated
values from that single row, including values for any columns that
specify :paramref:`_schema.Column.server_default` or which make use of
:paramref:`_schema.Column.default` using a SQL expression.
When invoking INSERT statements with multiple rows using
:ref:`insertmanyvalues <engine_insertmanyvalues>`, the
:meth:`.UpdateBase.return_defaults` modifier will have the effect of
the :attr:`_engine.CursorResult.inserted_primary_key_rows` and
:attr:`_engine.CursorResult.returned_defaults_rows` attributes being
fully populated with lists of :class:`.Row` objects representing newly
inserted primary key values as well as newly inserted server generated
values for each row inserted. The
:attr:`.CursorResult.inserted_primary_key` and
:attr:`.CursorResult.returned_defaults` attributes will also continue
to be populated with the first row of these two collections.
If the backend does not support RETURNING or the :class:`.Table` in use
has disabled :paramref:`.Table.implicit_returning`, then no RETURNING
clause is added and no additional data is fetched, however the
INSERT, UPDATE or DELETE statement proceeds normally.
E.g.::
stmt = table.insert().values(data='newdata').return_defaults()
result = connection.execute(stmt)
server_created_at = result.returned_defaults['created_at']
When used against an UPDATE statement
:meth:`.UpdateBase.return_defaults` instead looks for columns that
include :paramref:`_schema.Column.onupdate` or
:paramref:`_schema.Column.server_onupdate` parameters assigned, when
constructing the columns that will be included in the RETURNING clause
by default if explicit columns were not specified. When used against a
DELETE statement, no columns are included in RETURNING by default, they
instead must be specified explicitly as there are no columns that
normally change values when a DELETE statement proceeds.
.. versionadded:: 2.0 :meth:`.UpdateBase.return_defaults` is supported
for DELETE statements also and has been moved from
:class:`.ValuesBase` to :class:`.UpdateBase`.
The :meth:`.UpdateBase.return_defaults` method is mutually exclusive
against the :meth:`.UpdateBase.returning` method and errors will be
raised during the SQL compilation process if both are used at the same
time on one statement. The RETURNING clause of the INSERT, UPDATE or
DELETE statement is therefore controlled by only one of these methods
at a time.
The :meth:`.UpdateBase.return_defaults` method differs from
:meth:`.UpdateBase.returning` in these ways:
1. :meth:`.UpdateBase.return_defaults` method causes the
:attr:`.CursorResult.returned_defaults` collection to be populated
with the first row from the RETURNING result. This attribute is not
populated when using :meth:`.UpdateBase.returning`.
2. :meth:`.UpdateBase.return_defaults` is compatible with existing
logic used to fetch auto-generated primary key values that are then
populated into the :attr:`.CursorResult.inserted_primary_key`
attribute. By contrast, using :meth:`.UpdateBase.returning` will
have the effect of the :attr:`.CursorResult.inserted_primary_key`
attribute being left unpopulated.
3. :meth:`.UpdateBase.return_defaults` can be called against any
backend. Backends that don't support RETURNING will skip the usage
of the feature, rather than raising an exception. The return value
of :attr:`_engine.CursorResult.returned_defaults` will be ``None``
for backends that don't support RETURNING or for which the target
:class:`.Table` sets :paramref:`.Table.implicit_returning` to
``False``.
4. An INSERT statement invoked with executemany() is supported if the
backend database driver supports the
:ref:`insertmanyvalues <engine_insertmanyvalues>`
feature which is now supported by most SQLAlchemy-included backends.
When executemany is used, the
:attr:`_engine.CursorResult.returned_defaults_rows` and
:attr:`_engine.CursorResult.inserted_primary_key_rows` accessors
will return the inserted defaults and primary keys.
.. versionadded:: 1.4 Added
:attr:`_engine.CursorResult.returned_defaults_rows` and
:attr:`_engine.CursorResult.inserted_primary_key_rows` accessors.
In version 2.0, the underlying implementation which fetches and
populates the data for these attributes was generalized to be
supported by most backends, whereas in 1.4 they were only
supported by the ``psycopg2`` driver.
:param cols: optional list of column key names or
:class:`_schema.Column` that acts as a filter for those columns that
will be fetched.
:param supplemental_cols: optional list of RETURNING expressions,
in the same form as one would pass to the
:meth:`.UpdateBase.returning` method. When present, the additional
columns will be included in the RETURNING clause, and the
:class:`.CursorResult` object will be "rewound" when returned, so
that methods like :meth:`.CursorResult.all` will return new rows
mostly as though the statement used :meth:`.UpdateBase.returning`
directly. However, unlike when using :meth:`.UpdateBase.returning`
directly, the **order of the columns is undefined**, so can only be
targeted using names or :attr:`.Row._mapping` keys; they cannot
reliably be targeted positionally.
.. versionadded:: 2.0
.. seealso::
:meth:`.UpdateBase.returning`
:attr:`_engine.CursorResult.returned_defaults`
:attr:`_engine.CursorResult.returned_defaults_rows`
:attr:`_engine.CursorResult.inserted_primary_key`
:attr:`_engine.CursorResult.inserted_primary_key_rows`
"""
if self._return_defaults:
# note _return_defaults_columns = () means return all columns,
# so if we have been here before, only update collection if there
# are columns in the collection
if self._return_defaults_columns and cols:
self._return_defaults_columns = tuple(
util.OrderedSet(self._return_defaults_columns).union(
coercions.expect(roles.ColumnsClauseRole, c)
for c in cols
)
)
else:
# set for all columns
self._return_defaults_columns = ()
else:
self._return_defaults_columns = tuple(
coercions.expect(roles.ColumnsClauseRole, c) for c in cols
)
self._return_defaults = True
if supplemental_cols:
# uniquifying while also maintaining order (the maintain of order
# is for test suites but also for vertical splicing
supplemental_col_tup = (
coercions.expect(roles.ColumnsClauseRole, c)
for c in supplemental_cols
)
if self._supplemental_returning is None:
self._supplemental_returning = tuple(
util.unique_list(supplemental_col_tup)
)
else:
self._supplemental_returning = tuple(
util.unique_list(
self._supplemental_returning
+ tuple(supplemental_col_tup)
)
)
return self
@_generative
def returning(
self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
@@ -500,7 +726,7 @@ class UpdateBase(
.. seealso::
:meth:`.ValuesBase.return_defaults` - an alternative method tailored
:meth:`.UpdateBase.return_defaults` - an alternative method tailored
towards efficient fetching of server-side defaults and triggers
for single-row INSERTs or UPDATEs.
@@ -703,7 +929,6 @@ class ValuesBase(UpdateBase):
_select_names: Optional[List[str]] = None
_inline: bool = False
_returning: Tuple[_ColumnsClauseElement, ...] = ()
def __init__(self, table: _DMLTableArgument):
self.table = coercions.expect(
@@ -859,7 +1084,15 @@ class ValuesBase(UpdateBase):
)
elif isinstance(arg, collections_abc.Sequence):
if arg and isinstance(arg[0], (list, dict, tuple)):
if arg and isinstance(arg[0], dict):
multi_kv_generator = DMLState.get_plugin_class(
self
)._get_multi_crud_kv_pairs
self._multi_values += (multi_kv_generator(self, arg),)
return self
if arg and isinstance(arg[0], (list, tuple)):
self._multi_values += (arg,)
return self
@@ -888,173 +1121,13 @@ class ValuesBase(UpdateBase):
# and ensures they get the "crud"-style name when rendered.
kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs
coerced_arg = {k: v for k, v in kv_generator(self, arg.items())}
coerced_arg = dict(kv_generator(self, arg.items(), True))
if self._values:
self._values = self._values.union(coerced_arg)
else:
self._values = util.immutabledict(coerced_arg)
return self
@_generative
def return_defaults(
self: SelfValuesBase, *cols: _DMLColumnArgument
) -> SelfValuesBase:
"""Make use of a :term:`RETURNING` clause for the purpose
of fetching server-side expressions and defaults, for supporting
backends only.
.. tip::
The :meth:`.ValuesBase.return_defaults` method is used by the ORM
for its internal work in fetching newly generated primary key
and server default values, in particular to provide the underyling
implementation of the :paramref:`_orm.Mapper.eager_defaults`
ORM feature. Its behavior is fairly idiosyncratic
and is not really intended for general use. End users should
stick with using :meth:`.UpdateBase.returning` in order to
add RETURNING clauses to their INSERT, UPDATE and DELETE
statements.
Normally, a single row INSERT statement will automatically populate the
:attr:`.CursorResult.inserted_primary_key` attribute when executed,
which stores the primary key of the row that was just inserted in the
form of a :class:`.Row` object with column names as named tuple keys
(and the :attr:`.Row._mapping` view fully populated as well). The
dialect in use chooses the strategy to use in order to populate this
data; if it was generated using server-side defaults and / or SQL
expressions, dialect-specific approaches such as ``cursor.lastrowid``
or ``RETURNING`` are typically used to acquire the new primary key
value.
However, when the statement is modified by calling
:meth:`.ValuesBase.return_defaults` before executing the statement,
additional behaviors take place **only** for backends that support
RETURNING and for :class:`.Table` objects that maintain the
:paramref:`.Table.implicit_returning` parameter at its default value of
``True``. In these cases, when the :class:`.CursorResult` is returned
from the statement's execution, not only will
:attr:`.CursorResult.inserted_primary_key` be populated as always, the
:attr:`.CursorResult.returned_defaults` attribute will also be
populated with a :class:`.Row` named-tuple representing the full range
of server generated
values from that single row, including values for any columns that
specify :paramref:`_schema.Column.server_default` or which make use of
:paramref:`_schema.Column.default` using a SQL expression.
When invoking INSERT statements with multiple rows using
:ref:`insertmanyvalues <engine_insertmanyvalues>`, the
:meth:`.ValuesBase.return_defaults` modifier will have the effect of
the :attr:`_engine.CursorResult.inserted_primary_key_rows` and
:attr:`_engine.CursorResult.returned_defaults_rows` attributes being
fully populated with lists of :class:`.Row` objects representing newly
inserted primary key values as well as newly inserted server generated
values for each row inserted. The
:attr:`.CursorResult.inserted_primary_key` and
:attr:`.CursorResult.returned_defaults` attributes will also continue
to be populated with the first row of these two collections.
If the backend does not support RETURNING or the :class:`.Table` in use
has disabled :paramref:`.Table.implicit_returning`, then no RETURNING
clause is added and no additional data is fetched, however the
INSERT or UPDATE statement proceeds normally.
E.g.::
stmt = table.insert().values(data='newdata').return_defaults()
result = connection.execute(stmt)
server_created_at = result.returned_defaults['created_at']
The :meth:`.ValuesBase.return_defaults` method is mutually exclusive
against the :meth:`.UpdateBase.returning` method and errors will be
raised during the SQL compilation process if both are used at the same
time on one statement. The RETURNING clause of the INSERT or UPDATE
statement is therefore controlled by only one of these methods at a
time.
The :meth:`.ValuesBase.return_defaults` method differs from
:meth:`.UpdateBase.returning` in these ways:
1. :meth:`.ValuesBase.return_defaults` method causes the
:attr:`.CursorResult.returned_defaults` collection to be populated
with the first row from the RETURNING result. This attribute is not
populated when using :meth:`.UpdateBase.returning`.
2. :meth:`.ValuesBase.return_defaults` is compatible with existing
logic used to fetch auto-generated primary key values that are then
populated into the :attr:`.CursorResult.inserted_primary_key`
attribute. By contrast, using :meth:`.UpdateBase.returning` will
have the effect of the :attr:`.CursorResult.inserted_primary_key`
attribute being left unpopulated.
3. :meth:`.ValuesBase.return_defaults` can be called against any
backend. Backends that don't support RETURNING will skip the usage
of the feature, rather than raising an exception. The return value
of :attr:`_engine.CursorResult.returned_defaults` will be ``None``
for backends that don't support RETURNING or for which the target
:class:`.Table` sets :paramref:`.Table.implicit_returning` to
``False``.
4. An INSERT statement invoked with executemany() is supported if the
backend database driver supports the
:ref:`insertmanyvalues <engine_insertmanyvalues>`
feature which is now supported by most SQLAlchemy-included backends.
When executemany is used, the
:attr:`_engine.CursorResult.returned_defaults_rows` and
:attr:`_engine.CursorResult.inserted_primary_key_rows` accessors
will return the inserted defaults and primary keys.
.. versionadded:: 1.4 Added
:attr:`_engine.CursorResult.returned_defaults_rows` and
:attr:`_engine.CursorResult.inserted_primary_key_rows` accessors.
In version 2.0, the underlying implementation which fetches and
populates the data for these attributes was generalized to be
supported by most backends, whereas in 1.4 they were only
supported by the ``psycopg2`` driver.
:param cols: optional list of column key names or
:class:`_schema.Column` that acts as a filter for those columns that
will be fetched.
.. seealso::
:meth:`.UpdateBase.returning`
:attr:`_engine.CursorResult.returned_defaults`
:attr:`_engine.CursorResult.returned_defaults_rows`
:attr:`_engine.CursorResult.inserted_primary_key`
:attr:`_engine.CursorResult.inserted_primary_key_rows`
"""
if self._return_defaults:
# note _return_defaults_columns = () means return all columns,
# so if we have been here before, only update collection if there
# are columns in the collection
if self._return_defaults_columns and cols:
self._return_defaults_columns = tuple(
set(self._return_defaults_columns).union(
coercions.expect(roles.ColumnsClauseRole, c)
for c in cols
)
)
else:
# set for all columns
self._return_defaults_columns = ()
else:
self._return_defaults_columns = tuple(
coercions.expect(roles.ColumnsClauseRole, c) for c in cols
)
self._return_defaults = True
return self
SelfInsert = typing.TypeVar("SelfInsert", bound="Insert")
@@ -1459,7 +1532,7 @@ class Update(DMLWhereBase, ValuesBase):
)
kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs
self._ordered_values = kv_generator(self, args)
self._ordered_values = kv_generator(self, args, True)
return self
@_generative
+12 -1
View File
@@ -68,7 +68,7 @@ class CursorSQL(SQLMatchRule):
class CompiledSQL(SQLMatchRule):
def __init__(
self, statement, params=None, dialect="default", enable_returning=False
self, statement, params=None, dialect="default", enable_returning=True
):
self.statement = statement
self.params = params
@@ -90,6 +90,17 @@ class CompiledSQL(SQLMatchRule):
dialect.insert_returning = (
dialect.update_returning
) = dialect.delete_returning = True
dialect.use_insertmanyvalues = True
dialect.supports_multivalues_insert = True
dialect.update_returning_multifrom = True
dialect.delete_returning_multifrom = True
# dialect.favor_returning_over_lastrowid = True
# dialect.insert_null_pk_still_autoincrements = True
# this is calculated but we need it to be True for this
# to look like all the current RETURNING dialects
assert dialect.insert_executemany_returning
return dialect
else:
return url.URL.create(self.dialect).get_dialect()()
+11 -11
View File
@@ -23,7 +23,6 @@ from .util import adict
from .util import drop_all_tables_from_metadata
from .. import event
from .. import util
from ..orm import declarative_base
from ..orm import DeclarativeBase
from ..orm import MappedAsDataclass
from ..orm import registry
@@ -117,7 +116,7 @@ class TestBase:
metadata=metadata,
type_annotation_map={
str: sa.String().with_variant(
sa.String(50), "mysql", "mariadb"
sa.String(50), "mysql", "mariadb", "oracle"
)
},
)
@@ -132,7 +131,7 @@ class TestBase:
metadata = _md
type_annotation_map = {
str: sa.String().with_variant(
sa.String(50), "mysql", "mariadb"
sa.String(50), "mysql", "mariadb", "oracle"
)
}
@@ -780,18 +779,19 @@ class DeclarativeMappedTest(MappedTest):
def _with_register_classes(cls, fn):
cls_registry = cls.classes
class DeclarativeBasic:
class _DeclBase(DeclarativeBase):
__table_cls__ = schema.Table
metadata = cls._tables_metadata
type_annotation_map = {
str: sa.String().with_variant(
sa.String(50), "mysql", "mariadb", "oracle"
)
}
def __init_subclass__(cls) -> None:
def __init_subclass__(cls, **kw) -> None:
assert cls_registry is not None
cls_registry[cls.__name__] = cls
super().__init_subclass__()
_DeclBase = declarative_base(
metadata=cls._tables_metadata,
cls=DeclarativeBasic,
)
super().__init_subclass__(**kw)
cls.DeclarativeBasic = _DeclBase
@@ -89,8 +89,13 @@ class RowCountTest(fixtures.TablesTest):
eq_(r.rowcount, 3)
@testing.requires.update_returning
@testing.requires.sane_rowcount_w_returning
def test_update_rowcount_return_defaults(self, connection):
"""note this test should succeed for all RETURNING backends
as of 2.0. In
Idf28379f8705e403a3c6a937f6a798a042ef2540 we changed rowcount to use
len(rows) when we have implicit returning
"""
employees_table = self.tables.employees
department = employees_table.c.department
+4 -1
View File
@@ -11,6 +11,7 @@ from __future__ import annotations
from itertools import filterfalse
from typing import AbstractSet
from typing import Any
from typing import Callable
from typing import cast
from typing import Collection
from typing import Dict
@@ -481,7 +482,9 @@ class IdentitySet:
return "%s(%r)" % (type(self).__name__, list(self._members.values()))
def unique_list(seq, hashfunc=None):
def unique_list(
seq: Iterable[_T], hashfunc: Optional[Callable[[_T], int]] = None
) -> List[_T]:
seen: Set[Any] = set()
seen_add = seen.add
if not hashfunc:
+47 -175
View File
@@ -465,7 +465,11 @@ class ShardTest:
t = get_tokyo(sess2)
eq_(t.city, tokyo.city)
def test_bulk_update_synchronize_evaluate(self):
@testing.combinations(
"fetch", "evaluate", "auto", argnames="synchronize_session"
)
@testing.combinations(True, False, argnames="legacy")
def test_orm_update_synchronize(self, synchronize_session, legacy):
sess = self._fixture_data()
eq_(
@@ -476,199 +480,67 @@ class ShardTest:
temps = sess.query(Report).all()
eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
sess.query(Report).filter(Report.temperature >= 80).update(
{"temperature": Report.temperature + 6},
synchronize_session="evaluate",
)
eq_(
set(row.temperature for row in sess.query(Report.temperature)),
{86.0, 75.0, 91.0},
)
# test synchronize session as well
eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
def test_bulk_update_synchronize_fetch(self):
sess = self._fixture_data()
eq_(
set(row.temperature for row in sess.query(Report.temperature)),
{80.0, 75.0, 85.0},
)
temps = sess.query(Report).all()
eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
sess.query(Report).filter(Report.temperature >= 80).update(
{"temperature": Report.temperature + 6},
synchronize_session="fetch",
)
eq_(
set(row.temperature for row in sess.query(Report.temperature)),
{86.0, 75.0, 91.0},
)
# test synchronize session as well
eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
def test_bulk_delete_synchronize_evaluate(self):
sess = self._fixture_data()
temps = sess.query(Report).all()
eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
sess.query(Report).filter(Report.temperature >= 80).delete(
synchronize_session="evaluate"
)
eq_(
set(row.temperature for row in sess.query(Report.temperature)),
{75.0},
)
# test synchronize session as well
for t in temps:
assert inspect(t).deleted is (t.temperature >= 80)
def test_bulk_delete_synchronize_fetch(self):
sess = self._fixture_data()
temps = sess.query(Report).all()
eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
sess.query(Report).filter(Report.temperature >= 80).delete(
synchronize_session="fetch"
)
eq_(
set(row.temperature for row in sess.query(Report.temperature)),
{75.0},
)
# test synchronize session as well
for t in temps:
assert inspect(t).deleted is (t.temperature >= 80)
def test_bulk_update_future_synchronize_evaluate(self):
sess = self._fixture_data()
eq_(
set(
row.temperature
for row in sess.execute(select(Report.temperature))
),
{80.0, 75.0, 85.0},
)
temps = sess.execute(select(Report)).scalars().all()
eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
sess.execute(
update(Report)
.filter(Report.temperature >= 80)
.values(
if legacy:
sess.query(Report).filter(Report.temperature >= 80).update(
{"temperature": Report.temperature + 6},
synchronize_session=synchronize_session,
)
.execution_options(synchronize_session="evaluate")
else:
sess.execute(
update(Report)
.filter(Report.temperature >= 80)
.values(temperature=Report.temperature + 6)
.execution_options(synchronize_session=synchronize_session)
)
# test synchronize session
def go():
eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
self.assert_sql_count(
sess._ShardedSession__binds["north_america"], go, 0
)
eq_(
set(
row.temperature
for row in sess.execute(select(Report.temperature))
),
set(row.temperature for row in sess.query(Report.temperature)),
{86.0, 75.0, 91.0},
)
# test synchronize session as well
eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
def test_bulk_update_future_synchronize_fetch(self):
@testing.combinations(
"fetch", "evaluate", "auto", argnames="synchronize_session"
)
@testing.combinations(True, False, argnames="legacy")
def test_orm_delete_synchronize(self, synchronize_session, legacy):
sess = self._fixture_data()
eq_(
set(
row.temperature
for row in sess.execute(select(Report.temperature))
),
{80.0, 75.0, 85.0},
)
temps = sess.execute(select(Report)).scalars().all()
temps = sess.query(Report).all()
eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
# MARKMARK
# omitting the criteria so that the UPDATE affects three out of
# four shards
sess.execute(
update(Report)
.values(
{"temperature": Report.temperature + 6},
if legacy:
sess.query(Report).filter(Report.temperature >= 80).delete(
synchronize_session=synchronize_session
)
.execution_options(synchronize_session="fetch")
else:
sess.execute(
delete(Report)
.filter(Report.temperature >= 80)
.execution_options(synchronize_session=synchronize_session)
)
def go():
# test synchronize session
for t in temps:
assert inspect(t).deleted is (t.temperature >= 80)
self.assert_sql_count(
sess._ShardedSession__binds["north_america"], go, 0
)
eq_(
set(
row.temperature
for row in sess.execute(select(Report.temperature))
),
{86.0, 81.0, 91.0},
)
# test synchronize session as well
eq_(set(t.temperature for t in temps), {86.0, 81.0, 91.0})
def test_bulk_delete_future_synchronize_evaluate(self):
sess = self._fixture_data()
temps = sess.execute(select(Report)).scalars().all()
eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
sess.execute(
delete(Report)
.filter(Report.temperature >= 80)
.execution_options(synchronize_session="evaluate")
)
eq_(
set(
row.temperature
for row in sess.execute(select(Report.temperature))
),
set(row.temperature for row in sess.query(Report.temperature)),
{75.0},
)
# test synchronize session as well
for t in temps:
assert inspect(t).deleted is (t.temperature >= 80)
def test_bulk_delete_future_synchronize_fetch(self):
sess = self._fixture_data()
temps = sess.execute(select(Report)).scalars().all()
eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
sess.execute(
delete(Report)
.filter(Report.temperature >= 80)
.execution_options(synchronize_session="fetch")
)
eq_(
set(
row.temperature
for row in sess.execute(select(Report.temperature))
),
{75.0},
)
# test synchronize session as well
for t in temps:
assert inspect(t).deleted is (t.temperature >= 80)
class DistinctEngineShardTest(ShardTest, fixtures.MappedTest):
def _init_dbs(self):
+32 -3
View File
@@ -3,6 +3,7 @@ from decimal import Decimal
from sqlalchemy import exc
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import insert
from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL
@@ -1017,15 +1018,43 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL):
params={"first_name": "Dr."},
)
def test_update_expr(self):
@testing.combinations("attr", "str", "kwarg", argnames="keytype")
def test_update_expr(self, keytype):
Person = self.classes.Person
statement = update(Person).values({Person.name: "Dr. No"})
if keytype == "attr":
statement = update(Person).values({Person.name: "Dr. No"})
elif keytype == "str":
statement = update(Person).values({"name": "Dr. No"})
elif keytype == "kwarg":
statement = update(Person).values(name="Dr. No")
else:
assert False
self.assert_compile(
statement,
"UPDATE person SET first_name=:first_name, last_name=:last_name",
params={"first_name": "Dr.", "last_name": "No"},
checkparams={"first_name": "Dr.", "last_name": "No"},
)
@testing.combinations("attr", "str", "kwarg", argnames="keytype")
def test_insert_expr(self, keytype):
Person = self.classes.Person
if keytype == "attr":
statement = insert(Person).values({Person.name: "Dr. No"})
elif keytype == "str":
statement = insert(Person).values({"name": "Dr. No"})
elif keytype == "kwarg":
statement = insert(Person).values(name="Dr. No")
else:
assert False
self.assert_compile(
statement,
"INSERT INTO person (first_name, last_name) VALUES "
"(:first_name, :last_name)",
checkparams={"first_name": "Dr.", "last_name": "No"},
)
# these tests all run two UPDATES to assert that caching is not
View File
@@ -1,8 +1,11 @@
from sqlalchemy import FetchedValue
from sqlalchemy import ForeignKey
from sqlalchemy import Identity
from sqlalchemy import insert
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import testing
from sqlalchemy import update
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import mock
@@ -20,6 +23,8 @@ class BulkTest(testing.AssertsExecutionResults):
class BulkInsertUpdateVersionId(BulkTest, fixtures.MappedTest):
__backend__ = True
@classmethod
def define_tables(cls, metadata):
Table(
@@ -73,6 +78,8 @@ class BulkInsertUpdateVersionId(BulkTest, fixtures.MappedTest):
class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
__backend__ = True
@classmethod
def setup_mappers(cls):
User, Address, Order = cls.classes("User", "Address", "Order")
@@ -82,22 +89,42 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
cls.mapper_registry.map_imperatively(Address, a)
cls.mapper_registry.map_imperatively(Order, o)
def test_bulk_save_return_defaults(self):
@testing.combinations("save_objects", "insert_mappings", "insert_stmt")
def test_bulk_save_return_defaults(self, statement_type):
(User,) = self.classes("User")
s = fixture_session()
objects = [User(name="u1"), User(name="u2"), User(name="u3")]
assert "id" not in objects[0].__dict__
with self.sql_execution_asserter() as asserter:
s.bulk_save_objects(objects, return_defaults=True)
if statement_type == "save_objects":
objects = [User(name="u1"), User(name="u2"), User(name="u3")]
assert "id" not in objects[0].__dict__
returning_users_id = " RETURNING users.id"
with self.sql_execution_asserter() as asserter:
s.bulk_save_objects(objects, return_defaults=True)
elif statement_type == "insert_mappings":
data = [dict(name="u1"), dict(name="u2"), dict(name="u3")]
returning_users_id = " RETURNING users.id"
with self.sql_execution_asserter() as asserter:
s.bulk_insert_mappings(User, data, return_defaults=True)
elif statement_type == "insert_stmt":
data = [dict(name="u1"), dict(name="u2"), dict(name="u3")]
# for statement, "return_defaults" is heuristic on if we are
# a joined inh mapping if we don't otherwise include
# .returning() on the statement itself
returning_users_id = ""
with self.sql_execution_asserter() as asserter:
s.execute(insert(User), data)
asserter.assert_(
Conditional(
testing.db.dialect.insert_executemany_returning,
testing.db.dialect.insert_executemany_returning
or statement_type == "insert_stmt",
[
CompiledSQL(
"INSERT INTO users (name) VALUES (:name)",
"INSERT INTO users (name) "
f"VALUES (:name){returning_users_id}",
[{"name": "u1"}, {"name": "u2"}, {"name": "u3"}],
),
],
@@ -117,7 +144,8 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
],
)
)
eq_(objects[0].__dict__["id"], 1)
if statement_type == "save_objects":
eq_(objects[0].__dict__["id"], 1)
def test_bulk_save_mappings_preserve_order(self):
(User,) = self.classes("User")
@@ -219,8 +247,9 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
)
)
def test_bulk_update(self):
(User,) = self.classes("User")
@testing.combinations("update_mappings", "update_stmt")
def test_bulk_update(self, statement_type):
User = self.classes.User
s = fixture_session(expire_on_commit=False)
objects = [User(name="u1"), User(name="u2"), User(name="u3")]
@@ -228,15 +257,18 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
s.commit()
s = fixture_session()
with self.sql_execution_asserter() as asserter:
s.bulk_update_mappings(
User,
[
{"id": 1, "name": "u1new"},
{"id": 2, "name": "u2"},
{"id": 3, "name": "u3new"},
],
)
data = [
{"id": 1, "name": "u1new"},
{"id": 2, "name": "u2"},
{"id": 3, "name": "u3new"},
]
if statement_type == "update_mappings":
with self.sql_execution_asserter() as asserter:
s.bulk_update_mappings(User, data)
elif statement_type == "update_stmt":
with self.sql_execution_asserter() as asserter:
s.execute(update(User), data)
asserter.assert_(
CompiledSQL(
@@ -303,6 +335,8 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
class BulkUDPostfetchTest(BulkTest, fixtures.MappedTest):
__backend__ = True
@classmethod
def define_tables(cls, metadata):
Table(
@@ -360,6 +394,8 @@ class BulkUDPostfetchTest(BulkTest, fixtures.MappedTest):
class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest):
__backend__ = True
@classmethod
def define_tables(cls, metadata):
Table(
@@ -547,6 +583,8 @@ class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest):
class BulkInheritanceTest(BulkTest, fixtures.MappedTest):
__backend__ = True
@classmethod
def define_tables(cls, metadata):
Table(
@@ -643,6 +681,7 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest):
)
s = fixture_session()
objects = [
Manager(name="m1", status="s1", manager_name="mn1"),
Engineer(name="e1", status="s2", primary_language="l1"),
@@ -669,7 +708,7 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest):
[
CompiledSQL(
"INSERT INTO people (name, type) "
"VALUES (:name, :type)",
"VALUES (:name, :type) RETURNING people.person_id",
[
{"type": "engineer", "name": "e1"},
{"type": "engineer", "name": "e2"},
@@ -798,59 +837,74 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest):
),
)
def test_bulk_insert_joined_inh_return_defaults(self):
@testing.combinations("insert_mappings", "insert_stmt")
def test_bulk_insert_joined_inh_return_defaults(self, statement_type):
Person, Engineer, Manager, Boss = self.classes(
"Person", "Engineer", "Manager", "Boss"
)
s = fixture_session()
with self.sql_execution_asserter() as asserter:
s.bulk_insert_mappings(
Boss,
[
dict(
name="b1",
status="s1",
manager_name="mn1",
golf_swing="g1",
),
dict(
name="b2",
status="s2",
manager_name="mn2",
golf_swing="g2",
),
dict(
name="b3",
status="s3",
manager_name="mn3",
golf_swing="g3",
),
],
return_defaults=True,
)
data = [
dict(
name="b1",
status="s1",
manager_name="mn1",
golf_swing="g1",
),
dict(
name="b2",
status="s2",
manager_name="mn2",
golf_swing="g2",
),
dict(
name="b3",
status="s3",
manager_name="mn3",
golf_swing="g3",
),
]
if statement_type == "insert_mappings":
with self.sql_execution_asserter() as asserter:
s.bulk_insert_mappings(
Boss,
data,
return_defaults=True,
)
elif statement_type == "insert_stmt":
with self.sql_execution_asserter() as asserter:
s.execute(insert(Boss), data)
asserter.assert_(
Conditional(
testing.db.dialect.insert_executemany_returning,
[
CompiledSQL(
"INSERT INTO people (name) VALUES (:name)",
[{"name": "b1"}, {"name": "b2"}, {"name": "b3"}],
"INSERT INTO people (name, type) "
"VALUES (:name, :type) RETURNING people.person_id",
[
{"name": "b1", "type": "boss"},
{"name": "b2", "type": "boss"},
{"name": "b3", "type": "boss"},
],
),
],
[
CompiledSQL(
"INSERT INTO people (name) VALUES (:name)",
[{"name": "b1"}],
"INSERT INTO people (name, type) "
"VALUES (:name, :type)",
[{"name": "b1", "type": "boss"}],
),
CompiledSQL(
"INSERT INTO people (name) VALUES (:name)",
[{"name": "b2"}],
"INSERT INTO people (name, type) "
"VALUES (:name, :type)",
[{"name": "b2", "type": "boss"}],
),
CompiledSQL(
"INSERT INTO people (name) VALUES (:name)",
[{"name": "b3"}],
"INSERT INTO people (name, type) "
"VALUES (:name, :type)",
[{"name": "b3", "type": "boss"}],
),
],
),
@@ -874,15 +928,79 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest):
),
)
@testing.combinations("update_mappings", "update_stmt")
def test_bulk_update(self, statement_type):
Person, Engineer, Manager, Boss = self.classes(
"Person", "Engineer", "Manager", "Boss"
)
s = fixture_session()
b1, b2, b3 = (
Boss(name="b1", status="s1", manager_name="mn1", golf_swing="g1"),
Boss(name="b2", status="s2", manager_name="mn2", golf_swing="g2"),
Boss(name="b3", status="s3", manager_name="mn3", golf_swing="g3"),
)
s.add_all([b1, b2, b3])
s.commit()
# slight non-convenient thing. we have to fill in boss_id here
# for update, this is not sent along automatically. this is not a
# new behavior in bulk
new_data = [
{
"person_id": b1.person_id,
"boss_id": b1.boss_id,
"name": "b1_updated",
"manager_name": "mn1_updated",
},
{
"person_id": b3.person_id,
"boss_id": b3.boss_id,
"manager_name": "mn2_updated",
"golf_swing": "g1_updated",
},
]
if statement_type == "update_mappings":
with self.sql_execution_asserter() as asserter:
s.bulk_update_mappings(Boss, new_data)
elif statement_type == "update_stmt":
with self.sql_execution_asserter() as asserter:
s.execute(update(Boss), new_data)
asserter.assert_(
CompiledSQL(
"UPDATE people SET name=:name WHERE "
"people.person_id = :people_person_id",
[{"name": "b1_updated", "people_person_id": 1}],
),
CompiledSQL(
"UPDATE managers SET manager_name=:manager_name WHERE "
"managers.person_id = :managers_person_id",
[
{"manager_name": "mn1_updated", "managers_person_id": 1},
{"manager_name": "mn2_updated", "managers_person_id": 3},
],
),
CompiledSQL(
"UPDATE boss SET golf_swing=:golf_swing WHERE "
"boss.boss_id = :boss_boss_id",
[{"golf_swing": "g1_updated", "boss_boss_id": 3}],
),
)
class BulkIssue6793Test(BulkTest, fixtures.DeclarativeMappedTest):
__backend__ = True
@classmethod
def setup_classes(cls):
Base = cls.DeclarativeBasic
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
id = Column(Integer, Identity(), primary_key=True)
name = Column(String(255), nullable=False)
def test_issue_6793(self):
@@ -907,7 +1025,8 @@ class BulkIssue6793Test(BulkTest, fixtures.DeclarativeMappedTest):
[{"name": "A"}, {"name": "B"}],
),
CompiledSQL(
"INSERT INTO users (name) VALUES (:name)",
"INSERT INTO users (name) VALUES (:name) "
"RETURNING users.id",
[{"name": "C"}, {"name": "D"}],
),
],
File diff suppressed because it is too large Load Diff
@@ -324,7 +324,6 @@ class EvaluateTest(fixtures.MappedTest):
"""test #3162"""
User = self.classes.User
with expect_raises_message(
evaluator.UnevaluatableError,
r"Custom operator '\^\^' can't be evaluated in "
@@ -1,3 +1,4 @@
from sqlalchemy import bindparam
from sqlalchemy import Boolean
from sqlalchemy import case
from sqlalchemy import column
@@ -7,6 +8,7 @@ from sqlalchemy import exc
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import insert
from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import lambda_stmt
from sqlalchemy import MetaData
@@ -17,6 +19,7 @@ from sqlalchemy import testing
from sqlalchemy import text
from sqlalchemy import update
from sqlalchemy.orm import backref
from sqlalchemy.orm import exc as orm_exc
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
@@ -26,6 +29,7 @@ from sqlalchemy.orm import with_loader_criteria
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_raises
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import in_
from sqlalchemy.testing import not_in
@@ -123,6 +127,25 @@ class UpdateDeleteTest(fixtures.MappedTest):
},
)
def test_update_dont_use_col_key(self):
User = self.classes.User
s = fixture_session()
# make sure objects are present to synchronize
_ = s.query(User).all()
with expect_raises_message(
exc.InvalidRequestError,
"Attribute name not found, can't be synchronized back "
"to objects: 'age_int'",
):
s.execute(update(User).values(age_int=5))
stmt = update(User).values(age=5)
s.execute(stmt)
eq_(s.scalars(select(User.age)).all(), [5, 5, 5, 5])
@testing.combinations("table", "mapper", "both", argnames="bind_type")
@testing.combinations(
"update", "insert", "delete", argnames="statement_type"
@@ -162,7 +185,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
assert_raises_message(
exc.ArgumentError,
"Valid strategies for session synchronization "
"are 'evaluate', 'fetch', False",
"are 'auto', 'evaluate', 'fetch', False",
s.query(User).update,
{},
synchronize_session="fake",
@@ -351,6 +374,12 @@ class UpdateDeleteTest(fixtures.MappedTest):
def test_evaluate_dont_refresh_expired_objects(
self, expire_jane_age, add_filter_criteria
):
"""test #5664.
approach is revised in SQLAlchemy 2.0 to not pre-emptively
unexpire the involved attributes
"""
User = self.classes.User
sess = fixture_session()
@@ -379,15 +408,10 @@ class UpdateDeleteTest(fixtures.MappedTest):
if add_filter_criteria:
if expire_jane_age:
asserter.assert_(
# it has to unexpire jane.name, because jane is not fully
# expired and the criteria needs to look at this particular
# key
CompiledSQL(
"SELECT users.age_int AS users_age_int, "
"users.name AS users_name FROM users "
"WHERE users.id = :pk_1",
[{"pk_1": 4}],
),
# previously, this would unexpire the attribute and
# cause an additional SELECT. The
# 2.0 approach is that if the object has expired attrs
# we just expire the whole thing, avoiding SQL up front
CompiledSQL(
"UPDATE users "
"SET age_int=(users.age_int + :age_int_1) "
@@ -397,14 +421,10 @@ class UpdateDeleteTest(fixtures.MappedTest):
)
else:
asserter.assert_(
# it has to unexpire jane.name, because jane is not fully
# expired and the criteria needs to look at this particular
# key
CompiledSQL(
"SELECT users.name AS users_name FROM users "
"WHERE users.id = :pk_1",
[{"pk_1": 4}],
),
# previously, this would unexpire the attribute and
# cause an additional SELECT. The
# 2.0 approach is that if the object has expired attrs
# we just expire the whole thing, avoiding SQL up front
CompiledSQL(
"UPDATE users SET "
"age_int=(users.age_int + :age_int_1) "
@@ -443,9 +463,9 @@ class UpdateDeleteTest(fixtures.MappedTest):
),
]
if expire_jane_age and not add_filter_criteria:
if expire_jane_age:
to_assert.append(
# refresh jane
# refresh jane for partial attributes
CompiledSQL(
"SELECT users.age_int AS users_age_int, "
"users.name AS users_name FROM users "
@@ -455,6 +475,75 @@ class UpdateDeleteTest(fixtures.MappedTest):
)
asserter.assert_(*to_assert)
@testing.combinations(True, False, argnames="is_evaluable")
def test_auto_synchronize(self, is_evaluable):
User = self.classes.User
sess = fixture_session()
john, jack, jill, jane = sess.query(User).order_by(User.id).all()
if is_evaluable:
crit = or_(User.name == "jack", User.name == "jane")
else:
crit = case((User.name.in_(["jack", "jane"]), True), else_=False)
with self.sql_execution_asserter() as asserter:
sess.execute(update(User).where(crit).values(age=User.age + 10))
if is_evaluable:
asserter.assert_(
CompiledSQL(
"UPDATE users SET age_int=(users.age_int + :age_int_1) "
"WHERE users.name = :name_1 OR users.name = :name_2",
[{"age_int_1": 10, "name_1": "jack", "name_2": "jane"}],
),
)
elif testing.db.dialect.update_returning:
asserter.assert_(
CompiledSQL(
"UPDATE users SET age_int=(users.age_int + :age_int_1) "
"WHERE CASE WHEN (users.name IN (__[POSTCOMPILE_name_1])) "
"THEN :param_1 ELSE :param_2 END = 1 RETURNING users.id",
[
{
"age_int_1": 10,
"name_1": ["jack", "jane"],
"param_1": True,
"param_2": False,
}
],
),
)
else:
asserter.assert_(
CompiledSQL(
"SELECT users.id FROM users WHERE CASE WHEN "
"(users.name IN (__[POSTCOMPILE_name_1])) "
"THEN :param_1 ELSE :param_2 END = 1",
[
{
"name_1": ["jack", "jane"],
"param_1": True,
"param_2": False,
}
],
),
CompiledSQL(
"UPDATE users SET age_int=(users.age_int + :age_int_1) "
"WHERE CASE WHEN (users.name IN (__[POSTCOMPILE_name_1])) "
"THEN :param_1 ELSE :param_2 END = 1",
[
{
"age_int_1": 10,
"name_1": ["jack", "jane"],
"param_1": True,
"param_2": False,
}
],
),
)
def test_fetch_dont_refresh_expired_objects(self):
User = self.classes.User
@@ -518,17 +607,25 @@ class UpdateDeleteTest(fixtures.MappedTest):
),
)
def test_delete(self):
@testing.combinations(False, None, "auto", "evaluate", "fetch")
def test_delete(self, synchronize_session):
User = self.classes.User
sess = fixture_session()
john, jack, jill, jane = sess.query(User).order_by(User.id).all()
sess.query(User).filter(
or_(User.name == "john", User.name == "jill")
).delete()
assert john not in sess and jill not in sess
stmt = delete(User).filter(
or_(User.name == "john", User.name == "jill")
)
if synchronize_session is not None:
stmt = stmt.execution_options(
synchronize_session=synchronize_session
)
sess.execute(stmt)
if synchronize_session not in (False, None):
assert john not in sess and jill not in sess
eq_(sess.query(User).order_by(User.id).all(), [jack, jane])
@@ -629,6 +726,33 @@ class UpdateDeleteTest(fixtures.MappedTest):
eq_(sess.query(User).order_by(User.id).all(), [jack, jill, jane])
def test_update_multirow_not_supported(self):
User = self.classes.User
sess = fixture_session()
with expect_raises_message(
exc.InvalidRequestError,
"WHERE clause with bulk ORM UPDATE not supported " "right now.",
):
sess.execute(
update(User).where(User.id == bindparam("id")),
[{"id": 1, "age": 27}, {"id": 2, "age": 37}],
)
def test_delete_bulk_not_supported(self):
User = self.classes.User
sess = fixture_session()
with expect_raises_message(
exc.InvalidRequestError, "Bulk ORM DELETE not supported right now."
):
sess.execute(
delete(User),
[{"id": 1}, {"id": 2}],
)
def test_update(self):
User, users = self.classes.User, self.tables.users
@@ -640,6 +764,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
)
eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27])
eq_(
sess.query(User.age).order_by(User.id).all(),
list(zip([25, 37, 29, 27])),
@@ -974,7 +1099,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
)
@testing.requires.update_returning
def test_update_explicit_returning(self):
def test_update_evaluate_w_explicit_returning(self):
User = self.classes.User
sess = fixture_session()
@@ -987,6 +1112,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
.filter(User.age > 29)
.values({"age": User.age - 10})
.returning(User.id)
.execution_options(synchronize_session="evaluate")
)
rows = sess.execute(stmt).all()
@@ -1006,24 +1132,41 @@ class UpdateDeleteTest(fixtures.MappedTest):
)
@testing.requires.update_returning
def test_no_fetch_w_explicit_returning(self):
@testing.combinations("update", "delete", argnames="crud_type")
def test_fetch_w_explicit_returning(self, crud_type):
User = self.classes.User
sess = fixture_session()
stmt = (
update(User)
.filter(User.age > 29)
.values({"age": User.age - 10})
.execution_options(synchronize_session="fetch")
.returning(User.id)
)
with expect_raises_message(
exc.InvalidRequestError,
r"Can't use synchronize_session='fetch' "
r"with explicit returning\(\)",
):
sess.execute(stmt)
if crud_type == "update":
stmt = (
update(User)
.filter(User.age > 29)
.values({"age": User.age - 10})
.execution_options(synchronize_session="fetch")
.returning(User, User.name)
)
expected = [
(User(age=37), "jack"),
(User(age=27), "jane"),
]
elif crud_type == "delete":
stmt = (
delete(User)
.filter(User.age > 29)
.execution_options(synchronize_session="fetch")
.returning(User, User.name)
)
expected = [
(User(age=47), "jack"),
(User(age=37), "jane"),
]
else:
assert False
result = sess.execute(stmt)
eq_(result.all(), expected)
@testing.combinations(True, False, argnames="implicit_returning")
def test_delete_fetch_returning(self, implicit_returning):
@@ -1142,7 +1285,8 @@ class UpdateDeleteTest(fixtures.MappedTest):
list(zip([25, 47, 44, 37])),
)
def test_update_changes_resets_dirty(self):
@testing.combinations("orm", "bulk")
def test_update_changes_resets_dirty(self, update_type):
User = self.classes.User
sess = fixture_session(autoflush=False)
@@ -1155,9 +1299,30 @@ class UpdateDeleteTest(fixtures.MappedTest):
# autoflush is false. therefore our '50' and '37' are getting
# blown away by this operation.
sess.query(User).filter(User.age > 29).update(
{"age": User.age - 10}, synchronize_session="evaluate"
)
if update_type == "orm":
sess.execute(
update(User)
.filter(User.age > 29)
.values({"age": User.age - 10}),
execution_options=dict(synchronize_session="evaluate"),
)
elif update_type == "bulk":
data = [
{"id": john.id, "age": 25},
{"id": jack.id, "age": 37},
{"id": jill.id, "age": 29},
{"id": jane.id, "age": 27},
]
sess.execute(
update(User),
data,
execution_options=dict(synchronize_session="evaluate"),
)
else:
assert False
for x in (john, jack, jill, jane):
assert not sess.is_modified(x)
@@ -1171,6 +1336,93 @@ class UpdateDeleteTest(fixtures.MappedTest):
assert not sess.is_modified(john)
assert not sess.is_modified(jack)
@testing.combinations(
None, False, "evaluate", "fetch", argnames="synchronize_session"
)
@testing.combinations(True, False, argnames="homogeneous_keys")
def test_bulk_update_synchronize_session(
self, synchronize_session, homogeneous_keys
):
User = self.classes.User
sess = fixture_session(expire_on_commit=False)
john, jack, jill, jane = sess.query(User).order_by(User.id).all()
if homogeneous_keys:
data = [
{"id": john.id, "age": 35},
{"id": jack.id, "age": 27},
{"id": jill.id, "age": 30},
]
else:
data = [
{"id": john.id, "age": 35},
{"id": jack.id, "name": "new jack"},
{"id": jill.id, "age": 30, "name": "new jill"},
]
with self.sql_execution_asserter() as asserter:
if synchronize_session is not None:
opts = {"synchronize_session": synchronize_session}
else:
opts = {}
if synchronize_session == "fetch":
with expect_raises_message(
exc.InvalidRequestError,
"The 'fetch' synchronization strategy is not available "
"for 'bulk' ORM updates",
):
sess.execute(update(User), data, execution_options=opts)
return
else:
sess.execute(update(User), data, execution_options=opts)
if homogeneous_keys:
asserter.assert_(
CompiledSQL(
"UPDATE users SET age_int=:age_int "
"WHERE users.id = :users_id",
[
{"age_int": 35, "users_id": 1},
{"age_int": 27, "users_id": 2},
{"age_int": 30, "users_id": 3},
],
)
)
else:
asserter.assert_(
CompiledSQL(
"UPDATE users SET age_int=:age_int "
"WHERE users.id = :users_id",
[{"age_int": 35, "users_id": 1}],
),
CompiledSQL(
"UPDATE users SET name=:name WHERE users.id = :users_id",
[{"name": "new jack", "users_id": 2}],
),
CompiledSQL(
"UPDATE users SET name=:name, age_int=:age_int "
"WHERE users.id = :users_id",
[{"name": "new jill", "age_int": 30, "users_id": 3}],
),
)
if synchronize_session is False:
eq_(jill.name, "jill")
eq_(jack.name, "jack")
eq_(jill.age, 29)
eq_(jack.age, 47)
else:
if not homogeneous_keys:
eq_(jill.name, "new jill")
eq_(jack.name, "new jack")
eq_(jack.age, 47)
else:
eq_(jack.age, 27)
eq_(jill.age, 30)
def test_update_changes_with_autoflush(self):
User = self.classes.User
@@ -1214,7 +1466,8 @@ class UpdateDeleteTest(fixtures.MappedTest):
)
@testing.fails_if(lambda: not testing.db.dialect.supports_sane_rowcount)
def test_update_returns_rowcount(self):
@testing.combinations("auto", "fetch", "evaluate")
def test_update_returns_rowcount(self, synchronize_session):
User = self.classes.User
sess = fixture_session()
@@ -1222,20 +1475,25 @@ class UpdateDeleteTest(fixtures.MappedTest):
rowcount = (
sess.query(User)
.filter(User.age > 29)
.update({"age": User.age + 0})
.update(
{"age": User.age + 0}, synchronize_session=synchronize_session
)
)
eq_(rowcount, 2)
rowcount = (
sess.query(User)
.filter(User.age > 29)
.update({"age": User.age - 10})
.update(
{"age": User.age - 10}, synchronize_session=synchronize_session
)
)
eq_(rowcount, 2)
# test future
result = sess.execute(
update(User).where(User.age > 19).values({"age": User.age - 10})
update(User).where(User.age > 19).values({"age": User.age - 10}),
execution_options={"synchronize_session": synchronize_session},
)
eq_(result.rowcount, 4)
@@ -1327,12 +1585,17 @@ class UpdateDeleteTest(fixtures.MappedTest):
)
assert john not in sess
def test_evaluate_before_update(self):
@testing.combinations(True, False)
def test_evaluate_before_update(self, full_expiration):
User = self.classes.User
sess = fixture_session()
john = sess.query(User).filter_by(name="john").one()
sess.expire(john, ["age"])
if full_expiration:
sess.expire(john)
else:
sess.expire(john, ["age"])
# eval must be before the update. otherwise
# we eval john, age has been expired and doesn't
@@ -1356,17 +1619,47 @@ class UpdateDeleteTest(fixtures.MappedTest):
eq_(john.name, "j2")
eq_(john.age, 40)
def test_evaluate_before_delete(self):
@testing.combinations(True, False)
def test_evaluate_before_delete(self, full_expiration):
User = self.classes.User
sess = fixture_session()
john = sess.query(User).filter_by(name="john").one()
sess.expire(john, ["age"])
jill = sess.query(User).filter_by(name="jill").one()
jane = sess.query(User).filter_by(name="jane").one()
sess.query(User).filter_by(name="john").filter_by(age=25).delete(
if full_expiration:
sess.expire(jill)
sess.expire(john)
else:
sess.expire(jill, ["age"])
sess.expire(john, ["age"])
sess.query(User).filter(or_(User.age == 25, User.age == 37)).delete(
synchronize_session="evaluate"
)
assert john not in sess
# was fully deleted
assert jane not in sess
# deleted object was expired, but not otherwise affected
assert jill in sess
# deleted object was expired, but not otherwise affected
assert john in sess
# partially expired row fully expired
assert inspect(jill).expired
# non-deleted row still present
eq_(jill.age, 29)
# partially expired row fully expired
assert inspect(john).expired
# is deleted
with expect_raises(orm_exc.ObjectDeletedError):
john.name
def test_fetch_before_delete(self):
User = self.classes.User
@@ -1378,6 +1671,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
sess.query(User).filter_by(name="john").filter_by(age=25).delete(
synchronize_session="fetch"
)
assert john not in sess
def test_update_unordered_dict(self):
@@ -1495,6 +1789,60 @@ class UpdateDeleteTest(fixtures.MappedTest):
]
eq_(["name", "age_int"], cols)
@testing.requires.sqlite
def test_sharding_extension_returning_mismatch(self, testing_engine):
"""test one horizontal shard case where the given binds don't match
for RETURNING support; we dont support this.
See test/ext/test_horizontal_shard.py for complete round trip
test cases for ORM update/delete
"""
e1 = testing_engine("sqlite://")
e2 = testing_engine("sqlite://")
e1.connect().close()
e2.connect().close()
e1.dialect.update_returning = True
e2.dialect.update_returning = False
engines = [e1, e2]
# a simulated version of the horizontal sharding extension
def execute_and_instances(orm_context):
execution_options = dict(orm_context.local_execution_options)
partial = []
for engine in engines:
bind_arguments = dict(orm_context.bind_arguments)
bind_arguments["bind"] = engine
result_ = orm_context.invoke_statement(
bind_arguments=bind_arguments,
execution_options=execution_options,
)
partial.append(result_)
return partial[0].merge(*partial[1:])
User = self.classes.User
session = Session()
event.listen(
session, "do_orm_execute", execute_and_instances, retval=True
)
stmt = (
update(User)
.filter(User.id == 15)
.values(age=123)
.execution_options(synchronize_session="fetch")
)
with expect_raises_message(
exc.InvalidRequestError,
"For synchronize_session='fetch', can't mix multiple backends "
"where some support RETURNING and others don't",
):
session.execute(stmt)
class UpdateDeleteIgnoresLoadersTest(fixtures.MappedTest):
@classmethod
@@ -1748,6 +2096,7 @@ class UpdateDeleteFromTest(fixtures.MappedTest):
"Could not evaluate current criteria in Python.",
q.update,
{"samename": "ed"},
synchronize_session="evaluate",
)
@testing.requires.multi_table_update
@@ -1901,7 +2250,7 @@ class ExpressionUpdateTest(fixtures.MappedTest):
sess.commit()
eq_(d1.cnt, 0)
sess.query(Data).update({Data.cnt: Data.cnt + 1})
sess.query(Data).update({Data.cnt: Data.cnt + 1}, "evaluate")
sess.flush()
eq_(d1.cnt, 1)
@@ -2443,7 +2792,8 @@ class LoadFromReturningTest(fixtures.MappedTest):
)
@testing.requires.update_returning
def test_load_from_update(self, connection):
@testing.combinations(True, False, argnames="use_from_statement")
def test_load_from_update(self, connection, use_from_statement):
User = self.classes.User
stmt = (
@@ -2453,7 +2803,16 @@ class LoadFromReturningTest(fixtures.MappedTest):
.returning(User)
)
stmt = select(User).from_statement(stmt)
if use_from_statement:
# this is now a legacy-ish case, because as of 2.0 you can just
# use returning() directly to get the objects back.
#
# when from_statement is used, the UPDATE statement is no
# longer interpreted by
# BulkUDCompileState.orm_pre_session_exec or
# BulkUDCompileState.orm_setup_cursor_result. The compilation
# level routines still take place though
stmt = select(User).from_statement(stmt)
with Session(connection) as sess:
rows = sess.execute(stmt).scalars().all()
@@ -2468,7 +2827,8 @@ class LoadFromReturningTest(fixtures.MappedTest):
("multiple", testing.requires.multivalues_inserts),
argnames="params",
)
def test_load_from_insert(self, connection, params):
@testing.combinations(True, False, argnames="use_from_statement")
def test_load_from_insert(self, connection, params, use_from_statement):
User = self.classes.User
if params == "multiple":
@@ -2484,7 +2844,8 @@ class LoadFromReturningTest(fixtures.MappedTest):
stmt = insert(User).values(values).returning(User)
stmt = select(User).from_statement(stmt)
if use_from_statement:
stmt = select(User).from_statement(stmt)
with Session(connection) as sess:
rows = sess.execute(stmt).scalars().all()
@@ -2505,3 +2866,25 @@ class LoadFromReturningTest(fixtures.MappedTest):
)
else:
assert False
@testing.requires.delete_returning
@testing.combinations(True, False, argnames="use_from_statement")
def test_load_from_delete(self, connection, use_from_statement):
User = self.classes.User
stmt = (
delete(User).where(User.name.in_(["jack", "jill"])).returning(User)
)
if use_from_statement:
stmt = select(User).from_statement(stmt)
with Session(connection) as sess:
rows = sess.execute(stmt).scalars().all()
eq_(
rows,
[User(name="jack", age=47), User(name="jill", age=29)],
)
# TODO: state of above objects should be "deleted"
+2 -1
View File
@@ -2012,7 +2012,8 @@ class JoinedNoFKSortingTest(fixtures.MappedTest):
and testing.db.dialect.supports_default_metavalue,
[
CompiledSQL(
"INSERT INTO a (id) VALUES (DEFAULT)", [{}, {}, {}, {}]
"INSERT INTO a (id) VALUES (DEFAULT) RETURNING a.id",
[{}, {}, {}, {}],
),
],
[
+10 -1
View File
@@ -326,6 +326,7 @@ class BindIntegrationTest(_fixtures.FixtureTest):
),
(
lambda User: update(User)
.execution_options(synchronize_session=False)
.values(name="not ed")
.where(User.name == "ed"),
lambda User: {"clause": mock.ANY, "mapper": inspect(User)},
@@ -392,7 +393,15 @@ class BindIntegrationTest(_fixtures.FixtureTest):
engine = {"e1": e1, "e2": e2, "e3": e3}[expected_engine_name]
with mock.patch(
"sqlalchemy.orm.context.ORMCompileState.orm_setup_cursor_result"
"sqlalchemy.orm.context." "ORMCompileState.orm_setup_cursor_result"
), mock.patch(
"sqlalchemy.orm.context.ORMCompileState.orm_execute_statement"
), mock.patch(
"sqlalchemy.orm.bulk_persistence."
"BulkORMInsert.orm_execute_statement"
), mock.patch(
"sqlalchemy.orm.bulk_persistence."
"BulkUDCompileState.orm_setup_cursor_result"
):
sess.execute(statement)
+204 -3
View File
@@ -1,8 +1,10 @@
import dataclasses
import operator
import random
import sqlalchemy as sa
from sqlalchemy import ForeignKey
from sqlalchemy import insert
from sqlalchemy import Integer
from sqlalchemy import select
from sqlalchemy import String
@@ -233,7 +235,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
is g.edges[1]
)
def test_bulk_update_sql(self):
def test_update_crit_sql(self):
Edge, Point = (self.classes.Edge, self.classes.Point)
sess = self._fixture()
@@ -258,7 +260,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
dialect="default",
)
def test_bulk_update_evaluate(self):
def test_update_crit_evaluate(self):
Edge, Point = (self.classes.Edge, self.classes.Point)
sess = self._fixture()
@@ -287,7 +289,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
eq_(e1.end, Point(17, 8))
def test_bulk_update_fetch(self):
def test_update_crit_fetch(self):
Edge, Point = (self.classes.Edge, self.classes.Point)
sess = self._fixture()
@@ -305,6 +307,205 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
eq_(e1.end, Point(17, 8))
@testing.combinations(
"legacy", "statement", "values", "stmt_returning", "values_returning"
)
def test_bulk_insert(self, type_):
Edge, Point = (self.classes.Edge, self.classes.Point)
Graph = self.classes.Graph
sess = self._fixture()
graph = Graph(id=2)
sess.add(graph)
sess.flush()
graph_id = 2
data = [
{
"start": Point(random.randint(1, 50), random.randint(1, 50)),
"end": Point(random.randint(1, 50), random.randint(1, 50)),
"graph_id": graph_id,
}
for i in range(25)
]
returning = False
if type_ == "statement":
sess.execute(insert(Edge), data)
elif type_ == "stmt_returning":
result = sess.scalars(insert(Edge).returning(Edge), data)
returning = True
elif type_ == "values":
sess.execute(insert(Edge).values(data))
elif type_ == "values_returning":
result = sess.scalars(insert(Edge).values(data).returning(Edge))
returning = True
elif type_ == "legacy":
sess.bulk_insert_mappings(Edge, data)
else:
assert False
if returning:
eq_(result.all(), [Edge(rec["start"], rec["end"]) for rec in data])
edges = self.tables.edges
eq_(
sess.execute(
select(edges.c["x1", "y1", "x2", "y2"])
.where(edges.c.graph_id == graph_id)
.order_by(edges.c.id)
).all(),
[
(e["start"].x, e["start"].y, e["end"].x, e["end"].y)
for e in data
],
)
@testing.combinations("legacy", "statement")
def test_bulk_insert_heterogeneous(self, type_):
Edge, Point = (self.classes.Edge, self.classes.Point)
Graph = self.classes.Graph
sess = self._fixture()
graph = Graph(id=2)
sess.add(graph)
sess.flush()
graph_id = 2
d1 = [
{
"start": Point(random.randint(1, 50), random.randint(1, 50)),
"end": Point(random.randint(1, 50), random.randint(1, 50)),
"graph_id": graph_id,
}
for i in range(3)
]
d2 = [
{
"start": Point(random.randint(1, 50), random.randint(1, 50)),
"graph_id": graph_id,
}
for i in range(2)
]
d3 = [
{
"x2": random.randint(1, 50),
"y2": random.randint(1, 50),
"graph_id": graph_id,
}
for i in range(2)
]
data = d1 + d2 + d3
random.shuffle(data)
assert_data = [
{
"start": d["start"] if "start" in d else None,
"end": d["end"]
if "end" in d
else Point(d["x2"], d["y2"])
if "x2" in d
else None,
"graph_id": d["graph_id"],
}
for d in data
]
if type_ == "statement":
sess.execute(insert(Edge), data)
elif type_ == "legacy":
sess.bulk_insert_mappings(Edge, data)
else:
assert False
edges = self.tables.edges
eq_(
sess.execute(
select(edges.c["x1", "y1", "x2", "y2"])
.where(edges.c.graph_id == graph_id)
.order_by(edges.c.id)
).all(),
[
(
e["start"].x if e["start"] else None,
e["start"].y if e["start"] else None,
e["end"].x if e["end"] else None,
e["end"].y if e["end"] else None,
)
for e in assert_data
],
)
@testing.combinations("legacy", "statement")
def test_bulk_update(self, type_):
Edge, Point = (self.classes.Edge, self.classes.Point)
Graph = self.classes.Graph
sess = self._fixture()
graph = Graph(id=2)
sess.add(graph)
sess.flush()
graph_id = 2
data = [
{
"start": Point(random.randint(1, 50), random.randint(1, 50)),
"end": Point(random.randint(1, 50), random.randint(1, 50)),
"graph_id": graph_id,
}
for i in range(25)
]
sess.execute(insert(Edge), data)
inserted_data = [
dict(row._mapping)
for row in sess.execute(
select(Edge.id, Edge.start, Edge.end, Edge.graph_id)
.where(Edge.graph_id == graph_id)
.order_by(Edge.id)
)
]
to_update = []
updated_pks = {}
for rec in random.choices(inserted_data, k=7):
rec_copy = dict(rec)
updated_pks[rec_copy["id"]] = rec_copy
rec_copy["start"] = Point(
random.randint(1, 50), random.randint(1, 50)
)
rec_copy["end"] = Point(
random.randint(1, 50), random.randint(1, 50)
)
to_update.append(rec_copy)
expected_dataset = [
updated_pks[row["id"]] if row["id"] in updated_pks else row
for row in inserted_data
]
if type_ == "statement":
sess.execute(update(Edge), to_update)
elif type_ == "legacy":
sess.bulk_update_mappings(Edge, to_update)
else:
assert False
edges = self.tables.edges
eq_(
sess.execute(
select(edges.c["x1", "y1", "x2", "y2"])
.where(edges.c.graph_id == graph_id)
.order_by(edges.c.id)
).all(),
[
(e["start"].x, e["start"].y, e["end"].x, e["end"].y)
for e in expected_dataset
],
)
def test_get_history(self):
Edge = self.classes.Edge
Point = self.classes.Point
+1 -1
View File
@@ -1122,7 +1122,7 @@ class OneToManyManyToOneTest(fixtures.MappedTest):
[
CompiledSQL(
"INSERT INTO ball (person_id, data) "
"VALUES (:person_id, :data)",
"VALUES (:person_id, :data) RETURNING ball.id",
[
{"person_id": None, "data": "some data"},
{"person_id": None, "data": "some data"},
+4
View File
@@ -383,20 +383,24 @@ class ComputedDefaultsOnUpdateTest(fixtures.MappedTest):
CompiledSQL(
"UPDATE test SET foo=:foo WHERE test.id = :test_id",
[{"foo": 5, "test_id": 1}],
enable_returning=False,
),
CompiledSQL(
"UPDATE test SET foo=:foo WHERE test.id = :test_id",
[{"foo": 6, "test_id": 2}],
enable_returning=False,
),
CompiledSQL(
"SELECT test.bar AS test_bar FROM test "
"WHERE test.id = :pk_1",
[{"pk_1": 1}],
enable_returning=False,
),
CompiledSQL(
"SELECT test.bar AS test_bar FROM test "
"WHERE test.id = :pk_1",
[{"pk_1": 2}],
enable_returning=False,
),
)
else:
+11 -2
View File
@@ -661,8 +661,17 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest):
canary = self._flag_fixture(sess)
sess.execute(delete(User).filter_by(id=18))
sess.execute(update(User).filter_by(id=18).values(name="eighteen"))
sess.execute(
delete(User)
.filter_by(id=18)
.execution_options(synchronize_session="evaluate")
)
sess.execute(
update(User)
.filter_by(id=18)
.values(name="eighteen")
.execution_options(synchronize_session="evaluate")
)
eq_(
canary.mock_calls,
+4 -2
View File
@@ -2868,12 +2868,14 @@ class SaveTest2(_fixtures.FixtureTest):
testing.db.dialect.insert_executemany_returning,
[
CompiledSQL(
"INSERT INTO users (name) VALUES (:name)",
"INSERT INTO users (name) VALUES (:name) "
"RETURNING users.id",
[{"name": "u1"}, {"name": "u2"}],
),
CompiledSQL(
"INSERT INTO addresses (user_id, email_address) "
"VALUES (:user_id, :email_address)",
"VALUES (:user_id, :email_address) "
"RETURNING addresses.id",
[
{"user_id": 1, "email_address": "a1"},
{"user_id": 2, "email_address": "a2"},
+27 -7
View File
@@ -98,7 +98,8 @@ class RudimentaryFlushTest(UOWTest):
[
CompiledSQL(
"INSERT INTO addresses (user_id, email_address) "
"VALUES (:user_id, :email_address)",
"VALUES (:user_id, :email_address) "
"RETURNING addresses.id",
lambda ctx: [
{"email_address": "a1", "user_id": u1.id},
{"email_address": "a2", "user_id": u1.id},
@@ -220,7 +221,8 @@ class RudimentaryFlushTest(UOWTest):
[
CompiledSQL(
"INSERT INTO addresses (user_id, email_address) "
"VALUES (:user_id, :email_address)",
"VALUES (:user_id, :email_address) "
"RETURNING addresses.id",
lambda ctx: [
{"email_address": "a1", "user_id": u1.id},
{"email_address": "a2", "user_id": u1.id},
@@ -889,7 +891,7 @@ class SingleCycleTest(UOWTest):
[
CompiledSQL(
"INSERT INTO nodes (parent_id, data) VALUES "
"(:parent_id, :data)",
"(:parent_id, :data) RETURNING nodes.id",
lambda ctx: [
{"parent_id": n1.id, "data": "n2"},
{"parent_id": n1.id, "data": "n3"},
@@ -1003,7 +1005,7 @@ class SingleCycleTest(UOWTest):
[
CompiledSQL(
"INSERT INTO nodes (parent_id, data) VALUES "
"(:parent_id, :data)",
"(:parent_id, :data) RETURNING nodes.id",
lambda ctx: [
{"parent_id": n1.id, "data": "n2"},
{"parent_id": n1.id, "data": "n3"},
@@ -1165,7 +1167,7 @@ class SingleCycleTest(UOWTest):
[
CompiledSQL(
"INSERT INTO nodes (parent_id, data) VALUES "
"(:parent_id, :data)",
"(:parent_id, :data) RETURNING nodes.id",
lambda ctx: [
{"parent_id": n1.id, "data": "n11"},
{"parent_id": n1.id, "data": "n12"},
@@ -1196,7 +1198,7 @@ class SingleCycleTest(UOWTest):
[
CompiledSQL(
"INSERT INTO nodes (parent_id, data) VALUES "
"(:parent_id, :data)",
"(:parent_id, :data) RETURNING nodes.id",
lambda ctx: [
{"parent_id": n12.id, "data": "n121"},
{"parent_id": n12.id, "data": "n122"},
@@ -2099,7 +2101,7 @@ class BatchInsertsTest(fixtures.MappedTest, testing.AssertsExecutionResults):
testing.db.dialect.insert_executemany_returning,
[
CompiledSQL(
"INSERT INTO t (data) VALUES (:data)",
"INSERT INTO t (data) VALUES (:data) RETURNING t.id",
[{"data": "t1"}, {"data": "t2"}],
),
],
@@ -2472,20 +2474,24 @@ class EagerDefaultsTest(fixtures.MappedTest):
CompiledSQL(
"INSERT INTO test (id, foo) VALUES (:id, 2 + 5)",
[{"id": 1}],
enable_returning=False,
),
CompiledSQL(
"INSERT INTO test (id, foo) VALUES (:id, 5 + 5)",
[{"id": 2}],
enable_returning=False,
),
CompiledSQL(
"SELECT test.foo AS test_foo FROM test "
"WHERE test.id = :pk_1",
[{"pk_1": 1}],
enable_returning=False,
),
CompiledSQL(
"SELECT test.foo AS test_foo FROM test "
"WHERE test.id = :pk_1",
[{"pk_1": 2}],
enable_returning=False,
),
)
@@ -2678,20 +2684,24 @@ class EagerDefaultsTest(fixtures.MappedTest):
CompiledSQL(
"UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
[{"foo": 5, "test2_id": 1}],
enable_returning=False,
),
CompiledSQL(
"UPDATE test2 SET foo=:foo, bar=:bar "
"WHERE test2.id = :test2_id",
[{"foo": 6, "bar": 10, "test2_id": 2}],
enable_returning=False,
),
CompiledSQL(
"UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
[{"foo": 7, "test2_id": 3}],
enable_returning=False,
),
CompiledSQL(
"UPDATE test2 SET foo=:foo, bar=:bar "
"WHERE test2.id = :test2_id",
[{"foo": 8, "bar": 12, "test2_id": 4}],
enable_returning=False,
),
CompiledSQL(
"SELECT test2.bar AS test2_bar FROM test2 "
@@ -2772,31 +2782,37 @@ class EagerDefaultsTest(fixtures.MappedTest):
"UPDATE test4 SET foo=:foo, bar=5 + 3 "
"WHERE test4.id = :test4_id",
[{"foo": 5, "test4_id": 1}],
enable_returning=False,
),
CompiledSQL(
"UPDATE test4 SET foo=:foo, bar=:bar "
"WHERE test4.id = :test4_id",
[{"foo": 6, "bar": 10, "test4_id": 2}],
enable_returning=False,
),
CompiledSQL(
"UPDATE test4 SET foo=:foo, bar=5 + 3 "
"WHERE test4.id = :test4_id",
[{"foo": 7, "test4_id": 3}],
enable_returning=False,
),
CompiledSQL(
"UPDATE test4 SET foo=:foo, bar=:bar "
"WHERE test4.id = :test4_id",
[{"foo": 8, "bar": 12, "test4_id": 4}],
enable_returning=False,
),
CompiledSQL(
"SELECT test4.bar AS test4_bar FROM test4 "
"WHERE test4.id = :pk_1",
[{"pk_1": 1}],
enable_returning=False,
),
CompiledSQL(
"SELECT test4.bar AS test4_bar FROM test4 "
"WHERE test4.id = :pk_1",
[{"pk_1": 3}],
enable_returning=False,
),
],
),
@@ -2871,20 +2887,24 @@ class EagerDefaultsTest(fixtures.MappedTest):
"UPDATE test2 SET foo=:foo, bar=1 + 1 "
"WHERE test2.id = :test2_id",
[{"foo": 5, "test2_id": 1}],
enable_returning=False,
),
CompiledSQL(
"UPDATE test2 SET foo=:foo, bar=:bar "
"WHERE test2.id = :test2_id",
[{"foo": 6, "bar": 10, "test2_id": 2}],
enable_returning=False,
),
CompiledSQL(
"UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
[{"foo": 7, "test2_id": 3}],
enable_returning=False,
),
CompiledSQL(
"UPDATE test2 SET foo=:foo, bar=5 + 7 "
"WHERE test2.id = :test2_id",
[{"foo": 8, "test2_id": 4}],
enable_returning=False,
),
CompiledSQL(
"SELECT test2.bar AS test2_bar FROM test2 "
+6 -4
View File
@@ -1424,12 +1424,10 @@ class ServerVersioningTest(fixtures.MappedTest):
sess.add(f1)
statements = [
# note that the assertsql tests the rule against
# "default" - on a "returning" backend, the statement
# includes "RETURNING"
CompiledSQL(
"INSERT INTO version_table (version_id, value) "
"VALUES (1, :value)",
"VALUES (1, :value) "
"RETURNING version_table.id, version_table.version_id",
lambda ctx: [{"value": "f1"}],
)
]
@@ -1493,6 +1491,7 @@ class ServerVersioningTest(fixtures.MappedTest):
"value": "f2",
}
],
enable_returning=False,
),
CompiledSQL(
"SELECT version_table.version_id "
@@ -1618,6 +1617,7 @@ class ServerVersioningTest(fixtures.MappedTest):
"value": "f1a",
}
],
enable_returning=False,
),
CompiledSQL(
"UPDATE version_table SET version_id=2, value=:value "
@@ -1630,6 +1630,7 @@ class ServerVersioningTest(fixtures.MappedTest):
"value": "f2a",
}
],
enable_returning=False,
),
CompiledSQL(
"UPDATE version_table SET version_id=2, value=:value "
@@ -1642,6 +1643,7 @@ class ServerVersioningTest(fixtures.MappedTest):
"value": "f3a",
}
],
enable_returning=False,
),
CompiledSQL(
"SELECT version_table.version_id "
+46 -1
View File
@@ -100,10 +100,55 @@ class CursorResultTest(fixtures.TablesTest):
Table(
"test",
metadata,
Column("x", Integer, primary_key=True),
Column(
"x", Integer, primary_key=True, test_needs_autoincrement=False
),
Column("y", String(50)),
)
@testing.requires.insert_returning
def test_splice_horizontally(self, connection):
users = self.tables.users
addresses = self.tables.addresses
r1 = connection.execute(
users.insert().returning(users.c.user_name, users.c.user_id),
[
dict(user_id=1, user_name="john"),
dict(user_id=2, user_name="jack"),
],
)
r2 = connection.execute(
addresses.insert().returning(
addresses.c.address_id,
addresses.c.address,
addresses.c.user_id,
),
[
dict(address_id=1, user_id=1, address="foo@bar.com"),
dict(address_id=2, user_id=2, address="bar@bat.com"),
],
)
rows = r1.splice_horizontally(r2).all()
eq_(
rows,
[
("john", 1, 1, "foo@bar.com", 1),
("jack", 2, 2, "bar@bat.com", 2),
],
)
eq_(rows[0]._mapping[users.c.user_id], 1)
eq_(rows[0]._mapping[addresses.c.user_id], 1)
eq_(rows[1].address, "bar@bat.com")
with expect_raises_message(
exc.InvalidRequestError, "Ambiguous column name 'user_id'"
):
rows[0].user_id
def test_keys_no_rows(self, connection):
for i in range(2):
+195
View File
@@ -23,6 +23,7 @@ from sqlalchemy.testing import config
from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
from sqlalchemy.testing import mock
from sqlalchemy.testing import provision
from sqlalchemy.testing.schema import Column
@@ -76,6 +77,7 @@ class ReturnCombinationTests(fixtures.TestBase, AssertsCompiledSQL):
stmt = stmt.returning(t.c.x)
stmt = stmt.return_defaults()
assert_raises_message(
sa_exc.CompileError,
r"Can't compile statement that includes returning\(\) "
@@ -330,6 +332,7 @@ class InsertReturningTest(fixtures.TablesTest, AssertsExecutionResults):
table = self.tables.returning_tbl
exprs = testing.resolve_lambda(testcase, table=table)
result = connection.execute(
table.insert().returning(*exprs),
{"persons": 5, "full": False, "strval": "str1"},
@@ -679,6 +682,30 @@ class InsertReturnDefaultsTest(fixtures.TablesTest):
Column("upddef", Integer, onupdate=IncDefault()),
)
Table(
"table_no_addtl_defaults",
metadata,
Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
),
Column("data", String(50)),
)
class MyType(TypeDecorator):
impl = String(50)
def process_result_value(self, value, dialect):
return f"PROCESSED! {value}"
Table(
"table_datatype_has_result_proc",
metadata,
Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
),
Column("data", MyType()),
)
def test_chained_insert_pk(self, connection):
t1 = self.tables.t1
result = connection.execute(
@@ -758,6 +785,38 @@ class InsertReturnDefaultsTest(fixtures.TablesTest):
)
eq_(result.inserted_primary_key, (1,))
def test_insert_w_defaults_supplemental_cols(self, connection):
t1 = self.tables.t1
result = connection.execute(
t1.insert().return_defaults(supplemental_cols=[t1.c.id]),
{"data": "d1"},
)
eq_(result.all(), [(1, 0, None)])
def test_insert_w_no_defaults_supplemental_cols(self, connection):
t1 = self.tables.table_no_addtl_defaults
result = connection.execute(
t1.insert().return_defaults(supplemental_cols=[t1.c.id]),
{"data": "d1"},
)
eq_(result.all(), [(1,)])
def test_insert_w_defaults_supplemental_processor_cols(self, connection):
"""test that the cursor._rewind() used by supplemental RETURNING
clears out result-row processors as we will have already processed
the rows.
"""
t1 = self.tables.table_datatype_has_result_proc
result = connection.execute(
t1.insert().return_defaults(
supplemental_cols=[t1.c.id, t1.c.data]
),
{"data": "d1"},
)
eq_(result.all(), [(1, "PROCESSED! d1")])
class UpdatedReturnDefaultsTest(fixtures.TablesTest):
__requires__ = ("update_returning",)
@@ -792,6 +851,7 @@ class UpdatedReturnDefaultsTest(fixtures.TablesTest):
t1 = self.tables.t1
connection.execute(t1.insert().values(upddef=1))
result = connection.execute(
t1.update().values(upddef=2).return_defaults(t1.c.data)
)
@@ -800,6 +860,72 @@ class UpdatedReturnDefaultsTest(fixtures.TablesTest):
[None],
)
def test_update_values_col_is_excluded(self, connection):
"""columns that are in values() are not returned"""
t1 = self.tables.t1
connection.execute(t1.insert().values(upddef=1))
result = connection.execute(
t1.update().values(data="x", upddef=2).return_defaults(t1.c.data)
)
is_(result.returned_defaults, None)
result = connection.execute(
t1.update()
.values(data="x", upddef=2)
.return_defaults(t1.c.data, t1.c.id)
)
eq_(result.returned_defaults, (1,))
def test_update_supplemental_cols(self, connection):
"""with supplemental_cols, we can get back arbitrary cols."""
t1 = self.tables.t1
connection.execute(t1.insert().values(upddef=1))
result = connection.execute(
t1.update()
.values(data="x", insdef=3)
.return_defaults(supplemental_cols=[t1.c.data, t1.c.insdef])
)
row = result.returned_defaults
# row has all the cols in it
eq_(row, ("x", 3, 1))
eq_(row._mapping[t1.c.upddef], 1)
eq_(row._mapping[t1.c.insdef], 3)
# result is rewound
# but has both return_defaults + supplemental_cols
eq_(result.all(), [("x", 3, 1)])
def test_update_expl_return_defaults_plus_supplemental_cols(
self, connection
):
"""with supplemental_cols, we can get back arbitrary cols."""
t1 = self.tables.t1
connection.execute(t1.insert().values(upddef=1))
result = connection.execute(
t1.update()
.values(data="x", insdef=3)
.return_defaults(
t1.c.id, supplemental_cols=[t1.c.data, t1.c.insdef]
)
)
row = result.returned_defaults
# row has all the cols in it
eq_(row, (1, "x", 3))
eq_(row._mapping[t1.c.id], 1)
eq_(row._mapping[t1.c.insdef], 3)
assert t1.c.upddef not in row._mapping
# result is rewound
# but has both return_defaults + supplemental_cols
eq_(result.all(), [(1, "x", 3)])
def test_update_sql_expr(self, connection):
from sqlalchemy import literal
@@ -833,6 +959,75 @@ class UpdatedReturnDefaultsTest(fixtures.TablesTest):
eq_(dict(result.returned_defaults._mapping), {"upddef": 1})
class DeleteReturnDefaultsTest(fixtures.TablesTest):
__requires__ = ("delete_returning",)
run_define_tables = "each"
__backend__ = True
define_tables = InsertReturnDefaultsTest.define_tables
def test_delete(self, connection):
t1 = self.tables.t1
connection.execute(t1.insert().values(upddef=1))
result = connection.execute(t1.delete().return_defaults(t1.c.upddef))
eq_(
[result.returned_defaults._mapping[k] for k in (t1.c.upddef,)], [1]
)
def test_delete_empty_return_defaults(self, connection):
t1 = self.tables.t1
connection.execute(t1.insert().values(upddef=5))
result = connection.execute(t1.delete().return_defaults())
# there's no "delete" default, so we get None. we have to
# ask for them in all cases
eq_(result.returned_defaults, None)
def test_delete_non_default(self, connection):
"""test that a column not marked at all as a
default works with this feature."""
t1 = self.tables.t1
connection.execute(t1.insert().values(upddef=1))
result = connection.execute(t1.delete().return_defaults(t1.c.data))
eq_(
[result.returned_defaults._mapping[k] for k in (t1.c.data,)],
[None],
)
def test_delete_non_default_plus_default(self, connection):
t1 = self.tables.t1
connection.execute(t1.insert().values(upddef=1))
result = connection.execute(
t1.delete().return_defaults(t1.c.data, t1.c.upddef)
)
eq_(
dict(result.returned_defaults._mapping),
{"data": None, "upddef": 1},
)
def test_delete_supplemental_cols(self, connection):
"""with supplemental_cols, we can get back arbitrary cols."""
t1 = self.tables.t1
connection.execute(t1.insert().values(upddef=1))
result = connection.execute(
t1.delete().return_defaults(
t1.c.id, supplemental_cols=[t1.c.data, t1.c.insdef]
)
)
row = result.returned_defaults
# row has all the cols in it
eq_(row, (1, None, 0))
eq_(row._mapping[t1.c.insdef], 0)
# result is rewound
# but has both return_defaults + supplemental_cols
eq_(result.all(), [(1, None, 0)])
class InsertManyReturnDefaultsTest(fixtures.TablesTest):
__requires__ = ("insert_executemany_returning",)
run_define_tables = "each"
+21
View File
@@ -44,6 +44,7 @@ from sqlalchemy.sql import operators
from sqlalchemy.sql import table
from sqlalchemy.sql import util as sql_util
from sqlalchemy.sql import visitors
from sqlalchemy.sql.dml import Insert
from sqlalchemy.sql.selectable import LABEL_STYLE_NONE
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
@@ -3029,6 +3030,26 @@ class AnnotationsTest(fixtures.TestBase):
eq_(whereclause.left._annotations, {"foo": "bar"})
eq_(whereclause.right._annotations, {"foo": "bar"})
@testing.combinations(True, False, None)
def test_setup_inherit_cache(self, inherit_cache_value):
if inherit_cache_value is None:
class MyInsertThing(Insert):
pass
else:
class MyInsertThing(Insert):
inherit_cache = inherit_cache_value
t = table("t", column("x"))
anno = MyInsertThing(t)._annotate({"foo": "bar"})
if inherit_cache_value is not None:
is_(type(anno).__dict__["inherit_cache"], inherit_cache_value)
else:
assert "inherit_cache" not in type(anno).__dict__
def test_proxy_set_iteration_includes_annotated(self):
from sqlalchemy.schema import Column