mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-30 12:34:52 -04:00
f4c4f784cd
Fixed bug in compiler where the string identifier of a savepoint would be cached in the identifier quoting dictionary; as these identifiers are arbitrary, a small memory leak could occur if a single :class:`.Connection` had an unbounded number of savepoints used, as well as if the savepoint clause constructs were used directly with an unbounded umber of savepoint names. The memory leak does **not** impact the vast majority of cases as normally the :class:`.Connection`, which renders savepoint names with a simple counter starting at "1", is used on a per-transaction or per-fixed-number-of-transactions basis before being discarded. The savepoint name in virtually all cases does not require quoting at all, however to support potential third party use cases the "check for quotes needed" logic is retained, at a small performance cost. Uncondtionally quoting the name is another option, but this would turn the name into a case sensitive name which runs the risk of poor interactions with existing deployments that may be looking at these names in other contexts. Change-Id: I6b53c96abf7fdf1840592bbca5da81347911844c Fixes: #3931
3035 lines
105 KiB
Python
3035 lines
105 KiB
Python
# sql/compiler.py
|
|
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
|
# <see AUTHORS file>
|
|
#
|
|
# This module is part of SQLAlchemy and is released under
|
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
|
|
|
"""Base SQL and DDL compiler implementations.
|
|
|
|
Classes provided include:
|
|
|
|
:class:`.compiler.SQLCompiler` - renders SQL
|
|
strings
|
|
|
|
:class:`.compiler.DDLCompiler` - renders DDL
|
|
(data definition language) strings
|
|
|
|
:class:`.compiler.GenericTypeCompiler` - renders
|
|
type specification strings.
|
|
|
|
To generate user-defined SQL strings, see
|
|
:doc:`/ext/compiler`.
|
|
|
|
"""
|
|
|
|
import contextlib
|
|
import re
|
|
from . import schema, sqltypes, operators, functions, visitors, \
|
|
elements, selectable, crud
|
|
from .. import util, exc
|
|
import itertools
|
|
|
|
RESERVED_WORDS = set([
|
|
'all', 'analyse', 'analyze', 'and', 'any', 'array',
|
|
'as', 'asc', 'asymmetric', 'authorization', 'between',
|
|
'binary', 'both', 'case', 'cast', 'check', 'collate',
|
|
'column', 'constraint', 'create', 'cross', 'current_date',
|
|
'current_role', 'current_time', 'current_timestamp',
|
|
'current_user', 'default', 'deferrable', 'desc',
|
|
'distinct', 'do', 'else', 'end', 'except', 'false',
|
|
'for', 'foreign', 'freeze', 'from', 'full', 'grant',
|
|
'group', 'having', 'ilike', 'in', 'initially', 'inner',
|
|
'intersect', 'into', 'is', 'isnull', 'join', 'leading',
|
|
'left', 'like', 'limit', 'localtime', 'localtimestamp',
|
|
'natural', 'new', 'not', 'notnull', 'null', 'off', 'offset',
|
|
'old', 'on', 'only', 'or', 'order', 'outer', 'overlaps',
|
|
'placing', 'primary', 'references', 'right', 'select',
|
|
'session_user', 'set', 'similar', 'some', 'symmetric', 'table',
|
|
'then', 'to', 'trailing', 'true', 'union', 'unique', 'user',
|
|
'using', 'verbose', 'when', 'where'])
|
|
|
|
LEGAL_CHARACTERS = re.compile(r'^[A-Z0-9_$]+$', re.I)
|
|
ILLEGAL_INITIAL_CHARACTERS = set([str(x) for x in range(0, 10)]).union(['$'])
|
|
|
|
BIND_PARAMS = re.compile(r'(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])', re.UNICODE)
|
|
BIND_PARAMS_ESC = re.compile(r'\x5c(:[\w\$]*)(?![:\w\$])', re.UNICODE)
|
|
|
|
BIND_TEMPLATES = {
|
|
'pyformat': "%%(%(name)s)s",
|
|
'qmark': "?",
|
|
'format': "%%s",
|
|
'numeric': ":[_POSITION]",
|
|
'named': ":%(name)s"
|
|
}
|
|
|
|
|
|
OPERATORS = {
|
|
# binary
|
|
operators.and_: ' AND ',
|
|
operators.or_: ' OR ',
|
|
operators.add: ' + ',
|
|
operators.mul: ' * ',
|
|
operators.sub: ' - ',
|
|
operators.div: ' / ',
|
|
operators.mod: ' % ',
|
|
operators.truediv: ' / ',
|
|
operators.neg: '-',
|
|
operators.lt: ' < ',
|
|
operators.le: ' <= ',
|
|
operators.ne: ' != ',
|
|
operators.gt: ' > ',
|
|
operators.ge: ' >= ',
|
|
operators.eq: ' = ',
|
|
operators.is_distinct_from: ' IS DISTINCT FROM ',
|
|
operators.isnot_distinct_from: ' IS NOT DISTINCT FROM ',
|
|
operators.concat_op: ' || ',
|
|
operators.match_op: ' MATCH ',
|
|
operators.notmatch_op: ' NOT MATCH ',
|
|
operators.in_op: ' IN ',
|
|
operators.notin_op: ' NOT IN ',
|
|
operators.comma_op: ', ',
|
|
operators.from_: ' FROM ',
|
|
operators.as_: ' AS ',
|
|
operators.is_: ' IS ',
|
|
operators.isnot: ' IS NOT ',
|
|
operators.collate: ' COLLATE ',
|
|
|
|
# unary
|
|
operators.exists: 'EXISTS ',
|
|
operators.distinct_op: 'DISTINCT ',
|
|
operators.inv: 'NOT ',
|
|
operators.any_op: 'ANY ',
|
|
operators.all_op: 'ALL ',
|
|
|
|
# modifiers
|
|
operators.desc_op: ' DESC',
|
|
operators.asc_op: ' ASC',
|
|
operators.nullsfirst_op: ' NULLS FIRST',
|
|
operators.nullslast_op: ' NULLS LAST',
|
|
|
|
}
|
|
|
|
FUNCTIONS = {
|
|
functions.coalesce: 'coalesce%(expr)s',
|
|
functions.current_date: 'CURRENT_DATE',
|
|
functions.current_time: 'CURRENT_TIME',
|
|
functions.current_timestamp: 'CURRENT_TIMESTAMP',
|
|
functions.current_user: 'CURRENT_USER',
|
|
functions.localtime: 'LOCALTIME',
|
|
functions.localtimestamp: 'LOCALTIMESTAMP',
|
|
functions.random: 'random%(expr)s',
|
|
functions.sysdate: 'sysdate',
|
|
functions.session_user: 'SESSION_USER',
|
|
functions.user: 'USER'
|
|
}
|
|
|
|
EXTRACT_MAP = {
|
|
'month': 'month',
|
|
'day': 'day',
|
|
'year': 'year',
|
|
'second': 'second',
|
|
'hour': 'hour',
|
|
'doy': 'doy',
|
|
'minute': 'minute',
|
|
'quarter': 'quarter',
|
|
'dow': 'dow',
|
|
'week': 'week',
|
|
'epoch': 'epoch',
|
|
'milliseconds': 'milliseconds',
|
|
'microseconds': 'microseconds',
|
|
'timezone_hour': 'timezone_hour',
|
|
'timezone_minute': 'timezone_minute'
|
|
}
|
|
|
|
COMPOUND_KEYWORDS = {
|
|
selectable.CompoundSelect.UNION: 'UNION',
|
|
selectable.CompoundSelect.UNION_ALL: 'UNION ALL',
|
|
selectable.CompoundSelect.EXCEPT: 'EXCEPT',
|
|
selectable.CompoundSelect.EXCEPT_ALL: 'EXCEPT ALL',
|
|
selectable.CompoundSelect.INTERSECT: 'INTERSECT',
|
|
selectable.CompoundSelect.INTERSECT_ALL: 'INTERSECT ALL'
|
|
}
|
|
|
|
|
|
class Compiled(object):
|
|
|
|
"""Represent a compiled SQL or DDL expression.
|
|
|
|
The ``__str__`` method of the ``Compiled`` object should produce
|
|
the actual text of the statement. ``Compiled`` objects are
|
|
specific to their underlying database dialect, and also may
|
|
or may not be specific to the columns referenced within a
|
|
particular set of bind parameters. In no case should the
|
|
``Compiled`` object be dependent on the actual values of those
|
|
bind parameters, even though it may reference those values as
|
|
defaults.
|
|
"""
|
|
|
|
_cached_metadata = None
|
|
|
|
execution_options = util.immutabledict()
|
|
"""
|
|
Execution options propagated from the statement. In some cases,
|
|
sub-elements of the statement can modify these.
|
|
"""
|
|
|
|
def __init__(self, dialect, statement, bind=None,
|
|
schema_translate_map=None,
|
|
compile_kwargs=util.immutabledict()):
|
|
"""Construct a new :class:`.Compiled` object.
|
|
|
|
:param dialect: :class:`.Dialect` to compile against.
|
|
|
|
:param statement: :class:`.ClauseElement` to be compiled.
|
|
|
|
:param bind: Optional Engine or Connection to compile this
|
|
statement against.
|
|
|
|
:param schema_translate_map: dictionary of schema names to be
|
|
translated when forming the resultant SQL
|
|
|
|
.. versionadded:: 1.1
|
|
|
|
.. seealso::
|
|
|
|
:ref:`schema_translating`
|
|
|
|
:param compile_kwargs: additional kwargs that will be
|
|
passed to the initial call to :meth:`.Compiled.process`.
|
|
|
|
|
|
"""
|
|
|
|
self.dialect = dialect
|
|
self.bind = bind
|
|
self.preparer = self.dialect.identifier_preparer
|
|
if schema_translate_map:
|
|
self.preparer = self.preparer._with_schema_translate(
|
|
schema_translate_map)
|
|
|
|
if statement is not None:
|
|
self.statement = statement
|
|
self.can_execute = statement.supports_execution
|
|
if self.can_execute:
|
|
self.execution_options = statement._execution_options
|
|
self.string = self.process(self.statement, **compile_kwargs)
|
|
|
|
@util.deprecated("0.7", ":class:`.Compiled` objects now compile "
|
|
"within the constructor.")
|
|
def compile(self):
|
|
"""Produce the internal string representation of this element.
|
|
"""
|
|
pass
|
|
|
|
def _execute_on_connection(self, connection, multiparams, params):
|
|
if self.can_execute:
|
|
return connection._execute_compiled(self, multiparams, params)
|
|
else:
|
|
raise exc.ObjectNotExecutableError(self.statement)
|
|
|
|
@property
|
|
def sql_compiler(self):
|
|
"""Return a Compiled that is capable of processing SQL expressions.
|
|
|
|
If this compiler is one, it would likely just return 'self'.
|
|
|
|
"""
|
|
|
|
raise NotImplementedError()
|
|
|
|
def process(self, obj, **kwargs):
|
|
return obj._compiler_dispatch(self, **kwargs)
|
|
|
|
def __str__(self):
|
|
"""Return the string text of the generated SQL or DDL."""
|
|
|
|
return self.string or ''
|
|
|
|
def construct_params(self, params=None):
|
|
"""Return the bind params for this compiled object.
|
|
|
|
:param params: a dict of string/object pairs whose values will
|
|
override bind values compiled in to the
|
|
statement.
|
|
"""
|
|
|
|
raise NotImplementedError()
|
|
|
|
@property
|
|
def params(self):
|
|
"""Return the bind params for this compiled object."""
|
|
return self.construct_params()
|
|
|
|
def execute(self, *multiparams, **params):
|
|
"""Execute this compiled object."""
|
|
|
|
e = self.bind
|
|
if e is None:
|
|
raise exc.UnboundExecutionError(
|
|
"This Compiled object is not bound to any Engine "
|
|
"or Connection.")
|
|
return e._execute_compiled(self, multiparams, params)
|
|
|
|
def scalar(self, *multiparams, **params):
|
|
"""Execute this compiled object and return the result's
|
|
scalar value."""
|
|
|
|
return self.execute(*multiparams, **params).scalar()
|
|
|
|
|
|
class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)):
|
|
"""Produces DDL specification for TypeEngine objects."""
|
|
|
|
ensure_kwarg = r'visit_\w+'
|
|
|
|
def __init__(self, dialect):
|
|
self.dialect = dialect
|
|
|
|
def process(self, type_, **kw):
|
|
return type_._compiler_dispatch(self, **kw)
|
|
|
|
|
|
class _CompileLabel(visitors.Visitable):
|
|
|
|
"""lightweight label object which acts as an expression.Label."""
|
|
|
|
__visit_name__ = 'label'
|
|
__slots__ = 'element', 'name'
|
|
|
|
def __init__(self, col, name, alt_names=()):
|
|
self.element = col
|
|
self.name = name
|
|
self._alt_names = (col,) + alt_names
|
|
|
|
@property
|
|
def proxy_set(self):
|
|
return self.element.proxy_set
|
|
|
|
@property
|
|
def type(self):
|
|
return self.element.type
|
|
|
|
def self_group(self, **kw):
|
|
return self
|
|
|
|
|
|
class SQLCompiler(Compiled):
|
|
"""Default implementation of :class:`.Compiled`.
|
|
|
|
Compiles :class:`.ClauseElement` objects into SQL strings.
|
|
|
|
"""
|
|
|
|
extract_map = EXTRACT_MAP
|
|
|
|
compound_keywords = COMPOUND_KEYWORDS
|
|
|
|
isdelete = isinsert = isupdate = False
|
|
"""class-level defaults which can be set at the instance
|
|
level to define if this Compiled instance represents
|
|
INSERT/UPDATE/DELETE
|
|
"""
|
|
|
|
isplaintext = False
|
|
|
|
returning = None
|
|
"""holds the "returning" collection of columns if
|
|
the statement is CRUD and defines returning columns
|
|
either implicitly or explicitly
|
|
"""
|
|
|
|
returning_precedes_values = False
|
|
"""set to True classwide to generate RETURNING
|
|
clauses before the VALUES or WHERE clause (i.e. MSSQL)
|
|
"""
|
|
|
|
render_table_with_column_in_update_from = False
|
|
"""set to True classwide to indicate the SET clause
|
|
in a multi-table UPDATE statement should qualify
|
|
columns with the table name (i.e. MySQL only)
|
|
"""
|
|
|
|
ansi_bind_rules = False
|
|
"""SQL 92 doesn't allow bind parameters to be used
|
|
in the columns clause of a SELECT, nor does it allow
|
|
ambiguous expressions like "? = ?". A compiler
|
|
subclass can set this flag to False if the target
|
|
driver/DB enforces this
|
|
"""
|
|
|
|
_textual_ordered_columns = False
|
|
"""tell the result object that the column names as rendered are important,
|
|
but they are also "ordered" vs. what is in the compiled object here.
|
|
"""
|
|
|
|
_ordered_columns = True
|
|
"""
|
|
if False, means we can't be sure the list of entries
|
|
in _result_columns is actually the rendered order. Usually
|
|
True unless using an unordered TextAsFrom.
|
|
"""
|
|
|
|
insert_prefetch = update_prefetch = ()
|
|
|
|
|
|
def __init__(self, dialect, statement, column_keys=None,
|
|
inline=False, **kwargs):
|
|
"""Construct a new :class:`.SQLCompiler` object.
|
|
|
|
:param dialect: :class:`.Dialect` to be used
|
|
|
|
:param statement: :class:`.ClauseElement` to be compiled
|
|
|
|
:param column_keys: a list of column names to be compiled into an
|
|
INSERT or UPDATE statement.
|
|
|
|
:param inline: whether to generate INSERT statements as "inline", e.g.
|
|
not formatted to return any generated defaults
|
|
|
|
:param kwargs: additional keyword arguments to be consumed by the
|
|
superclass.
|
|
|
|
"""
|
|
self.column_keys = column_keys
|
|
|
|
# compile INSERT/UPDATE defaults/sequences inlined (no pre-
|
|
# execute)
|
|
self.inline = inline or getattr(statement, 'inline', False)
|
|
|
|
# a dictionary of bind parameter keys to BindParameter
|
|
# instances.
|
|
self.binds = {}
|
|
|
|
# a dictionary of BindParameter instances to "compiled" names
|
|
# that are actually present in the generated SQL
|
|
self.bind_names = util.column_dict()
|
|
|
|
# stack which keeps track of nested SELECT statements
|
|
self.stack = []
|
|
|
|
# relates label names in the final SQL to a tuple of local
|
|
# column/label name, ColumnElement object (if any) and
|
|
# TypeEngine. ResultProxy uses this for type processing and
|
|
# column targeting
|
|
self._result_columns = []
|
|
|
|
# true if the paramstyle is positional
|
|
self.positional = dialect.positional
|
|
if self.positional:
|
|
self.positiontup = []
|
|
self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
|
|
|
|
self.ctes = None
|
|
|
|
self.label_length = dialect.label_length \
|
|
or dialect.max_identifier_length
|
|
|
|
# a map which tracks "anonymous" identifiers that are created on
|
|
# the fly here
|
|
self.anon_map = util.PopulateDict(self._process_anon)
|
|
|
|
# a map which tracks "truncated" names based on
|
|
# dialect.label_length or dialect.max_identifier_length
|
|
self.truncated_names = {}
|
|
Compiled.__init__(self, dialect, statement, **kwargs)
|
|
|
|
if (
|
|
self.isinsert or self.isupdate or self.isdelete
|
|
) and statement._returning:
|
|
self.returning = statement._returning
|
|
|
|
if self.positional and dialect.paramstyle == 'numeric':
|
|
self._apply_numbered_params()
|
|
|
|
@property
|
|
def prefetch(self):
|
|
return list(self.insert_prefetch + self.update_prefetch)
|
|
|
|
@util.memoized_instancemethod
|
|
def _init_cte_state(self):
|
|
"""Initialize collections related to CTEs only if
|
|
a CTE is located, to save on the overhead of
|
|
these collections otherwise.
|
|
|
|
"""
|
|
# collect CTEs to tack on top of a SELECT
|
|
self.ctes = util.OrderedDict()
|
|
self.ctes_by_name = {}
|
|
self.ctes_recursive = False
|
|
if self.positional:
|
|
self.cte_positional = {}
|
|
|
|
@contextlib.contextmanager
|
|
def _nested_result(self):
|
|
"""special API to support the use case of 'nested result sets'"""
|
|
result_columns, ordered_columns = (
|
|
self._result_columns, self._ordered_columns)
|
|
self._result_columns, self._ordered_columns = [], False
|
|
|
|
try:
|
|
if self.stack:
|
|
entry = self.stack[-1]
|
|
entry['need_result_map_for_nested'] = True
|
|
else:
|
|
entry = None
|
|
yield self._result_columns, self._ordered_columns
|
|
finally:
|
|
if entry:
|
|
entry.pop('need_result_map_for_nested')
|
|
self._result_columns, self._ordered_columns = (
|
|
result_columns, ordered_columns)
|
|
|
|
def _apply_numbered_params(self):
|
|
poscount = itertools.count(1)
|
|
self.string = re.sub(
|
|
r'\[_POSITION\]',
|
|
lambda m: str(util.next(poscount)),
|
|
self.string)
|
|
|
|
@util.memoized_property
|
|
def _bind_processors(self):
|
|
return dict(
|
|
(key, value) for key, value in
|
|
((self.bind_names[bindparam],
|
|
bindparam.type._cached_bind_processor(self.dialect))
|
|
for bindparam in self.bind_names)
|
|
if value is not None
|
|
)
|
|
|
|
def is_subquery(self):
|
|
return len(self.stack) > 1
|
|
|
|
@property
|
|
def sql_compiler(self):
|
|
return self
|
|
|
|
def construct_params(self, params=None, _group_number=None, _check=True):
|
|
"""return a dictionary of bind parameter keys and values"""
|
|
|
|
if params:
|
|
pd = {}
|
|
for bindparam in self.bind_names:
|
|
name = self.bind_names[bindparam]
|
|
if bindparam.key in params:
|
|
pd[name] = params[bindparam.key]
|
|
elif name in params:
|
|
pd[name] = params[name]
|
|
|
|
elif _check and bindparam.required:
|
|
if _group_number:
|
|
raise exc.InvalidRequestError(
|
|
"A value is required for bind parameter %r, "
|
|
"in parameter group %d" %
|
|
(bindparam.key, _group_number))
|
|
else:
|
|
raise exc.InvalidRequestError(
|
|
"A value is required for bind parameter %r"
|
|
% bindparam.key)
|
|
|
|
elif bindparam.callable:
|
|
pd[name] = bindparam.effective_value
|
|
else:
|
|
pd[name] = bindparam.value
|
|
return pd
|
|
else:
|
|
pd = {}
|
|
for bindparam in self.bind_names:
|
|
if _check and bindparam.required:
|
|
if _group_number:
|
|
raise exc.InvalidRequestError(
|
|
"A value is required for bind parameter %r, "
|
|
"in parameter group %d" %
|
|
(bindparam.key, _group_number))
|
|
else:
|
|
raise exc.InvalidRequestError(
|
|
"A value is required for bind parameter %r"
|
|
% bindparam.key)
|
|
|
|
if bindparam.callable:
|
|
pd[self.bind_names[bindparam]] = bindparam.effective_value
|
|
else:
|
|
pd[self.bind_names[bindparam]] = bindparam.value
|
|
return pd
|
|
|
|
@property
|
|
def params(self):
|
|
"""Return the bind param dictionary embedded into this
|
|
compiled object, for those values that are present."""
|
|
return self.construct_params(_check=False)
|
|
|
|
@util.dependencies("sqlalchemy.engine.result")
|
|
def _create_result_map(self, result):
|
|
"""utility method used for unit tests only."""
|
|
return result.ResultMetaData._create_result_map(self._result_columns)
|
|
|
|
def default_from(self):
|
|
"""Called when a SELECT statement has no froms, and no FROM clause is
|
|
to be appended.
|
|
|
|
Gives Oracle a chance to tack on a ``FROM DUAL`` to the string output.
|
|
|
|
"""
|
|
return ""
|
|
|
|
def visit_grouping(self, grouping, asfrom=False, **kwargs):
|
|
return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
|
|
|
|
def visit_label_reference(
|
|
self, element, within_columns_clause=False, **kwargs):
|
|
if self.stack and self.dialect.supports_simple_order_by_label:
|
|
selectable = self.stack[-1]['selectable']
|
|
|
|
with_cols, only_froms, only_cols = selectable._label_resolve_dict
|
|
if within_columns_clause:
|
|
resolve_dict = only_froms
|
|
else:
|
|
resolve_dict = only_cols
|
|
|
|
# this can be None in the case that a _label_reference()
|
|
# were subject to a replacement operation, in which case
|
|
# the replacement of the Label element may have changed
|
|
# to something else like a ColumnClause expression.
|
|
order_by_elem = element.element._order_by_label_element
|
|
|
|
if order_by_elem is not None and order_by_elem.name in \
|
|
resolve_dict and \
|
|
order_by_elem.shares_lineage(
|
|
resolve_dict[order_by_elem.name]):
|
|
kwargs['render_label_as_label'] = \
|
|
element.element._order_by_label_element
|
|
return self.process(
|
|
element.element, within_columns_clause=within_columns_clause,
|
|
**kwargs)
|
|
|
|
def visit_textual_label_reference(
|
|
self, element, within_columns_clause=False, **kwargs):
|
|
if not self.stack:
|
|
# compiling the element outside of the context of a SELECT
|
|
return self.process(
|
|
element._text_clause
|
|
)
|
|
|
|
selectable = self.stack[-1]['selectable']
|
|
with_cols, only_froms, only_cols = selectable._label_resolve_dict
|
|
try:
|
|
if within_columns_clause:
|
|
col = only_froms[element.element]
|
|
else:
|
|
col = with_cols[element.element]
|
|
except KeyError:
|
|
# treat it like text()
|
|
util.warn_limited(
|
|
"Can't resolve label reference %r; converting to text()",
|
|
util.ellipses_string(element.element))
|
|
return self.process(
|
|
element._text_clause
|
|
)
|
|
else:
|
|
kwargs['render_label_as_label'] = col
|
|
return self.process(
|
|
col, within_columns_clause=within_columns_clause, **kwargs)
|
|
|
|
def visit_label(self, label,
|
|
add_to_result_map=None,
|
|
within_label_clause=False,
|
|
within_columns_clause=False,
|
|
render_label_as_label=None,
|
|
**kw):
|
|
# only render labels within the columns clause
|
|
# or ORDER BY clause of a select. dialect-specific compilers
|
|
# can modify this behavior.
|
|
render_label_with_as = (within_columns_clause and not
|
|
within_label_clause)
|
|
render_label_only = render_label_as_label is label
|
|
|
|
if render_label_only or render_label_with_as:
|
|
if isinstance(label.name, elements._truncated_label):
|
|
labelname = self._truncated_identifier("colident", label.name)
|
|
else:
|
|
labelname = label.name
|
|
|
|
if render_label_with_as:
|
|
if add_to_result_map is not None:
|
|
add_to_result_map(
|
|
labelname,
|
|
label.name,
|
|
(label, labelname, ) + label._alt_names,
|
|
label.type
|
|
)
|
|
|
|
return label.element._compiler_dispatch(
|
|
self, within_columns_clause=True,
|
|
within_label_clause=True, **kw) + \
|
|
OPERATORS[operators.as_] + \
|
|
self.preparer.format_label(label, labelname)
|
|
elif render_label_only:
|
|
return self.preparer.format_label(label, labelname)
|
|
else:
|
|
return label.element._compiler_dispatch(
|
|
self, within_columns_clause=False, **kw)
|
|
|
|
def _fallback_column_name(self, column):
|
|
raise exc.CompileError("Cannot compile Column object until "
|
|
"its 'name' is assigned.")
|
|
|
|
def visit_column(self, column, add_to_result_map=None,
|
|
include_table=True, **kwargs):
|
|
name = orig_name = column.name
|
|
if name is None:
|
|
name = self._fallback_column_name(column)
|
|
|
|
is_literal = column.is_literal
|
|
if not is_literal and isinstance(name, elements._truncated_label):
|
|
name = self._truncated_identifier("colident", name)
|
|
|
|
if add_to_result_map is not None:
|
|
add_to_result_map(
|
|
name,
|
|
orig_name,
|
|
(column, name, column.key),
|
|
column.type
|
|
)
|
|
|
|
if is_literal:
|
|
name = self.escape_literal_column(name)
|
|
else:
|
|
name = self.preparer.quote(name)
|
|
|
|
table = column.table
|
|
if table is None or not include_table or not table.named_with_column:
|
|
return name
|
|
else:
|
|
effective_schema = self.preparer.schema_for_object(table)
|
|
|
|
if effective_schema:
|
|
schema_prefix = self.preparer.quote_schema(
|
|
effective_schema) + '.'
|
|
else:
|
|
schema_prefix = ''
|
|
tablename = table.name
|
|
if isinstance(tablename, elements._truncated_label):
|
|
tablename = self._truncated_identifier("alias", tablename)
|
|
|
|
return schema_prefix + \
|
|
self.preparer.quote(tablename) + \
|
|
"." + name
|
|
|
|
def escape_literal_column(self, text):
|
|
"""provide escaping for the literal_column() construct."""
|
|
|
|
# TODO: some dialects might need different behavior here
|
|
return text.replace('%', '%%')
|
|
|
|
def visit_fromclause(self, fromclause, **kwargs):
|
|
return fromclause.name
|
|
|
|
def visit_index(self, index, **kwargs):
|
|
return index.name
|
|
|
|
def visit_typeclause(self, typeclause, **kw):
|
|
kw['type_expression'] = typeclause
|
|
return self.dialect.type_compiler.process(typeclause.type, **kw)
|
|
|
|
def post_process_text(self, text):
|
|
return text
|
|
|
|
def visit_textclause(self, textclause, **kw):
|
|
def do_bindparam(m):
|
|
name = m.group(1)
|
|
if name in textclause._bindparams:
|
|
return self.process(textclause._bindparams[name], **kw)
|
|
else:
|
|
return self.bindparam_string(name, **kw)
|
|
|
|
if not self.stack:
|
|
self.isplaintext = True
|
|
|
|
# un-escape any \:params
|
|
return BIND_PARAMS_ESC.sub(
|
|
lambda m: m.group(1),
|
|
BIND_PARAMS.sub(
|
|
do_bindparam,
|
|
self.post_process_text(textclause.text))
|
|
)
|
|
|
|
def visit_text_as_from(self, taf,
|
|
compound_index=None,
|
|
asfrom=False,
|
|
parens=True, **kw):
|
|
|
|
toplevel = not self.stack
|
|
entry = self._default_stack_entry if toplevel else self.stack[-1]
|
|
|
|
populate_result_map = toplevel or \
|
|
(
|
|
compound_index == 0 and entry.get(
|
|
'need_result_map_for_compound', False)
|
|
) or entry.get('need_result_map_for_nested', False)
|
|
|
|
if populate_result_map:
|
|
self._ordered_columns = \
|
|
self._textual_ordered_columns = taf.positional
|
|
for c in taf.column_args:
|
|
self.process(c, within_columns_clause=True,
|
|
add_to_result_map=self._add_to_result_map)
|
|
|
|
text = self.process(taf.element, **kw)
|
|
if asfrom and parens:
|
|
text = "(%s)" % text
|
|
return text
|
|
|
|
def visit_null(self, expr, **kw):
|
|
return 'NULL'
|
|
|
|
def visit_true(self, expr, **kw):
|
|
if self.dialect.supports_native_boolean:
|
|
return 'true'
|
|
else:
|
|
return "1"
|
|
|
|
def visit_false(self, expr, **kw):
|
|
if self.dialect.supports_native_boolean:
|
|
return 'false'
|
|
else:
|
|
return "0"
|
|
|
|
def visit_clauselist(self, clauselist, **kw):
|
|
sep = clauselist.operator
|
|
if sep is None:
|
|
sep = " "
|
|
else:
|
|
sep = OPERATORS[clauselist.operator]
|
|
return sep.join(
|
|
s for s in
|
|
(
|
|
c._compiler_dispatch(self, **kw)
|
|
for c in clauselist.clauses)
|
|
if s)
|
|
|
|
def visit_case(self, clause, **kwargs):
|
|
x = "CASE "
|
|
if clause.value is not None:
|
|
x += clause.value._compiler_dispatch(self, **kwargs) + " "
|
|
for cond, result in clause.whens:
|
|
x += "WHEN " + cond._compiler_dispatch(
|
|
self, **kwargs
|
|
) + " THEN " + result._compiler_dispatch(
|
|
self, **kwargs) + " "
|
|
if clause.else_ is not None:
|
|
x += "ELSE " + clause.else_._compiler_dispatch(
|
|
self, **kwargs
|
|
) + " "
|
|
x += "END"
|
|
return x
|
|
|
|
def visit_type_coerce(self, type_coerce, **kw):
|
|
return type_coerce.typed_expression._compiler_dispatch(self, **kw)
|
|
|
|
def visit_cast(self, cast, **kwargs):
|
|
return "CAST(%s AS %s)" % \
|
|
(cast.clause._compiler_dispatch(self, **kwargs),
|
|
cast.typeclause._compiler_dispatch(self, **kwargs))
|
|
|
|
def _format_frame_clause(self, range_, **kw):
|
|
return '%s AND %s' % (
|
|
"UNBOUNDED PRECEDING"
|
|
if range_[0] is elements.RANGE_UNBOUNDED
|
|
else "CURRENT ROW" if range_[0] is elements.RANGE_CURRENT
|
|
else "%s PRECEDING" % (self.process(range_[0], **kw), ),
|
|
|
|
"UNBOUNDED FOLLOWING"
|
|
if range_[1] is elements.RANGE_UNBOUNDED
|
|
else "CURRENT ROW" if range_[1] is elements.RANGE_CURRENT
|
|
else "%s FOLLOWING" % (self.process(range_[1], **kw), )
|
|
)
|
|
|
|
def visit_over(self, over, **kwargs):
|
|
if over.range_:
|
|
range_ = "RANGE BETWEEN %s" % self._format_frame_clause(
|
|
over.range_, **kwargs)
|
|
elif over.rows:
|
|
range_ = "ROWS BETWEEN %s" % self._format_frame_clause(
|
|
over.rows, **kwargs)
|
|
else:
|
|
range_ = None
|
|
|
|
return "%s OVER (%s)" % (
|
|
over.element._compiler_dispatch(self, **kwargs),
|
|
' '.join([
|
|
'%s BY %s' % (
|
|
word, clause._compiler_dispatch(self, **kwargs)
|
|
)
|
|
for word, clause in (
|
|
('PARTITION', over.partition_by),
|
|
('ORDER', over.order_by)
|
|
)
|
|
if clause is not None and len(clause)
|
|
] + ([range_] if range_ else [])
|
|
)
|
|
)
|
|
|
|
def visit_withingroup(self, withingroup, **kwargs):
|
|
return "%s WITHIN GROUP (ORDER BY %s)" % (
|
|
withingroup.element._compiler_dispatch(self, **kwargs),
|
|
withingroup.order_by._compiler_dispatch(self, **kwargs)
|
|
)
|
|
|
|
def visit_funcfilter(self, funcfilter, **kwargs):
|
|
return "%s FILTER (WHERE %s)" % (
|
|
funcfilter.func._compiler_dispatch(self, **kwargs),
|
|
funcfilter.criterion._compiler_dispatch(self, **kwargs)
|
|
)
|
|
|
|
def visit_extract(self, extract, **kwargs):
|
|
field = self.extract_map.get(extract.field, extract.field)
|
|
return "EXTRACT(%s FROM %s)" % (
|
|
field, extract.expr._compiler_dispatch(self, **kwargs))
|
|
|
|
def visit_function(self, func, add_to_result_map=None, **kwargs):
|
|
if add_to_result_map is not None:
|
|
add_to_result_map(
|
|
func.name, func.name, (), func.type
|
|
)
|
|
|
|
disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
|
|
if disp:
|
|
return disp(func, **kwargs)
|
|
else:
|
|
name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s")
|
|
return ".".join(list(func.packagenames) + [name]) % \
|
|
{'expr': self.function_argspec(func, **kwargs)}
|
|
|
|
def visit_next_value_func(self, next_value, **kw):
|
|
return self.visit_sequence(next_value.sequence)
|
|
|
|
def visit_sequence(self, sequence):
|
|
raise NotImplementedError(
|
|
"Dialect '%s' does not support sequence increments." %
|
|
self.dialect.name
|
|
)
|
|
|
|
def function_argspec(self, func, **kwargs):
|
|
return func.clause_expr._compiler_dispatch(self, **kwargs)
|
|
|
|
def visit_compound_select(self, cs, asfrom=False,
|
|
parens=True, compound_index=0, **kwargs):
|
|
toplevel = not self.stack
|
|
entry = self._default_stack_entry if toplevel else self.stack[-1]
|
|
need_result_map = toplevel or \
|
|
(compound_index == 0
|
|
and entry.get('need_result_map_for_compound', False))
|
|
|
|
self.stack.append(
|
|
{
|
|
'correlate_froms': entry['correlate_froms'],
|
|
'asfrom_froms': entry['asfrom_froms'],
|
|
'selectable': cs,
|
|
'need_result_map_for_compound': need_result_map
|
|
})
|
|
|
|
keyword = self.compound_keywords.get(cs.keyword)
|
|
|
|
text = (" " + keyword + " ").join(
|
|
(c._compiler_dispatch(self,
|
|
asfrom=asfrom, parens=False,
|
|
compound_index=i, **kwargs)
|
|
for i, c in enumerate(cs.selects))
|
|
)
|
|
|
|
group_by = cs._group_by_clause._compiler_dispatch(
|
|
self, asfrom=asfrom, **kwargs)
|
|
if group_by:
|
|
text += " GROUP BY " + group_by
|
|
|
|
text += self.order_by_clause(cs, **kwargs)
|
|
text += (cs._limit_clause is not None
|
|
or cs._offset_clause is not None) and \
|
|
self.limit_clause(cs, **kwargs) or ""
|
|
|
|
if self.ctes and toplevel:
|
|
text = self._render_cte_clause() + text
|
|
|
|
self.stack.pop(-1)
|
|
if asfrom and parens:
|
|
return "(" + text + ")"
|
|
else:
|
|
return text
|
|
|
|
def _get_operator_dispatch(self, operator_, qualifier1, qualifier2):
|
|
attrname = "visit_%s_%s%s" % (
|
|
operator_.__name__, qualifier1,
|
|
"_" + qualifier2 if qualifier2 else "")
|
|
return getattr(self, attrname, None)
|
|
|
|
def visit_unary(self, unary, **kw):
|
|
if unary.operator:
|
|
if unary.modifier:
|
|
raise exc.CompileError(
|
|
"Unary expression does not support operator "
|
|
"and modifier simultaneously")
|
|
disp = self._get_operator_dispatch(
|
|
unary.operator, "unary", "operator")
|
|
if disp:
|
|
return disp(unary, unary.operator, **kw)
|
|
else:
|
|
return self._generate_generic_unary_operator(
|
|
unary, OPERATORS[unary.operator], **kw)
|
|
elif unary.modifier:
|
|
disp = self._get_operator_dispatch(
|
|
unary.modifier, "unary", "modifier")
|
|
if disp:
|
|
return disp(unary, unary.modifier, **kw)
|
|
else:
|
|
return self._generate_generic_unary_modifier(
|
|
unary, OPERATORS[unary.modifier], **kw)
|
|
else:
|
|
raise exc.CompileError(
|
|
"Unary expression has no operator or modifier")
|
|
|
|
def visit_istrue_unary_operator(self, element, operator, **kw):
|
|
if self.dialect.supports_native_boolean:
|
|
return self.process(element.element, **kw)
|
|
else:
|
|
return "%s = 1" % self.process(element.element, **kw)
|
|
|
|
def visit_isfalse_unary_operator(self, element, operator, **kw):
|
|
if self.dialect.supports_native_boolean:
|
|
return "NOT %s" % self.process(element.element, **kw)
|
|
else:
|
|
return "%s = 0" % self.process(element.element, **kw)
|
|
|
|
def visit_notmatch_op_binary(self, binary, operator, **kw):
|
|
return "NOT %s" % self.visit_binary(
|
|
binary, override_operator=operators.match_op)
|
|
|
|
def visit_binary(self, binary, override_operator=None,
|
|
eager_grouping=False, **kw):
|
|
|
|
# don't allow "? = ?" to render
|
|
if self.ansi_bind_rules and \
|
|
isinstance(binary.left, elements.BindParameter) and \
|
|
isinstance(binary.right, elements.BindParameter):
|
|
kw['literal_binds'] = True
|
|
|
|
operator_ = override_operator or binary.operator
|
|
disp = self._get_operator_dispatch(operator_, "binary", None)
|
|
if disp:
|
|
return disp(binary, operator_, **kw)
|
|
else:
|
|
try:
|
|
opstring = OPERATORS[operator_]
|
|
except KeyError:
|
|
raise exc.UnsupportedCompilationError(self, operator_)
|
|
else:
|
|
return self._generate_generic_binary(binary, opstring, **kw)
|
|
|
|
def visit_custom_op_binary(self, element, operator, **kw):
|
|
kw['eager_grouping'] = operator.eager_grouping
|
|
return self._generate_generic_binary(
|
|
element, " " + operator.opstring + " ", **kw)
|
|
|
|
def visit_custom_op_unary_operator(self, element, operator, **kw):
|
|
return self._generate_generic_unary_operator(
|
|
element, operator.opstring + " ", **kw)
|
|
|
|
def visit_custom_op_unary_modifier(self, element, operator, **kw):
|
|
return self._generate_generic_unary_modifier(
|
|
element, " " + operator.opstring, **kw)
|
|
|
|
def _generate_generic_binary(
|
|
self, binary, opstring, eager_grouping=False, **kw):
|
|
|
|
_in_binary = kw.get('_in_binary', False)
|
|
|
|
kw['_in_binary'] = True
|
|
text = binary.left._compiler_dispatch(
|
|
self, eager_grouping=eager_grouping, **kw) + \
|
|
opstring + \
|
|
binary.right._compiler_dispatch(
|
|
self, eager_grouping=eager_grouping, **kw)
|
|
|
|
if _in_binary and eager_grouping:
|
|
text = "(%s)" % text
|
|
return text
|
|
|
|
def _generate_generic_unary_operator(self, unary, opstring, **kw):
|
|
return opstring + unary.element._compiler_dispatch(self, **kw)
|
|
|
|
def _generate_generic_unary_modifier(self, unary, opstring, **kw):
|
|
return unary.element._compiler_dispatch(self, **kw) + opstring
|
|
|
|
@util.memoized_property
|
|
def _like_percent_literal(self):
|
|
return elements.literal_column("'%'", type_=sqltypes.STRINGTYPE)
|
|
|
|
def visit_contains_op_binary(self, binary, operator, **kw):
|
|
binary = binary._clone()
|
|
percent = self._like_percent_literal
|
|
binary.right = percent.__add__(binary.right).__add__(percent)
|
|
return self.visit_like_op_binary(binary, operator, **kw)
|
|
|
|
def visit_notcontains_op_binary(self, binary, operator, **kw):
|
|
binary = binary._clone()
|
|
percent = self._like_percent_literal
|
|
binary.right = percent.__add__(binary.right).__add__(percent)
|
|
return self.visit_notlike_op_binary(binary, operator, **kw)
|
|
|
|
def visit_startswith_op_binary(self, binary, operator, **kw):
|
|
binary = binary._clone()
|
|
percent = self._like_percent_literal
|
|
binary.right = percent.__radd__(
|
|
binary.right
|
|
)
|
|
return self.visit_like_op_binary(binary, operator, **kw)
|
|
|
|
def visit_notstartswith_op_binary(self, binary, operator, **kw):
|
|
binary = binary._clone()
|
|
percent = self._like_percent_literal
|
|
binary.right = percent.__radd__(
|
|
binary.right
|
|
)
|
|
return self.visit_notlike_op_binary(binary, operator, **kw)
|
|
|
|
def visit_endswith_op_binary(self, binary, operator, **kw):
|
|
binary = binary._clone()
|
|
percent = self._like_percent_literal
|
|
binary.right = percent.__add__(binary.right)
|
|
return self.visit_like_op_binary(binary, operator, **kw)
|
|
|
|
def visit_notendswith_op_binary(self, binary, operator, **kw):
|
|
binary = binary._clone()
|
|
percent = self._like_percent_literal
|
|
binary.right = percent.__add__(binary.right)
|
|
return self.visit_notlike_op_binary(binary, operator, **kw)
|
|
|
|
def visit_like_op_binary(self, binary, operator, **kw):
|
|
escape = binary.modifiers.get("escape", None)
|
|
|
|
# TODO: use ternary here, not "and"/ "or"
|
|
return '%s LIKE %s' % (
|
|
binary.left._compiler_dispatch(self, **kw),
|
|
binary.right._compiler_dispatch(self, **kw)) \
|
|
+ (
|
|
' ESCAPE ' +
|
|
self.render_literal_value(escape, sqltypes.STRINGTYPE)
|
|
if escape else ''
|
|
)
|
|
|
|
def visit_notlike_op_binary(self, binary, operator, **kw):
|
|
escape = binary.modifiers.get("escape", None)
|
|
return '%s NOT LIKE %s' % (
|
|
binary.left._compiler_dispatch(self, **kw),
|
|
binary.right._compiler_dispatch(self, **kw)) \
|
|
+ (
|
|
' ESCAPE ' +
|
|
self.render_literal_value(escape, sqltypes.STRINGTYPE)
|
|
if escape else ''
|
|
)
|
|
|
|
def visit_ilike_op_binary(self, binary, operator, **kw):
|
|
escape = binary.modifiers.get("escape", None)
|
|
return 'lower(%s) LIKE lower(%s)' % (
|
|
binary.left._compiler_dispatch(self, **kw),
|
|
binary.right._compiler_dispatch(self, **kw)) \
|
|
+ (
|
|
' ESCAPE ' +
|
|
self.render_literal_value(escape, sqltypes.STRINGTYPE)
|
|
if escape else ''
|
|
)
|
|
|
|
def visit_notilike_op_binary(self, binary, operator, **kw):
|
|
escape = binary.modifiers.get("escape", None)
|
|
return 'lower(%s) NOT LIKE lower(%s)' % (
|
|
binary.left._compiler_dispatch(self, **kw),
|
|
binary.right._compiler_dispatch(self, **kw)) \
|
|
+ (
|
|
' ESCAPE ' +
|
|
self.render_literal_value(escape, sqltypes.STRINGTYPE)
|
|
if escape else ''
|
|
)
|
|
|
|
def visit_between_op_binary(self, binary, operator, **kw):
|
|
symmetric = binary.modifiers.get("symmetric", False)
|
|
return self._generate_generic_binary(
|
|
binary, " BETWEEN SYMMETRIC "
|
|
if symmetric else " BETWEEN ", **kw)
|
|
|
|
def visit_notbetween_op_binary(self, binary, operator, **kw):
|
|
symmetric = binary.modifiers.get("symmetric", False)
|
|
return self._generate_generic_binary(
|
|
binary, " NOT BETWEEN SYMMETRIC "
|
|
if symmetric else " NOT BETWEEN ", **kw)
|
|
|
|
def visit_bindparam(self, bindparam, within_columns_clause=False,
|
|
literal_binds=False,
|
|
skip_bind_expression=False,
|
|
**kwargs):
|
|
if not skip_bind_expression and bindparam.type._has_bind_expression:
|
|
bind_expression = bindparam.type.bind_expression(bindparam)
|
|
return self.process(bind_expression,
|
|
skip_bind_expression=True)
|
|
|
|
if literal_binds or \
|
|
(within_columns_clause and
|
|
self.ansi_bind_rules):
|
|
if bindparam.value is None and bindparam.callable is None:
|
|
raise exc.CompileError("Bind parameter '%s' without a "
|
|
"renderable value not allowed here."
|
|
% bindparam.key)
|
|
return self.render_literal_bindparam(
|
|
bindparam, within_columns_clause=True, **kwargs)
|
|
|
|
name = self._truncate_bindparam(bindparam)
|
|
|
|
if name in self.binds:
|
|
existing = self.binds[name]
|
|
if existing is not bindparam:
|
|
if (existing.unique or bindparam.unique) and \
|
|
not existing.proxy_set.intersection(
|
|
bindparam.proxy_set):
|
|
raise exc.CompileError(
|
|
"Bind parameter '%s' conflicts with "
|
|
"unique bind parameter of the same name" %
|
|
bindparam.key
|
|
)
|
|
elif existing._is_crud or bindparam._is_crud:
|
|
raise exc.CompileError(
|
|
"bindparam() name '%s' is reserved "
|
|
"for automatic usage in the VALUES or SET "
|
|
"clause of this "
|
|
"insert/update statement. Please use a "
|
|
"name other than column name when using bindparam() "
|
|
"with insert() or update() (for example, 'b_%s')." %
|
|
(bindparam.key, bindparam.key)
|
|
)
|
|
|
|
self.binds[bindparam.key] = self.binds[name] = bindparam
|
|
|
|
return self.bindparam_string(name, **kwargs)
|
|
|
|
def render_literal_bindparam(self, bindparam, **kw):
|
|
value = bindparam.effective_value
|
|
return self.render_literal_value(value, bindparam.type)
|
|
|
|
def render_literal_value(self, value, type_):
|
|
"""Render the value of a bind parameter as a quoted literal.
|
|
|
|
This is used for statement sections that do not accept bind parameters
|
|
on the target driver/database.
|
|
|
|
This should be implemented by subclasses using the quoting services
|
|
of the DBAPI.
|
|
|
|
"""
|
|
|
|
processor = type_._cached_literal_processor(self.dialect)
|
|
if processor:
|
|
return processor(value)
|
|
else:
|
|
raise NotImplementedError(
|
|
"Don't know how to literal-quote value %r" % value)
|
|
|
|
def _truncate_bindparam(self, bindparam):
|
|
if bindparam in self.bind_names:
|
|
return self.bind_names[bindparam]
|
|
|
|
bind_name = bindparam.key
|
|
if isinstance(bind_name, elements._truncated_label):
|
|
bind_name = self._truncated_identifier("bindparam", bind_name)
|
|
|
|
# add to bind_names for translation
|
|
self.bind_names[bindparam] = bind_name
|
|
|
|
return bind_name
|
|
|
|
def _truncated_identifier(self, ident_class, name):
|
|
if (ident_class, name) in self.truncated_names:
|
|
return self.truncated_names[(ident_class, name)]
|
|
|
|
anonname = name.apply_map(self.anon_map)
|
|
|
|
if len(anonname) > self.label_length - 6:
|
|
counter = self.truncated_names.get(ident_class, 1)
|
|
truncname = anonname[0:max(self.label_length - 6, 0)] + \
|
|
"_" + hex(counter)[2:]
|
|
self.truncated_names[ident_class] = counter + 1
|
|
else:
|
|
truncname = anonname
|
|
self.truncated_names[(ident_class, name)] = truncname
|
|
return truncname
|
|
|
|
def _anonymize(self, name):
|
|
return name % self.anon_map
|
|
|
|
def _process_anon(self, key):
|
|
(ident, derived) = key.split(' ', 1)
|
|
anonymous_counter = self.anon_map.get(derived, 1)
|
|
self.anon_map[derived] = anonymous_counter + 1
|
|
return derived + "_" + str(anonymous_counter)
|
|
|
|
def bindparam_string(self, name, positional_names=None, **kw):
|
|
if self.positional:
|
|
if positional_names is not None:
|
|
positional_names.append(name)
|
|
else:
|
|
self.positiontup.append(name)
|
|
return self.bindtemplate % {'name': name}
|
|
|
|
def visit_cte(self, cte, asfrom=False, ashint=False,
|
|
fromhints=None,
|
|
**kwargs):
|
|
self._init_cte_state()
|
|
|
|
if isinstance(cte.name, elements._truncated_label):
|
|
cte_name = self._truncated_identifier("alias", cte.name)
|
|
else:
|
|
cte_name = cte.name
|
|
|
|
if cte_name in self.ctes_by_name:
|
|
existing_cte = self.ctes_by_name[cte_name]
|
|
# we've generated a same-named CTE that we are enclosed in,
|
|
# or this is the same CTE. just return the name.
|
|
if cte in existing_cte._restates or cte is existing_cte:
|
|
return self.preparer.format_alias(cte, cte_name)
|
|
elif existing_cte in cte._restates:
|
|
# we've generated a same-named CTE that is
|
|
# enclosed in us - we take precedence, so
|
|
# discard the text for the "inner".
|
|
del self.ctes[existing_cte]
|
|
else:
|
|
raise exc.CompileError(
|
|
"Multiple, unrelated CTEs found with "
|
|
"the same name: %r" %
|
|
cte_name)
|
|
|
|
self.ctes_by_name[cte_name] = cte
|
|
|
|
# look for embedded DML ctes and propagate autocommit
|
|
if 'autocommit' in cte.element._execution_options and \
|
|
'autocommit' not in self.execution_options:
|
|
self.execution_options = self.execution_options.union(
|
|
{"autocommit": cte.element._execution_options['autocommit']})
|
|
|
|
if cte._cte_alias is not None:
|
|
orig_cte = cte._cte_alias
|
|
if orig_cte not in self.ctes:
|
|
self.visit_cte(orig_cte, **kwargs)
|
|
cte_alias_name = cte._cte_alias.name
|
|
if isinstance(cte_alias_name, elements._truncated_label):
|
|
cte_alias_name = self._truncated_identifier(
|
|
"alias", cte_alias_name)
|
|
else:
|
|
orig_cte = cte
|
|
cte_alias_name = None
|
|
if not cte_alias_name and cte not in self.ctes:
|
|
if cte.recursive:
|
|
self.ctes_recursive = True
|
|
text = self.preparer.format_alias(cte, cte_name)
|
|
if cte.recursive:
|
|
if isinstance(cte.original, selectable.Select):
|
|
col_source = cte.original
|
|
elif isinstance(cte.original, selectable.CompoundSelect):
|
|
col_source = cte.original.selects[0]
|
|
else:
|
|
assert False
|
|
recur_cols = [c for c in
|
|
util.unique_list(col_source.inner_columns)
|
|
if c is not None]
|
|
|
|
text += "(%s)" % (", ".join(
|
|
self.preparer.format_column(ident)
|
|
for ident in recur_cols))
|
|
|
|
if self.positional:
|
|
kwargs['positional_names'] = self.cte_positional[cte] = []
|
|
|
|
text += " AS \n" + \
|
|
cte.original._compiler_dispatch(
|
|
self, asfrom=True, **kwargs
|
|
)
|
|
|
|
if cte._suffixes:
|
|
text += " " + self._generate_prefixes(
|
|
cte, cte._suffixes, **kwargs)
|
|
|
|
self.ctes[cte] = text
|
|
|
|
if asfrom:
|
|
if cte_alias_name:
|
|
text = self.preparer.format_alias(cte, cte_alias_name)
|
|
text += self.get_render_as_alias_suffix(cte_name)
|
|
else:
|
|
return self.preparer.format_alias(cte, cte_name)
|
|
return text
|
|
|
|
def visit_alias(self, alias, asfrom=False, ashint=False,
|
|
iscrud=False,
|
|
fromhints=None, **kwargs):
|
|
if asfrom or ashint:
|
|
if isinstance(alias.name, elements._truncated_label):
|
|
alias_name = self._truncated_identifier("alias", alias.name)
|
|
else:
|
|
alias_name = alias.name
|
|
|
|
if ashint:
|
|
return self.preparer.format_alias(alias, alias_name)
|
|
elif asfrom:
|
|
ret = alias.original._compiler_dispatch(self,
|
|
asfrom=True, **kwargs) + \
|
|
self.get_render_as_alias_suffix(
|
|
self.preparer.format_alias(alias, alias_name))
|
|
|
|
if fromhints and alias in fromhints:
|
|
ret = self.format_from_hint_text(ret, alias,
|
|
fromhints[alias], iscrud)
|
|
|
|
return ret
|
|
else:
|
|
return alias.original._compiler_dispatch(self, **kwargs)
|
|
|
|
def visit_lateral(self, lateral, **kw):
|
|
kw['lateral'] = True
|
|
return "LATERAL %s" % self.visit_alias(lateral, **kw)
|
|
|
|
def visit_tablesample(self, tablesample, asfrom=False, **kw):
|
|
text = "%s TABLESAMPLE %s" % (
|
|
self.visit_alias(tablesample, asfrom=True, **kw),
|
|
tablesample._get_method()._compiler_dispatch(self, **kw))
|
|
|
|
if tablesample.seed is not None:
|
|
text += " REPEATABLE (%s)" % (
|
|
tablesample.seed._compiler_dispatch(self, **kw))
|
|
|
|
return text
|
|
|
|
def get_render_as_alias_suffix(self, alias_name_text):
|
|
return " AS " + alias_name_text
|
|
|
|
def _add_to_result_map(self, keyname, name, objects, type_):
|
|
self._result_columns.append((keyname, name, objects, type_))
|
|
|
|
def _label_select_column(self, select, column,
|
|
populate_result_map,
|
|
asfrom, column_clause_args,
|
|
name=None,
|
|
within_columns_clause=True):
|
|
"""produce labeled columns present in a select()."""
|
|
|
|
if column.type._has_column_expression and \
|
|
populate_result_map:
|
|
col_expr = column.type.column_expression(column)
|
|
add_to_result_map = lambda keyname, name, objects, type_: \
|
|
self._add_to_result_map(
|
|
keyname, name,
|
|
(column,) + objects, type_)
|
|
else:
|
|
col_expr = column
|
|
if populate_result_map:
|
|
add_to_result_map = self._add_to_result_map
|
|
else:
|
|
add_to_result_map = None
|
|
|
|
if not within_columns_clause:
|
|
result_expr = col_expr
|
|
elif isinstance(column, elements.Label):
|
|
if col_expr is not column:
|
|
result_expr = _CompileLabel(
|
|
col_expr,
|
|
column.name,
|
|
alt_names=(column.element,)
|
|
)
|
|
else:
|
|
result_expr = col_expr
|
|
|
|
elif select is not None and name:
|
|
result_expr = _CompileLabel(
|
|
col_expr,
|
|
name,
|
|
alt_names=(column._key_label,)
|
|
)
|
|
|
|
elif \
|
|
asfrom and \
|
|
isinstance(column, elements.ColumnClause) and \
|
|
not column.is_literal and \
|
|
column.table is not None and \
|
|
not isinstance(column.table, selectable.Select):
|
|
result_expr = _CompileLabel(col_expr,
|
|
elements._as_truncated(column.name),
|
|
alt_names=(column.key,))
|
|
elif (
|
|
not isinstance(column, elements.TextClause) and
|
|
(
|
|
not isinstance(column, elements.UnaryExpression) or
|
|
column.wraps_column_expression
|
|
) and
|
|
(
|
|
not hasattr(column, 'name') or
|
|
isinstance(column, functions.Function)
|
|
)
|
|
):
|
|
result_expr = _CompileLabel(col_expr, column.anon_label)
|
|
elif col_expr is not column:
|
|
# TODO: are we sure "column" has a .name and .key here ?
|
|
# assert isinstance(column, elements.ColumnClause)
|
|
result_expr = _CompileLabel(col_expr,
|
|
elements._as_truncated(column.name),
|
|
alt_names=(column.key,))
|
|
else:
|
|
result_expr = col_expr
|
|
|
|
column_clause_args.update(
|
|
within_columns_clause=within_columns_clause,
|
|
add_to_result_map=add_to_result_map
|
|
)
|
|
return result_expr._compiler_dispatch(
|
|
self,
|
|
**column_clause_args
|
|
)
|
|
|
|
def format_from_hint_text(self, sqltext, table, hint, iscrud):
|
|
hinttext = self.get_from_hint_text(table, hint)
|
|
if hinttext:
|
|
sqltext += " " + hinttext
|
|
return sqltext
|
|
|
|
def get_select_hint_text(self, byfroms):
|
|
return None
|
|
|
|
def get_from_hint_text(self, table, text):
|
|
return None
|
|
|
|
def get_crud_hint_text(self, table, text):
|
|
return None
|
|
|
|
def get_statement_hint_text(self, hint_texts):
|
|
return " ".join(hint_texts)
|
|
|
|
def _transform_select_for_nested_joins(self, select):
|
|
"""Rewrite any "a JOIN (b JOIN c)" expression as
|
|
"a JOIN (select * from b JOIN c) AS anon", to support
|
|
databases that can't parse a parenthesized join correctly
|
|
(i.e. sqlite < 3.7.16).
|
|
|
|
"""
|
|
cloned = {}
|
|
column_translate = [{}]
|
|
|
|
def visit(element, **kw):
|
|
if element in column_translate[-1]:
|
|
return column_translate[-1][element]
|
|
|
|
elif element in cloned:
|
|
return cloned[element]
|
|
|
|
newelem = cloned[element] = element._clone()
|
|
|
|
if newelem.is_selectable and newelem._is_join and \
|
|
isinstance(newelem.right, selectable.FromGrouping):
|
|
|
|
newelem._reset_exported()
|
|
newelem.left = visit(newelem.left, **kw)
|
|
|
|
right = visit(newelem.right, **kw)
|
|
|
|
selectable_ = selectable.Select(
|
|
[right.element],
|
|
use_labels=True).alias()
|
|
|
|
for c in selectable_.c:
|
|
c._key_label = c.key
|
|
c._label = c.name
|
|
|
|
translate_dict = dict(
|
|
zip(newelem.right.element.c, selectable_.c)
|
|
)
|
|
|
|
# translating from both the old and the new
|
|
# because different select() structures will lead us
|
|
# to traverse differently
|
|
translate_dict[right.element.left] = selectable_
|
|
translate_dict[right.element.right] = selectable_
|
|
translate_dict[newelem.right.element.left] = selectable_
|
|
translate_dict[newelem.right.element.right] = selectable_
|
|
|
|
# propagate translations that we've gained
|
|
# from nested visit(newelem.right) outwards
|
|
# to the enclosing select here. this happens
|
|
# only when we have more than one level of right
|
|
# join nesting, i.e. "a JOIN (b JOIN (c JOIN d))"
|
|
for k, v in list(column_translate[-1].items()):
|
|
if v in translate_dict:
|
|
# remarkably, no current ORM tests (May 2013)
|
|
# hit this condition, only test_join_rewriting
|
|
# does.
|
|
column_translate[-1][k] = translate_dict[v]
|
|
|
|
column_translate[-1].update(translate_dict)
|
|
|
|
newelem.right = selectable_
|
|
|
|
newelem.onclause = visit(newelem.onclause, **kw)
|
|
|
|
elif newelem._is_from_container:
|
|
# if we hit an Alias, CompoundSelect or ScalarSelect, put a
|
|
# marker in the stack.
|
|
kw['transform_clue'] = 'select_container'
|
|
newelem._copy_internals(clone=visit, **kw)
|
|
elif newelem.is_selectable and newelem._is_select:
|
|
barrier_select = kw.get('transform_clue', None) == \
|
|
'select_container'
|
|
# if we're still descended from an
|
|
# Alias/CompoundSelect/ScalarSelect, we're
|
|
# in a FROM clause, so start with a new translate collection
|
|
if barrier_select:
|
|
column_translate.append({})
|
|
kw['transform_clue'] = 'inside_select'
|
|
newelem._copy_internals(clone=visit, **kw)
|
|
if barrier_select:
|
|
del column_translate[-1]
|
|
else:
|
|
newelem._copy_internals(clone=visit, **kw)
|
|
|
|
return newelem
|
|
|
|
return visit(select)
|
|
|
|
def _transform_result_map_for_nested_joins(
|
|
self, select, transformed_select):
|
|
inner_col = dict((c._key_label, c) for
|
|
c in transformed_select.inner_columns)
|
|
|
|
d = dict(
|
|
(inner_col[c._key_label], c)
|
|
for c in select.inner_columns
|
|
)
|
|
|
|
self._result_columns = [
|
|
(key, name, tuple([d.get(col, col) for col in objs]), typ)
|
|
for key, name, objs, typ in self._result_columns
|
|
]
|
|
|
|
_default_stack_entry = util.immutabledict([
|
|
('correlate_froms', frozenset()),
|
|
('asfrom_froms', frozenset())
|
|
])
|
|
|
|
def _display_froms_for_select(self, select, asfrom, lateral=False):
|
|
# utility method to help external dialects
|
|
# get the correct from list for a select.
|
|
# specifically the oracle dialect needs this feature
|
|
# right now.
|
|
toplevel = not self.stack
|
|
entry = self._default_stack_entry if toplevel else self.stack[-1]
|
|
|
|
correlate_froms = entry['correlate_froms']
|
|
asfrom_froms = entry['asfrom_froms']
|
|
|
|
if asfrom and not lateral:
|
|
froms = select._get_display_froms(
|
|
explicit_correlate_froms=correlate_froms.difference(
|
|
asfrom_froms),
|
|
implicit_correlate_froms=())
|
|
else:
|
|
froms = select._get_display_froms(
|
|
explicit_correlate_froms=correlate_froms,
|
|
implicit_correlate_froms=asfrom_froms)
|
|
return froms
|
|
|
|
def visit_select(self, select, asfrom=False, parens=True,
|
|
fromhints=None,
|
|
compound_index=0,
|
|
nested_join_translation=False,
|
|
select_wraps_for=None,
|
|
lateral=False,
|
|
**kwargs):
|
|
|
|
needs_nested_translation = \
|
|
select.use_labels and \
|
|
not nested_join_translation and \
|
|
not self.stack and \
|
|
not self.dialect.supports_right_nested_joins
|
|
|
|
if needs_nested_translation:
|
|
transformed_select = self._transform_select_for_nested_joins(
|
|
select)
|
|
text = self.visit_select(
|
|
transformed_select, asfrom=asfrom, parens=parens,
|
|
fromhints=fromhints,
|
|
compound_index=compound_index,
|
|
nested_join_translation=True, **kwargs
|
|
)
|
|
|
|
toplevel = not self.stack
|
|
entry = self._default_stack_entry if toplevel else self.stack[-1]
|
|
|
|
populate_result_map = toplevel or \
|
|
(
|
|
compound_index == 0 and entry.get(
|
|
'need_result_map_for_compound', False)
|
|
) or entry.get('need_result_map_for_nested', False)
|
|
|
|
# this was first proposed as part of #3372; however, it is not
|
|
# reached in current tests and could possibly be an assertion
|
|
# instead.
|
|
if not populate_result_map and 'add_to_result_map' in kwargs:
|
|
del kwargs['add_to_result_map']
|
|
|
|
if needs_nested_translation:
|
|
if populate_result_map:
|
|
self._transform_result_map_for_nested_joins(
|
|
select, transformed_select)
|
|
return text
|
|
|
|
froms = self._setup_select_stack(select, entry, asfrom, lateral)
|
|
|
|
column_clause_args = kwargs.copy()
|
|
column_clause_args.update({
|
|
'within_label_clause': False,
|
|
'within_columns_clause': False
|
|
})
|
|
|
|
text = "SELECT " # we're off to a good start !
|
|
|
|
if select._hints:
|
|
hint_text, byfrom = self._setup_select_hints(select)
|
|
if hint_text:
|
|
text += hint_text + " "
|
|
else:
|
|
byfrom = None
|
|
|
|
if select._prefixes:
|
|
text += self._generate_prefixes(
|
|
select, select._prefixes, **kwargs)
|
|
|
|
text += self.get_select_precolumns(select, **kwargs)
|
|
# the actual list of columns to print in the SELECT column list.
|
|
inner_columns = [
|
|
c for c in [
|
|
self._label_select_column(
|
|
select,
|
|
column,
|
|
populate_result_map, asfrom,
|
|
column_clause_args,
|
|
name=name)
|
|
for name, column in select._columns_plus_names
|
|
]
|
|
if c is not None
|
|
]
|
|
|
|
if populate_result_map and select_wraps_for is not None:
|
|
# if this select is a compiler-generated wrapper,
|
|
# rewrite the targeted columns in the result map
|
|
|
|
translate = dict(
|
|
zip(
|
|
[name for (key, name) in select._columns_plus_names],
|
|
[name for (key, name) in
|
|
select_wraps_for._columns_plus_names])
|
|
)
|
|
|
|
self._result_columns = [
|
|
(key, name, tuple(translate.get(o, o) for o in obj), type_)
|
|
for key, name, obj, type_ in self._result_columns
|
|
]
|
|
|
|
text = self._compose_select_body(
|
|
text, select, inner_columns, froms, byfrom, kwargs)
|
|
|
|
if select._statement_hints:
|
|
per_dialect = [
|
|
ht for (dialect_name, ht)
|
|
in select._statement_hints
|
|
if dialect_name in ('*', self.dialect.name)
|
|
]
|
|
if per_dialect:
|
|
text += " " + self.get_statement_hint_text(per_dialect)
|
|
|
|
if self.ctes and toplevel:
|
|
text = self._render_cte_clause() + text
|
|
|
|
if select._suffixes:
|
|
text += " " + self._generate_prefixes(
|
|
select, select._suffixes, **kwargs)
|
|
|
|
self.stack.pop(-1)
|
|
|
|
if (asfrom or lateral) and parens:
|
|
return "(" + text + ")"
|
|
else:
|
|
return text
|
|
|
|
def _setup_select_hints(self, select):
|
|
byfrom = dict([
|
|
(from_, hinttext % {
|
|
'name': from_._compiler_dispatch(
|
|
self, ashint=True)
|
|
})
|
|
for (from_, dialect), hinttext in
|
|
select._hints.items()
|
|
if dialect in ('*', self.dialect.name)
|
|
])
|
|
hint_text = self.get_select_hint_text(byfrom)
|
|
return hint_text, byfrom
|
|
|
|
def _setup_select_stack(self, select, entry, asfrom, lateral):
|
|
correlate_froms = entry['correlate_froms']
|
|
asfrom_froms = entry['asfrom_froms']
|
|
|
|
if asfrom and not lateral:
|
|
froms = select._get_display_froms(
|
|
explicit_correlate_froms=correlate_froms.difference(
|
|
asfrom_froms),
|
|
implicit_correlate_froms=())
|
|
else:
|
|
froms = select._get_display_froms(
|
|
explicit_correlate_froms=correlate_froms,
|
|
implicit_correlate_froms=asfrom_froms)
|
|
|
|
new_correlate_froms = set(selectable._from_objects(*froms))
|
|
all_correlate_froms = new_correlate_froms.union(correlate_froms)
|
|
|
|
new_entry = {
|
|
'asfrom_froms': new_correlate_froms,
|
|
'correlate_froms': all_correlate_froms,
|
|
'selectable': select,
|
|
}
|
|
self.stack.append(new_entry)
|
|
|
|
return froms
|
|
|
|
def _compose_select_body(
|
|
self, text, select, inner_columns, froms, byfrom, kwargs):
|
|
text += ', '.join(inner_columns)
|
|
|
|
if froms:
|
|
text += " \nFROM "
|
|
|
|
if select._hints:
|
|
text += ', '.join(
|
|
[f._compiler_dispatch(self, asfrom=True,
|
|
fromhints=byfrom, **kwargs)
|
|
for f in froms])
|
|
else:
|
|
text += ', '.join(
|
|
[f._compiler_dispatch(self, asfrom=True, **kwargs)
|
|
for f in froms])
|
|
else:
|
|
text += self.default_from()
|
|
|
|
if select._whereclause is not None:
|
|
t = select._whereclause._compiler_dispatch(self, **kwargs)
|
|
if t:
|
|
text += " \nWHERE " + t
|
|
|
|
if select._group_by_clause.clauses:
|
|
group_by = select._group_by_clause._compiler_dispatch(
|
|
self, **kwargs)
|
|
if group_by:
|
|
text += " GROUP BY " + group_by
|
|
|
|
if select._having is not None:
|
|
t = select._having._compiler_dispatch(self, **kwargs)
|
|
if t:
|
|
text += " \nHAVING " + t
|
|
|
|
if select._order_by_clause.clauses:
|
|
text += self.order_by_clause(select, **kwargs)
|
|
|
|
if (select._limit_clause is not None or
|
|
select._offset_clause is not None):
|
|
text += self.limit_clause(select, **kwargs)
|
|
|
|
if select._for_update_arg is not None:
|
|
text += self.for_update_clause(select, **kwargs)
|
|
|
|
return text
|
|
|
|
def _generate_prefixes(self, stmt, prefixes, **kw):
|
|
clause = " ".join(
|
|
prefix._compiler_dispatch(self, **kw)
|
|
for prefix, dialect_name in prefixes
|
|
if dialect_name is None or
|
|
dialect_name == self.dialect.name
|
|
)
|
|
if clause:
|
|
clause += " "
|
|
return clause
|
|
|
|
def _render_cte_clause(self):
|
|
if self.positional:
|
|
self.positiontup = sum([
|
|
self.cte_positional[cte]
|
|
for cte in self.ctes], []) + \
|
|
self.positiontup
|
|
cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
|
|
cte_text += ", \n".join(
|
|
[txt for txt in self.ctes.values()]
|
|
)
|
|
cte_text += "\n "
|
|
return cte_text
|
|
|
|
def get_cte_preamble(self, recursive):
|
|
if recursive:
|
|
return "WITH RECURSIVE"
|
|
else:
|
|
return "WITH"
|
|
|
|
def get_select_precolumns(self, select, **kw):
|
|
"""Called when building a ``SELECT`` statement, position is just
|
|
before column list.
|
|
|
|
"""
|
|
return select._distinct and "DISTINCT " or ""
|
|
|
|
def order_by_clause(self, select, **kw):
|
|
order_by = select._order_by_clause._compiler_dispatch(self, **kw)
|
|
if order_by:
|
|
return " ORDER BY " + order_by
|
|
else:
|
|
return ""
|
|
|
|
def for_update_clause(self, select, **kw):
|
|
return " FOR UPDATE"
|
|
|
|
def returning_clause(self, stmt, returning_cols):
|
|
raise exc.CompileError(
|
|
"RETURNING is not supported by this "
|
|
"dialect's statement compiler.")
|
|
|
|
def limit_clause(self, select, **kw):
|
|
text = ""
|
|
if select._limit_clause is not None:
|
|
text += "\n LIMIT " + self.process(select._limit_clause, **kw)
|
|
if select._offset_clause is not None:
|
|
if select._limit_clause is None:
|
|
text += "\n LIMIT -1"
|
|
text += " OFFSET " + self.process(select._offset_clause, **kw)
|
|
return text
|
|
|
|
def visit_table(self, table, asfrom=False, iscrud=False, ashint=False,
|
|
fromhints=None, use_schema=True, **kwargs):
|
|
if asfrom or ashint:
|
|
effective_schema = self.preparer.schema_for_object(table)
|
|
|
|
if use_schema and effective_schema:
|
|
ret = self.preparer.quote_schema(effective_schema) + \
|
|
"." + self.preparer.quote(table.name)
|
|
else:
|
|
ret = self.preparer.quote(table.name)
|
|
if fromhints and table in fromhints:
|
|
ret = self.format_from_hint_text(ret, table,
|
|
fromhints[table], iscrud)
|
|
return ret
|
|
else:
|
|
return ""
|
|
|
|
def visit_join(self, join, asfrom=False, **kwargs):
|
|
if join.full:
|
|
join_type = " FULL OUTER JOIN "
|
|
elif join.isouter:
|
|
join_type = " LEFT OUTER JOIN "
|
|
else:
|
|
join_type = " JOIN "
|
|
return (
|
|
join.left._compiler_dispatch(self, asfrom=True, **kwargs) +
|
|
join_type +
|
|
join.right._compiler_dispatch(self, asfrom=True, **kwargs) +
|
|
" ON " +
|
|
join.onclause._compiler_dispatch(self, **kwargs)
|
|
)
|
|
|
|
def _setup_crud_hints(self, stmt, table_text):
|
|
dialect_hints = dict([
|
|
(table, hint_text)
|
|
for (table, dialect), hint_text in
|
|
stmt._hints.items()
|
|
if dialect in ('*', self.dialect.name)
|
|
])
|
|
if stmt.table in dialect_hints:
|
|
table_text = self.format_from_hint_text(
|
|
table_text,
|
|
stmt.table,
|
|
dialect_hints[stmt.table],
|
|
True
|
|
)
|
|
return dialect_hints, table_text
|
|
|
|
def visit_insert(self, insert_stmt, asfrom=False, **kw):
|
|
toplevel = not self.stack
|
|
|
|
self.stack.append(
|
|
{'correlate_froms': set(),
|
|
"asfrom_froms": set(),
|
|
"selectable": insert_stmt})
|
|
|
|
crud_params = crud._setup_crud_params(
|
|
self, insert_stmt, crud.ISINSERT, **kw)
|
|
|
|
if not crud_params and \
|
|
not self.dialect.supports_default_values and \
|
|
not self.dialect.supports_empty_insert:
|
|
raise exc.CompileError("The '%s' dialect with current database "
|
|
"version settings does not support empty "
|
|
"inserts." %
|
|
self.dialect.name)
|
|
|
|
if insert_stmt._has_multi_parameters:
|
|
if not self.dialect.supports_multivalues_insert:
|
|
raise exc.CompileError(
|
|
"The '%s' dialect with current database "
|
|
"version settings does not support "
|
|
"in-place multirow inserts." %
|
|
self.dialect.name)
|
|
crud_params_single = crud_params[0]
|
|
else:
|
|
crud_params_single = crud_params
|
|
|
|
preparer = self.preparer
|
|
supports_default_values = self.dialect.supports_default_values
|
|
|
|
text = "INSERT "
|
|
|
|
if insert_stmt._prefixes:
|
|
text += self._generate_prefixes(insert_stmt,
|
|
insert_stmt._prefixes, **kw)
|
|
|
|
text += "INTO "
|
|
table_text = preparer.format_table(insert_stmt.table)
|
|
|
|
if insert_stmt._hints:
|
|
dialect_hints, table_text = self._setup_crud_hints(
|
|
insert_stmt, table_text)
|
|
else:
|
|
dialect_hints = None
|
|
|
|
text += table_text
|
|
|
|
if crud_params_single or not supports_default_values:
|
|
text += " (%s)" % ', '.join([preparer.format_column(c[0])
|
|
for c in crud_params_single])
|
|
|
|
if self.returning or insert_stmt._returning:
|
|
returning_clause = self.returning_clause(
|
|
insert_stmt, self.returning or insert_stmt._returning)
|
|
|
|
if self.returning_precedes_values:
|
|
text += " " + returning_clause
|
|
else:
|
|
returning_clause = None
|
|
|
|
if insert_stmt.select is not None:
|
|
text += " %s" % self.process(self._insert_from_select, **kw)
|
|
elif not crud_params and supports_default_values:
|
|
text += " DEFAULT VALUES"
|
|
elif insert_stmt._has_multi_parameters:
|
|
text += " VALUES %s" % (
|
|
", ".join(
|
|
"(%s)" % (
|
|
', '.join(c[1] for c in crud_param_set)
|
|
)
|
|
for crud_param_set in crud_params
|
|
)
|
|
)
|
|
else:
|
|
text += " VALUES (%s)" % \
|
|
', '.join([c[1] for c in crud_params])
|
|
|
|
if insert_stmt._post_values_clause is not None:
|
|
post_values_clause = self.process(
|
|
insert_stmt._post_values_clause, **kw)
|
|
if post_values_clause:
|
|
text += " " + post_values_clause
|
|
|
|
if returning_clause and not self.returning_precedes_values:
|
|
text += " " + returning_clause
|
|
|
|
if self.ctes and toplevel:
|
|
text = self._render_cte_clause() + text
|
|
|
|
self.stack.pop(-1)
|
|
|
|
if asfrom:
|
|
return "(" + text + ")"
|
|
else:
|
|
return text
|
|
|
|
def update_limit_clause(self, update_stmt):
|
|
"""Provide a hook for MySQL to add LIMIT to the UPDATE"""
|
|
return None
|
|
|
|
def update_tables_clause(self, update_stmt, from_table,
|
|
extra_froms, **kw):
|
|
"""Provide a hook to override the initial table clause
|
|
in an UPDATE statement.
|
|
|
|
MySQL overrides this.
|
|
|
|
"""
|
|
kw['asfrom'] = True
|
|
return from_table._compiler_dispatch(self, iscrud=True, **kw)
|
|
|
|
def update_from_clause(self, update_stmt,
|
|
from_table, extra_froms,
|
|
from_hints,
|
|
**kw):
|
|
"""Provide a hook to override the generation of an
|
|
UPDATE..FROM clause.
|
|
|
|
MySQL and MSSQL override this.
|
|
|
|
"""
|
|
return "FROM " + ', '.join(
|
|
t._compiler_dispatch(self, asfrom=True,
|
|
fromhints=from_hints, **kw)
|
|
for t in extra_froms)
|
|
|
|
def visit_update(self, update_stmt, asfrom=False, **kw):
|
|
toplevel = not self.stack
|
|
|
|
self.stack.append(
|
|
{'correlate_froms': set([update_stmt.table]),
|
|
"asfrom_froms": set([update_stmt.table]),
|
|
"selectable": update_stmt})
|
|
|
|
extra_froms = update_stmt._extra_froms
|
|
|
|
text = "UPDATE "
|
|
|
|
if update_stmt._prefixes:
|
|
text += self._generate_prefixes(update_stmt,
|
|
update_stmt._prefixes, **kw)
|
|
|
|
table_text = self.update_tables_clause(update_stmt, update_stmt.table,
|
|
extra_froms, **kw)
|
|
|
|
crud_params = crud._setup_crud_params(
|
|
self, update_stmt, crud.ISUPDATE, **kw)
|
|
|
|
if update_stmt._hints:
|
|
dialect_hints, table_text = self._setup_crud_hints(
|
|
update_stmt, table_text)
|
|
else:
|
|
dialect_hints = None
|
|
|
|
text += table_text
|
|
|
|
text += ' SET '
|
|
include_table = extra_froms and \
|
|
self.render_table_with_column_in_update_from
|
|
text += ', '.join(
|
|
c[0]._compiler_dispatch(self,
|
|
include_table=include_table) +
|
|
'=' + c[1] for c in crud_params
|
|
)
|
|
|
|
if self.returning or update_stmt._returning:
|
|
if self.returning_precedes_values:
|
|
text += " " + self.returning_clause(
|
|
update_stmt, self.returning or update_stmt._returning)
|
|
|
|
if extra_froms:
|
|
extra_from_text = self.update_from_clause(
|
|
update_stmt,
|
|
update_stmt.table,
|
|
extra_froms,
|
|
dialect_hints, **kw)
|
|
if extra_from_text:
|
|
text += " " + extra_from_text
|
|
|
|
if update_stmt._whereclause is not None:
|
|
t = self.process(update_stmt._whereclause, **kw)
|
|
if t:
|
|
text += " WHERE " + t
|
|
|
|
limit_clause = self.update_limit_clause(update_stmt)
|
|
if limit_clause:
|
|
text += " " + limit_clause
|
|
|
|
if (self.returning or update_stmt._returning) and \
|
|
not self.returning_precedes_values:
|
|
text += " " + self.returning_clause(
|
|
update_stmt, self.returning or update_stmt._returning)
|
|
|
|
if self.ctes and toplevel:
|
|
text = self._render_cte_clause() + text
|
|
|
|
self.stack.pop(-1)
|
|
|
|
if asfrom:
|
|
return "(" + text + ")"
|
|
else:
|
|
return text
|
|
|
|
@util.memoized_property
|
|
def _key_getters_for_crud_column(self):
|
|
return crud._key_getters_for_crud_column(self, self.statement)
|
|
|
|
def visit_delete(self, delete_stmt, asfrom=False, **kw):
|
|
toplevel = not self.stack
|
|
|
|
self.stack.append({'correlate_froms': set([delete_stmt.table]),
|
|
"asfrom_froms": set([delete_stmt.table]),
|
|
"selectable": delete_stmt})
|
|
|
|
crud._setup_crud_params(self, delete_stmt, crud.ISDELETE, **kw)
|
|
|
|
text = "DELETE "
|
|
|
|
if delete_stmt._prefixes:
|
|
text += self._generate_prefixes(delete_stmt,
|
|
delete_stmt._prefixes, **kw)
|
|
|
|
text += "FROM "
|
|
table_text = delete_stmt.table._compiler_dispatch(
|
|
self, asfrom=True, iscrud=True)
|
|
|
|
if delete_stmt._hints:
|
|
dialect_hints, table_text = self._setup_crud_hints(
|
|
delete_stmt, table_text)
|
|
|
|
text += table_text
|
|
|
|
if delete_stmt._returning:
|
|
if self.returning_precedes_values:
|
|
text += " " + self.returning_clause(
|
|
delete_stmt, delete_stmt._returning)
|
|
|
|
if delete_stmt._whereclause is not None:
|
|
t = delete_stmt._whereclause._compiler_dispatch(self, **kw)
|
|
if t:
|
|
text += " WHERE " + t
|
|
|
|
if delete_stmt._returning and not self.returning_precedes_values:
|
|
text += " " + self.returning_clause(
|
|
delete_stmt, delete_stmt._returning)
|
|
|
|
if self.ctes and toplevel:
|
|
text = self._render_cte_clause() + text
|
|
|
|
self.stack.pop(-1)
|
|
|
|
if asfrom:
|
|
return "(" + text + ")"
|
|
else:
|
|
return text
|
|
|
|
def visit_savepoint(self, savepoint_stmt):
|
|
return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
|
|
|
|
def visit_rollback_to_savepoint(self, savepoint_stmt):
|
|
return "ROLLBACK TO SAVEPOINT %s" % \
|
|
self.preparer.format_savepoint(savepoint_stmt)
|
|
|
|
def visit_release_savepoint(self, savepoint_stmt):
|
|
return "RELEASE SAVEPOINT %s" % \
|
|
self.preparer.format_savepoint(savepoint_stmt)
|
|
|
|
|
|
class StrSQLCompiler(SQLCompiler):
|
|
""""a compiler subclass with a few non-standard SQL features allowed.
|
|
|
|
Used for stringification of SQL statements when a real dialect is not
|
|
available.
|
|
|
|
"""
|
|
|
|
def _fallback_column_name(self, column):
|
|
return "<name unknown>"
|
|
|
|
def visit_getitem_binary(self, binary, operator, **kw):
|
|
return "%s[%s]" % (
|
|
self.process(binary.left, **kw),
|
|
self.process(binary.right, **kw)
|
|
)
|
|
|
|
def visit_json_getitem_op_binary(self, binary, operator, **kw):
|
|
return self.visit_getitem_binary(binary, operator, **kw)
|
|
|
|
def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
|
|
return self.visit_getitem_binary(binary, operator, **kw)
|
|
|
|
def returning_clause(self, stmt, returning_cols):
|
|
columns = [
|
|
self._label_select_column(None, c, True, False, {})
|
|
for c in elements._select_iterables(returning_cols)
|
|
]
|
|
|
|
return 'RETURNING ' + ', '.join(columns)
|
|
|
|
|
|
class DDLCompiler(Compiled):
|
|
|
|
@util.memoized_property
|
|
def sql_compiler(self):
|
|
return self.dialect.statement_compiler(self.dialect, None)
|
|
|
|
@util.memoized_property
|
|
def type_compiler(self):
|
|
return self.dialect.type_compiler
|
|
|
|
def construct_params(self, params=None):
|
|
return None
|
|
|
|
def visit_ddl(self, ddl, **kwargs):
|
|
# table events can substitute table and schema name
|
|
context = ddl.context
|
|
if isinstance(ddl.target, schema.Table):
|
|
context = context.copy()
|
|
|
|
preparer = self.preparer
|
|
path = preparer.format_table_seq(ddl.target)
|
|
if len(path) == 1:
|
|
table, sch = path[0], ''
|
|
else:
|
|
table, sch = path[-1], path[0]
|
|
|
|
context.setdefault('table', table)
|
|
context.setdefault('schema', sch)
|
|
context.setdefault('fullname', preparer.format_table(ddl.target))
|
|
|
|
return self.sql_compiler.post_process_text(ddl.statement % context)
|
|
|
|
def visit_create_schema(self, create):
|
|
schema = self.preparer.format_schema(create.element)
|
|
return "CREATE SCHEMA " + schema
|
|
|
|
def visit_drop_schema(self, drop):
|
|
schema = self.preparer.format_schema(drop.element)
|
|
text = "DROP SCHEMA " + schema
|
|
if drop.cascade:
|
|
text += " CASCADE"
|
|
return text
|
|
|
|
def visit_create_table(self, create):
|
|
table = create.element
|
|
preparer = self.preparer
|
|
|
|
text = "\nCREATE "
|
|
if table._prefixes:
|
|
text += " ".join(table._prefixes) + " "
|
|
text += "TABLE " + preparer.format_table(table) + " "
|
|
|
|
create_table_suffix = self.create_table_suffix(table)
|
|
if create_table_suffix:
|
|
text += create_table_suffix + " "
|
|
|
|
text += "("
|
|
|
|
separator = "\n"
|
|
|
|
# if only one primary key, specify it along with the column
|
|
first_pk = False
|
|
for create_column in create.columns:
|
|
column = create_column.element
|
|
try:
|
|
processed = self.process(create_column,
|
|
first_pk=column.primary_key
|
|
and not first_pk)
|
|
if processed is not None:
|
|
text += separator
|
|
separator = ", \n"
|
|
text += "\t" + processed
|
|
if column.primary_key:
|
|
first_pk = True
|
|
except exc.CompileError as ce:
|
|
util.raise_from_cause(
|
|
exc.CompileError(
|
|
util.u("(in table '%s', column '%s'): %s") %
|
|
(table.description, column.name, ce.args[0])
|
|
))
|
|
|
|
const = self.create_table_constraints(
|
|
table, _include_foreign_key_constraints= # noqa
|
|
create.include_foreign_key_constraints)
|
|
if const:
|
|
text += separator + "\t" + const
|
|
|
|
text += "\n)%s\n\n" % self.post_create_table(table)
|
|
return text
|
|
|
|
def visit_create_column(self, create, first_pk=False):
|
|
column = create.element
|
|
|
|
if column.system:
|
|
return None
|
|
|
|
text = self.get_column_specification(
|
|
column,
|
|
first_pk=first_pk
|
|
)
|
|
const = " ".join(self.process(constraint)
|
|
for constraint in column.constraints)
|
|
if const:
|
|
text += " " + const
|
|
|
|
return text
|
|
|
|
def create_table_constraints(
|
|
self, table,
|
|
_include_foreign_key_constraints=None):
|
|
|
|
# On some DB order is significant: visit PK first, then the
|
|
# other constraints (engine.ReflectionTest.testbasic failed on FB2)
|
|
constraints = []
|
|
if table.primary_key:
|
|
constraints.append(table.primary_key)
|
|
|
|
all_fkcs = table.foreign_key_constraints
|
|
if _include_foreign_key_constraints is not None:
|
|
omit_fkcs = all_fkcs.difference(_include_foreign_key_constraints)
|
|
else:
|
|
omit_fkcs = set()
|
|
|
|
constraints.extend([c for c in table._sorted_constraints
|
|
if c is not table.primary_key and
|
|
c not in omit_fkcs])
|
|
|
|
return ", \n\t".join(
|
|
p for p in
|
|
(self.process(constraint)
|
|
for constraint in constraints
|
|
if (
|
|
constraint._create_rule is None or
|
|
constraint._create_rule(self))
|
|
and (
|
|
not self.dialect.supports_alter or
|
|
not getattr(constraint, 'use_alter', False)
|
|
)) if p is not None
|
|
)
|
|
|
|
def visit_drop_table(self, drop):
|
|
return "\nDROP TABLE " + self.preparer.format_table(drop.element)
|
|
|
|
def visit_drop_view(self, drop):
|
|
return "\nDROP VIEW " + self.preparer.format_table(drop.element)
|
|
|
|
def _verify_index_table(self, index):
|
|
if index.table is None:
|
|
raise exc.CompileError("Index '%s' is not associated "
|
|
"with any table." % index.name)
|
|
|
|
def visit_create_index(self, create, include_schema=False,
|
|
include_table_schema=True):
|
|
index = create.element
|
|
self._verify_index_table(index)
|
|
preparer = self.preparer
|
|
text = "CREATE "
|
|
if index.unique:
|
|
text += "UNIQUE "
|
|
text += "INDEX %s ON %s (%s)" \
|
|
% (
|
|
self._prepared_index_name(index,
|
|
include_schema=include_schema),
|
|
preparer.format_table(index.table,
|
|
use_schema=include_table_schema),
|
|
', '.join(
|
|
self.sql_compiler.process(
|
|
expr, include_table=False, literal_binds=True) for
|
|
expr in index.expressions)
|
|
)
|
|
return text
|
|
|
|
def visit_drop_index(self, drop):
|
|
index = drop.element
|
|
return "\nDROP INDEX " + self._prepared_index_name(
|
|
index, include_schema=True)
|
|
|
|
def _prepared_index_name(self, index, include_schema=False):
|
|
if index.table is not None:
|
|
effective_schema = self.preparer.schema_for_object(index.table)
|
|
else:
|
|
effective_schema = None
|
|
if include_schema and effective_schema:
|
|
schema_name = self.preparer.quote_schema(effective_schema)
|
|
else:
|
|
schema_name = None
|
|
|
|
ident = index.name
|
|
if isinstance(ident, elements._truncated_label):
|
|
max_ = self.dialect.max_index_name_length or \
|
|
self.dialect.max_identifier_length
|
|
if len(ident) > max_:
|
|
ident = ident[0:max_ - 8] + \
|
|
"_" + util.md5_hex(ident)[-4:]
|
|
else:
|
|
self.dialect.validate_identifier(ident)
|
|
|
|
index_name = self.preparer.quote(ident)
|
|
|
|
if schema_name:
|
|
index_name = schema_name + "." + index_name
|
|
return index_name
|
|
|
|
def visit_add_constraint(self, create):
|
|
return "ALTER TABLE %s ADD %s" % (
|
|
self.preparer.format_table(create.element.table),
|
|
self.process(create.element)
|
|
)
|
|
|
|
def visit_create_sequence(self, create):
|
|
text = "CREATE SEQUENCE %s" % \
|
|
self.preparer.format_sequence(create.element)
|
|
if create.element.increment is not None:
|
|
text += " INCREMENT BY %d" % create.element.increment
|
|
if create.element.start is not None:
|
|
text += " START WITH %d" % create.element.start
|
|
if create.element.minvalue is not None:
|
|
text += " MINVALUE %d" % create.element.minvalue
|
|
if create.element.maxvalue is not None:
|
|
text += " MAXVALUE %d" % create.element.maxvalue
|
|
if create.element.nominvalue is not None:
|
|
text += " NO MINVALUE"
|
|
if create.element.nomaxvalue is not None:
|
|
text += " NO MAXVALUE"
|
|
if create.element.cycle is not None:
|
|
text += " CYCLE"
|
|
return text
|
|
|
|
def visit_drop_sequence(self, drop):
|
|
return "DROP SEQUENCE %s" % \
|
|
self.preparer.format_sequence(drop.element)
|
|
|
|
def visit_drop_constraint(self, drop):
|
|
constraint = drop.element
|
|
if constraint.name is not None:
|
|
formatted_name = self.preparer.format_constraint(constraint)
|
|
else:
|
|
formatted_name = None
|
|
|
|
if formatted_name is None:
|
|
raise exc.CompileError(
|
|
"Can't emit DROP CONSTRAINT for constraint %r; "
|
|
"it has no name" % drop.element)
|
|
return "ALTER TABLE %s DROP CONSTRAINT %s%s" % (
|
|
self.preparer.format_table(drop.element.table),
|
|
formatted_name,
|
|
drop.cascade and " CASCADE" or ""
|
|
)
|
|
|
|
def get_column_specification(self, column, **kwargs):
|
|
colspec = self.preparer.format_column(column) + " " + \
|
|
self.dialect.type_compiler.process(
|
|
column.type, type_expression=column)
|
|
default = self.get_column_default_string(column)
|
|
if default is not None:
|
|
colspec += " DEFAULT " + default
|
|
|
|
if not column.nullable:
|
|
colspec += " NOT NULL"
|
|
return colspec
|
|
|
|
def create_table_suffix(self, table):
|
|
return ''
|
|
|
|
def post_create_table(self, table):
|
|
return ''
|
|
|
|
def get_column_default_string(self, column):
|
|
if isinstance(column.server_default, schema.DefaultClause):
|
|
if isinstance(column.server_default.arg, util.string_types):
|
|
return self.sql_compiler.render_literal_value(
|
|
column.server_default.arg, sqltypes.STRINGTYPE)
|
|
else:
|
|
return self.sql_compiler.process(
|
|
column.server_default.arg, literal_binds=True)
|
|
else:
|
|
return None
|
|
|
|
def visit_check_constraint(self, constraint):
|
|
text = ""
|
|
if constraint.name is not None:
|
|
formatted_name = self.preparer.format_constraint(constraint)
|
|
if formatted_name is not None:
|
|
text += "CONSTRAINT %s " % formatted_name
|
|
text += "CHECK (%s)" % self.sql_compiler.process(constraint.sqltext,
|
|
include_table=False,
|
|
literal_binds=True)
|
|
text += self.define_constraint_deferrability(constraint)
|
|
return text
|
|
|
|
def visit_column_check_constraint(self, constraint):
|
|
text = ""
|
|
if constraint.name is not None:
|
|
formatted_name = self.preparer.format_constraint(constraint)
|
|
if formatted_name is not None:
|
|
text += "CONSTRAINT %s " % formatted_name
|
|
text += "CHECK (%s)" % constraint.sqltext
|
|
text += self.define_constraint_deferrability(constraint)
|
|
return text
|
|
|
|
def visit_primary_key_constraint(self, constraint):
|
|
if len(constraint) == 0:
|
|
return ''
|
|
text = ""
|
|
if constraint.name is not None:
|
|
formatted_name = self.preparer.format_constraint(constraint)
|
|
if formatted_name is not None:
|
|
text += "CONSTRAINT %s " % formatted_name
|
|
text += "PRIMARY KEY "
|
|
text += "(%s)" % ', '.join(self.preparer.quote(c.name)
|
|
for c in (constraint.columns_autoinc_first
|
|
if constraint._implicit_generated
|
|
else constraint.columns))
|
|
text += self.define_constraint_deferrability(constraint)
|
|
return text
|
|
|
|
def visit_foreign_key_constraint(self, constraint):
|
|
preparer = self.preparer
|
|
text = ""
|
|
if constraint.name is not None:
|
|
formatted_name = self.preparer.format_constraint(constraint)
|
|
if formatted_name is not None:
|
|
text += "CONSTRAINT %s " % formatted_name
|
|
remote_table = list(constraint.elements)[0].column.table
|
|
text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
|
|
', '.join(preparer.quote(f.parent.name)
|
|
for f in constraint.elements),
|
|
self.define_constraint_remote_table(
|
|
constraint, remote_table, preparer),
|
|
', '.join(preparer.quote(f.column.name)
|
|
for f in constraint.elements)
|
|
)
|
|
text += self.define_constraint_match(constraint)
|
|
text += self.define_constraint_cascades(constraint)
|
|
text += self.define_constraint_deferrability(constraint)
|
|
return text
|
|
|
|
def define_constraint_remote_table(self, constraint, table, preparer):
|
|
"""Format the remote table clause of a CREATE CONSTRAINT clause."""
|
|
|
|
return preparer.format_table(table)
|
|
|
|
def visit_unique_constraint(self, constraint):
|
|
if len(constraint) == 0:
|
|
return ''
|
|
text = ""
|
|
if constraint.name is not None:
|
|
formatted_name = self.preparer.format_constraint(constraint)
|
|
text += "CONSTRAINT %s " % formatted_name
|
|
text += "UNIQUE (%s)" % (
|
|
', '.join(self.preparer.quote(c.name)
|
|
for c in constraint))
|
|
text += self.define_constraint_deferrability(constraint)
|
|
return text
|
|
|
|
def define_constraint_cascades(self, constraint):
|
|
text = ""
|
|
if constraint.ondelete is not None:
|
|
text += " ON DELETE %s" % constraint.ondelete
|
|
if constraint.onupdate is not None:
|
|
text += " ON UPDATE %s" % constraint.onupdate
|
|
return text
|
|
|
|
def define_constraint_deferrability(self, constraint):
|
|
text = ""
|
|
if constraint.deferrable is not None:
|
|
if constraint.deferrable:
|
|
text += " DEFERRABLE"
|
|
else:
|
|
text += " NOT DEFERRABLE"
|
|
if constraint.initially is not None:
|
|
text += " INITIALLY %s" % constraint.initially
|
|
return text
|
|
|
|
def define_constraint_match(self, constraint):
|
|
text = ""
|
|
if constraint.match is not None:
|
|
text += " MATCH %s" % constraint.match
|
|
return text
|
|
|
|
|
|
class GenericTypeCompiler(TypeCompiler):
|
|
|
|
def visit_FLOAT(self, type_, **kw):
|
|
return "FLOAT"
|
|
|
|
def visit_REAL(self, type_, **kw):
|
|
return "REAL"
|
|
|
|
def visit_NUMERIC(self, type_, **kw):
|
|
if type_.precision is None:
|
|
return "NUMERIC"
|
|
elif type_.scale is None:
|
|
return "NUMERIC(%(precision)s)" % \
|
|
{'precision': type_.precision}
|
|
else:
|
|
return "NUMERIC(%(precision)s, %(scale)s)" % \
|
|
{'precision': type_.precision,
|
|
'scale': type_.scale}
|
|
|
|
def visit_DECIMAL(self, type_, **kw):
|
|
if type_.precision is None:
|
|
return "DECIMAL"
|
|
elif type_.scale is None:
|
|
return "DECIMAL(%(precision)s)" % \
|
|
{'precision': type_.precision}
|
|
else:
|
|
return "DECIMAL(%(precision)s, %(scale)s)" % \
|
|
{'precision': type_.precision,
|
|
'scale': type_.scale}
|
|
|
|
def visit_INTEGER(self, type_, **kw):
|
|
return "INTEGER"
|
|
|
|
def visit_SMALLINT(self, type_, **kw):
|
|
return "SMALLINT"
|
|
|
|
def visit_BIGINT(self, type_, **kw):
|
|
return "BIGINT"
|
|
|
|
def visit_TIMESTAMP(self, type_, **kw):
|
|
return 'TIMESTAMP'
|
|
|
|
def visit_DATETIME(self, type_, **kw):
|
|
return "DATETIME"
|
|
|
|
def visit_DATE(self, type_, **kw):
|
|
return "DATE"
|
|
|
|
def visit_TIME(self, type_, **kw):
|
|
return "TIME"
|
|
|
|
def visit_CLOB(self, type_, **kw):
|
|
return "CLOB"
|
|
|
|
def visit_NCLOB(self, type_, **kw):
|
|
return "NCLOB"
|
|
|
|
def _render_string_type(self, type_, name):
|
|
|
|
text = name
|
|
if type_.length:
|
|
text += "(%d)" % type_.length
|
|
if type_.collation:
|
|
text += ' COLLATE "%s"' % type_.collation
|
|
return text
|
|
|
|
def visit_CHAR(self, type_, **kw):
|
|
return self._render_string_type(type_, "CHAR")
|
|
|
|
def visit_NCHAR(self, type_, **kw):
|
|
return self._render_string_type(type_, "NCHAR")
|
|
|
|
def visit_VARCHAR(self, type_, **kw):
|
|
return self._render_string_type(type_, "VARCHAR")
|
|
|
|
def visit_NVARCHAR(self, type_, **kw):
|
|
return self._render_string_type(type_, "NVARCHAR")
|
|
|
|
def visit_TEXT(self, type_, **kw):
|
|
return self._render_string_type(type_, "TEXT")
|
|
|
|
def visit_BLOB(self, type_, **kw):
|
|
return "BLOB"
|
|
|
|
def visit_BINARY(self, type_, **kw):
|
|
return "BINARY" + (type_.length and "(%d)" % type_.length or "")
|
|
|
|
def visit_VARBINARY(self, type_, **kw):
|
|
return "VARBINARY" + (type_.length and "(%d)" % type_.length or "")
|
|
|
|
def visit_BOOLEAN(self, type_, **kw):
|
|
return "BOOLEAN"
|
|
|
|
def visit_large_binary(self, type_, **kw):
|
|
return self.visit_BLOB(type_, **kw)
|
|
|
|
def visit_boolean(self, type_, **kw):
|
|
return self.visit_BOOLEAN(type_, **kw)
|
|
|
|
def visit_time(self, type_, **kw):
|
|
return self.visit_TIME(type_, **kw)
|
|
|
|
def visit_datetime(self, type_, **kw):
|
|
return self.visit_DATETIME(type_, **kw)
|
|
|
|
def visit_date(self, type_, **kw):
|
|
return self.visit_DATE(type_, **kw)
|
|
|
|
def visit_big_integer(self, type_, **kw):
|
|
return self.visit_BIGINT(type_, **kw)
|
|
|
|
def visit_small_integer(self, type_, **kw):
|
|
return self.visit_SMALLINT(type_, **kw)
|
|
|
|
def visit_integer(self, type_, **kw):
|
|
return self.visit_INTEGER(type_, **kw)
|
|
|
|
def visit_real(self, type_, **kw):
|
|
return self.visit_REAL(type_, **kw)
|
|
|
|
def visit_float(self, type_, **kw):
|
|
return self.visit_FLOAT(type_, **kw)
|
|
|
|
def visit_numeric(self, type_, **kw):
|
|
return self.visit_NUMERIC(type_, **kw)
|
|
|
|
def visit_string(self, type_, **kw):
|
|
return self.visit_VARCHAR(type_, **kw)
|
|
|
|
def visit_unicode(self, type_, **kw):
|
|
return self.visit_VARCHAR(type_, **kw)
|
|
|
|
def visit_text(self, type_, **kw):
|
|
return self.visit_TEXT(type_, **kw)
|
|
|
|
def visit_unicode_text(self, type_, **kw):
|
|
return self.visit_TEXT(type_, **kw)
|
|
|
|
def visit_enum(self, type_, **kw):
|
|
return self.visit_VARCHAR(type_, **kw)
|
|
|
|
def visit_null(self, type_, **kw):
|
|
raise exc.CompileError("Can't generate DDL for %r; "
|
|
"did you forget to specify a "
|
|
"type on this Column?" % type_)
|
|
|
|
def visit_type_decorator(self, type_, **kw):
|
|
return self.process(type_.type_engine(self.dialect), **kw)
|
|
|
|
def visit_user_defined(self, type_, **kw):
|
|
return type_.get_col_spec(**kw)
|
|
|
|
|
|
class StrSQLTypeCompiler(GenericTypeCompiler):
|
|
def __getattr__(self, key):
|
|
if key.startswith("visit_"):
|
|
return self._visit_unknown
|
|
else:
|
|
raise AttributeError(key)
|
|
|
|
def _visit_unknown(self, type_, **kw):
|
|
return "%s" % type_.__class__.__name__
|
|
|
|
|
|
class IdentifierPreparer(object):
|
|
|
|
"""Handle quoting and case-folding of identifiers based on options."""
|
|
|
|
reserved_words = RESERVED_WORDS
|
|
|
|
legal_characters = LEGAL_CHARACTERS
|
|
|
|
illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
|
|
|
|
schema_for_object = schema._schema_getter(None)
|
|
|
|
def __init__(self, dialect, initial_quote='"',
|
|
final_quote=None, escape_quote='"', omit_schema=False):
|
|
"""Construct a new ``IdentifierPreparer`` object.
|
|
|
|
initial_quote
|
|
Character that begins a delimited identifier.
|
|
|
|
final_quote
|
|
Character that ends a delimited identifier. Defaults to
|
|
`initial_quote`.
|
|
|
|
omit_schema
|
|
Prevent prepending schema name. Useful for databases that do
|
|
not support schemae.
|
|
"""
|
|
|
|
self.dialect = dialect
|
|
self.initial_quote = initial_quote
|
|
self.final_quote = final_quote or self.initial_quote
|
|
self.escape_quote = escape_quote
|
|
self.escape_to_quote = self.escape_quote * 2
|
|
self.omit_schema = omit_schema
|
|
self._strings = {}
|
|
|
|
def _with_schema_translate(self, schema_translate_map):
|
|
prep = self.__class__.__new__(self.__class__)
|
|
prep.__dict__.update(self.__dict__)
|
|
prep.schema_for_object = schema._schema_getter(schema_translate_map)
|
|
return prep
|
|
|
|
def _escape_identifier(self, value):
|
|
"""Escape an identifier.
|
|
|
|
Subclasses should override this to provide database-dependent
|
|
escaping behavior.
|
|
"""
|
|
|
|
return value.replace(self.escape_quote, self.escape_to_quote)
|
|
|
|
def _unescape_identifier(self, value):
|
|
"""Canonicalize an escaped identifier.
|
|
|
|
Subclasses should override this to provide database-dependent
|
|
unescaping behavior that reverses _escape_identifier.
|
|
"""
|
|
|
|
return value.replace(self.escape_to_quote, self.escape_quote)
|
|
|
|
def quote_identifier(self, value):
|
|
"""Quote an identifier.
|
|
|
|
Subclasses should override this to provide database-dependent
|
|
quoting behavior.
|
|
"""
|
|
|
|
return self.initial_quote + \
|
|
self._escape_identifier(value) + \
|
|
self.final_quote
|
|
|
|
def _requires_quotes(self, value):
|
|
"""Return True if the given identifier requires quoting."""
|
|
lc_value = value.lower()
|
|
return (lc_value in self.reserved_words
|
|
or value[0] in self.illegal_initial_characters
|
|
or not self.legal_characters.match(util.text_type(value))
|
|
or (lc_value != value))
|
|
|
|
def quote_schema(self, schema, force=None):
|
|
"""Conditionally quote a schema.
|
|
|
|
Subclasses can override this to provide database-dependent
|
|
quoting behavior for schema names.
|
|
|
|
the 'force' flag should be considered deprecated.
|
|
|
|
"""
|
|
return self.quote(schema, force)
|
|
|
|
def quote(self, ident, force=None):
|
|
"""Conditionally quote an identifier.
|
|
|
|
the 'force' flag should be considered deprecated.
|
|
"""
|
|
|
|
force = getattr(ident, "quote", None)
|
|
|
|
if force is None:
|
|
if ident in self._strings:
|
|
return self._strings[ident]
|
|
else:
|
|
if self._requires_quotes(ident):
|
|
self._strings[ident] = self.quote_identifier(ident)
|
|
else:
|
|
self._strings[ident] = ident
|
|
return self._strings[ident]
|
|
elif force:
|
|
return self.quote_identifier(ident)
|
|
else:
|
|
return ident
|
|
|
|
def format_sequence(self, sequence, use_schema=True):
|
|
name = self.quote(sequence.name)
|
|
|
|
effective_schema = self.schema_for_object(sequence)
|
|
|
|
if (not self.omit_schema and use_schema and
|
|
effective_schema is not None):
|
|
name = self.quote_schema(effective_schema) + "." + name
|
|
return name
|
|
|
|
def format_label(self, label, name=None):
|
|
return self.quote(name or label.name)
|
|
|
|
def format_alias(self, alias, name=None):
|
|
return self.quote(name or alias.name)
|
|
|
|
def format_savepoint(self, savepoint, name=None):
|
|
# Running the savepoint name through quoting is unnecessary
|
|
# for all known dialects. This is here to support potential
|
|
# third party use cases
|
|
ident = name or savepoint.ident
|
|
if self._requires_quotes(ident):
|
|
ident = self.quote_identifier(ident)
|
|
return ident
|
|
|
|
@util.dependencies("sqlalchemy.sql.naming")
|
|
def format_constraint(self, naming, constraint):
|
|
if isinstance(constraint.name, elements._defer_name):
|
|
name = naming._constraint_name_for_table(
|
|
constraint, constraint.table)
|
|
if name:
|
|
return self.quote(name)
|
|
elif isinstance(constraint.name, elements._defer_none_name):
|
|
return None
|
|
return self.quote(constraint.name)
|
|
|
|
def format_table(self, table, use_schema=True, name=None):
|
|
"""Prepare a quoted table and schema name."""
|
|
|
|
if name is None:
|
|
name = table.name
|
|
result = self.quote(name)
|
|
|
|
effective_schema = self.schema_for_object(table)
|
|
|
|
if not self.omit_schema and use_schema \
|
|
and effective_schema:
|
|
result = self.quote_schema(effective_schema) + "." + result
|
|
return result
|
|
|
|
def format_schema(self, name, quote=None):
|
|
"""Prepare a quoted schema name."""
|
|
|
|
return self.quote(name, quote)
|
|
|
|
def format_column(self, column, use_table=False,
|
|
name=None, table_name=None):
|
|
"""Prepare a quoted column name."""
|
|
|
|
if name is None:
|
|
name = column.name
|
|
if not getattr(column, 'is_literal', False):
|
|
if use_table:
|
|
return self.format_table(
|
|
column.table, use_schema=False,
|
|
name=table_name) + "." + self.quote(name)
|
|
else:
|
|
return self.quote(name)
|
|
else:
|
|
# literal textual elements get stuck into ColumnClause a lot,
|
|
# which shouldn't get quoted
|
|
|
|
if use_table:
|
|
return self.format_table(
|
|
column.table, use_schema=False,
|
|
name=table_name) + '.' + name
|
|
else:
|
|
return name
|
|
|
|
def format_table_seq(self, table, use_schema=True):
|
|
"""Format table name and schema as a tuple."""
|
|
|
|
# Dialects with more levels in their fully qualified references
|
|
# ('database', 'owner', etc.) could override this and return
|
|
# a longer sequence.
|
|
|
|
effective_schema = self.schema_for_object(table)
|
|
|
|
if not self.omit_schema and use_schema and \
|
|
effective_schema:
|
|
return (self.quote_schema(effective_schema),
|
|
self.format_table(table, use_schema=False))
|
|
else:
|
|
return (self.format_table(table, use_schema=False), )
|
|
|
|
@util.memoized_property
|
|
def _r_identifiers(self):
|
|
initial, final, escaped_final = \
|
|
[re.escape(s) for s in
|
|
(self.initial_quote, self.final_quote,
|
|
self._escape_identifier(self.final_quote))]
|
|
r = re.compile(
|
|
r'(?:'
|
|
r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s'
|
|
r'|([^\.]+))(?=\.|$))+' %
|
|
{'initial': initial,
|
|
'final': final,
|
|
'escaped': escaped_final})
|
|
return r
|
|
|
|
def unformat_identifiers(self, identifiers):
|
|
"""Unpack 'schema.table.column'-like strings into components."""
|
|
|
|
r = self._r_identifiers
|
|
return [self._unescape_identifier(i)
|
|
for i in [a or b for a, b in r.findall(identifiers)]]
|