mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-06-04 23:06:24 -04:00
6f60e76883
to 'scale'. 'length' is deprecated and is still accepted with a warning. [ticket:827] - The 'length' argument to MSInteger, MSBigInteger, MSTinyInteger, MSSmallInteger and MSYear has been renamed to 'display_width'. [ticket:827] - mysql._Numeric now consumes 'unsigned' and 'zerofill' from the given kw, so that the same kw can be passed along to Numeric and allow the 'length' deprecation logic to still take effect - added testlib.engines.all_dialects() to return a dialect for every db module - informix added to sqlalchemy.databases.__all__. Since other "experimental" dbs like access and sybase are there, informix should be as well.
248 lines
8.0 KiB
Python
248 lines
8.0 KiB
Python
import sys, types, weakref
|
|
from collections import deque
|
|
from testlib import config
|
|
from testlib.compat import _function_named
|
|
|
|
class ConnectionKiller(object):
|
|
def __init__(self):
|
|
self.proxy_refs = weakref.WeakKeyDictionary()
|
|
|
|
def checkout(self, dbapi_con, con_record, con_proxy):
|
|
self.proxy_refs[con_proxy] = True
|
|
|
|
def _apply_all(self, methods):
|
|
for rec in self.proxy_refs:
|
|
if rec is not None and rec.is_valid:
|
|
try:
|
|
for name in methods:
|
|
if callable(name):
|
|
name(rec)
|
|
else:
|
|
getattr(rec, name)()
|
|
except (SystemExit, KeyboardInterrupt):
|
|
raise
|
|
except Exception, e:
|
|
# fixme
|
|
sys.stderr.write("\n" + str(e) + "\n")
|
|
|
|
def rollback_all(self):
|
|
self._apply_all(('rollback',))
|
|
|
|
def close_all(self):
|
|
self._apply_all(('rollback', 'close'))
|
|
|
|
def assert_all_closed(self):
|
|
for rec in self.proxy_refs:
|
|
if rec.is_valid:
|
|
assert False
|
|
|
|
testing_reaper = ConnectionKiller()
|
|
|
|
def assert_conns_closed(fn):
|
|
def decorated(*args, **kw):
|
|
try:
|
|
fn(*args, **kw)
|
|
finally:
|
|
testing_reaper.assert_all_closed()
|
|
return _function_named(decorated, fn.__name__)
|
|
|
|
def rollback_open_connections(fn):
|
|
"""Decorator that rolls back all open connections after fn execution."""
|
|
|
|
def decorated(*args, **kw):
|
|
try:
|
|
fn(*args, **kw)
|
|
finally:
|
|
testing_reaper.rollback_all()
|
|
return _function_named(decorated, fn.__name__)
|
|
|
|
def close_open_connections(fn):
|
|
"""Decorator that closes all connections after fn execution."""
|
|
|
|
def decorated(*args, **kw):
|
|
try:
|
|
fn(*args, **kw)
|
|
finally:
|
|
testing_reaper.close_all()
|
|
return _function_named(decorated, fn.__name__)
|
|
|
|
def all_dialects():
|
|
import sqlalchemy.databases as d
|
|
for name in d.__all__:
|
|
mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
|
|
yield mod.dialect()
|
|
|
|
class ReconnectFixture(object):
|
|
def __init__(self, dbapi):
|
|
self.dbapi = dbapi
|
|
self.connections = []
|
|
|
|
def __getattr__(self, key):
|
|
return getattr(self.dbapi, key)
|
|
|
|
def connect(self, *args, **kwargs):
|
|
conn = self.dbapi.connect(*args, **kwargs)
|
|
self.connections.append(conn)
|
|
return conn
|
|
|
|
def shutdown(self):
|
|
for c in list(self.connections):
|
|
c.close()
|
|
self.connections = []
|
|
|
|
def reconnecting_engine(url=None, options=None):
|
|
url = url or config.db_url
|
|
dbapi = config.db.dialect.dbapi
|
|
if not options:
|
|
options = {}
|
|
options['module'] = ReconnectFixture(dbapi)
|
|
engine = testing_engine(url, options)
|
|
engine.test_shutdown = engine.dialect.dbapi.shutdown
|
|
return engine
|
|
|
|
def testing_engine(url=None, options=None):
|
|
"""Produce an engine configured by --options with optional overrides."""
|
|
|
|
from sqlalchemy import create_engine
|
|
from testlib.testing import ExecutionContextWrapper
|
|
|
|
url = url or config.db_url
|
|
options = options or config.db_opts
|
|
|
|
listeners = options.setdefault('listeners', [])
|
|
listeners.append(testing_reaper)
|
|
|
|
engine = create_engine(url, **options)
|
|
|
|
create_context = engine.dialect.create_execution_context
|
|
def create_exec_context(*args, **kwargs):
|
|
return ExecutionContextWrapper(create_context(*args, **kwargs))
|
|
engine.dialect.create_execution_context = create_exec_context
|
|
return engine
|
|
|
|
def utf8_engine(url=None, options=None):
|
|
"""Hook for dialects or drivers that don't handle utf8 by default."""
|
|
|
|
from sqlalchemy.engine import url as engine_url
|
|
|
|
if config.db.name == 'mysql':
|
|
dbapi_ver = config.db.dialect.dbapi.version_info
|
|
if (dbapi_ver < (1, 2, 1) or
|
|
dbapi_ver in ((1, 2, 1, 'gamma', 1), (1, 2, 1, 'gamma', 2),
|
|
(1, 2, 1, 'gamma', 3), (1, 2, 1, 'gamma', 5))):
|
|
raise RuntimeError('Character set support unavailable with this '
|
|
'driver version: %s' % repr(dbapi_ver))
|
|
else:
|
|
url = url or config.db_url
|
|
url = engine_url.make_url(url)
|
|
url.query['charset'] = 'utf8'
|
|
url.query['use_unicode'] = '0'
|
|
url = str(url)
|
|
|
|
return testing_engine(url, options)
|
|
|
|
def mock_engine(db=None):
|
|
"""Provides a mocking engine based on the current testing.db."""
|
|
|
|
from sqlalchemy import create_engine
|
|
|
|
dbi = db or config.db
|
|
buffer = []
|
|
def executor(sql, *a, **kw):
|
|
buffer.append(sql)
|
|
engine = create_engine(dbi.name + '://',
|
|
strategy='mock', executor=executor)
|
|
assert not hasattr(engine, 'mock')
|
|
engine.mock = buffer
|
|
return engine
|
|
|
|
class ReplayableSession(object):
|
|
"""A simple record/playback tool.
|
|
|
|
This is *not* a mock testing class. It only records a session for later
|
|
playback and makes no assertions on call consistency whatsoever. It's
|
|
unlikely to be suitable for anything other than DB-API recording.
|
|
|
|
"""
|
|
|
|
Callable = object()
|
|
NoAttribute = object()
|
|
Natives = set([getattr(types, t)
|
|
for t in dir(types) if not t.startswith('_')]). \
|
|
difference([getattr(types, t)
|
|
for t in ('FunctionType', 'BuiltinFunctionType',
|
|
'MethodType', 'BuiltinMethodType',
|
|
'LambdaType', 'UnboundMethodType',)])
|
|
def __init__(self):
|
|
self.buffer = deque()
|
|
|
|
def recorder(self, base):
|
|
return self.Recorder(self.buffer, base)
|
|
|
|
def player(self):
|
|
return self.Player(self.buffer)
|
|
|
|
class Recorder(object):
|
|
def __init__(self, buffer, subject):
|
|
self._buffer = buffer
|
|
self._subject = subject
|
|
|
|
def __call__(self, *args, **kw):
|
|
subject, buffer = [object.__getattribute__(self, x)
|
|
for x in ('_subject', '_buffer')]
|
|
|
|
result = subject(*args, **kw)
|
|
if type(result) not in ReplayableSession.Natives:
|
|
buffer.append(ReplayableSession.Callable)
|
|
return type(self)(buffer, result)
|
|
else:
|
|
buffer.append(result)
|
|
return result
|
|
|
|
def __getattribute__(self, key):
|
|
try:
|
|
return object.__getattribute__(self, key)
|
|
except AttributeError:
|
|
pass
|
|
|
|
subject, buffer = [object.__getattribute__(self, x)
|
|
for x in ('_subject', '_buffer')]
|
|
try:
|
|
result = type(subject).__getattribute__(subject, key)
|
|
except AttributeError:
|
|
buffer.append(ReplayableSession.NoAttribute)
|
|
raise
|
|
else:
|
|
if type(result) not in ReplayableSession.Natives:
|
|
buffer.append(ReplayableSession.Callable)
|
|
return type(self)(buffer, result)
|
|
else:
|
|
buffer.append(result)
|
|
return result
|
|
|
|
class Player(object):
|
|
def __init__(self, buffer):
|
|
self._buffer = buffer
|
|
|
|
def __call__(self, *args, **kw):
|
|
buffer = object.__getattribute__(self, '_buffer')
|
|
result = buffer.popleft()
|
|
if result is ReplayableSession.Callable:
|
|
return self
|
|
else:
|
|
return result
|
|
|
|
def __getattribute__(self, key):
|
|
try:
|
|
return object.__getattribute__(self, key)
|
|
except AttributeError:
|
|
pass
|
|
buffer = object.__getattribute__(self, '_buffer')
|
|
result = buffer.popleft()
|
|
if result is ReplayableSession.Callable:
|
|
return self
|
|
elif result is ReplayableSession.NoAttribute:
|
|
raise AttributeError(key)
|
|
else:
|
|
return result
|