mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-17 14:17:29 -04:00
Add sqlalchemy.ext.mutable.MutableSet
from https://bitbucket.org/zzzeek/sqlalchemy/issues/3297
This commit is contained in:
@@ -778,3 +778,68 @@ class MutableList(Mutable, list):
|
||||
|
||||
def __setstate__(self, state):
|
||||
self[:] = state
|
||||
|
||||
|
||||
class MutableSet(Mutable, set):
|
||||
"""A set type that implements :class:`.Mutable`.
|
||||
|
||||
The :class:`.MutableSet` object implements a list that will
|
||||
emit change events to the underlying mapping when the contents of
|
||||
the set are altered, including when values are added or removed.
|
||||
"""
|
||||
|
||||
def update(self, *arg):
|
||||
set.update(self, *arg)
|
||||
self.changed()
|
||||
|
||||
def intersection_update(self, *arg):
|
||||
set.intersection_update(self, *arg)
|
||||
self.changed()
|
||||
|
||||
def difference_update(self, *arg):
|
||||
set.difference_update(self, *arg)
|
||||
self.changed()
|
||||
|
||||
def symmetric_difference_update(self, *arg):
|
||||
set.symmetric_difference_update(self, *arg)
|
||||
self.changed()
|
||||
|
||||
def add(self, elem):
|
||||
set.add(self, elem)
|
||||
self.changed()
|
||||
|
||||
def remove(self, elem):
|
||||
set.remove(self, elem)
|
||||
self.changed()
|
||||
|
||||
def discard(self, elem):
|
||||
set.discard(self, elem)
|
||||
self.changed()
|
||||
|
||||
def pop(self, *arg):
|
||||
result = set.pop(self, *arg)
|
||||
self.changed()
|
||||
return result
|
||||
|
||||
def clear(self):
|
||||
set.clear(self)
|
||||
self.changed()
|
||||
|
||||
@classmethod
|
||||
def coerce(cls, index, value):
|
||||
"""Convert plain set to instance of this class."""
|
||||
if not isinstance(value, cls):
|
||||
if isinstance(value, set):
|
||||
return cls(value)
|
||||
return Mutable.coerce(index, value)
|
||||
else:
|
||||
return value
|
||||
|
||||
def __getstate__(self):
|
||||
return set(self)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.update(state)
|
||||
|
||||
def __reduce_ex__(self, proto):
|
||||
return (self.__class__, (list(self), ))
|
||||
|
||||
+195
-1
@@ -8,7 +8,7 @@ from sqlalchemy.testing import eq_, assert_raises_message, assert_raises
|
||||
from sqlalchemy.testing.util import picklers
|
||||
from sqlalchemy.testing import fixtures
|
||||
from sqlalchemy.ext.mutable import MutableComposite
|
||||
from sqlalchemy.ext.mutable import MutableDict, MutableList
|
||||
from sqlalchemy.ext.mutable import MutableDict, MutableList, MutableSet
|
||||
|
||||
|
||||
class Foo(fixtures.BasicEntity):
|
||||
@@ -461,6 +461,183 @@ class _MutableListTestBase(_MutableListTestFixture):
|
||||
eq_(f1.data[0], 3)
|
||||
|
||||
|
||||
class _MutableSetTestFixture(object):
|
||||
@classmethod
|
||||
def _type_fixture(cls):
|
||||
return MutableSet
|
||||
|
||||
def teardown(self):
|
||||
# clear out mapper events
|
||||
Mapper.dispatch._clear()
|
||||
ClassManager.dispatch._clear()
|
||||
super(_MutableSetTestFixture, self).teardown()
|
||||
|
||||
|
||||
class _MutableSetTestBase(_MutableSetTestFixture):
|
||||
run_define_tables = 'each'
|
||||
|
||||
def setup_mappers(cls):
|
||||
foo = cls.tables.foo
|
||||
|
||||
mapper(Foo, foo)
|
||||
|
||||
def test_coerce_none(self):
|
||||
sess = Session()
|
||||
f1 = Foo(data=None)
|
||||
sess.add(f1)
|
||||
sess.commit()
|
||||
eq_(f1.data, None)
|
||||
|
||||
def test_coerce_raise(self):
|
||||
assert_raises_message(
|
||||
ValueError,
|
||||
"Attribute 'data' does not accept objects of type",
|
||||
Foo, data=[1, 2, 3]
|
||||
)
|
||||
|
||||
def test_clear(self):
|
||||
sess = Session()
|
||||
|
||||
f1 = Foo(data=set([1, 2]))
|
||||
sess.add(f1)
|
||||
sess.commit()
|
||||
|
||||
f1.data.clear()
|
||||
sess.commit()
|
||||
|
||||
eq_(f1.data, set())
|
||||
|
||||
def test_pop(self):
|
||||
sess = Session()
|
||||
|
||||
f1 = Foo(data=set([1]))
|
||||
sess.add(f1)
|
||||
sess.commit()
|
||||
|
||||
eq_(f1.data.pop(), 1)
|
||||
sess.commit()
|
||||
|
||||
assert_raises(KeyError, f1.data.pop)
|
||||
|
||||
eq_(f1.data, set())
|
||||
|
||||
def test_add(self):
|
||||
sess = Session()
|
||||
|
||||
f1 = Foo(data=set([1, 2]))
|
||||
sess.add(f1)
|
||||
sess.commit()
|
||||
|
||||
f1.data.add(5)
|
||||
sess.commit()
|
||||
|
||||
eq_(f1.data, set([1, 2, 5]))
|
||||
|
||||
def test_update(self):
|
||||
sess = Session()
|
||||
|
||||
f1 = Foo(data=set([1, 2]))
|
||||
sess.add(f1)
|
||||
sess.commit()
|
||||
|
||||
f1.data.update(set([2, 5]))
|
||||
sess.commit()
|
||||
|
||||
eq_(f1.data, set([1, 2, 5]))
|
||||
|
||||
def test_intersection_update(self):
|
||||
sess = Session()
|
||||
|
||||
f1 = Foo(data=set([1, 2]))
|
||||
sess.add(f1)
|
||||
sess.commit()
|
||||
|
||||
f1.data.intersection_update(set([2, 5]))
|
||||
sess.commit()
|
||||
|
||||
eq_(f1.data, set([2]))
|
||||
|
||||
def test_difference_update(self):
|
||||
sess = Session()
|
||||
|
||||
f1 = Foo(data=set([1, 2]))
|
||||
sess.add(f1)
|
||||
sess.commit()
|
||||
|
||||
f1.data.difference_update(set([2, 5]))
|
||||
sess.commit()
|
||||
|
||||
eq_(f1.data, set([1]))
|
||||
|
||||
def test_symmetric_difference_update(self):
|
||||
sess = Session()
|
||||
|
||||
f1 = Foo(data=set([1, 2]))
|
||||
sess.add(f1)
|
||||
sess.commit()
|
||||
|
||||
f1.data.symmetric_difference_update(set([2, 5]))
|
||||
sess.commit()
|
||||
|
||||
eq_(f1.data, set([1, 5]))
|
||||
|
||||
def test_remove(self):
|
||||
sess = Session()
|
||||
|
||||
f1 = Foo(data=set([1, 2, 3]))
|
||||
sess.add(f1)
|
||||
sess.commit()
|
||||
|
||||
f1.data.remove(2)
|
||||
sess.commit()
|
||||
|
||||
eq_(f1.data, set([1, 3]))
|
||||
|
||||
def test_discard(self):
|
||||
sess = Session()
|
||||
|
||||
f1 = Foo(data=set([1, 2, 3]))
|
||||
sess.add(f1)
|
||||
sess.commit()
|
||||
|
||||
f1.data.discard(2)
|
||||
sess.commit()
|
||||
|
||||
eq_(f1.data, set([1, 3]))
|
||||
|
||||
f1.data.discard(2)
|
||||
sess.commit()
|
||||
|
||||
eq_(f1.data, set([1, 3]))
|
||||
|
||||
def test_pickle_parent(self):
|
||||
sess = Session()
|
||||
|
||||
f1 = Foo(data=set([1, 2]))
|
||||
sess.add(f1)
|
||||
sess.commit()
|
||||
f1.data
|
||||
sess.close()
|
||||
|
||||
for loads, dumps in picklers():
|
||||
sess = Session()
|
||||
f2 = loads(dumps(f1))
|
||||
sess.add(f2)
|
||||
f2.data.add(3)
|
||||
assert f2 in sess.dirty
|
||||
|
||||
def test_unrelated_flush(self):
|
||||
sess = Session()
|
||||
f1 = Foo(data=set([1, 2]), unrelated_data="unrelated")
|
||||
sess.add(f1)
|
||||
sess.flush()
|
||||
f1.unrelated_data = "unrelated 2"
|
||||
sess.flush()
|
||||
f1.data.add(3)
|
||||
sess.commit()
|
||||
eq_(f1.data, set([1, 2, 3]))
|
||||
|
||||
|
||||
class MutableColumnDefaultTest(_MutableDictTestFixture, fixtures.MappedTest):
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
@@ -566,6 +743,23 @@ class MutableListWithScalarPickleTest(_MutableListTestBase, fixtures.MappedTest)
|
||||
)
|
||||
|
||||
|
||||
class MutableSetWithScalarPickleTest(_MutableSetTestBase, fixtures.MappedTest):
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
MutableSet = cls._type_fixture()
|
||||
|
||||
mutable_pickle = MutableSet.as_mutable(PickleType)
|
||||
Table('foo', metadata,
|
||||
Column('id', Integer, primary_key=True,
|
||||
test_needs_autoincrement=True),
|
||||
Column('skip', mutable_pickle),
|
||||
Column('data', mutable_pickle),
|
||||
Column('non_mutable_data', PickleType),
|
||||
Column('unrelated_data', String(50))
|
||||
)
|
||||
|
||||
|
||||
class MutableAssocWithAttrInheritTest(_MutableDictTestBase,
|
||||
fixtures.MappedTest):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user