mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-06-06 07:45:30 -04:00
c0b5a0446b
tests extend from either TestBase or ORMTest, using additional mixins for special assertion methods as needed
281 lines
7.1 KiB
Python
281 lines
7.1 KiB
Python
import testenv; testenv.configure_for_tests()
|
|
import sys, weakref
|
|
from sqlalchemy import create_engine, exceptions, select
|
|
from testlib import *
|
|
|
|
|
|
class MockDisconnect(Exception):
|
|
pass
|
|
|
|
class MockDBAPI(object):
|
|
def __init__(self):
|
|
self.paramstyle = 'named'
|
|
self.connections = weakref.WeakKeyDictionary()
|
|
def connect(self, *args, **kwargs):
|
|
return MockConnection(self)
|
|
def shutdown(self):
|
|
for c in self.connections:
|
|
c.explode[0] = True
|
|
Error = MockDisconnect
|
|
|
|
class MockConnection(object):
|
|
def __init__(self, dbapi):
|
|
dbapi.connections[self] = True
|
|
self.explode = [False]
|
|
def rollback(self):
|
|
pass
|
|
def commit(self):
|
|
pass
|
|
def cursor(self):
|
|
return MockCursor(self)
|
|
def close(self):
|
|
pass
|
|
|
|
class MockCursor(object):
|
|
def __init__(self, parent):
|
|
self.explode = parent.explode
|
|
self.description = None
|
|
def execute(self, *args, **kwargs):
|
|
if self.explode[0]:
|
|
raise MockDisconnect("Lost the DB connection")
|
|
else:
|
|
return
|
|
def close(self):
|
|
pass
|
|
|
|
class MockReconnectTest(TestBase):
|
|
def setUp(self):
|
|
global db, dbapi
|
|
dbapi = MockDBAPI()
|
|
|
|
# create engine using our current dburi
|
|
db = create_engine('postgres://foo:bar@localhost/test', module=dbapi)
|
|
|
|
# monkeypatch disconnect checker
|
|
db.dialect.is_disconnect = lambda e: isinstance(e, MockDisconnect)
|
|
|
|
def test_reconnect(self):
|
|
"""test that an 'is_disconnect' condition will invalidate the connection, and additionally
|
|
dispose the previous connection pool and recreate."""
|
|
|
|
|
|
pid = id(db.pool)
|
|
|
|
# make a connection
|
|
conn = db.connect()
|
|
|
|
# connection works
|
|
conn.execute(select([1]))
|
|
|
|
# create a second connection within the pool, which we'll ensure also goes away
|
|
conn2 = db.connect()
|
|
conn2.close()
|
|
|
|
# two connections opened total now
|
|
assert len(dbapi.connections) == 2
|
|
|
|
# set it to fail
|
|
dbapi.shutdown()
|
|
|
|
try:
|
|
conn.execute(select([1]))
|
|
assert False
|
|
except exceptions.DBAPIError:
|
|
pass
|
|
|
|
# assert was invalidated
|
|
assert not conn.closed
|
|
assert conn.invalidated
|
|
|
|
# close shouldnt break
|
|
conn.close()
|
|
|
|
assert id(db.pool) != pid
|
|
|
|
# ensure all connections closed (pool was recycled)
|
|
assert len(dbapi.connections) == 0
|
|
|
|
conn =db.connect()
|
|
conn.execute(select([1]))
|
|
conn.close()
|
|
assert len(dbapi.connections) == 1
|
|
|
|
def test_invalidate_trans(self):
|
|
conn = db.connect()
|
|
trans = conn.begin()
|
|
dbapi.shutdown()
|
|
|
|
try:
|
|
conn.execute(select([1]))
|
|
assert False
|
|
except exceptions.DBAPIError:
|
|
pass
|
|
|
|
# assert was invalidated
|
|
assert len(dbapi.connections) == 0
|
|
assert not conn.closed
|
|
assert conn.invalidated
|
|
assert trans.is_active
|
|
|
|
try:
|
|
conn.execute(select([1]))
|
|
assert False
|
|
except exceptions.InvalidRequestError, e:
|
|
assert str(e) == "Can't reconnect until invalid transaction is rolled back"
|
|
|
|
assert trans.is_active
|
|
|
|
try:
|
|
trans.commit()
|
|
assert False
|
|
except exceptions.InvalidRequestError, e:
|
|
assert str(e) == "Can't reconnect until invalid transaction is rolled back"
|
|
|
|
assert trans.is_active
|
|
|
|
trans.rollback()
|
|
assert not trans.is_active
|
|
|
|
conn.execute(select([1]))
|
|
assert not conn.invalidated
|
|
|
|
assert len(dbapi.connections) == 1
|
|
|
|
def test_conn_reusable(self):
|
|
conn = db.connect()
|
|
|
|
conn.execute(select([1]))
|
|
|
|
assert len(dbapi.connections) == 1
|
|
|
|
dbapi.shutdown()
|
|
|
|
# raises error
|
|
try:
|
|
conn.execute(select([1]))
|
|
assert False
|
|
except exceptions.DBAPIError:
|
|
pass
|
|
|
|
assert not conn.closed
|
|
assert conn.invalidated
|
|
|
|
# ensure all connections closed (pool was recycled)
|
|
assert len(dbapi.connections) == 0
|
|
|
|
# test reconnects
|
|
conn.execute(select([1]))
|
|
assert not conn.invalidated
|
|
assert len(dbapi.connections) == 1
|
|
|
|
|
|
class RealReconnectTest(TestBase):
|
|
def setUp(self):
|
|
global engine
|
|
engine = engines.reconnecting_engine()
|
|
|
|
def tearDown(self):
|
|
engine.dispose()
|
|
|
|
def test_reconnect(self):
|
|
conn = engine.connect()
|
|
|
|
self.assertEquals(conn.execute(select([1])).scalar(), 1)
|
|
assert not conn.closed
|
|
|
|
engine.test_shutdown()
|
|
|
|
try:
|
|
conn.execute(select([1]))
|
|
assert False
|
|
except exceptions.DBAPIError, e:
|
|
if not e.connection_invalidated:
|
|
raise
|
|
|
|
assert not conn.closed
|
|
assert conn.invalidated
|
|
|
|
assert conn.invalidated
|
|
self.assertEquals(conn.execute(select([1])).scalar(), 1)
|
|
assert not conn.invalidated
|
|
|
|
# one more time
|
|
engine.test_shutdown()
|
|
try:
|
|
conn.execute(select([1]))
|
|
assert False
|
|
except exceptions.DBAPIError, e:
|
|
if not e.connection_invalidated:
|
|
raise
|
|
assert conn.invalidated
|
|
self.assertEquals(conn.execute(select([1])).scalar(), 1)
|
|
assert not conn.invalidated
|
|
|
|
conn.close()
|
|
|
|
def test_close(self):
|
|
conn = engine.connect()
|
|
self.assertEquals(conn.execute(select([1])).scalar(), 1)
|
|
assert not conn.closed
|
|
|
|
engine.test_shutdown()
|
|
|
|
try:
|
|
conn.execute(select([1]))
|
|
assert False
|
|
except exceptions.DBAPIError, e:
|
|
if not e.connection_invalidated:
|
|
raise
|
|
|
|
conn.close()
|
|
conn = engine.connect()
|
|
self.assertEquals(conn.execute(select([1])).scalar(), 1)
|
|
|
|
def test_with_transaction(self):
|
|
conn = engine.connect()
|
|
|
|
trans = conn.begin()
|
|
|
|
self.assertEquals(conn.execute(select([1])).scalar(), 1)
|
|
assert not conn.closed
|
|
|
|
engine.test_shutdown()
|
|
|
|
try:
|
|
conn.execute(select([1]))
|
|
assert False
|
|
except exceptions.DBAPIError, e:
|
|
if not e.connection_invalidated:
|
|
raise
|
|
|
|
assert not conn.closed
|
|
assert conn.invalidated
|
|
assert trans.is_active
|
|
|
|
try:
|
|
conn.execute(select([1]))
|
|
assert False
|
|
except exceptions.InvalidRequestError, e:
|
|
assert str(e) == "Can't reconnect until invalid transaction is rolled back"
|
|
|
|
assert trans.is_active
|
|
|
|
try:
|
|
trans.commit()
|
|
assert False
|
|
except exceptions.InvalidRequestError, e:
|
|
assert str(e) == "Can't reconnect until invalid transaction is rolled back"
|
|
|
|
assert trans.is_active
|
|
|
|
trans.rollback()
|
|
assert not trans.is_active
|
|
|
|
assert conn.invalidated
|
|
self.assertEquals(conn.execute(select([1])).scalar(), 1)
|
|
assert not conn.invalidated
|
|
|
|
|
|
if __name__ == '__main__':
|
|
testenv.main()
|