moves the binding of a TypeEngine object from "schema/statement creation" time into "compilation" time

This commit is contained in:
Mike Bayer
2006-04-06 01:15:46 +00:00
parent 753b7c2d3e
commit 680c276073
13 changed files with 59 additions and 84 deletions
+4 -1
View File
@@ -189,7 +189,10 @@ class ANSICompiler(sql.Compiled):
def visit_index(self, index):
self.strings[index] = index.name
def visit_typeclause(self, typeclause):
self.strings[typeclause] = typeclause.type.engine_impl(self.engine).get_col_spec()
def visit_textclause(self, textclause):
if textclause.parens and len(textclause.text):
self.strings[textclause] = "(" + textclause.text + ")"
+1 -1
View File
@@ -238,7 +238,7 @@ class FBCompiler(ansisql.ANSICompiler):
class FBSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, **kwargs):
colspec = column.name
colspec += " " + column.type.get_col_spec()
colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
+1 -1
View File
@@ -460,7 +460,7 @@ class MSSQLCompiler(ansisql.ANSICompiler):
class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, first_pk=False):
colspec = column.name + " " + column.type.get_col_spec()
colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec()
# install a IDENTITY Sequence if we have an implicit IDENTITY column
if column.primary_key and isinstance(column.type, types.Integer):
+1 -1
View File
@@ -263,7 +263,7 @@ class MySQLCompiler(ansisql.ANSICompiler):
class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, first_pk=False):
colspec = column.name + " " + column.type.get_col_spec()
colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
+1 -1
View File
@@ -306,7 +306,7 @@ class OracleCompiler(ansisql.ANSICompiler):
class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, **kwargs):
colspec = column.name
colspec += " " + column.type.get_col_spec()
colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
+1 -1
View File
@@ -305,7 +305,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
colspec += " SERIAL"
else:
colspec += " " + column.type.get_col_spec()
colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
+1 -1
View File
@@ -241,7 +241,7 @@ class SQLiteCompiler(ansisql.ANSICompiler):
class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, **kwargs):
colspec = column.name + " " + column.type.get_col_spec()
colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
+2 -2
View File
@@ -319,7 +319,7 @@ class SQLEngine(schema.SchemaEngine):
self.positional = True
else:
raise DBAPIError("Unsupported paramstyle '%s'" % self._paramstyle)
def type_descriptor(self, typeobj):
"""provides a database-specific TypeEngine object, given the generic object
which comes from the types module. Subclasses will usually use the adapt_type()
@@ -808,7 +808,7 @@ class ResultProxy:
rec = self.props[key.lower()]
else:
rec = self.props[key]
return rec[0].convert_result_value(row[rec[1]], self.engine)
return rec[0].engine_impl(self.engine).convert_result_value(row[rec[1]], self.engine)
def __iter__(self):
while True:
-39
View File
@@ -36,11 +36,6 @@ class BaseProxyEngine(schema.SchemaEngine):
return None
return e.oid_column_name()
def type_descriptor(self, typeobj):
"""Proxy point: return a ProxyTypeEngine
"""
return ProxyTypeEngine(self, typeobj)
def __getattr__(self, attr):
# call get_engine() to give subclasses a chance to change
# connection establishment behavior
@@ -116,37 +111,3 @@ class ProxyEngine(BaseProxyEngine):
self.storage.engine = engine
class ProxyType(object):
"""ProxyType base class; used by ProxyTypeEngine to construct proxying
types
"""
def __init__(self, engine, typeobj):
self._engine = engine
self.typeobj = typeobj
def __getattribute__(self, attr):
if attr.startswith('__') and attr.endswith('__'):
return object.__getattribute__(self, attr)
engine = object.__getattribute__(self, '_engine').engine
typeobj = object.__getattribute__(self, 'typeobj')
return getattr(engine.type_descriptor(typeobj), attr)
def __repr__(self):
return '<Proxy %s>' % (object.__getattribute__(self, 'typeobj'))
class ProxyTypeEngine(object):
"""Proxy type engine; creates dynamic proxy type subclass that is instance
of actual type, but proxies engine-dependant operations through the proxy
engine.
"""
def __new__(cls, engine, typeobj):
"""Create a new subclass of ProxyType and typeobj
so that internal isinstance() calls will get the expected result.
"""
if isinstance(typeobj, type):
typeclass = typeobj
else:
typeclass = typeobj.__class__
typed = type('ProxyTypeHelper', (ProxyType, typeclass), {})
return typed(engine, typeobj)
-1
View File
@@ -163,7 +163,6 @@ class Table(sql.TableClause, SchemaItem):
if column.primary_key:
self.primary_key.append(column)
column.table = self
column.type = self.engine.type_descriptor(column.type)
def append_index(self, index):
self.indexes[index.name] = index
+22 -28
View File
@@ -139,17 +139,11 @@ def cast(clause, totype, **kwargs):
or
cast(table.c.timestamp, DATE)
"""
engine = kwargs.get('engine', None)
if engine is None:
engine = getattr(clause, 'engine', None)
if engine is not None:
totype_desc = engine.type_descriptor(totype)
# handle non-column clauses (e.g. cast(1234, TEXT)
if not hasattr(clause, 'label'):
clause = literal(clause)
return Function('CAST', clause.label(totype_desc.get_col_spec()), type=totype, **kwargs)
else:
raise InvalidRequestError("No engine available, cannot generate cast for " + str(clause) + " to type " + str(totype))
# handle non-column clauses (e.g. cast(1234, TEXT)
if not hasattr(clause, 'label'):
clause = literal(clause)
totype = sqltypes.to_instance(totype)
return Function('CAST', CompoundClause("AS", clause, TypeClause(totype)), type=totype, **kwargs)
def exists(*args, **params):
params['correlate'] = True
@@ -295,7 +289,8 @@ class ClauseVisitor(object):
def visit_clauselist(self, list):pass
def visit_function(self, func):pass
def visit_label(self, label):pass
def visit_typeclause(self, typeclause):pass
class Compiled(ClauseVisitor):
"""represents a compiled SQL expression. the __str__ method of the Compiled object
should produce the actual text of the statement. Compiled objects are specific to the
@@ -671,13 +666,7 @@ class BindParamClause(ClauseElement, CompareMixin):
self.key = key
self.value = value
self.shortname = shortname
self.type = type or sqltypes.NULLTYPE
def _get_convert_type(self, engine):
try:
return self._converted_type
except AttributeError:
self._converted_type = engine.type_descriptor(self.type)
return self._converted_type
self.type = sqltypes.to_instance(type)
def accept_visitor(self, visitor):
visitor.visit_bindparam(self)
def _get_from_objects(self):
@@ -685,7 +674,7 @@ class BindParamClause(ClauseElement, CompareMixin):
def copy_container(self):
return BindParamClause(self.key, self.value, self.shortname, self.type)
def typeprocess(self, value, engine):
return self._get_convert_type(engine).convert_bind_param(value, engine)
return self.type.engine_impl(engine).convert_bind_param(value, engine)
def compare(self, other):
"""compares this BindParamClause to the given clause.
@@ -695,7 +684,14 @@ class BindParamClause(ClauseElement, CompareMixin):
def _make_proxy(self, selectable, name = None):
return self
# return self.obj._make_proxy(selectable, name=self.name)
class TypeClause(ClauseElement):
"""handles a type keyword in a SQL statement"""
def __init__(self, type):
self.type = type
def accept_visitor(self, visitor):
visitor.visit_typeclause(self)
class TextClause(ClauseElement):
"""represents literal a SQL text fragment. public constructor is the
text() function.
@@ -714,7 +710,7 @@ class TextClause(ClauseElement):
self.typemap = typemap
if typemap is not None:
for key in typemap.keys():
typemap[key] = engine.type_descriptor(typemap[key])
typemap[key] = sqltypes.to_instance(typemap[key])
def repl(m):
self.bindparams[m.group(1)] = bindparam(m.group(1))
return ":%s" % m.group(1)
@@ -820,11 +816,9 @@ class Function(ClauseList, ColumnElement):
"""describes a SQL function. extends ClauseList to provide comparison operators."""
def __init__(self, name, *clauses, **kwargs):
self.name = name
self.type = kwargs.get('type', sqltypes.NULLTYPE)
self.type = sqltypes.to_instance(kwargs.get('type', None))
self.packagenames = kwargs.get('packagenames', None) or []
self._engine = kwargs.get('engine', None)
if self._engine is not None:
self.type = self._engine.type_descriptor(self.type)
ClauseList.__init__(self, parens=True, *clauses)
key = property(lambda self:self.name)
def append(self, clause):
@@ -873,7 +867,7 @@ class BinaryClause(ClauseElement):
self.left = left
self.right = right
self.operator = operator
self.type = type
self.type = sqltypes.to_instance(type)
self.parens = False
if isinstance(self.left, BinaryClause):
self.left.parens = True
@@ -1028,7 +1022,7 @@ class Label(ColumnElement):
while isinstance(obj, Label):
obj = obj.obj
self.obj = obj
self.type = type or sqltypes.NullTypeEngine()
self.type = sqltypes.to_instance(type)
obj.parens=True
key = property(lambda s: s.name)
@@ -1049,7 +1043,7 @@ class ColumnClause(ColumnElement):
def __init__(self, text, selectable=None, type=None):
self.key = self.name = self.text = text
self.table = selectable
self.type = type or sqltypes.NullTypeEngine()
self.type = sqltypes.to_instance(type)
self.__label = None
def _get_label(self):
if self.__label is None:
+21 -3
View File
@@ -16,11 +16,22 @@ try:
import cPickle as pickle
except:
import pickle
class TypeEngine(object):
basetypes = []
def __init__(self, *args, **kwargs):
pass
def _get_impl_dict(self):
try:
return self._impl_dict
except AttributeError:
self._impl_dict = {}
return self._impl_dict
impl_dict = property(_get_impl_dict)
def engine_impl(self, engine):
try:
return self.impl_dict[engine]
except:
return self.impl_dict.setdefault(engine, engine.type_descriptor(self))
def _get_impl(self):
if hasattr(self, '_impl'):
return self._impl
@@ -41,7 +52,14 @@ class TypeEngine(object):
return {}
def adapt_args(self):
return self
def to_instance(typeobj):
if typeobj is None:
return NULLTYPE
elif isinstance(typeobj, type):
return typeobj()
else:
return typeobj
def adapt_type(typeobj, colspecs):
if isinstance(typeobj, type):
typeobj = typeobj()
+4 -4
View File
@@ -194,7 +194,7 @@ class ProxyEngineTest2(PersistTest):
return 'a'
def type_descriptor(self, typeobj):
if typeobj == types.Integer:
if isinstance(typeobj, types.Integer):
return TypeEngineX2()
else:
return TypeEngineSTR()
@@ -224,16 +224,16 @@ class ProxyEngineTest2(PersistTest):
engine = ProxyEngine()
engine.storage.engine = EngineA()
a = engine.type_descriptor(sqltypes.Integer)
a = sqltypes.Integer().engine_impl(engine)
assert a.convert_bind_param(12, engine) == 24
assert a.convert_bind_param([1,2,3], engine) == [1, 2, 3, 1, 2, 3]
a2 = engine.type_descriptor(sqltypes.String)
a2 = sqltypes.String().engine_impl(engine)
assert a2.convert_bind_param(12, engine) == "'12'"
assert a2.convert_bind_param([1,2,3], engine) == "'[1, 2, 3]'"
engine.storage.engine = EngineB()
b = engine.type_descriptor(sqltypes.Integer)
b = sqltypes.Integer().engine_impl(engine)
assert b.convert_bind_param(12, engine) == 'monkey'
assert b.convert_bind_param([1,2,3], engine) == 'monkey'