A few fixes to the access dialect

This commit is contained in:
Paul Johnston
2007-10-12 23:39:28 +00:00
parent 2585a470c0
commit aafe57ab5d
+22 -31
View File
@@ -7,10 +7,9 @@
import random
from sqlalchemy import sql, schema, types, exceptions, pool
from sqlalchemy.sql import compiler
from sqlalchemy.sql import compiler, expression
from sqlalchemy.engine import default, base
class AcNumeric(types.Numeric):
def result_processor(self, dialect):
return None
@@ -149,11 +148,13 @@ class AccessExecutionContext(default.DefaultExecutionContext):
break
if bool(tbl.has_sequence):
if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
self.cursor.execute("SELECT @@identity AS lastrowid")
row = self.cursor.fetchone()
self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:]
# print "LAST ROW ID", self._last_inserted_ids
# TBD: for some reason _last_inserted_ids doesn't exist here
# (but it does at corresponding point in mssql???)
#if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
self.cursor.execute("SELECT @@identity AS lastrowid")
row = self.cursor.fetchone()
self._last_inserted_ids = [int(row[0])] #+ self._last_inserted_ids[1:]
# print "LAST ROW ID", self._last_inserted_ids
super(AccessExecutionContext, self).post_exec()
@@ -177,7 +178,7 @@ class AccessDialect(default.DefaultDialect):
}
supports_sane_rowcount = False
supports_sane_multi_rowcount = False
def type_descriptor(self, typeobj):
newobj = types.adapt_type(typeobj, self.colspecs)
@@ -217,21 +218,6 @@ class AccessDialect(default.DefaultDialect):
def last_inserted_ids(self):
return self.context.last_inserted_ids
def compiler(self, statement, bindparams, **kwargs):
return AccessCompiler(self, statement, bindparams, **kwargs)
def schemagenerator(self, *args, **kwargs):
return AccessSchemaGenerator(self, *args, **kwargs)
def schemadropper(self, *args, **kwargs):
return AccessSchemaDropper(self, *args, **kwargs)
def defaultrunner(self, connection, **kwargs):
return AccessDefaultRunner(connection, **kwargs)
def preparer(self):
return AccessIdentifierPreparer(self)
def do_execute(self, cursor, statement, params, **kwargs):
if params == {}:
params = ()
@@ -254,7 +240,7 @@ class AccessDialect(default.DefaultDialect):
except Exception, e:
return False
def reflecttable(self, connection, table):
def reflecttable(self, connection, table, include_columns):
# This is defined in the function, as it relies on win32com constants,
# that aren't imported until dbapi method is called
if not hasattr(self, 'ischema_names'):
@@ -364,13 +350,11 @@ class AccessCompiler(compiler.DefaultCompiler):
"""Access uses "mod" instead of "%" """
return binary.operator == '%' and 'mod' or binary.operator
def visit_select(self, select):
"""Label function calls, so they return a name in cursor.description"""
for i,c in enumerate(select._raw_columns):
if isinstance(c, sql._Function):
select._raw_columns[i] = c.label(c.name + "_" + hex(random.randint(0, 65535))[2:])
super(AccessCompiler, self).visit_select(select)
def label_select_column(self, select, column):
if isinstance(column, expression._Function):
return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:])
else:
return super(AccessCompiler, self).label_select_column(select, column)
function_rewrites = {'current_date': 'now',
'current_timestamp': 'now',
@@ -418,9 +402,16 @@ class AccessDefaultRunner(base.DefaultRunner):
pass
class AccessIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = compiler.RESERVED_WORDS.copy()
reserved_words.update(['value', 'text'])
def __init__(self, dialect):
super(AccessIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
dialect = AccessDialect
dialect.poolclass = pool.SingletonThreadPool
dialect.statement_compiler = AccessCompiler
dialect.schemagenerator = AccessSchemaGenerator
dialect.schemadropper = AccessSchemaDropper
dialect.preparer = AccessIdentifierPreparer
dialect.defaultrunner = AccessDefaultRunner