Files
peewee/playhouse/postgres_ext.py

630 lines
19 KiB
Python

import json
import logging
import uuid
from peewee import *
from peewee import ColumnBase
from peewee import Expression
from peewee import FieldDatabaseHook
from peewee import Node
from peewee import NodeList
from peewee import Psycopg2Adapter
from peewee import Psycopg3Adapter
from peewee import __exception_wrapper__
from playhouse.pool import _PooledPostgresqlDatabase
try:
from psycopg2cffi import compat
compat.register()
except ImportError:
pass
try:
from psycopg2.extras import register_hstore
except ImportError:
def register_hstore(*args): pass
try:
from psycopg.types import TypeInfo
from psycopg.types.hstore import register_hstore as register_hstore_pg3
except ImportError:
def register_hstore_pg3(*args): pass
logger = logging.getLogger('peewee')
HCONTAINS_DICT = '@>'
HCONTAINS_KEYS = '?&'
HCONTAINS_KEY = '?'
HCONTAINS_ANY_KEY = '?|'
HKEY = '->'
HUPDATE = '||'
ACONTAINS = '@>'
ACONTAINED_BY = '<@'
ACONTAINS_ANY = '&&'
TS_MATCH = '@@'
JSONB_CONTAINS = '@>'
JSONB_CONTAINED_BY = '<@'
JSONB_CONTAINS_KEY = '?'
JSONB_CONTAINS_ANY_KEY = '?|'
JSONB_CONTAINS_ALL_KEYS = '?&'
JSONB_EXISTS = '?'
JSONB_REMOVE = '-'
JSONB_PATH_REMOVE = '#-'
JSONB_PATH = '#>'
class Json(Node):
# Fallback JSON handler.
__slots__ = ('value',)
def __init__(self, value, dumps=None):
self.value = value
self.dumps = dumps or json.dumps
def __sql__(self, ctx):
return ctx.value(self.value, self.dumps)
class _LookupNode(ColumnBase):
def __init__(self, node, parts):
self.node = node
self.parts = parts
super(_LookupNode, self).__init__()
def clone(self):
return type(self)(self.node, list(self.parts))
def __hash__(self):
return hash((self.__class__.__name__, id(self)))
class ObjectSlice(_LookupNode):
@classmethod
def create(cls, node, value):
if isinstance(value, slice):
stop = value.stop - 1 if value.stop is not None else None
parts = [value.start or 0, stop]
elif isinstance(value, int):
parts = [value]
elif isinstance(value, Node):
parts = value
else:
# Assumes colon-separated integer indexes.
parts = [int(i) for i in value.split(':')]
return cls(node, parts)
def __sql__(self, ctx):
ctx.sql(self.node)
if isinstance(self.parts, Node):
ctx.literal('[').sql(self.parts).literal(']')
else:
ctx.literal('[%s]' % ':'.join([str(p + 1) if p is not None else ''
for p in self.parts]))
return ctx
def __getitem__(self, value):
return ObjectSlice.create(self, value)
class IndexedFieldMixin(object):
default_index_type = 'GIN'
def __init__(self, *args, **kwargs):
kwargs.setdefault('index', True) # By default, use an index.
super(IndexedFieldMixin, self).__init__(*args, **kwargs)
class ArrayField(IndexedFieldMixin, Field):
passthrough = True
def __init__(self, field_class=IntegerField, field_kwargs=None,
dimensions=1, convert_values=False, *args, **kwargs):
self.__field = field_class(**(field_kwargs or {}))
self.dimensions = dimensions
self.convert_values = convert_values
self.field_type = self.__field.field_type
super(ArrayField, self).__init__(*args, **kwargs)
def bind(self, model, name, set_attribute=True):
ret = super(ArrayField, self).bind(model, name, set_attribute)
self.__field.bind(model, '__array_%s' % name, False)
return ret
def ddl_datatype(self, ctx):
data_type = self.__field.ddl_datatype(ctx)
return NodeList((data_type, SQL('[]' * self.dimensions)), glue='')
def db_value(self, value):
if value is None or isinstance(value, Node):
return value
elif self.convert_values:
return self._process(self.__field.db_value, value, self.dimensions)
else:
return value if isinstance(value, list) else list(value)
def python_value(self, value):
if self.convert_values and value is not None:
conv = self.__field.python_value
if isinstance(value, list):
return self._process(conv, value, self.dimensions)
else:
return conv(value)
else:
return value
def _process(self, conv, value, dimensions):
dimensions -= 1
if dimensions == 0:
return [conv(v) for v in value]
else:
return [self._process(conv, v, dimensions) for v in value]
def __getitem__(self, value):
return ObjectSlice.create(self, value)
def _e(op):
def inner(self, rhs):
return Expression(self, op, ArrayValue(self, rhs))
return inner
__eq__ = _e(OP.EQ)
__ne__ = _e(OP.NE)
__gt__ = _e(OP.GT)
__ge__ = _e(OP.GTE)
__lt__ = _e(OP.LT)
__le__ = _e(OP.LTE)
__hash__ = Field.__hash__
def contains(self, *items):
return Expression(self, ACONTAINS, ArrayValue(self, items))
def contains_any(self, *items):
return Expression(self, ACONTAINS_ANY, ArrayValue(self, items))
def contained_by(self, *items):
return Expression(self, ACONTAINED_BY, ArrayValue(self, items))
class ArrayValue(Node):
def __init__(self, field, value):
self.field = field
self.value = value
def __sql__(self, ctx):
return (ctx
.sql(AsIs(self.value))
.literal('::')
.sql(self.field.ddl_datatype(ctx)))
class DateTimeTZField(DateTimeField):
field_type = 'TIMESTAMPTZ'
class HStoreField(IndexedFieldMixin, Field):
field_type = 'HSTORE'
__hash__ = Field.__hash__
def __getitem__(self, key):
return Expression(self, HKEY, Value(key))
def keys(self):
return fn.akeys(self)
def values(self):
return fn.avals(self)
def items(self):
return fn.hstore_to_matrix(self)
def slice(self, *args):
return fn.slice(self, AsIs(list(args)))
def exists(self, key):
return fn.exist(self, key)
def defined(self, key):
return fn.defined(self, key)
def update(self, __data=None, **data):
if __data is not None:
data.update(__data)
return Expression(self, HUPDATE, data)
def delete(self, *keys):
value = Cast(AsIs(list(keys)), 'text[]')
return fn.delete(self, value)
def contains(self, value):
if isinstance(value, dict):
rhs = AsIs(value)
return Expression(self, HCONTAINS_DICT, rhs)
elif isinstance(value, (list, tuple)):
rhs = AsIs(value)
return Expression(self, HCONTAINS_KEYS, rhs)
return Expression(self, HCONTAINS_KEY, value)
def contains_any(self, *keys):
return Expression(self, HCONTAINS_ANY_KEY, AsIs(list(keys)))
class _JsonLookupBase(_LookupNode):
def __init__(self, node, parts, as_json=False):
super(_JsonLookupBase, self).__init__(node, parts)
self._jsonb = getattr(node, '_json_type', 'jsonb') == 'jsonb'
self._as_json = as_json
def clone(self):
return type(self)(self.node, list(self.parts), self._as_json)
@Node.copy
def as_json(self, as_json=True):
self._as_json = as_json
def concat(self, rhs):
if not isinstance(rhs, Node):
rhs = self.node.json_type(rhs)
return Expression(self.as_json(True), OP.CONCAT, rhs)
def contains(self, other):
if not isinstance(other, Node):
other = self.node.json_type(other)
return Expression(self.as_json(True), JSONB_CONTAINS, other)
def contained_by(self, other):
if not isinstance(other, Node):
other = self.node.json_type(other)
return Expression(self.as_json(True), JSONB_CONTAINED_BY, other)
def contains_any(self, *keys):
return Expression(
self.as_json(True),
JSONB_CONTAINS_ANY_KEY,
AsIs(list(keys), False))
def contains_all(self, *keys):
return Expression(
self.as_json(True),
JSONB_CONTAINS_ALL_KEYS,
AsIs(list(keys), False))
def has_key(self, key):
return Expression(self.as_json(True), JSONB_CONTAINS_KEY, key)
def remove(self):
parts = [str(p) if isinstance(p, int) else p for p in self.parts]
value = AsIs(parts, False)
return Expression(self.node, JSONB_PATH_REMOVE, value)
def length(self):
func = fn.jsonb_array_length if self._jsonb else fn.json_array_length
return func(self.as_json(True))
def extract(self, *path):
path = [str(p) if isinstance(p, int) else p for p in path]
func = fn.jsonb_extract_path if self._jsonb else fn.json_extract_path
return func(self.as_json(True), *path)
def path(self, *keys):
return JsonPath(self.as_json(True), keys, as_json=True)
class JsonLookup(_JsonLookupBase):
def __getitem__(self, value):
return JsonLookup(self.node, self.parts + [value], self._as_json)
def __sql__(self, ctx):
ctx.sql(self.node)
for part in self.parts[:-1]:
ctx.literal('->').sql(part)
if self.parts:
(ctx
.literal('->' if self._as_json else '->>')
.sql(self.parts[-1]))
return ctx
class JsonPath(_JsonLookupBase):
def __sql__(self, ctx):
return (ctx
.sql(self.node)
.literal('#>' if self._as_json else '#>>')
.sql(Value('{%s}' % ','.join(map(str, self.parts)))))
class JSONField(FieldDatabaseHook, Field):
field_type = 'JSON'
_json_datatype = 'json'
def __init__(self, dumps=None, **kwargs):
self._dumps = dumps
super(JSONField, self).__init__(**kwargs)
def _db_hook(self, database):
if database is None or not hasattr(database, '_adapter'):
self.json_type = Json
self.cast_json_case = True
else:
self.json_type = database._adapter.json_type
self.cast_json_case = database._adapter.cast_json_case
if self._dumps:
dumps = self._dumps
class _Json(self.json_type):
def __init__(self, value):
super(_Json, self).__init__(value, dumps=dumps)
self.json_type = _Json
def db_value(self, value):
if value is None or isinstance(value, (Node, self.json_type)):
return value
return self.json_type(value)
def to_value(self, value, case=False):
# CASE WHEN id = 123 THEN x.json_data fails because the expression is
# untyped, so we need an explicit cast with psycopg2.
if case and self.cast_json_case:
return Cast(self.json_type(value), self._json_datatype)
return self.db_value(value)
def __getitem__(self, value):
return JsonLookup(self, [value])
def path(self, *keys):
return JsonPath(self, keys, as_json=True)
def concat(self, value):
if not isinstance(value, Node):
value = self.json_type(value)
return super(JSONField, self).concat(value)
def length(self):
return fn.json_array_length(self)
def extract(self, *path):
path = [str(p) if isinstance(p, int) else p for p in path]
return fn.json_extract_path(self, *path)
class BinaryJSONField(IndexedFieldMixin, JSONField):
field_type = 'JSONB'
_json_datatype = 'jsonb'
__hash__ = Field.__hash__
def _db_hook(self, database):
if database is None or not hasattr(database, '_adapter'):
self.json_type = Json
self.cast_json_case = True
else:
self.json_type = database._adapter.jsonb_type
self.cast_json_case = database._adapter.cast_json_case
if self._dumps:
dumps = self._dumps
class _Json(self.json_type):
def __init__(self, value):
super(_Json, self).__init__(value, dumps=dumps)
self.json_type = _Json
def contains(self, other):
if not isinstance(other, Node):
other = self.json_type(other)
return Expression(self, JSONB_CONTAINS, other)
def contained_by(self, other):
if not isinstance(other, Node):
other = self.json_type(other)
return Expression(self, JSONB_CONTAINED_BY, other)
def contains_any(self, *items):
return Expression(
self,
JSONB_CONTAINS_ANY_KEY,
AsIs(list(items), False))
def contains_all(self, *items):
return Expression(
self,
JSONB_CONTAINS_ALL_KEYS,
AsIs(list(items), False))
def has_key(self, key):
return Expression(self, JSONB_CONTAINS_KEY, Value(key, False))
def remove(self, *items):
value = Cast(AsIs(list(items), False), 'text[]')
return Expression(self, JSONB_REMOVE, value)
def length(self):
return fn.jsonb_array_length(self)
def extract(self, *path):
path = [str(p) if isinstance(p, int) else p for p in path]
return fn.jsonb_extract_path(self, *path)
class TSVectorField(IndexedFieldMixin, TextField):
field_type = 'TSVECTOR'
__hash__ = Field.__hash__
def match(self, query, language=None, plain=False):
params = (language, query) if language is not None else (query,)
func = fn.plainto_tsquery if plain else fn.to_tsquery
return Expression(self, TS_MATCH, func(*params))
def Match(field, query, language=None):
params = (language, query) if language is not None else (query,)
field_params = (language, field) if language is not None else (field,)
return Expression(
fn.to_tsvector(*field_params),
TS_MATCH,
fn.to_tsquery(*params))
class IntervalField(Field):
field_type = 'INTERVAL'
class FetchManyCursor(object):
__slots__ = ('cursor', 'array_size', 'exhausted', 'iterable')
def __init__(self, cursor, array_size=None):
self.cursor = cursor
self.array_size = array_size or cursor.itersize
self.exhausted = False
self.iterable = self.row_gen()
@property
def description(self):
return self.cursor.description
def close(self):
if self.cursor is not None and not self.cursor.closed:
self.cursor.close()
def row_gen(self):
try:
while True:
rows = self.cursor.fetchmany(self.array_size)
if not rows:
return
for row in rows:
yield row
finally:
self.close()
def fetchone(self):
if self.exhausted:
return
try:
return next(self.iterable)
except StopIteration:
self.exhausted = True
class ServerSideQuery(Node):
def __init__(self, query, array_size=None):
self.query = query
self.array_size = array_size
self._cursor_wrapper = None
def __sql__(self, ctx):
return self.query.__sql__(ctx)
def __iter__(self):
if self._cursor_wrapper is None:
self._execute(self.query._database)
return iter(self._cursor_wrapper.iterator())
def close(self):
if self._cursor_wrapper is not None:
self._cursor_wrapper.cursor.close()
self._cursor_wrapper = None
return True
return False
def iterator(self):
if self._cursor_wrapper is None:
self._execute(self.query._database)
return self._cursor_wrapper.iterator()
def _execute(self, database):
if self._cursor_wrapper is None:
cursor = database.execute(self.query, named_cursor=True,
array_size=self.array_size)
self._cursor_wrapper = self.query._get_cursor_wrapper(cursor)
return self._cursor_wrapper
def ServerSide(query, array_size=None):
server_side_query = ServerSideQuery(query, array_size=array_size)
for row in server_side_query:
yield row
class _empty_object(object):
__slots__ = ()
def __nonzero__(self):
return False
__bool__ = __nonzero__
class Psycopg2ExtAdapter(Psycopg2Adapter):
def register_hstore(self, conn):
register_hstore(conn)
def server_side_cursor(self, conn):
# psycopg2 does not allow us to use these in autocommit, even if we ARE
# inside a transaction - so specify withhold (not desirable!).
return conn.cursor(name=str(uuid.uuid1()), withhold=True)
class Psycopg3ExtAdapter(Psycopg3Adapter):
def register_hstore(self, conn):
info = TypeInfo.fetch(conn, 'hstore')
register_hstore_pg3(info, conn)
def server_side_cursor(self, conn):
return conn.cursor(name=str(uuid.uuid1()))
class PostgresqlExtDatabase(PostgresqlDatabase):
psycopg2_adapter = Psycopg2ExtAdapter
psycopg3_adapter = Psycopg3ExtAdapter
def __init__(self, *args, **kwargs):
self._register_hstore = kwargs.pop('register_hstore', False)
self._server_side_cursors = kwargs.pop('server_side_cursors', False)
super(PostgresqlExtDatabase, self).__init__(*args, **kwargs)
def _connect(self):
conn = super(PostgresqlExtDatabase, self)._connect()
if self._register_hstore:
self._adapter.register_hstore(conn)
return conn
def cursor(self, named_cursor=None):
if self.is_closed():
if self.autoconnect:
self.connect()
else:
raise InterfaceError('Error, database connection not opened.')
if named_cursor:
return self._adapter.server_side_cursor(self._state.conn)
return self._state.conn.cursor()
def execute(self, query, named_cursor=False, array_size=None,
**context_options):
ctx = self.get_sql_context(**context_options)
sql, params = ctx.sql(query).query()
named_cursor = named_cursor or (self._server_side_cursors and
sql[:6].lower() == 'select')
cursor = self.execute_sql(sql, params, named_cursor=named_cursor)
if named_cursor:
cursor = FetchManyCursor(cursor, array_size)
return cursor
def execute_sql(self, sql, params=None, named_cursor=None):
logger.debug((sql, params))
with __exception_wrapper__:
cursor = self.cursor(named_cursor=named_cursor)
cursor.execute(sql, params or ())
return cursor
class PooledPostgresqlExtDatabase(_PooledPostgresqlDatabase, PostgresqlExtDatabase):
pass
class Psycopg3Database(PostgresqlExtDatabase):
def __init__(self, *args, **kwargs):
kwargs['prefer_psycopg3'] = True
super(Psycopg3Database, self).__init__(*args, **kwargs)
class PooledPsycopg3Database(_PooledPostgresqlDatabase, Psycopg3Database):
pass