mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-06-01 21:38:55 -04:00
343734b95f
as test_customfinder_pass at least does not provide the alternate ext manager and in all cases it's the global function that counts
651 lines
19 KiB
Python
651 lines
19 KiB
Python
from sqlalchemy.testing import eq_, assert_raises, assert_raises_message, ne_
|
|
from sqlalchemy import util
|
|
import sqlalchemy as sa
|
|
from sqlalchemy.orm import class_mapper
|
|
from sqlalchemy.orm import attributes
|
|
from sqlalchemy.orm.attributes import set_attribute, \
|
|
get_attribute, del_attribute
|
|
from sqlalchemy.orm.instrumentation import is_instrumented
|
|
from sqlalchemy.orm import clear_mappers
|
|
from sqlalchemy.testing import fixtures
|
|
from sqlalchemy.ext import instrumentation
|
|
from sqlalchemy.orm.instrumentation import register_class, manager_of_class
|
|
from sqlalchemy.testing.util import decorator
|
|
from sqlalchemy.orm import events
|
|
from sqlalchemy import event
|
|
|
|
|
|
@decorator
|
|
def modifies_instrumentation_finders(fn, *args, **kw):
|
|
pristine = instrumentation.instrumentation_finders[:]
|
|
try:
|
|
fn(*args, **kw)
|
|
finally:
|
|
del instrumentation.instrumentation_finders[:]
|
|
instrumentation.instrumentation_finders.extend(pristine)
|
|
|
|
|
|
class _ExtBase(object):
|
|
@classmethod
|
|
def teardown_class(cls):
|
|
instrumentation._reinstall_default_lookups()
|
|
|
|
|
|
class MyTypesManager(instrumentation.InstrumentationManager):
|
|
|
|
def instrument_attribute(self, class_, key, attr):
|
|
pass
|
|
|
|
def install_descriptor(self, class_, key, attr):
|
|
pass
|
|
|
|
def uninstall_descriptor(self, class_, key):
|
|
pass
|
|
|
|
def instrument_collection_class(self, class_, key, collection_class):
|
|
return MyListLike
|
|
|
|
def get_instance_dict(self, class_, instance):
|
|
return instance._goofy_dict
|
|
|
|
def initialize_instance_dict(self, class_, instance):
|
|
instance.__dict__['_goofy_dict'] = {}
|
|
|
|
def install_state(self, class_, instance, state):
|
|
instance.__dict__['_my_state'] = state
|
|
|
|
def state_getter(self, class_):
|
|
return lambda instance: instance.__dict__['_my_state']
|
|
|
|
|
|
class MyListLike(list):
|
|
# add @appender, @remover decorators as needed
|
|
_sa_iterator = list.__iter__
|
|
_sa_linker = None
|
|
_sa_converter = None
|
|
|
|
def _sa_appender(self, item, _sa_initiator=None):
|
|
if _sa_initiator is not False:
|
|
self._sa_adapter.fire_append_event(item, _sa_initiator)
|
|
list.append(self, item)
|
|
append = _sa_appender
|
|
|
|
def _sa_remover(self, item, _sa_initiator=None):
|
|
self._sa_adapter.fire_pre_remove_event(_sa_initiator)
|
|
if _sa_initiator is not False:
|
|
self._sa_adapter.fire_remove_event(item, _sa_initiator)
|
|
list.remove(self, item)
|
|
remove = _sa_remover
|
|
|
|
|
|
MyBaseClass, MyClass = None, None
|
|
|
|
|
|
class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
|
|
|
|
@classmethod
|
|
def setup_class(cls):
|
|
global MyBaseClass, MyClass
|
|
|
|
class MyBaseClass(object):
|
|
__sa_instrumentation_manager__ = \
|
|
instrumentation.InstrumentationManager
|
|
|
|
class MyClass(object):
|
|
|
|
# This proves that a staticmethod will work here; don't
|
|
# flatten this back to a class assignment!
|
|
def __sa_instrumentation_manager__(cls):
|
|
return MyTypesManager(cls)
|
|
|
|
__sa_instrumentation_manager__ = staticmethod(
|
|
__sa_instrumentation_manager__)
|
|
|
|
# This proves SA can handle a class with non-string dict keys
|
|
if not util.pypy and not util.jython:
|
|
locals()[42] = 99 # Don't remove this line!
|
|
|
|
def __init__(self, **kwargs):
|
|
for k in kwargs:
|
|
setattr(self, k, kwargs[k])
|
|
|
|
def __getattr__(self, key):
|
|
if is_instrumented(self, key):
|
|
return get_attribute(self, key)
|
|
else:
|
|
try:
|
|
return self._goofy_dict[key]
|
|
except KeyError:
|
|
raise AttributeError(key)
|
|
|
|
def __setattr__(self, key, value):
|
|
if is_instrumented(self, key):
|
|
set_attribute(self, key, value)
|
|
else:
|
|
self._goofy_dict[key] = value
|
|
|
|
def __hasattr__(self, key):
|
|
if is_instrumented(self, key):
|
|
return True
|
|
else:
|
|
return key in self._goofy_dict
|
|
|
|
def __delattr__(self, key):
|
|
if is_instrumented(self, key):
|
|
del_attribute(self, key)
|
|
else:
|
|
del self._goofy_dict[key]
|
|
|
|
def teardown(self):
|
|
clear_mappers()
|
|
|
|
def test_instance_dict(self):
|
|
class User(MyClass):
|
|
pass
|
|
|
|
register_class(User)
|
|
attributes.register_attribute(
|
|
User, 'user_id', uselist=False, useobject=False)
|
|
attributes.register_attribute(
|
|
User, 'user_name', uselist=False, useobject=False)
|
|
attributes.register_attribute(
|
|
User, 'email_address', uselist=False, useobject=False)
|
|
|
|
u = User()
|
|
u.user_id = 7
|
|
u.user_name = 'john'
|
|
u.email_address = 'lala@123.com'
|
|
eq_(
|
|
u.__dict__,
|
|
{
|
|
'_my_state': u._my_state,
|
|
'_goofy_dict': {
|
|
'user_id': 7, 'user_name': 'john',
|
|
'email_address': 'lala@123.com'}}
|
|
)
|
|
|
|
def test_basic(self):
|
|
for base in (object, MyBaseClass, MyClass):
|
|
class User(base):
|
|
pass
|
|
|
|
register_class(User)
|
|
attributes.register_attribute(
|
|
User, 'user_id', uselist=False, useobject=False)
|
|
attributes.register_attribute(
|
|
User, 'user_name', uselist=False, useobject=False)
|
|
attributes.register_attribute(
|
|
User, 'email_address', uselist=False, useobject=False)
|
|
|
|
u = User()
|
|
u.user_id = 7
|
|
u.user_name = 'john'
|
|
u.email_address = 'lala@123.com'
|
|
|
|
eq_(u.user_id, 7)
|
|
eq_(u.user_name, "john")
|
|
eq_(u.email_address, "lala@123.com")
|
|
attributes.instance_state(u)._commit_all(
|
|
attributes.instance_dict(u))
|
|
eq_(u.user_id, 7)
|
|
eq_(u.user_name, "john")
|
|
eq_(u.email_address, "lala@123.com")
|
|
|
|
u.user_name = 'heythere'
|
|
u.email_address = 'foo@bar.com'
|
|
eq_(u.user_id, 7)
|
|
eq_(u.user_name, "heythere")
|
|
eq_(u.email_address, "foo@bar.com")
|
|
|
|
def test_deferred(self):
|
|
for base in (object, MyBaseClass, MyClass):
|
|
class Foo(base):
|
|
pass
|
|
|
|
data = {'a': 'this is a', 'b': 12}
|
|
|
|
def loader(state, keys):
|
|
for k in keys:
|
|
state.dict[k] = data[k]
|
|
return attributes.ATTR_WAS_SET
|
|
|
|
manager = register_class(Foo)
|
|
manager.deferred_scalar_loader = loader
|
|
attributes.register_attribute(
|
|
Foo, 'a', uselist=False, useobject=False)
|
|
attributes.register_attribute(
|
|
Foo, 'b', uselist=False, useobject=False)
|
|
|
|
if base is object:
|
|
assert Foo not in \
|
|
instrumentation._instrumentation_factory._state_finders
|
|
else:
|
|
assert Foo in \
|
|
instrumentation._instrumentation_factory._state_finders
|
|
|
|
f = Foo()
|
|
attributes.instance_state(f)._expire(
|
|
attributes.instance_dict(f), set())
|
|
eq_(f.a, "this is a")
|
|
eq_(f.b, 12)
|
|
|
|
f.a = "this is some new a"
|
|
attributes.instance_state(f)._expire(
|
|
attributes.instance_dict(f), set())
|
|
eq_(f.a, "this is a")
|
|
eq_(f.b, 12)
|
|
|
|
attributes.instance_state(f)._expire(
|
|
attributes.instance_dict(f), set())
|
|
f.a = "this is another new a"
|
|
eq_(f.a, "this is another new a")
|
|
eq_(f.b, 12)
|
|
|
|
attributes.instance_state(f)._expire(
|
|
attributes.instance_dict(f), set())
|
|
eq_(f.a, "this is a")
|
|
eq_(f.b, 12)
|
|
|
|
del f.a
|
|
eq_(f.a, None)
|
|
eq_(f.b, 12)
|
|
|
|
attributes.instance_state(f)._commit_all(
|
|
attributes.instance_dict(f))
|
|
eq_(f.a, None)
|
|
eq_(f.b, 12)
|
|
|
|
def test_inheritance(self):
|
|
"""tests that attributes are polymorphic"""
|
|
|
|
for base in (object, MyBaseClass, MyClass):
|
|
class Foo(base):
|
|
pass
|
|
|
|
class Bar(Foo):
|
|
pass
|
|
|
|
register_class(Foo)
|
|
register_class(Bar)
|
|
|
|
def func1(state, passive):
|
|
return "this is the foo attr"
|
|
|
|
def func2(state, passive):
|
|
return "this is the bar attr"
|
|
|
|
def func3(state, passive):
|
|
return "this is the shared attr"
|
|
attributes.register_attribute(Foo, 'element',
|
|
uselist=False, callable_=func1,
|
|
useobject=True)
|
|
attributes.register_attribute(Foo, 'element2',
|
|
uselist=False, callable_=func3,
|
|
useobject=True)
|
|
attributes.register_attribute(Bar, 'element',
|
|
uselist=False, callable_=func2,
|
|
useobject=True)
|
|
|
|
x = Foo()
|
|
y = Bar()
|
|
assert x.element == 'this is the foo attr'
|
|
assert y.element == 'this is the bar attr', y.element
|
|
assert x.element2 == 'this is the shared attr'
|
|
assert y.element2 == 'this is the shared attr'
|
|
|
|
def test_collection_with_backref(self):
|
|
for base in (object, MyBaseClass, MyClass):
|
|
class Post(base):
|
|
pass
|
|
|
|
class Blog(base):
|
|
pass
|
|
|
|
register_class(Post)
|
|
register_class(Blog)
|
|
attributes.register_attribute(
|
|
Post, 'blog', uselist=False,
|
|
backref='posts', trackparent=True, useobject=True)
|
|
attributes.register_attribute(
|
|
Blog, 'posts', uselist=True,
|
|
backref='blog', trackparent=True, useobject=True)
|
|
b = Blog()
|
|
(p1, p2, p3) = (Post(), Post(), Post())
|
|
b.posts.append(p1)
|
|
b.posts.append(p2)
|
|
b.posts.append(p3)
|
|
self.assert_(b.posts == [p1, p2, p3])
|
|
self.assert_(p2.blog is b)
|
|
|
|
p3.blog = None
|
|
self.assert_(b.posts == [p1, p2])
|
|
p4 = Post()
|
|
p4.blog = b
|
|
self.assert_(b.posts == [p1, p2, p4])
|
|
|
|
p4.blog = b
|
|
p4.blog = b
|
|
self.assert_(b.posts == [p1, p2, p4])
|
|
|
|
# assert no failure removing None
|
|
p5 = Post()
|
|
p5.blog = None
|
|
del p5.blog
|
|
|
|
def test_history(self):
|
|
for base in (object, MyBaseClass, MyClass):
|
|
class Foo(base):
|
|
pass
|
|
|
|
class Bar(base):
|
|
pass
|
|
|
|
register_class(Foo)
|
|
register_class(Bar)
|
|
attributes.register_attribute(
|
|
Foo, "name", uselist=False, useobject=False)
|
|
attributes.register_attribute(
|
|
Foo, "bars", uselist=True, trackparent=True, useobject=True)
|
|
attributes.register_attribute(
|
|
Bar, "name", uselist=False, useobject=False)
|
|
|
|
f1 = Foo()
|
|
f1.name = 'f1'
|
|
|
|
eq_(
|
|
attributes.get_state_history(
|
|
attributes.instance_state(f1), 'name'),
|
|
(['f1'], (), ()))
|
|
|
|
b1 = Bar()
|
|
b1.name = 'b1'
|
|
f1.bars.append(b1)
|
|
eq_(
|
|
attributes.get_state_history(
|
|
attributes.instance_state(f1), 'bars'),
|
|
([b1], [], []))
|
|
|
|
attributes.instance_state(f1)._commit_all(
|
|
attributes.instance_dict(f1))
|
|
attributes.instance_state(b1)._commit_all(
|
|
attributes.instance_dict(b1))
|
|
|
|
eq_(
|
|
attributes.get_state_history(
|
|
attributes.instance_state(f1),
|
|
'name'),
|
|
((), ['f1'], ()))
|
|
eq_(
|
|
attributes.get_state_history(
|
|
attributes.instance_state(f1),
|
|
'bars'),
|
|
((), [b1], ()))
|
|
|
|
f1.name = 'f1mod'
|
|
b2 = Bar()
|
|
b2.name = 'b2'
|
|
f1.bars.append(b2)
|
|
eq_(
|
|
attributes.get_state_history(
|
|
attributes.instance_state(f1), 'name'),
|
|
(['f1mod'], (), ['f1']))
|
|
eq_(
|
|
attributes.get_state_history(
|
|
attributes.instance_state(f1), 'bars'),
|
|
([b2], [b1], []))
|
|
f1.bars.remove(b1)
|
|
eq_(
|
|
attributes.get_state_history(
|
|
attributes.instance_state(f1), 'bars'),
|
|
([b2], [], [b1]))
|
|
|
|
def test_null_instrumentation(self):
|
|
class Foo(MyBaseClass):
|
|
pass
|
|
register_class(Foo)
|
|
attributes.register_attribute(
|
|
Foo, "name", uselist=False, useobject=False)
|
|
attributes.register_attribute(
|
|
Foo, "bars", uselist=True, trackparent=True, useobject=True)
|
|
|
|
assert Foo.name == attributes.manager_of_class(Foo)['name']
|
|
assert Foo.bars == attributes.manager_of_class(Foo)['bars']
|
|
|
|
def test_alternate_finders(self):
|
|
"""Ensure the generic finder front-end deals with edge cases."""
|
|
|
|
class Unknown(object):
|
|
pass
|
|
|
|
class Known(MyBaseClass):
|
|
pass
|
|
|
|
register_class(Known)
|
|
k, u = Known(), Unknown()
|
|
|
|
assert instrumentation.manager_of_class(Unknown) is None
|
|
assert instrumentation.manager_of_class(Known) is not None
|
|
assert instrumentation.manager_of_class(None) is None
|
|
|
|
assert attributes.instance_state(k) is not None
|
|
assert_raises((AttributeError, KeyError),
|
|
attributes.instance_state, u)
|
|
assert_raises((AttributeError, KeyError),
|
|
attributes.instance_state, None)
|
|
|
|
def test_unmapped_not_type_error(self):
|
|
"""extension version of the same test in test_mapper.
|
|
|
|
fixes #3408
|
|
"""
|
|
assert_raises_message(
|
|
sa.exc.ArgumentError,
|
|
"Class object expected, got '5'.",
|
|
class_mapper, 5
|
|
)
|
|
|
|
def test_unmapped_not_type_error_iter_ok(self):
|
|
"""extension version of the same test in test_mapper.
|
|
|
|
fixes #3408
|
|
"""
|
|
assert_raises_message(
|
|
sa.exc.ArgumentError,
|
|
r"Class object expected, got '\(5, 6\)'.",
|
|
class_mapper, (5, 6)
|
|
)
|
|
|
|
|
|
class FinderTest(_ExtBase, fixtures.ORMTest):
|
|
|
|
def test_standard(self):
|
|
class A(object):
|
|
pass
|
|
|
|
register_class(A)
|
|
|
|
eq_(
|
|
type(manager_of_class(A)),
|
|
instrumentation.ClassManager)
|
|
|
|
def test_nativeext_interfaceexact(self):
|
|
class A(object):
|
|
__sa_instrumentation_manager__ = \
|
|
instrumentation.InstrumentationManager
|
|
|
|
register_class(A)
|
|
ne_(
|
|
type(manager_of_class(A)),
|
|
instrumentation.ClassManager)
|
|
|
|
def test_nativeext_submanager(self):
|
|
class Mine(instrumentation.ClassManager):
|
|
pass
|
|
|
|
class A(object):
|
|
__sa_instrumentation_manager__ = Mine
|
|
|
|
register_class(A)
|
|
eq_(type(manager_of_class(A)), Mine)
|
|
|
|
@modifies_instrumentation_finders
|
|
def test_customfinder_greedy(self):
|
|
class Mine(instrumentation.ClassManager):
|
|
pass
|
|
|
|
class A(object):
|
|
pass
|
|
|
|
def find(cls):
|
|
return Mine
|
|
|
|
instrumentation.instrumentation_finders.insert(0, find)
|
|
register_class(A)
|
|
eq_(type(manager_of_class(A)), Mine)
|
|
|
|
@modifies_instrumentation_finders
|
|
def test_customfinder_pass(self):
|
|
class A(object):
|
|
pass
|
|
|
|
def find(cls):
|
|
return None
|
|
|
|
instrumentation.instrumentation_finders.insert(0, find)
|
|
register_class(A)
|
|
|
|
eq_(
|
|
type(manager_of_class(A)),
|
|
instrumentation.ClassManager)
|
|
|
|
|
|
class InstrumentationCollisionTest(_ExtBase, fixtures.ORMTest):
|
|
|
|
def test_none(self):
|
|
class A(object):
|
|
pass
|
|
register_class(A)
|
|
|
|
mgr_factory = lambda cls: instrumentation.ClassManager(cls)
|
|
|
|
class B(object):
|
|
__sa_instrumentation_manager__ = staticmethod(mgr_factory)
|
|
register_class(B)
|
|
|
|
class C(object):
|
|
__sa_instrumentation_manager__ = instrumentation.ClassManager
|
|
register_class(C)
|
|
|
|
def test_single_down(self):
|
|
class A(object):
|
|
pass
|
|
register_class(A)
|
|
|
|
mgr_factory = lambda cls: instrumentation.ClassManager(cls)
|
|
|
|
class B(A):
|
|
__sa_instrumentation_manager__ = staticmethod(mgr_factory)
|
|
|
|
assert_raises_message(
|
|
TypeError, "multiple instrumentation implementations",
|
|
register_class, B)
|
|
|
|
def test_single_up(self):
|
|
|
|
class A(object):
|
|
pass
|
|
# delay registration
|
|
|
|
mgr_factory = lambda cls: instrumentation.ClassManager(cls)
|
|
|
|
class B(A):
|
|
__sa_instrumentation_manager__ = staticmethod(mgr_factory)
|
|
register_class(B)
|
|
|
|
assert_raises_message(
|
|
TypeError, "multiple instrumentation implementations",
|
|
register_class, A)
|
|
|
|
def test_diamond_b1(self):
|
|
mgr_factory = lambda cls: instrumentation.ClassManager(cls)
|
|
|
|
class A(object):
|
|
pass
|
|
|
|
class B1(A):
|
|
pass
|
|
|
|
class B2(A):
|
|
__sa_instrumentation_manager__ = staticmethod(mgr_factory)
|
|
|
|
class C(object):
|
|
pass
|
|
|
|
assert_raises_message(
|
|
TypeError, "multiple instrumentation implementations",
|
|
register_class, B1)
|
|
|
|
def test_diamond_b2(self):
|
|
mgr_factory = lambda cls: instrumentation.ClassManager(cls)
|
|
|
|
class A(object):
|
|
pass
|
|
|
|
class B1(A):
|
|
pass
|
|
|
|
class B2(A):
|
|
__sa_instrumentation_manager__ = staticmethod(mgr_factory)
|
|
|
|
class C(object):
|
|
pass
|
|
|
|
register_class(B2)
|
|
assert_raises_message(
|
|
TypeError, "multiple instrumentation implementations",
|
|
register_class, B1)
|
|
|
|
def test_diamond_c_b(self):
|
|
mgr_factory = lambda cls: instrumentation.ClassManager(cls)
|
|
|
|
class A(object):
|
|
pass
|
|
|
|
class B1(A):
|
|
pass
|
|
|
|
class B2(A):
|
|
__sa_instrumentation_manager__ = staticmethod(mgr_factory)
|
|
|
|
class C(object):
|
|
pass
|
|
|
|
register_class(C)
|
|
|
|
assert_raises_message(
|
|
TypeError, "multiple instrumentation implementations",
|
|
register_class, B1)
|
|
|
|
|
|
class ExtendedEventsTest(_ExtBase, fixtures.ORMTest):
|
|
|
|
"""Allow custom Events implementations."""
|
|
|
|
@modifies_instrumentation_finders
|
|
def test_subclassed(self):
|
|
class MyEvents(events.InstanceEvents):
|
|
pass
|
|
|
|
class MyClassManager(instrumentation.ClassManager):
|
|
dispatch = event.dispatcher(MyEvents)
|
|
|
|
instrumentation.instrumentation_finders.insert(
|
|
0, lambda cls: MyClassManager)
|
|
|
|
class A(object):
|
|
pass
|
|
|
|
register_class(A)
|
|
manager = instrumentation.manager_of_class(A)
|
|
assert issubclass(manager.dispatch._events, MyEvents)
|