Add sqlalchemy.ext.mutable.MutableSet

from https://bitbucket.org/zzzeek/sqlalchemy/issues/3297
This commit is contained in:
Jeong YunWon
2016-02-13 19:20:12 +09:00
parent 1b6422a603
commit f7354b43e4
2 changed files with 260 additions and 1 deletions
+65
View File
@@ -778,3 +778,68 @@ class MutableList(Mutable, list):
def __setstate__(self, state):
self[:] = state
class MutableSet(Mutable, set):
"""A set type that implements :class:`.Mutable`.
The :class:`.MutableSet` object implements a list that will
emit change events to the underlying mapping when the contents of
the set are altered, including when values are added or removed.
"""
def update(self, *arg):
set.update(self, *arg)
self.changed()
def intersection_update(self, *arg):
set.intersection_update(self, *arg)
self.changed()
def difference_update(self, *arg):
set.difference_update(self, *arg)
self.changed()
def symmetric_difference_update(self, *arg):
set.symmetric_difference_update(self, *arg)
self.changed()
def add(self, elem):
set.add(self, elem)
self.changed()
def remove(self, elem):
set.remove(self, elem)
self.changed()
def discard(self, elem):
set.discard(self, elem)
self.changed()
def pop(self, *arg):
result = set.pop(self, *arg)
self.changed()
return result
def clear(self):
set.clear(self)
self.changed()
@classmethod
def coerce(cls, index, value):
"""Convert plain set to instance of this class."""
if not isinstance(value, cls):
if isinstance(value, set):
return cls(value)
return Mutable.coerce(index, value)
else:
return value
def __getstate__(self):
return set(self)
def __setstate__(self, state):
self.update(state)
def __reduce_ex__(self, proto):
return (self.__class__, (list(self), ))
+195 -1
View File
@@ -8,7 +8,7 @@ from sqlalchemy.testing import eq_, assert_raises_message, assert_raises
from sqlalchemy.testing.util import picklers
from sqlalchemy.testing import fixtures
from sqlalchemy.ext.mutable import MutableComposite
from sqlalchemy.ext.mutable import MutableDict, MutableList
from sqlalchemy.ext.mutable import MutableDict, MutableList, MutableSet
class Foo(fixtures.BasicEntity):
@@ -461,6 +461,183 @@ class _MutableListTestBase(_MutableListTestFixture):
eq_(f1.data[0], 3)
class _MutableSetTestFixture(object):
@classmethod
def _type_fixture(cls):
return MutableSet
def teardown(self):
# clear out mapper events
Mapper.dispatch._clear()
ClassManager.dispatch._clear()
super(_MutableSetTestFixture, self).teardown()
class _MutableSetTestBase(_MutableSetTestFixture):
run_define_tables = 'each'
def setup_mappers(cls):
foo = cls.tables.foo
mapper(Foo, foo)
def test_coerce_none(self):
sess = Session()
f1 = Foo(data=None)
sess.add(f1)
sess.commit()
eq_(f1.data, None)
def test_coerce_raise(self):
assert_raises_message(
ValueError,
"Attribute 'data' does not accept objects of type",
Foo, data=[1, 2, 3]
)
def test_clear(self):
sess = Session()
f1 = Foo(data=set([1, 2]))
sess.add(f1)
sess.commit()
f1.data.clear()
sess.commit()
eq_(f1.data, set())
def test_pop(self):
sess = Session()
f1 = Foo(data=set([1]))
sess.add(f1)
sess.commit()
eq_(f1.data.pop(), 1)
sess.commit()
assert_raises(KeyError, f1.data.pop)
eq_(f1.data, set())
def test_add(self):
sess = Session()
f1 = Foo(data=set([1, 2]))
sess.add(f1)
sess.commit()
f1.data.add(5)
sess.commit()
eq_(f1.data, set([1, 2, 5]))
def test_update(self):
sess = Session()
f1 = Foo(data=set([1, 2]))
sess.add(f1)
sess.commit()
f1.data.update(set([2, 5]))
sess.commit()
eq_(f1.data, set([1, 2, 5]))
def test_intersection_update(self):
sess = Session()
f1 = Foo(data=set([1, 2]))
sess.add(f1)
sess.commit()
f1.data.intersection_update(set([2, 5]))
sess.commit()
eq_(f1.data, set([2]))
def test_difference_update(self):
sess = Session()
f1 = Foo(data=set([1, 2]))
sess.add(f1)
sess.commit()
f1.data.difference_update(set([2, 5]))
sess.commit()
eq_(f1.data, set([1]))
def test_symmetric_difference_update(self):
sess = Session()
f1 = Foo(data=set([1, 2]))
sess.add(f1)
sess.commit()
f1.data.symmetric_difference_update(set([2, 5]))
sess.commit()
eq_(f1.data, set([1, 5]))
def test_remove(self):
sess = Session()
f1 = Foo(data=set([1, 2, 3]))
sess.add(f1)
sess.commit()
f1.data.remove(2)
sess.commit()
eq_(f1.data, set([1, 3]))
def test_discard(self):
sess = Session()
f1 = Foo(data=set([1, 2, 3]))
sess.add(f1)
sess.commit()
f1.data.discard(2)
sess.commit()
eq_(f1.data, set([1, 3]))
f1.data.discard(2)
sess.commit()
eq_(f1.data, set([1, 3]))
def test_pickle_parent(self):
sess = Session()
f1 = Foo(data=set([1, 2]))
sess.add(f1)
sess.commit()
f1.data
sess.close()
for loads, dumps in picklers():
sess = Session()
f2 = loads(dumps(f1))
sess.add(f2)
f2.data.add(3)
assert f2 in sess.dirty
def test_unrelated_flush(self):
sess = Session()
f1 = Foo(data=set([1, 2]), unrelated_data="unrelated")
sess.add(f1)
sess.flush()
f1.unrelated_data = "unrelated 2"
sess.flush()
f1.data.add(3)
sess.commit()
eq_(f1.data, set([1, 2, 3]))
class MutableColumnDefaultTest(_MutableDictTestFixture, fixtures.MappedTest):
@classmethod
def define_tables(cls, metadata):
@@ -566,6 +743,23 @@ class MutableListWithScalarPickleTest(_MutableListTestBase, fixtures.MappedTest)
)
class MutableSetWithScalarPickleTest(_MutableSetTestBase, fixtures.MappedTest):
@classmethod
def define_tables(cls, metadata):
MutableSet = cls._type_fixture()
mutable_pickle = MutableSet.as_mutable(PickleType)
Table('foo', metadata,
Column('id', Integer, primary_key=True,
test_needs_autoincrement=True),
Column('skip', mutable_pickle),
Column('data', mutable_pickle),
Column('non_mutable_data', PickleType),
Column('unrelated_data', String(50))
)
class MutableAssocWithAttrInheritTest(_MutableDictTestBase,
fixtures.MappedTest):