some tests, should be OK

This commit is contained in:
Mike Bayer
2010-12-13 20:23:24 -05:00
parent 68fb34ba57
commit bfaa97dbce
4 changed files with 103 additions and 45 deletions
+5 -3
View File
@@ -175,8 +175,9 @@ class REAL(sqltypes.Float):
__visit_name__ = 'REAL'
def __init__(self):
super(REAL, self).__init__(precision=24)
def __init__(self, **kw):
kw.setdefault('precision', 24)
super(REAL, self).__init__(**kw)
class TINYINT(sqltypes.Integer):
__visit_name__ = 'TINYINT'
@@ -258,7 +259,8 @@ class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime):
class DATETIME2(_DateTimeBase, sqltypes.DateTime):
__visit_name__ = 'DATETIME2'
def __init__(self, precision=None, **kwargs):
def __init__(self, precision=None, **kw):
super(DATETIME2, self).__init__(**kw)
self.precision = precision
+1 -1
View File
@@ -771,7 +771,7 @@ class CHAR(_StringType, sqltypes.CHAR):
__visit_name__ = 'CHAR'
def __init__(self, length, **kwargs):
def __init__(self, length=None, **kwargs):
"""Construct a CHAR.
:param length: Maximum data length, in characters.
+25 -22
View File
@@ -131,7 +131,7 @@ class TypeEngine(AbstractType):
else:
return self.__class__
def dialect_impl(self, dialect, **kwargs):
def dialect_impl(self, dialect):
"""Return a dialect-specific implementation for this type."""
try:
@@ -149,22 +149,6 @@ class TypeEngine(AbstractType):
d['bind'] = bp = d['impl'].bind_processor(dialect)
return bp
def _dialect_info(self, dialect):
"""Return a dialect-specific registry containing bind/result processors."""
if self in dialect._type_memos:
return dialect._type_memos[self]
else:
impl = self._gen_dialect_impl(dialect)
# the impl we put in here
# must not have any references to self.
if impl is self:
impl = self.adapt(type(self))
dialect._type_memos[self] = d = {
'impl':impl,
}
return d
def _cached_result_processor(self, dialect, coltype):
"""Return a dialect-specific result processor for this type."""
@@ -172,11 +156,28 @@ class TypeEngine(AbstractType):
return dialect._type_memos[self][coltype]
except KeyError:
d = self._dialect_info(dialect)
# another key assumption. DBAPI type codes are
# constants.
# key assumption: DBAPI type codes are
# constants. Else this dictionary would
# grow unbounded.
d[coltype] = rp = d['impl'].result_processor(dialect, coltype)
return rp
def _dialect_info(self, dialect):
"""Return a dialect-specific registry which
caches a dialect-specific implementation, bind processing
function, and one or more result processing functions."""
if self in dialect._type_memos:
return dialect._type_memos[self]
else:
impl = self._gen_dialect_impl(dialect)
if impl is self:
impl = self.adapt(type(self))
# this can't be self, else we create a cycle
assert impl is not self
dialect._type_memos[self] = d = {'impl':impl}
return d
def _gen_dialect_impl(self, dialect):
return dialect.type_descriptor(self)
@@ -792,7 +793,7 @@ class String(Concatenable, TypeEngine):
length=self.length,
convert_unicode=self.convert_unicode,
unicode_error=self.unicode_error,
_warn_on_bytestring=True,
_warn_on_bytestring=self._warn_on_bytestring,
**kw
)
@@ -1171,7 +1172,9 @@ class Float(Numeric):
"""
__visit_name__ = 'float'
scale = None
def __init__(self, precision=None, asdecimal=False, **kwargs):
"""
Construct a Float.
@@ -1787,7 +1790,7 @@ class Interval(_DateAffinity, TypeDecorator):
self.day_precision = day_precision
def adapt(self, cls, **kw):
if self.native:
if self.native and hasattr(cls, '_adapt_from_generic_interval'):
return cls._adapt_from_generic_interval(self, **kw)
else:
return cls(**kw)
+72 -19
View File
@@ -3,19 +3,44 @@ from test.lib.testing import eq_, assert_raises, assert_raises_message
import decimal
import datetime, os, re
from sqlalchemy import *
from sqlalchemy import exc, types, util, schema
from sqlalchemy import exc, types, util, schema, dialects
for name in dialects.__all__:
__import__("sqlalchemy.dialects.%s" % name)
from sqlalchemy.sql import operators, column, table
from test.lib.testing import eq_
import sqlalchemy.engine.url as url
from sqlalchemy.databases import *
from test.lib.schema import Table, Column
from test.lib import *
from test.lib.util import picklers
from sqlalchemy.util.compat import decimal
from test.lib.util import round_decimal
class AdaptTest(TestBase):
def _all_dialect_modules(self):
return [
getattr(dialects, d)
for d in dialects.__all__
if not d.startswith('_')
]
def _all_dialects(self):
return [d.base.dialect() for d in
self._all_dialect_modules()]
def _all_types(self):
def types_for_mod(mod):
for key in dir(mod):
typ = getattr(mod, key)
if not isinstance(typ, type) or not issubclass(typ, types.TypeEngine):
continue
yield typ
for typ in types_for_mod(types):
yield typ
for dialect in self._all_dialect_modules():
for typ in types_for_mod(dialect):
yield typ
def test_uppercase_rendering(self):
"""Test that uppercase types from types.py always render as their
type.
@@ -27,12 +52,7 @@ class AdaptTest(TestBase):
"""
for dialect in [
oracle.dialect(),
mysql.dialect(),
postgresql.dialect(),
sqlite.dialect(),
mssql.dialect()]:
for dialect in self._all_dialects():
for type_, expected in (
(FLOAT, "FLOAT"),
(NUMERIC, "NUMERIC"),
@@ -49,7 +69,7 @@ class AdaptTest(TestBase):
"NVARCHAR2(10)")),
(CHAR, "CHAR"),
(NCHAR, ("NCHAR", "NATIONAL CHAR")),
(BLOB, "BLOB"),
(BLOB, ("BLOB", "BLOB SUB_TYPE 0")),
(BOOLEAN, ("BOOLEAN", "BOOL"))
):
if isinstance(expected, str):
@@ -65,7 +85,40 @@ class AdaptTest(TestBase):
assert str(types.to_instance(type_)) in expected, \
"default str() of type %r not expected, %r" % \
(type_, expected)
@testing.uses_deprecated()
def test_adapt_method(self):
"""ensure all types have a working adapt() method,
which creates a distinct copy.
The distinct copy ensures that when we cache
the adapted() form of a type against the original
in a weak key dictionary, a cycle is not formed.
This test doesn't test type-specific arguments of
adapt() beyond their defaults.
"""
for typ in self._all_types():
if typ in (types.TypeDecorator, types.TypeEngine):
continue
elif typ is dialects.postgresql.ARRAY:
t1 = typ(String)
else:
t1 = typ()
for cls in [typ] + typ.__subclasses__():
if not issubclass(typ, types.Enum) and \
issubclass(cls, types.Enum):
continue
t2 = t1.adapt(cls)
assert t1 is not t2
for k in t1.__dict__:
if k == 'impl':
continue
eq_(getattr(t2, k), t1.__dict__[k])
class TypeAffinityTest(TestBase):
def test_type_affinity(self):
for type_, affin in [
@@ -155,7 +208,7 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL):
(Float(2), "FLOAT(2)", {'precision':4}),
(Numeric(19, 2), "NUMERIC(19, 2)", {}),
]:
for dialect_ in (postgresql, mssql, mysql):
for dialect_ in (dialects.postgresql, dialects.mssql, dialects.mysql):
dialect_ = dialect_.dialect()
raw_impl = types.to_instance(impl_, **kw)
@@ -188,8 +241,8 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL):
else:
return super(MyType, self).load_dialect_impl(dialect)
sl = sqlite.dialect()
pg = postgresql.dialect()
sl = dialects.sqlite.dialect()
pg = dialects.postgresql.dialect()
t = MyType()
self.assert_compile(t, "VARCHAR(50)", dialect=sl)
self.assert_compile(t, "FLOAT", dialect=pg)
@@ -1082,12 +1135,12 @@ class CompileTest(TestBase, AssertsCompiledSQL):
for type_, expected in (
(String(), "VARCHAR"),
(Integer(), "INTEGER"),
(postgresql.INET(), "INET"),
(postgresql.FLOAT(), "FLOAT"),
(mysql.REAL(precision=8, scale=2), "REAL(8, 2)"),
(postgresql.REAL(), "REAL"),
(dialects.postgresql.INET(), "INET"),
(dialects.postgresql.FLOAT(), "FLOAT"),
(dialects.mysql.REAL(precision=8, scale=2), "REAL(8, 2)"),
(dialects.postgresql.REAL(), "REAL"),
(INTEGER(), "INTEGER"),
(mysql.INTEGER(display_width=5), "INTEGER(5)")
(dialects.mysql.INTEGER(display_width=5), "INTEGER(5)")
):
self.assert_compile(type_, expected)