- add QueryContext to load(), refresh()

- add list of attribute names to refresh()
- ensure refresh() only called when attributes actually refreshed
- tests.  [ticket:2011]
This commit is contained in:
Mike Bayer
2010-12-31 11:46:30 -05:00
parent d9b032e36a
commit 9d04eaffcc
7 changed files with 251 additions and 52 deletions
+2 -2
View File
@@ -51,7 +51,7 @@ class Mutable(object):
key = attribute.key
parent_cls = attribute.class_
def load(state):
def load(state, *args):
"""Listen for objects loaded or refreshed.
Wrap the target data member's value with
@@ -230,7 +230,7 @@ class MutableComposite(object):
key = attribute.key
parent_cls = attribute.class_
def load(state):
def load(state, *args):
"""Listen for objects loaded or refreshed.
Wrap the target data member's value with
+1 -1
View File
@@ -83,7 +83,7 @@ class MapperExtension(object):
if me_meth is not ls_meth:
if meth == 'reconstruct_instance':
def go(ls_meth):
def reconstruct(instance):
def reconstruct(instance, ctx):
ls_meth(self, instance)
return reconstruct
event.listen(self.class_manager, 'load',
+1 -1
View File
@@ -155,7 +155,7 @@ class CompositeProperty(DescriptorProperty):
def _setup_event_handlers(self):
"""Establish events that populate/expire the composite attribute."""
def load_handler(state):
def load_handler(state, *args):
dict_ = state.dict
if self.key in dict_:
+41 -8
View File
@@ -121,7 +121,7 @@ class InstanceEvents(event.Events):
"""
def load(self, target):
def load(self, target, context):
"""Receive an object instance after it has been created via
``__new__``, and after initial attribute population has
occurred.
@@ -135,29 +135,59 @@ class InstanceEvents(event.Events):
attributes and collections may or may not be loaded or even
initialized, depending on what's present in the result rows.
:param target: the mapped instance. If
the event is configured with ``raw=True``, this will
instead be the :class:`.InstanceState` state-management
object associated with the instance.
:param context: the :class:`.QueryContext` corresponding to the
current :class:`.Query` in progress.
"""
def refresh(self, target):
def refresh(self, target, context, attrs):
"""Receive an object instance after one or more attributes have
been refreshed.
been refreshed from a query.
This hook is called after expired attributes have been reloaded.
:param target: the mapped instance. If
the event is configured with ``raw=True``, this will
instead be the :class:`.InstanceState` state-management
object associated with the instance.
:param context: the :class:`.QueryContext` corresponding to the
current :class:`.Query` in progress.
:param attrs: iterable collection of attribute names which
were populated, or None if all column-mapped, non-deferred
attributes were populated.
"""
def expire(self, target, keys):
def expire(self, target, attrs):
"""Receive an object instance after its attributes or some subset
have been expired.
'keys' is a list of attribute names. If None, the entire
state was expired.
:param target: the mapped instance. If
the event is configured with ``raw=True``, this will
instead be the :class:`.InstanceState` state-management
object associated with the instance.
:param attrs: iterable collection of attribute
names which were expired, or None if all attributes were
expired.
"""
def resurrect(self, target):
"""Receive an object instance as it is 'resurrected' from
garbage collection, which occurs when a "dirty" state falls
out of scope."""
out of scope.
:param target: the mapped instance. If
the event is configured with ``raw=True``, this will
instead be the :class:`.InstanceState` state-management
object associated with the instance.
"""
class MapperEvents(event.Events):
@@ -412,7 +442,10 @@ class MapperEvents(event.Events):
:param row: the result row being handled. This may be
an actual :class:`.RowProxy` or may be a dictionary containing
:class:`.Column` objects as keys.
:param class\_: the mapped class.
:param target: the mapped instance. If
the event is configured with ``raw=True``, this will
instead be the :class:`.InstanceState` state-management
object associated with the instance.
:return: When configured with ``retval=True``, a return
value of ``EXT_STOP`` will bypass instance population by
the mapper. A value of ``EXT_CONTINUE`` indicates that
+28 -25
View File
@@ -2269,37 +2269,40 @@ class Mapper(object):
else:
populate_state(state, dict_, row, isnew, only_load_props)
else:
if loaded_instance:
state.manager.dispatch.load(state, context)
elif isnew:
state.manager.dispatch.refresh(state, context, only_load_props)
elif state in context.partials or state.unloaded:
# populate attributes on non-loading instances which have
# been expired
# TODO: apply eager loads to un-lazy loaded collections ?
if state in context.partials or state.unloaded:
if state in context.partials:
isnew = False
(d_, attrs) = context.partials[state]
else:
isnew = True
attrs = state.unloaded
# allow query.instances to commit the subset of attrs
context.partials[state] = (dict_, attrs)
if populate_instance:
for fn in populate_instance:
ret = fn(self, context, row, state,
only_load_props=attrs,
instancekey=identitykey, isnew=isnew)
if ret is not EXT_CONTINUE:
break
else:
populate_state(state, dict_, row, isnew, attrs)
if state in context.partials:
isnew = False
(d_, attrs) = context.partials[state]
else:
isnew = True
attrs = state.unloaded
# allow query.instances to commit the subset of attrs
context.partials[state] = (dict_, attrs)
if populate_instance:
for fn in populate_instance:
ret = fn(self, context, row, state,
only_load_props=attrs,
instancekey=identitykey, isnew=isnew)
if ret is not EXT_CONTINUE:
break
else:
populate_state(state, dict_, row, isnew, attrs)
else:
populate_state(state, dict_, row, isnew, attrs)
if isnew:
state.manager.dispatch.refresh(state, context, attrs)
if loaded_instance:
state.manager.dispatch.load(state)
elif isnew:
state.manager.dispatch.refresh(state)
if result is not None:
if append_result:
@@ -2462,7 +2465,7 @@ def validates(*names):
return fn
return wrap
def _event_on_load(state):
def _event_on_load(state, ctx):
instrumenting_mapper = state.manager.info[_INSTRUMENTOR]
if instrumenting_mapper._reconstructor:
instrumenting_mapper._reconstructor(state.obj())
+177 -14
View File
@@ -12,7 +12,18 @@ from test.lib.testing import eq_
from test.orm import _base, _fixtures
from sqlalchemy import event
class MapperEventsTest(_fixtures.FixtureTest):
class _RemoveListeners(object):
def teardown(self):
# TODO: need to get remove() functionality
# going
Mapper.dispatch._clear()
ClassManager.dispatch._clear()
Session.dispatch._clear()
super(_RemoveListeners, self).teardown()
class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest):
run_inserts = None
@testing.resolve_artifact_names
@@ -58,12 +69,6 @@ class MapperEventsTest(_fixtures.FixtureTest):
b = B()
eq_(canary, [('init_a', b), ('init_b', b),('init_e', b)])
def teardown(self):
# TODO: need to get remove() functionality
# going
Mapper.dispatch._clear()
ClassManager.dispatch._clear()
super(MapperEventsTest, self).teardown()
def listen_all(self, mapper, **kw):
canary = []
@@ -223,7 +228,171 @@ class MapperEventsTest(_fixtures.FixtureTest):
eq_(canary, [User, Address])
class SessionEventsTest(_fixtures.FixtureTest):
class LoadTest(_fixtures.FixtureTest):
run_inserts = None
@classmethod
@testing.resolve_artifact_names
def setup_mappers(cls):
mapper(User, users)
@testing.resolve_artifact_names
def _fixture(self):
canary = []
def load(target, ctx):
canary.append("load")
def refresh(target, ctx, attrs):
canary.append(("refresh", attrs))
event.listen(User, "load", load)
event.listen(User, "refresh", refresh)
return canary
@testing.resolve_artifact_names
def test_just_loaded(self):
canary = self._fixture()
sess = Session()
u1 = User(name='u1')
sess.add(u1)
sess.commit()
sess.close()
sess.query(User).first()
eq_(canary, ['load'])
@testing.resolve_artifact_names
def test_repeated_rows(self):
canary = self._fixture()
sess = Session()
u1 = User(name='u1')
sess.add(u1)
sess.commit()
sess.close()
sess.query(User).union_all(sess.query(User)).all()
eq_(canary, ['load'])
class RefreshTest(_fixtures.FixtureTest):
run_inserts = None
@classmethod
@testing.resolve_artifact_names
def setup_mappers(cls):
mapper(User, users)
@testing.resolve_artifact_names
def _fixture(self):
canary = []
def load(target, ctx):
canary.append("load")
def refresh(target, ctx, attrs):
canary.append(("refresh", attrs))
event.listen(User, "load", load)
event.listen(User, "refresh", refresh)
return canary
@testing.resolve_artifact_names
def test_already_present(self):
canary = self._fixture()
sess = Session()
u1 = User(name='u1')
sess.add(u1)
sess.flush()
sess.query(User).first()
eq_(canary, [])
@testing.resolve_artifact_names
def test_repeated_rows(self):
canary = self._fixture()
sess = Session()
u1 = User(name='u1')
sess.add(u1)
sess.commit()
sess.query(User).union_all(sess.query(User)).all()
eq_(canary, [('refresh', set(['id','name']))])
@testing.resolve_artifact_names
def test_via_refresh_state(self):
canary = self._fixture()
sess = Session()
u1 = User(name='u1')
sess.add(u1)
sess.commit()
u1.name
eq_(canary, [('refresh', set(['id','name']))])
@testing.resolve_artifact_names
def test_was_expired(self):
canary = self._fixture()
sess = Session()
u1 = User(name='u1')
sess.add(u1)
sess.flush()
sess.expire(u1)
sess.query(User).first()
eq_(canary, [('refresh', set(['id','name']))])
@testing.resolve_artifact_names
def test_was_expired_via_commit(self):
canary = self._fixture()
sess = Session()
u1 = User(name='u1')
sess.add(u1)
sess.commit()
sess.query(User).first()
eq_(canary, [('refresh', set(['id','name']))])
@testing.resolve_artifact_names
def test_was_expired_attrs(self):
canary = self._fixture()
sess = Session()
u1 = User(name='u1')
sess.add(u1)
sess.flush()
sess.expire(u1, ['name'])
sess.query(User).first()
eq_(canary, [('refresh', set(['name']))])
@testing.resolve_artifact_names
def test_populate_existing(self):
canary = self._fixture()
sess = Session()
u1 = User(name='u1')
sess.add(u1)
sess.commit()
sess.query(User).populate_existing().first()
eq_(canary, [('refresh', None)])
class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest):
run_inserts = None
def test_class_listen(self):
@@ -491,12 +660,6 @@ class SessionEventsTest(_fixtures.FixtureTest):
]
)
def teardown(self):
# TODO: need to get remove() functionality
# going
Session.dispatch._clear()
super(SessionEventsTest, self).teardown()
class MapperExtensionTest(_fixtures.FixtureTest):
+1 -1
View File
@@ -20,7 +20,7 @@ class MergeTest(_fixtures.FixtureTest):
def load_tracker(self, cls, canary=None):
if canary is None:
def canary(instance):
def canary(instance, *args):
canary.called += 1
canary.called = 0