Files
sqlalchemy/lib/sqlalchemy/databases/postgres.py
T

261 lines
9.1 KiB
Python

# postgres.py
# Copyright (C) 2005 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import sys, StringIO, string, types, re
import sqlalchemy.sql as sql
import sqlalchemy.engine as engine
import sqlalchemy.schema as schema
import sqlalchemy.ansisql as ansisql
import sqlalchemy.types as sqltypes
from sqlalchemy import *
import sqlalchemy.databases.information_schema as ischema
try:
import psycopg2 as psycopg
except:
try:
import psycopg
except:
psycopg = None
class PGNumeric(sqltypes.Numeric):
def get_col_spec(self):
return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
class PGFloat(sqltypes.Float):
def get_col_spec(self):
return "FLOAT(%(precision)s)" % {'precision': self.precision}
class PGInteger(sqltypes.Integer):
def get_col_spec(self):
return "INTEGER"
class PG2DateTime(sqltypes.DateTime):
def get_col_spec(self):
return "TIMESTAMP"
class PG1DateTime(sqltypes.DateTime):
def convert_bind_param(self, value):
# TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
return value
def convert_result_value(self, value):
# TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
return value
def get_col_spec(self):
return "TIMESTAMP"
class PGText(sqltypes.TEXT):
def get_col_spec(self):
return "TEXT"
class PGString(sqltypes.String):
def get_col_spec(self):
return "VARCHAR(%(length)s)" % {'length' : self.length}
class PGChar(sqltypes.CHAR):
def get_col_spec(self):
return "CHAR(%(length)s)" % {'length' : self.length}
class PGBinary(sqltypes.Binary):
def get_col_spec(self):
return "BYTEA"
class PGBoolean(sqltypes.Boolean):
def get_col_spec(self):
return "BOOLEAN"
pg2_colspecs = {
sqltypes.Integer : PGInteger,
sqltypes.Numeric : PGNumeric,
sqltypes.Float : PGFloat,
sqltypes.DateTime : PG2DateTime,
sqltypes.String : PGString,
sqltypes.Binary : PGBinary,
sqltypes.Boolean : PGBoolean,
sqltypes.TEXT : PGText,
sqltypes.CHAR: PGChar,
}
pg1_colspecs = pg2_colspecs.copy()
pg1_colspecs[sqltypes.DateTime] = PG1DateTime
pg2_ischema_names = {
'integer' : PGInteger,
'bigint' : PGInteger,
'character varying' : PGString,
'character' : PGChar,
'text' : PGText,
'numeric' : PGNumeric,
'float' : PGFloat,
'real' : PGFloat,
'double precision' : PGFloat,
'timestamp with time zone' : PG2DateTime,
'timestamp without time zone' : PG2DateTime,
'bytea' : PGBinary,
'boolean' : PGBoolean,
}
pg1_ischema_names = pg2_ischema_names.copy()
pg1_ischema_names['timestamp with time zone'] = \
pg1_ischema_names['timestamp without time zone'] = PG1DateTime
def engine(opts, **params):
return PGSQLEngine(opts, **params)
def descriptor():
return {'name':'postgres',
'description':'PostGres',
'arguments':[
('user',"Database Username",None),
('password',"Database Password",None),
('database',"Database Name",None),
('host',"Hostname", None),
]}
class PGSQLEngine(ansisql.ANSISQLEngine):
def __init__(self, opts, module=None, use_oids=False, **params):
self.use_oids = use_oids
if module is None:
if psycopg is None:
raise "Couldnt locate psycopg1 or psycopg2: specify postgres module argument"
self.module = psycopg
else:
self.module = module
# figure psycopg version 1 or 2
if self.module.__name__ == 'psycopg2':
self.version = 2
else:
self.version = 1
self.opts = opts or {}
ansisql.ANSISQLEngine.__init__(self, **params)
def connect_args(self):
return [[], self.opts]
def type_descriptor(self, typeobj):
if self.version == 2:
return sqltypes.adapt_type(typeobj, pg2_colspecs)
else:
return sqltypes.adapt_type(typeobj, pg1_colspecs)
def last_inserted_ids(self):
return self.context.last_inserted_ids
def compiler(self, statement, bindparams, **kwargs):
return PGCompiler(self, statement, bindparams, **kwargs)
def schemagenerator(self, proxy, **params):
return PGSchemaGenerator(proxy, **params)
def schemadropper(self, proxy, **params):
return PGSchemaDropper(proxy, **params)
def defaultrunner(self, proxy):
return PGDefaultRunner(self, proxy)
def get_default_schema_name(self):
if not hasattr(self, '_default_schema_name'):
self._default_schema_name = text("select current_schema()", self).scalar()
return self._default_schema_name
def last_inserted_ids(self):
return self.context.last_inserted_ids
def rowid_column_name(self):
if self.use_oids:
return "oid"
else:
return None
def pre_exec(self, proxy, statement, parameters, **kwargs):
return
def post_exec(self, proxy, compiled, parameters, **kwargs):
if getattr(compiled, "isinsert", False) and self.context.last_inserted_ids is None:
if not self.use_oids:
raise "cant use cursor.lastrowid without OIDs enabled"
else:
table = compiled.statement.table
cursor = proxy()
if cursor.lastrowid is not None and table is not None and len(table.primary_key):
s = sql.select(table.primary_key, table.rowid_column == cursor.lastrowid)
c = s.compile()
cursor = proxy(str(c), c.get_params())
row = cursor.fetchone()
self.context.last_inserted_ids = [v for v in row]
def _executemany(self, c, statement, parameters):
"""we need accurate rowcounts for updates, inserts and deletes. psycopg2 is not nice enough
to produce this correctly for an executemany, so we do our own executemany here."""
rowcount = 0
for param in parameters:
c.execute(statement, param)
rowcount += c.rowcount
self.context.rowcount = rowcount
def dbapi(self):
return self.module
def reflecttable(self, table):
if self.version == 2:
ischema_names = pg2_ischema_names
else:
ischema_names = pg1_ischema_names
ischema.reflecttable(self, table, ischema_names)
class PGCompiler(ansisql.ANSICompiler):
def visit_insert_column(self, column):
# Postgres advises against OID usage and turns it off in 8.1,
# effectively making cursor.lastrowid
# useless, effectively making reliance upon SERIAL useless.
# so all column primary key inserts must be explicitly present
if column.primary_key:
self.parameters[column.key] = None
def limit_clause(self, select):
text = ""
if select.limit is not None:
text += " \n LIMIT " + str(select.limit)
if select.offset is not None:
if select.limit is None:
text += " \n LIMIT ALL"
text += " OFFSET " + str(select.offset)
return text
class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, **kwargs):
colspec = column.name
if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
colspec += " SERIAL"
else:
colspec += " " + column.type.get_col_spec()
if not column.nullable:
colspec += " NOT NULL"
if column.primary_key and not override_pk:
colspec += " PRIMARY KEY"
if column.foreign_key:
colspec += " REFERENCES %s(%s)" % (column.column.foreign_key.column.table.name, column.column.foreign_key.column.name)
return colspec
def visit_sequence(self, sequence):
if not sequence.optional:
self.append("CREATE SEQUENCE %s" % sequence.name)
self.execute()
class PGSchemaDropper(ansisql.ANSISchemaDropper):
def visit_sequence(self, sequence):
if not sequence.optional:
self.append("DROP SEQUENCE %s" % sequence.name)
self.execute()
class PGDefaultRunner(ansisql.ANSIDefaultRunner):
def get_column_default(self, column):
if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
c = self.proxy("select nextval('%s_%s_seq')" % (column.table.name, column.name))
return c.fetchone()[0]
else:
return ansisql.ANSIDefaultRunner.get_column_default(self, column)
def visit_sequence(self, seq):
if not seq.optional:
c = self.proxy("select nextval('%s')" % seq.name)
return c.fetchone()[0]
else:
return None