Factor out the column resolution into helper func.

This commit is contained in:
Charles Leifer
2026-05-05 18:50:09 -05:00
parent 4fedb8f4d0
commit d7b42efdb3
+95 -79
View File
@@ -8016,6 +8016,97 @@ def safe_python_value(conv_func):
return validate
def _resolve_model_columns(cursor, model, select):
"""Resolve cursor columns against a model's selected nodes.
Returns ``(columns, fields, converters, no_convert, convert)``:
``columns`` and ``fields`` are aligned per-column lists, ``converters``
is a per-column ``python_value`` callable or ``None``, and
``no_convert``/``convert`` are the index partitions of ``converters``.
"""
combined = model._meta.combined
table = model._meta.table
description = cursor.description
ncols = len(description)
columns = []
converters = [None] * ncols
fields = [None] * ncols
for idx, description_item in enumerate(description):
column = orig_column = description_item[0]
# Try to clean-up messy column descriptions when people do not
# provide an alias. The idea is that we take something like:
# SUM("t1"."price") -> "price") -> price
dot_index = column.rfind('.')
if dot_index != -1:
column = column[dot_index + 1:]
column = column.strip('()"`')
columns.append(column)
# Now we'll see what they selected and see if we can improve the
# column-name being returned - e.g. by mapping it to the selected
# field's name.
try:
raw_node = select[idx]
except IndexError:
if column in combined:
raw_node = node = combined[column]
else:
continue
else:
node = raw_node.unwrap()
# If this column was given an alias, then we will use whatever
# alias was returned by the cursor.
is_alias = raw_node.is_alias()
if is_alias:
columns[idx] = orig_column
# Heuristics used to attempt to get the field associated with a
# given SELECT column, so that we can accurately convert the value
# returned by the database-cursor into a Python object.
if isinstance(node, Field):
if raw_node._coerce:
converters[idx] = node.python_value
fields[idx] = node
if not is_alias:
columns[idx] = node.name
elif isinstance(node, ColumnBase) and raw_node._converter:
converters[idx] = raw_node._converter
elif isinstance(node, Function) and node._coerce:
if node._python_value is not None:
converters[idx] = node._python_value
elif node.arguments and isinstance(node.arguments[0], Node):
# If the first argument is a field or references a column
# on a Model, try using that field's conversion function.
# This usually works, but we use "safe_python_value()" so
# that if a TypeError or ValueError occurs during
# conversion we can just fall-back to the raw cursor value.
first = node.arguments[0].unwrap()
if isinstance(first, Entity):
path = first._path[-1] # Try to look-up by name.
first = combined.get(path)
if isinstance(first, Field):
converters[idx] = safe_python_value(first.python_value)
elif column in combined:
if node._coerce:
converters[idx] = combined[column].python_value
if isinstance(node, Column) and node.source == table:
fields[idx] = combined[column]
no_convert = []
convert = []
for i in range(ncols):
if converters[i] is not None:
convert.append(i)
else:
no_convert.append(i)
return columns, fields, converters, no_convert, convert
class BaseModelCursorWrapper(DictCursorWrapper):
def __init__(self, cursor, model, columns):
super(BaseModelCursorWrapper, self).__init__(cursor)
@@ -8023,85 +8114,10 @@ class BaseModelCursorWrapper(DictCursorWrapper):
self.select = columns or []
def initialize(self):
combined = self.model._meta.combined
table = self.model._meta.table
description = self.cursor.description
self.ncols = len(self.cursor.description)
self.columns = []
self.converters = converters = [None] * self.ncols
self.fields = fields = [None] * self.ncols
for idx, description_item in enumerate(description):
column = orig_column = description_item[0]
# Try to clean-up messy column descriptions when people do not
# provide an alias. The idea is that we take something like:
# SUM("t1"."price") -> "price") -> price
dot_index = column.rfind('.')
if dot_index != -1:
column = column[dot_index + 1:]
column = column.strip('()"`')
self.columns.append(column)
# Now we'll see what they selected and see if we can improve the
# column-name being returned - e.g. by mapping it to the selected
# field's name.
try:
raw_node = self.select[idx]
except IndexError:
if column in combined:
raw_node = node = combined[column]
else:
continue
else:
node = raw_node.unwrap()
# If this column was given an alias, then we will use whatever
# alias was returned by the cursor.
is_alias = raw_node.is_alias()
if is_alias:
self.columns[idx] = orig_column
# Heuristics used to attempt to get the field associated with a
# given SELECT column, so that we can accurately convert the value
# returned by the database-cursor into a Python object.
if isinstance(node, Field):
if raw_node._coerce:
converters[idx] = node.python_value
fields[idx] = node
if not is_alias:
self.columns[idx] = node.name
elif isinstance(node, ColumnBase) and raw_node._converter:
converters[idx] = raw_node._converter
elif isinstance(node, Function) and node._coerce:
if node._python_value is not None:
converters[idx] = node._python_value
elif node.arguments and isinstance(node.arguments[0], Node):
# If the first argument is a field or references a column
# on a Model, try using that field's conversion function.
# This usually works, but we use "safe_python_value()" so
# that if a TypeError or ValueError occurs during
# conversion we can just fall-back to the raw cursor value.
first = node.arguments[0].unwrap()
if isinstance(first, Entity):
path = first._path[-1] # Try to look-up by name.
first = combined.get(path)
if isinstance(first, Field):
converters[idx] = safe_python_value(first.python_value)
elif column in combined:
if node._coerce:
converters[idx] = combined[column].python_value
if isinstance(node, Column) and node.source == table:
fields[idx] = combined[column]
self.no_convert = []
self.convert = []
for i in range(self.ncols):
if converters[i] is not None:
self.convert.append(i)
else:
self.no_convert.append(i)
(self.columns, self.fields, self.converters,
self.no_convert, self.convert) = _resolve_model_columns(
self.cursor, self.model, self.select)
self.ncols = len(self.columns)
def process_row(self, row):
raise NotImplementedError