mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-20 07:32:05 -04:00
A few fixes to the access dialect
This commit is contained in:
@@ -7,10 +7,9 @@
|
||||
|
||||
import random
|
||||
from sqlalchemy import sql, schema, types, exceptions, pool
|
||||
from sqlalchemy.sql import compiler
|
||||
from sqlalchemy.sql import compiler, expression
|
||||
from sqlalchemy.engine import default, base
|
||||
|
||||
|
||||
class AcNumeric(types.Numeric):
|
||||
def result_processor(self, dialect):
|
||||
return None
|
||||
@@ -149,11 +148,13 @@ class AccessExecutionContext(default.DefaultExecutionContext):
|
||||
break
|
||||
|
||||
if bool(tbl.has_sequence):
|
||||
if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
|
||||
self.cursor.execute("SELECT @@identity AS lastrowid")
|
||||
row = self.cursor.fetchone()
|
||||
self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:]
|
||||
# print "LAST ROW ID", self._last_inserted_ids
|
||||
# TBD: for some reason _last_inserted_ids doesn't exist here
|
||||
# (but it does at corresponding point in mssql???)
|
||||
#if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
|
||||
self.cursor.execute("SELECT @@identity AS lastrowid")
|
||||
row = self.cursor.fetchone()
|
||||
self._last_inserted_ids = [int(row[0])] #+ self._last_inserted_ids[1:]
|
||||
# print "LAST ROW ID", self._last_inserted_ids
|
||||
|
||||
super(AccessExecutionContext, self).post_exec()
|
||||
|
||||
@@ -177,7 +178,7 @@ class AccessDialect(default.DefaultDialect):
|
||||
}
|
||||
|
||||
supports_sane_rowcount = False
|
||||
|
||||
supports_sane_multi_rowcount = False
|
||||
|
||||
def type_descriptor(self, typeobj):
|
||||
newobj = types.adapt_type(typeobj, self.colspecs)
|
||||
@@ -217,21 +218,6 @@ class AccessDialect(default.DefaultDialect):
|
||||
def last_inserted_ids(self):
|
||||
return self.context.last_inserted_ids
|
||||
|
||||
def compiler(self, statement, bindparams, **kwargs):
|
||||
return AccessCompiler(self, statement, bindparams, **kwargs)
|
||||
|
||||
def schemagenerator(self, *args, **kwargs):
|
||||
return AccessSchemaGenerator(self, *args, **kwargs)
|
||||
|
||||
def schemadropper(self, *args, **kwargs):
|
||||
return AccessSchemaDropper(self, *args, **kwargs)
|
||||
|
||||
def defaultrunner(self, connection, **kwargs):
|
||||
return AccessDefaultRunner(connection, **kwargs)
|
||||
|
||||
def preparer(self):
|
||||
return AccessIdentifierPreparer(self)
|
||||
|
||||
def do_execute(self, cursor, statement, params, **kwargs):
|
||||
if params == {}:
|
||||
params = ()
|
||||
@@ -254,7 +240,7 @@ class AccessDialect(default.DefaultDialect):
|
||||
except Exception, e:
|
||||
return False
|
||||
|
||||
def reflecttable(self, connection, table):
|
||||
def reflecttable(self, connection, table, include_columns):
|
||||
# This is defined in the function, as it relies on win32com constants,
|
||||
# that aren't imported until dbapi method is called
|
||||
if not hasattr(self, 'ischema_names'):
|
||||
@@ -364,13 +350,11 @@ class AccessCompiler(compiler.DefaultCompiler):
|
||||
"""Access uses "mod" instead of "%" """
|
||||
return binary.operator == '%' and 'mod' or binary.operator
|
||||
|
||||
def visit_select(self, select):
|
||||
"""Label function calls, so they return a name in cursor.description"""
|
||||
for i,c in enumerate(select._raw_columns):
|
||||
if isinstance(c, sql._Function):
|
||||
select._raw_columns[i] = c.label(c.name + "_" + hex(random.randint(0, 65535))[2:])
|
||||
|
||||
super(AccessCompiler, self).visit_select(select)
|
||||
def label_select_column(self, select, column):
|
||||
if isinstance(column, expression._Function):
|
||||
return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:])
|
||||
else:
|
||||
return super(AccessCompiler, self).label_select_column(select, column)
|
||||
|
||||
function_rewrites = {'current_date': 'now',
|
||||
'current_timestamp': 'now',
|
||||
@@ -418,9 +402,16 @@ class AccessDefaultRunner(base.DefaultRunner):
|
||||
pass
|
||||
|
||||
class AccessIdentifierPreparer(compiler.IdentifierPreparer):
|
||||
reserved_words = compiler.RESERVED_WORDS.copy()
|
||||
reserved_words.update(['value', 'text'])
|
||||
def __init__(self, dialect):
|
||||
super(AccessIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
|
||||
|
||||
|
||||
dialect = AccessDialect
|
||||
dialect.poolclass = pool.SingletonThreadPool
|
||||
dialect.statement_compiler = AccessCompiler
|
||||
dialect.schemagenerator = AccessSchemaGenerator
|
||||
dialect.schemadropper = AccessSchemaDropper
|
||||
dialect.preparer = AccessIdentifierPreparer
|
||||
dialect.defaultrunner = AccessDefaultRunner
|
||||
|
||||
Reference in New Issue
Block a user