Major refactoring of the MSSQL dialect. Thanks zzzeek.

Includes simplifying the IDENTITY handling and the exception handling. Also
includes a cleanup of the connection string handling for pyodbc to favor
the DSN syntax.
This commit is contained in:
Michael Trier
2008-12-22 20:20:55 +00:00
parent 4bb8489073
commit 886ddcd12d
6 changed files with 387 additions and 211 deletions
+4
View File
@@ -196,6 +196,10 @@ CHANGES
new doc section "Custom Comparators".
- mssql
- Changes to the connection string parameters favor DSN as the
default specification for pyodbc. See the mssql.py docstring
for detailed usage instructions.
- Added experimental support of savepoints. It
currently does not work fully with sessions.
+297 -195
View File
@@ -1,48 +1,192 @@
# mssql.py
"""MSSQL backend, thru either pymssq, adodbapi or pyodbc interfaces.
"""Support for the Microsoft SQL Server database.
* ``IDENTITY`` columns are supported by using SA ``schema.Sequence()``
objects. In other words::
Driver
------
The MSSQL dialect will work with three different available drivers:
* *pymssql* - http://pymssql.sourceforge.net/
* *pyodbc* - http://pyodbc.sourceforge.net/. This is the recommeded
driver.
* *adodbapi* - http://adodbapi.sourceforge.net/
Drivers are loaded in the order listed above based on availability.
Currently the pyodbc driver offers the greatest level of
compatibility.
Connecting
----------
Connecting with create_engine() uses the standard URL approach of
``mssql://user:pass@host/dbname[?key=value&key=value...]``.
If the database name is present, the tokens are converted to a
connection string with the specified values. If the database is not
present, then the host token is taken directly as the DSN name.
Examples of pyodbc connection string URLs:
* *mssql://mydsn* - connects using the specified DSN named ``mydsn``.
The connection string that is created will appear like::
dsn=mydsn;TrustedConnection=Yes
* *mssql://user:pass@mydsn* - connects using the DSN named
``mydsn`` passing in the ``UID`` and ``PWD`` information. The
connection string that is created will appear like::
dsn=mydsn;UID=user;PWD=pass
* *mssql://user:pass@mydsn/?LANGUAGE=us_english* - connects
using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD``
information, plus the additional connection configuration option
``LANGUAGE``. The connection string that is created will appear
like::
dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english
* *mssql://user:pass@host/db* - connects using a connection string
dynamically created that would appear like::
DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass
* *mssql://user:pass@host:123/db* - connects using a connection
string that is dynamically created, which also includes the port
information using the comma syntax. If your connection string
requires the port information to be passed as a ``port`` keyword
see the next example. This will create the following connection
string::
DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass
* *mssql://user:pass@host/db?port=123* - connects using a connection
string that is dynamically created that includes the port
information as a separate ``port`` keyword. This will create the
following connection string::
DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass;port=123
If you require a connection string that is outside the options
presented above, use the ``odbc_connect`` keyword to pass in a
urlencoded connection string. What gets passed in will be urldecoded
and passed directly.
For example::
mssql:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb
would create the following connection string::
dsn=mydsn;Database=db
Encoding your connection string can be easily accomplished through
the python shell. For example::
>>> import urllib
>>> urllib.quote_plus('dsn=mydsn;Database=db')
'dsn%3Dmydsn%3BDatabase%3Ddb'
Additional arguments which may be specified either as query string
arguments on the URL, or as keyword argument to
:func:`~sqlalchemy.create_engine()` are:
* *auto_identity_insert* - enables support for IDENTITY inserts by
automatically turning IDENTITY INSERT ON and OFF as required.
Defaults to ``True`.
* *query_timeout* - allows you to override the default query timeout.
Defaults to ``None``. This is only supported on pymssql.
* *text_as_varchar* - if enabled this will treat all TEXT column
types as their equivalent VARCHAR(max) type. This is often used if
you need to compare a VARCHAR to a TEXT field, which is not
supported directly on MSSQL. Defaults to ``False``.
* *use_scope_identity* - allows you to specify that SCOPE_IDENTITY
should be used in place of the non-scoped version @@IDENTITY.
Defaults to ``False``. On pymssql this defaults to ``True``, and on
pyodbc this defaults to ``True`` if the version of pyodbc being
used supports it.
* *has_window_funcs* - indicates whether or not window functions
(LIMIT and OFFSET) are supported on the version of MSSQL being
used. If you're running MSSQL 2005 or later turn this on to get
OFFSET support. Defaults to ``False``.
* *max_identifier_length* - allows you to se the maximum length of
identfiers supported by the database. Defaults to 128. For pymssql
the default is 30.
* *schema_name* - use to set the schema name. Defaults to ``dbo``.
Auto Increment Behavior
-----------------------
``IDENTITY`` columns are supported by using SQLAlchemy
``schema.Sequence()`` objects. In other words::
Table('test', mss_engine,
Column('id', Integer, Sequence('blah',100,10), primary_key=True),
Column('id', Integer,
Sequence('blah',100,10), primary_key=True),
Column('name', String(20))
).create()
would yield::
would yield::
CREATE TABLE test (
id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY,
name VARCHAR(20) NULL,
)
Note that the start & increment values for sequences are optional
and will default to 1,1.
Note that the ``start`` and ``increment`` values for sequences are
optional and will default to 1,1.
* Support for ``SET IDENTITY_INSERT ON`` mode (automagic on / off for
``INSERT`` s)
* Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on ``INSERT``
* Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on
``INSERT``
* ``select._limit`` implemented as ``SELECT TOP n``
LIMIT/OFFSET Support
--------------------
* Experimental implemention of LIMIT / OFFSET with row_number()
MSSQL has no support for the LIMIT or OFFSET keysowrds. LIMIT is
supported directly through the ``TOP`` Transact SQL keyword::
* Support for three levels of column nullability provided. The default
nullability allows nulls::
select.limit
will yield::
SELECT TOP n
If the ``has_window_funcs`` flag is set then LIMIT with OFFSET
support is available through the ``ROW_NUMBER OVER`` construct. This
construct requires an ``ORDER BY`` to be specified as well and is
only available on MSSQL 2005 and later.
Nullability
-----------
MSSQL has support for three levels of column nullability. The default
nullability allows nulls and is explicit in the CREATE TABLE
construct::
name VARCHAR(20) NULL
If ``nullable=None`` is specified then no specification is made. In other
words the database's configured default is used. This will render::
If ``nullable=None`` is specified then no specification is made. In
other words the database's configured default is used. This will
render::
name VARCHAR(20)
If ``nullable`` is True or False then the column will be ``NULL` or
``NOT NULL`` respectively.
If ``nullable`` is ``True`` or ``False`` then the column will be
``NULL` or ``NOT NULL`` respectively.
Known issues / TODO:
Known Issues
------------
* No support for more than one ``IDENTITY`` column per table
@@ -50,7 +194,7 @@ Known issues / TODO:
does **not** work around
"""
import datetime, operator, re, sys
import datetime, operator, re, sys, urllib
from sqlalchemy import sql, schema, exc, util
from sqlalchemy.sql import compiler, expression, operators as sqlops, functions as sql_functions
@@ -299,77 +443,92 @@ class MSVariant(sqltypes.TypeEngine):
def get_col_spec(self):
return "SQL_VARIANT"
class MSSQLExecutionContext(default.DefaultExecutionContext):
def __init__(self, *args, **kwargs):
self.IINSERT = self.HASIDENT = False
super(MSSQLExecutionContext, self).__init__(*args, **kwargs)
def _has_implicit_sequence(self, column):
if column.primary_key and column.autoincrement:
if isinstance(column.type, sqltypes.Integer) and not column.foreign_keys:
if column.default is None or (isinstance(column.default, schema.Sequence) and \
column.default.optional):
return True
return False
def _has_implicit_sequence(column):
return column.primary_key and \
column.autoincrement and \
isinstance(column.type, sqltypes.Integer) and \
not column.foreign_keys and \
(
column.default is None or
(
isinstance(column.default, schema.Sequence) and
column.default.optional)
)
def _table_sequence_column(tbl):
if not hasattr(tbl, '_ms_has_sequence'):
tbl._ms_has_sequence = None
for column in tbl.c:
if getattr(column, 'sequence', False) or _has_implicit_sequence(column):
tbl._ms_has_sequence = column
break
return tbl._ms_has_sequence
class MSSQLExecutionContext(default.DefaultExecutionContext):
IINSERT = False
HASIDENT = False
def pre_exec(self):
"""MS-SQL has a special mode for inserting non-NULL values
into IDENTITY columns.
"""Activate IDENTITY_INSERT if needed."""
Activate it if the feature is turned on and needed.
"""
if self.compiled.isinsert:
tbl = self.compiled.statement.table
if not hasattr(tbl, 'has_sequence'):
tbl.has_sequence = None
for column in tbl.c:
if getattr(column, 'sequence', False) or self._has_implicit_sequence(column):
tbl.has_sequence = column
break
self.HASIDENT = bool(tbl.has_sequence)
seq_column = _table_sequence_column(tbl)
self.HASIDENT = bool(seq_column)
if self.dialect.auto_identity_insert and self.HASIDENT:
if isinstance(self.compiled_parameters, list):
self.IINSERT = tbl.has_sequence.key in self.compiled_parameters[0]
else:
self.IINSERT = tbl.has_sequence.key in self.compiled_parameters
self.IINSERT = tbl._ms_has_sequence.key in self.compiled_parameters[0]
else:
self.IINSERT = False
if self.IINSERT:
self.cursor.execute("SET IDENTITY_INSERT %s ON" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
self.cursor.execute("SET IDENTITY_INSERT %s ON" %
self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
super(MSSQLExecutionContext, self).pre_exec()
def handle_dbapi_exception(self, e):
if self.IINSERT:
try:
self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
except:
pass
def post_exec(self):
"""Turn off the INDENTITY_INSERT mode if it's been activated,
and fetch recently inserted IDENTIFY values (works only for
one column).
"""
"""Disable IDENTITY_INSERT if enabled."""
if self.compiled.isinsert and (not self.executemany) and self.HASIDENT and not self.IINSERT:
if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
if self.compiled.isinsert and not self.executemany and self.HASIDENT and not self.IINSERT:
if not self._last_inserted_ids or self._last_inserted_ids[0] is None:
if self.dialect.use_scope_identity:
self.cursor.execute("SELECT scope_identity() AS lastrowid")
else:
self.cursor.execute("SELECT @@identity AS lastrowid")
row = self.cursor.fetchone()
self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:]
super(MSSQLExecutionContext, self).post_exec()
if self.IINSERT:
self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext):
def pre_exec(self):
"""where appropriate, issue "select scope_identity()" in the same statement"""
super(MSSQLExecutionContext_pyodbc, self).pre_exec()
if self.compiled.isinsert and self.HASIDENT and (not self.IINSERT) \
if self.compiled.isinsert and self.HASIDENT and not self.IINSERT \
and len(self.parameters) == 1 and self.dialect.use_scope_identity:
self.statement += "; select scope_identity()"
def post_exec(self):
if self.compiled.isinsert and self.HASIDENT and (not self.IINSERT) and self.dialect.use_scope_identity:
# do nothing - id was fetched in dialect.do_execute()
pass
if self.HASIDENT and not self.IINSERT and self.dialect.use_scope_identity and not self.executemany:
import pyodbc
# Fetch the last inserted id from the manipulated statement
# We may have to skip over a number of result sets with no data (due to triggers, etc.)
while True:
try:
row = self.cursor.fetchone()
break
except pyodbc.Error, e:
self.cursor.nextset()
self._last_inserted_ids = [int(row[0])]
else:
super(MSSQLExecutionContext_pyodbc, self).post_exec()
@@ -377,7 +536,13 @@ class MSSQLDialect(default.DefaultDialect):
name = 'mssql'
supports_default_values = True
supports_empty_insert = False
auto_identity_insert = True
execution_ctx_cls = MSSQLExecutionContext
text_as_varchar = False
use_scope_identity = False
has_window_funcs = False
max_identifier_length = 128
schema_name = "dbo"
colspecs = {
sqltypes.Unicode : MSNVarchar,
@@ -426,23 +591,33 @@ class MSSQLDialect(default.DefaultDialect):
'sql_variant': MSVariant,
}
def __new__(cls, dbapi=None, *args, **kwargs):
if cls != MSSQLDialect:
def __new__(cls, *args, **kwargs):
if cls is not MSSQLDialect:
# this gets called with the dialect specific class
return super(MSSQLDialect, cls).__new__(cls, *args, **kwargs)
dbapi = kwargs.get('dbapi', None)
if dbapi:
dialect = dialect_mapping.get(dbapi.__name__)
return dialect(*args, **kwargs)
return dialect(**kwargs)
else:
return object.__new__(cls, *args, **kwargs)
def __init__(self, auto_identity_insert=True, **params):
super(MSSQLDialect, self).__init__(**params)
self.auto_identity_insert = auto_identity_insert
self.text_as_varchar = False
self.use_scope_identity = False
self.has_window_funcs = False
self.set_default_schema_name("dbo")
def __init__(self,
auto_identity_insert=True, query_timeout=None, text_as_varchar=False,
use_scope_identity=False, has_window_funcs=False, max_identifier_length=None,
schema_name="dbo", **opts):
self.auto_identity_insert = bool(auto_identity_insert)
self.query_timeout = int(query_timeout or 0)
self.schema_name = schema_name
# to-do: the options below should use server version introspection to set themselves on connection
self.text_as_varchar = bool(text_as_varchar)
self.use_scope_identity = bool(use_scope_identity)
self.has_window_funcs = bool(has_window_funcs)
self.max_identifier_length = int(max_identifier_length or 0) or 128
super(MSSQLDialect, self).__init__(**opts)
@classmethod
def dbapi(cls, module_name=None):
if module_name:
try:
@@ -458,8 +633,8 @@ class MSSQLDialect(default.DefaultDialect):
pass
else:
raise ImportError('No DBAPI module detected for MSSQL - please install pyodbc, pymssql, or adodbapi')
dbapi = classmethod(dbapi)
@base.connection_memoize(('mssql', 'server_version_info'))
def server_version_info(self, connection):
"""A tuple of the database server version.
@@ -472,14 +647,11 @@ class MSSQLDialect(default.DefaultDialect):
cached per-Connection.
"""
return connection.dialect._server_version_info(connection.connection)
server_version_info = base.connection_memoize(
('mssql', 'server_version_info'))(server_version_info)
def _server_version_info(self, dbapi_con):
"""Return a tuple of the database's version number."""
raise NotImplementedError()
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
opts.update(url.query)
@@ -493,7 +665,7 @@ class MSSQLDialect(default.DefaultDialect):
self.use_scope_identity = bool(int(opts.pop('use_scope_identity')))
if 'has_window_funcs' in opts:
self.has_window_funcs = bool(int(opts.pop('has_window_funcs')))
return self.make_connect_string(opts)
return self.make_connect_string(opts, url.query)
def type_descriptor(self, typeobj):
newobj = sqltypes.adapt_type(typeobj, self.colspecs)
@@ -505,51 +677,10 @@ class MSSQLDialect(default.DefaultDialect):
def get_default_schema_name(self, connection):
return self.schema_name
def set_default_schema_name(self, schema_name):
self.schema_name = schema_name
def last_inserted_ids(self):
return self.context.last_inserted_ids
def do_execute(self, cursor, statement, params, context=None, **kwargs):
if params == {}:
params = ()
try:
super(MSSQLDialect, self).do_execute(cursor, statement, params, context=context, **kwargs)
finally:
if context.IINSERT:
cursor.execute("SET IDENTITY_INSERT %s OFF" % self.identifier_preparer.format_table(context.compiled.statement.table))
def do_executemany(self, cursor, statement, params, context=None, **kwargs):
try:
super(MSSQLDialect, self).do_executemany(cursor, statement, params, context=context, **kwargs)
finally:
if context.IINSERT:
cursor.execute("SET IDENTITY_INSERT %s OFF" % self.identifier_preparer.format_table(context.compiled.statement.table))
def _execute(self, c, statement, parameters):
try:
if parameters == {}:
parameters = ()
c.execute(statement, parameters)
self.context.rowcount = c.rowcount
c.DBPROP_COMMITPRESERVE = "Y"
except Exception, e:
raise exc.DBAPIError.instance(statement, parameters, e)
def table_names(self, connection, schema):
from sqlalchemy.databases import information_schema as ischema
return ischema.table_names(connection, schema)
def raw_connection(self, connection):
"""Pull the raw pymmsql connection out--sensative to "pool.ConnectionFairy" and pymssql.pymssqlCnx Classes"""
try:
# TODO: probably want to move this to individual dialect subclasses to
# save on the exception throw + simplify
return connection.connection.__dict__['_pymssqlCnx__cnx']
except:
return connection.connection.adoConn
def uppercase_table(self, t):
# convert all names to uppercase -- fixes refs to INFORMATION_SCHEMA for case-senstive DBs, and won't matter for case-insensitive
t.name = t.name.upper()
@@ -559,6 +690,7 @@ class MSSQLDialect(default.DefaultDialect):
c.name = c.name.upper()
return t
def has_table(self, connection, tablename, schema=None):
import sqlalchemy.databases.information_schema as ischema
@@ -645,7 +777,7 @@ class MSSQLDialect(default.DefaultDialect):
ic = table.c[col_name]
ic.autoincrement = True
# setup a psuedo-sequence to represent the identity attribute - we interpret this at table.create() time as the identity attribute
ic.sequence = schema.Sequence(ic.name + '_identity')
ic.sequence = schema.Sequence(ic.name + '_identity', 1, 1)
# MSSQL: only one identity per table allowed
cursor.close()
break
@@ -722,16 +854,13 @@ class MSSQLDialect_pymssql(MSSQLDialect):
supports_sane_rowcount = False
max_identifier_length = 30
@classmethod
def import_dbapi(cls):
import pymssql as module
# pymmsql doesn't have a Binary method. we use string
# TODO: monkeypatching here is less than ideal
module.Binary = lambda st: str(st)
return module
import_dbapi = classmethod(import_dbapi)
ischema_names = MSSQLDialect.ischema_names.copy()
def __init__(self, **params):
super(MSSQLDialect_pymssql, self).__init__(**params)
@@ -739,23 +868,16 @@ class MSSQLDialect_pymssql(MSSQLDialect):
# pymssql understands only ascii
if self.convert_unicode:
util.warn("pymssql does not support unicode")
self.encoding = params.get('encoding', 'ascii')
def do_rollback(self, connection):
# pymssql throws an error on repeated rollbacks. Ignore it.
# TODO: this is normal behavior for most DBs. are we sure we want to ignore it ?
try:
connection.rollback()
except:
pass
def create_connect_args(self, url):
r = super(MSSQLDialect_pymssql, self).create_connect_args(url)
if hasattr(self, 'query_timeout'):
self.dbapi._mssql.set_query_timeout(self.query_timeout)
return r
def make_connect_string(self, keys):
def make_connect_string(self, keys, query):
if keys.get('port'):
# pymssql expects port as host:port, not a separate arg
keys['host'] = ''.join([keys.get('host', ''), ':', str(keys['port'])])
@@ -776,6 +898,7 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
def __init__(self, **params):
super(MSSQLDialect_pyodbc, self).__init__(**params)
# FIXME: scope_identity sniff should look at server version, not the ODBC driver
# whether use_scope_identity will work depends on the version of pyodbc
try:
import pyodbc
@@ -783,10 +906,10 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
except:
pass
@classmethod
def import_dbapi(cls):
import pyodbc as module
return module
import_dbapi = classmethod(import_dbapi)
colspecs = MSSQLDialect.colspecs.copy()
if supports_unicode:
@@ -800,45 +923,41 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
ischema_names['smalldatetime'] = MSDate_pyodbc
ischema_names['datetime'] = MSDateTime_pyodbc
def make_connect_string(self, keys):
def make_connect_string(self, keys, query):
if 'max_identifier_length' in keys:
self.max_identifier_length = int(keys.pop('max_identifier_length'))
if 'dsn' in keys:
connectors = ['dsn=%s' % keys.pop('dsn')]
if 'odbc_connect' in keys:
connectors = [urllib.unquote_plus(keys.pop('odbc_connect'))]
else:
port = ''
if 'port' in keys and (
keys.get('driver', 'SQL Server') == 'SQL Server'):
port = ',%d' % int(keys.pop('port'))
dsn_connection = 'dsn' in keys or ('host' in keys and 'database' not in keys)
if dsn_connection:
connectors= ['dsn=%s' % (keys.pop('host', '') or keys.pop('dsn', ''))]
else:
port = ''
if 'port' in keys and not 'port' in query:
port = ',%d' % int(keys.pop('port'))
connectors = ["DRIVER={%s}" % keys.pop('driver', 'SQL Server'),
'Server=%s%s' % (keys.pop('host', ''), port),
'Database=%s' % keys.pop('database', '') ]
connectors = ["DRIVER={%s}" % keys.pop('driver', 'SQL Server'),
'Server=%s%s' % (keys.pop('host', ''), port),
'Database=%s' % keys.pop('database', '') ]
if 'port' in keys and not port:
connectors.append('Port=%d' % int(keys.pop('port')))
user = keys.pop("user", None)
if user:
connectors.append("UID=%s" % user)
connectors.append("PWD=%s" % keys.pop('password', ''))
else:
connectors.append("TrustedConnection=Yes")
user = keys.pop("user", None)
if user:
connectors.append("UID=%s" % user)
connectors.append("PWD=%s" % keys.pop('password', ''))
else:
connectors.append("TrustedConnection=Yes")
# if set to 'Yes', the ODBC layer will try to automagically convert
# textual data from your database encoding to your client encoding
# This should obviously be set to 'No' if you query a cp1253 encoded
# database from a latin1 client...
if 'odbc_autotranslate' in keys:
connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate"))
# if set to 'Yes', the ODBC layer will try to automagically convert
# textual data from your database encoding to your client encoding
# This should obviously be set to 'No' if you query a cp1253 encoded
# database from a latin1 client...
if 'odbc_autotranslate' in keys:
connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate"))
connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()])
# Allow specification of partial ODBC connect string
if 'odbc_options' in keys:
odbc_options=keys.pop('odbc_options')
if odbc_options[0]=="'" and odbc_options[-1]=="'":
odbc_options=odbc_options[1:-1]
connectors.append(odbc_options)
connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()])
return [[";".join (connectors)], {}]
def is_disconnect(self, e):
@@ -850,23 +969,8 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
return False
def do_execute(self, cursor, statement, parameters, context=None, **kwargs):
super(MSSQLDialect_pyodbc, self).do_execute(cursor, statement, parameters, context=context, **kwargs)
if context and context.HASIDENT and (not context.IINSERT) and context.dialect.use_scope_identity:
import pyodbc
# Fetch the last inserted id from the manipulated statement
# We may have to skip over a number of result sets with no data (due to triggers, etc.)
while True:
try:
row = cursor.fetchone()
break
except pyodbc.Error, e:
cursor.nextset()
context._last_inserted_ids = [int(row[0])]
def _server_version_info(self, dbapi_con):
"""Convert a pyodbc SQL_DBMS_VER string into a tuple."""
version = []
r = re.compile('[.\-]')
for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)):
@@ -882,10 +986,10 @@ class MSSQLDialect_adodbapi(MSSQLDialect):
supports_unicode = sys.maxunicode == 65535
supports_unicode_statements = True
@classmethod
def import_dbapi(cls):
import adodbapi as module
return module
import_dbapi = classmethod(import_dbapi)
colspecs = MSSQLDialect.colspecs.copy()
colspecs[sqltypes.Unicode] = AdoMSNVarchar
@@ -895,7 +999,7 @@ class MSSQLDialect_adodbapi(MSSQLDialect):
ischema_names['nvarchar'] = AdoMSNVarchar
ischema_names['datetime'] = MSDateTime_adodbapi
def make_connect_string(self, keys):
def make_connect_string(self, keys, query):
connectors = ["Provider=SQLOLEDB"]
if 'port' in keys:
connectors.append ("Data Source=%s, %s" % (keys.get("host"), keys.get("port")))
@@ -963,7 +1067,7 @@ class MSSQLCompiler(compiler.DefaultCompiler):
so tries to wrap it in a subquery with ``row_number()`` criterion.
"""
if self.dialect.has_window_funcs and (not getattr(select, '_mssql_visit', None)) and select._offset:
if self.dialect.has_window_funcs and not getattr(select, '_mssql_visit', None) and select._offset:
# to use ROW_NUMBER(), an ORDER BY is required.
orderby = self.process(select._order_by_clause)
if not orderby:
@@ -1073,21 +1177,25 @@ class MSSQLSchemaGenerator(compiler.SchemaGenerator):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
# install a IDENTITY Sequence if we have an implicit IDENTITY column
if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \
column.autoincrement and isinstance(column.type, sqltypes.Integer) and not column.foreign_keys:
if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional):
column.sequence = schema.Sequence(column.name + '_seq')
if column.nullable is not None:
if not column.nullable:
colspec += " NOT NULL"
else:
colspec += " NULL"
if not column.table:
raise exc.InvalidRequestError("mssql requires Table-bound columns in order to generate DDL")
seq_col = _table_sequence_column(column.table)
if hasattr(column, 'sequence'):
column.table.has_sequence = column
colspec += " IDENTITY(%s,%s)" % (column.sequence.start or 1, column.sequence.increment or 1)
# install a IDENTITY Sequence if we have an implicit IDENTITY column
if seq_col is column:
sequence = getattr(column, 'sequence', None)
if sequence:
start, increment = sequence.start or 1, sequence.increment or 1
else:
start, increment = 1, 1
colspec += " IDENTITY(%s,%s)" % (start, increment)
else:
default = self.get_column_default_string(column)
if default is not None:
@@ -1104,11 +1212,6 @@ class MSSQLSchemaDropper(compiler.SchemaDropper):
self.execute()
class MSSQLDefaultRunner(base.DefaultRunner):
# TODO: does ms-sql have standalone sequences ?
# A: No, only auto-incrementing IDENTITY property of a column
pass
class MSSQLIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = compiler.IdentifierPreparer.reserved_words.union(MSSQL_RESERVED_WORDS)
@@ -1116,7 +1219,7 @@ class MSSQLIdentifierPreparer(compiler.IdentifierPreparer):
super(MSSQLIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
def _escape_identifier(self, value):
#TODO: determin MSSQL's escapeing rules
#TODO: determine MSSQL's escaping rules
return value
dialect = MSSQLDialect
@@ -1124,4 +1227,3 @@ dialect.statement_compiler = MSSQLCompiler
dialect.schemagenerator = MSSQLSchemaGenerator
dialect.schemadropper = MSSQLSchemaDropper
dialect.preparer = MSSQLIdentifierPreparer
dialect.defaultrunner = MSSQLDefaultRunner
+20 -11
View File
@@ -350,6 +350,11 @@ class ExecutionContext(object):
raise NotImplementedError()
def handle_dbapi_exception(self, e):
"""Receive a DBAPI exception which occured upon execute, result fetch, etc."""
raise NotImplementedError()
def should_autocommit_text(self, statement):
"""Parse the given textual statement and return True if it refers to a "committable" statement"""
@@ -714,7 +719,7 @@ class Connection(Connectable):
try:
self.engine.dialect.do_begin(self.connection)
except Exception, e:
self._handle_dbapi_exception(e, None, None, None)
self._handle_dbapi_exception(e, None, None, None, None)
raise
def _rollback_impl(self):
@@ -725,7 +730,7 @@ class Connection(Connectable):
self.engine.dialect.do_rollback(self.connection)
self.__transaction = None
except Exception, e:
self._handle_dbapi_exception(e, None, None, None)
self._handle_dbapi_exception(e, None, None, None, None)
raise
else:
self.__transaction = None
@@ -737,7 +742,7 @@ class Connection(Connectable):
self.engine.dialect.do_commit(self.connection)
self.__transaction = None
except Exception, e:
self._handle_dbapi_exception(e, None, None, None)
self._handle_dbapi_exception(e, None, None, None, None)
raise
def _savepoint_impl(self, name=None):
@@ -897,13 +902,17 @@ class Connection(Connectable):
schema_item = None
return ddl(None, schema_item, self, *params, **multiparams)
def _handle_dbapi_exception(self, e, statement, parameters, cursor):
def _handle_dbapi_exception(self, e, statement, parameters, cursor, context):
if getattr(self, '_reentrant_error', False):
raise exc.DBAPIError.instance(None, None, e)
self._reentrant_error = True
try:
if not isinstance(e, self.dialect.dbapi.Error):
return
if context:
context.handle_dbapi_exception(e)
is_disconnect = self.dialect.is_disconnect(e)
if is_disconnect:
self.invalidate(e)
@@ -923,7 +932,7 @@ class Connection(Connectable):
dialect = self.engine.dialect
return dialect.execution_ctx_cls(dialect, connection=self, **kwargs)
except Exception, e:
self._handle_dbapi_exception(e, kwargs.get('statement', None), kwargs.get('parameters', None), None)
self._handle_dbapi_exception(e, kwargs.get('statement', None), kwargs.get('parameters', None), None, None)
raise
def _cursor_execute(self, cursor, statement, parameters, context=None):
@@ -933,7 +942,7 @@ class Connection(Connectable):
try:
self.dialect.do_execute(cursor, statement, parameters, context=context)
except Exception, e:
self._handle_dbapi_exception(e, statement, parameters, cursor)
self._handle_dbapi_exception(e, statement, parameters, cursor, context)
raise
def _cursor_executemany(self, cursor, statement, parameters, context=None):
@@ -943,7 +952,7 @@ class Connection(Connectable):
try:
self.dialect.do_executemany(cursor, statement, parameters, context=context)
except Exception, e:
self._handle_dbapi_exception(e, statement, parameters, cursor)
self._handle_dbapi_exception(e, statement, parameters, cursor, context)
raise
# poor man's multimethod/generic function thingy
@@ -1623,7 +1632,7 @@ class ResultProxy(object):
self.close()
return l
except Exception, e:
self.connection._handle_dbapi_exception(e, None, None, self.cursor)
self.connection._handle_dbapi_exception(e, None, None, self.cursor, self.context)
raise
def fetchmany(self, size=None):
@@ -1636,7 +1645,7 @@ class ResultProxy(object):
self.close()
return l
except Exception, e:
self.connection._handle_dbapi_exception(e, None, None, self.cursor)
self.connection._handle_dbapi_exception(e, None, None, self.cursor, self.context)
raise
def fetchone(self):
@@ -1649,7 +1658,7 @@ class ResultProxy(object):
self.close()
return None
except Exception, e:
self.connection._handle_dbapi_exception(e, None, None, self.cursor)
self.connection._handle_dbapi_exception(e, None, None, self.cursor, self.context)
raise
def scalar(self):
@@ -1657,7 +1666,7 @@ class ResultProxy(object):
try:
row = self._fetchone_impl()
except Exception, e:
self.connection._handle_dbapi_exception(e, None, None, self.cursor)
self.connection._handle_dbapi_exception(e, None, None, self.cursor, self.context)
raise
try:
+5 -2
View File
@@ -259,6 +259,9 @@ class DefaultExecutionContext(base.ExecutionContext):
def post_exec(self):
pass
def handle_dbapi_exception(self, e):
pass
def get_result_proxy(self):
return base.ResultProxy(self)
@@ -306,7 +309,7 @@ class DefaultExecutionContext(base.ExecutionContext):
try:
self.cursor.setinputsizes(*inputsizes)
except Exception, e:
self._connection._handle_dbapi_exception(e, None, None, None)
self._connection._handle_dbapi_exception(e, None, None, None, self)
raise
else:
inputsizes = {}
@@ -318,7 +321,7 @@ class DefaultExecutionContext(base.ExecutionContext):
try:
self.cursor.setinputsizes(**inputsizes)
except Exception, e:
self._connection._handle_dbapi_exception(e, None, None, None)
self._connection._handle_dbapi_exception(e, None, None, None, self)
raise
def __process_defaults(self):
+60 -2
View File
@@ -251,7 +251,10 @@ class GenerativeQueryTest(TestBase):
class SchemaTest(TestBase):
def setUp(self):
self.column = Column('test_column', Integer)
t = Table('sometable', MetaData(),
Column('test_column', Integer)
)
self.column = t.c.test_column
def test_that_mssql_default_nullability_emits_null(self):
schemagenerator = \
@@ -399,18 +402,73 @@ class MatchTest(TestBase, AssertsCompiledSQL):
class ParseConnectTest(TestBase, AssertsCompiledSQL):
__only_on__ = 'mssql'
def test_pyodbc_connect_dsn_trusted(self):
u = url.make_url('mssql://mydsn')
dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
self.assertEquals([['dsn=mydsn;TrustedConnection=Yes'], {}], connection)
def test_pyodbc_connect_old_style_dsn_trusted(self):
u = url.make_url('mssql:///?dsn=mydsn')
dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
self.assertEquals([['dsn=mydsn;TrustedConnection=Yes'], {}], connection)
def test_pyodbc_connect_dsn_non_trusted(self):
u = url.make_url('mssql://username:password@mydsn')
dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
self.assertEquals([['dsn=mydsn;UID=username;PWD=password'], {}], connection)
def test_pyodbc_connect_dsn_extra(self):
u = url.make_url('mssql://username:password@mydsn/?LANGUAGE=us_english&foo=bar')
dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
self.assertEquals([['dsn=mydsn;UID=username;PWD=password;LANGUAGE=us_english;foo=bar'], {}], connection)
def test_pyodbc_connect(self):
u = url.make_url('mssql://username:password@hostspec/database')
dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
def test_pyodbc_connect_comma_port(self):
u = url.make_url('mssql://username:password@hostspec:12345/database')
dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
self.assertEquals([['DRIVER={SQL Server};Server=hostspec,12345;Database=database;UID=username;PWD=password'], {}], connection)
def test_pyodbc_connect_config_port(self):
u = url.make_url('mssql://username:password@hostspec/database?port=12345')
dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;port=12345'], {}], connection)
def test_pyodbc_extra_connect(self):
u = url.make_url('mssql://username:password@hostspec/database?LANGUAGE=us_english&foo=bar')
dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;foo=bar;LANGUAGE=us_english'], {}], connection)
def test_pyodbc_odbc_connect(self):
u = url.make_url('mssql:///?odbc_connect=DRIVER%3D%7BSQL+Server%7D%3BServer%3Dhostspec%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword')
dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
def test_pyodbc_odbc_connect_with_dsn(self):
u = url.make_url('mssql:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword')
dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
self.assertEquals([['dsn=mydsn;Database=database;UID=username;PWD=password'], {}], connection)
def test_pyodbc_odbc_connect_ignores_other_values(self):
u = url.make_url('mssql://userdiff:passdiff@localhost/dbdiff?odbc_connect=DRIVER%3D%7BSQL+Server%7D%3BServer%3Dhostspec%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword')
dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
class TypesTest(TestBase):
__only_on__ = 'mssql'
@@ -443,7 +501,7 @@ class TypesTest(TestBase):
numeric_table.insert().execute(numericcol=Decimal('1E-7'))
numeric_table.insert().execute(numericcol=Decimal('1E-8'))
except:
assert False
assert False
if __name__ == "__main__":
testenv.main()
+1 -1
View File
@@ -59,7 +59,7 @@ class QueryTest(TestBase):
result = table.insert().execute(**values)
ret = values.copy()
for col, id in zip(table.primary_key, result.last_inserted_ids()):
ret[col.key] = id