mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-14 12:47:22 -04:00
- named_with_column becomes an attribute
- cleanup within compiler visit_select(), column labeling - is_select() removed from dialects, replaced with returns_rows_text(), returns_rows_compiled() - should_autocommit() removed from dialects, replaced with should_autocommit_text() and should_autocommit_compiled() - typemap and column_labels collections removed from Compiler, replaced with single "result_map" collection. - ResultProxy uses more succinct logic in combination with result_map to target columns
This commit is contained in:
@@ -356,11 +356,11 @@ class AccessCompiler(compiler.DefaultCompiler):
|
||||
"""Access uses "mod" instead of "%" """
|
||||
return binary.operator == '%' and 'mod' or binary.operator
|
||||
|
||||
def label_select_column(self, select, column):
|
||||
def label_select_column(self, select, column, asfrom):
|
||||
if isinstance(column, expression._Function):
|
||||
return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:])
|
||||
return column.label()
|
||||
else:
|
||||
return super(AccessCompiler, self).label_select_column(select, column)
|
||||
return super(AccessCompiler, self).label_select_column(select, column, asfrom)
|
||||
|
||||
function_rewrites = {'current_date': 'now',
|
||||
'current_timestamp': 'now',
|
||||
|
||||
@@ -409,15 +409,6 @@ class InfoCompiler(compiler.DefaultCompiler):
|
||||
def limit_clause(self, select):
|
||||
return ""
|
||||
|
||||
def __visit_label(self, label):
|
||||
# TODO: whats this method for ?
|
||||
if self.select_stack:
|
||||
self.typemap.setdefault(label.name.lower(), label.obj.type)
|
||||
if self.strings[label.obj]:
|
||||
self.strings[label] = self.strings[label.obj] + " AS " + label.name
|
||||
else:
|
||||
self.strings[label] = None
|
||||
|
||||
def visit_function( self , func ):
|
||||
if func.name.lower() == 'current_date':
|
||||
return "today"
|
||||
|
||||
@@ -339,8 +339,8 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
|
||||
_ms_is_select = re.compile(r'\s*(?:SELECT|sp_columns)',
|
||||
re.I | re.UNICODE)
|
||||
|
||||
def is_select(self):
|
||||
return self._ms_is_select.match(self.statement) is not None
|
||||
def returns_rows_text(self, statement):
|
||||
return self._ms_is_select.match(statement) is not None
|
||||
|
||||
|
||||
class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext):
|
||||
@@ -910,11 +910,11 @@ class MSSQLCompiler(compiler.DefaultCompiler):
|
||||
else:
|
||||
return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)
|
||||
|
||||
def label_select_column(self, select, column):
|
||||
def label_select_column(self, select, column, asfrom):
|
||||
if isinstance(column, expression._Function):
|
||||
return column.label(None)
|
||||
else:
|
||||
return super(MSSQLCompiler, self).label_select_column(select, column)
|
||||
return super(MSSQLCompiler, self).label_select_column(select, column, asfrom)
|
||||
|
||||
function_rewrites = {'current_date': 'getdate',
|
||||
'length': 'len',
|
||||
|
||||
@@ -1378,9 +1378,6 @@ def descriptor():
|
||||
|
||||
|
||||
class MySQLExecutionContext(default.DefaultExecutionContext):
|
||||
_my_is_select = re.compile(r'\s*(?:SELECT|SHOW|DESCRIBE|XA +RECOVER)',
|
||||
re.I | re.UNICODE)
|
||||
|
||||
def post_exec(self):
|
||||
if self.compiled.isinsert and not self.executemany:
|
||||
if (not len(self._last_inserted_ids) or
|
||||
@@ -1388,11 +1385,11 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
|
||||
self._last_inserted_ids = ([self.cursor.lastrowid] +
|
||||
self._last_inserted_ids[1:])
|
||||
|
||||
def is_select(self):
|
||||
return SELECT_RE.match(self.statement)
|
||||
def returns_rows_text(self, statement):
|
||||
return SELECT_RE.match(statement)
|
||||
|
||||
def should_autocommit(self):
|
||||
return AUTOCOMMIT_RE.match(self.statement)
|
||||
def should_autocommit_text(self, statement):
|
||||
return AUTOCOMMIT_RE.match(statement)
|
||||
|
||||
|
||||
class MySQLDialect(default.DefaultDialect):
|
||||
@@ -1873,9 +1870,6 @@ class MySQLCompiler(compiler.DefaultCompiler):
|
||||
if type_ is None:
|
||||
return self.process(cast.clause)
|
||||
|
||||
if self.stack and self.stack[-1].get('select'):
|
||||
# not sure if we want to set the typemap here...
|
||||
self.typemap.setdefault("CAST", cast.type)
|
||||
return 'CAST(%s AS %s)' % (self.process(cast.clause), type_)
|
||||
|
||||
|
||||
|
||||
@@ -233,16 +233,24 @@ RETURNING_QUOTED_RE = re.compile(
|
||||
|
||||
class PGExecutionContext(default.DefaultExecutionContext):
|
||||
|
||||
def is_select(self):
|
||||
m = SELECT_RE.match(self.statement)
|
||||
return m and (not m.group(1) or (RETURNING_RE.search(self.statement)
|
||||
and RETURNING_QUOTED_RE.match(self.statement)))
|
||||
def returns_rows_text(self, statement):
|
||||
m = SELECT_RE.match(statement)
|
||||
return m and (not m.group(1) or (RETURNING_RE.search(statement)
|
||||
and RETURNING_QUOTED_RE.match(statement)))
|
||||
|
||||
def returns_rows_compiled(self, compiled):
|
||||
return isinstance(compiled.statement, expression.Selectable) or \
|
||||
(
|
||||
(compiled.isupdate or compiled.isinsert) and "postgres_returning" in compiled.statement.kwargs
|
||||
)
|
||||
|
||||
def create_cursor(self):
|
||||
# executing a default or Sequence standalone creates an execution context without a statement.
|
||||
# so slightly hacky "if no statement assume we're server side" logic
|
||||
# TODO: dont use regexp if Compiled is used ?
|
||||
self.__is_server_side = \
|
||||
self.dialect.server_side_cursors and (self.statement is None or \
|
||||
self.dialect.server_side_cursors and \
|
||||
(self.statement is None or \
|
||||
(SELECT_RE.match(self.statement) and not re.search(r'FOR UPDATE(?: NOWAIT)?\s*$', self.statement, re.I))
|
||||
)
|
||||
|
||||
|
||||
@@ -185,8 +185,8 @@ class SQLiteExecutionContext(default.DefaultExecutionContext):
|
||||
if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
|
||||
self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:]
|
||||
|
||||
def is_select(self):
|
||||
return SELECT_REGEXP.match(self.statement)
|
||||
def returns_rows_text(self, statement):
|
||||
return SELECT_REGEXP.match(statement)
|
||||
|
||||
class SQLiteDialect(default.DefaultDialect):
|
||||
supports_alter = False
|
||||
@@ -343,9 +343,6 @@ class SQLiteCompiler(compiler.DefaultCompiler):
|
||||
if self.dialect.supports_cast:
|
||||
return super(SQLiteCompiler, self).visit_cast(cast)
|
||||
else:
|
||||
if self.stack and self.stack[-1].get('select'):
|
||||
# not sure if we want to set the typemap here...
|
||||
self.typemap.setdefault("CAST", cast.type)
|
||||
return self.process(cast.clause)
|
||||
|
||||
def limit_clause(self, select):
|
||||
|
||||
@@ -778,11 +778,11 @@ class SybaseSQLCompiler(compiler.DefaultCompiler):
|
||||
else:
|
||||
return super(SybaseSQLCompiler, self).visit_binary(binary)
|
||||
|
||||
def label_select_column(self, select, column):
|
||||
def label_select_column(self, select, column, asfrom):
|
||||
if isinstance(column, expression._Function):
|
||||
return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:])
|
||||
return column.label(None)
|
||||
else:
|
||||
return super(SybaseSQLCompiler, self).label_select_column(select, column)
|
||||
return super(SybaseSQLCompiler, self).label_select_column(select, column, asfrom)
|
||||
|
||||
function_rewrites = {'current_date': 'getdate',
|
||||
}
|
||||
@@ -795,13 +795,7 @@ class SybaseSQLCompiler(compiler.DefaultCompiler):
|
||||
cast = expression._Cast(func, SybaseDate_mxodbc)
|
||||
# infinite recursion
|
||||
# res = self.visit_cast(cast)
|
||||
if self.stack and self.stack[-1].get('select'):
|
||||
# not sure if we want to set the typemap here...
|
||||
self.typemap.setdefault("CAST", cast.type)
|
||||
# res = "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause))
|
||||
res = "CAST(%s AS %s)" % (res, self.process(cast.typeclause))
|
||||
# elif func.name.lower() == 'count':
|
||||
# res = 'count(*)'
|
||||
return res
|
||||
|
||||
def for_update_clause(self, select):
|
||||
|
||||
@@ -315,6 +315,12 @@ class ExecutionContext(object):
|
||||
isupdate
|
||||
True if the statement is an UPDATE.
|
||||
|
||||
should_autocommit
|
||||
True if the statement is a "committable" statement
|
||||
|
||||
returns_rows
|
||||
True if the statement should return result rows
|
||||
|
||||
The Dialect should provide an ExecutionContext via the
|
||||
create_execution_context() method. The `pre_exec` and `post_exec`
|
||||
methods will be called for compiled statements.
|
||||
@@ -363,8 +369,13 @@ class ExecutionContext(object):
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
def should_autocommit(self):
|
||||
"""Return True if this context's statement should be 'committed' automatically in a non-transactional context"""
|
||||
def should_autocommit_compiled(self, compiled):
|
||||
"""return True if the given Compiled object refers to a "committable" statement."""
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
def should_autocommit_text(self, statement):
|
||||
"""Parse the given textual statement and return True if it refers to a "committable" statement"""
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -750,7 +761,7 @@ class Connection(Connectable):
|
||||
|
||||
# TODO: have the dialect determine if autocommit can be set on
|
||||
# the connection directly without this extra step
|
||||
if not self.in_transaction() and context.should_autocommit():
|
||||
if not self.in_transaction() and context.should_autocommit:
|
||||
self._commit_impl()
|
||||
|
||||
def _autorollback(self):
|
||||
@@ -1305,7 +1316,7 @@ class ResultProxy(object):
|
||||
self.cursor = context.cursor
|
||||
self.connection = context.root_connection
|
||||
self.__echo = context.engine._should_log_info
|
||||
if context.is_select():
|
||||
if context.returns_rows:
|
||||
self._init_metadata()
|
||||
self._rowcount = None
|
||||
else:
|
||||
@@ -1322,8 +1333,6 @@ class ResultProxy(object):
|
||||
out_parameters = property(lambda s:s.context.out_parameters)
|
||||
|
||||
def _init_metadata(self):
|
||||
if hasattr(self, '_ResultProxy__props'):
|
||||
return
|
||||
self.__props = {}
|
||||
self._key_cache = self._create_key_cache()
|
||||
self.__keys = []
|
||||
@@ -1336,20 +1345,24 @@ class ResultProxy(object):
|
||||
# sqlite possibly prepending table name to colnames so strip
|
||||
colname = (item[0].split('.')[-1]).decode(self.dialect.encoding)
|
||||
|
||||
if self.context.typemap is not None:
|
||||
type = self.context.typemap.get(colname.lower(), typemap.get(item[1], types.NULLTYPE))
|
||||
if self.context.result_map:
|
||||
try:
|
||||
(name, obj, type_) = self.context.result_map[colname]
|
||||
except KeyError:
|
||||
(name, obj, type_) = (colname, None, typemap.get(item[1], types.NULLTYPE))
|
||||
else:
|
||||
type = typemap.get(item[1], types.NULLTYPE)
|
||||
(name, obj, type_) = (colname, None, typemap.get(item[1], types.NULLTYPE))
|
||||
|
||||
rec = (type, type.dialect_impl(self.dialect).result_processor(self.dialect), i)
|
||||
rec = (type_, type_.dialect_impl(self.dialect).result_processor(self.dialect), i)
|
||||
|
||||
if rec[0] is None:
|
||||
raise exceptions.InvalidRequestError(
|
||||
"None for metadata " + colname)
|
||||
if self.__props.setdefault(colname.lower(), rec) is not rec:
|
||||
self.__props[colname.lower()] = (type, self.__ambiguous_processor(colname), 0)
|
||||
if self.__props.setdefault(name.lower(), rec) is not rec:
|
||||
self.__props[name.lower()] = (type_, self.__ambiguous_processor(colname), 0)
|
||||
|
||||
self.__keys.append(colname)
|
||||
self.__props[i] = rec
|
||||
if obj:
|
||||
for o in obj:
|
||||
self.__props[o] = rec
|
||||
|
||||
if self.__echo:
|
||||
self.context.engine.logger.debug("Col " + repr(tuple([x[0] for x in metadata])))
|
||||
@@ -1362,16 +1375,19 @@ class ResultProxy(object):
|
||||
"""Given a key, which could be a ColumnElement, string, etc.,
|
||||
matches it to the appropriate key we got from the result set's
|
||||
metadata; then cache it locally for quick re-access."""
|
||||
|
||||
if isinstance(key, int) and key in props:
|
||||
|
||||
if isinstance(key, basestring):
|
||||
key = key.lower()
|
||||
|
||||
try:
|
||||
rec = props[key]
|
||||
elif isinstance(key, basestring) and key.lower() in props:
|
||||
rec = props[key.lower()]
|
||||
elif isinstance(key, expression.ColumnElement):
|
||||
label = context.column_labels.get(key._label, key.name).lower()
|
||||
if label in props:
|
||||
rec = props[label]
|
||||
if not "rec" in locals():
|
||||
except KeyError:
|
||||
# fallback for targeting a ColumnElement to a textual expression
|
||||
if isinstance(key, expression.ColumnElement):
|
||||
if key._label.lower() in props:
|
||||
return props[key._label.lower()]
|
||||
elif key.name.lower() in props:
|
||||
return props[key.name.lower()]
|
||||
raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % (str(key)))
|
||||
|
||||
return rec
|
||||
@@ -1470,18 +1486,20 @@ class ResultProxy(object):
|
||||
|
||||
def _get_col(self, row, key):
|
||||
try:
|
||||
rec = self._key_cache[key]
|
||||
type_, processor, index = self._key_cache[key]
|
||||
except TypeError:
|
||||
# the 'slice' use case is very infrequent,
|
||||
# so we use an exception catch to reduce conditionals in _get_col
|
||||
if isinstance(key, slice):
|
||||
indices = key.indices(len(row))
|
||||
return tuple([self._get_col(row, i) for i in xrange(*indices)])
|
||||
|
||||
if rec[1]:
|
||||
return rec[1](row[rec[2]])
|
||||
else:
|
||||
raise
|
||||
|
||||
if processor:
|
||||
return processor(row[index])
|
||||
else:
|
||||
return row[rec[2]]
|
||||
return row[index]
|
||||
|
||||
def _fetchone_impl(self):
|
||||
return self.cursor.fetchone()
|
||||
|
||||
@@ -146,9 +146,8 @@ class DefaultExecutionContext(base.ExecutionContext):
|
||||
if value is not None
|
||||
])
|
||||
|
||||
self.typemap = compiled.typemap
|
||||
self.column_labels = compiled.column_labels
|
||||
|
||||
self.result_map = compiled.result_map
|
||||
|
||||
if not dialect.supports_unicode_statements:
|
||||
self.statement = unicode(compiled).encode(self.dialect.encoding)
|
||||
else:
|
||||
@@ -156,6 +155,12 @@ class DefaultExecutionContext(base.ExecutionContext):
|
||||
|
||||
self.isinsert = compiled.isinsert
|
||||
self.isupdate = compiled.isupdate
|
||||
if isinstance(compiled.statement, expression._TextClause):
|
||||
self.returns_rows = self.returns_rows_text(self.statement)
|
||||
self.should_autocommit = self.should_autocommit_text(self.statement)
|
||||
else:
|
||||
self.returns_rows = self.returns_rows_compiled(compiled)
|
||||
self.should_autocommit = self.should_autocommit_compiled(compiled)
|
||||
|
||||
if not parameters:
|
||||
self.compiled_parameters = [compiled.construct_params()]
|
||||
@@ -170,7 +175,7 @@ class DefaultExecutionContext(base.ExecutionContext):
|
||||
|
||||
elif statement is not None:
|
||||
# plain text statement.
|
||||
self.typemap = self.column_labels = None
|
||||
self.result_map = None
|
||||
self.parameters = self.__encode_param_keys(parameters)
|
||||
self.executemany = len(parameters) > 1
|
||||
if not dialect.supports_unicode_statements:
|
||||
@@ -179,10 +184,12 @@ class DefaultExecutionContext(base.ExecutionContext):
|
||||
self.statement = statement
|
||||
self.isinsert = self.isupdate = False
|
||||
self.cursor = self.create_cursor()
|
||||
self.returns_rows = self.returns_rows_text(statement)
|
||||
self.should_autocommit = self.should_autocommit_text(statement)
|
||||
else:
|
||||
# no statement. used for standalone ColumnDefault execution.
|
||||
self.statement = None
|
||||
self.isinsert = self.isupdate = self.executemany = False
|
||||
self.isinsert = self.isupdate = self.executemany = self.returns_rows = self.should_autocommit = False
|
||||
self.cursor = self.create_cursor()
|
||||
|
||||
connection = property(lambda s:s._connection._branch())
|
||||
@@ -244,10 +251,18 @@ class DefaultExecutionContext(base.ExecutionContext):
|
||||
parameters.append(param)
|
||||
return parameters
|
||||
|
||||
def is_select(self):
|
||||
"""return TRUE if the statement is expected to have result rows."""
|
||||
def returns_rows_compiled(self, compiled):
|
||||
return isinstance(compiled.statement, expression.Selectable)
|
||||
|
||||
return SELECT_REGEXP.match(self.statement)
|
||||
def returns_rows_text(self, statement):
|
||||
return SELECT_REGEXP.match(statement)
|
||||
|
||||
def should_autocommit_compiled(self, compiled):
|
||||
return isinstance(compiled.statement, expression._UpdateBase)
|
||||
|
||||
def should_autocommit_text(self, statement):
|
||||
return AUTOCOMMIT_REGEXP.match(statement)
|
||||
|
||||
|
||||
def create_cursor(self):
|
||||
return self._connection.connection.cursor()
|
||||
@@ -261,9 +276,6 @@ class DefaultExecutionContext(base.ExecutionContext):
|
||||
def result(self):
|
||||
return self.get_result_proxy()
|
||||
|
||||
def should_autocommit(self):
|
||||
return AUTOCOMMIT_REGEXP.match(self.statement)
|
||||
|
||||
def pre_exec(self):
|
||||
pass
|
||||
|
||||
|
||||
@@ -249,7 +249,7 @@ class Query(object):
|
||||
# alias non-labeled column elements.
|
||||
if isinstance(column, sql.ColumnElement) and not hasattr(column, '_label'):
|
||||
column = column.label(None)
|
||||
|
||||
|
||||
q._entities = q._entities + [(column, None, id)]
|
||||
return q
|
||||
|
||||
@@ -887,7 +887,7 @@ class Query(object):
|
||||
context.exec_with_path(self.select_mapper, value.key, value.setup, context, parentclauses=clauses)
|
||||
elif isinstance(m, sql.ColumnElement):
|
||||
if clauses is not None:
|
||||
m = clauses.adapt_clause(m)
|
||||
m = clauses.aliased_column(m)
|
||||
context.secondary_columns.append(m)
|
||||
|
||||
if self._eager_loaders and self._nestable(**self._select_args()):
|
||||
|
||||
@@ -456,7 +456,7 @@ class Column(SchemaItem, expression._ColumnClause):
|
||||
|
||||
def __str__(self):
|
||||
if self.table is not None:
|
||||
if self.table.named_with_column():
|
||||
if self.table.named_with_column:
|
||||
return (self.table.description + "." + self.description)
|
||||
else:
|
||||
return self.description
|
||||
|
||||
@@ -130,13 +130,11 @@ class DefaultCompiler(engine.Compiled):
|
||||
# a stack. what recursive compiler doesn't have a stack ? :)
|
||||
self.stack = []
|
||||
|
||||
# a dictionary of result-set column names (strings) to TypeEngine instances,
|
||||
# which will be passed to a ResultProxy and used for resultset-level value conversion
|
||||
self.typemap = {}
|
||||
|
||||
# a dictionary of select columns labels mapped to their "generated" label
|
||||
self.column_labels = {}
|
||||
|
||||
# relates label names in the final SQL to
|
||||
# a tuple of local column/label name, ColumnElement object (if any) and TypeEngine.
|
||||
# ResultProxy uses this for type processing and column targeting
|
||||
self.result_map = {}
|
||||
|
||||
# a dictionary of ClauseElement subclasses to counters, which are used to
|
||||
# generate truncated identifier names or "anonymous" identifiers such as
|
||||
# for aliases
|
||||
@@ -213,19 +211,15 @@ class DefaultCompiler(engine.Compiled):
|
||||
def visit_grouping(self, grouping, **kwargs):
|
||||
return "(" + self.process(grouping.elem) + ")"
|
||||
|
||||
def visit_label(self, label, typemap=None, column_labels=None):
|
||||
def visit_label(self, label, result_map=None):
|
||||
labelname = self._truncated_identifier("colident", label.name)
|
||||
|
||||
if typemap is not None:
|
||||
self.typemap.setdefault(labelname.lower(), label.obj.type)
|
||||
if result_map is not None:
|
||||
result_map[labelname] = (label.name, (label, label.obj), label.obj.type)
|
||||
|
||||
if column_labels is not None:
|
||||
if isinstance(label.obj, sql._ColumnClause):
|
||||
column_labels[label.obj._label] = labelname
|
||||
column_labels[label.name] = labelname
|
||||
return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)])
|
||||
|
||||
def visit_column(self, column, typemap=None, column_labels=None, **kwargs):
|
||||
def visit_column(self, column, result_map=None, **kwargs):
|
||||
# there is actually somewhat of a ruleset when you would *not* necessarily
|
||||
# want to truncate a column identifier, if its mapped to the name of a
|
||||
# physical column. but thats very hard to identify at this point, and
|
||||
@@ -236,15 +230,13 @@ class DefaultCompiler(engine.Compiled):
|
||||
else:
|
||||
name = column.name
|
||||
|
||||
if typemap is not None:
|
||||
typemap.setdefault(name.lower(), column.type)
|
||||
if column_labels is not None:
|
||||
self.column_labels.setdefault(column._label, name.lower())
|
||||
if result_map is not None:
|
||||
result_map[name] = (name, (column, ), column.type)
|
||||
|
||||
if column._is_oid:
|
||||
n = self.dialect.oid_column_name(column)
|
||||
if n is not None:
|
||||
if column.table is None or not column.table.named_with_column():
|
||||
if column.table is None or not column.table.named_with_column:
|
||||
return n
|
||||
else:
|
||||
return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + n
|
||||
@@ -254,7 +246,7 @@ class DefaultCompiler(engine.Compiled):
|
||||
return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + self.preparer.quote(pk, pkname)
|
||||
else:
|
||||
return None
|
||||
elif column.table is None or not column.table.named_with_column():
|
||||
elif column.table is None or not column.table.named_with_column:
|
||||
if getattr(column, "is_literal", False):
|
||||
return name
|
||||
else:
|
||||
@@ -277,8 +269,9 @@ class DefaultCompiler(engine.Compiled):
|
||||
|
||||
def visit_textclause(self, textclause, **kwargs):
|
||||
if textclause.typemap is not None:
|
||||
self.typemap.update(textclause.typemap)
|
||||
|
||||
for colname, type_ in textclause.typemap.iteritems():
|
||||
self.result_map[colname] = (colname, None, type_)
|
||||
|
||||
def do_bindparam(m):
|
||||
name = m.group(1)
|
||||
if name in textclause.bindparams:
|
||||
@@ -302,7 +295,7 @@ class DefaultCompiler(engine.Compiled):
|
||||
sep = ', '
|
||||
else:
|
||||
sep = " " + self.operator_string(clauselist.operator) + " "
|
||||
return string.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None], sep)
|
||||
return sep.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None])
|
||||
|
||||
def apply_function_parens(self, func):
|
||||
return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0
|
||||
@@ -310,12 +303,13 @@ class DefaultCompiler(engine.Compiled):
|
||||
def visit_calculatedclause(self, clause, **kwargs):
|
||||
return self.process(clause.clause_expr)
|
||||
|
||||
def visit_cast(self, cast, typemap=None, **kwargs):
|
||||
def visit_cast(self, cast, **kwargs):
|
||||
return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause))
|
||||
|
||||
def visit_function(self, func, typemap=None, **kwargs):
|
||||
if typemap is not None:
|
||||
typemap.setdefault(func.name, func.type)
|
||||
def visit_function(self, func, result_map=None, **kwargs):
|
||||
if result_map is not None:
|
||||
result_map[func.name] = (func.name, None, func.type)
|
||||
|
||||
if not self.apply_function_parens(func):
|
||||
return ".".join(func.packagenames + [func.name])
|
||||
else:
|
||||
@@ -325,7 +319,7 @@ class DefaultCompiler(engine.Compiled):
|
||||
stack_entry = {'select':cs}
|
||||
|
||||
if asfrom:
|
||||
stack_entry['is_selected_from'] = stack_entry['is_subquery'] = True
|
||||
stack_entry['is_subquery'] = True
|
||||
elif self.stack and self.stack[-1].get('select'):
|
||||
stack_entry['is_subquery'] = True
|
||||
self.stack.append(stack_entry)
|
||||
@@ -353,7 +347,7 @@ class DefaultCompiler(engine.Compiled):
|
||||
s = s + " " + self.operator_string(unary.modifier)
|
||||
return s
|
||||
|
||||
def visit_binary(self, binary, typemap=None, **kwargs):
|
||||
def visit_binary(self, binary, **kwargs):
|
||||
op = self.operator_string(binary.operator)
|
||||
if callable(op):
|
||||
return op(self.process(binary.left), self.process(binary.right))
|
||||
@@ -438,22 +432,17 @@ class DefaultCompiler(engine.Compiled):
|
||||
else:
|
||||
return self.process(alias.original, **kwargs)
|
||||
|
||||
def label_select_column(self, select, column):
|
||||
"""convert a column from a select's "columns" clause.
|
||||
def label_select_column(self, select, column, asfrom):
|
||||
"""label columns present in a select()."""
|
||||
|
||||
given a select() and a column element from its inner_columns collection, return a
|
||||
Label object if this column should be labeled in the columns clause. Otherwise,
|
||||
return None and the column will be used as-is.
|
||||
|
||||
The calling method will traverse the returned label to acquire its string
|
||||
representation.
|
||||
"""
|
||||
|
||||
# SQLite doesnt like selecting from a subquery where the column
|
||||
# names look like table.colname. so if column is in a "selected from"
|
||||
# subquery, label it synoymously with its column name
|
||||
if isinstance(column, sql._Label):
|
||||
return column
|
||||
|
||||
if select.use_labels and column._label:
|
||||
return column.label(column._label)
|
||||
|
||||
if \
|
||||
(self.stack and self.stack[-1].get('is_selected_from')) and \
|
||||
asfrom and \
|
||||
isinstance(column, sql._ColumnClause) and \
|
||||
not column.is_literal and \
|
||||
column.table is not None and \
|
||||
@@ -462,20 +451,20 @@ class DefaultCompiler(engine.Compiled):
|
||||
elif not isinstance(column, (sql._UnaryExpression, sql._TextClause)) and not hasattr(column, 'name'):
|
||||
return column.label(None)
|
||||
else:
|
||||
return None
|
||||
return column
|
||||
|
||||
def visit_select(self, select, asfrom=False, parens=True, **kwargs):
|
||||
|
||||
stack_entry = {'select':select}
|
||||
|
||||
if asfrom:
|
||||
stack_entry['is_selected_from'] = stack_entry['is_subquery'] = True
|
||||
stack_entry['is_subquery'] = True
|
||||
column_clause_args = {}
|
||||
elif self.stack and 'select' in self.stack[-1]:
|
||||
stack_entry['is_subquery'] = True
|
||||
column_clause_args = {}
|
||||
else:
|
||||
column_clause_args = {'typemap':self.typemap, 'column_labels':self.column_labels}
|
||||
column_clause_args = {'result_map':self.result_map}
|
||||
|
||||
if self.stack and 'from' in self.stack[-1]:
|
||||
existingfroms = self.stack[-1]['from']
|
||||
@@ -487,8 +476,7 @@ class DefaultCompiler(engine.Compiled):
|
||||
correlate_froms = util.Set()
|
||||
for f in froms:
|
||||
correlate_froms.add(f)
|
||||
for f2 in f._get_from_objects():
|
||||
correlate_froms.add(f2)
|
||||
correlate_froms.update(f._get_from_objects())
|
||||
|
||||
# TODO: might want to propigate existing froms for select(select(select))
|
||||
# where innermost select should correlate to outermost
|
||||
@@ -501,19 +489,8 @@ class DefaultCompiler(engine.Compiled):
|
||||
inner_columns = util.OrderedSet()
|
||||
|
||||
for co in select.inner_columns:
|
||||
if select.use_labels:
|
||||
labelname = co._label
|
||||
if labelname is not None:
|
||||
l = co.label(labelname)
|
||||
inner_columns.add(self.process(l, **column_clause_args))
|
||||
else:
|
||||
inner_columns.add(self.process(co, **column_clause_args))
|
||||
else:
|
||||
l = self.label_select_column(select, co)
|
||||
if l is not None:
|
||||
inner_columns.add(self.process(l, **column_clause_args))
|
||||
else:
|
||||
inner_columns.add(self.process(co, **column_clause_args))
|
||||
l = self.label_select_column(select, co, asfrom=asfrom)
|
||||
inner_columns.add(self.process(l, **column_clause_args))
|
||||
|
||||
collist = string.join(inner_columns.difference(util.Set([None])), ', ')
|
||||
|
||||
|
||||
@@ -1522,6 +1522,7 @@ class FromClause(Selectable):
|
||||
"""Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement."""
|
||||
|
||||
__visit_name__ = 'fromclause'
|
||||
named_with_column=False
|
||||
|
||||
def __init__(self):
|
||||
self.oid_column = None
|
||||
@@ -1562,13 +1563,6 @@ class FromClause(Selectable):
|
||||
|
||||
return Alias(self, name)
|
||||
|
||||
def named_with_column(self):
|
||||
"""True if the name of this FromClause may be prepended to a
|
||||
column in a generated SQL statement.
|
||||
"""
|
||||
|
||||
return False
|
||||
|
||||
def is_derived_from(self, fromclause):
|
||||
"""Return True if this FromClause is 'derived' from the given FromClause.
|
||||
|
||||
@@ -2379,6 +2373,8 @@ class Alias(FromClause):
|
||||
``FromClause`` subclasses.
|
||||
"""
|
||||
|
||||
named_with_column = True
|
||||
|
||||
def __init__(self, selectable, alias=None):
|
||||
baseselectable = selectable
|
||||
while isinstance(baseselectable, Alias):
|
||||
@@ -2386,7 +2382,7 @@ class Alias(FromClause):
|
||||
self.original = baseselectable
|
||||
self.selectable = selectable
|
||||
if alias is None:
|
||||
if self.original.named_with_column():
|
||||
if self.original.named_with_column:
|
||||
alias = getattr(self.original, 'name', None)
|
||||
alias = '{ANON %d %s}' % (id(self), alias or 'anon')
|
||||
self.name = alias
|
||||
@@ -2408,9 +2404,6 @@ class Alias(FromClause):
|
||||
def _table_iterator(self):
|
||||
return self.original._table_iterator()
|
||||
|
||||
def named_with_column(self):
|
||||
return True
|
||||
|
||||
def _exportable_columns(self):
|
||||
#return self.selectable._exportable_columns()
|
||||
return self.selectable.columns
|
||||
@@ -2602,7 +2595,7 @@ class _ColumnClause(ColumnElement):
|
||||
if self.is_literal:
|
||||
return None
|
||||
if self.__label is None:
|
||||
if self.table is not None and self.table.named_with_column():
|
||||
if self.table is not None and self.table.named_with_column:
|
||||
self.__label = self.table.name + "_" + self.name
|
||||
counter = 1
|
||||
while self.__label in self.table.c:
|
||||
@@ -2652,6 +2645,8 @@ class TableClause(FromClause):
|
||||
functionality.
|
||||
"""
|
||||
|
||||
named_with_column = True
|
||||
|
||||
def __init__(self, name, *columns):
|
||||
super(TableClause, self).__init__()
|
||||
self.name = self.fullname = name
|
||||
@@ -2666,9 +2661,6 @@ class TableClause(FromClause):
|
||||
# TableClause is immutable
|
||||
return self
|
||||
|
||||
def named_with_column(self):
|
||||
return True
|
||||
|
||||
def append_column(self, c):
|
||||
self._columns[c.name] = c
|
||||
c.table = self
|
||||
@@ -3041,16 +3033,14 @@ class Select(_SelectBaseMixin, FromClause):
|
||||
froms = froms.difference(hide_froms)
|
||||
|
||||
if len(froms) > 1:
|
||||
corr = self.__correlate
|
||||
if self.__correlate:
|
||||
froms = froms.difference(self.__correlate)
|
||||
if self._should_correlate and existing_froms is not None:
|
||||
corr.update(existing_froms)
|
||||
froms = froms.difference(existing_froms)
|
||||
|
||||
f = froms.difference(corr)
|
||||
if not f:
|
||||
if not froms:
|
||||
raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate()))
|
||||
return f
|
||||
else:
|
||||
return froms
|
||||
return froms
|
||||
|
||||
froms = property(_get_display_froms, doc="""Return a list of all FromClause elements which will be applied to the FROM clause of the resulting statement.""")
|
||||
|
||||
|
||||
@@ -101,6 +101,9 @@ class ReturningTest(AssertMixin):
|
||||
|
||||
result3 = table.insert(postgres_returning=[(table.c.id*2).label('double_id')]).execute({'persons': 4, 'full': False})
|
||||
self.assertEqual([dict(row) for row in result3], [{'double_id':8}])
|
||||
|
||||
result4 = testbase.db.execute('insert into tables (id, persons, "full") values (5, 10, true) returning persons')
|
||||
self.assertEqual([dict(row) for row in result4], [{'persons': 10}])
|
||||
finally:
|
||||
table.drop()
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ class CompileTest(AssertMixin):
|
||||
t1.update().compile()
|
||||
|
||||
# TODO: this is alittle high
|
||||
@profiling.profiled('ctest_select', call_range=(130, 150), always=True)
|
||||
@profiling.profiled('ctest_select', call_range=(110, 130), always=True)
|
||||
def test_select(self):
|
||||
s = select([t1], t1.c.c2==t2.c.c1)
|
||||
s.compile()
|
||||
|
||||
@@ -50,7 +50,7 @@ class ZooMarkTest(testing.AssertMixin):
|
||||
metadata.create_all()
|
||||
|
||||
@testing.supported('postgres')
|
||||
@profiling.profiled('populate', call_range=(2800, 3700), always=True)
|
||||
@profiling.profiled('populate', call_range=(2700, 3700), always=True)
|
||||
def test_1a_populate(self):
|
||||
Zoo = metadata.tables['Zoo']
|
||||
Animal = metadata.tables['Animal']
|
||||
@@ -126,7 +126,7 @@ class ZooMarkTest(testing.AssertMixin):
|
||||
tick = i.execute(Species='Tick', Name='Tick %d' % x, Legs=8)
|
||||
|
||||
@testing.supported('postgres')
|
||||
@profiling.profiled('properties', call_range=(2900, 3330), always=True)
|
||||
@profiling.profiled('properties', call_range=(2300, 3030), always=True)
|
||||
def test_3_properties(self):
|
||||
Zoo = metadata.tables['Zoo']
|
||||
Animal = metadata.tables['Animal']
|
||||
@@ -149,7 +149,7 @@ class ZooMarkTest(testing.AssertMixin):
|
||||
ticks = fullobject(Animal.select(Animal.c.Species=='Tick'))
|
||||
|
||||
@testing.supported('postgres')
|
||||
@profiling.profiled('expressions', call_range=(10350, 12200), always=True)
|
||||
@profiling.profiled('expressions', call_range=(9200, 12050), always=True)
|
||||
def test_4_expressions(self):
|
||||
Zoo = metadata.tables['Zoo']
|
||||
Animal = metadata.tables['Animal']
|
||||
@@ -203,7 +203,7 @@ class ZooMarkTest(testing.AssertMixin):
|
||||
assert len(fulltable(Animal.select(func.date_part('day', Animal.c.LastEscape) == 21))) == 1
|
||||
|
||||
@testing.supported('postgres')
|
||||
@profiling.profiled('aggregates', call_range=(960, 1170), always=True)
|
||||
@profiling.profiled('aggregates', call_range=(800, 1170), always=True)
|
||||
def test_5_aggregates(self):
|
||||
Animal = metadata.tables['Animal']
|
||||
Zoo = metadata.tables['Zoo']
|
||||
@@ -245,7 +245,7 @@ class ZooMarkTest(testing.AssertMixin):
|
||||
legs.sort()
|
||||
|
||||
@testing.supported('postgres')
|
||||
@profiling.profiled('editing', call_range=(1150, 1280), always=True)
|
||||
@profiling.profiled('editing', call_range=(1050, 1180), always=True)
|
||||
def test_6_editing(self):
|
||||
Zoo = metadata.tables['Zoo']
|
||||
|
||||
@@ -274,7 +274,7 @@ class ZooMarkTest(testing.AssertMixin):
|
||||
assert SDZ['Founded'] == datetime.date(1935, 9, 13)
|
||||
|
||||
@testing.supported('postgres')
|
||||
@profiling.profiled('multiview', call_range=(2300, 2500), always=True)
|
||||
@profiling.profiled('multiview', call_range=(1900, 2300), always=True)
|
||||
def test_7_multiview(self):
|
||||
Zoo = metadata.tables['Zoo']
|
||||
Animal = metadata.tables['Animal']
|
||||
|
||||
Reference in New Issue
Block a user