Files
sqlalchemy/test/orm/collection.py
T
Jason Kirtland 8fa48edbf9 - Removed 2.3 set emulations/enhancements.
(sets.Set-based collections & DB-API returns still work.)
2008-07-15 19:23:52 +00:00

1800 lines
53 KiB
Python

import testenv; testenv.configure_for_tests()
import sys
from operator import and_
import sqlalchemy.orm.collections as collections
from sqlalchemy.orm.collections import collection
from testlib import sa, testing
from testlib.sa import Table, Column, Integer, String, ForeignKey
from testlib.sa import util, exc as sa_exc
from testlib.sa.orm import create_session, mapper, relation, \
attributes
from orm import _base
class Canary(sa.orm.interfaces.AttributeExtension):
def __init__(self):
self.data = set()
self.added = set()
self.removed = set()
def append(self, obj, value, initiator):
assert value not in self.added
self.data.add(value)
self.added.add(value)
def remove(self, obj, value, initiator):
assert value not in self.removed
self.data.remove(value)
self.removed.add(value)
def set(self, obj, value, oldvalue, initiator):
if oldvalue is not None:
self.remove(obj, oldvalue, None)
self.append(obj, value, None)
class CollectionsTest(_base.ORMTest):
class Entity(object):
def __init__(self, a=None, b=None, c=None):
self.a = a
self.b = b
self.c = c
def __repr__(self):
return str((id(self), self.a, self.b, self.c))
def setUpAll(self):
attributes.register_class(self.Entity)
def tearDownAll(self):
attributes.unregister_class(self.Entity)
_base.ORMTest.tearDownAll(self)
_entity_id = 1
@classmethod
def entity_maker(cls):
cls._entity_id += 1
return cls.Entity(cls._entity_id)
@classmethod
def dictable_entity(cls, a=None, b=None, c=None):
id = cls._entity_id = (cls._entity_id + 1)
return cls.Entity(a or str(id), b or 'value %s' % id, c)
def _test_adapter(self, typecallable, creator=None, to_set=None):
if creator is None:
creator = self.entity_maker
class Foo(object):
pass
canary = Canary()
attributes.register_class(Foo)
attributes.register_attribute(Foo, 'attr', True, extension=canary,
typecallable=typecallable, useobject=True)
obj = Foo()
adapter = collections.collection_adapter(obj.attr)
direct = obj.attr
if to_set is None:
to_set = lambda col: set(col)
def assert_eq():
self.assert_(to_set(direct) == canary.data)
self.assert_(set(adapter) == canary.data)
assert_ne = lambda: self.assert_(to_set(direct) != canary.data)
e1, e2 = creator(), creator()
adapter.append_with_event(e1)
assert_eq()
adapter.append_without_event(e2)
assert_ne()
canary.data.add(e2)
assert_eq()
adapter.remove_without_event(e2)
assert_ne()
canary.data.remove(e2)
assert_eq()
adapter.remove_with_event(e1)
assert_eq()
def _test_list(self, typecallable, creator=None):
if creator is None:
creator = self.entity_maker
class Foo(object):
pass
canary = Canary()
attributes.register_class(Foo)
attributes.register_attribute(Foo, 'attr', True, extension=canary,
typecallable=typecallable, useobject=True)
obj = Foo()
adapter = collections.collection_adapter(obj.attr)
direct = obj.attr
control = list()
def assert_eq():
self.assert_(set(direct) == canary.data)
self.assert_(set(adapter) == canary.data)
self.assert_(direct == control)
# assume append() is available for list tests
e = creator()
direct.append(e)
control.append(e)
assert_eq()
if hasattr(direct, 'pop'):
direct.pop()
control.pop()
assert_eq()
if hasattr(direct, '__setitem__'):
e = creator()
direct.append(e)
control.append(e)
e = creator()
direct[0] = e
control[0] = e
assert_eq()
if reduce(and_, [hasattr(direct, a) for a in
('__delitem__', 'insert', '__len__')], True):
values = [creator(), creator(), creator(), creator()]
direct[slice(0,1)] = values
control[slice(0,1)] = values
assert_eq()
values = [creator(), creator()]
direct[slice(0,-1,2)] = values
control[slice(0,-1,2)] = values
assert_eq()
values = [creator()]
direct[slice(0,-1)] = values
control[slice(0,-1)] = values
assert_eq()
if hasattr(direct, '__delitem__'):
e = creator()
direct.append(e)
control.append(e)
del direct[-1]
del control[-1]
assert_eq()
if hasattr(direct, '__getslice__'):
for e in [creator(), creator(), creator(), creator()]:
direct.append(e)
control.append(e)
del direct[:-3]
del control[:-3]
assert_eq()
del direct[0:1]
del control[0:1]
assert_eq()
del direct[::2]
del control[::2]
assert_eq()
if hasattr(direct, 'remove'):
e = creator()
direct.append(e)
control.append(e)
direct.remove(e)
control.remove(e)
assert_eq()
if hasattr(direct, '__setslice__'):
values = [creator(), creator()]
direct[0:1] = values
control[0:1] = values
assert_eq()
values = [creator()]
direct[0:] = values
control[0:] = values
assert_eq()
values = [creator()]
direct[:1] = values
control[:1] = values
assert_eq()
values = [creator()]
direct[-1::2] = values
control[-1::2] = values
assert_eq()
values = [creator()] * len(direct[1::2])
direct[1::2] = values
control[1::2] = values
assert_eq()
if hasattr(direct, '__delslice__'):
for i in range(1, 4):
e = creator()
direct.append(e)
control.append(e)
del direct[-1:]
del control[-1:]
assert_eq()
del direct[1:2]
del control[1:2]
assert_eq()
del direct[:]
del control[:]
assert_eq()
if hasattr(direct, 'extend'):
values = [creator(), creator(), creator()]
direct.extend(values)
control.extend(values)
assert_eq()
if hasattr(direct, '__iadd__'):
values = [creator(), creator(), creator()]
direct += values
control += values
assert_eq()
direct += []
control += []
assert_eq()
values = [creator(), creator()]
obj.attr += values
control += values
assert_eq()
if hasattr(direct, '__imul__'):
direct *= 2
control *= 2
assert_eq()
obj.attr *= 2
control *= 2
assert_eq()
def _test_list_bulk(self, typecallable, creator=None):
if creator is None:
creator = self.entity_maker
class Foo(object):
pass
canary = Canary()
attributes.register_class(Foo)
attributes.register_attribute(Foo, 'attr', True, extension=canary,
typecallable=typecallable, useobject=True)
obj = Foo()
direct = obj.attr
e1 = creator()
obj.attr.append(e1)
like_me = typecallable()
e2 = creator()
like_me.append(e2)
self.assert_(obj.attr is direct)
obj.attr = like_me
self.assert_(obj.attr is not direct)
self.assert_(obj.attr is not like_me)
self.assert_(set(obj.attr) == set([e2]))
self.assert_(e1 in canary.removed)
self.assert_(e2 in canary.added)
e3 = creator()
real_list = [e3]
obj.attr = real_list
self.assert_(obj.attr is not real_list)
self.assert_(set(obj.attr) == set([e3]))
self.assert_(e2 in canary.removed)
self.assert_(e3 in canary.added)
e4 = creator()
try:
obj.attr = set([e4])
self.assert_(False)
except TypeError:
self.assert_(e4 not in canary.data)
self.assert_(e3 in canary.data)
e5 = creator()
e6 = creator()
e7 = creator()
obj.attr = [e5, e6, e7]
self.assert_(e5 in canary.added)
self.assert_(e6 in canary.added)
self.assert_(e7 in canary.added)
obj.attr = [e6, e7]
self.assert_(e5 in canary.removed)
self.assert_(e6 in canary.added)
self.assert_(e7 in canary.added)
self.assert_(e6 not in canary.removed)
self.assert_(e7 not in canary.removed)
def test_list(self):
self._test_adapter(list)
self._test_list(list)
self._test_list_bulk(list)
def test_list_subclass(self):
class MyList(list):
pass
self._test_adapter(MyList)
self._test_list(MyList)
self._test_list_bulk(MyList)
self.assert_(getattr(MyList, '_sa_instrumented') == id(MyList))
def test_list_duck(self):
class ListLike(object):
def __init__(self):
self.data = list()
def append(self, item):
self.data.append(item)
def remove(self, item):
self.data.remove(item)
def insert(self, index, item):
self.data.insert(index, item)
def pop(self, index=-1):
return self.data.pop(index)
def extend(self):
assert False
def __iter__(self):
return iter(self.data)
def __eq__(self, other):
return self.data == other
def __repr__(self):
return 'ListLike(%s)' % repr(self.data)
self._test_adapter(ListLike)
self._test_list(ListLike)
self._test_list_bulk(ListLike)
self.assert_(getattr(ListLike, '_sa_instrumented') == id(ListLike))
def test_list_emulates(self):
class ListIsh(object):
__emulates__ = list
def __init__(self):
self.data = list()
def append(self, item):
self.data.append(item)
def remove(self, item):
self.data.remove(item)
def insert(self, index, item):
self.data.insert(index, item)
def pop(self, index=-1):
return self.data.pop(index)
def extend(self):
assert False
def __iter__(self):
return iter(self.data)
def __eq__(self, other):
return self.data == other
def __repr__(self):
return 'ListIsh(%s)' % repr(self.data)
self._test_adapter(ListIsh)
self._test_list(ListIsh)
self._test_list_bulk(ListIsh)
self.assert_(getattr(ListIsh, '_sa_instrumented') == id(ListIsh))
def _test_set(self, typecallable, creator=None):
if creator is None:
creator = self.entity_maker
class Foo(object):
pass
canary = Canary()
attributes.register_class(Foo)
attributes.register_attribute(Foo, 'attr', True, extension=canary,
typecallable=typecallable, useobject=True)
obj = Foo()
adapter = collections.collection_adapter(obj.attr)
direct = obj.attr
control = set()
def assert_eq():
self.assert_(set(direct) == canary.data)
self.assert_(set(adapter) == canary.data)
self.assert_(direct == control)
def addall(*values):
for item in values:
direct.add(item)
control.add(item)
assert_eq()
def zap():
for item in list(direct):
direct.remove(item)
control.clear()
addall(creator())
e = creator()
addall(e)
addall(e)
if hasattr(direct, 'pop'):
direct.pop()
control.pop()
assert_eq()
if hasattr(direct, 'remove'):
e = creator()
addall(e)
direct.remove(e)
control.remove(e)
assert_eq()
e = creator()
try:
direct.remove(e)
except KeyError:
assert_eq()
self.assert_(e not in canary.removed)
else:
self.assert_(False)
if hasattr(direct, 'discard'):
e = creator()
addall(e)
direct.discard(e)
control.discard(e)
assert_eq()
e = creator()
direct.discard(e)
self.assert_(e not in canary.removed)
assert_eq()
if hasattr(direct, 'update'):
zap()
e = creator()
addall(e)
values = set([e, creator(), creator()])
direct.update(values)
control.update(values)
assert_eq()
if hasattr(direct, '__ior__'):
zap()
e = creator()
addall(e)
values = set([e, creator(), creator()])
direct |= values
control |= values
assert_eq()
# cover self-assignment short-circuit
values = set([e, creator(), creator()])
obj.attr |= values
control |= values
assert_eq()
values = frozenset([e, creator()])
obj.attr |= values
control |= values
assert_eq()
try:
direct |= [e, creator()]
assert False
except TypeError:
assert True
if hasattr(direct, 'clear'):
addall(creator(), creator())
direct.clear()
control.clear()
assert_eq()
if hasattr(direct, 'difference_update'):
zap()
e = creator()
addall(creator(), creator())
values = set([creator()])
direct.difference_update(values)
control.difference_update(values)
assert_eq()
values.update(set([e, creator()]))
direct.difference_update(values)
control.difference_update(values)
assert_eq()
if hasattr(direct, '__isub__'):
zap()
e = creator()
addall(creator(), creator())
values = set([creator()])
direct -= values
control -= values
assert_eq()
values.update(set([e, creator()]))
direct -= values
control -= values
assert_eq()
values = set([creator()])
obj.attr -= values
control -= values
assert_eq()
values = frozenset([creator()])
obj.attr -= values
control -= values
assert_eq()
try:
direct -= [e, creator()]
assert False
except TypeError:
assert True
if hasattr(direct, 'intersection_update'):
zap()
e = creator()
addall(e, creator(), creator())
values = set(control)
direct.intersection_update(values)
control.intersection_update(values)
assert_eq()
values.update(set([e, creator()]))
direct.intersection_update(values)
control.intersection_update(values)
assert_eq()
if hasattr(direct, '__iand__'):
zap()
e = creator()
addall(e, creator(), creator())
values = set(control)
direct &= values
control &= values
assert_eq()
values.update(set([e, creator()]))
direct &= values
control &= values
assert_eq()
values.update(set([creator()]))
obj.attr &= values
control &= values
assert_eq()
try:
direct &= [e, creator()]
assert False
except TypeError:
assert True
if hasattr(direct, 'symmetric_difference_update'):
zap()
e = creator()
addall(e, creator(), creator())
values = set([e, creator()])
direct.symmetric_difference_update(values)
control.symmetric_difference_update(values)
assert_eq()
e = creator()
addall(e)
values = set([e])
direct.symmetric_difference_update(values)
control.symmetric_difference_update(values)
assert_eq()
values = set()
direct.symmetric_difference_update(values)
control.symmetric_difference_update(values)
assert_eq()
if hasattr(direct, '__ixor__'):
zap()
e = creator()
addall(e, creator(), creator())
values = set([e, creator()])
direct ^= values
control ^= values
assert_eq()
e = creator()
addall(e)
values = set([e])
direct ^= values
control ^= values
assert_eq()
values = set()
direct ^= values
control ^= values
assert_eq()
values = set([creator()])
obj.attr ^= values
control ^= values
assert_eq()
try:
direct ^= [e, creator()]
assert False
except TypeError:
assert True
def _test_set_bulk(self, typecallable, creator=None):
if creator is None:
creator = self.entity_maker
class Foo(object):
pass
canary = Canary()
attributes.register_class(Foo)
attributes.register_attribute(Foo, 'attr', True, extension=canary,
typecallable=typecallable, useobject=True)
obj = Foo()
direct = obj.attr
e1 = creator()
obj.attr.add(e1)
like_me = typecallable()
e2 = creator()
like_me.add(e2)
self.assert_(obj.attr is direct)
obj.attr = like_me
self.assert_(obj.attr is not direct)
self.assert_(obj.attr is not like_me)
self.assert_(obj.attr == set([e2]))
self.assert_(e1 in canary.removed)
self.assert_(e2 in canary.added)
e3 = creator()
real_set = set([e3])
obj.attr = real_set
self.assert_(obj.attr is not real_set)
self.assert_(obj.attr == set([e3]))
self.assert_(e2 in canary.removed)
self.assert_(e3 in canary.added)
e4 = creator()
try:
obj.attr = [e4]
self.assert_(False)
except TypeError:
self.assert_(e4 not in canary.data)
self.assert_(e3 in canary.data)
def test_set(self):
self._test_adapter(set)
self._test_set(set)
self._test_set_bulk(set)
def test_set_subclass(self):
class MySet(set):
pass
self._test_adapter(MySet)
self._test_set(MySet)
self._test_set_bulk(MySet)
self.assert_(getattr(MySet, '_sa_instrumented') == id(MySet))
def test_set_duck(self):
class SetLike(object):
def __init__(self):
self.data = set()
def add(self, item):
self.data.add(item)
def remove(self, item):
self.data.remove(item)
def discard(self, item):
self.data.discard(item)
def pop(self):
return self.data.pop()
def update(self, other):
self.data.update(other)
def __iter__(self):
return iter(self.data)
def __eq__(self, other):
return self.data == other
self._test_adapter(SetLike)
self._test_set(SetLike)
self._test_set_bulk(SetLike)
self.assert_(getattr(SetLike, '_sa_instrumented') == id(SetLike))
def test_set_emulates(self):
class SetIsh(object):
__emulates__ = set
def __init__(self):
self.data = set()
def add(self, item):
self.data.add(item)
def remove(self, item):
self.data.remove(item)
def discard(self, item):
self.data.discard(item)
def pop(self):
return self.data.pop()
def update(self, other):
self.data.update(other)
def __iter__(self):
return iter(self.data)
def __eq__(self, other):
return self.data == other
self._test_adapter(SetIsh)
self._test_set(SetIsh)
self._test_set_bulk(SetIsh)
self.assert_(getattr(SetIsh, '_sa_instrumented') == id(SetIsh))
def _test_dict(self, typecallable, creator=None):
if creator is None:
creator = self.dictable_entity
class Foo(object):
pass
canary = Canary()
attributes.register_class(Foo)
attributes.register_attribute(Foo, 'attr', True, extension=canary,
typecallable=typecallable, useobject=True)
obj = Foo()
adapter = collections.collection_adapter(obj.attr)
direct = obj.attr
control = dict()
def assert_eq():
self.assert_(set(direct.values()) == canary.data)
self.assert_(set(adapter) == canary.data)
self.assert_(direct == control)
def addall(*values):
for item in values:
direct.set(item)
control[item.a] = item
assert_eq()
def zap():
for item in list(adapter):
direct.remove(item)
control.clear()
# assume an 'set' method is available for tests
addall(creator())
if hasattr(direct, '__setitem__'):
e = creator()
direct[e.a] = e
control[e.a] = e
assert_eq()
e = creator(e.a, e.b)
direct[e.a] = e
control[e.a] = e
assert_eq()
if hasattr(direct, '__delitem__'):
e = creator()
addall(e)
del direct[e.a]
del control[e.a]
assert_eq()
e = creator()
try:
del direct[e.a]
except KeyError:
self.assert_(e not in canary.removed)
if hasattr(direct, 'clear'):
addall(creator(), creator(), creator())
direct.clear()
control.clear()
assert_eq()
direct.clear()
control.clear()
assert_eq()
if hasattr(direct, 'pop'):
e = creator()
addall(e)
direct.pop(e.a)
control.pop(e.a)
assert_eq()
e = creator()
try:
direct.pop(e.a)
except KeyError:
self.assert_(e not in canary.removed)
if hasattr(direct, 'popitem'):
zap()
e = creator()
addall(e)
direct.popitem()
control.popitem()
assert_eq()
if hasattr(direct, 'setdefault'):
e = creator()
val_a = direct.setdefault(e.a, e)
val_b = control.setdefault(e.a, e)
assert_eq()
self.assert_(val_a is val_b)
val_a = direct.setdefault(e.a, e)
val_b = control.setdefault(e.a, e)
assert_eq()
self.assert_(val_a is val_b)
if hasattr(direct, 'update'):
e = creator()
d = dict([(ee.a, ee) for ee in [e, creator(), creator()]])
addall(e, creator())
direct.update(d)
control.update(d)
assert_eq()
if sys.version_info >= (2, 4):
kw = dict([(ee.a, ee) for ee in [e, creator()]])
direct.update(**kw)
control.update(**kw)
assert_eq()
def _test_dict_bulk(self, typecallable, creator=None):
if creator is None:
creator = self.dictable_entity
class Foo(object):
pass
canary = Canary()
attributes.register_class(Foo)
attributes.register_attribute(Foo, 'attr', True, extension=canary,
typecallable=typecallable, useobject=True)
obj = Foo()
direct = obj.attr
e1 = creator()
collections.collection_adapter(direct).append_with_event(e1)
like_me = typecallable()
e2 = creator()
like_me.set(e2)
self.assert_(obj.attr is direct)
obj.attr = like_me
self.assert_(obj.attr is not direct)
self.assert_(obj.attr is not like_me)
self.assert_(set(collections.collection_adapter(obj.attr)) == set([e2]))
self.assert_(e1 in canary.removed)
self.assert_(e2 in canary.added)
# key validity on bulk assignment is a basic feature of MappedCollection
# but is not present in basic, @converter-less dict collections.
e3 = creator()
if isinstance(obj.attr, collections.MappedCollection):
real_dict = dict(badkey=e3)
try:
obj.attr = real_dict
self.assert_(False)
except TypeError:
pass
self.assert_(obj.attr is not real_dict)
self.assert_('badkey' not in obj.attr)
self.assertEquals(set(collections.collection_adapter(obj.attr)),
set([e2]))
self.assert_(e3 not in canary.added)
else:
real_dict = dict(keyignored1=e3)
obj.attr = real_dict
self.assert_(obj.attr is not real_dict)
self.assert_('keyignored1' not in obj.attr)
self.assertEquals(set(collections.collection_adapter(obj.attr)),
set([e3]))
self.assert_(e2 in canary.removed)
self.assert_(e3 in canary.added)
obj.attr = typecallable()
self.assertEquals(list(collections.collection_adapter(obj.attr)), [])
e4 = creator()
try:
obj.attr = [e4]
self.assert_(False)
except TypeError:
self.assert_(e4 not in canary.data)
def test_dict(self):
try:
self._test_adapter(dict, self.dictable_entity,
to_set=lambda c: set(c.values()))
self.assert_(False)
except sa_exc.ArgumentError, e:
self.assert_(e.args[0] == 'Type InstrumentedDict must elect an appender method to be a collection class')
try:
self._test_dict(dict)
self.assert_(False)
except sa_exc.ArgumentError, e:
self.assert_(e.args[0] == 'Type InstrumentedDict must elect an appender method to be a collection class')
def test_dict_subclass(self):
class MyDict(dict):
@collection.appender
@collection.internally_instrumented
def set(self, item, _sa_initiator=None):
self.__setitem__(item.a, item, _sa_initiator=_sa_initiator)
@collection.remover
@collection.internally_instrumented
def _remove(self, item, _sa_initiator=None):
self.__delitem__(item.a, _sa_initiator=_sa_initiator)
self._test_adapter(MyDict, self.dictable_entity,
to_set=lambda c: set(c.values()))
self._test_dict(MyDict)
self._test_dict_bulk(MyDict)
self.assert_(getattr(MyDict, '_sa_instrumented') == id(MyDict))
def test_dict_subclass2(self):
class MyEasyDict(collections.MappedCollection):
def __init__(self):
super(MyEasyDict, self).__init__(lambda e: e.a)
self._test_adapter(MyEasyDict, self.dictable_entity,
to_set=lambda c: set(c.values()))
self._test_dict(MyEasyDict)
self._test_dict_bulk(MyEasyDict)
self.assert_(getattr(MyEasyDict, '_sa_instrumented') == id(MyEasyDict))
def test_dict_subclass3(self):
class MyOrdered(util.OrderedDict, collections.MappedCollection):
def __init__(self):
collections.MappedCollection.__init__(self, lambda e: e.a)
util.OrderedDict.__init__(self)
self._test_adapter(MyOrdered, self.dictable_entity,
to_set=lambda c: set(c.values()))
self._test_dict(MyOrdered)
self._test_dict_bulk(MyOrdered)
self.assert_(getattr(MyOrdered, '_sa_instrumented') == id(MyOrdered))
def test_dict_duck(self):
class DictLike(object):
def __init__(self):
self.data = dict()
@collection.appender
@collection.replaces(1)
def set(self, item):
current = self.data.get(item.a, None)
self.data[item.a] = item
return current
@collection.remover
def _remove(self, item):
del self.data[item.a]
def __setitem__(self, key, value):
self.data[key] = value
def __getitem__(self, key):
return self.data[key]
def __delitem__(self, key):
del self.data[key]
def values(self):
return self.data.values()
def __contains__(self, key):
return key in self.data
@collection.iterator
def itervalues(self):
return self.data.itervalues()
def __eq__(self, other):
return self.data == other
def __repr__(self):
return 'DictLike(%s)' % repr(self.data)
self._test_adapter(DictLike, self.dictable_entity,
to_set=lambda c: set(c.itervalues()))
self._test_dict(DictLike)
self._test_dict_bulk(DictLike)
self.assert_(getattr(DictLike, '_sa_instrumented') == id(DictLike))
def test_dict_emulates(self):
class DictIsh(object):
__emulates__ = dict
def __init__(self):
self.data = dict()
@collection.appender
@collection.replaces(1)
def set(self, item):
current = self.data.get(item.a, None)
self.data[item.a] = item
return current
@collection.remover
def _remove(self, item):
del self.data[item.a]
def __setitem__(self, key, value):
self.data[key] = value
def __getitem__(self, key):
return self.data[key]
def __delitem__(self, key):
del self.data[key]
def values(self):
return self.data.values()
def __contains__(self, key):
return key in self.data
@collection.iterator
def itervalues(self):
return self.data.itervalues()
def __eq__(self, other):
return self.data == other
def __repr__(self):
return 'DictIsh(%s)' % repr(self.data)
self._test_adapter(DictIsh, self.dictable_entity,
to_set=lambda c: set(c.itervalues()))
self._test_dict(DictIsh)
self._test_dict_bulk(DictIsh)
self.assert_(getattr(DictIsh, '_sa_instrumented') == id(DictIsh))
def _test_object(self, typecallable, creator=None):
if creator is None:
creator = self.entity_maker
class Foo(object):
pass
canary = Canary()
attributes.register_class(Foo)
attributes.register_attribute(Foo, 'attr', True, extension=canary,
typecallable=typecallable, useobject=True)
obj = Foo()
adapter = collections.collection_adapter(obj.attr)
direct = obj.attr
control = set()
def assert_eq():
self.assert_(set(direct) == canary.data)
self.assert_(set(adapter) == canary.data)
self.assert_(direct == control)
# There is no API for object collections. We'll make one up
# for the purposes of the test.
e = creator()
direct.push(e)
control.add(e)
assert_eq()
direct.zark(e)
control.remove(e)
assert_eq()
e = creator()
direct.maybe_zark(e)
control.discard(e)
assert_eq()
e = creator()
direct.push(e)
control.add(e)
assert_eq()
e = creator()
direct.maybe_zark(e)
control.discard(e)
assert_eq()
def test_object_duck(self):
class MyCollection(object):
def __init__(self):
self.data = set()
@collection.appender
def push(self, item):
self.data.add(item)
@collection.remover
def zark(self, item):
self.data.remove(item)
@collection.removes_return()
def maybe_zark(self, item):
if item in self.data:
self.data.remove(item)
return item
@collection.iterator
def __iter__(self):
return iter(self.data)
def __eq__(self, other):
return self.data == other
self._test_adapter(MyCollection)
self._test_object(MyCollection)
self.assert_(getattr(MyCollection, '_sa_instrumented') ==
id(MyCollection))
def test_object_emulates(self):
class MyCollection2(object):
__emulates__ = None
def __init__(self):
self.data = set()
# looks like a list
def append(self, item):
assert False
@collection.appender
def push(self, item):
self.data.add(item)
@collection.remover
def zark(self, item):
self.data.remove(item)
@collection.removes_return()
def maybe_zark(self, item):
if item in self.data:
self.data.remove(item)
return item
@collection.iterator
def __iter__(self):
return iter(self.data)
def __eq__(self, other):
return self.data == other
self._test_adapter(MyCollection2)
self._test_object(MyCollection2)
self.assert_(getattr(MyCollection2, '_sa_instrumented') ==
id(MyCollection2))
def test_recipes(self):
class Custom(object):
def __init__(self):
self.data = []
@collection.appender
@collection.adds('entity')
def put(self, entity):
self.data.append(entity)
@collection.remover
@collection.removes(1)
def remove(self, entity):
self.data.remove(entity)
@collection.adds(1)
def push(self, *args):
self.data.append(args[0])
@collection.removes('entity')
def yank(self, entity, arg):
self.data.remove(entity)
@collection.replaces(2)
def replace(self, arg, entity, **kw):
self.data.insert(0, entity)
return self.data.pop()
@collection.removes_return()
def pop(self, key):
return self.data.pop()
@collection.iterator
def __iter__(self):
return iter(self.data)
class Foo(object):
pass
canary = Canary()
attributes.register_class(Foo)
attributes.register_attribute(Foo, 'attr', True, extension=canary,
typecallable=Custom, useobject=True)
obj = Foo()
adapter = collections.collection_adapter(obj.attr)
direct = obj.attr
control = list()
def assert_eq():
self.assert_(set(direct) == canary.data)
self.assert_(set(adapter) == canary.data)
self.assert_(list(direct) == control)
creator = self.entity_maker
e1 = creator()
direct.put(e1)
control.append(e1)
assert_eq()
e2 = creator()
direct.put(entity=e2)
control.append(e2)
assert_eq()
direct.remove(e2)
control.remove(e2)
assert_eq()
direct.remove(entity=e1)
control.remove(e1)
assert_eq()
e3 = creator()
direct.push(e3)
control.append(e3)
assert_eq()
direct.yank(e3, 'blah')
control.remove(e3)
assert_eq()
e4, e5, e6, e7 = creator(), creator(), creator(), creator()
direct.put(e4)
direct.put(e5)
control.append(e4)
control.append(e5)
dr1 = direct.replace('foo', e6, bar='baz')
control.insert(0, e6)
cr1 = control.pop()
assert_eq()
self.assert_(dr1 is cr1)
dr2 = direct.replace(arg=1, entity=e7)
control.insert(0, e7)
cr2 = control.pop()
assert_eq()
self.assert_(dr2 is cr2)
dr3 = direct.pop('blah')
cr3 = control.pop()
assert_eq()
self.assert_(dr3 is cr3)
def test_lifecycle(self):
class Foo(object):
pass
canary = Canary()
creator = self.entity_maker
attributes.register_class(Foo)
attributes.register_attribute(Foo, 'attr', True, extension=canary, useobject=True)
obj = Foo()
col1 = obj.attr
e1 = creator()
obj.attr.append(e1)
e2 = creator()
bulk1 = [e2]
# empty & sever col1 from obj
obj.attr = bulk1
self.assert_(len(col1) == 0)
self.assert_(len(canary.data) == 1)
self.assert_(obj.attr is not col1)
self.assert_(obj.attr is not bulk1)
self.assert_(obj.attr == bulk1)
e3 = creator()
col1.append(e3)
self.assert_(e3 not in canary.data)
self.assert_(collections.collection_adapter(col1) is None)
obj.attr[0] = e3
self.assert_(e3 in canary.data)
class DictHelpersTest(_base.MappedTest):
def define_tables(self, metadata):
Table('parents', metadata,
Column('id', Integer, primary_key=True),
Column('label', String(128)))
Table('children', metadata,
Column('id', Integer, primary_key=True),
Column('parent_id', Integer, ForeignKey('parents.id'),
nullable=False),
Column('a', String(128)),
Column('b', String(128)),
Column('c', String(128)))
def setup_classes(self):
class Parent(_base.BasicEntity):
def __init__(self, label=None):
self.label = label
class Child(_base.BasicEntity):
def __init__(self, a=None, b=None, c=None):
self.a = a
self.b = b
self.c = c
@testing.resolve_artifact_names
def _test_scalar_mapped(self, collection_class):
mapper(Child, children)
mapper(Parent, parents, properties={
'children': relation(Child, collection_class=collection_class,
cascade="all, delete-orphan")})
p = Parent()
p.children['foo'] = Child('foo', 'value')
p.children['bar'] = Child('bar', 'value')
session = create_session()
session.add(p)
session.flush()
pid = p.id
session.clear()
p = session.query(Parent).get(pid)
self.assert_(set(p.children.keys()) == set(['foo', 'bar']))
cid = p.children['foo'].id
collections.collection_adapter(p.children).append_with_event(
Child('foo', 'newvalue'))
session.flush()
session.clear()
p = session.query(Parent).get(pid)
self.assert_(set(p.children.keys()) == set(['foo', 'bar']))
self.assert_(p.children['foo'].id != cid)
self.assert_(len(list(collections.collection_adapter(p.children))) == 2)
session.flush()
session.clear()
p = session.query(Parent).get(pid)
self.assert_(len(list(collections.collection_adapter(p.children))) == 2)
collections.collection_adapter(p.children).remove_with_event(
p.children['foo'])
self.assert_(len(list(collections.collection_adapter(p.children))) == 1)
session.flush()
session.clear()
p = session.query(Parent).get(pid)
self.assert_(len(list(collections.collection_adapter(p.children))) == 1)
del p.children['bar']
self.assert_(len(list(collections.collection_adapter(p.children))) == 0)
session.flush()
session.clear()
p = session.query(Parent).get(pid)
self.assert_(len(list(collections.collection_adapter(p.children))) == 0)
@testing.resolve_artifact_names
def _test_composite_mapped(self, collection_class):
mapper(Child, children)
mapper(Parent, parents, properties={
'children': relation(Child, collection_class=collection_class,
cascade="all, delete-orphan")
})
p = Parent()
p.children[('foo', '1')] = Child('foo', '1', 'value 1')
p.children[('foo', '2')] = Child('foo', '2', 'value 2')
session = create_session()
session.add(p)
session.flush()
pid = p.id
session.clear()
p = session.query(Parent).get(pid)
self.assert_(set(p.children.keys()) == set([('foo', '1'), ('foo', '2')]))
cid = p.children[('foo', '1')].id
collections.collection_adapter(p.children).append_with_event(
Child('foo', '1', 'newvalue'))
session.flush()
session.clear()
p = session.query(Parent).get(pid)
self.assert_(set(p.children.keys()) == set([('foo', '1'), ('foo', '2')]))
self.assert_(p.children[('foo', '1')].id != cid)
self.assert_(len(list(collections.collection_adapter(p.children))) == 2)
def test_mapped_collection(self):
collection_class = collections.mapped_collection(lambda c: c.a)
self._test_scalar_mapped(collection_class)
def test_mapped_collection2(self):
collection_class = collections.mapped_collection(lambda c: (c.a, c.b))
self._test_composite_mapped(collection_class)
def test_attr_mapped_collection(self):
collection_class = collections.attribute_mapped_collection('a')
self._test_scalar_mapped(collection_class)
@testing.resolve_artifact_names
def test_column_mapped_collection(self):
collection_class = collections.column_mapped_collection(
children.c.a)
self._test_scalar_mapped(collection_class)
@testing.resolve_artifact_names
def test_column_mapped_collection2(self):
collection_class = collections.column_mapped_collection(
(children.c.a, children.c.b))
self._test_composite_mapped(collection_class)
def test_mixin(self):
class Ordered(util.OrderedDict, collections.MappedCollection):
def __init__(self):
collections.MappedCollection.__init__(self, lambda v: v.a)
util.OrderedDict.__init__(self)
collection_class = Ordered
self._test_scalar_mapped(collection_class)
def test_mixin2(self):
class Ordered2(util.OrderedDict, collections.MappedCollection):
def __init__(self, keyfunc):
collections.MappedCollection.__init__(self, keyfunc)
util.OrderedDict.__init__(self)
collection_class = lambda: Ordered2(lambda v: (v.a, v.b))
self._test_composite_mapped(collection_class)
# TODO: are these tests redundant vs. the above tests ?
# remove if so
class CustomCollectionsTest(_base.MappedTest):
def define_tables(self, metadata):
Table('sometable', metadata,
Column('col1',Integer, primary_key=True),
Column('data', String(30)))
Table('someothertable', metadata,
Column('col1', Integer, primary_key=True),
Column('scol1', Integer,
ForeignKey('sometable.col1')),
Column('data', String(20)))
@testing.resolve_artifact_names
def test_basic(self):
class MyList(list):
pass
class Foo(object):
pass
class Bar(object):
pass
mapper(Foo, sometable, properties={
'bars':relation(Bar, collection_class=MyList)
})
mapper(Bar, someothertable)
f = Foo()
assert isinstance(f.bars, MyList)
@testing.resolve_artifact_names
def test_lazyload(self):
"""test that a 'set' can be used as a collection and can lazyload."""
class Foo(object):
pass
class Bar(object):
pass
mapper(Foo, sometable, properties={
'bars':relation(Bar, collection_class=set)
})
mapper(Bar, someothertable)
f = Foo()
f.bars.add(Bar())
f.bars.add(Bar())
sess = create_session()
sess.add(f)
sess.flush()
sess.clear()
f = sess.query(Foo).get(f.col1)
assert len(list(f.bars)) == 2
f.bars.clear()
@testing.resolve_artifact_names
def test_dict(self):
"""test that a 'dict' can be used as a collection and can lazyload."""
class Foo(object):
pass
class Bar(object):
pass
class AppenderDict(dict):
@collection.appender
def set(self, item):
self[id(item)] = item
@collection.remover
def remove(self, item):
if id(item) in self:
del self[id(item)]
mapper(Foo, sometable, properties={
'bars':relation(Bar, collection_class=AppenderDict)
})
mapper(Bar, someothertable)
f = Foo()
f.bars.set(Bar())
f.bars.set(Bar())
sess = create_session()
sess.add(f)
sess.flush()
sess.clear()
f = sess.query(Foo).get(f.col1)
assert len(list(f.bars)) == 2
f.bars.clear()
@testing.resolve_artifact_names
def test_dict_wrapper(self):
"""test that the supplied 'dict' wrapper can be used as a collection and can lazyload."""
class Foo(object):
pass
class Bar(object):
def __init__(self, data): self.data = data
mapper(Foo, sometable, properties={
'bars':relation(Bar,
collection_class=collections.column_mapped_collection(
someothertable.c.data))
})
mapper(Bar, someothertable)
f = Foo()
col = collections.collection_adapter(f.bars)
col.append_with_event(Bar('a'))
col.append_with_event(Bar('b'))
sess = create_session()
sess.add(f)
sess.flush()
sess.clear()
f = sess.query(Foo).get(f.col1)
assert len(list(f.bars)) == 2
existing = set([id(b) for b in f.bars.values()])
col = collections.collection_adapter(f.bars)
col.append_with_event(Bar('b'))
f.bars['a'] = Bar('a')
sess.flush()
sess.clear()
f = sess.query(Foo).get(f.col1)
assert len(list(f.bars)) == 2
replaced = set([id(b) for b in f.bars.values()])
self.assert_(existing != replaced)
@testing.resolve_artifact_names
def test_list(self):
class Parent(object):
pass
class Child(object):
pass
mapper(Parent, sometable, properties={
'children':relation(Child, collection_class=list)
})
mapper(Child, someothertable)
control = list()
p = Parent()
o = Child()
control.append(o)
p.children.append(o)
assert control == p.children
assert control == list(p.children)
o = [Child(), Child(), Child(), Child()]
control.extend(o)
p.children.extend(o)
assert control == p.children
assert control == list(p.children)
assert control[0] == p.children[0]
assert control[-1] == p.children[-1]
assert control[1:3] == p.children[1:3]
del control[1]
del p.children[1]
assert control == p.children
assert control == list(p.children)
o = [Child()]
control[1:3] = o
p.children[1:3] = o
assert control == p.children
assert control == list(p.children)
o = [Child(), Child(), Child(), Child()]
control[1:3] = o
p.children[1:3] = o
assert control == p.children
assert control == list(p.children)
o = [Child(), Child(), Child(), Child()]
control[-1:-2] = o
p.children[-1:-2] = o
assert control == p.children
assert control == list(p.children)
o = [Child(), Child(), Child(), Child()]
control[4:] = o
p.children[4:] = o
assert control == p.children
assert control == list(p.children)
o = Child()
control.insert(0, o)
p.children.insert(0, o)
assert control == p.children
assert control == list(p.children)
o = Child()
control.insert(3, o)
p.children.insert(3, o)
assert control == p.children
assert control == list(p.children)
o = Child()
control.insert(999, o)
p.children.insert(999, o)
assert control == p.children
assert control == list(p.children)
del control[0:1]
del p.children[0:1]
assert control == p.children
assert control == list(p.children)
del control[1:1]
del p.children[1:1]
assert control == p.children
assert control == list(p.children)
del control[1:3]
del p.children[1:3]
assert control == p.children
assert control == list(p.children)
del control[7:]
del p.children[7:]
assert control == p.children
assert control == list(p.children)
assert control.pop() == p.children.pop()
assert control == p.children
assert control == list(p.children)
assert control.pop(0) == p.children.pop(0)
assert control == p.children
assert control == list(p.children)
assert control.pop(2) == p.children.pop(2)
assert control == p.children
assert control == list(p.children)
o = Child()
control.insert(2, o)
p.children.insert(2, o)
assert control == p.children
assert control == list(p.children)
control.remove(o)
p.children.remove(o)
assert control == p.children
assert control == list(p.children)
@testing.resolve_artifact_names
def test_custom(self):
class Parent(object):
pass
class Child(object):
pass
class MyCollection(object):
def __init__(self):
self.data = []
@collection.appender
def append(self, value):
self.data.append(value)
@collection.remover
def remove(self, value):
self.data.remove(value)
@collection.iterator
def __iter__(self):
return iter(self.data)
mapper(Parent, sometable, properties={
'children':relation(Child, collection_class=MyCollection)
})
mapper(Child, someothertable)
control = list()
p1 = Parent()
o = Child()
control.append(o)
p1.children.append(o)
assert control == list(p1.children)
o = Child()
control.append(o)
p1.children.append(o)
assert control == list(p1.children)
o = Child()
control.append(o)
p1.children.append(o)
assert control == list(p1.children)
sess = create_session()
sess.add(p1)
sess.flush()
sess.clear()
p2 = sess.query(Parent).get(p1.col1)
o = list(p2.children)
assert len(o) == 3
class InstrumentationTest(_base.ORMTest):
def test_uncooperative_descriptor_in_sweep(self):
class DoNotTouch(object):
def __get__(self, obj, owner):
raise AttributeError
class Touchy(list):
no_touch = DoNotTouch()
assert 'no_touch' in Touchy.__dict__
assert not hasattr(Touchy, 'no_touch')
assert 'no_touch' in dir(Touchy)
instrumented = collections._instrument_class(Touchy)
assert True
if __name__ == "__main__":
testenv.main()