Merge pull request #5 from cjw296/pg-ranges

Support for Postgres range types.
This commit is contained in:
mike bayer
2013-06-22 07:47:02 -07:00
7 changed files with 640 additions and 4 deletions
+50 -1
View File
@@ -16,7 +16,8 @@ they originate from :mod:`sqlalchemy.types` or from the local dialect::
ARRAY, BIGINT, BIT, BOOLEAN, BYTEA, CHAR, CIDR, DATE, \
DOUBLE_PRECISION, ENUM, FLOAT, HSTORE, INET, INTEGER, \
INTERVAL, MACADDR, NUMERIC, REAL, SMALLINT, TEXT, TIME, \
TIMESTAMP, UUID, VARCHAR
TIMESTAMP, UUID, VARCHAR, INT4RANGE, INT8RANGE, NUMRANGE, \
DATERANGE, TSRANGE, TSTZRANGE
Types which are specific to PostgreSQL, or have PostgreSQL-specific
construction arguments, are as follows:
@@ -81,6 +82,54 @@ construction arguments, are as follows:
:members: __init__
:show-inheritance:
.. autoclass:: sqlalchemy.dialects.postgresql.ranges.RangeOperators
:members:
.. autoclass:: INT4RANGE
:show-inheritance:
.. autoclass:: INT8RANGE
:show-inheritance:
.. autoclass:: NUMRANGE
:show-inheritance:
.. autoclass:: DATERANGE
:show-inheritance:
.. autoclass:: TSRANGE
:show-inheritance:
.. autoclass:: TSTZRANGE
:show-inheritance:
PostgreSQL Constraint Types
---------------------------
SQLAlchemy supports Postgresql EXCLUDE constraints via the
:class:`ExcludeConstraint` class:
.. autoclass:: ExcludeConstraint
:show-inheritance:
:members: __init__
For example::
from sqlalchemy.dialects.postgresql import (
ExcludeConstraint,
TSRANGE as Range,
)
class RoomBookings(Base):
room = Column(Integer(), primary_key=True)
during = Column(TSRANGE())
__table_args__ = (
ExcludeConstraint(('room', '='), ('during', '&&')),
)
psycopg2
--------------
@@ -12,12 +12,16 @@ from .base import \
INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, \
INET, CIDR, UUID, BIT, MACADDR, DOUBLE_PRECISION, TIMESTAMP, TIME, \
DATE, BYTEA, BOOLEAN, INTERVAL, ARRAY, ENUM, dialect, array, Any, All
from .constraints import ExcludeConstraint
from .hstore import HSTORE, hstore
from .ranges import INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, \
TSTZRANGE
__all__ = (
'INTEGER', 'BIGINT', 'SMALLINT', 'VARCHAR', 'CHAR', 'TEXT', 'NUMERIC',
'FLOAT', 'REAL', 'INET', 'CIDR', 'UUID', 'BIT', 'MACADDR',
'DOUBLE_PRECISION', 'TIMESTAMP', 'TIME', 'DATE', 'BYTEA', 'BOOLEAN',
'INTERVAL', 'ARRAY', 'ENUM', 'dialect', 'Any', 'All', 'array', 'HSTORE',
'hstore'
'hstore', 'INT4RANGE', 'INT8RANGE', 'NUMRANGE', 'DATERANGE',
'TSRANGE', 'TSTZRANGE'
)
+35 -1
View File
@@ -443,7 +443,7 @@ class array(expression.Tuple):
An instance of :class:`.array` will always have the datatype
:class:`.ARRAY`. The "inner" type of the array is inferred from
the values present, unless the "type_" keyword argument is passed::
the values present, unless the ``type_`` keyword argument is passed::
array(['foo', 'bar'], type_=CHAR)
@@ -1141,6 +1141,22 @@ class PGDDLCompiler(compiler.DDLCompiler):
text += " WHERE " + where_compiled
return text
def visit_exclude_constraint(self, constraint):
text = ""
if constraint.name is not None:
text += "CONSTRAINT %s " % \
self.preparer.format_constraint(constraint)
elements = []
for c in constraint.columns:
op = constraint.operators[c.name]
elements.append(self.preparer.quote(c.name, c.quote)+' WITH '+op)
text += "EXCLUDE USING %s (%s)" % (constraint.using, ', '.join(elements))
if constraint.where is not None:
sqltext = sql_util.expression_as_ddl(constraint.where)
text += ' WHERE (%s)' % self.sql_compiler.process(sqltext)
text += self.define_constraint_deferrability(constraint)
return text
class PGTypeCompiler(compiler.GenericTypeCompiler):
def visit_INET(self, type_):
@@ -1167,6 +1183,24 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
def visit_HSTORE(self, type_):
return "HSTORE"
def visit_INT4RANGE(self, type_):
return "INT4RANGE"
def visit_INT8RANGE(self, type_):
return "INT8RANGE"
def visit_NUMRANGE(self, type_):
return "NUMRANGE"
def visit_DATERANGE(self, type_):
return "DATERANGE"
def visit_TSRANGE(self, type_):
return "TSRANGE"
def visit_TSTZRANGE(self, type_):
return "TSTZRANGE"
def visit_datetime(self, type_):
return self.visit_TIMESTAMP(type_)
@@ -0,0 +1,73 @@
# Copyright (C) 2013 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from sqlalchemy.schema import ColumnCollectionConstraint
from sqlalchemy.sql import expression
class ExcludeConstraint(ColumnCollectionConstraint):
"""A table-level EXCLUDE constraint.
Defines an EXCLUDE constraint as described in the `postgres
documentation`__.
__ http://www.postgresql.org/docs/9.0/static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE
"""
__visit_name__ = 'exclude_constraint'
where = None
def __init__(self, *elements, **kw):
"""
:param \*elements:
A sequence of two tuples of the form ``(column, operator)`` where
column must be a column name or Column object and operator must
be a string containing the operator to use.
:param name:
Optional, the in-database name of this constraint.
:param deferrable:
Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when
issuing DDL for this constraint.
:param initially:
Optional string. If set, emit INITIALLY <value> when issuing DDL
for this constraint.
:param using:
Optional string. If set, emit USING <index_method> when issuing DDL
for this constraint. Defaults to 'gist'.
:param where:
Optional string. If set, emit WHERE <predicate> when issuing DDL
for this constraint.
"""
ColumnCollectionConstraint.__init__(
self,
*[col for col, op in elements],
name=kw.get('name'),
deferrable=kw.get('deferrable'),
initially=kw.get('initially')
)
self.operators = {}
for col_or_string, op in elements:
name = getattr(col_or_string, 'name', col_or_string)
self.operators[name] = op
self.using = kw.get('using', 'gist')
where = kw.get('where')
if where:
self.where = expression._literal_as_text(where)
def copy(self, **kw):
elements = [(col, self.operators[col])
for col in self.columns.keys()]
c = self.__class__(*elements,
name=self.name,
deferrable=self.deferrable,
initially=self.initially)
c.dispatch._update(self.dispatch)
return c
@@ -0,0 +1,133 @@
# Copyright (C) 2013 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from .base import ischema_names
from ... import types as sqltypes
__all__ = ('INT4RANGE', 'INT8RANGE', 'NUMRANGE')
class RangeOperators(object):
"""
This mixin provides functionality for the Range Operators
listed in Table 9-44 of the `postgres documentation`__ for Range
Functions and Operators. It is used by all the range types
provided in the ``postgres`` dialect and can likely be used for
any range types you create yourself.
__ http://www.postgresql.org/docs/devel/static/functions-range.html
No extra support is provided for the Range Functions listed in
Table 9-45 of the postgres documentation. For these, the normal
:func:`~sqlalchemy.sql.expression.func` object should be used.
"""
class comparator_factory(sqltypes.Concatenable.Comparator):
"""Define comparison operations for range types."""
def __ne__(self, other):
"Boolean expression. Returns true if two ranges are not equal"
return self.expr.op('<>')(other)
def contains(self, other, **kw):
"""Boolean expression. Returns true if the right hand operand,
which can be an element or a range, is contained within the
column.
"""
return self.expr.op('@>')(other)
def contained_by(self, other):
"""Boolean expression. Returns true if the column is contained
within the right hand operand.
"""
return self.expr.op('<@')(other)
def overlaps(self, other):
"""Boolean expression. Returns true if the column overlaps
(has points in common with) the right hand operand.
"""
return self.expr.op('&&')(other)
def strictly_left_of(self, other):
"""Boolean expression. Returns true if the column is strictly
left of the right hand operand.
"""
return self.expr.op('<<')(other)
__lshift__ = strictly_left_of
def strictly_right_of(self, other):
"""Boolean expression. Returns true if the column is strictly
right of the right hand operand.
"""
return self.expr.op('>>')(other)
__rshift__ = strictly_right_of
def not_extend_right_of(self, other):
"""Boolean expression. Returns true if the range in the column
does not extend right of the range in the operand.
"""
return self.expr.op('&<')(other)
def not_extend_left_of(self, other):
"""Boolean expression. Returns true if the range in the column
does not extend left of the range in the operand.
"""
return self.expr.op('&>')(other)
def adjacent_to(self, other):
"""Boolean expression. Returns true if the range in the column
is adjacent to the range in the operand.
"""
return self.expr.op('-|-')(other)
def __add__(self, other):
"""Range expression. Returns the union of the two ranges.
Will raise an exception if the resulting range is not
contigous.
"""
return self.expr.op('+')(other)
class INT4RANGE(RangeOperators, sqltypes.TypeEngine):
"Represent the Postgresql INT4RANGE type."
__visit_name__ = 'INT4RANGE'
ischema_names['int4range'] = INT4RANGE
class INT8RANGE(RangeOperators, sqltypes.TypeEngine):
"Represent the Postgresql INT8RANGE type."
__visit_name__ = 'INT8RANGE'
ischema_names['int8range'] = INT8RANGE
class NUMRANGE(RangeOperators, sqltypes.TypeEngine):
"Represent the Postgresql NUMRANGE type."
__visit_name__ = 'NUMRANGE'
ischema_names['numrange'] = NUMRANGE
class DATERANGE(RangeOperators, sqltypes.TypeEngine):
"Represent the Postgresql DATERANGE type."
__visit_name__ = 'DATERANGE'
ischema_names['daterange'] = DATERANGE
class TSRANGE(RangeOperators, sqltypes.TypeEngine):
"Represent the Postgresql TSRANGE type."
__visit_name__ = 'TSRANGE'
ischema_names['tsrange'] = TSRANGE
class TSTZRANGE(RangeOperators, sqltypes.TypeEngine):
"Represent the Postgresql TSTZRANGE type."
__visit_name__ = 'TSTZRANGE'
ischema_names['tstzrange'] = TSTZRANGE
+329 -1
View File
@@ -17,7 +17,9 @@ from sqlalchemy import Table, Column, select, MetaData, text, Integer, \
from sqlalchemy.orm import Session, mapper, aliased
from sqlalchemy import exc, schema, types
from sqlalchemy.dialects.postgresql import base as postgresql
from sqlalchemy.dialects.postgresql import HSTORE, hstore, array
from sqlalchemy.dialects.postgresql import HSTORE, hstore, array, \
INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, TSTZRANGE, \
ExcludeConstraint
import decimal
from sqlalchemy import util
from sqlalchemy.testing.util import round_decimal
@@ -182,6 +184,53 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
'USING hash (data)',
dialect=postgresql.dialect())
def test_exclude_constraint_min(self):
m = MetaData()
tbl = Table('testtbl', m,
Column('room', Integer, primary_key=True))
cons = ExcludeConstraint(('room', '='))
tbl.append_constraint(cons)
self.assert_compile(schema.AddConstraint(cons),
'ALTER TABLE testtbl ADD EXCLUDE USING gist '
'(room WITH =)',
dialect=postgresql.dialect())
def test_exclude_constraint_full(self):
m = MetaData()
room = Column('room', Integer, primary_key=True)
tbl = Table('testtbl', m,
room,
Column('during', TSRANGE))
room = Column('room', Integer, primary_key=True)
cons = ExcludeConstraint((room, '='), ('during', '&&'),
name='my_name',
using='gist',
where="room > 100",
deferrable=True,
initially='immediate')
tbl.append_constraint(cons)
self.assert_compile(schema.AddConstraint(cons),
'ALTER TABLE testtbl ADD CONSTRAINT my_name '
'EXCLUDE USING gist '
'(room WITH =, during WITH ''&&) WHERE '
'(room > 100) DEFERRABLE INITIALLY immediate',
dialect=postgresql.dialect())
def test_exclude_constraint_copy(self):
m = MetaData()
cons = ExcludeConstraint(('room', '='))
tbl = Table('testtbl', m,
Column('room', Integer, primary_key=True),
cons)
# apparently you can't copy a ColumnCollectionConstraint until
# after it has been bound to a table...
cons_copy = cons.copy()
tbl.append_constraint(cons_copy)
self.assert_compile(schema.AddConstraint(cons_copy),
'ALTER TABLE testtbl ADD EXCLUDE USING gist '
'(room WITH =)',
dialect=postgresql.dialect())
def test_substring(self):
self.assert_compile(func.substring('abc', 1, 2),
'SUBSTRING(%(substring_1)s FROM %(substring_2)s '
@@ -3242,3 +3291,282 @@ class HStoreRoundTripTest(fixtures.TablesTest):
def test_unicode_round_trip_native(self):
engine = testing.db
self._test_unicode_round_trip(engine)
class _RangeTypeMixin(object):
__requires__ = 'range_types',
__dialect__ = 'postgresql+psycopg2'
@property
def extras(self):
# done this way so we don't get ImportErrors with
# older psycopg2 versions.
from psycopg2 import extras
return extras
@classmethod
def define_tables(cls, metadata):
# no reason ranges shouldn't be primary keys,
# so lets just use them as such
table = Table('data_table', metadata,
Column('range', cls._col_type, primary_key=True),
)
cls.col = table.c.range
def test_actual_type(self):
eq_(str(self._col_type()), self._col_str)
def test_reflect(self):
from sqlalchemy import inspect
insp = inspect(testing.db)
cols = insp.get_columns('data_table')
assert isinstance(cols[0]['type'], self._col_type)
def _assert_data(self):
data = testing.db.execute(
select([self.tables.data_table.c.range])
).fetchall()
eq_(data, [(self._data_obj(), )])
def test_insert_obj(self):
testing.db.engine.execute(
self.tables.data_table.insert(),
{'range': self._data_obj()}
)
self._assert_data()
def test_insert_text(self):
testing.db.engine.execute(
self.tables.data_table.insert(),
{'range': self._data_str}
)
self._assert_data()
# operator tests
def _test_clause(self, colclause, expected):
dialect = postgresql.dialect()
compiled = str(colclause.compile(dialect=dialect))
eq_(compiled, expected)
def test_where_equal(self):
self._test_clause(
self.col==self._data_str,
"data_table.range = %(range_1)s"
)
def test_where_not_equal(self):
self._test_clause(
self.col!=self._data_str,
"data_table.range <> %(range_1)s"
)
def test_where_less_than(self):
self._test_clause(
self.col < self._data_str,
"data_table.range < %(range_1)s"
)
def test_where_greater_than(self):
self._test_clause(
self.col > self._data_str,
"data_table.range > %(range_1)s"
)
def test_where_less_than_or_equal(self):
self._test_clause(
self.col <= self._data_str,
"data_table.range <= %(range_1)s"
)
def test_where_greater_than_or_equal(self):
self._test_clause(
self.col >= self._data_str,
"data_table.range >= %(range_1)s"
)
def test_contains(self):
self._test_clause(
self.col.contains(self._data_str),
"data_table.range @> %(range_1)s"
)
def test_contained_by(self):
self._test_clause(
self.col.contained_by(self._data_str),
"data_table.range <@ %(range_1)s"
)
def test_overlaps(self):
self._test_clause(
self.col.overlaps(self._data_str),
"data_table.range && %(range_1)s"
)
def test_strictly_left_of(self):
self._test_clause(
self.col << self._data_str,
"data_table.range << %(range_1)s"
)
self._test_clause(
self.col.strictly_left_of(self._data_str),
"data_table.range << %(range_1)s"
)
def test_strictly_right_of(self):
self._test_clause(
self.col >> self._data_str,
"data_table.range >> %(range_1)s"
)
self._test_clause(
self.col.strictly_right_of(self._data_str),
"data_table.range >> %(range_1)s"
)
def test_not_extend_right_of(self):
self._test_clause(
self.col.not_extend_right_of(self._data_str),
"data_table.range &< %(range_1)s"
)
def test_not_extend_left_of(self):
self._test_clause(
self.col.not_extend_left_of(self._data_str),
"data_table.range &> %(range_1)s"
)
def test_adjacent_to(self):
self._test_clause(
self.col.adjacent_to(self._data_str),
"data_table.range -|- %(range_1)s"
)
def test_union(self):
self._test_clause(
self.col + self.col,
"data_table.range + data_table.range"
)
def test_union_result(self):
# insert
testing.db.engine.execute(
self.tables.data_table.insert(),
{'range': self._data_str}
)
# select
range = self.tables.data_table.c.range
data = testing.db.execute(
select([range + range])
).fetchall()
eq_(data, [(self._data_obj(), )])
def test_intersection(self):
self._test_clause(
self.col * self.col,
"data_table.range * data_table.range"
)
def test_intersection_result(self):
# insert
testing.db.engine.execute(
self.tables.data_table.insert(),
{'range': self._data_str}
)
# select
range = self.tables.data_table.c.range
data = testing.db.execute(
select([range * range])
).fetchall()
eq_(data, [(self._data_obj(), )])
def test_different(self):
self._test_clause(
self.col - self.col,
"data_table.range - data_table.range"
)
def test_difference_result(self):
# insert
testing.db.engine.execute(
self.tables.data_table.insert(),
{'range': self._data_str}
)
# select
range = self.tables.data_table.c.range
data = testing.db.execute(
select([range - range])
).fetchall()
eq_(data, [(self._data_obj().__class__(empty=True), )])
class Int4RangeTests(_RangeTypeMixin, fixtures.TablesTest):
_col_type = INT4RANGE
_col_str = 'INT4RANGE'
_data_str = '[1,2)'
def _data_obj(self):
return self.extras.NumericRange(1, 2)
class Int8RangeTests(_RangeTypeMixin, fixtures.TablesTest):
_col_type = INT8RANGE
_col_str = 'INT8RANGE'
_data_str = '[9223372036854775806,9223372036854775807)'
def _data_obj(self):
return self.extras.NumericRange(
9223372036854775806, 9223372036854775807
)
class NumRangeTests(_RangeTypeMixin, fixtures.TablesTest):
_col_type = NUMRANGE
_col_str = 'NUMRANGE'
_data_str = '[1.0,2.0)'
def _data_obj(self):
return self.extras.NumericRange(
decimal.Decimal('1.0'), decimal.Decimal('2.0')
)
class DateRangeTests(_RangeTypeMixin, fixtures.TablesTest):
_col_type = DATERANGE
_col_str = 'DATERANGE'
_data_str = '[2013-03-23,2013-03-24)'
def _data_obj(self):
return self.extras.DateRange(
datetime.date(2013, 3, 23), datetime.date(2013, 3, 24)
)
class DateTimeRangeTests(_RangeTypeMixin, fixtures.TablesTest):
_col_type = TSRANGE
_col_str = 'TSRANGE'
_data_str = '[2013-03-23 14:30,2013-03-23 23:30)'
def _data_obj(self):
return self.extras.DateTimeRange(
datetime.datetime(2013, 3, 23, 14, 30),
datetime.datetime(2013, 3, 23, 23, 30)
)
class DateTimeTZRangeTests(_RangeTypeMixin, fixtures.TablesTest):
_col_type = TSTZRANGE
_col_str = 'TSTZRANGE'
# make sure we use one, steady timestamp with timezone pair
# for all parts of all these tests
_tstzs = None
def tstzs(self):
if self._tstzs is None:
lower = testing.db.connect().scalar(
func.current_timestamp().select()
)
upper = lower+datetime.timedelta(1)
self._tstzs = (lower, upper)
return self._tstzs
@property
def _data_str(self):
return '[%s,%s)' % self.tstzs()
def _data_obj(self):
return self.extras.DateTimeTZRange(*self.tstzs())
+15
View File
@@ -602,6 +602,21 @@ class DefaultRequirements(SuiteRequirements):
return only_if(check_hstore)
@property
def range_types(self):
def check_range_types():
if not against("postgresql+psycopg2"):
return False
try:
self.db.execute("select '[1,2)'::int4range;")
# only supported in psycopg 2.5+
from psycopg2.extras import NumericRange
return True
except:
return False
return only_if(check_range_types)
@property
def sqlite(self):
return skip_if(lambda: not self._has_sqlite())