mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-17 22:22:13 -04:00
repaired oracle savepoint implementation
This commit is contained in:
@@ -745,13 +745,13 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
|
||||
return text
|
||||
|
||||
def visit_savepoint(self, savepoint_stmt):
|
||||
return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
|
||||
return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
|
||||
|
||||
def visit_rollback_to_savepoint(self, savepoint_stmt):
|
||||
return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
|
||||
return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
|
||||
|
||||
def visit_release_savepoint(self, savepoint_stmt):
|
||||
return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
|
||||
return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
|
||||
|
||||
def __str__(self):
|
||||
return self.string
|
||||
@@ -1052,8 +1052,8 @@ class ANSIIdentifierPreparer(object):
|
||||
def format_alias(self, alias, name=None):
|
||||
return self.__generic_obj_format(alias, name or alias.name)
|
||||
|
||||
def format_savepoint(self, savepoint):
|
||||
return self.__generic_obj_format(savepoint, savepoint)
|
||||
def format_savepoint(self, savepoint, name=None):
|
||||
return self.__generic_obj_format(savepoint, name or savepoint.ident)
|
||||
|
||||
def format_constraint(self, constraint):
|
||||
return self.__generic_obj_format(constraint, constraint.name)
|
||||
|
||||
@@ -280,12 +280,19 @@ class OracleDialect(ansisql.ANSIDialect):
|
||||
else:
|
||||
return "rowid"
|
||||
|
||||
def do_release_savepoint(self, connection, name):
|
||||
# Oracle does not support RELEASE SAVEPOINT
|
||||
pass
|
||||
|
||||
def create_execution_context(self, *args, **kwargs):
|
||||
return OracleExecutionContext(self, *args, **kwargs)
|
||||
|
||||
def compiler(self, statement, bindparams, **kwargs):
|
||||
return OracleCompiler(self, statement, bindparams, **kwargs)
|
||||
|
||||
def preparer(self):
|
||||
return OracleIdentifierPreparer(self)
|
||||
|
||||
def schemagenerator(self, *args, **kwargs):
|
||||
return OracleSchemaGenerator(self, *args, **kwargs)
|
||||
|
||||
@@ -662,4 +669,10 @@ class OracleDefaultRunner(ansisql.ANSIDefaultRunner):
|
||||
def visit_sequence(self, seq):
|
||||
return self.connection.execute("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL").scalar()
|
||||
|
||||
class OracleIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
|
||||
def format_savepoint(self, savepoint):
|
||||
name = re.sub(r'^_+', '', savepoint.ident)
|
||||
return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name)
|
||||
|
||||
|
||||
dialect = OracleDialect
|
||||
|
||||
@@ -173,7 +173,7 @@ class TransactionTest(PersistTest):
|
||||
)
|
||||
connection.close()
|
||||
|
||||
@testing.supported('postgres', 'mysql')
|
||||
@testing.supported('postgres', 'mysql', 'oracle')
|
||||
@testing.exclude('mysql', '<', (5, 0, 3))
|
||||
def testtwophasetransaction(self):
|
||||
connection = testbase.db.connect()
|
||||
@@ -301,7 +301,7 @@ class TLTransactionTest(PersistTest):
|
||||
tlengine = create_engine(testbase.db.url, strategy='threadlocal')
|
||||
metadata = MetaData()
|
||||
users = Table('query_users', metadata,
|
||||
Column('user_id', INT, primary_key = True),
|
||||
Column('user_id', INT, Sequence('query_users_id_seq', optional=True), primary_key=True),
|
||||
Column('user_name', VARCHAR(20)),
|
||||
test_needs_acid=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user