mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-06-05 07:17:06 -04:00
795 lines
31 KiB
Python
795 lines
31 KiB
Python
# postgres.py
|
|
# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com
|
|
#
|
|
# This module is part of SQLAlchemy and is released under
|
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
|
|
|
"""Support for the PostgreSQL database.
|
|
|
|
PostgreSQL supports partial indexes. To create them pass a posgres_where
|
|
option to the Index constructor::
|
|
|
|
Index('my_index', my_table.c.id, postgres_where=tbl.c.value > 10)
|
|
|
|
PostgreSQL 8.2+ supports returning a result set from inserts and updates.
|
|
To use this pass the column/expression list to the postgres_returning
|
|
parameter when creating the queries::
|
|
|
|
raises = tbl.update(empl.c.sales > 100, values=dict(salary=empl.c.salary * 1.1),
|
|
postgres_returning=[empl.c.id, empl.c.salary]).execute().fetchall()
|
|
"""
|
|
|
|
import decimal, random, re, string
|
|
|
|
from sqlalchemy import sql, schema, exc, util
|
|
from sqlalchemy.engine import base, default
|
|
from sqlalchemy.sql import compiler, expression
|
|
from sqlalchemy.sql import operators as sql_operators
|
|
from sqlalchemy import types as sqltypes
|
|
|
|
|
|
class PGInet(sqltypes.TypeEngine):
|
|
def get_col_spec(self):
|
|
return "INET"
|
|
|
|
class PGCidr(sqltypes.TypeEngine):
|
|
def get_col_spec(self):
|
|
return "CIDR"
|
|
|
|
class PGMacAddr(sqltypes.TypeEngine):
|
|
def get_col_spec(self):
|
|
return "MACADDR"
|
|
|
|
class PGNumeric(sqltypes.Numeric):
|
|
def get_col_spec(self):
|
|
if not self.precision:
|
|
return "NUMERIC"
|
|
else:
|
|
return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}
|
|
|
|
def bind_processor(self, dialect):
|
|
return None
|
|
|
|
def result_processor(self, dialect):
|
|
if self.asdecimal:
|
|
return None
|
|
else:
|
|
def process(value):
|
|
if isinstance(value, decimal.Decimal):
|
|
return float(value)
|
|
else:
|
|
return value
|
|
return process
|
|
|
|
class PGFloat(sqltypes.Float):
|
|
def get_col_spec(self):
|
|
if not self.precision:
|
|
return "FLOAT"
|
|
else:
|
|
return "FLOAT(%(precision)s)" % {'precision': self.precision}
|
|
|
|
|
|
class PGInteger(sqltypes.Integer):
|
|
def get_col_spec(self):
|
|
return "INTEGER"
|
|
|
|
class PGSmallInteger(sqltypes.Smallinteger):
|
|
def get_col_spec(self):
|
|
return "SMALLINT"
|
|
|
|
class PGBigInteger(PGInteger):
|
|
def get_col_spec(self):
|
|
return "BIGINT"
|
|
|
|
class PGDateTime(sqltypes.DateTime):
|
|
def get_col_spec(self):
|
|
return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
|
|
|
|
class PGDate(sqltypes.Date):
|
|
def get_col_spec(self):
|
|
return "DATE"
|
|
|
|
class PGTime(sqltypes.Time):
|
|
def get_col_spec(self):
|
|
return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
|
|
|
|
class PGInterval(sqltypes.TypeEngine):
|
|
def get_col_spec(self):
|
|
return "INTERVAL"
|
|
|
|
class PGText(sqltypes.Text):
|
|
def get_col_spec(self):
|
|
return "TEXT"
|
|
|
|
class PGString(sqltypes.String):
|
|
def get_col_spec(self):
|
|
if self.length:
|
|
return "VARCHAR(%(length)d)" % {'length' : self.length}
|
|
else:
|
|
return "VARCHAR"
|
|
|
|
class PGChar(sqltypes.CHAR):
|
|
def get_col_spec(self):
|
|
if self.length:
|
|
return "CHAR(%(length)d)" % {'length' : self.length}
|
|
else:
|
|
return "CHAR"
|
|
|
|
class PGBinary(sqltypes.Binary):
|
|
def get_col_spec(self):
|
|
return "BYTEA"
|
|
|
|
class PGBoolean(sqltypes.Boolean):
|
|
def get_col_spec(self):
|
|
return "BOOLEAN"
|
|
|
|
class PGArray(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
|
|
def __init__(self, item_type, mutable=True):
|
|
if isinstance(item_type, type):
|
|
item_type = item_type()
|
|
self.item_type = item_type
|
|
self.mutable = mutable
|
|
|
|
def copy_value(self, value):
|
|
if value is None:
|
|
return None
|
|
elif self.mutable:
|
|
return list(value)
|
|
else:
|
|
return value
|
|
|
|
def compare_values(self, x, y):
|
|
return x == y
|
|
|
|
def is_mutable(self):
|
|
return self.mutable
|
|
|
|
def dialect_impl(self, dialect, **kwargs):
|
|
impl = self.__class__.__new__(self.__class__)
|
|
impl.__dict__.update(self.__dict__)
|
|
impl.item_type = self.item_type.dialect_impl(dialect)
|
|
return impl
|
|
|
|
def bind_processor(self, dialect):
|
|
item_proc = self.item_type.bind_processor(dialect)
|
|
def process(value):
|
|
if value is None:
|
|
return value
|
|
def convert_item(item):
|
|
if isinstance(item, (list, tuple)):
|
|
return [convert_item(child) for child in item]
|
|
else:
|
|
if item_proc:
|
|
return item_proc(item)
|
|
else:
|
|
return item
|
|
return [convert_item(item) for item in value]
|
|
return process
|
|
|
|
def result_processor(self, dialect):
|
|
item_proc = self.item_type.result_processor(dialect)
|
|
def process(value):
|
|
if value is None:
|
|
return value
|
|
def convert_item(item):
|
|
if isinstance(item, list):
|
|
return [convert_item(child) for child in item]
|
|
else:
|
|
if item_proc:
|
|
return item_proc(item)
|
|
else:
|
|
return item
|
|
return [convert_item(item) for item in value]
|
|
return process
|
|
def get_col_spec(self):
|
|
return self.item_type.get_col_spec() + '[]'
|
|
|
|
colspecs = {
|
|
sqltypes.Integer : PGInteger,
|
|
sqltypes.Smallinteger : PGSmallInteger,
|
|
sqltypes.Numeric : PGNumeric,
|
|
sqltypes.Float : PGFloat,
|
|
sqltypes.DateTime : PGDateTime,
|
|
sqltypes.Date : PGDate,
|
|
sqltypes.Time : PGTime,
|
|
sqltypes.String : PGString,
|
|
sqltypes.Binary : PGBinary,
|
|
sqltypes.Boolean : PGBoolean,
|
|
sqltypes.Text : PGText,
|
|
sqltypes.CHAR: PGChar,
|
|
}
|
|
|
|
ischema_names = {
|
|
'integer' : PGInteger,
|
|
'bigint' : PGBigInteger,
|
|
'smallint' : PGSmallInteger,
|
|
'character varying' : PGString,
|
|
'character' : PGChar,
|
|
'text' : PGText,
|
|
'numeric' : PGNumeric,
|
|
'float' : PGFloat,
|
|
'real' : PGFloat,
|
|
'inet': PGInet,
|
|
'cidr': PGCidr,
|
|
'macaddr': PGMacAddr,
|
|
'double precision' : PGFloat,
|
|
'timestamp' : PGDateTime,
|
|
'timestamp with time zone' : PGDateTime,
|
|
'timestamp without time zone' : PGDateTime,
|
|
'time with time zone' : PGTime,
|
|
'time without time zone' : PGTime,
|
|
'date' : PGDate,
|
|
'time': PGTime,
|
|
'bytea' : PGBinary,
|
|
'boolean' : PGBoolean,
|
|
'interval':PGInterval,
|
|
}
|
|
|
|
SERVER_SIDE_CURSOR_RE = re.compile(
|
|
r'\s*SELECT',
|
|
re.I | re.UNICODE)
|
|
|
|
SELECT_RE = re.compile(
|
|
r'\s*(?:SELECT|FETCH|(UPDATE|INSERT))',
|
|
re.I | re.UNICODE)
|
|
|
|
RETURNING_RE = re.compile(
|
|
'RETURNING',
|
|
re.I | re.UNICODE)
|
|
|
|
# This finds if the RETURNING is not inside a quoted/commented values. Handles string literals,
|
|
# quoted identifiers, dollar quotes, SQL comments and C style multiline comments. This does not
|
|
# handle correctly nested C style quotes, lets hope no one does the following:
|
|
# UPDATE tbl SET x=y /* foo /* bar */ RETURNING */
|
|
RETURNING_QUOTED_RE = re.compile(
|
|
"""\s*(?:UPDATE|INSERT)\s
|
|
(?: # handle quoted and commented tokens separately
|
|
[^'"$/-] # non quote/comment character
|
|
| -(?!-) # a dash that does not begin a comment
|
|
| /(?!\*) # a slash that does not begin a comment
|
|
| "(?:[^"]|"")*" # quoted literal
|
|
| '(?:[^']|'')*' # quoted string
|
|
| \$(?P<dquote>[^$]*)\$.*?\$(?P=dquote)\$ # dollar quotes
|
|
| --[^\\n]*(?=\\n) # SQL comment, leave out line ending as that counts as whitespace
|
|
# for the returning token
|
|
| /\*([^*]|\*(?!/))*\*/ # C style comment, doesn't handle nesting
|
|
)*
|
|
\sRETURNING\s""", re.I | re.UNICODE | re.VERBOSE)
|
|
|
|
class PGExecutionContext(default.DefaultExecutionContext):
|
|
def returns_rows_text(self, statement):
|
|
m = SELECT_RE.match(statement)
|
|
return m and (not m.group(1) or (RETURNING_RE.search(statement)
|
|
and RETURNING_QUOTED_RE.match(statement)))
|
|
|
|
def returns_rows_compiled(self, compiled):
|
|
return isinstance(compiled.statement, expression.Selectable) or \
|
|
(
|
|
(compiled.isupdate or compiled.isinsert) and "postgres_returning" in compiled.statement.kwargs
|
|
)
|
|
|
|
def create_cursor(self):
|
|
self.__is_server_side = \
|
|
self.dialect.server_side_cursors and \
|
|
((self.compiled and isinstance(self.compiled.statement, expression.Selectable)) \
|
|
or \
|
|
(
|
|
(not self.compiled or isinstance(self.compiled.statement, expression._TextClause))
|
|
and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement))
|
|
)
|
|
|
|
if self.__is_server_side:
|
|
# use server-side cursors:
|
|
# http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
|
|
ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:])
|
|
return self._connection.connection.cursor(ident)
|
|
else:
|
|
return self._connection.connection.cursor()
|
|
|
|
def get_result_proxy(self):
|
|
if self.__is_server_side:
|
|
return base.BufferedRowResultProxy(self)
|
|
else:
|
|
return base.ResultProxy(self)
|
|
|
|
class PGDialect(default.DefaultDialect):
|
|
name = 'postgres'
|
|
supports_alter = True
|
|
supports_unicode_statements = False
|
|
max_identifier_length = 63
|
|
supports_sane_rowcount = True
|
|
supports_sane_multi_rowcount = False
|
|
preexecute_pk_sequences = True
|
|
supports_pk_autoincrement = False
|
|
default_paramstyle = 'pyformat'
|
|
supports_default_values = True
|
|
supports_empty_insert = False
|
|
|
|
def __init__(self, server_side_cursors=False, **kwargs):
|
|
default.DefaultDialect.__init__(self, **kwargs)
|
|
self.server_side_cursors = server_side_cursors
|
|
|
|
def dbapi(cls):
|
|
import psycopg2 as psycopg
|
|
return psycopg
|
|
dbapi = classmethod(dbapi)
|
|
|
|
def create_connect_args(self, url):
|
|
opts = url.translate_connect_args(username='user')
|
|
if 'port' in opts:
|
|
opts['port'] = int(opts['port'])
|
|
opts.update(url.query)
|
|
return ([], opts)
|
|
|
|
def create_execution_context(self, *args, **kwargs):
|
|
return PGExecutionContext(self, *args, **kwargs)
|
|
|
|
def type_descriptor(self, typeobj):
|
|
return sqltypes.adapt_type(typeobj, colspecs)
|
|
|
|
def do_begin_twophase(self, connection, xid):
|
|
self.do_begin(connection.connection)
|
|
|
|
def do_prepare_twophase(self, connection, xid):
|
|
connection.execute(sql.text("PREPARE TRANSACTION :tid", bindparams=[sql.bindparam('tid', xid)]))
|
|
|
|
def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
|
|
if is_prepared:
|
|
if recover:
|
|
#FIXME: ugly hack to get out of transaction context when commiting recoverable transactions
|
|
# Must find out a way how to make the dbapi not open a transaction.
|
|
connection.execute(sql.text("ROLLBACK"))
|
|
connection.execute(sql.text("ROLLBACK PREPARED :tid", bindparams=[sql.bindparam('tid', xid)]))
|
|
connection.execute(sql.text("BEGIN"))
|
|
self.do_rollback(connection.connection)
|
|
else:
|
|
self.do_rollback(connection.connection)
|
|
|
|
def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
|
|
if is_prepared:
|
|
if recover:
|
|
connection.execute(sql.text("ROLLBACK"))
|
|
connection.execute(sql.text("COMMIT PREPARED :tid", bindparams=[sql.bindparam('tid', xid)]))
|
|
connection.execute(sql.text("BEGIN"))
|
|
self.do_rollback(connection.connection)
|
|
else:
|
|
self.do_commit(connection.connection)
|
|
|
|
def do_recover_twophase(self, connection):
|
|
resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts"))
|
|
return [row[0] for row in resultset]
|
|
|
|
def get_default_schema_name(self, connection):
|
|
return connection.scalar("select current_schema()", None)
|
|
get_default_schema_name = base.connection_memoize(
|
|
('dialect', 'default_schema_name'))(get_default_schema_name)
|
|
|
|
def last_inserted_ids(self):
|
|
if self.context.last_inserted_ids is None:
|
|
raise exc.InvalidRequestError("no INSERT executed, or can't use cursor.lastrowid without Postgres OIDs enabled")
|
|
else:
|
|
return self.context.last_inserted_ids
|
|
|
|
def has_table(self, connection, table_name, schema=None):
|
|
# seems like case gets folded in pg_class...
|
|
if schema is None:
|
|
cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=current_schema() and lower(relname)=%(name)s""", {'name':table_name.lower().encode(self.encoding)});
|
|
else:
|
|
cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=%(schema)s and lower(relname)=%(name)s""", {'name':table_name.lower().encode(self.encoding), 'schema':schema});
|
|
return bool( not not cursor.rowcount )
|
|
|
|
def has_sequence(self, connection, sequence_name):
|
|
cursor = connection.execute('''SELECT relname FROM pg_class WHERE relkind = 'S' AND relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%%' AND nspname != 'information_schema' AND relname = %(seqname)s);''', {'seqname': sequence_name.encode(self.encoding)})
|
|
return bool(not not cursor.rowcount)
|
|
|
|
def is_disconnect(self, e):
|
|
if isinstance(e, self.dbapi.OperationalError):
|
|
return 'closed the connection' in str(e) or 'connection not open' in str(e)
|
|
elif isinstance(e, self.dbapi.InterfaceError):
|
|
return 'connection already closed' in str(e) or 'cursor already closed' in str(e)
|
|
elif isinstance(e, self.dbapi.ProgrammingError):
|
|
# yes, it really says "losed", not "closed"
|
|
return "losed the connection unexpectedly" in str(e)
|
|
else:
|
|
return False
|
|
|
|
def table_names(self, connection, schema):
|
|
s = """
|
|
SELECT relname
|
|
FROM pg_class c
|
|
WHERE relkind = 'r'
|
|
AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace)
|
|
""" % locals()
|
|
return [row[0].decode(self.encoding) for row in connection.execute(s)]
|
|
|
|
def server_version_info(self, connection):
|
|
v = connection.execute("select version()").scalar()
|
|
m = re.match('PostgreSQL (\d+)\.(\d+)\.(\d+)', v)
|
|
if not m:
|
|
raise AssertionError("Could not determine version from string '%s'" % v)
|
|
return tuple([int(x) for x in m.group(1, 2, 3)])
|
|
|
|
def reflecttable(self, connection, table, include_columns):
|
|
preparer = self.identifier_preparer
|
|
if table.schema is not None:
|
|
schema_where_clause = "n.nspname = :schema"
|
|
schemaname = table.schema
|
|
if isinstance(schemaname, str):
|
|
schemaname = schemaname.decode(self.encoding)
|
|
else:
|
|
schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)"
|
|
schemaname = None
|
|
|
|
SQL_COLS = """
|
|
SELECT a.attname,
|
|
pg_catalog.format_type(a.atttypid, a.atttypmod),
|
|
(SELECT substring(d.adsrc for 128) FROM pg_catalog.pg_attrdef d
|
|
WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef)
|
|
AS DEFAULT,
|
|
a.attnotnull, a.attnum, a.attrelid as table_oid
|
|
FROM pg_catalog.pg_attribute a
|
|
WHERE a.attrelid = (
|
|
SELECT c.oid
|
|
FROM pg_catalog.pg_class c
|
|
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
|
|
WHERE (%s)
|
|
AND c.relname = :table_name AND c.relkind in ('r','v')
|
|
) AND a.attnum > 0 AND NOT a.attisdropped
|
|
ORDER BY a.attnum
|
|
""" % schema_where_clause
|
|
|
|
s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode, 'default':sqltypes.Unicode})
|
|
tablename = table.name
|
|
if isinstance(tablename, str):
|
|
tablename = tablename.decode(self.encoding)
|
|
c = connection.execute(s, table_name=tablename, schema=schemaname)
|
|
rows = c.fetchall()
|
|
|
|
if not rows:
|
|
raise exc.NoSuchTableError(table.name)
|
|
|
|
domains = self._load_domains(connection)
|
|
|
|
for name, format_type, default, notnull, attnum, table_oid in rows:
|
|
if include_columns and name not in include_columns:
|
|
continue
|
|
|
|
## strip (30) from character varying(30)
|
|
attype = re.search('([^\([]+)', format_type).group(1)
|
|
nullable = not notnull
|
|
is_array = format_type.endswith('[]')
|
|
|
|
try:
|
|
charlen = re.search('\(([\d,]+)\)', format_type).group(1)
|
|
except:
|
|
charlen = False
|
|
|
|
numericprec = False
|
|
numericscale = False
|
|
if attype == 'numeric':
|
|
if charlen is False:
|
|
numericprec, numericscale = (None, None)
|
|
else:
|
|
numericprec, numericscale = charlen.split(',')
|
|
charlen = False
|
|
if attype == 'double precision':
|
|
numericprec, numericscale = (53, False)
|
|
charlen = False
|
|
if attype == 'integer':
|
|
numericprec, numericscale = (32, 0)
|
|
charlen = False
|
|
|
|
args = []
|
|
for a in (charlen, numericprec, numericscale):
|
|
if a is None:
|
|
args.append(None)
|
|
elif a is not False:
|
|
args.append(int(a))
|
|
|
|
kwargs = {}
|
|
if attype == 'timestamp with time zone':
|
|
kwargs['timezone'] = True
|
|
elif attype == 'timestamp without time zone':
|
|
kwargs['timezone'] = False
|
|
|
|
if attype in ischema_names:
|
|
coltype = ischema_names[attype]
|
|
else:
|
|
if attype in domains:
|
|
domain = domains[attype]
|
|
if domain['attype'] in ischema_names:
|
|
# A table can't override whether the domain is nullable.
|
|
nullable = domain['nullable']
|
|
|
|
if domain['default'] and not default:
|
|
# It can, however, override the default value, but can't set it to null.
|
|
default = domain['default']
|
|
coltype = ischema_names[domain['attype']]
|
|
else:
|
|
coltype = None
|
|
|
|
if coltype:
|
|
coltype = coltype(*args, **kwargs)
|
|
if is_array:
|
|
coltype = PGArray(coltype)
|
|
else:
|
|
util.warn("Did not recognize type '%s' of column '%s'" %
|
|
(attype, name))
|
|
coltype = sqltypes.NULLTYPE
|
|
|
|
colargs = []
|
|
if default is not None:
|
|
match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
|
|
if match is not None:
|
|
# the default is related to a Sequence
|
|
sch = table.schema
|
|
if '.' not in match.group(2) and sch is not None:
|
|
# unconditionally quote the schema name. this could
|
|
# later be enhanced to obey quoting rules / "quote schema"
|
|
default = match.group(1) + ('"%s"' % sch) + '.' + match.group(2) + match.group(3)
|
|
colargs.append(schema.DefaultClause(sql.text(default)))
|
|
table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))
|
|
|
|
|
|
# Primary keys
|
|
PK_SQL = """
|
|
SELECT attname FROM pg_attribute
|
|
WHERE attrelid = (
|
|
SELECT indexrelid FROM pg_index i
|
|
WHERE i.indrelid = :table
|
|
AND i.indisprimary = 't')
|
|
ORDER BY attnum
|
|
"""
|
|
t = sql.text(PK_SQL, typemap={'attname':sqltypes.Unicode})
|
|
c = connection.execute(t, table=table_oid)
|
|
for row in c.fetchall():
|
|
pk = row[0]
|
|
col = table.c[pk]
|
|
table.primary_key.add(col)
|
|
if col.default is None:
|
|
col.autoincrement = False
|
|
|
|
# Foreign keys
|
|
FK_SQL = """
|
|
SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef
|
|
FROM pg_catalog.pg_constraint r
|
|
WHERE r.conrelid = :table AND r.contype = 'f'
|
|
ORDER BY 1
|
|
"""
|
|
|
|
t = sql.text(FK_SQL, typemap={'conname':sqltypes.Unicode, 'condef':sqltypes.Unicode})
|
|
c = connection.execute(t, table=table_oid)
|
|
for conname, condef in c.fetchall():
|
|
m = re.search('FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)', condef).groups()
|
|
(constrained_columns, referred_schema, referred_table, referred_columns) = m
|
|
constrained_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', constrained_columns)]
|
|
if referred_schema:
|
|
referred_schema = preparer._unquote_identifier(referred_schema)
|
|
elif table.schema is not None and table.schema == self.get_default_schema_name(connection):
|
|
# no schema (i.e. its the default schema), and the table we're
|
|
# reflecting has the default schema explicit, then use that.
|
|
# i.e. try to use the user's conventions
|
|
referred_schema = table.schema
|
|
referred_table = preparer._unquote_identifier(referred_table)
|
|
referred_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s', referred_columns)]
|
|
|
|
refspec = []
|
|
if referred_schema is not None:
|
|
schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema,
|
|
autoload_with=connection)
|
|
for column in referred_columns:
|
|
refspec.append(".".join([referred_schema, referred_table, column]))
|
|
else:
|
|
schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection)
|
|
for column in referred_columns:
|
|
refspec.append(".".join([referred_table, column]))
|
|
|
|
table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname))
|
|
|
|
def _load_domains(self, connection):
|
|
## Load data types for domains:
|
|
SQL_DOMAINS = """
|
|
SELECT t.typname as "name",
|
|
pg_catalog.format_type(t.typbasetype, t.typtypmod) as "attype",
|
|
not t.typnotnull as "nullable",
|
|
t.typdefault as "default",
|
|
pg_catalog.pg_type_is_visible(t.oid) as "visible",
|
|
n.nspname as "schema"
|
|
FROM pg_catalog.pg_type t
|
|
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
|
|
LEFT JOIN pg_catalog.pg_constraint r ON t.oid = r.contypid
|
|
WHERE t.typtype = 'd'
|
|
"""
|
|
|
|
s = sql.text(SQL_DOMAINS, typemap={'attname':sqltypes.Unicode})
|
|
c = connection.execute(s)
|
|
|
|
domains = {}
|
|
for domain in c.fetchall():
|
|
## strip (30) from character varying(30)
|
|
attype = re.search('([^\(]+)', domain['attype']).group(1)
|
|
if domain['visible']:
|
|
# 'visible' just means whether or not the domain is in a
|
|
# schema that's on the search path -- or not overriden by
|
|
# a schema with higher presedence. If it's not visible,
|
|
# it will be prefixed with the schema-name when it's used.
|
|
name = domain['name']
|
|
else:
|
|
name = "%s.%s" % (domain['schema'], domain['name'])
|
|
|
|
domains[name] = {'attype':attype, 'nullable': domain['nullable'], 'default': domain['default']}
|
|
|
|
return domains
|
|
|
|
|
|
|
|
class PGCompiler(compiler.DefaultCompiler):
|
|
operators = compiler.DefaultCompiler.operators.copy()
|
|
operators.update(
|
|
{
|
|
sql_operators.mod : '%%',
|
|
sql_operators.ilike_op: lambda x, y, escape=None: '%s ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
|
|
sql_operators.notilike_op: lambda x, y, escape=None: '%s NOT ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
|
|
sql_operators.match_op: lambda x, y: '%s @@ to_tsquery(%s)' % (x, y),
|
|
}
|
|
)
|
|
|
|
functions = compiler.DefaultCompiler.functions.copy()
|
|
functions.update (
|
|
{
|
|
'TIMESTAMP':lambda x:'TIMESTAMP %s' % x,
|
|
}
|
|
)
|
|
|
|
def visit_sequence(self, seq):
|
|
if seq.optional:
|
|
return None
|
|
else:
|
|
return "nextval('%s')" % self.preparer.format_sequence(seq)
|
|
|
|
def limit_clause(self, select):
|
|
text = ""
|
|
if select._limit is not None:
|
|
text += " \n LIMIT " + str(select._limit)
|
|
if select._offset is not None:
|
|
if select._limit is None:
|
|
text += " \n LIMIT ALL"
|
|
text += " OFFSET " + str(select._offset)
|
|
return text
|
|
|
|
def get_select_precolumns(self, select):
|
|
if select._distinct:
|
|
if isinstance(select._distinct, bool):
|
|
return "DISTINCT "
|
|
elif isinstance(select._distinct, (list, tuple)):
|
|
return "DISTINCT ON (" + ', '.join(
|
|
[(isinstance(col, basestring) and col or self.process(col)) for col in select._distinct]
|
|
)+ ") "
|
|
else:
|
|
return "DISTINCT ON (" + unicode(select._distinct) + ") "
|
|
else:
|
|
return ""
|
|
|
|
def for_update_clause(self, select):
|
|
if select.for_update == 'nowait':
|
|
return " FOR UPDATE NOWAIT"
|
|
else:
|
|
return super(PGCompiler, self).for_update_clause(select)
|
|
|
|
def _append_returning(self, text, stmt):
|
|
returning_cols = stmt.kwargs['postgres_returning']
|
|
def flatten_columnlist(collist):
|
|
for c in collist:
|
|
if isinstance(c, expression.Selectable):
|
|
for co in c.columns:
|
|
yield co
|
|
else:
|
|
yield c
|
|
columns = [self.process(c, within_columns_clause=True) for c in flatten_columnlist(returning_cols)]
|
|
text += ' RETURNING ' + string.join(columns, ', ')
|
|
return text
|
|
|
|
def visit_update(self, update_stmt):
|
|
text = super(PGCompiler, self).visit_update(update_stmt)
|
|
if 'postgres_returning' in update_stmt.kwargs:
|
|
return self._append_returning(text, update_stmt)
|
|
else:
|
|
return text
|
|
|
|
def visit_insert(self, insert_stmt):
|
|
text = super(PGCompiler, self).visit_insert(insert_stmt)
|
|
if 'postgres_returning' in insert_stmt.kwargs:
|
|
return self._append_returning(text, insert_stmt)
|
|
else:
|
|
return text
|
|
|
|
class PGSchemaGenerator(compiler.SchemaGenerator):
|
|
def get_column_specification(self, column, **kwargs):
|
|
colspec = self.preparer.format_column(column)
|
|
if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
|
|
if isinstance(column.type, PGBigInteger):
|
|
colspec += " BIGSERIAL"
|
|
else:
|
|
colspec += " SERIAL"
|
|
else:
|
|
colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
|
|
default = self.get_column_default_string(column)
|
|
if default is not None:
|
|
colspec += " DEFAULT " + default
|
|
|
|
if not column.nullable:
|
|
colspec += " NOT NULL"
|
|
return colspec
|
|
|
|
def visit_sequence(self, sequence):
|
|
if not sequence.optional and (not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name)):
|
|
self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
|
|
self.execute()
|
|
|
|
def visit_index(self, index):
|
|
preparer = self.preparer
|
|
self.append("CREATE ")
|
|
if index.unique:
|
|
self.append("UNIQUE ")
|
|
self.append("INDEX %s ON %s (%s)" \
|
|
% (preparer.quote(self._validate_identifier(index.name, True), index.quote),
|
|
preparer.format_table(index.table),
|
|
string.join([preparer.format_column(c) for c in index.columns], ', ')))
|
|
whereclause = index.kwargs.get('postgres_where', None)
|
|
if whereclause is not None:
|
|
compiler = self._compile(whereclause, None)
|
|
# this might belong to the compiler class
|
|
inlined_clause = str(compiler) % dict(
|
|
[(key,bind.value) for key,bind in compiler.binds.iteritems()])
|
|
self.append(" WHERE " + inlined_clause)
|
|
self.execute()
|
|
|
|
class PGSchemaDropper(compiler.SchemaDropper):
|
|
def visit_sequence(self, sequence):
|
|
if not sequence.optional and (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)):
|
|
self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence))
|
|
self.execute()
|
|
|
|
class PGDefaultRunner(base.DefaultRunner):
|
|
def __init__(self, context):
|
|
base.DefaultRunner.__init__(self, context)
|
|
# craete cursor which won't conflict with a server-side cursor
|
|
self.cursor = context._connection.connection.cursor()
|
|
|
|
def get_column_default(self, column, isinsert=True):
|
|
if column.primary_key:
|
|
# pre-execute passive defaults on primary keys
|
|
if (isinstance(column.server_default, schema.DefaultClause) and
|
|
column.server_default.arg is not None):
|
|
return self.execute_string("select %s" % column.server_default.arg)
|
|
elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
|
|
sch = column.table.schema
|
|
# TODO: this has to build into the Sequence object so we can get the quoting
|
|
# logic from it
|
|
if sch is not None:
|
|
exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
|
|
else:
|
|
exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
|
|
return self.execute_string(exc.encode(self.dialect.encoding))
|
|
|
|
return super(PGDefaultRunner, self).get_column_default(column)
|
|
|
|
def visit_sequence(self, seq):
|
|
if not seq.optional:
|
|
return self.execute_string(("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)))
|
|
else:
|
|
return None
|
|
|
|
class PGIdentifierPreparer(compiler.IdentifierPreparer):
|
|
def _unquote_identifier(self, value):
|
|
if value[0] == self.initial_quote:
|
|
value = value[1:-1].replace('""','"')
|
|
return value
|
|
|
|
dialect = PGDialect
|
|
dialect.statement_compiler = PGCompiler
|
|
dialect.schemagenerator = PGSchemaGenerator
|
|
dialect.schemadropper = PGSchemaDropper
|
|
dialect.preparer = PGIdentifierPreparer
|
|
dialect.defaultrunner = PGDefaultRunner
|