repaired oracle savepoint implementation

This commit is contained in:
Mike Bayer
2007-08-11 00:03:26 +00:00
parent b852fcbce0
commit 1391efea78
3 changed files with 20 additions and 7 deletions
+5 -5
View File
@@ -745,13 +745,13 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
return text
def visit_savepoint(self, savepoint_stmt):
return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
def visit_rollback_to_savepoint(self, savepoint_stmt):
return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
def visit_release_savepoint(self, savepoint_stmt):
return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
def __str__(self):
return self.string
@@ -1052,8 +1052,8 @@ class ANSIIdentifierPreparer(object):
def format_alias(self, alias, name=None):
return self.__generic_obj_format(alias, name or alias.name)
def format_savepoint(self, savepoint):
return self.__generic_obj_format(savepoint, savepoint)
def format_savepoint(self, savepoint, name=None):
return self.__generic_obj_format(savepoint, name or savepoint.ident)
def format_constraint(self, constraint):
return self.__generic_obj_format(constraint, constraint.name)
+13
View File
@@ -280,12 +280,19 @@ class OracleDialect(ansisql.ANSIDialect):
else:
return "rowid"
def do_release_savepoint(self, connection, name):
# Oracle does not support RELEASE SAVEPOINT
pass
def create_execution_context(self, *args, **kwargs):
return OracleExecutionContext(self, *args, **kwargs)
def compiler(self, statement, bindparams, **kwargs):
return OracleCompiler(self, statement, bindparams, **kwargs)
def preparer(self):
return OracleIdentifierPreparer(self)
def schemagenerator(self, *args, **kwargs):
return OracleSchemaGenerator(self, *args, **kwargs)
@@ -662,4 +669,10 @@ class OracleDefaultRunner(ansisql.ANSIDefaultRunner):
def visit_sequence(self, seq):
return self.connection.execute("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL").scalar()
class OracleIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
def format_savepoint(self, savepoint):
name = re.sub(r'^_+', '', savepoint.ident)
return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name)
dialect = OracleDialect
+2 -2
View File
@@ -173,7 +173,7 @@ class TransactionTest(PersistTest):
)
connection.close()
@testing.supported('postgres', 'mysql')
@testing.supported('postgres', 'mysql', 'oracle')
@testing.exclude('mysql', '<', (5, 0, 3))
def testtwophasetransaction(self):
connection = testbase.db.connect()
@@ -301,7 +301,7 @@ class TLTransactionTest(PersistTest):
tlengine = create_engine(testbase.db.url, strategy='threadlocal')
metadata = MetaData()
users = Table('query_users', metadata,
Column('user_id', INT, primary_key = True),
Column('user_id', INT, Sequence('query_users_id_seq', optional=True), primary_key=True),
Column('user_name', VARCHAR(20)),
test_needs_acid=True,
)