- apply pep8 to compiler.py

- deprecate Compiled.compile() - have __init__ do compilation
if statement is present.
This commit is contained in:
Mike Bayer
2010-12-21 16:34:00 -05:00
parent b0f48ca2a9
commit dff4e0591e
6 changed files with 234 additions and 137 deletions
+1 -1
View File
@@ -677,8 +677,8 @@ class MSSQLCompiler(compiler.SQLCompiler):
})
def __init__(self, *args, **kwargs):
super(MSSQLCompiler, self).__init__(*args, **kwargs)
self.tablealiases = {}
super(MSSQLCompiler, self).__init__(*args, **kwargs)
def visit_now_func(self, fn, **kw):
return "CURRENT_TIMESTAMP"
+7 -4
View File
@@ -675,14 +675,17 @@ class Compiled(object):
"""
self.dialect = dialect
self.statement = statement
self.bind = bind
self.can_execute = statement.supports_execution
if statement is not None:
self.statement = statement
self.can_execute = statement.supports_execution
self.string = self.process(self.statement)
@util.deprecated("0.7", ":class:`.Compiled` objects now compile "
"within the constructor.")
def compile(self):
"""Produce the internal string representation of this element."""
self.string = self.process(self.statement)
pass
@property
def sql_compiler(self):
+217 -122
View File
@@ -24,7 +24,8 @@ To generate user-defined SQL strings, see
import re
from sqlalchemy import schema, engine, util, exc
from sqlalchemy.sql import operators, functions, util as sql_util, visitors
from sqlalchemy.sql import operators, functions, util as sql_util, \
visitors
from sqlalchemy.sql import expression as sql
import decimal
@@ -197,7 +198,8 @@ class SQLCompiler(engine.Compiled):
# driver/DB enforces this
ansi_bind_rules = False
def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs):
def __init__(self, dialect, statement, column_keys=None,
inline=False, **kwargs):
"""Construct a new ``DefaultCompiler`` object.
dialect
@@ -211,47 +213,49 @@ class SQLCompiler(engine.Compiled):
statement.
"""
engine.Compiled.__init__(self, dialect, statement, **kwargs)
self.column_keys = column_keys
# compile INSERT/UPDATE defaults/sequences inlined (no pre-execute)
# compile INSERT/UPDATE defaults/sequences inlined (no pre-
# execute)
self.inline = inline or getattr(statement, 'inline', False)
# a dictionary of bind parameter keys to _BindParamClause instances.
# a dictionary of bind parameter keys to _BindParamClause
# instances.
self.binds = {}
# a dictionary of _BindParamClause instances to "compiled" names that are
# actually present in the generated SQL
# a dictionary of _BindParamClause instances to "compiled" names
# that are actually present in the generated SQL
self.bind_names = util.column_dict()
# stack which keeps track of nested SELECT statements
self.stack = []
# relates label names in the final SQL to
# a tuple of local column/label name, ColumnElement object (if any) and TypeEngine.
# ResultProxy uses this for type processing and column targeting
# relates label names in the final SQL to a tuple of local
# column/label name, ColumnElement object (if any) and
# TypeEngine. ResultProxy uses this for type processing and
# column targeting
self.result_map = {}
# true if the paramstyle is positional
self.positional = self.dialect.positional
self.positional = dialect.positional
if self.positional:
self.positiontup = []
self.bindtemplate = BIND_TEMPLATES[self.dialect.paramstyle]
self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
# an IdentifierPreparer that formats the quoting of identifiers
self.preparer = self.dialect.identifier_preparer
self.preparer = dialect.identifier_preparer
self.label_length = dialect.label_length \
or dialect.max_identifier_length
self.label_length = self.dialect.label_length or self.dialect.max_identifier_length
# a map which tracks "anonymous" identifiers that are
# created on the fly here
# a map which tracks "anonymous" identifiers that are created on
# the fly here
self.anon_map = util.PopulateDict(self._process_anon)
# a map which tracks "truncated" names based on dialect.label_length
# or dialect.max_identifier_length
# a map which tracks "truncated" names based on
# dialect.label_length or dialect.max_identifier_length
self.truncated_names = {}
engine.Compiled.__init__(self, dialect, statement, **kwargs)
@util.memoized_property
@@ -284,13 +288,13 @@ class SQLCompiler(engine.Compiled):
elif bindparam.required:
if _group_number:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r, "
"in parameter group %d" %
(bindparam.key, _group_number))
"A value is required for bind parameter %r, "
"in parameter group %d" %
(bindparam.key, _group_number))
else:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r"
% bindparam.key)
"A value is required for bind parameter %r"
% bindparam.key)
elif bindparam.callable:
pd[name] = bindparam.callable()
else:
@@ -311,7 +315,8 @@ class SQLCompiler(engine.Compiled):
""")
def default_from(self):
"""Called when a SELECT statement has no froms, and no FROM clause is to be appended.
"""Called when a SELECT statement has no froms, and no FROM clause is
to be appended.
Gives Oracle a chance to tack on a ``FROM DUAL`` to the string output.
@@ -328,12 +333,15 @@ class SQLCompiler(engine.Compiled):
# or ORDER BY clause of a select. dialect-specific compilers
# can modify this behavior.
if within_columns_clause and not within_label_clause:
labelname = isinstance(label.name, sql._generated_label) and \
self._truncated_identifier("colident", label.name) or label.name
if isinstance(label.name, sql._generated_label):
labelname = self._truncated_identifier("colident", label.name)
else:
labelname = label.name
if result_map is not None:
result_map[labelname.lower()] = \
(label.name, (label, label.element, labelname), label.type)
(label.name, (label, label.element, labelname),\
label.type)
return self.process(label.element,
within_columns_clause=True,
@@ -373,11 +381,12 @@ class SQLCompiler(engine.Compiled):
else:
schema_prefix = ''
tablename = column.table.name
tablename = isinstance(tablename, sql._generated_label) and \
self._truncated_identifier("alias", tablename) or tablename
if isinstance(tablename, sql._generated_label):
tablename = self._truncated_identifier("alias", tablename)
return schema_prefix + \
self.preparer.quote(tablename, column.table.quote) + "." + name
self.preparer.quote(tablename, column.table.quote) + \
"." + name
def escape_literal_column(self, text):
"""provide escaping for the literal_column() construct."""
@@ -411,7 +420,8 @@ class SQLCompiler(engine.Compiled):
# un-escape any \:params
return BIND_PARAMS_ESC.sub(lambda m: m.group(1),
BIND_PARAMS.sub(do_bindparam, self.post_process_text(textclause.text))
BIND_PARAMS.sub(do_bindparam,
self.post_process_text(textclause.text))
)
def visit_null(self, null, **kwargs):
@@ -423,8 +433,11 @@ class SQLCompiler(engine.Compiled):
sep = " "
else:
sep = OPERATORS[clauselist.operator]
return sep.join(s for s in (self.process(c, **kwargs) for c in clauselist.clauses)
if s is not None)
return sep.join(
s for s in
(self.process(c, **kwargs)
for c in clauselist.clauses)
if s is not None)
def visit_case(self, clause, **kwargs):
x = "CASE "
@@ -440,11 +453,13 @@ class SQLCompiler(engine.Compiled):
def visit_cast(self, cast, **kwargs):
return "CAST(%s AS %s)" % \
(self.process(cast.clause, **kwargs), self.process(cast.typeclause, **kwargs))
(self.process(cast.clause, **kwargs),
self.process(cast.typeclause, **kwargs))
def visit_extract(self, extract, **kwargs):
field = self.extract_map.get(extract.field, extract.field)
return "EXTRACT(%s FROM %s)" % (field, self.process(extract.expr, **kwargs))
return "EXTRACT(%s FROM %s)" % (field,
self.process(extract.expr, **kwargs))
def visit_function(self, func, result_map=None, **kwargs):
if result_map is not None:
@@ -461,7 +476,8 @@ class SQLCompiler(engine.Compiled):
def function_argspec(self, func, **kwargs):
return self.process(func.clause_expr, **kwargs)
def visit_compound_select(self, cs, asfrom=False, parens=True, compound_index=1, **kwargs):
def visit_compound_select(self, cs, asfrom=False,
parens=True, compound_index=1, **kwargs):
entry = self.stack and self.stack[-1] or {}
self.stack.append({'from':entry.get('from', None), 'iswrapper':True})
@@ -478,7 +494,8 @@ class SQLCompiler(engine.Compiled):
text += " GROUP BY " + group_by
text += self.order_by_clause(cs, **kwargs)
text += (cs._limit is not None or cs._offset is not None) and self.limit_clause(cs) or ""
text += (cs._limit is not None or cs._offset is not None) and \
self.limit_clause(cs) or ""
self.stack.pop(-1)
if asfrom and parens:
@@ -530,8 +547,8 @@ class SQLCompiler(engine.Compiled):
def visit_ilike_op(self, binary, **kw):
escape = binary.modifiers.get("escape", None)
return 'lower(%s) LIKE lower(%s)' % (
self.process(binary.left, **kw),
self.process(binary.right, **kw)) \
self.process(binary.left, **kw),
self.process(binary.right, **kw)) \
+ (escape and
(' ESCAPE ' + self.render_literal_value(escape, None))
or '')
@@ -539,8 +556,8 @@ class SQLCompiler(engine.Compiled):
def visit_notilike_op(self, binary, **kw):
escape = binary.modifiers.get("escape", None)
return 'lower(%s) NOT LIKE lower(%s)' % (
self.process(binary.left, **kw),
self.process(binary.right, **kw)) \
self.process(binary.left, **kw),
self.process(binary.right, **kw)) \
+ (escape and
(' ESCAPE ' + self.render_literal_value(escape, None))
or '')
@@ -563,7 +580,8 @@ class SQLCompiler(engine.Compiled):
if bindparam.value is None:
raise exc.CompileError("Bind parameter without a "
"renderable value not allowed here.")
return self.render_literal_bindparam(bindparam, within_columns_clause=True, **kwargs)
return self.render_literal_bindparam(bindparam,
within_columns_clause=True, **kwargs)
name = self._truncate_bindparam(bindparam)
if name in self.binds:
@@ -572,17 +590,19 @@ class SQLCompiler(engine.Compiled):
if existing.unique or bindparam.unique:
raise exc.CompileError(
"Bind parameter '%s' conflicts with "
"unique bind parameter of the same name" % bindparam.key
"unique bind parameter of the same name" %
bindparam.key
)
elif getattr(existing, '_is_crud', False):
raise exc.CompileError(
"bindparam() name '%s' is reserved "
"for automatic usage in the VALUES or SET clause of this "
"insert/update statement. Please use a "
"name other than column name when using bindparam() "
"with insert() or update() (for example, 'b_%s')."
% (bindparam.key, bindparam.key)
)
"bindparam() name '%s' is reserved "
"for automatic usage in the VALUES or SET "
"clause of this "
"insert/update statement. Please use a "
"name other than column name when using bindparam() "
"with insert() or update() (for example, 'b_%s')."
% (bindparam.key, bindparam.key)
)
self.binds[bindparam.key] = self.binds[name] = bindparam
return self.bindparam_string(name)
@@ -614,15 +634,17 @@ class SQLCompiler(engine.Compiled):
elif isinstance(value, decimal.Decimal):
return str(value)
else:
raise NotImplementedError("Don't know how to literal-quote value %r" % value)
raise NotImplementedError(
"Don't know how to literal-quote value %r" % value)
def _truncate_bindparam(self, bindparam):
if bindparam in self.bind_names:
return self.bind_names[bindparam]
bind_name = bindparam.key
bind_name = isinstance(bind_name, sql._generated_label) and \
self._truncated_identifier("bindparam", bind_name) or bind_name
if isinstance(bind_name, sql._generated_label):
bind_name = self._truncated_identifier("bindparam", bind_name)
# add to bind_names for translation
self.bind_names[bindparam] = bind_name
@@ -636,7 +658,8 @@ class SQLCompiler(engine.Compiled):
if len(anonname) > self.label_length:
counter = self.truncated_names.get(ident_class, 1)
truncname = anonname[0:max(self.label_length - 6, 0)] + "_" + hex(counter)[2:]
truncname = anonname[0:max(self.label_length - 6, 0)] + \
"_" + hex(counter)[2:]
self.truncated_names[ident_class] = counter + 1
else:
truncname = anonname
@@ -659,14 +682,19 @@ class SQLCompiler(engine.Compiled):
else:
return self.bindtemplate % {'name':name}
def visit_alias(self, alias, asfrom=False, ashint=False, fromhints=None, **kwargs):
def visit_alias(self, alias, asfrom=False, ashint=False,
fromhints=None, **kwargs):
if asfrom or ashint:
alias_name = isinstance(alias.name, sql._generated_label) and \
self._truncated_identifier("alias", alias.name) or alias.name
if isinstance(alias.name, sql._generated_label):
alias_name = self._truncated_identifier("alias", alias.name)
else:
alias_name = alias.name
if ashint:
return self.preparer.format_alias(alias, alias_name)
elif asfrom:
ret = self.process(alias.original, asfrom=True, **kwargs) + " AS " + \
ret = self.process(alias.original, asfrom=True, **kwargs) + \
" AS " + \
self.preparer.format_alias(alias, alias_name)
if fromhints and alias in fromhints:
@@ -695,8 +723,10 @@ class SQLCompiler(engine.Compiled):
not isinstance(column.table, sql.Select):
return _CompileLabel(column, sql._generated_label(column.name))
elif not isinstance(column,
(sql._UnaryExpression, sql._TextClause, sql._BindParamClause)) \
and (not hasattr(column, 'name') or isinstance(column, sql.Function)):
(sql._UnaryExpression, sql._TextClause,
sql._BindParamClause)) \
and (not hasattr(column, 'name') or \
isinstance(column, sql.Function)):
return _CompileLabel(column, column.anon_label)
else:
return column
@@ -719,12 +749,13 @@ class SQLCompiler(engine.Compiled):
correlate_froms = set(sql._from_objects(*froms))
# TODO: might want to propagate existing froms for select(select(select))
# where innermost select should correlate to outermost
# if existingfroms:
# correlate_froms = correlate_froms.union(existingfroms)
# TODO: might want to propagate existing froms for
# select(select(select)) where innermost select should correlate
# to outermost if existingfroms: correlate_froms =
# correlate_froms.union(existingfroms)
self.stack.append({'from':correlate_froms, 'iswrapper':iswrapper})
self.stack.append({'from': correlate_froms, 'iswrapper'
: iswrapper})
if compound_index==1 and not entry or entry.get('iswrapper', False):
column_clause_args = {'result_map':self.result_map}
@@ -747,7 +778,8 @@ class SQLCompiler(engine.Compiled):
if select._hints:
byfrom = dict([
(from_, hinttext % {'name':self.process(from_, ashint=True)})
(from_, hinttext % {
'name':self.process(from_, ashint=True)})
for (from_, dialect), hinttext in
select._hints.iteritems()
if dialect in ('*', self.dialect.name)
@@ -757,7 +789,9 @@ class SQLCompiler(engine.Compiled):
text += hint_text + " "
if select._prefixes:
text += " ".join(self.process(x, **kwargs) for x in select._prefixes) + " "
text += " ".join(
self.process(x, **kwargs)
for x in select._prefixes) + " "
text += self.get_select_precolumns(select)
text += ', '.join(inner_columns)
@@ -806,8 +840,8 @@ class SQLCompiler(engine.Compiled):
return text
def get_select_precolumns(self, select):
"""Called when building a ``SELECT`` statement, position is just before
column list.
"""Called when building a ``SELECT`` statement, position is just
before column list.
"""
return select._distinct and "DISTINCT " or ""
@@ -835,11 +869,14 @@ class SQLCompiler(engine.Compiled):
text += " OFFSET " + self.process(sql.literal(select._offset))
return text
def visit_table(self, table, asfrom=False, ashint=False, fromhints=None, **kwargs):
def visit_table(self, table, asfrom=False, ashint=False,
fromhints=None, **kwargs):
if asfrom or ashint:
if getattr(table, "schema", None):
ret = self.preparer.quote_schema(table.schema, table.quote_schema) + \
"." + self.preparer.quote(table.name, table.quote)
ret = self.preparer.quote_schema(table.schema,
table.quote_schema) + \
"." + self.preparer.quote(table.name,
table.quote)
else:
ret = self.preparer.quote(table.name, table.quote)
if fromhints and table in fromhints:
@@ -887,7 +924,8 @@ class SQLCompiler(engine.Compiled):
if self.returning or insert_stmt._returning:
self.returning = self.returning or insert_stmt._returning
returning_clause = self.returning_clause(insert_stmt, self.returning)
returning_clause = self.returning_clause(
insert_stmt, self.returning)
if self.returning_precedes_values:
text += " " + returning_clause
@@ -913,27 +951,31 @@ class SQLCompiler(engine.Compiled):
text += ' SET ' + \
', '.join(
self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1]
self.preparer.quote(c[0].name, c[0].quote) +
'=' + c[1]
for c in colparams
)
if update_stmt._returning:
self.returning = update_stmt._returning
if self.returning_precedes_values:
text += " " + self.returning_clause(update_stmt, update_stmt._returning)
text += " " + self.returning_clause(
update_stmt, update_stmt._returning)
if update_stmt._whereclause is not None:
text += " WHERE " + self.process(update_stmt._whereclause)
if self.returning and not self.returning_precedes_values:
text += " " + self.returning_clause(update_stmt, update_stmt._returning)
text += " " + self.returning_clause(
update_stmt, update_stmt._returning)
self.stack.pop(-1)
return text
def _create_crud_bind_param(self, col, value, required=False):
bindparam = sql.bindparam(col.key, value, type_=col.type, required=required)
bindparam = sql.bindparam(col.key, value,
type_=col.type, required=required)
bindparam._is_crud = True
if col.key in self.binds:
raise exc.CompileError(
@@ -952,8 +994,8 @@ class SQLCompiler(engine.Compiled):
"""create a set of tuples representing column/string pairs for use
in an INSERT or UPDATE statement.
Also generates the Compiled object's postfetch, prefetch, and returning
column collections, used for default handling and ultimately
Also generates the Compiled object's postfetch, prefetch, and
returning column collections, used for default handling and ultimately
populating the ResultProxy's prefetch_cols() and postfetch_cols()
collections.
@@ -967,7 +1009,8 @@ class SQLCompiler(engine.Compiled):
# compiled params - return binds for all columns
if self.column_keys is None and stmt.parameters is None:
return [
(c, self._create_crud_bind_param(c, None, required=True))
(c, self._create_crud_bind_param(c,
None, required=True))
for c in stmt.table.columns
]
@@ -980,7 +1023,8 @@ class SQLCompiler(engine.Compiled):
else:
parameters = dict((sql._column_as_key(key), required)
for key in self.column_keys
if not stmt.parameters or key not in stmt.parameters)
if not stmt.parameters or
key not in stmt.parameters)
if stmt.parameters is not None:
for k, v in stmt.parameters.iteritems():
@@ -1006,7 +1050,8 @@ class SQLCompiler(engine.Compiled):
if c.key in parameters:
value = parameters[c.key]
if sql._is_literal(value):
value = self._create_crud_bind_param(c, value, required=value is required)
value = self._create_crud_bind_param(
c, value, required=value is required)
else:
self.postfetch.append(c)
value = self.process(value.self_group())
@@ -1029,10 +1074,15 @@ class SQLCompiler(engine.Compiled):
values.append((c, proc))
self.returning.append(c)
elif c.default.is_clause_element:
values.append((c, self.process(c.default.arg.self_group())))
values.append(
(c,
self.process(c.default.arg.self_group()))
)
self.returning.append(c)
else:
values.append((c, self._create_crud_bind_param(c, None)))
values.append(
(c, self._create_crud_bind_param(c, None))
)
self.prefetch.append(c)
else:
self.returning.append(c)
@@ -1043,9 +1093,12 @@ class SQLCompiler(engine.Compiled):
self.dialect.supports_sequences or
not c.default.is_sequence
)
) or self.dialect.preexecute_autoincrement_sequences:
) or \
self.dialect.preexecute_autoincrement_sequences:
values.append((c, self._create_crud_bind_param(c, None)))
values.append(
(c, self._create_crud_bind_param(c, None))
)
self.prefetch.append(c)
elif c.default is not None:
@@ -1056,13 +1109,17 @@ class SQLCompiler(engine.Compiled):
if not c.primary_key:
self.postfetch.append(c)
elif c.default.is_clause_element:
values.append((c, self.process(c.default.arg.self_group())))
values.append(
(c, self.process(c.default.arg.self_group()))
)
if not c.primary_key:
# dont add primary key column to postfetch
self.postfetch.append(c)
else:
values.append((c, self._create_crud_bind_param(c, None)))
values.append(
(c, self._create_crud_bind_param(c, None))
)
self.prefetch.append(c)
elif c.server_default is not None:
if not c.primary_key:
@@ -1071,10 +1128,14 @@ class SQLCompiler(engine.Compiled):
elif self.isupdate:
if c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
values.append((c, self.process(c.onupdate.arg.self_group())))
values.append(
(c, self.process(c.onupdate.arg.self_group()))
)
self.postfetch.append(c)
else:
values.append((c, self._create_crud_bind_param(c, None)))
values.append(
(c, self._create_crud_bind_param(c, None))
)
self.prefetch.append(c)
elif c.server_onupdate is not None:
self.postfetch.append(c)
@@ -1089,13 +1150,15 @@ class SQLCompiler(engine.Compiled):
if delete_stmt._returning:
self.returning = delete_stmt._returning
if self.returning_precedes_values:
text += " " + self.returning_clause(delete_stmt, delete_stmt._returning)
text += " " + self.returning_clause(
delete_stmt, delete_stmt._returning)
if delete_stmt._whereclause is not None:
text += " WHERE " + self.process(delete_stmt._whereclause)
if self.returning and not self.returning_precedes_values:
text += " " + self.returning_clause(delete_stmt, delete_stmt._returning)
text += " " + self.returning_clause(
delete_stmt, delete_stmt._returning)
self.stack.pop(-1)
@@ -1105,17 +1168,19 @@ class SQLCompiler(engine.Compiled):
return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
def visit_rollback_to_savepoint(self, savepoint_stmt):
return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
return "ROLLBACK TO SAVEPOINT %s" % \
self.preparer.format_savepoint(savepoint_stmt)
def visit_release_savepoint(self, savepoint_stmt):
return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
return "RELEASE SAVEPOINT %s" % \
self.preparer.format_savepoint(savepoint_stmt)
class DDLCompiler(engine.Compiled):
@util.memoized_property
def sql_compiler(self):
return self.dialect.statement_compiler(self.dialect, self.statement)
return self.dialect.statement_compiler(self.dialect, None)
@property
def preparer(self):
@@ -1161,11 +1226,13 @@ class DDLCompiler(engine.Compiled):
separator = ", \n"
text += "\t" + self.get_column_specification(
column,
first_pk=column.primary_key and not first_pk
first_pk=column.primary_key and \
not first_pk
)
if column.primary_key:
first_pk = True
const = " ".join(self.process(constraint) for constraint in column.constraints)
const = " ".join(self.process(constraint) \
for constraint in column.constraints)
if const:
text += " " + const
@@ -1184,10 +1251,12 @@ class DDLCompiler(engine.Compiled):
if table.primary_key:
constraints.append(table.primary_key)
constraints.extend([c for c in table.constraints if c is not table.primary_key])
constraints.extend([c for c in table.constraints
if c is not table.primary_key])
return ", \n\t".join(p for p in
(self.process(constraint) for constraint in constraints
(self.process(constraint)
for constraint in constraints
if (
constraint._create_rule is None or
constraint._create_rule(self))
@@ -1230,7 +1299,8 @@ class DDLCompiler(engine.Compiled):
def visit_drop_index(self, drop):
index = drop.element
return "\nDROP INDEX " + \
self.preparer.quote(self._index_identifier(index.name), index.quote)
self.preparer.quote(
self._index_identifier(index.name), index.quote)
def visit_add_constraint(self, create):
preparer = self.preparer
@@ -1240,7 +1310,8 @@ class DDLCompiler(engine.Compiled):
)
def visit_create_sequence(self, create):
text = "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element)
text = "CREATE SEQUENCE %s" % \
self.preparer.format_sequence(create.element)
if create.element.increment is not None:
text += " INCREMENT BY %d" % create.element.increment
if create.element.start is not None:
@@ -1248,7 +1319,8 @@ class DDLCompiler(engine.Compiled):
return text
def visit_drop_sequence(self, drop):
return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
return "DROP SEQUENCE %s" % \
self.preparer.format_sequence(drop.element)
def visit_drop_constraint(self, drop):
preparer = self.preparer
@@ -1301,7 +1373,8 @@ class DDLCompiler(engine.Compiled):
return ''
text = ""
if constraint.name is not None:
text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint)
text += "CONSTRAINT %s " % \
self.preparer.format_constraint(constraint)
text += "PRIMARY KEY "
text += "(%s)" % ', '.join(self.preparer.quote(c.name, c.quote)
for c in constraint)
@@ -1318,7 +1391,8 @@ class DDLCompiler(engine.Compiled):
text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
', '.join(preparer.quote(f.parent.name, f.parent.quote)
for f in constraint._elements.values()),
self.define_constraint_remote_table(constraint, remote_table, preparer),
self.define_constraint_remote_table(
constraint, remote_table, preparer),
', '.join(preparer.quote(f.column.name, f.column.quote)
for f in constraint._elements.values())
)
@@ -1334,8 +1408,11 @@ class DDLCompiler(engine.Compiled):
def visit_unique_constraint(self, constraint):
text = ""
if constraint.name is not None:
text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint)
text += "UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint))
text += "CONSTRAINT %s " % \
self.preparer.format_constraint(constraint)
text += "UNIQUE (%s)" % (
', '.join(self.preparer.quote(c.name, c.quote)
for c in constraint))
text += self.define_constraint_deferrability(constraint)
return text
@@ -1373,9 +1450,12 @@ class GenericTypeCompiler(engine.TypeCompiler):
if type_.precision is None:
return "NUMERIC"
elif type_.scale is None:
return "NUMERIC(%(precision)s)" % {'precision': type_.precision}
return "NUMERIC(%(precision)s)" % \
{'precision': type_.precision}
else:
return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale}
return "NUMERIC(%(precision)s, %(scale)s)" % \
{'precision': type_.precision,
'scale' : type_.scale}
def visit_DECIMAL(self, type_):
return "DECIMAL"
@@ -1499,7 +1579,8 @@ class IdentifierPreparer(object):
Character that begins a delimited identifier.
final_quote
Character that ends a delimited identifier. Defaults to `initial_quote`.
Character that ends a delimited identifier. Defaults to
`initial_quote`.
omit_schema
Prevent prepending schema name. Useful for databases that do
@@ -1539,7 +1620,9 @@ class IdentifierPreparer(object):
quoting behavior.
"""
return self.initial_quote + self._escape_identifier(value) + self.final_quote
return self.initial_quote + \
self._escape_identifier(value) + \
self.final_quote
def _requires_quotes(self, value):
"""Return True if the given identifier requires quoting."""
@@ -1574,8 +1657,10 @@ class IdentifierPreparer(object):
def format_sequence(self, sequence, use_schema=True):
name = self.quote(sequence.name, sequence.quote)
if not self.omit_schema and use_schema and sequence.schema is not None:
name = self.quote_schema(sequence.schema, sequence.quote) + "." + name
if not self.omit_schema and use_schema and \
sequence.schema is not None:
name = self.quote_schema(sequence.schema, sequence.quote) + \
"." + name
return name
def format_label(self, label, name=None):
@@ -1596,24 +1681,33 @@ class IdentifierPreparer(object):
if name is None:
name = table.name
result = self.quote(name, table.quote)
if not self.omit_schema and use_schema and getattr(table, "schema", None):
result = self.quote_schema(table.schema, table.quote_schema) + "." + result
if not self.omit_schema and use_schema \
and getattr(table, "schema", None):
result = self.quote_schema(table.schema, table.quote_schema) + \
"." + result
return result
def format_column(self, column, use_table=False, name=None, table_name=None):
def format_column(self, column, use_table=False,
name=None, table_name=None):
"""Prepare a quoted column name."""
if name is None:
name = column.name
if not getattr(column, 'is_literal', False):
if use_table:
return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(name, column.quote)
return self.format_table(
column.table, use_schema=False,
name=table_name) + "." + \
self.quote(name, column.quote)
else:
return self.quote(name, column.quote)
else:
# literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted
# literal textual elements get stuck into ColumnClause alot,
# which shouldnt get quoted
if use_table:
return self.format_table(column.table, use_schema=False, name=table_name) + "." + name
return self.format_table(column.table,
use_schema=False, name=table_name) + '.' + name
else:
return name
@@ -1624,7 +1718,8 @@ class IdentifierPreparer(object):
# ('database', 'owner', etc.) could override this and return
# a longer sequence.
if not self.omit_schema and use_schema and getattr(table, 'schema', None):
if not self.omit_schema and use_schema and \
getattr(table, 'schema', None):
return (self.quote_schema(table.schema, table.quote_schema),
self.format_table(table, use_schema=False))
else:
+4 -4
View File
@@ -1451,10 +1451,10 @@ class ClauseElement(Visitable):
bind = self.bind
else:
dialect = default.DefaultDialect()
compiler = self._compiler(dialect, bind=bind, **kw)
compiler.compile()
return compiler
c= self._compiler(dialect, bind=bind, **kw)
#c.string = c.process(c.statement)
return c
def _compiler(self, dialect, **kw):
"""Return a compiler appropriate for this ClauseElement, given a
Dialect."""
+4 -5
View File
@@ -188,8 +188,7 @@ class TypesTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
table_args.append(Column('c%s' % index, type_(*args, **kw)))
numeric_table = Table(*table_args)
gen = testing.db.dialect.ddl_compiler(
testing.db.dialect, numeric_table)
gen = testing.db.dialect.ddl_compiler(testing.db.dialect, None)
for col in numeric_table.c:
index = int(col.name[1:])
@@ -277,8 +276,7 @@ class TypesTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
table_args.append(Column('c%s' % index, type_(*args, **kw)))
charset_table = Table(*table_args)
gen = testing.db.dialect.ddl_compiler(testing.db.dialect,
charset_table)
gen = testing.db.dialect.ddl_compiler(testing.db.dialect, None)
for col in charset_table.c:
index = int(col.name[1:])
@@ -1471,5 +1469,6 @@ class MatchTest(TestBase, AssertsCompiledSQL):
def colspec(c):
return testing.db.dialect.ddl_compiler(testing.db.dialect, c.table).get_column_specification(c)
return testing.db.dialect.ddl_compiler(
testing.db.dialect, None).get_column_specification(c)
+1 -1
View File
@@ -408,7 +408,7 @@ class DDLExecutionTest(TestBase):
"""test the escaping of % characters in the DDL construct."""
default_from = testing.db.dialect.statement_compiler(
testing.db.dialect, DDL("")).default_from()
testing.db.dialect, None).default_from()
eq_(
testing.db.execute(