Files
sqlalchemy/test/engine/transaction.py
T
Jason Kirtland 5b779d30c3 - Added testbase.Table and testbase.Column, interceptors that can set up
test-run- and dialect-specific options on those objects
  All tests re-pointed to go through the interceptors
- Removed mysql_engine= from table declarations, replaced with a general
  flag indicating storage requirements
- Added ability to choose a global MySQL storage engine for all tests
  --mysql-engine=<whatever>
  If none is specified, tests use the old db-default/InnoDB behavior
- Added ability to append arbitrary table creation params
  --table-option=KEY=VALUE
  For MySQL 3, use this to set mysql_type instead of --mysql-engine
- Removed a couple dead test modules
2007-06-15 22:35:53 +00:00

476 lines
18 KiB
Python

import testbase
import unittest, sys, datetime, random, time, threading
import tables
db = testbase.db
from sqlalchemy import *
from sqlalchemy.orm import *
from testbase import Table, Column
class TransactionTest(testbase.PersistTest):
def setUpAll(self):
global users, metadata
metadata = MetaData()
users = Table('query_users', metadata,
Column('user_id', INT, primary_key = True),
Column('user_name', VARCHAR(20)),
test_needs_acid=True,
)
users.create(testbase.db)
def tearDown(self):
testbase.db.connect().execute(users.delete())
def tearDownAll(self):
users.drop(testbase.db)
def testcommits(self):
connection = testbase.db.connect()
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name='user1')
transaction.commit()
transaction = connection.begin()
connection.execute(users.insert(), user_id=2, user_name='user2')
connection.execute(users.insert(), user_id=3, user_name='user3')
transaction.commit()
transaction = connection.begin()
result = connection.execute("select * from query_users")
assert len(result.fetchall()) == 3
transaction.commit()
def testrollback(self):
"""test a basic rollback"""
connection = testbase.db.connect()
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name='user1')
connection.execute(users.insert(), user_id=2, user_name='user2')
connection.execute(users.insert(), user_id=3, user_name='user3')
transaction.rollback()
result = connection.execute("select * from query_users")
assert len(result.fetchall()) == 0
connection.close()
def testraise(self):
connection = testbase.db.connect()
transaction = connection.begin()
try:
connection.execute(users.insert(), user_id=1, user_name='user1')
connection.execute(users.insert(), user_id=2, user_name='user2')
connection.execute(users.insert(), user_id=1, user_name='user3')
transaction.commit()
assert False
except Exception , e:
print "Exception: ", e
transaction.rollback()
result = connection.execute("select * from query_users")
assert len(result.fetchall()) == 0
connection.close()
def testnestedrollback(self):
connection = testbase.db.connect()
try:
transaction = connection.begin()
try:
connection.execute(users.insert(), user_id=1, user_name='user1')
connection.execute(users.insert(), user_id=2, user_name='user2')
connection.execute(users.insert(), user_id=3, user_name='user3')
trans2 = connection.begin()
try:
connection.execute(users.insert(), user_id=4, user_name='user4')
connection.execute(users.insert(), user_id=5, user_name='user5')
raise Exception("uh oh")
trans2.commit()
except:
trans2.rollback()
raise
transaction.rollback()
except Exception, e:
transaction.rollback()
raise
except Exception, e:
try:
assert str(e) == 'uh oh' # and not "This transaction is inactive"
finally:
connection.close()
def testnesting(self):
connection = testbase.db.connect()
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name='user1')
connection.execute(users.insert(), user_id=2, user_name='user2')
connection.execute(users.insert(), user_id=3, user_name='user3')
trans2 = connection.begin()
connection.execute(users.insert(), user_id=4, user_name='user4')
connection.execute(users.insert(), user_id=5, user_name='user5')
trans2.commit()
transaction.rollback()
self.assert_(connection.scalar("select count(1) from query_users") == 0)
result = connection.execute("select * from query_users")
assert len(result.fetchall()) == 0
connection.close()
class AutoRollbackTest(testbase.PersistTest):
def setUpAll(self):
global metadata
metadata = MetaData()
def tearDownAll(self):
metadata.drop_all(testbase.db)
@testbase.unsupported('sqlite')
def testrollback_deadlock(self):
"""test that returning connections to the pool clears any object locks."""
conn1 = testbase.db.connect()
conn2 = testbase.db.connect()
users = Table('deadlock_users', metadata,
Column('user_id', INT, primary_key = True),
Column('user_name', VARCHAR(20)),
test_needs_acid=True,
)
users.create(conn1)
conn1.execute("select * from deadlock_users")
conn1.close()
# without auto-rollback in the connection pool's return() logic, this deadlocks in Postgres,
# because conn1 is returned to the pool but still has a lock on "deadlock_users"
# comment out the rollback in pool/ConnectionFairy._close() to see !
users.drop(conn2)
conn2.close()
class TLTransactionTest(testbase.PersistTest):
def setUpAll(self):
global users, metadata, tlengine
tlengine = create_engine(testbase.db_uri, strategy='threadlocal')
metadata = MetaData()
users = Table('query_users', metadata,
Column('user_id', INT, primary_key = True),
Column('user_name', VARCHAR(20)),
test_needs_acid=True,
)
users.create(tlengine)
def tearDown(self):
tlengine.execute(users.delete())
def tearDownAll(self):
users.drop(tlengine)
tlengine.dispose()
def testrollback(self):
"""test a basic rollback"""
tlengine.begin()
tlengine.execute(users.insert(), user_id=1, user_name='user1')
tlengine.execute(users.insert(), user_id=2, user_name='user2')
tlengine.execute(users.insert(), user_id=3, user_name='user3')
tlengine.rollback()
external_connection = tlengine.connect()
result = external_connection.execute("select * from query_users")
try:
assert len(result.fetchall()) == 0
finally:
external_connection.close()
def testcommit(self):
"""test a basic commit"""
tlengine.begin()
tlengine.execute(users.insert(), user_id=1, user_name='user1')
tlengine.execute(users.insert(), user_id=2, user_name='user2')
tlengine.execute(users.insert(), user_id=3, user_name='user3')
tlengine.commit()
external_connection = tlengine.connect()
result = external_connection.execute("select * from query_users")
try:
assert len(result.fetchall()) == 3
finally:
external_connection.close()
def testcommits(self):
connection = tlengine.contextual_connect()
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name='user1')
transaction.commit()
transaction = connection.begin()
connection.execute(users.insert(), user_id=2, user_name='user2')
connection.execute(users.insert(), user_id=3, user_name='user3')
transaction.commit()
transaction = connection.begin()
result = connection.execute("select * from query_users")
assert len(result.fetchall()) == 3
transaction.commit()
def testrollback_off_conn(self):
# test that a TLTransaction opened off a TLConnection allows that
# TLConnection to be aware of the transactional context
conn = tlengine.contextual_connect()
trans = conn.begin()
conn.execute(users.insert(), user_id=1, user_name='user1')
conn.execute(users.insert(), user_id=2, user_name='user2')
conn.execute(users.insert(), user_id=3, user_name='user3')
trans.rollback()
external_connection = tlengine.connect()
result = external_connection.execute("select * from query_users")
try:
assert len(result.fetchall()) == 0
finally:
external_connection.close()
def testmorerollback_off_conn(self):
# test that an existing TLConnection automatically takes place in a TLTransaction
# opened on a second TLConnection
conn = tlengine.contextual_connect()
conn2 = tlengine.contextual_connect()
trans = conn2.begin()
conn.execute(users.insert(), user_id=1, user_name='user1')
conn.execute(users.insert(), user_id=2, user_name='user2')
conn.execute(users.insert(), user_id=3, user_name='user3')
trans.rollback()
external_connection = tlengine.connect()
result = external_connection.execute("select * from query_users")
try:
assert len(result.fetchall()) == 0
finally:
external_connection.close()
def testcommit_off_conn(self):
conn = tlengine.contextual_connect()
trans = conn.begin()
conn.execute(users.insert(), user_id=1, user_name='user1')
conn.execute(users.insert(), user_id=2, user_name='user2')
conn.execute(users.insert(), user_id=3, user_name='user3')
trans.commit()
external_connection = tlengine.connect()
result = external_connection.execute("select * from query_users")
try:
assert len(result.fetchall()) == 3
finally:
external_connection.close()
@testbase.unsupported('sqlite')
def testnesting(self):
"""tests nesting of tranacstions"""
external_connection = tlengine.connect()
self.assert_(external_connection.connection is not tlengine.contextual_connect().connection)
tlengine.begin()
tlengine.execute(users.insert(), user_id=1, user_name='user1')
tlengine.execute(users.insert(), user_id=2, user_name='user2')
tlengine.execute(users.insert(), user_id=3, user_name='user3')
tlengine.begin()
tlengine.execute(users.insert(), user_id=4, user_name='user4')
tlengine.execute(users.insert(), user_id=5, user_name='user5')
tlengine.commit()
tlengine.rollback()
try:
self.assert_(external_connection.scalar("select count(1) from query_users") == 0)
finally:
external_connection.close()
def testmixednesting(self):
"""tests nesting of transactions off the TLEngine directly inside of
tranasctions off the connection from the TLEngine"""
external_connection = tlengine.connect()
self.assert_(external_connection.connection is not tlengine.contextual_connect().connection)
conn = tlengine.contextual_connect()
trans = conn.begin()
trans2 = conn.begin()
tlengine.execute(users.insert(), user_id=1, user_name='user1')
tlengine.execute(users.insert(), user_id=2, user_name='user2')
tlengine.execute(users.insert(), user_id=3, user_name='user3')
tlengine.begin()
tlengine.execute(users.insert(), user_id=4, user_name='user4')
tlengine.begin()
tlengine.execute(users.insert(), user_id=5, user_name='user5')
tlengine.execute(users.insert(), user_id=6, user_name='user6')
tlengine.execute(users.insert(), user_id=7, user_name='user7')
tlengine.commit()
tlengine.execute(users.insert(), user_id=8, user_name='user8')
tlengine.commit()
trans2.commit()
trans.rollback()
conn.close()
try:
self.assert_(external_connection.scalar("select count(1) from query_users") == 0)
finally:
external_connection.close()
def testmoremixednesting(self):
"""tests nesting of transactions off the connection from the TLEngine
inside of tranasctions off thbe TLEngine directly."""
external_connection = tlengine.connect()
self.assert_(external_connection.connection is not tlengine.contextual_connect().connection)
tlengine.begin()
connection = tlengine.contextual_connect()
connection.execute(users.insert(), user_id=1, user_name='user1')
tlengine.begin()
connection.execute(users.insert(), user_id=2, user_name='user2')
connection.execute(users.insert(), user_id=3, user_name='user3')
trans = connection.begin()
connection.execute(users.insert(), user_id=4, user_name='user4')
connection.execute(users.insert(), user_id=5, user_name='user5')
trans.commit()
tlengine.commit()
tlengine.rollback()
connection.close()
try:
self.assert_(external_connection.scalar("select count(1) from query_users") == 0)
finally:
external_connection.close()
def testsessionnesting(self):
class User(object):
pass
try:
mapper(User, users)
sess = create_session(bind_to=tlengine)
tlengine.begin()
u = User()
sess.save(u)
sess.flush()
tlengine.commit()
finally:
clear_mappers()
def testconnections(self):
"""tests that contextual_connect is threadlocal"""
c1 = tlengine.contextual_connect()
c2 = tlengine.contextual_connect()
assert c1.connection is c2.connection
c2.close()
assert c1.connection.connection is not None
class ForUpdateTest(testbase.PersistTest):
def setUpAll(self):
global counters, metadata
metadata = MetaData()
counters = Table('forupdate_counters', metadata,
Column('counter_id', INT, primary_key = True),
Column('counter_value', INT),
test_needs_acid=True,
)
counters.create(testbase.db)
def tearDown(self):
testbase.db.connect().execute(counters.delete())
def tearDownAll(self):
counters.drop(testbase.db)
def increment(self, count, errors, update_style=True, delay=0.005):
con = db.connect()
sel = counters.select(for_update=update_style,
whereclause=counters.c.counter_id==1)
for i in xrange(count):
trans = con.begin()
try:
existing = con.execute(sel).fetchone()
incr = existing['counter_value'] + 1
time.sleep(delay)
con.execute(counters.update(counters.c.counter_id==1,
values={'counter_value':incr}))
time.sleep(delay)
readback = con.execute(sel).fetchone()
if (readback['counter_value'] != incr):
raise AssertionError("Got %s post-update, expected %s" %
(readback['counter_value'], incr))
trans.commit()
except Exception, e:
trans.rollback()
errors.append(e)
break
con.close()
@testbase.supported('mysql', 'oracle', 'postgres')
def testqueued_update(self):
"""Test SELECT FOR UPDATE with concurrent modifications.
Runs concurrent modifications on a single row in the users table,
with each mutator trying to increment a value stored in user_name.
"""
db = testbase.db
db.execute(counters.insert(), counter_id=1, counter_value=0)
iterations, thread_count = 10, 5
threads, errors = [], []
for i in xrange(thread_count):
thread = threading.Thread(target=self.increment,
args=(iterations,),
kwargs={'errors': errors,
'update_style': True})
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
for e in errors:
sys.stderr.write("Failure: %s\n" % e)
self.assert_(len(errors) == 0)
sel = counters.select(whereclause=counters.c.counter_id==1)
final = db.execute(sel).fetchone()
self.assert_(final['counter_value'] == iterations * thread_count)
def overlap(self, ids, errors, update_style):
sel = counters.select(for_update=update_style,
whereclause=counters.c.counter_id.in_(*ids))
con = db.connect()
trans = con.begin()
try:
rows = con.execute(sel).fetchall()
time.sleep(0.25)
trans.commit()
except Exception, e:
trans.rollback()
errors.append(e)
def _threaded_overlap(self, thread_count, groups, update_style=True, pool=5):
db = testbase.db
for cid in range(pool - 1):
db.execute(counters.insert(), counter_id=cid + 1, counter_value=0)
errors, threads = [], []
for i in xrange(thread_count):
thread = threading.Thread(target=self.overlap,
args=(groups.pop(0), errors, update_style))
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
return errors
@testbase.supported('mysql', 'oracle', 'postgres')
def testqueued_select(self):
"""Simple SELECT FOR UPDATE conflict test"""
errors = self._threaded_overlap(2, [(1,2,3),(3,4,5)])
for e in errors:
sys.stderr.write("Failure: %s\n" % e)
self.assert_(len(errors) == 0)
@testbase.supported('oracle', 'postgres')
def testnowait_select(self):
"""Simple SELECT FOR UPDATE NOWAIT conflict test"""
errors = self._threaded_overlap(2, [(1,2,3),(3,4,5)],
update_style='nowait')
self.assert_(len(errors) != 0)
if __name__ == "__main__":
testbase.main()