mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-15 05:07:16 -04:00
705f308452
still should better define 'default_ordering'/'rowid_column' relationship since its a little kludgy.
261 lines
9.1 KiB
Python
261 lines
9.1 KiB
Python
# postgres.py
|
|
# Copyright (C) 2005 Michael Bayer mike_mp@zzzcomputing.com
|
|
#
|
|
# This module is part of SQLAlchemy and is released under
|
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
|
|
|
import sys, StringIO, string, types, re
|
|
|
|
import sqlalchemy.sql as sql
|
|
import sqlalchemy.engine as engine
|
|
import sqlalchemy.schema as schema
|
|
import sqlalchemy.ansisql as ansisql
|
|
import sqlalchemy.types as sqltypes
|
|
from sqlalchemy import *
|
|
import sqlalchemy.databases.information_schema as ischema
|
|
|
|
try:
|
|
import psycopg2 as psycopg
|
|
except:
|
|
try:
|
|
import psycopg
|
|
except:
|
|
psycopg = None
|
|
|
|
class PGNumeric(sqltypes.Numeric):
|
|
def get_col_spec(self):
|
|
return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
|
|
class PGFloat(sqltypes.Float):
|
|
def get_col_spec(self):
|
|
return "FLOAT(%(precision)s)" % {'precision': self.precision}
|
|
class PGInteger(sqltypes.Integer):
|
|
def get_col_spec(self):
|
|
return "INTEGER"
|
|
class PG2DateTime(sqltypes.DateTime):
|
|
def get_col_spec(self):
|
|
return "TIMESTAMP"
|
|
class PG1DateTime(sqltypes.DateTime):
|
|
def convert_bind_param(self, value):
|
|
# TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
|
|
return value
|
|
def convert_result_value(self, value):
|
|
# TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
|
|
return value
|
|
def get_col_spec(self):
|
|
return "TIMESTAMP"
|
|
class PGText(sqltypes.TEXT):
|
|
def get_col_spec(self):
|
|
return "TEXT"
|
|
class PGString(sqltypes.String):
|
|
def get_col_spec(self):
|
|
return "VARCHAR(%(length)s)" % {'length' : self.length}
|
|
class PGChar(sqltypes.CHAR):
|
|
def get_col_spec(self):
|
|
return "CHAR(%(length)s)" % {'length' : self.length}
|
|
class PGBinary(sqltypes.Binary):
|
|
def get_col_spec(self):
|
|
return "BYTEA"
|
|
class PGBoolean(sqltypes.Boolean):
|
|
def get_col_spec(self):
|
|
return "BOOLEAN"
|
|
|
|
pg2_colspecs = {
|
|
sqltypes.Integer : PGInteger,
|
|
sqltypes.Numeric : PGNumeric,
|
|
sqltypes.Float : PGFloat,
|
|
sqltypes.DateTime : PG2DateTime,
|
|
sqltypes.String : PGString,
|
|
sqltypes.Binary : PGBinary,
|
|
sqltypes.Boolean : PGBoolean,
|
|
sqltypes.TEXT : PGText,
|
|
sqltypes.CHAR: PGChar,
|
|
}
|
|
pg1_colspecs = pg2_colspecs.copy()
|
|
pg1_colspecs[sqltypes.DateTime] = PG1DateTime
|
|
|
|
pg2_ischema_names = {
|
|
'integer' : PGInteger,
|
|
'bigint' : PGInteger,
|
|
'character varying' : PGString,
|
|
'character' : PGChar,
|
|
'text' : PGText,
|
|
'numeric' : PGNumeric,
|
|
'float' : PGFloat,
|
|
'real' : PGFloat,
|
|
'double precision' : PGFloat,
|
|
'timestamp with time zone' : PG2DateTime,
|
|
'timestamp without time zone' : PG2DateTime,
|
|
'bytea' : PGBinary,
|
|
'boolean' : PGBoolean,
|
|
}
|
|
pg1_ischema_names = pg2_ischema_names.copy()
|
|
pg1_ischema_names['timestamp with time zone'] = \
|
|
pg1_ischema_names['timestamp without time zone'] = PG1DateTime
|
|
|
|
def engine(opts, **params):
|
|
return PGSQLEngine(opts, **params)
|
|
|
|
def descriptor():
|
|
return {'name':'postgres',
|
|
'description':'PostGres',
|
|
'arguments':[
|
|
('user',"Database Username",None),
|
|
('password',"Database Password",None),
|
|
('database',"Database Name",None),
|
|
('host',"Hostname", None),
|
|
]}
|
|
|
|
class PGSQLEngine(ansisql.ANSISQLEngine):
|
|
def __init__(self, opts, module=None, use_oids=False, **params):
|
|
self.use_oids = use_oids
|
|
if module is None:
|
|
if psycopg is None:
|
|
raise "Couldnt locate psycopg1 or psycopg2: specify postgres module argument"
|
|
self.module = psycopg
|
|
else:
|
|
self.module = module
|
|
# figure psycopg version 1 or 2
|
|
if self.module.__name__ == 'psycopg2':
|
|
self.version = 2
|
|
else:
|
|
self.version = 1
|
|
self.opts = opts or {}
|
|
ansisql.ANSISQLEngine.__init__(self, **params)
|
|
|
|
def connect_args(self):
|
|
return [[], self.opts]
|
|
|
|
def type_descriptor(self, typeobj):
|
|
if self.version == 2:
|
|
return sqltypes.adapt_type(typeobj, pg2_colspecs)
|
|
else:
|
|
return sqltypes.adapt_type(typeobj, pg1_colspecs)
|
|
|
|
def last_inserted_ids(self):
|
|
return self.context.last_inserted_ids
|
|
|
|
def compiler(self, statement, bindparams, **kwargs):
|
|
return PGCompiler(self, statement, bindparams, **kwargs)
|
|
|
|
def schemagenerator(self, proxy, **params):
|
|
return PGSchemaGenerator(proxy, **params)
|
|
|
|
def schemadropper(self, proxy, **params):
|
|
return PGSchemaDropper(proxy, **params)
|
|
|
|
def defaultrunner(self, proxy):
|
|
return PGDefaultRunner(self, proxy)
|
|
|
|
def get_default_schema_name(self):
|
|
if not hasattr(self, '_default_schema_name'):
|
|
self._default_schema_name = text("select current_schema()", self).scalar()
|
|
return self._default_schema_name
|
|
|
|
def last_inserted_ids(self):
|
|
return self.context.last_inserted_ids
|
|
|
|
def rowid_column_name(self):
|
|
if self.use_oids:
|
|
return "oid"
|
|
else:
|
|
return None
|
|
|
|
def pre_exec(self, proxy, statement, parameters, **kwargs):
|
|
return
|
|
|
|
def post_exec(self, proxy, compiled, parameters, **kwargs):
|
|
if getattr(compiled, "isinsert", False) and self.context.last_inserted_ids is None:
|
|
if not self.use_oids:
|
|
raise "cant use cursor.lastrowid without OIDs enabled"
|
|
else:
|
|
table = compiled.statement.table
|
|
cursor = proxy()
|
|
if cursor.lastrowid is not None and table is not None and len(table.primary_key):
|
|
s = sql.select(table.primary_key, table.rowid_column == cursor.lastrowid)
|
|
c = s.compile()
|
|
cursor = proxy(str(c), c.get_params())
|
|
row = cursor.fetchone()
|
|
self.context.last_inserted_ids = [v for v in row]
|
|
|
|
def _executemany(self, c, statement, parameters):
|
|
"""we need accurate rowcounts for updates, inserts and deletes. psycopg2 is not nice enough
|
|
to produce this correctly for an executemany, so we do our own executemany here."""
|
|
rowcount = 0
|
|
for param in parameters:
|
|
c.execute(statement, param)
|
|
rowcount += c.rowcount
|
|
self.context.rowcount = rowcount
|
|
|
|
def dbapi(self):
|
|
return self.module
|
|
|
|
def reflecttable(self, table):
|
|
if self.version == 2:
|
|
ischema_names = pg2_ischema_names
|
|
else:
|
|
ischema_names = pg1_ischema_names
|
|
|
|
ischema.reflecttable(self, table, ischema_names)
|
|
|
|
class PGCompiler(ansisql.ANSICompiler):
|
|
|
|
def visit_insert_column(self, column):
|
|
# Postgres advises against OID usage and turns it off in 8.1,
|
|
# effectively making cursor.lastrowid
|
|
# useless, effectively making reliance upon SERIAL useless.
|
|
# so all column primary key inserts must be explicitly present
|
|
if column.primary_key:
|
|
self.parameters[column.key] = None
|
|
|
|
def limit_clause(self, select):
|
|
text = ""
|
|
if select.limit is not None:
|
|
text += " \n LIMIT " + str(select.limit)
|
|
if select.offset is not None:
|
|
if select.limit is None:
|
|
text += " \n LIMIT ALL"
|
|
text += " OFFSET " + str(select.offset)
|
|
return text
|
|
|
|
class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
|
|
def get_column_specification(self, column, override_pk=False, **kwargs):
|
|
colspec = column.name
|
|
if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
|
|
colspec += " SERIAL"
|
|
else:
|
|
colspec += " " + column.type.get_col_spec()
|
|
|
|
if not column.nullable:
|
|
colspec += " NOT NULL"
|
|
if column.primary_key and not override_pk:
|
|
colspec += " PRIMARY KEY"
|
|
if column.foreign_key:
|
|
colspec += " REFERENCES %s(%s)" % (column.column.foreign_key.column.table.name, column.column.foreign_key.column.name)
|
|
return colspec
|
|
|
|
def visit_sequence(self, sequence):
|
|
if not sequence.optional:
|
|
self.append("CREATE SEQUENCE %s" % sequence.name)
|
|
self.execute()
|
|
|
|
class PGSchemaDropper(ansisql.ANSISchemaDropper):
|
|
def visit_sequence(self, sequence):
|
|
if not sequence.optional:
|
|
self.append("DROP SEQUENCE %s" % sequence.name)
|
|
self.execute()
|
|
|
|
class PGDefaultRunner(ansisql.ANSIDefaultRunner):
|
|
def get_column_default(self, column):
|
|
if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
|
|
c = self.proxy("select nextval('%s_%s_seq')" % (column.table.name, column.name))
|
|
return c.fetchone()[0]
|
|
else:
|
|
return ansisql.ANSIDefaultRunner.get_column_default(self, column)
|
|
|
|
def visit_sequence(self, seq):
|
|
if not seq.optional:
|
|
c = self.proxy("select nextval('%s')" % seq.name)
|
|
return c.fetchone()[0]
|
|
else:
|
|
return None
|