Files
sqlalchemy/test/ext/test_extendedattr.py
T
Khairi Hafsham 772374735d Make all tests to be PEP8 compliant
tested using pycodestyle version 2.2.0

Fixes: #3885
Change-Id: I5df43adc3aefe318f9eeab72a078247a548ec566
Pull-request: https://github.com/zzzeek/sqlalchemy/pull/343
2017-02-07 11:21:56 -05:00

651 lines
19 KiB
Python

from sqlalchemy.testing import eq_, assert_raises, assert_raises_message, ne_
from sqlalchemy import util
import sqlalchemy as sa
from sqlalchemy.orm import class_mapper
from sqlalchemy.orm import attributes
from sqlalchemy.orm.attributes import set_attribute, \
get_attribute, del_attribute
from sqlalchemy.orm.instrumentation import is_instrumented
from sqlalchemy.orm import clear_mappers
from sqlalchemy.testing import fixtures
from sqlalchemy.ext import instrumentation
from sqlalchemy.orm.instrumentation import register_class, manager_of_class
from sqlalchemy.testing.util import decorator
from sqlalchemy.orm import events
from sqlalchemy import event
@decorator
def modifies_instrumentation_finders(fn, *args, **kw):
pristine = instrumentation.instrumentation_finders[:]
try:
fn(*args, **kw)
finally:
del instrumentation.instrumentation_finders[:]
instrumentation.instrumentation_finders.extend(pristine)
class _ExtBase(object):
@classmethod
def teardown_class(cls):
instrumentation._reinstall_default_lookups()
class MyTypesManager(instrumentation.InstrumentationManager):
def instrument_attribute(self, class_, key, attr):
pass
def install_descriptor(self, class_, key, attr):
pass
def uninstall_descriptor(self, class_, key):
pass
def instrument_collection_class(self, class_, key, collection_class):
return MyListLike
def get_instance_dict(self, class_, instance):
return instance._goofy_dict
def initialize_instance_dict(self, class_, instance):
instance.__dict__['_goofy_dict'] = {}
def install_state(self, class_, instance, state):
instance.__dict__['_my_state'] = state
def state_getter(self, class_):
return lambda instance: instance.__dict__['_my_state']
class MyListLike(list):
# add @appender, @remover decorators as needed
_sa_iterator = list.__iter__
_sa_linker = None
_sa_converter = None
def _sa_appender(self, item, _sa_initiator=None):
if _sa_initiator is not False:
self._sa_adapter.fire_append_event(item, _sa_initiator)
list.append(self, item)
append = _sa_appender
def _sa_remover(self, item, _sa_initiator=None):
self._sa_adapter.fire_pre_remove_event(_sa_initiator)
if _sa_initiator is not False:
self._sa_adapter.fire_remove_event(item, _sa_initiator)
list.remove(self, item)
remove = _sa_remover
MyBaseClass, MyClass = None, None
class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
@classmethod
def setup_class(cls):
global MyBaseClass, MyClass
class MyBaseClass(object):
__sa_instrumentation_manager__ = \
instrumentation.InstrumentationManager
class MyClass(object):
# This proves that a staticmethod will work here; don't
# flatten this back to a class assignment!
def __sa_instrumentation_manager__(cls):
return MyTypesManager(cls)
__sa_instrumentation_manager__ = staticmethod(
__sa_instrumentation_manager__)
# This proves SA can handle a class with non-string dict keys
if not util.pypy and not util.jython:
locals()[42] = 99 # Don't remove this line!
def __init__(self, **kwargs):
for k in kwargs:
setattr(self, k, kwargs[k])
def __getattr__(self, key):
if is_instrumented(self, key):
return get_attribute(self, key)
else:
try:
return self._goofy_dict[key]
except KeyError:
raise AttributeError(key)
def __setattr__(self, key, value):
if is_instrumented(self, key):
set_attribute(self, key, value)
else:
self._goofy_dict[key] = value
def __hasattr__(self, key):
if is_instrumented(self, key):
return True
else:
return key in self._goofy_dict
def __delattr__(self, key):
if is_instrumented(self, key):
del_attribute(self, key)
else:
del self._goofy_dict[key]
def teardown(self):
clear_mappers()
def test_instance_dict(self):
class User(MyClass):
pass
register_class(User)
attributes.register_attribute(
User, 'user_id', uselist=False, useobject=False)
attributes.register_attribute(
User, 'user_name', uselist=False, useobject=False)
attributes.register_attribute(
User, 'email_address', uselist=False, useobject=False)
u = User()
u.user_id = 7
u.user_name = 'john'
u.email_address = 'lala@123.com'
eq_(
u.__dict__,
{
'_my_state': u._my_state,
'_goofy_dict': {
'user_id': 7, 'user_name': 'john',
'email_address': 'lala@123.com'}}
)
def test_basic(self):
for base in (object, MyBaseClass, MyClass):
class User(base):
pass
register_class(User)
attributes.register_attribute(
User, 'user_id', uselist=False, useobject=False)
attributes.register_attribute(
User, 'user_name', uselist=False, useobject=False)
attributes.register_attribute(
User, 'email_address', uselist=False, useobject=False)
u = User()
u.user_id = 7
u.user_name = 'john'
u.email_address = 'lala@123.com'
eq_(u.user_id, 7)
eq_(u.user_name, "john")
eq_(u.email_address, "lala@123.com")
attributes.instance_state(u)._commit_all(
attributes.instance_dict(u))
eq_(u.user_id, 7)
eq_(u.user_name, "john")
eq_(u.email_address, "lala@123.com")
u.user_name = 'heythere'
u.email_address = 'foo@bar.com'
eq_(u.user_id, 7)
eq_(u.user_name, "heythere")
eq_(u.email_address, "foo@bar.com")
def test_deferred(self):
for base in (object, MyBaseClass, MyClass):
class Foo(base):
pass
data = {'a': 'this is a', 'b': 12}
def loader(state, keys):
for k in keys:
state.dict[k] = data[k]
return attributes.ATTR_WAS_SET
manager = register_class(Foo)
manager.deferred_scalar_loader = loader
attributes.register_attribute(
Foo, 'a', uselist=False, useobject=False)
attributes.register_attribute(
Foo, 'b', uselist=False, useobject=False)
if base is object:
assert Foo not in \
instrumentation._instrumentation_factory._state_finders
else:
assert Foo in \
instrumentation._instrumentation_factory._state_finders
f = Foo()
attributes.instance_state(f)._expire(
attributes.instance_dict(f), set())
eq_(f.a, "this is a")
eq_(f.b, 12)
f.a = "this is some new a"
attributes.instance_state(f)._expire(
attributes.instance_dict(f), set())
eq_(f.a, "this is a")
eq_(f.b, 12)
attributes.instance_state(f)._expire(
attributes.instance_dict(f), set())
f.a = "this is another new a"
eq_(f.a, "this is another new a")
eq_(f.b, 12)
attributes.instance_state(f)._expire(
attributes.instance_dict(f), set())
eq_(f.a, "this is a")
eq_(f.b, 12)
del f.a
eq_(f.a, None)
eq_(f.b, 12)
attributes.instance_state(f)._commit_all(
attributes.instance_dict(f))
eq_(f.a, None)
eq_(f.b, 12)
def test_inheritance(self):
"""tests that attributes are polymorphic"""
for base in (object, MyBaseClass, MyClass):
class Foo(base):
pass
class Bar(Foo):
pass
register_class(Foo)
register_class(Bar)
def func1(state, passive):
return "this is the foo attr"
def func2(state, passive):
return "this is the bar attr"
def func3(state, passive):
return "this is the shared attr"
attributes.register_attribute(Foo, 'element',
uselist=False, callable_=func1,
useobject=True)
attributes.register_attribute(Foo, 'element2',
uselist=False, callable_=func3,
useobject=True)
attributes.register_attribute(Bar, 'element',
uselist=False, callable_=func2,
useobject=True)
x = Foo()
y = Bar()
assert x.element == 'this is the foo attr'
assert y.element == 'this is the bar attr', y.element
assert x.element2 == 'this is the shared attr'
assert y.element2 == 'this is the shared attr'
def test_collection_with_backref(self):
for base in (object, MyBaseClass, MyClass):
class Post(base):
pass
class Blog(base):
pass
register_class(Post)
register_class(Blog)
attributes.register_attribute(
Post, 'blog', uselist=False,
backref='posts', trackparent=True, useobject=True)
attributes.register_attribute(
Blog, 'posts', uselist=True,
backref='blog', trackparent=True, useobject=True)
b = Blog()
(p1, p2, p3) = (Post(), Post(), Post())
b.posts.append(p1)
b.posts.append(p2)
b.posts.append(p3)
self.assert_(b.posts == [p1, p2, p3])
self.assert_(p2.blog is b)
p3.blog = None
self.assert_(b.posts == [p1, p2])
p4 = Post()
p4.blog = b
self.assert_(b.posts == [p1, p2, p4])
p4.blog = b
p4.blog = b
self.assert_(b.posts == [p1, p2, p4])
# assert no failure removing None
p5 = Post()
p5.blog = None
del p5.blog
def test_history(self):
for base in (object, MyBaseClass, MyClass):
class Foo(base):
pass
class Bar(base):
pass
register_class(Foo)
register_class(Bar)
attributes.register_attribute(
Foo, "name", uselist=False, useobject=False)
attributes.register_attribute(
Foo, "bars", uselist=True, trackparent=True, useobject=True)
attributes.register_attribute(
Bar, "name", uselist=False, useobject=False)
f1 = Foo()
f1.name = 'f1'
eq_(
attributes.get_state_history(
attributes.instance_state(f1), 'name'),
(['f1'], (), ()))
b1 = Bar()
b1.name = 'b1'
f1.bars.append(b1)
eq_(
attributes.get_state_history(
attributes.instance_state(f1), 'bars'),
([b1], [], []))
attributes.instance_state(f1)._commit_all(
attributes.instance_dict(f1))
attributes.instance_state(b1)._commit_all(
attributes.instance_dict(b1))
eq_(
attributes.get_state_history(
attributes.instance_state(f1),
'name'),
((), ['f1'], ()))
eq_(
attributes.get_state_history(
attributes.instance_state(f1),
'bars'),
((), [b1], ()))
f1.name = 'f1mod'
b2 = Bar()
b2.name = 'b2'
f1.bars.append(b2)
eq_(
attributes.get_state_history(
attributes.instance_state(f1), 'name'),
(['f1mod'], (), ['f1']))
eq_(
attributes.get_state_history(
attributes.instance_state(f1), 'bars'),
([b2], [b1], []))
f1.bars.remove(b1)
eq_(
attributes.get_state_history(
attributes.instance_state(f1), 'bars'),
([b2], [], [b1]))
def test_null_instrumentation(self):
class Foo(MyBaseClass):
pass
register_class(Foo)
attributes.register_attribute(
Foo, "name", uselist=False, useobject=False)
attributes.register_attribute(
Foo, "bars", uselist=True, trackparent=True, useobject=True)
assert Foo.name == attributes.manager_of_class(Foo)['name']
assert Foo.bars == attributes.manager_of_class(Foo)['bars']
def test_alternate_finders(self):
"""Ensure the generic finder front-end deals with edge cases."""
class Unknown(object):
pass
class Known(MyBaseClass):
pass
register_class(Known)
k, u = Known(), Unknown()
assert instrumentation.manager_of_class(Unknown) is None
assert instrumentation.manager_of_class(Known) is not None
assert instrumentation.manager_of_class(None) is None
assert attributes.instance_state(k) is not None
assert_raises((AttributeError, KeyError),
attributes.instance_state, u)
assert_raises((AttributeError, KeyError),
attributes.instance_state, None)
def test_unmapped_not_type_error(self):
"""extension version of the same test in test_mapper.
fixes #3408
"""
assert_raises_message(
sa.exc.ArgumentError,
"Class object expected, got '5'.",
class_mapper, 5
)
def test_unmapped_not_type_error_iter_ok(self):
"""extension version of the same test in test_mapper.
fixes #3408
"""
assert_raises_message(
sa.exc.ArgumentError,
r"Class object expected, got '\(5, 6\)'.",
class_mapper, (5, 6)
)
class FinderTest(_ExtBase, fixtures.ORMTest):
def test_standard(self):
class A(object):
pass
register_class(A)
eq_(
type(manager_of_class(A)),
instrumentation.ClassManager)
def test_nativeext_interfaceexact(self):
class A(object):
__sa_instrumentation_manager__ = \
instrumentation.InstrumentationManager
register_class(A)
ne_(
type(manager_of_class(A)),
instrumentation.ClassManager)
def test_nativeext_submanager(self):
class Mine(instrumentation.ClassManager):
pass
class A(object):
__sa_instrumentation_manager__ = Mine
register_class(A)
eq_(type(manager_of_class(A)), Mine)
@modifies_instrumentation_finders
def test_customfinder_greedy(self):
class Mine(instrumentation.ClassManager):
pass
class A(object):
pass
def find(cls):
return Mine
instrumentation.instrumentation_finders.insert(0, find)
register_class(A)
eq_(type(manager_of_class(A)), Mine)
@modifies_instrumentation_finders
def test_customfinder_pass(self):
class A(object):
pass
def find(cls):
return None
instrumentation.instrumentation_finders.insert(0, find)
register_class(A)
eq_(
type(manager_of_class(A)),
instrumentation.ClassManager)
class InstrumentationCollisionTest(_ExtBase, fixtures.ORMTest):
def test_none(self):
class A(object):
pass
register_class(A)
def mgr_factory(cls): return instrumentation.ClassManager(cls)
class B(object):
__sa_instrumentation_manager__ = staticmethod(mgr_factory)
register_class(B)
class C(object):
__sa_instrumentation_manager__ = instrumentation.ClassManager
register_class(C)
def test_single_down(self):
class A(object):
pass
register_class(A)
def mgr_factory(cls): return instrumentation.ClassManager(cls)
class B(A):
__sa_instrumentation_manager__ = staticmethod(mgr_factory)
assert_raises_message(
TypeError, "multiple instrumentation implementations",
register_class, B)
def test_single_up(self):
class A(object):
pass
# delay registration
def mgr_factory(cls): return instrumentation.ClassManager(cls)
class B(A):
__sa_instrumentation_manager__ = staticmethod(mgr_factory)
register_class(B)
assert_raises_message(
TypeError, "multiple instrumentation implementations",
register_class, A)
def test_diamond_b1(self):
def mgr_factory(cls): return instrumentation.ClassManager(cls)
class A(object):
pass
class B1(A):
pass
class B2(A):
__sa_instrumentation_manager__ = staticmethod(mgr_factory)
class C(object):
pass
assert_raises_message(
TypeError, "multiple instrumentation implementations",
register_class, B1)
def test_diamond_b2(self):
def mgr_factory(cls): return instrumentation.ClassManager(cls)
class A(object):
pass
class B1(A):
pass
class B2(A):
__sa_instrumentation_manager__ = staticmethod(mgr_factory)
class C(object):
pass
register_class(B2)
assert_raises_message(
TypeError, "multiple instrumentation implementations",
register_class, B1)
def test_diamond_c_b(self):
def mgr_factory(cls): return instrumentation.ClassManager(cls)
class A(object):
pass
class B1(A):
pass
class B2(A):
__sa_instrumentation_manager__ = staticmethod(mgr_factory)
class C(object):
pass
register_class(C)
assert_raises_message(
TypeError, "multiple instrumentation implementations",
register_class, B1)
class ExtendedEventsTest(_ExtBase, fixtures.ORMTest):
"""Allow custom Events implementations."""
@modifies_instrumentation_finders
def test_subclassed(self):
class MyEvents(events.InstanceEvents):
pass
class MyClassManager(instrumentation.ClassManager):
dispatch = event.dispatcher(MyEvents)
instrumentation.instrumentation_finders.insert(
0, lambda cls: MyClassManager)
class A(object):
pass
register_class(A)
manager = instrumentation.manager_of_class(A)
assert issubclass(manager.dispatch._events, MyEvents)