Files
sqlalchemy/test/testbase.py
T
Mike Bayer 00a20dd075
2005-10-15 20:47:19 +00:00

76 lines
2.8 KiB
Python

import unittest
import StringIO
import sqlalchemy.engine as engine
echo = True
class PersistTest(unittest.TestCase):
def __init__(self, *args, **params):
unittest.TestCase.__init__(self, *args, **params)
def echo(self, text):
if echo:
print text
def capture_exec(self, db, callable_):
e = db.echo
b = db.logger
buffer = StringIO.StringIO()
db.logger = buffer
db.echo = True
try:
callable_()
if echo:
print buffer.getvalue()
return buffer.getvalue()
finally:
db.logger = b
db.echo = e
class AssertMixin(PersistTest):
def assert_result(self, result, class_, *objects):
if echo:
print repr(result)
self.assert_list(result, class_, objects)
def assert_list(self, result, class_, list):
self.assert_(len(result) == len(list), "result list is not the same size as test list, for class " + class_.__name__)
for i in range(0, len(list)):
self.assert_row(class_, result[i], list[i])
def assert_row(self, class_, rowobj, desc):
self.assert_(rowobj.__class__ is class_, "item class is not " + repr(class_))
for key, value in desc.iteritems():
if isinstance(value, tuple):
if isinstance(value[1], list):
self.assert_list(getattr(rowobj, key), value[0], value[1])
else:
self.assert_row(value[0], getattr(rowobj, key), value[1])
else:
self.assert_(getattr(rowobj, key) == value, "attribute %s value %s does not match %s" % (key, getattr(rowobj, key), value))
def assert_enginesql(self, db, callable_, result):
self.assert_(self.capture_exec(db, callable_) == result, result)
class EngineAssert(object):
def __init__(self, engine):
self.engine = engine
self.realexec = engine.execute
engine.execute = self.execute
self.echo = engine.echo
self.set_assert_list(None, None)
def __getattr__(self, key):
return getattr(self.engine, key)
def set_assert_list(self, unittest, list):
self.unittest = unittest
self.assert_list = list
if list is not None:
self.assert_list.reverse()
def execute(self, statement, parameters, **kwargs):
self.engine.echo = self.echo
if self.assert_list is not None:
item = self.assert_list.pop()
(query, params) = item
self.unittest.assert_(statement == query and params == parameters, query + repr(params) + statement + repr(parameters))
return self.realexec(statement, parameters, **kwargs)
def runTests(suite):
runner = unittest.TextTestRunner(verbosity = 2, descriptions =1)
runner.run(suite)