render INSERT/UPDATE column expressions up front; pass state

Fixes related to rendering of complex UPDATE DML
which was not correctly preserving positional parameter
order in conjunction with DML features that are only known
to work on the PostgreSQL database.    Both pg8000
and asyncpg use positional parameters which is why these
issues are suddenly apparent.

crud.py now takes on the task of rendering the column
expressions for SET or VALUES so that for the very unusual
case that the column expression is a compound expression
that includes a bound parameter (namely an array index),
the bound parameter order is preserved.

Additionally, crud.py passes through the positional_names
keyword argument into bindparam_string() which is necessary
when CTEs are being rendered, as PG supports complex
CTE / INSERT / UPDATE scenarios.

Change-Id: I7f03920500e19b721636b84594de78a5bfdcbc82
This commit is contained in:
Mike Bayer
2020-08-08 13:03:17 -04:00
parent 302e8dee82
commit c0685e5f41
4 changed files with 234 additions and 56 deletions
+7 -12
View File
@@ -3286,7 +3286,7 @@ class SQLCompiler(Compiled):
if crud_params_single or not supports_default_values:
text += " (%s)" % ", ".join(
[preparer.format_column(c[0]) for c in crud_params_single]
[expr for c, expr, value in crud_params_single]
)
if self.returning or insert_stmt._returning:
@@ -3311,12 +3311,15 @@ class SQLCompiler(Compiled):
elif compile_state._has_multi_parameters:
text += " VALUES %s" % (
", ".join(
"(%s)" % (", ".join(c[1] for c in crud_param_set))
"(%s)"
% (", ".join(value for c, expr, value in crud_param_set))
for crud_param_set in crud_params
)
)
else:
insert_single_values_expr = ", ".join([c[1] for c in crud_params])
insert_single_values_expr = ", ".join(
[value for c, expr, value in crud_params]
)
text += " VALUES (%s)" % insert_single_values_expr
if toplevel:
self.insert_single_values_expr = insert_single_values_expr
@@ -3424,15 +3427,7 @@ class SQLCompiler(Compiled):
text += table_text
text += " SET "
include_table = (
is_multitable and self.render_table_with_column_in_update_from
)
text += ", ".join(
c[0]._compiler_dispatch(self, include_table=include_table)
+ "="
+ c[1]
for c in crud_params
)
text += ", ".join(expr + "=" + value for c, expr, value in crud_params)
if self.returning or update_stmt._returning:
if self.returning_precedes_values:
+151 -42
View File
@@ -65,7 +65,11 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
# compiled params - return binds for all columns
if compiler.column_keys is None and compile_state._no_parameters:
return [
(c, _create_bind_param(compiler, c, None, required=True))
(
c,
compiler.preparer.format_column(c),
_create_bind_param(compiler, c, None, required=True),
)
for c in stmt.table.columns
]
@@ -90,18 +94,20 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
if stmt_parameters is not None:
_get_stmt_parameters_params(
compiler, parameters, stmt_parameters, _column_as_key, values, kw
compiler,
compile_state,
parameters,
stmt_parameters,
_column_as_key,
values,
kw,
)
check_columns = {}
# special logic that only occurs for multi-table UPDATE
# statements
if (
compile_state.isupdate
and compile_state._extra_froms
and stmt_parameters
):
if compile_state.isupdate and compile_state.is_multitable:
_get_multitable_params(
compiler,
stmt,
@@ -162,7 +168,13 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
# into INSERT (firstcol) VALUES (DEFAULT) which can be turned
# into an in-place multi values. This supports
# insert_executemany_returning mode :)
values = [(stmt.table.columns[0], "DEFAULT")]
values = [
(
stmt.table.columns[0],
compiler.preparer.format_column(stmt.table.columns[0]),
"DEFAULT",
)
]
return values
@@ -286,7 +298,7 @@ def _scan_insert_from_select_cols(
col_key = _getattr_col_key(c)
if col_key in parameters and col_key not in check_columns:
parameters.pop(col_key)
values.append((c, None))
values.append((c, compiler.preparer.format_column(c), None))
else:
_append_param_insert_select_hasdefault(
compiler, stmt, c, add_select_cols, kw
@@ -297,7 +309,7 @@ def _scan_insert_from_select_cols(
compiler._insert_from_select = compiler._insert_from_select._generate()
compiler._insert_from_select._raw_columns = tuple(
compiler._insert_from_select._raw_columns
) + tuple(expr for col, expr in add_select_cols)
) + tuple(expr for col, col_expr, expr in add_select_cols)
def _scan_cols(
@@ -390,7 +402,13 @@ def _scan_cols(
elif compile_state.isupdate:
_append_param_update(
compiler, stmt, c, implicit_return_defaults, values, kw
compiler,
compile_state,
stmt,
c,
implicit_return_defaults,
values,
kw,
)
@@ -410,6 +428,10 @@ def _append_param_parameter(
value = parameters.pop(col_key)
col_value = compiler.preparer.format_column(
c, use_table=compile_state.include_table_with_column_exprs
)
if coercions._is_literal(value):
value = _create_bind_param(
compiler,
@@ -446,7 +468,7 @@ def _append_param_parameter(
if not c.primary_key:
compiler.postfetch.append(c)
value = compiler.process(value.self_group(), **kw)
values.append((c, value))
values.append((c, col_value, value))
def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
@@ -472,16 +494,31 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
not c.default.optional
or not compiler.dialect.sequences_optional
):
proc = compiler.process(c.default, **kw)
values.append((c, proc))
values.append(
(
c,
compiler.preparer.format_column(c),
compiler.process(c.default, **kw),
)
)
compiler.returning.append(c)
elif c.default.is_clause_element:
values.append(
(c, compiler.process(c.default.arg.self_group(), **kw))
(
c,
compiler.preparer.format_column(c),
compiler.process(c.default.arg.self_group(), **kw),
)
)
compiler.returning.append(c)
else:
values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
values.append(
(
c,
compiler.preparer.format_column(c),
_create_insert_prefetch_bind_param(compiler, c, **kw),
)
)
elif c is stmt.table._autoincrement_column or c.server_default is not None:
compiler.returning.append(c)
elif not c.nullable:
@@ -490,14 +527,22 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
_warn_pk_with_no_anticipated_value(c)
def _create_insert_prefetch_bind_param(compiler, c, process=True, name=None):
param = _create_bind_param(compiler, c, None, process=process, name=name)
def _create_insert_prefetch_bind_param(
compiler, c, process=True, name=None, **kw
):
param = _create_bind_param(
compiler, c, None, process=process, name=name, **kw
)
compiler.insert_prefetch.append(c)
return param
def _create_update_prefetch_bind_param(compiler, c, process=True, name=None):
param = _create_bind_param(compiler, c, None, process=process, name=name)
def _create_update_prefetch_bind_param(
compiler, c, process=True, name=None, **kw
):
param = _create_bind_param(
compiler, c, None, process=process, name=name, **kw
)
compiler.update_prefetch.append(c)
return param
@@ -539,9 +584,9 @@ def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
else:
col = _multiparam_column(c, index)
if isinstance(stmt, dml.Insert):
return _create_insert_prefetch_bind_param(compiler, col)
return _create_insert_prefetch_bind_param(compiler, col, **kw)
else:
return _create_update_prefetch_bind_param(compiler, col)
return _create_update_prefetch_bind_param(compiler, col, **kw)
def _append_param_insert_pk(compiler, stmt, c, values, kw):
@@ -582,7 +627,13 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw):
or compiler.dialect.preexecute_autoincrement_sequences
)
):
values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
values.append(
(
c,
compiler.preparer.format_column(c),
_create_insert_prefetch_bind_param(compiler, c, **kw),
)
)
elif c.default is None and c.server_default is None and not c.nullable:
# no .default, no .server_default, not autoincrement, we have
# no indication this primary key column will have any value
@@ -597,15 +648,25 @@ def _append_param_insert_hasdefault(
if compiler.dialect.supports_sequences and (
not c.default.optional or not compiler.dialect.sequences_optional
):
proc = compiler.process(c.default, **kw)
values.append((c, proc))
values.append(
(
c,
compiler.preparer.format_column(c),
compiler.process(c.default, **kw),
)
)
if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
elif not c.primary_key:
compiler.postfetch.append(c)
elif c.default.is_clause_element:
proc = compiler.process(c.default.arg.self_group(), **kw)
values.append((c, proc))
values.append(
(
c,
compiler.preparer.format_column(c),
compiler.process(c.default.arg.self_group(), **kw),
)
)
if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
@@ -613,7 +674,13 @@ def _append_param_insert_hasdefault(
# don't add primary key column to postfetch
compiler.postfetch.append(c)
else:
values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
values.append(
(
c,
compiler.preparer.format_column(c),
_create_insert_prefetch_bind_param(compiler, c, **kw),
)
)
def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw):
@@ -622,32 +689,55 @@ def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw):
if compiler.dialect.supports_sequences and (
not c.default.optional or not compiler.dialect.sequences_optional
):
proc = c.default
values.append((c, proc.next_value()))
values.append(
(c, compiler.preparer.format_column(c), c.default.next_value())
)
elif c.default.is_clause_element:
proc = c.default.arg.self_group()
values.append((c, proc))
values.append(
(c, compiler.preparer.format_column(c), c.default.arg.self_group())
)
else:
values.append(
(c, _create_insert_prefetch_bind_param(compiler, c, process=False))
(
c,
compiler.preparer.format_column(c),
_create_insert_prefetch_bind_param(
compiler, c, process=False, **kw
),
)
)
def _append_param_update(
compiler, stmt, c, implicit_return_defaults, values, kw
compiler, compile_state, stmt, c, implicit_return_defaults, values, kw
):
include_table = compile_state.include_table_with_column_exprs
if c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
values.append(
(c, compiler.process(c.onupdate.arg.self_group(), **kw))
(
c,
compiler.preparer.format_column(
c, use_table=include_table,
),
compiler.process(c.onupdate.arg.self_group(), **kw),
)
)
if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
else:
compiler.postfetch.append(c)
else:
values.append((c, _create_update_prefetch_bind_param(compiler, c)))
values.append(
(
c,
compiler.preparer.format_column(
c, use_table=include_table,
),
_create_update_prefetch_bind_param(compiler, c, **kw),
)
)
elif c.server_onupdate is not None:
if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
@@ -676,6 +766,9 @@ def _get_multitable_params(
(coercions.expect(roles.DMLColumnRole, c), param)
for c, param in stmt_parameters.items()
)
include_table = compile_state.include_table_with_column_exprs
affected_tables = set()
for t in compile_state._extra_froms:
for c in t.c:
@@ -683,6 +776,8 @@ def _get_multitable_params(
affected_tables.add(t)
check_columns[_getattr_col_key(c)] = c
value = normalized_params[c]
col_value = compiler.process(c, include_table=include_table)
if coercions._is_literal(value):
value = _create_bind_param(
compiler,
@@ -699,7 +794,7 @@ def _get_multitable_params(
else:
compiler.postfetch.append(c)
value = compiler.process(value.self_group(), **kw)
values.append((c, value))
values.append((c, col_value, value))
# determine tables which are actually to be updated - process onupdate
# and server_onupdate for these
for t in affected_tables:
@@ -711,6 +806,7 @@ def _get_multitable_params(
values.append(
(
c,
compiler.process(c, include_table=include_table),
compiler.process(
c.onupdate.arg.self_group(), **kw
),
@@ -721,8 +817,9 @@ def _get_multitable_params(
values.append(
(
c,
compiler.process(c, include_table=include_table),
_create_update_prefetch_bind_param(
compiler, c, name=_col_bind_name(c)
compiler, c, name=_col_bind_name(c), **kw
),
)
)
@@ -736,7 +833,7 @@ def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw):
for i, row in enumerate(compile_state._multi_parameters[1:]):
extension = []
for (col, param) in values_0:
for (col, col_expr, param) in values_0:
if col in row or col.key in row:
key = col if col in row else col.key
@@ -755,7 +852,7 @@ def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw):
compiler, stmt, col, i, kw
)
extension.append((col, new_param))
extension.append((col, col_expr, new_param))
values.append(extension)
@@ -763,8 +860,15 @@ def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw):
def _get_stmt_parameters_params(
compiler, parameters, stmt_parameters, _column_as_key, values, kw
compiler,
compile_state,
parameters,
stmt_parameters,
_column_as_key,
values,
kw,
):
for k, v in stmt_parameters.items():
colkey = _column_as_key(k)
if colkey is not None:
@@ -773,6 +877,11 @@ def _get_stmt_parameters_params(
# a non-Column expression on the left side;
# add it to values() in an "as-is" state,
# coercing right side to bound param
col_expr = compiler.process(
k, include_table=compile_state.include_table_with_column_exprs
)
if coercions._is_literal(v):
v = compiler.process(
elements.BindParameter(None, v, type_=k.type), **kw
@@ -780,7 +889,7 @@ def _get_stmt_parameters_params(
else:
v = compiler.process(v.self_group(), **kw)
values.append((k, v))
values.append((k, col_expr, v))
def _get_returning_modifiers(compiler, stmt, compile_state):
+9 -1
View File
@@ -133,6 +133,8 @@ class DMLState(CompileState):
class InsertDMLState(DMLState):
isinsert = True
include_table_with_column_exprs = False
def __init__(self, statement, compiler, **kw):
self.statement = statement
@@ -149,6 +151,8 @@ class InsertDMLState(DMLState):
class UpdateDMLState(DMLState):
isupdate = True
include_table_with_column_exprs = False
def __init__(self, statement, compiler, **kw):
self.statement = statement
self.isupdate = True
@@ -159,7 +163,11 @@ class UpdateDMLState(DMLState):
self._process_values(statement)
elif statement._multi_values:
self._process_multi_values(statement)
self._extra_froms = self._make_extra_froms(statement)
self._extra_froms = ef = self._make_extra_froms(statement)
self.is_multitable = mt = ef and self._dict_parameters
self.include_table_with_column_exprs = (
mt and compiler.render_table_with_column_in_update_from
)
@CompileState.plugin_for("default", "delete")
+67 -1
View File
@@ -840,7 +840,7 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
dialect=mysql.dialect(),
)
def test_update_to_expression(self):
def test_update_to_expression_one(self):
"""test update from an expression.
this logic is triggered currently by a left side that doesn't
@@ -856,6 +856,72 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
"UPDATE mytable SET foo(myid)=:param_1",
)
def test_update_to_expression_two(self):
"""test update from an expression.
this logic is triggered currently by a left side that doesn't
have a key. The current supported use case is updating the index
of a PostgreSQL ARRAY type.
"""
from sqlalchemy import ARRAY
t = table(
"foo",
column("data1", ARRAY(Integer)),
column("data2", ARRAY(Integer)),
)
stmt = t.update().values({t.c.data1[5]: 7, t.c.data2[10]: 18})
dialect = default.StrCompileDialect()
dialect.paramstyle = "qmark"
dialect.positional = True
self.assert_compile(
stmt,
"UPDATE foo SET data1[?]=?, data2[?]=?",
dialect=dialect,
checkpositional=(5, 7, 10, 18),
)
def test_update_to_expression_three(self):
# this test is from test_defaults but exercises a particular
# parameter ordering issue
metadata = MetaData()
q = Table(
"q",
metadata,
Column("x", Integer, default=2),
Column("y", Integer, onupdate=5),
Column("z", Integer),
)
p = Table(
"p",
metadata,
Column("s", Integer),
Column("t", Integer),
Column("u", Integer, onupdate=1),
)
cte = (
q.update().where(q.c.z == 1).values(x=7).returning(q.c.z).cte("c")
)
stmt = select([p.c.s, cte.c.z]).where(p.c.s == cte.c.z)
dialect = default.StrCompileDialect()
dialect.paramstyle = "qmark"
dialect.positional = True
self.assert_compile(
stmt,
"WITH c AS (UPDATE q SET x=?, y=? WHERE q.z = ? RETURNING q.z) "
"SELECT p.s, c.z FROM p, c WHERE p.s = c.z",
checkpositional=(7, None, 1),
dialect=dialect,
)
def test_update_bound_ordering(self):
"""test that bound parameters between the UPDATE and FROM clauses
order correctly in different SQL compilation scenarios.