mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-28 03:26:01 -04:00
just a pep8 pass of lib/sqlalchemy/testing/
This commit is contained in:
@@ -10,12 +10,11 @@ from .exclusions import db_spec, _is_excluded, fails_if, skip_if, future,\
|
||||
|
||||
from .assertions import emits_warning, emits_warning_on, uses_deprecated, \
|
||||
eq_, ne_, is_, is_not_, startswith_, assert_raises, \
|
||||
assert_raises_message, AssertsCompiledSQL, ComparesTables, AssertsExecutionResults
|
||||
assert_raises_message, AssertsCompiledSQL, ComparesTables, \
|
||||
AssertsExecutionResults
|
||||
|
||||
from .util import run_as_contextmanager, rowset, fail, provide_metadata, adict
|
||||
|
||||
crashes = skip
|
||||
|
||||
from .config import db, requirements as requires
|
||||
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import itertools
|
||||
from .util import fail
|
||||
import contextlib
|
||||
|
||||
|
||||
def emits_warning(*messages):
|
||||
"""Mark a test as emitting a warning.
|
||||
|
||||
@@ -50,6 +51,7 @@ def emits_warning(*messages):
|
||||
resetwarnings()
|
||||
return decorate
|
||||
|
||||
|
||||
def emits_warning_on(db, *warnings):
|
||||
"""Mark a test as emitting a warning on a specific dialect.
|
||||
|
||||
@@ -115,7 +117,6 @@ def uses_deprecated(*messages):
|
||||
return decorate
|
||||
|
||||
|
||||
|
||||
def global_cleanup_assertions():
|
||||
"""Check things that have to be finalized at the end of a test suite.
|
||||
|
||||
@@ -129,28 +130,32 @@ def global_cleanup_assertions():
|
||||
assert not pool._refs, str(pool._refs)
|
||||
|
||||
|
||||
|
||||
def eq_(a, b, msg=None):
|
||||
"""Assert a == b, with repr messaging on failure."""
|
||||
assert a == b, msg or "%r != %r" % (a, b)
|
||||
|
||||
|
||||
def ne_(a, b, msg=None):
|
||||
"""Assert a != b, with repr messaging on failure."""
|
||||
assert a != b, msg or "%r == %r" % (a, b)
|
||||
|
||||
|
||||
def is_(a, b, msg=None):
|
||||
"""Assert a is b, with repr messaging on failure."""
|
||||
assert a is b, msg or "%r is not %r" % (a, b)
|
||||
|
||||
|
||||
def is_not_(a, b, msg=None):
|
||||
"""Assert a is not b, with repr messaging on failure."""
|
||||
assert a is not b, msg or "%r is %r" % (a, b)
|
||||
|
||||
|
||||
def startswith_(a, fragment, msg=None):
|
||||
"""Assert a.startswith(fragment), with repr messaging on failure."""
|
||||
assert a.startswith(fragment), msg or "%r does not start with %r" % (
|
||||
a, fragment)
|
||||
|
||||
|
||||
def assert_raises(except_cls, callable_, *args, **kw):
|
||||
try:
|
||||
callable_(*args, **kw)
|
||||
@@ -161,6 +166,7 @@ def assert_raises(except_cls, callable_, *args, **kw):
|
||||
# assert outside the block so it works for AssertionError too !
|
||||
assert success, "Callable did not raise an exception"
|
||||
|
||||
|
||||
def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
|
||||
try:
|
||||
callable_(*args, **kwargs)
|
||||
@@ -214,7 +220,9 @@ class AssertsCompiledSQL(object):
|
||||
p = c.construct_params(params)
|
||||
eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
|
||||
|
||||
|
||||
class ComparesTables(object):
|
||||
|
||||
def assert_tables_equal(self, table, reflected_table, strict_types=False):
|
||||
assert len(table.c) == len(reflected_table.c)
|
||||
for c, reflected_c in zip(table.c, reflected_table.c):
|
||||
@@ -224,15 +232,19 @@ class ComparesTables(object):
|
||||
eq_(c.nullable, reflected_c.nullable)
|
||||
|
||||
if strict_types:
|
||||
msg = "Type '%s' doesn't correspond to type '%s'"
|
||||
assert type(reflected_c.type) is type(c.type), \
|
||||
"Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
|
||||
msg % (reflected_c.type, c.type)
|
||||
else:
|
||||
self.assert_types_base(reflected_c, c)
|
||||
|
||||
if isinstance(c.type, sqltypes.String):
|
||||
eq_(c.type.length, reflected_c.type.length)
|
||||
|
||||
eq_(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys]))
|
||||
eq_(
|
||||
set([f.column.name for f in c.foreign_keys]),
|
||||
set([f.column.name for f in reflected_c.foreign_keys])
|
||||
)
|
||||
if c.server_default:
|
||||
assert isinstance(reflected_c.server_default,
|
||||
schema.FetchedValue)
|
||||
@@ -246,6 +258,7 @@ class ComparesTables(object):
|
||||
"On column %r, type '%s' doesn't correspond to type '%s'" % \
|
||||
(c1.name, c1.type, c2.type)
|
||||
|
||||
|
||||
class AssertsExecutionResults(object):
|
||||
def assert_result(self, result, class_, *objects):
|
||||
result = list(result)
|
||||
@@ -296,6 +309,7 @@ class AssertsExecutionResults(object):
|
||||
len(found), len(expected)))
|
||||
|
||||
NOVALUE = object()
|
||||
|
||||
def _compare_item(obj, spec):
|
||||
for key, value in spec.iteritems():
|
||||
if isinstance(value, tuple):
|
||||
@@ -347,7 +361,8 @@ class AssertsExecutionResults(object):
|
||||
self.assert_sql_execution(db, callable_, *newrules)
|
||||
|
||||
def assert_sql_count(self, db, callable_, count):
|
||||
self.assert_sql_execution(db, callable_, assertsql.CountStatements(count))
|
||||
self.assert_sql_execution(
|
||||
db, callable_, assertsql.CountStatements(count))
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assert_execution(self, *rules):
|
||||
@@ -359,4 +374,4 @@ class AssertsExecutionResults(object):
|
||||
assertsql.asserter.clear_rules()
|
||||
|
||||
def assert_statement_count(self, count):
|
||||
return self.assert_execution(assertsql.CountStatements(count))
|
||||
return self.assert_execution(assertsql.CountStatements(count))
|
||||
|
||||
@@ -3,6 +3,7 @@ from ..engine.default import DefaultDialect
|
||||
from .. import util
|
||||
import re
|
||||
|
||||
|
||||
class AssertRule(object):
|
||||
|
||||
def process_execute(self, clauseelement, *multiparams, **params):
|
||||
@@ -40,6 +41,7 @@ class AssertRule(object):
|
||||
assert False, 'Rule has not been consumed'
|
||||
return self.is_consumed()
|
||||
|
||||
|
||||
class SQLMatchRule(AssertRule):
|
||||
def __init__(self):
|
||||
self._result = None
|
||||
@@ -56,6 +58,7 @@ class SQLMatchRule(AssertRule):
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class ExactSQL(SQLMatchRule):
|
||||
|
||||
def __init__(self, sql, params=None):
|
||||
@@ -138,6 +141,7 @@ class RegexSQL(SQLMatchRule):
|
||||
_received_statement,
|
||||
_received_parameters)
|
||||
|
||||
|
||||
class CompiledSQL(SQLMatchRule):
|
||||
|
||||
def __init__(self, statement, params):
|
||||
@@ -217,6 +221,7 @@ class CountStatements(AssertRule):
|
||||
% (self.count, self._statement_count)
|
||||
return True
|
||||
|
||||
|
||||
class AllOf(AssertRule):
|
||||
|
||||
def __init__(self, *rules):
|
||||
@@ -244,6 +249,7 @@ class AllOf(AssertRule):
|
||||
def consume_final(self):
|
||||
return len(self.rules) == 0
|
||||
|
||||
|
||||
def _process_engine_statement(query, context):
|
||||
if util.jython:
|
||||
|
||||
@@ -256,6 +262,7 @@ def _process_engine_statement(query, context):
|
||||
query = re.sub(r'\n', '', query)
|
||||
return query
|
||||
|
||||
|
||||
def _process_assertion_statement(query, context):
|
||||
paramstyle = context.dialect.paramstyle
|
||||
if paramstyle == 'named':
|
||||
@@ -275,6 +282,7 @@ def _process_assertion_statement(query, context):
|
||||
|
||||
return query
|
||||
|
||||
|
||||
class SQLAssert(object):
|
||||
|
||||
rules = None
|
||||
@@ -311,4 +319,3 @@ class SQLAssert(object):
|
||||
executemany)
|
||||
|
||||
asserter = SQLAssert()
|
||||
|
||||
|
||||
@@ -1,3 +1,2 @@
|
||||
requirements = None
|
||||
db = None
|
||||
|
||||
|
||||
@@ -9,7 +9,9 @@ from .. import event, pool
|
||||
import re
|
||||
import warnings
|
||||
|
||||
|
||||
class ConnectionKiller(object):
|
||||
|
||||
def __init__(self):
|
||||
self.proxy_refs = weakref.WeakKeyDictionary()
|
||||
self.testing_engines = weakref.WeakKeyDictionary()
|
||||
@@ -83,12 +85,14 @@ class ConnectionKiller(object):
|
||||
|
||||
testing_reaper = ConnectionKiller()
|
||||
|
||||
|
||||
def drop_all_tables(metadata, bind):
|
||||
testing_reaper.close_all()
|
||||
if hasattr(bind, 'close'):
|
||||
bind.close()
|
||||
metadata.drop_all(bind)
|
||||
|
||||
|
||||
@decorator
|
||||
def assert_conns_closed(fn, *args, **kw):
|
||||
try:
|
||||
@@ -96,6 +100,7 @@ def assert_conns_closed(fn, *args, **kw):
|
||||
finally:
|
||||
testing_reaper.assert_all_closed()
|
||||
|
||||
|
||||
@decorator
|
||||
def rollback_open_connections(fn, *args, **kw):
|
||||
"""Decorator that rolls back all open connections after fn execution."""
|
||||
@@ -105,6 +110,7 @@ def rollback_open_connections(fn, *args, **kw):
|
||||
finally:
|
||||
testing_reaper.rollback_all()
|
||||
|
||||
|
||||
@decorator
|
||||
def close_first(fn, *args, **kw):
|
||||
"""Decorator that closes all connections before fn execution."""
|
||||
@@ -121,6 +127,7 @@ def close_open_connections(fn, *args, **kw):
|
||||
finally:
|
||||
testing_reaper.close_all()
|
||||
|
||||
|
||||
def all_dialects(exclude=None):
|
||||
import sqlalchemy.databases as d
|
||||
for name in d.__all__:
|
||||
@@ -129,10 +136,13 @@ def all_dialects(exclude=None):
|
||||
continue
|
||||
mod = getattr(d, name, None)
|
||||
if not mod:
|
||||
mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
|
||||
mod = getattr(__import__(
|
||||
'sqlalchemy.databases.%s' % name).databases, name)
|
||||
yield mod.dialect()
|
||||
|
||||
|
||||
class ReconnectFixture(object):
|
||||
|
||||
def __init__(self, dbapi):
|
||||
self.dbapi = dbapi
|
||||
self.connections = []
|
||||
@@ -165,6 +175,7 @@ class ReconnectFixture(object):
|
||||
self._safe(c.close)
|
||||
self.connections = []
|
||||
|
||||
|
||||
def reconnecting_engine(url=None, options=None):
|
||||
url = url or config.db_url
|
||||
dbapi = config.db.dialect.dbapi
|
||||
@@ -173,9 +184,11 @@ def reconnecting_engine(url=None, options=None):
|
||||
options['module'] = ReconnectFixture(dbapi)
|
||||
engine = testing_engine(url, options)
|
||||
_dispose = engine.dispose
|
||||
|
||||
def dispose():
|
||||
engine.dialect.dbapi.shutdown()
|
||||
_dispose()
|
||||
|
||||
engine.test_shutdown = engine.dialect.dbapi.shutdown
|
||||
engine.dispose = dispose
|
||||
return engine
|
||||
@@ -209,6 +222,7 @@ def testing_engine(url=None, options=None):
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def utf8_engine(url=None, options=None):
|
||||
"""Hook for dialects or drivers that don't handle utf8 by default."""
|
||||
|
||||
@@ -226,6 +240,7 @@ def utf8_engine(url=None, options=None):
|
||||
|
||||
return testing_engine(url, options)
|
||||
|
||||
|
||||
def mock_engine(dialect_name=None):
|
||||
"""Provides a mocking engine based on the current testing.db.
|
||||
|
||||
@@ -244,17 +259,21 @@ def mock_engine(dialect_name=None):
|
||||
dialect_name = config.db.name
|
||||
|
||||
buffer = []
|
||||
|
||||
def executor(sql, *a, **kw):
|
||||
buffer.append(sql)
|
||||
|
||||
def assert_sql(stmts):
|
||||
recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer]
|
||||
assert recv == stmts, recv
|
||||
|
||||
def print_sql():
|
||||
d = engine.dialect
|
||||
return "\n".join(
|
||||
str(s.compile(dialect=d))
|
||||
for s in engine.mock
|
||||
)
|
||||
|
||||
engine = create_engine(dialect_name + '://',
|
||||
strategy='mock', executor=executor)
|
||||
assert not hasattr(engine, 'mock')
|
||||
@@ -263,6 +282,7 @@ def mock_engine(dialect_name=None):
|
||||
engine.print_sql = print_sql
|
||||
return engine
|
||||
|
||||
|
||||
class DBAPIProxyCursor(object):
|
||||
"""Proxy a DBAPI cursor.
|
||||
|
||||
@@ -287,6 +307,7 @@ class DBAPIProxyCursor(object):
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.cursor, key)
|
||||
|
||||
|
||||
class DBAPIProxyConnection(object):
|
||||
"""Proxy a DBAPI connection.
|
||||
|
||||
@@ -308,14 +329,17 @@ class DBAPIProxyConnection(object):
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.conn, key)
|
||||
|
||||
def proxying_engine(conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor):
|
||||
|
||||
def proxying_engine(conn_cls=DBAPIProxyConnection,
|
||||
cursor_cls=DBAPIProxyCursor):
|
||||
"""Produce an engine that provides proxy hooks for
|
||||
common methods.
|
||||
|
||||
"""
|
||||
def mock_conn():
|
||||
return conn_cls(config.db, cursor_cls)
|
||||
return testing_engine(options={'creator':mock_conn})
|
||||
return testing_engine(options={'creator': mock_conn})
|
||||
|
||||
|
||||
class ReplayableSession(object):
|
||||
"""A simple record/playback tool.
|
||||
@@ -427,4 +451,3 @@ class ReplayableSession(object):
|
||||
raise AttributeError(key)
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
@@ -2,7 +2,10 @@ import sqlalchemy as sa
|
||||
from sqlalchemy import exc as sa_exc
|
||||
|
||||
_repr_stack = set()
|
||||
|
||||
|
||||
class BasicEntity(object):
|
||||
|
||||
def __init__(self, **kw):
|
||||
for key, value in kw.iteritems():
|
||||
setattr(self, key, value)
|
||||
@@ -21,7 +24,10 @@ class BasicEntity(object):
|
||||
_repr_stack.remove(id(self))
|
||||
|
||||
_recursion_stack = set()
|
||||
|
||||
|
||||
class ComparableEntity(BasicEntity):
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.__class__)
|
||||
|
||||
|
||||
@@ -61,6 +61,7 @@ class skip_if(object):
|
||||
self._fails_on = skip_if(other, reason)
|
||||
return self
|
||||
|
||||
|
||||
class fails_if(skip_if):
|
||||
def __call__(self, fn):
|
||||
@decorator
|
||||
@@ -69,14 +70,17 @@ class fails_if(skip_if):
|
||||
return fn(*args, **kw)
|
||||
return decorate(fn)
|
||||
|
||||
|
||||
def only_if(predicate, reason=None):
|
||||
predicate = _as_predicate(predicate)
|
||||
return skip_if(NotPredicate(predicate), reason)
|
||||
|
||||
|
||||
def succeeds_if(predicate, reason=None):
|
||||
predicate = _as_predicate(predicate)
|
||||
return fails_if(NotPredicate(predicate), reason)
|
||||
|
||||
|
||||
class Predicate(object):
|
||||
@classmethod
|
||||
def as_predicate(cls, predicate):
|
||||
@@ -93,6 +97,7 @@ class Predicate(object):
|
||||
else:
|
||||
assert False, "unknown predicate type: %s" % predicate
|
||||
|
||||
|
||||
class BooleanPredicate(Predicate):
|
||||
def __init__(self, value, description=None):
|
||||
self.value = value
|
||||
@@ -110,6 +115,7 @@ class BooleanPredicate(Predicate):
|
||||
def __str__(self):
|
||||
return self._as_string()
|
||||
|
||||
|
||||
class SpecPredicate(Predicate):
|
||||
def __init__(self, db, op=None, spec=None, description=None):
|
||||
self.db = db
|
||||
@@ -177,6 +183,7 @@ class SpecPredicate(Predicate):
|
||||
def __str__(self):
|
||||
return self._as_string()
|
||||
|
||||
|
||||
class LambdaPredicate(Predicate):
|
||||
def __init__(self, lambda_, description=None, args=None, kw=None):
|
||||
self.lambda_ = lambda_
|
||||
@@ -201,6 +208,7 @@ class LambdaPredicate(Predicate):
|
||||
def __str__(self):
|
||||
return self._as_string()
|
||||
|
||||
|
||||
class NotPredicate(Predicate):
|
||||
def __init__(self, predicate):
|
||||
self.predicate = predicate
|
||||
@@ -211,6 +219,7 @@ class NotPredicate(Predicate):
|
||||
def __str__(self):
|
||||
return self.predicate._as_string(True)
|
||||
|
||||
|
||||
class OrPredicate(Predicate):
|
||||
def __init__(self, predicates, description=None):
|
||||
self.predicates = predicates
|
||||
@@ -256,9 +265,11 @@ class OrPredicate(Predicate):
|
||||
|
||||
_as_predicate = Predicate.as_predicate
|
||||
|
||||
|
||||
def _is_excluded(db, op, spec):
|
||||
return SpecPredicate(db, op, spec)()
|
||||
|
||||
|
||||
def _server_version(engine):
|
||||
"""Return a server_version_info tuple."""
|
||||
|
||||
@@ -268,24 +279,30 @@ def _server_version(engine):
|
||||
conn.close()
|
||||
return version
|
||||
|
||||
|
||||
def db_spec(*dbs):
|
||||
return OrPredicate(
|
||||
Predicate.as_predicate(db) for db in dbs
|
||||
)
|
||||
|
||||
|
||||
def open():
|
||||
return skip_if(BooleanPredicate(False, "mark as execute"))
|
||||
|
||||
|
||||
def closed():
|
||||
return skip_if(BooleanPredicate(True, "marked as skip"))
|
||||
|
||||
|
||||
@decorator
|
||||
def future(fn, *args, **kw):
|
||||
return fails_if(LambdaPredicate(fn, *args, **kw), "Future feature")
|
||||
|
||||
|
||||
def fails_on(db, reason=None):
|
||||
return fails_if(SpecPredicate(db), reason)
|
||||
|
||||
|
||||
def fails_on_everything_except(*dbs):
|
||||
return succeeds_if(
|
||||
OrPredicate([
|
||||
@@ -293,9 +310,11 @@ def fails_on_everything_except(*dbs):
|
||||
])
|
||||
)
|
||||
|
||||
|
||||
def skip(db, reason=None):
|
||||
return skip_if(SpecPredicate(db), reason)
|
||||
|
||||
|
||||
def only_on(dbs, reason=None):
|
||||
return only_if(
|
||||
OrPredicate([SpecPredicate(db) for db in util.to_list(dbs)])
|
||||
|
||||
@@ -7,6 +7,7 @@ import sys
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta
|
||||
|
||||
|
||||
class TestBase(object):
|
||||
# A sequence of database names to always run, regardless of the
|
||||
# constraints below.
|
||||
@@ -29,6 +30,7 @@ class TestBase(object):
|
||||
def assert_(self, val, msg=None):
|
||||
assert val, msg
|
||||
|
||||
|
||||
class TablesTest(TestBase):
|
||||
|
||||
# 'once', None
|
||||
@@ -208,9 +210,11 @@ class _ORMTest(object):
|
||||
sa.orm.session.Session.close_all()
|
||||
sa.orm.clear_mappers()
|
||||
|
||||
|
||||
class ORMTest(_ORMTest, TestBase):
|
||||
pass
|
||||
|
||||
|
||||
class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
|
||||
# 'once', 'each', None
|
||||
run_setup_classes = 'once'
|
||||
@@ -252,7 +256,6 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
|
||||
cls.classes.clear()
|
||||
_ORMTest.teardown_class()
|
||||
|
||||
|
||||
@classmethod
|
||||
def _setup_once_classes(cls):
|
||||
if cls.run_setup_classes == 'once':
|
||||
@@ -275,18 +278,21 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
|
||||
|
||||
"""
|
||||
cls_registry = cls.classes
|
||||
|
||||
class FindFixture(type):
|
||||
def __init__(cls, classname, bases, dict_):
|
||||
cls_registry[classname] = cls
|
||||
return type.__init__(cls, classname, bases, dict_)
|
||||
|
||||
|
||||
class _Base(object):
|
||||
__metaclass__ = FindFixture
|
||||
|
||||
class Basic(BasicEntity, _Base):
|
||||
pass
|
||||
|
||||
class Comparable(ComparableEntity, _Base):
|
||||
pass
|
||||
|
||||
cls.Basic = Basic
|
||||
cls.Comparable = Comparable
|
||||
fn()
|
||||
@@ -306,6 +312,7 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
|
||||
def setup_mappers(cls):
|
||||
pass
|
||||
|
||||
|
||||
class DeclarativeMappedTest(MappedTest):
|
||||
run_setup_classes = 'once'
|
||||
run_setup_mappers = 'once'
|
||||
@@ -317,17 +324,21 @@ class DeclarativeMappedTest(MappedTest):
|
||||
@classmethod
|
||||
def _with_register_classes(cls, fn):
|
||||
cls_registry = cls.classes
|
||||
|
||||
class FindFixtureDeclarative(DeclarativeMeta):
|
||||
def __init__(cls, classname, bases, dict_):
|
||||
cls_registry[classname] = cls
|
||||
return DeclarativeMeta.__init__(
|
||||
cls, classname, bases, dict_)
|
||||
|
||||
class DeclarativeBasic(object):
|
||||
__table_cls__ = schema.Table
|
||||
|
||||
_DeclBase = declarative_base(metadata=cls.metadata,
|
||||
metaclass=FindFixtureDeclarative,
|
||||
cls=DeclarativeBasic)
|
||||
cls.DeclarativeBasic = _DeclBase
|
||||
fn()
|
||||
|
||||
if cls.metadata.tables:
|
||||
cls.metadata.create_all(config.db)
|
||||
|
||||
Reference in New Issue
Block a user