mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-24 09:31:48 -04:00
f94d9efecc
- removed sqlite version warning, all tests pass 100% with py2.5's older sqlite lib - fixed dynamic test for py2.5
383 lines
13 KiB
Python
383 lines
13 KiB
Python
# sqlite.py
|
|
# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com
|
|
#
|
|
# This module is part of SQLAlchemy and is released under
|
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
|
|
|
|
|
import re
|
|
|
|
from sqlalchemy import schema, ansisql, exceptions, pool, PassiveDefault
|
|
import sqlalchemy.engine.default as default
|
|
import sqlalchemy.types as sqltypes
|
|
import datetime,time, warnings
|
|
import sqlalchemy.util as util
|
|
|
|
|
|
class SLNumeric(sqltypes.Numeric):
|
|
def get_col_spec(self):
|
|
if self.precision is None:
|
|
return "NUMERIC"
|
|
else:
|
|
return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
|
|
|
|
class SLInteger(sqltypes.Integer):
|
|
def get_col_spec(self):
|
|
return "INTEGER"
|
|
|
|
class SLSmallInteger(sqltypes.Smallinteger):
|
|
def get_col_spec(self):
|
|
return "SMALLINT"
|
|
|
|
class DateTimeMixin(object):
|
|
def convert_bind_param(self, value, dialect):
|
|
if value is not None:
|
|
if getattr(value, 'microsecond', None) is not None:
|
|
return value.strftime(self.__format__ + "." + str(value.microsecond))
|
|
else:
|
|
return value.strftime(self.__format__)
|
|
else:
|
|
return None
|
|
|
|
def _cvt(self, value, dialect):
|
|
if value is None:
|
|
return None
|
|
try:
|
|
(value, microsecond) = value.split('.')
|
|
microsecond = int(microsecond)
|
|
except ValueError:
|
|
(value, microsecond) = (value, 0)
|
|
return time.strptime(value, self.__format__)[0:6] + (microsecond,)
|
|
|
|
class SLDateTime(DateTimeMixin,sqltypes.DateTime):
|
|
__format__ = "%Y-%m-%d %H:%M:%S"
|
|
|
|
def get_col_spec(self):
|
|
return "TIMESTAMP"
|
|
|
|
def convert_result_value(self, value, dialect):
|
|
tup = self._cvt(value, dialect)
|
|
return tup and datetime.datetime(*tup)
|
|
|
|
class SLDate(DateTimeMixin, sqltypes.Date):
|
|
__format__ = "%Y-%m-%d"
|
|
|
|
def get_col_spec(self):
|
|
return "DATE"
|
|
|
|
def convert_result_value(self, value, dialect):
|
|
tup = self._cvt(value, dialect)
|
|
return tup and datetime.date(*tup[0:3])
|
|
|
|
class SLTime(DateTimeMixin, sqltypes.Time):
|
|
__format__ = "%H:%M:%S"
|
|
|
|
def get_col_spec(self):
|
|
return "TIME"
|
|
|
|
def convert_result_value(self, value, dialect):
|
|
tup = self._cvt(value, dialect)
|
|
return tup and datetime.time(*tup[3:7])
|
|
|
|
class SLText(sqltypes.TEXT):
|
|
def get_col_spec(self):
|
|
return "TEXT"
|
|
|
|
class SLString(sqltypes.String):
|
|
def get_col_spec(self):
|
|
return "VARCHAR(%(length)s)" % {'length' : self.length}
|
|
|
|
class SLChar(sqltypes.CHAR):
|
|
def get_col_spec(self):
|
|
return "CHAR(%(length)s)" % {'length' : self.length}
|
|
|
|
class SLBinary(sqltypes.Binary):
|
|
def get_col_spec(self):
|
|
return "BLOB"
|
|
|
|
class SLBoolean(sqltypes.Boolean):
|
|
def get_col_spec(self):
|
|
return "BOOLEAN"
|
|
|
|
def convert_bind_param(self, value, dialect):
|
|
if value is None:
|
|
return None
|
|
return value and 1 or 0
|
|
|
|
def convert_result_value(self, value, dialect):
|
|
if value is None:
|
|
return None
|
|
return value and True or False
|
|
|
|
colspecs = {
|
|
sqltypes.Integer : SLInteger,
|
|
sqltypes.Smallinteger : SLSmallInteger,
|
|
sqltypes.Numeric : SLNumeric,
|
|
sqltypes.Float : SLNumeric,
|
|
sqltypes.DateTime : SLDateTime,
|
|
sqltypes.Date : SLDate,
|
|
sqltypes.Time : SLTime,
|
|
sqltypes.String : SLString,
|
|
sqltypes.Binary : SLBinary,
|
|
sqltypes.Boolean : SLBoolean,
|
|
sqltypes.TEXT : SLText,
|
|
sqltypes.CHAR: SLChar,
|
|
}
|
|
|
|
pragma_names = {
|
|
'INTEGER' : SLInteger,
|
|
'INT' : SLInteger,
|
|
'SMALLINT' : SLSmallInteger,
|
|
'VARCHAR' : SLString,
|
|
'CHAR' : SLChar,
|
|
'TEXT' : SLText,
|
|
'NUMERIC' : SLNumeric,
|
|
'FLOAT' : SLNumeric,
|
|
'TIMESTAMP' : SLDateTime,
|
|
'DATETIME' : SLDateTime,
|
|
'DATE' : SLDate,
|
|
'BLOB' : SLBinary,
|
|
}
|
|
|
|
def descriptor():
|
|
return {'name':'sqlite',
|
|
'description':'SQLite',
|
|
'arguments':[
|
|
('database', "Database Filename",None)
|
|
]}
|
|
|
|
class SQLiteExecutionContext(default.DefaultExecutionContext):
|
|
def post_exec(self):
|
|
if self.compiled.isinsert:
|
|
if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
|
|
self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:]
|
|
|
|
def is_select(self):
|
|
return re.match(r'SELECT|PRAGMA', self.statement.lstrip(), re.I) is not None
|
|
|
|
class SQLiteDialect(ansisql.ANSIDialect):
|
|
|
|
def __init__(self, **kwargs):
|
|
ansisql.ANSIDialect.__init__(self, default_paramstyle='qmark', **kwargs)
|
|
def vers(num):
|
|
return tuple([int(x) for x in num.split('.')])
|
|
if self.dbapi is not None:
|
|
sqlite_ver = self.dbapi.version_info
|
|
if sqlite_ver < (2,1,'3'):
|
|
warnings.warn(RuntimeWarning("The installed version of pysqlite2 (%s) is out-dated, and will cause errors in some cases. Version 2.1.3 or greater is recommended." % '.'.join([str(subver) for subver in sqlite_ver])))
|
|
self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3"))
|
|
|
|
def dbapi(cls):
|
|
try:
|
|
from pysqlite2 import dbapi2 as sqlite
|
|
except ImportError, e:
|
|
try:
|
|
from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name.
|
|
except ImportError:
|
|
raise e
|
|
return sqlite
|
|
dbapi = classmethod(dbapi)
|
|
|
|
def compiler(self, statement, bindparams, **kwargs):
|
|
return SQLiteCompiler(self, statement, bindparams, **kwargs)
|
|
|
|
def schemagenerator(self, *args, **kwargs):
|
|
return SQLiteSchemaGenerator(self, *args, **kwargs)
|
|
|
|
def schemadropper(self, *args, **kwargs):
|
|
return SQLiteSchemaDropper(self, *args, **kwargs)
|
|
|
|
def supports_alter(self):
|
|
return False
|
|
|
|
def preparer(self):
|
|
return SQLiteIdentifierPreparer(self)
|
|
|
|
def create_connect_args(self, url):
|
|
filename = url.database or ':memory:'
|
|
|
|
opts = url.query.copy()
|
|
util.coerce_kw_type(opts, 'timeout', float)
|
|
util.coerce_kw_type(opts, 'isolation_level', str)
|
|
util.coerce_kw_type(opts, 'detect_types', int)
|
|
util.coerce_kw_type(opts, 'check_same_thread', bool)
|
|
util.coerce_kw_type(opts, 'cached_statements', int)
|
|
|
|
return ([filename], opts)
|
|
|
|
def type_descriptor(self, typeobj):
|
|
return sqltypes.adapt_type(typeobj, colspecs)
|
|
|
|
def create_execution_context(self, **kwargs):
|
|
return SQLiteExecutionContext(self, **kwargs)
|
|
|
|
def supports_unicode_statements(self):
|
|
return True
|
|
|
|
def last_inserted_ids(self):
|
|
return self.context.last_inserted_ids
|
|
|
|
def oid_column_name(self, column):
|
|
return "oid"
|
|
|
|
def table_names(self, connection, schema):
|
|
s = "SELECT name FROM sqlite_master WHERE type='table'"
|
|
return [row[0] for row in connection.execute(s)]
|
|
|
|
def has_table(self, connection, table_name, schema=None):
|
|
cursor = connection.execute("PRAGMA table_info(%s)" %
|
|
self.identifier_preparer.quote_identifier(table_name), {})
|
|
row = cursor.fetchone()
|
|
|
|
# consume remaining rows, to work around: http://www.sqlite.org/cvstrac/tktview?tn=1884
|
|
while cursor.fetchone() is not None:pass
|
|
|
|
return (row is not None)
|
|
|
|
def reflecttable(self, connection, table, include_columns):
|
|
c = connection.execute("PRAGMA table_info(%s)" % self.preparer().format_table(table), {})
|
|
found_table = False
|
|
while True:
|
|
row = c.fetchone()
|
|
if row is None:
|
|
break
|
|
#print "row! " + repr(row)
|
|
found_table = True
|
|
(name, type, nullable, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4] is not None, row[5])
|
|
name = re.sub(r'^\"|\"$', '', name)
|
|
if include_columns and name not in include_columns:
|
|
continue
|
|
match = re.match(r'(\w+)(\(.*?\))?', type)
|
|
if match:
|
|
coltype = match.group(1)
|
|
args = match.group(2)
|
|
else:
|
|
coltype = "VARCHAR"
|
|
args = ''
|
|
|
|
#print "coltype: " + repr(coltype) + " args: " + repr(args)
|
|
try:
|
|
coltype = pragma_names[coltype]
|
|
except KeyError:
|
|
warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (coltype, name)))
|
|
coltype = sqltypes.NULLTYPE
|
|
|
|
if args is not None:
|
|
args = re.findall(r'(\d+)', args)
|
|
#print "args! " +repr(args)
|
|
coltype = coltype(*[int(a) for a in args])
|
|
|
|
colargs= []
|
|
if has_default:
|
|
colargs.append(PassiveDefault('?'))
|
|
table.append_column(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable, *colargs))
|
|
|
|
if not found_table:
|
|
raise exceptions.NoSuchTableError(table.name)
|
|
|
|
c = connection.execute("PRAGMA foreign_key_list(%s)" % self.preparer().format_table(table), {})
|
|
fks = {}
|
|
while True:
|
|
row = c.fetchone()
|
|
if row is None:
|
|
break
|
|
(constraint_name, tablename, localcol, remotecol) = (row[0], row[2], row[3], row[4])
|
|
tablename = re.sub(r'^\"|\"$', '', tablename)
|
|
localcol = re.sub(r'^\"|\"$', '', localcol)
|
|
remotecol = re.sub(r'^\"|\"$', '', remotecol)
|
|
try:
|
|
fk = fks[constraint_name]
|
|
except KeyError:
|
|
fk = ([],[])
|
|
fks[constraint_name] = fk
|
|
|
|
#print "row! " + repr([key for key in row.keys()]), repr(row)
|
|
# look up the table based on the given table's engine, not 'self',
|
|
# since it could be a ProxyEngine
|
|
remotetable = schema.Table(tablename, table.metadata, autoload=True, autoload_with=connection)
|
|
constrained_column = table.c[localcol].name
|
|
refspec = ".".join([tablename, remotecol])
|
|
if constrained_column not in fk[0]:
|
|
fk[0].append(constrained_column)
|
|
if refspec not in fk[1]:
|
|
fk[1].append(refspec)
|
|
for name, value in fks.iteritems():
|
|
table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1]))
|
|
# check for UNIQUE indexes
|
|
c = connection.execute("PRAGMA index_list(%s)" % self.preparer().format_table(table), {})
|
|
unique_indexes = []
|
|
while True:
|
|
row = c.fetchone()
|
|
if row is None:
|
|
break
|
|
if (row[2] == 1):
|
|
unique_indexes.append(row[1])
|
|
# loop thru unique indexes for one that includes the primary key
|
|
for idx in unique_indexes:
|
|
c = connection.execute("PRAGMA index_info(" + idx + ")", {})
|
|
cols = []
|
|
while True:
|
|
row = c.fetchone()
|
|
if row is None:
|
|
break
|
|
cols.append(row[2])
|
|
col = table.columns[row[2]]
|
|
|
|
class SQLiteCompiler(ansisql.ANSICompiler):
|
|
def visit_cast(self, cast):
|
|
if self.dialect.supports_cast:
|
|
return super(SQLiteCompiler, self).visit_cast(cast)
|
|
else:
|
|
if len(self.select_stack):
|
|
# not sure if we want to set the typemap here...
|
|
self.typemap.setdefault("CAST", cast.type)
|
|
return self.process(cast.clause)
|
|
|
|
def limit_clause(self, select):
|
|
text = ""
|
|
if select._limit is not None:
|
|
text += " \n LIMIT " + str(select._limit)
|
|
if select._offset is not None:
|
|
if select._limit is None:
|
|
text += " \n LIMIT -1"
|
|
text += " OFFSET " + str(select._offset)
|
|
else:
|
|
text += " OFFSET 0"
|
|
return text
|
|
|
|
def for_update_clause(self, select):
|
|
# sqlite has no "FOR UPDATE" AFAICT
|
|
return ''
|
|
|
|
class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
|
|
|
|
def get_column_specification(self, column, **kwargs):
|
|
colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
|
|
default = self.get_column_default_string(column)
|
|
if default is not None:
|
|
colspec += " DEFAULT " + default
|
|
|
|
if not column.nullable:
|
|
colspec += " NOT NULL"
|
|
return colspec
|
|
|
|
# this doesnt seem to be needed, although i suspect older versions of sqlite might still
|
|
# not directly support composite primary keys
|
|
#def visit_primary_key_constraint(self, constraint):
|
|
# if len(constraint) > 1:
|
|
# self.append(", \n")
|
|
# # put all PRIMARY KEYS in a UNIQUE index
|
|
# self.append("\tUNIQUE (%s)" % string.join([c.name for c in constraint],', '))
|
|
# else:
|
|
# super(SQLiteSchemaGenerator, self).visit_primary_key_constraint(constraint)
|
|
|
|
class SQLiteSchemaDropper(ansisql.ANSISchemaDropper):
|
|
pass
|
|
|
|
class SQLiteIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
|
|
def __init__(self, dialect):
|
|
super(SQLiteIdentifierPreparer, self).__init__(dialect, omit_schema=True)
|
|
|
|
dialect = SQLiteDialect
|
|
dialect.poolclass = pool.SingletonThreadPool
|