- removed enhance_classes from scoped_session, replaced with

scoped_session(...).mapper.  'mapper' essentially does the same
thing as assign_mapper less verbosely.
- adapted assignmapper unit tests into scoped_session tests
This commit is contained in:
Mike Bayer
2007-08-03 19:31:38 +00:00
parent fdc58f4141
commit e7c83bb371
3 changed files with 118 additions and 27 deletions
+38 -16
View File
@@ -1,4 +1,4 @@
from sqlalchemy.util import ScopedRegistry, warn_deprecated
from sqlalchemy.util import ScopedRegistry, warn_deprecated, to_list
from sqlalchemy.orm import MapperExtension, EXT_CONTINUE
from sqlalchemy.orm.session import Session
from sqlalchemy.orm.mapper import global_extensions
@@ -13,16 +13,21 @@ class ScopedSession(object):
Usage::
Session = scoped_session(sessionmaker(autoflush=True), enhance_classes=True)
Session = scoped_session(sessionmaker(autoflush=True))
To map classes so that new instances are saved in the current
Session automatically, as well as to provide session-aware
class attributes such as "query":
mapper = Session.mapper
mapper(Class, table, ...)
"""
def __init__(self, session_factory, scopefunc=None, enhance_classes=False):
def __init__(self, session_factory, scopefunc=None):
self.session_factory = session_factory
self.enhance_classes = enhance_classes
self.registry = ScopedRegistry(session_factory, scopefunc)
if self.enhance_classes:
global_extensions.append(_ScopedExt(self))
self.extension = _ScopedExt(self)
def __call__(self, **kwargs):
if kwargs:
@@ -39,15 +44,28 @@ class ScopedSession(object):
else:
return self.registry()
def mapper(self, *args, **kwargs):
"""return a mapper() function which associates this ScopedSession with the Mapper."""
from sqlalchemy.orm import mapper
validate = kwargs.pop('validate', False)
extension = to_list(kwargs.setdefault('extension', []))
if validate:
extension.append(self.extension.validating())
else:
extension.append(self.extension)
return mapper(*args, **kwargs)
def configure(self, **kwargs):
"""reconfigure the sessionmaker used by this SessionContext"""
"""reconfigure the sessionmaker used by this ScopedSession."""
self.session_factory.configure(**kwargs)
def instrument(name):
def do(self, *args, **kwargs):
return getattr(self.registry(), name)(*args, **kwargs)
return do
for meth in ('get', 'close', 'save', 'commit', 'update', 'flush', 'query', 'delete'):
for meth in ('get', 'close', 'save', 'commit', 'update', 'flush', 'query', 'delete', 'clear'):
setattr(ScopedSession, meth, instrument(meth))
def makeprop(name):
@@ -67,18 +85,22 @@ for prop in ('close_all',):
setattr(ScopedSession, prop, clslevel(prop))
class _ScopedExt(MapperExtension):
def __init__(self, context):
def __init__(self, context, validate=False):
self.context = context
self.validate = validate
def validating(self):
return _ScopedExt(self.context, validate=True)
def get_session(self):
return self.context.registry()
def instrument_class(self, mapper, class_):
class query(object):
def __getattr__(self, key):
return getattr(registry().query(class_), key)
def __call__(self):
return registry().query(class_)
def __getattr__(s, key):
return getattr(self.context.registry().query(class_), key)
def __call__(s):
return self.context.registry().query(class_)
if not hasattr(class_, 'query'):
class_.query = query()
@@ -87,9 +109,9 @@ class _ScopedExt(MapperExtension):
session = kwargs.pop('_sa_session', self.context.registry())
if not isinstance(oldinit, types.MethodType):
for key, value in kwargs.items():
#if validate:
# if not self.mapper.get_property(key, resolve_synonyms=False, raiseerr=False):
# raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
if self.validate:
if not mapper.get_property(key, resolve_synonyms=False, raiseerr=False):
raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
setattr(instance, key, value)
session._save_impl(instance, entity_name=kwargs.pop('_sa_entity_name', None))
return EXT_CONTINUE
+77 -9
View File
@@ -4,7 +4,6 @@ from sqlalchemy.orm import *
from testlib import *
from testlib.tables import *
import testlib.tables as tables
from sqlalchemy.orm.session import Session
class SessionTest(AssertMixin):
def setUpAll(self):
@@ -98,7 +97,7 @@ class SessionTest(AssertMixin):
conn1 = testbase.db.connect()
conn2 = testbase.db.connect()
sess = Session(bind=conn1, transactional=True, autoflush=True)
sess = create_session(bind=conn1, transactional=True, autoflush=True)
u = User()
u.user_name='ed'
sess.save(u)
@@ -116,7 +115,7 @@ class SessionTest(AssertMixin):
mapper(User, users)
try:
sess = Session(transactional=True, autoflush=True)
sess = create_session(transactional=True, autoflush=True)
u = User()
u.user_name='ed'
sess.save(u)
@@ -137,7 +136,7 @@ class SessionTest(AssertMixin):
conn1 = testbase.db.connect()
conn2 = testbase.db.connect()
sess = Session(bind=conn1, transactional=True, autoflush=True)
sess = create_session(bind=conn1, transactional=True, autoflush=True)
u = User()
u.user_name='ed'
sess.save(u)
@@ -153,7 +152,7 @@ class SessionTest(AssertMixin):
'addresses':relation(Address)
})
sess = Session(transactional=True, autoflush=True)
sess = create_session(transactional=True, autoflush=True)
u = sess.query(User).get(8)
newad = Address()
newad.email_address == 'something new'
@@ -173,7 +172,7 @@ class SessionTest(AssertMixin):
mapper(User, users)
conn = testbase.db.connect()
trans = conn.begin()
sess = Session(conn, transactional=True, autoflush=True)
sess = create_session(bind=conn, transactional=True, autoflush=True)
sess.begin()
u = User()
sess.save(u)
@@ -189,7 +188,7 @@ class SessionTest(AssertMixin):
try:
conn = testbase.db.connect()
trans = conn.begin()
sess = Session(conn, transactional=True, autoflush=True)
sess = create_session(bind=conn, transactional=True, autoflush=True)
u1 = User()
sess.save(u1)
sess.flush()
@@ -217,7 +216,7 @@ class SessionTest(AssertMixin):
mapper(Address, addresses)
engine2 = create_engine(testbase.db.url)
sess = Session(transactional=False, autoflush=False, twophase=True)
sess = create_session(transactional=False, autoflush=False, twophase=True)
sess.bind_mapper(User, testbase.db)
sess.bind_mapper(Address, engine2)
sess.begin()
@@ -234,7 +233,7 @@ class SessionTest(AssertMixin):
def test_joined_transaction(self):
class User(object):pass
mapper(User, users)
sess = Session(transactional=True, autoflush=True)
sess = create_session(transactional=True, autoflush=True)
sess.begin()
u = User()
sess.save(u)
@@ -440,6 +439,75 @@ class SessionTest(AssertMixin):
key = s.identity_key(User, row=row, entity_name="en")
self._assert_key(key, (User, (1,), "en"))
class ScopedSessionTest(PersistTest):
def setUpAll(self):
global metadata, table, table2
metadata = MetaData(testbase.db)
table = Table('sometable', metadata,
Column('id', Integer, primary_key=True),
Column('data', String(30)))
table2 = Table('someothertable', metadata,
Column('id', Integer, primary_key=True),
Column('someid', None, ForeignKey('sometable.id'))
)
metadata.create_all()
def setUp(self):
global SomeObject, SomeOtherObject
class SomeObject(object):pass
class SomeOtherObject(object):pass
global Session
Session = scoped_session(create_session)
Session.mapper(SomeObject, table, properties={
'options':relation(SomeOtherObject)
})
Session.mapper(SomeOtherObject, table2)
s = SomeObject()
s.id = 1
s.data = 'hello'
sso = SomeOtherObject()
s.options.append(sso)
Session.flush()
Session.clear()
def tearDownAll(self):
metadata.drop_all()
def tearDown(self):
for table in metadata.table_iterator(reverse=True):
table.delete().execute()
clear_mappers()
def test_query(self):
sso = SomeOtherObject.query().first()
assert SomeObject.query.filter_by(id=1).one().options[0].id == sso.id
def test_validating_constructor(self):
s2 = SomeObject(someid=12)
s3 = SomeOtherObject(someid=123, bogus=345)
class ValidatedOtherObject(object):pass
Session.mapper(ValidatedOtherObject, table2, validate=True)
v1 = ValidatedOtherObject(someid=12)
try:
v2 = ValidatedOtherObject(someid=12, bogus=345)
assert False
except exceptions.ArgumentError:
pass
def test_dont_clobber_methods(self):
class MyClass(object):
def expunge(self):
return "an expunge !"
Session.mapper(MyClass, table2)
assert MyClass().expunge() == "an expunge !"
if __name__ == "__main__":
testbase.main()
+3 -2
View File
@@ -12,8 +12,9 @@ from testlib import tables
class UnitOfWorkTest(AssertMixin):
def setUpAll(self):
global Session
Session = scoped_session(sessionmaker(autoflush=True, transactional=True), enhance_classes=True)
global Session, mapper
Session = scoped_session(sessionmaker(autoflush=True, transactional=True))
mapper = Session.mapper
def tearDownAll(self):
global_extensions[:] = []
def tearDown(self):