mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-06-05 15:23:48 -04:00
778 lines
30 KiB
Python
778 lines
30 KiB
Python
# oracle.py
|
|
# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com
|
|
#
|
|
# This module is part of SQLAlchemy and is released under
|
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
|
|
|
|
|
import datetime, random, re
|
|
|
|
from sqlalchemy import util, sql, schema, log
|
|
from sqlalchemy.engine import default, base
|
|
from sqlalchemy.sql import compiler, visitors
|
|
from sqlalchemy.sql import operators as sql_operators, functions as sql_functions
|
|
from sqlalchemy import types as sqltypes
|
|
|
|
|
|
class OracleNumeric(sqltypes.Numeric):
|
|
def get_col_spec(self):
|
|
if self.precision is None:
|
|
return "NUMERIC"
|
|
else:
|
|
return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}
|
|
|
|
class OracleInteger(sqltypes.Integer):
|
|
def get_col_spec(self):
|
|
return "INTEGER"
|
|
|
|
class OracleSmallInteger(sqltypes.Smallinteger):
|
|
def get_col_spec(self):
|
|
return "SMALLINT"
|
|
|
|
class OracleDate(sqltypes.Date):
|
|
def get_col_spec(self):
|
|
return "DATE"
|
|
def bind_processor(self, dialect):
|
|
return None
|
|
|
|
def result_processor(self, dialect):
|
|
def process(value):
|
|
if not isinstance(value, datetime.datetime):
|
|
return value
|
|
else:
|
|
return value.date()
|
|
return process
|
|
|
|
class OracleDateTime(sqltypes.DateTime):
|
|
def get_col_spec(self):
|
|
return "DATE"
|
|
|
|
def result_processor(self, dialect):
|
|
def process(value):
|
|
if value is None or isinstance(value, datetime.datetime):
|
|
return value
|
|
else:
|
|
# convert cx_oracle datetime object returned pre-python 2.4
|
|
return datetime.datetime(value.year, value.month,
|
|
value.day,value.hour, value.minute, value.second)
|
|
return process
|
|
|
|
# Note:
|
|
# Oracle DATE == DATETIME
|
|
# Oracle does not allow milliseconds in DATE
|
|
# Oracle does not support TIME columns
|
|
|
|
# only if cx_oracle contains TIMESTAMP
|
|
class OracleTimestamp(sqltypes.TIMESTAMP):
|
|
def get_col_spec(self):
|
|
return "TIMESTAMP"
|
|
|
|
def get_dbapi_type(self, dialect):
|
|
return dialect.TIMESTAMP
|
|
|
|
def result_processor(self, dialect):
|
|
def process(value):
|
|
if value is None or isinstance(value, datetime.datetime):
|
|
return value
|
|
else:
|
|
# convert cx_oracle datetime object returned pre-python 2.4
|
|
return datetime.datetime(value.year, value.month,
|
|
value.day,value.hour, value.minute, value.second)
|
|
return process
|
|
|
|
class OracleString(sqltypes.String):
|
|
def get_col_spec(self):
|
|
return "VARCHAR(%(length)s)" % {'length' : self.length}
|
|
|
|
class OracleText(sqltypes.Text):
|
|
def get_dbapi_type(self, dbapi):
|
|
return dbapi.CLOB
|
|
|
|
def get_col_spec(self):
|
|
return "CLOB"
|
|
|
|
def result_processor(self, dialect):
|
|
super_process = super(OracleText, self).result_processor(dialect)
|
|
lob = dialect.dbapi.LOB
|
|
def process(value):
|
|
if isinstance(value, lob):
|
|
if super_process:
|
|
return super_process(value.read())
|
|
else:
|
|
return value.read()
|
|
else:
|
|
if super_process:
|
|
return super_process(value)
|
|
else:
|
|
return value
|
|
return process
|
|
|
|
|
|
class OracleChar(sqltypes.CHAR):
|
|
def get_col_spec(self):
|
|
return "CHAR(%(length)s)" % {'length' : self.length}
|
|
|
|
class OracleBinary(sqltypes.Binary):
|
|
def get_dbapi_type(self, dbapi):
|
|
return dbapi.BLOB
|
|
|
|
def get_col_spec(self):
|
|
return "BLOB"
|
|
|
|
def bind_processor(self, dialect):
|
|
return None
|
|
|
|
def result_processor(self, dialect):
|
|
lob = dialect.dbapi.LOB
|
|
def process(value):
|
|
if isinstance(value, lob):
|
|
return value.read()
|
|
else:
|
|
return value
|
|
return process
|
|
|
|
class OracleRaw(OracleBinary):
|
|
def get_col_spec(self):
|
|
return "RAW(%(length)s)" % {'length' : self.length}
|
|
|
|
class OracleBoolean(sqltypes.Boolean):
|
|
def get_col_spec(self):
|
|
return "SMALLINT"
|
|
|
|
def result_processor(self, dialect):
|
|
def process(value):
|
|
if value is None:
|
|
return None
|
|
return value and True or False
|
|
return process
|
|
|
|
def bind_processor(self, dialect):
|
|
def process(value):
|
|
if value is True:
|
|
return 1
|
|
elif value is False:
|
|
return 0
|
|
elif value is None:
|
|
return None
|
|
else:
|
|
return value and True or False
|
|
return process
|
|
|
|
colspecs = {
|
|
sqltypes.Integer : OracleInteger,
|
|
sqltypes.Smallinteger : OracleSmallInteger,
|
|
sqltypes.Numeric : OracleNumeric,
|
|
sqltypes.Float : OracleNumeric,
|
|
sqltypes.DateTime : OracleDateTime,
|
|
sqltypes.Date : OracleDate,
|
|
sqltypes.String : OracleString,
|
|
sqltypes.Binary : OracleBinary,
|
|
sqltypes.Boolean : OracleBoolean,
|
|
sqltypes.Text : OracleText,
|
|
sqltypes.TIMESTAMP : OracleTimestamp,
|
|
sqltypes.CHAR: OracleChar,
|
|
}
|
|
|
|
ischema_names = {
|
|
'VARCHAR2' : OracleString,
|
|
'CHAR' : OracleString,
|
|
'DATE' : OracleDateTime,
|
|
'DATETIME' : OracleDateTime,
|
|
'NUMBER' : OracleNumeric,
|
|
'BLOB' : OracleBinary,
|
|
'BFILE' : OracleBinary,
|
|
'CLOB' : OracleText,
|
|
'TIMESTAMP' : OracleTimestamp,
|
|
'RAW' : OracleRaw,
|
|
'FLOAT' : OracleNumeric,
|
|
'DOUBLE PRECISION' : OracleNumeric,
|
|
'LONG' : OracleText,
|
|
}
|
|
|
|
class OracleExecutionContext(default.DefaultExecutionContext):
|
|
def pre_exec(self):
|
|
super(OracleExecutionContext, self).pre_exec()
|
|
if self.dialect.auto_setinputsizes:
|
|
self.set_input_sizes()
|
|
if self.compiled_parameters is not None and len(self.compiled_parameters) == 1:
|
|
for key in self.compiled.binds:
|
|
bindparam = self.compiled.binds[key]
|
|
name = self.compiled.bind_names[bindparam]
|
|
value = self.compiled_parameters[0][name]
|
|
if bindparam.isoutparam:
|
|
dbtype = bindparam.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
|
|
if not hasattr(self, 'out_parameters'):
|
|
self.out_parameters = {}
|
|
self.out_parameters[name] = self.cursor.var(dbtype)
|
|
self.parameters[0][name] = self.out_parameters[name]
|
|
|
|
def create_cursor(self):
|
|
c = self._connection.connection.cursor()
|
|
if self.dialect.arraysize:
|
|
c.cursor.arraysize = self.dialect.arraysize
|
|
return c
|
|
|
|
def get_result_proxy(self):
|
|
if hasattr(self, 'out_parameters'):
|
|
if self.compiled_parameters is not None and len(self.compiled_parameters) == 1:
|
|
for bind, name in self.compiled.bind_names.iteritems():
|
|
if name in self.out_parameters:
|
|
type = bind.type
|
|
self.out_parameters[name] = type.dialect_impl(self.dialect).result_processor(self.dialect)(self.out_parameters[name].getvalue())
|
|
else:
|
|
for k in self.out_parameters:
|
|
self.out_parameters[k] = self.out_parameters[k].getvalue()
|
|
|
|
if self.cursor.description is not None:
|
|
for column in self.cursor.description:
|
|
type_code = column[1]
|
|
if type_code in self.dialect.ORACLE_BINARY_TYPES:
|
|
return base.BufferedColumnResultProxy(self)
|
|
|
|
return base.ResultProxy(self)
|
|
|
|
class OracleDialect(default.DefaultDialect):
|
|
name = 'oracle'
|
|
supports_alter = True
|
|
supports_unicode_statements = False
|
|
max_identifier_length = 30
|
|
supports_sane_rowcount = True
|
|
supports_sane_multi_rowcount = False
|
|
preexecute_pk_sequences = True
|
|
supports_pk_autoincrement = False
|
|
default_paramstyle = 'named'
|
|
|
|
def __init__(self, use_ansi=True, auto_setinputsizes=True, auto_convert_lobs=True, threaded=True, allow_twophase=True, arraysize=50, **kwargs):
|
|
default.DefaultDialect.__init__(self, **kwargs)
|
|
self.use_ansi = use_ansi
|
|
self.threaded = threaded
|
|
self.arraysize = arraysize
|
|
self.allow_twophase = allow_twophase
|
|
self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' )
|
|
self.auto_setinputsizes = auto_setinputsizes
|
|
self.auto_convert_lobs = auto_convert_lobs
|
|
if self.dbapi is None or not self.auto_convert_lobs or not 'CLOB' in self.dbapi.__dict__:
|
|
self.dbapi_type_map = {}
|
|
self.ORACLE_BINARY_TYPES = []
|
|
else:
|
|
# only use this for LOB objects. using it for strings, dates
|
|
# etc. leads to a little too much magic, reflection doesn't know if it should
|
|
# expect encoded strings or unicodes, etc.
|
|
self.dbapi_type_map = {
|
|
self.dbapi.CLOB: OracleText(),
|
|
self.dbapi.BLOB: OracleBinary(),
|
|
self.dbapi.BINARY: OracleRaw(),
|
|
}
|
|
self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB"] if hasattr(self.dbapi, k)]
|
|
|
|
def dbapi(cls):
|
|
import cx_Oracle
|
|
return cx_Oracle
|
|
dbapi = classmethod(dbapi)
|
|
|
|
def create_connect_args(self, url):
|
|
dialect_opts = dict(url.query)
|
|
for opt in ('use_ansi', 'auto_setinputsizes', 'auto_convert_lobs',
|
|
'threaded', 'allow_twophase'):
|
|
if opt in dialect_opts:
|
|
util.coerce_kw_type(dialect_opts, opt, bool)
|
|
setattr(self, opt, dialect_opts[opt])
|
|
|
|
if url.database:
|
|
# if we have a database, then we have a remote host
|
|
port = url.port
|
|
if port:
|
|
port = int(port)
|
|
else:
|
|
port = 1521
|
|
dsn = self.dbapi.makedsn(url.host, port, url.database)
|
|
else:
|
|
# we have a local tnsname
|
|
dsn = url.host
|
|
|
|
opts = dict(
|
|
user=url.username,
|
|
password=url.password,
|
|
dsn=dsn,
|
|
threaded=self.threaded,
|
|
twophase=self.allow_twophase,
|
|
)
|
|
if 'mode' in url.query:
|
|
opts['mode'] = url.query['mode']
|
|
if isinstance(opts['mode'], basestring):
|
|
mode = opts['mode'].upper()
|
|
if mode == 'SYSDBA':
|
|
opts['mode'] = self.dbapi.SYSDBA
|
|
elif mode == 'SYSOPER':
|
|
opts['mode'] = self.dbapi.SYSOPER
|
|
else:
|
|
util.coerce_kw_type(opts, 'mode', int)
|
|
# Can't set 'handle' or 'pool' via URL query args, use connect_args
|
|
|
|
return ([], opts)
|
|
|
|
def is_disconnect(self, e):
|
|
if isinstance(e, self.dbapi.InterfaceError):
|
|
return "not connected" in str(e)
|
|
else:
|
|
return "ORA-03114" in str(e) or "ORA-03113" in str(e)
|
|
|
|
def type_descriptor(self, typeobj):
|
|
return sqltypes.adapt_type(typeobj, colspecs)
|
|
|
|
def oid_column_name(self, column):
|
|
if not isinstance(column.table, (sql.TableClause, sql.Select)):
|
|
return None
|
|
else:
|
|
return "rowid"
|
|
|
|
def create_xid(self):
|
|
"""create a two-phase transaction ID.
|
|
|
|
this id will be passed to do_begin_twophase(), do_rollback_twophase(),
|
|
do_commit_twophase(). its format is unspecified."""
|
|
|
|
id = random.randint(0, 2 ** 128)
|
|
return (0x1234, "%032x" % 9, "%032x" % id)
|
|
|
|
def do_release_savepoint(self, connection, name):
|
|
# Oracle does not support RELEASE SAVEPOINT
|
|
pass
|
|
|
|
def do_begin_twophase(self, connection, xid):
|
|
connection.connection.begin(*xid)
|
|
|
|
def do_prepare_twophase(self, connection, xid):
|
|
connection.connection.prepare()
|
|
|
|
def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
|
|
self.do_rollback(connection.connection)
|
|
|
|
def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
|
|
self.do_commit(connection.connection)
|
|
|
|
def do_recover_twophase(self, connection):
|
|
pass
|
|
|
|
def create_execution_context(self, *args, **kwargs):
|
|
return OracleExecutionContext(self, *args, **kwargs)
|
|
|
|
def has_table(self, connection, table_name, schema=None):
|
|
if not schema:
|
|
schema = self.get_default_schema_name(connection)
|
|
cursor = connection.execute("""select table_name from all_tables where table_name=:name and owner=:schema_name""", {'name':self._denormalize_name(table_name), 'schema_name':self._denormalize_name(schema)})
|
|
return cursor.fetchone() is not None
|
|
|
|
def has_sequence(self, connection, sequence_name, schema=None):
|
|
if not schema:
|
|
schema = self.get_default_schema_name(connection)
|
|
cursor = connection.execute("""select sequence_name from all_sequences where sequence_name=:name and sequence_owner=:schema_name""", {'name':self._denormalize_name(sequence_name), 'schema_name':self._denormalize_name(schema)})
|
|
return cursor.fetchone() is not None
|
|
|
|
def _normalize_name(self, name):
|
|
if name is None:
|
|
return None
|
|
elif name.upper() == name and not self.identifier_preparer._requires_quotes(name.lower().decode(self.encoding)):
|
|
return name.lower().decode(self.encoding)
|
|
else:
|
|
return name.decode(self.encoding)
|
|
|
|
def _denormalize_name(self, name):
|
|
if name is None:
|
|
return None
|
|
elif name.lower() == name and not self.identifier_preparer._requires_quotes(name.lower()):
|
|
return name.upper().encode(self.encoding)
|
|
else:
|
|
return name.encode(self.encoding)
|
|
|
|
def get_default_schema_name(self, connection):
|
|
return self._normalize_name(connection.execute('SELECT USER FROM DUAL').scalar())
|
|
get_default_schema_name = base.connection_memoize(
|
|
('dialect', 'default_schema_name'))(get_default_schema_name)
|
|
|
|
def table_names(self, connection, schema):
|
|
# note that table_names() isnt loading DBLINKed or synonym'ed tables
|
|
if schema is None:
|
|
s = "select table_name from all_tables where nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM', 'SYSAUX')"
|
|
cursor = connection.execute(s)
|
|
else:
|
|
s = "select table_name from all_tables where nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM','SYSAUX') AND OWNER = :owner"
|
|
cursor = connection.execute(s, {'owner': self._denormalize_name(schema)})
|
|
return [self._normalize_name(row[0]) for row in cursor]
|
|
|
|
def _resolve_synonym(self, connection, desired_owner=None, desired_synonym=None, desired_table=None):
|
|
"""search for a local synonym matching the given desired owner/name.
|
|
|
|
if desired_owner is None, attempts to locate a distinct owner.
|
|
|
|
returns the actual name, owner, dblink name, and synonym name if found.
|
|
"""
|
|
|
|
sql = """select OWNER, TABLE_OWNER, TABLE_NAME, DB_LINK, SYNONYM_NAME
|
|
from ALL_SYNONYMS WHERE """
|
|
|
|
clauses = []
|
|
params = {}
|
|
if desired_synonym:
|
|
clauses.append("SYNONYM_NAME=:synonym_name")
|
|
params['synonym_name'] = desired_synonym
|
|
if desired_owner:
|
|
clauses.append("TABLE_OWNER=:desired_owner")
|
|
params['desired_owner'] = desired_owner
|
|
if desired_table:
|
|
clauses.append("TABLE_NAME=:tname")
|
|
params['tname'] = desired_table
|
|
|
|
sql += " AND ".join(clauses)
|
|
|
|
result = connection.execute(sql, **params)
|
|
if desired_owner:
|
|
row = result.fetchone()
|
|
if row:
|
|
return row['TABLE_NAME'], row['TABLE_OWNER'], row['DB_LINK'], row['SYNONYM_NAME']
|
|
else:
|
|
return None, None, None, None
|
|
else:
|
|
rows = result.fetchall()
|
|
if len(rows) > 1:
|
|
raise AssertionError("There are multiple tables visible to the schema, you must specify owner")
|
|
elif len(rows) == 1:
|
|
row = rows[0]
|
|
return row['TABLE_NAME'], row['TABLE_OWNER'], row['DB_LINK'], row['SYNONYM_NAME']
|
|
else:
|
|
return None, None, None, None
|
|
|
|
def reflecttable(self, connection, table, include_columns):
|
|
preparer = self.identifier_preparer
|
|
|
|
resolve_synonyms = table.kwargs.get('oracle_resolve_synonyms', False)
|
|
|
|
if resolve_synonyms:
|
|
actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(table.schema), desired_synonym=self._denormalize_name(table.name))
|
|
else:
|
|
actual_name, owner, dblink, synonym = None, None, None, None
|
|
|
|
if not actual_name:
|
|
actual_name = self._denormalize_name(table.name)
|
|
if not dblink:
|
|
dblink = ''
|
|
if not owner:
|
|
owner = self._denormalize_name(table.schema or self.get_default_schema_name(connection))
|
|
|
|
c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name and OWNER = :owner" % {'dblink':dblink}, {'table_name':actual_name, 'owner':owner})
|
|
|
|
while True:
|
|
row = c.fetchone()
|
|
if row is None:
|
|
break
|
|
|
|
(colname, coltype, length, precision, scale, nullable, default) = (self._normalize_name(row[0]), row[1], row[2], row[3], row[4], row[5]=='Y', row[6])
|
|
|
|
if include_columns and colname not in include_columns:
|
|
continue
|
|
|
|
# INTEGER if the scale is 0 and precision is null
|
|
# NUMBER if the scale and precision are both null
|
|
# NUMBER(9,2) if the precision is 9 and the scale is 2
|
|
# NUMBER(3) if the precision is 3 and scale is 0
|
|
#length is ignored except for CHAR and VARCHAR2
|
|
if coltype == 'NUMBER' :
|
|
if precision is None and scale is None:
|
|
coltype = OracleNumeric
|
|
elif precision is None and scale == 0 :
|
|
coltype = OracleInteger
|
|
else :
|
|
coltype = OracleNumeric(precision, scale)
|
|
elif coltype=='CHAR' or coltype=='VARCHAR2':
|
|
coltype = ischema_names.get(coltype, OracleString)(length)
|
|
else:
|
|
coltype = re.sub(r'\(\d+\)', '', coltype)
|
|
try:
|
|
coltype = ischema_names[coltype]
|
|
except KeyError:
|
|
util.warn("Did not recognize type '%s' of column '%s'" %
|
|
(coltype, colname))
|
|
coltype = sqltypes.NULLTYPE
|
|
|
|
colargs = []
|
|
if default is not None:
|
|
colargs.append(schema.DefaultClause(sql.text(default)))
|
|
|
|
table.append_column(schema.Column(colname, coltype, nullable=nullable, *colargs))
|
|
|
|
if not table.columns:
|
|
raise AssertionError("Couldn't find any column information for table %s" % actual_name)
|
|
|
|
c = connection.execute("""SELECT
|
|
ac.constraint_name,
|
|
ac.constraint_type,
|
|
loc.column_name AS local_column,
|
|
rem.table_name AS remote_table,
|
|
rem.column_name AS remote_column,
|
|
rem.owner AS remote_owner
|
|
FROM all_constraints%(dblink)s ac,
|
|
all_cons_columns%(dblink)s loc,
|
|
all_cons_columns%(dblink)s rem
|
|
WHERE ac.table_name = :table_name
|
|
AND ac.constraint_type IN ('R','P')
|
|
AND ac.owner = :owner
|
|
AND ac.owner = loc.owner
|
|
AND ac.constraint_name = loc.constraint_name
|
|
AND ac.r_owner = rem.owner(+)
|
|
AND ac.r_constraint_name = rem.constraint_name(+)
|
|
-- order multiple primary keys correctly
|
|
ORDER BY ac.constraint_name, loc.position, rem.position"""
|
|
% {'dblink':dblink}, {'table_name' : actual_name, 'owner' : owner})
|
|
|
|
fks = {}
|
|
while True:
|
|
row = c.fetchone()
|
|
if row is None:
|
|
break
|
|
#print "ROW:" , row
|
|
(cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = row[0:2] + tuple([self._normalize_name(x) for x in row[2:]])
|
|
if cons_type == 'P':
|
|
table.primary_key.add(table.c[local_column])
|
|
elif cons_type == 'R':
|
|
try:
|
|
fk = fks[cons_name]
|
|
except KeyError:
|
|
fk = ([], [])
|
|
fks[cons_name] = fk
|
|
if remote_table is None:
|
|
# ticket 363
|
|
util.warn(
|
|
("Got 'None' querying 'table_name' from "
|
|
"all_cons_columns%(dblink)s - does the user have "
|
|
"proper rights to the table?") % {'dblink':dblink})
|
|
continue
|
|
|
|
if resolve_synonyms:
|
|
ref_remote_name, ref_remote_owner, ref_dblink, ref_synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(remote_owner), desired_table=self._denormalize_name(remote_table))
|
|
if ref_synonym:
|
|
remote_table = self._normalize_name(ref_synonym)
|
|
remote_owner = self._normalize_name(ref_remote_owner)
|
|
|
|
if not table.schema and self._denormalize_name(remote_owner) == owner:
|
|
refspec = ".".join([remote_table, remote_column])
|
|
t = schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection, oracle_resolve_synonyms=resolve_synonyms, useexisting=True)
|
|
else:
|
|
refspec = ".".join([x for x in [remote_owner, remote_table, remote_column] if x])
|
|
t = schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection, schema=remote_owner, oracle_resolve_synonyms=resolve_synonyms, useexisting=True)
|
|
|
|
if local_column not in fk[0]:
|
|
fk[0].append(local_column)
|
|
if refspec not in fk[1]:
|
|
fk[1].append(refspec)
|
|
|
|
for name, value in fks.iteritems():
|
|
table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name))
|
|
|
|
|
|
class _OuterJoinColumn(sql.ClauseElement):
|
|
__visit_name__ = 'outer_join_column'
|
|
def __init__(self, column):
|
|
self.column = column
|
|
def _get_from_objects(self, **kwargs):
|
|
return []
|
|
|
|
class OracleCompiler(compiler.DefaultCompiler):
|
|
"""Oracle compiler modifies the lexical structure of Select
|
|
statements to work under non-ANSI configured Oracle databases, if
|
|
the use_ansi flag is False.
|
|
"""
|
|
|
|
operators = compiler.DefaultCompiler.operators.copy()
|
|
operators.update(
|
|
{
|
|
sql_operators.mod : lambda x, y:"mod(%s, %s)" % (x, y),
|
|
sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y)
|
|
}
|
|
)
|
|
|
|
functions = compiler.DefaultCompiler.functions.copy()
|
|
functions.update (
|
|
{
|
|
sql_functions.now : 'CURRENT_TIMESTAMP'
|
|
}
|
|
)
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(OracleCompiler, self).__init__(*args, **kwargs)
|
|
self.__wheres = {}
|
|
|
|
def default_from(self):
|
|
"""Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended.
|
|
|
|
The Oracle compiler tacks a "FROM DUAL" to the statement.
|
|
"""
|
|
|
|
return " FROM DUAL"
|
|
|
|
def apply_function_parens(self, func):
|
|
return len(func.clauses) > 0
|
|
|
|
def visit_join(self, join, **kwargs):
|
|
if self.dialect.use_ansi:
|
|
return compiler.DefaultCompiler.visit_join(self, join, **kwargs)
|
|
else:
|
|
return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True)
|
|
|
|
def _get_nonansi_join_whereclause(self, froms):
|
|
clauses = []
|
|
|
|
def visit_join(join):
|
|
if join.isouter:
|
|
def visit_binary(binary):
|
|
if binary.operator == sql_operators.eq:
|
|
if binary.left.table is join.right:
|
|
binary.left = _OuterJoinColumn(binary.left)
|
|
elif binary.right.table is join.right:
|
|
binary.right = _OuterJoinColumn(binary.right)
|
|
clauses.append(visitors.cloned_traverse(join.onclause, {}, {'binary':visit_binary}))
|
|
else:
|
|
clauses.append(join.onclause)
|
|
|
|
for f in froms:
|
|
visitors.traverse(f, {}, {'join':visit_join})
|
|
return sql.and_(*clauses)
|
|
|
|
def visit_outer_join_column(self, vc):
|
|
return self.process(vc.column) + "(+)"
|
|
|
|
def visit_sequence(self, seq):
|
|
return self.dialect.identifier_preparer.format_sequence(seq) + ".nextval"
|
|
|
|
def visit_alias(self, alias, asfrom=False, **kwargs):
|
|
"""Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??"""
|
|
|
|
if asfrom:
|
|
return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + self.preparer.format_alias(alias, self._anonymize(alias.name))
|
|
else:
|
|
return self.process(alias.original, **kwargs)
|
|
|
|
def _TODO_visit_compound_select(self, select):
|
|
"""Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle."""
|
|
pass
|
|
|
|
def visit_select(self, select, **kwargs):
|
|
"""Look for ``LIMIT`` and OFFSET in a select statement, and if
|
|
so tries to wrap it in a subquery with ``rownum`` criterion.
|
|
"""
|
|
|
|
if not getattr(select, '_oracle_visit', None):
|
|
if not self.dialect.use_ansi:
|
|
if self.stack and 'from' in self.stack[-1]:
|
|
existingfroms = self.stack[-1]['from']
|
|
else:
|
|
existingfroms = None
|
|
|
|
froms = select._get_display_froms(existingfroms)
|
|
whereclause = self._get_nonansi_join_whereclause(froms)
|
|
if whereclause:
|
|
select = select.where(whereclause)
|
|
select._oracle_visit = True
|
|
|
|
if select._limit is not None or select._offset is not None:
|
|
# See http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html
|
|
#
|
|
# Generalized form of an Oracle pagination query:
|
|
# select ... from (
|
|
# select /*+ FIRST_ROWS(N) */ ...., rownum as ora_rn from (
|
|
# select distinct ... where ... order by ...
|
|
# ) where ROWNUM <= :limit+:offset
|
|
# ) where ora_rn > :offset
|
|
# Outer select and "ROWNUM as ora_rn" can be dropped if limit=0
|
|
|
|
# TODO: use annotations instead of clone + attr set ?
|
|
select = select._generate()
|
|
select._oracle_visit = True
|
|
|
|
# Wrap the middle select and add the hint
|
|
limitselect = sql.select([c for c in select.c])
|
|
if select._limit:
|
|
limitselect = limitselect.prefix_with("/*+ FIRST_ROWS(%d) */" % select._limit)
|
|
|
|
limitselect._oracle_visit = True
|
|
limitselect._is_wrapper = True
|
|
|
|
# If needed, add the limiting clause
|
|
if select._limit is not None:
|
|
max_row = select._limit
|
|
if select._offset is not None:
|
|
max_row += select._offset
|
|
limitselect.append_whereclause(
|
|
sql.literal_column("ROWNUM")<=max_row)
|
|
|
|
# If needed, add the ora_rn, and wrap again with offset.
|
|
if select._offset is None:
|
|
select = limitselect
|
|
else:
|
|
limitselect = limitselect.column(
|
|
sql.literal_column("ROWNUM").label("ora_rn"))
|
|
limitselect._oracle_visit = True
|
|
limitselect._is_wrapper = True
|
|
|
|
offsetselect = sql.select(
|
|
[c for c in limitselect.c if c.key!='ora_rn'])
|
|
offsetselect._oracle_visit = True
|
|
offsetselect._is_wrapper = True
|
|
|
|
offsetselect.append_whereclause(
|
|
sql.literal_column("ora_rn")>select._offset)
|
|
|
|
select = offsetselect
|
|
|
|
kwargs['iswrapper'] = getattr(select, '_is_wrapper', False)
|
|
return compiler.DefaultCompiler.visit_select(self, select, **kwargs)
|
|
|
|
def limit_clause(self, select):
|
|
return ""
|
|
|
|
def for_update_clause(self, select):
|
|
if select.for_update == "nowait":
|
|
return " FOR UPDATE NOWAIT"
|
|
else:
|
|
return super(OracleCompiler, self).for_update_clause(select)
|
|
|
|
|
|
class OracleSchemaGenerator(compiler.SchemaGenerator):
|
|
def get_column_specification(self, column, **kwargs):
|
|
colspec = self.preparer.format_column(column)
|
|
colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
|
|
default = self.get_column_default_string(column)
|
|
if default is not None:
|
|
colspec += " DEFAULT " + default
|
|
|
|
if not column.nullable:
|
|
colspec += " NOT NULL"
|
|
return colspec
|
|
|
|
def visit_sequence(self, sequence):
|
|
if not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name, sequence.schema):
|
|
self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
|
|
self.execute()
|
|
|
|
class OracleSchemaDropper(compiler.SchemaDropper):
|
|
def visit_sequence(self, sequence):
|
|
if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name, sequence.schema):
|
|
self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence))
|
|
self.execute()
|
|
|
|
class OracleDefaultRunner(base.DefaultRunner):
|
|
def visit_sequence(self, seq):
|
|
return self.execute_string("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL", {})
|
|
|
|
class OracleIdentifierPreparer(compiler.IdentifierPreparer):
|
|
def format_savepoint(self, savepoint):
|
|
name = re.sub(r'^_+', '', savepoint.ident)
|
|
return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name)
|
|
|
|
|
|
dialect = OracleDialect
|
|
dialect.statement_compiler = OracleCompiler
|
|
dialect.schemagenerator = OracleSchemaGenerator
|
|
dialect.schemadropper = OracleSchemaDropper
|
|
dialect.preparer = OracleIdentifierPreparer
|
|
dialect.defaultrunner = OracleDefaultRunner
|