mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-06 08:56:51 -04:00
moves the binding of a TypeEngine object from "schema/statement creation" time into "compilation" time
This commit is contained in:
@@ -189,7 +189,10 @@ class ANSICompiler(sql.Compiled):
|
||||
|
||||
def visit_index(self, index):
|
||||
self.strings[index] = index.name
|
||||
|
||||
|
||||
def visit_typeclause(self, typeclause):
|
||||
self.strings[typeclause] = typeclause.type.engine_impl(self.engine).get_col_spec()
|
||||
|
||||
def visit_textclause(self, textclause):
|
||||
if textclause.parens and len(textclause.text):
|
||||
self.strings[textclause] = "(" + textclause.text + ")"
|
||||
|
||||
@@ -238,7 +238,7 @@ class FBCompiler(ansisql.ANSICompiler):
|
||||
class FBSchemaGenerator(ansisql.ANSISchemaGenerator):
|
||||
def get_column_specification(self, column, override_pk=False, **kwargs):
|
||||
colspec = column.name
|
||||
colspec += " " + column.type.get_col_spec()
|
||||
colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
|
||||
default = self.get_column_default_string(column)
|
||||
if default is not None:
|
||||
colspec += " DEFAULT " + default
|
||||
|
||||
@@ -460,7 +460,7 @@ class MSSQLCompiler(ansisql.ANSICompiler):
|
||||
|
||||
class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
|
||||
def get_column_specification(self, column, override_pk=False, first_pk=False):
|
||||
colspec = column.name + " " + column.type.get_col_spec()
|
||||
colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec()
|
||||
|
||||
# install a IDENTITY Sequence if we have an implicit IDENTITY column
|
||||
if column.primary_key and isinstance(column.type, types.Integer):
|
||||
|
||||
@@ -263,7 +263,7 @@ class MySQLCompiler(ansisql.ANSICompiler):
|
||||
|
||||
class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
|
||||
def get_column_specification(self, column, override_pk=False, first_pk=False):
|
||||
colspec = column.name + " " + column.type.get_col_spec()
|
||||
colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec()
|
||||
default = self.get_column_default_string(column)
|
||||
if default is not None:
|
||||
colspec += " DEFAULT " + default
|
||||
|
||||
@@ -306,7 +306,7 @@ class OracleCompiler(ansisql.ANSICompiler):
|
||||
class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
|
||||
def get_column_specification(self, column, override_pk=False, **kwargs):
|
||||
colspec = column.name
|
||||
colspec += " " + column.type.get_col_spec()
|
||||
colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
|
||||
default = self.get_column_default_string(column)
|
||||
if default is not None:
|
||||
colspec += " DEFAULT " + default
|
||||
|
||||
@@ -305,7 +305,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
|
||||
if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
|
||||
colspec += " SERIAL"
|
||||
else:
|
||||
colspec += " " + column.type.get_col_spec()
|
||||
colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
|
||||
default = self.get_column_default_string(column)
|
||||
if default is not None:
|
||||
colspec += " DEFAULT " + default
|
||||
|
||||
@@ -241,7 +241,7 @@ class SQLiteCompiler(ansisql.ANSICompiler):
|
||||
|
||||
class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
|
||||
def get_column_specification(self, column, override_pk=False, **kwargs):
|
||||
colspec = column.name + " " + column.type.get_col_spec()
|
||||
colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec()
|
||||
default = self.get_column_default_string(column)
|
||||
if default is not None:
|
||||
colspec += " DEFAULT " + default
|
||||
|
||||
@@ -319,7 +319,7 @@ class SQLEngine(schema.SchemaEngine):
|
||||
self.positional = True
|
||||
else:
|
||||
raise DBAPIError("Unsupported paramstyle '%s'" % self._paramstyle)
|
||||
|
||||
|
||||
def type_descriptor(self, typeobj):
|
||||
"""provides a database-specific TypeEngine object, given the generic object
|
||||
which comes from the types module. Subclasses will usually use the adapt_type()
|
||||
@@ -808,7 +808,7 @@ class ResultProxy:
|
||||
rec = self.props[key.lower()]
|
||||
else:
|
||||
rec = self.props[key]
|
||||
return rec[0].convert_result_value(row[rec[1]], self.engine)
|
||||
return rec[0].engine_impl(self.engine).convert_result_value(row[rec[1]], self.engine)
|
||||
|
||||
def __iter__(self):
|
||||
while True:
|
||||
|
||||
@@ -36,11 +36,6 @@ class BaseProxyEngine(schema.SchemaEngine):
|
||||
return None
|
||||
return e.oid_column_name()
|
||||
|
||||
def type_descriptor(self, typeobj):
|
||||
"""Proxy point: return a ProxyTypeEngine
|
||||
"""
|
||||
return ProxyTypeEngine(self, typeobj)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
# call get_engine() to give subclasses a chance to change
|
||||
# connection establishment behavior
|
||||
@@ -116,37 +111,3 @@ class ProxyEngine(BaseProxyEngine):
|
||||
self.storage.engine = engine
|
||||
|
||||
|
||||
class ProxyType(object):
|
||||
"""ProxyType base class; used by ProxyTypeEngine to construct proxying
|
||||
types
|
||||
"""
|
||||
def __init__(self, engine, typeobj):
|
||||
self._engine = engine
|
||||
self.typeobj = typeobj
|
||||
|
||||
def __getattribute__(self, attr):
|
||||
if attr.startswith('__') and attr.endswith('__'):
|
||||
return object.__getattribute__(self, attr)
|
||||
|
||||
engine = object.__getattribute__(self, '_engine').engine
|
||||
typeobj = object.__getattribute__(self, 'typeobj')
|
||||
return getattr(engine.type_descriptor(typeobj), attr)
|
||||
|
||||
def __repr__(self):
|
||||
return '<Proxy %s>' % (object.__getattribute__(self, 'typeobj'))
|
||||
|
||||
class ProxyTypeEngine(object):
|
||||
"""Proxy type engine; creates dynamic proxy type subclass that is instance
|
||||
of actual type, but proxies engine-dependant operations through the proxy
|
||||
engine.
|
||||
"""
|
||||
def __new__(cls, engine, typeobj):
|
||||
"""Create a new subclass of ProxyType and typeobj
|
||||
so that internal isinstance() calls will get the expected result.
|
||||
"""
|
||||
if isinstance(typeobj, type):
|
||||
typeclass = typeobj
|
||||
else:
|
||||
typeclass = typeobj.__class__
|
||||
typed = type('ProxyTypeHelper', (ProxyType, typeclass), {})
|
||||
return typed(engine, typeobj)
|
||||
|
||||
@@ -163,7 +163,6 @@ class Table(sql.TableClause, SchemaItem):
|
||||
if column.primary_key:
|
||||
self.primary_key.append(column)
|
||||
column.table = self
|
||||
column.type = self.engine.type_descriptor(column.type)
|
||||
|
||||
def append_index(self, index):
|
||||
self.indexes[index.name] = index
|
||||
|
||||
+22
-28
@@ -139,17 +139,11 @@ def cast(clause, totype, **kwargs):
|
||||
or
|
||||
cast(table.c.timestamp, DATE)
|
||||
"""
|
||||
engine = kwargs.get('engine', None)
|
||||
if engine is None:
|
||||
engine = getattr(clause, 'engine', None)
|
||||
if engine is not None:
|
||||
totype_desc = engine.type_descriptor(totype)
|
||||
# handle non-column clauses (e.g. cast(1234, TEXT)
|
||||
if not hasattr(clause, 'label'):
|
||||
clause = literal(clause)
|
||||
return Function('CAST', clause.label(totype_desc.get_col_spec()), type=totype, **kwargs)
|
||||
else:
|
||||
raise InvalidRequestError("No engine available, cannot generate cast for " + str(clause) + " to type " + str(totype))
|
||||
# handle non-column clauses (e.g. cast(1234, TEXT)
|
||||
if not hasattr(clause, 'label'):
|
||||
clause = literal(clause)
|
||||
totype = sqltypes.to_instance(totype)
|
||||
return Function('CAST', CompoundClause("AS", clause, TypeClause(totype)), type=totype, **kwargs)
|
||||
|
||||
def exists(*args, **params):
|
||||
params['correlate'] = True
|
||||
@@ -295,7 +289,8 @@ class ClauseVisitor(object):
|
||||
def visit_clauselist(self, list):pass
|
||||
def visit_function(self, func):pass
|
||||
def visit_label(self, label):pass
|
||||
|
||||
def visit_typeclause(self, typeclause):pass
|
||||
|
||||
class Compiled(ClauseVisitor):
|
||||
"""represents a compiled SQL expression. the __str__ method of the Compiled object
|
||||
should produce the actual text of the statement. Compiled objects are specific to the
|
||||
@@ -671,13 +666,7 @@ class BindParamClause(ClauseElement, CompareMixin):
|
||||
self.key = key
|
||||
self.value = value
|
||||
self.shortname = shortname
|
||||
self.type = type or sqltypes.NULLTYPE
|
||||
def _get_convert_type(self, engine):
|
||||
try:
|
||||
return self._converted_type
|
||||
except AttributeError:
|
||||
self._converted_type = engine.type_descriptor(self.type)
|
||||
return self._converted_type
|
||||
self.type = sqltypes.to_instance(type)
|
||||
def accept_visitor(self, visitor):
|
||||
visitor.visit_bindparam(self)
|
||||
def _get_from_objects(self):
|
||||
@@ -685,7 +674,7 @@ class BindParamClause(ClauseElement, CompareMixin):
|
||||
def copy_container(self):
|
||||
return BindParamClause(self.key, self.value, self.shortname, self.type)
|
||||
def typeprocess(self, value, engine):
|
||||
return self._get_convert_type(engine).convert_bind_param(value, engine)
|
||||
return self.type.engine_impl(engine).convert_bind_param(value, engine)
|
||||
def compare(self, other):
|
||||
"""compares this BindParamClause to the given clause.
|
||||
|
||||
@@ -695,7 +684,14 @@ class BindParamClause(ClauseElement, CompareMixin):
|
||||
def _make_proxy(self, selectable, name = None):
|
||||
return self
|
||||
# return self.obj._make_proxy(selectable, name=self.name)
|
||||
|
||||
|
||||
class TypeClause(ClauseElement):
|
||||
"""handles a type keyword in a SQL statement"""
|
||||
def __init__(self, type):
|
||||
self.type = type
|
||||
def accept_visitor(self, visitor):
|
||||
visitor.visit_typeclause(self)
|
||||
|
||||
class TextClause(ClauseElement):
|
||||
"""represents literal a SQL text fragment. public constructor is the
|
||||
text() function.
|
||||
@@ -714,7 +710,7 @@ class TextClause(ClauseElement):
|
||||
self.typemap = typemap
|
||||
if typemap is not None:
|
||||
for key in typemap.keys():
|
||||
typemap[key] = engine.type_descriptor(typemap[key])
|
||||
typemap[key] = sqltypes.to_instance(typemap[key])
|
||||
def repl(m):
|
||||
self.bindparams[m.group(1)] = bindparam(m.group(1))
|
||||
return ":%s" % m.group(1)
|
||||
@@ -820,11 +816,9 @@ class Function(ClauseList, ColumnElement):
|
||||
"""describes a SQL function. extends ClauseList to provide comparison operators."""
|
||||
def __init__(self, name, *clauses, **kwargs):
|
||||
self.name = name
|
||||
self.type = kwargs.get('type', sqltypes.NULLTYPE)
|
||||
self.type = sqltypes.to_instance(kwargs.get('type', None))
|
||||
self.packagenames = kwargs.get('packagenames', None) or []
|
||||
self._engine = kwargs.get('engine', None)
|
||||
if self._engine is not None:
|
||||
self.type = self._engine.type_descriptor(self.type)
|
||||
ClauseList.__init__(self, parens=True, *clauses)
|
||||
key = property(lambda self:self.name)
|
||||
def append(self, clause):
|
||||
@@ -873,7 +867,7 @@ class BinaryClause(ClauseElement):
|
||||
self.left = left
|
||||
self.right = right
|
||||
self.operator = operator
|
||||
self.type = type
|
||||
self.type = sqltypes.to_instance(type)
|
||||
self.parens = False
|
||||
if isinstance(self.left, BinaryClause):
|
||||
self.left.parens = True
|
||||
@@ -1028,7 +1022,7 @@ class Label(ColumnElement):
|
||||
while isinstance(obj, Label):
|
||||
obj = obj.obj
|
||||
self.obj = obj
|
||||
self.type = type or sqltypes.NullTypeEngine()
|
||||
self.type = sqltypes.to_instance(type)
|
||||
obj.parens=True
|
||||
key = property(lambda s: s.name)
|
||||
|
||||
@@ -1049,7 +1043,7 @@ class ColumnClause(ColumnElement):
|
||||
def __init__(self, text, selectable=None, type=None):
|
||||
self.key = self.name = self.text = text
|
||||
self.table = selectable
|
||||
self.type = type or sqltypes.NullTypeEngine()
|
||||
self.type = sqltypes.to_instance(type)
|
||||
self.__label = None
|
||||
def _get_label(self):
|
||||
if self.__label is None:
|
||||
|
||||
+21
-3
@@ -16,11 +16,22 @@ try:
|
||||
import cPickle as pickle
|
||||
except:
|
||||
import pickle
|
||||
|
||||
|
||||
class TypeEngine(object):
|
||||
basetypes = []
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
def _get_impl_dict(self):
|
||||
try:
|
||||
return self._impl_dict
|
||||
except AttributeError:
|
||||
self._impl_dict = {}
|
||||
return self._impl_dict
|
||||
impl_dict = property(_get_impl_dict)
|
||||
def engine_impl(self, engine):
|
||||
try:
|
||||
return self.impl_dict[engine]
|
||||
except:
|
||||
return self.impl_dict.setdefault(engine, engine.type_descriptor(self))
|
||||
def _get_impl(self):
|
||||
if hasattr(self, '_impl'):
|
||||
return self._impl
|
||||
@@ -41,7 +52,14 @@ class TypeEngine(object):
|
||||
return {}
|
||||
def adapt_args(self):
|
||||
return self
|
||||
|
||||
|
||||
def to_instance(typeobj):
|
||||
if typeobj is None:
|
||||
return NULLTYPE
|
||||
elif isinstance(typeobj, type):
|
||||
return typeobj()
|
||||
else:
|
||||
return typeobj
|
||||
def adapt_type(typeobj, colspecs):
|
||||
if isinstance(typeobj, type):
|
||||
typeobj = typeobj()
|
||||
|
||||
@@ -194,7 +194,7 @@ class ProxyEngineTest2(PersistTest):
|
||||
return 'a'
|
||||
|
||||
def type_descriptor(self, typeobj):
|
||||
if typeobj == types.Integer:
|
||||
if isinstance(typeobj, types.Integer):
|
||||
return TypeEngineX2()
|
||||
else:
|
||||
return TypeEngineSTR()
|
||||
@@ -224,16 +224,16 @@ class ProxyEngineTest2(PersistTest):
|
||||
engine = ProxyEngine()
|
||||
engine.storage.engine = EngineA()
|
||||
|
||||
a = engine.type_descriptor(sqltypes.Integer)
|
||||
a = sqltypes.Integer().engine_impl(engine)
|
||||
assert a.convert_bind_param(12, engine) == 24
|
||||
assert a.convert_bind_param([1,2,3], engine) == [1, 2, 3, 1, 2, 3]
|
||||
|
||||
a2 = engine.type_descriptor(sqltypes.String)
|
||||
a2 = sqltypes.String().engine_impl(engine)
|
||||
assert a2.convert_bind_param(12, engine) == "'12'"
|
||||
assert a2.convert_bind_param([1,2,3], engine) == "'[1, 2, 3]'"
|
||||
|
||||
engine.storage.engine = EngineB()
|
||||
b = engine.type_descriptor(sqltypes.Integer)
|
||||
b = sqltypes.Integer().engine_impl(engine)
|
||||
assert b.convert_bind_param(12, engine) == 'monkey'
|
||||
assert b.convert_bind_param([1,2,3], engine) == 'monkey'
|
||||
|
||||
|
||||
Reference in New Issue
Block a user