mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-06-01 13:28:30 -04:00
- removed enhance_classes from scoped_session, replaced with
scoped_session(...).mapper. 'mapper' essentially does the same thing as assign_mapper less verbosely. - adapted assignmapper unit tests into scoped_session tests
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from sqlalchemy.util import ScopedRegistry, warn_deprecated
|
||||
from sqlalchemy.util import ScopedRegistry, warn_deprecated, to_list
|
||||
from sqlalchemy.orm import MapperExtension, EXT_CONTINUE
|
||||
from sqlalchemy.orm.session import Session
|
||||
from sqlalchemy.orm.mapper import global_extensions
|
||||
@@ -13,16 +13,21 @@ class ScopedSession(object):
|
||||
|
||||
Usage::
|
||||
|
||||
Session = scoped_session(sessionmaker(autoflush=True), enhance_classes=True)
|
||||
Session = scoped_session(sessionmaker(autoflush=True))
|
||||
|
||||
To map classes so that new instances are saved in the current
|
||||
Session automatically, as well as to provide session-aware
|
||||
class attributes such as "query":
|
||||
|
||||
mapper = Session.mapper
|
||||
mapper(Class, table, ...)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, session_factory, scopefunc=None, enhance_classes=False):
|
||||
def __init__(self, session_factory, scopefunc=None):
|
||||
self.session_factory = session_factory
|
||||
self.enhance_classes = enhance_classes
|
||||
self.registry = ScopedRegistry(session_factory, scopefunc)
|
||||
if self.enhance_classes:
|
||||
global_extensions.append(_ScopedExt(self))
|
||||
self.extension = _ScopedExt(self)
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
if kwargs:
|
||||
@@ -39,15 +44,28 @@ class ScopedSession(object):
|
||||
else:
|
||||
return self.registry()
|
||||
|
||||
def mapper(self, *args, **kwargs):
|
||||
"""return a mapper() function which associates this ScopedSession with the Mapper."""
|
||||
|
||||
from sqlalchemy.orm import mapper
|
||||
validate = kwargs.pop('validate', False)
|
||||
extension = to_list(kwargs.setdefault('extension', []))
|
||||
if validate:
|
||||
extension.append(self.extension.validating())
|
||||
else:
|
||||
extension.append(self.extension)
|
||||
return mapper(*args, **kwargs)
|
||||
|
||||
def configure(self, **kwargs):
|
||||
"""reconfigure the sessionmaker used by this SessionContext"""
|
||||
"""reconfigure the sessionmaker used by this ScopedSession."""
|
||||
|
||||
self.session_factory.configure(**kwargs)
|
||||
|
||||
def instrument(name):
|
||||
def do(self, *args, **kwargs):
|
||||
return getattr(self.registry(), name)(*args, **kwargs)
|
||||
return do
|
||||
for meth in ('get', 'close', 'save', 'commit', 'update', 'flush', 'query', 'delete'):
|
||||
for meth in ('get', 'close', 'save', 'commit', 'update', 'flush', 'query', 'delete', 'clear'):
|
||||
setattr(ScopedSession, meth, instrument(meth))
|
||||
|
||||
def makeprop(name):
|
||||
@@ -67,18 +85,22 @@ for prop in ('close_all',):
|
||||
setattr(ScopedSession, prop, clslevel(prop))
|
||||
|
||||
class _ScopedExt(MapperExtension):
|
||||
def __init__(self, context):
|
||||
def __init__(self, context, validate=False):
|
||||
self.context = context
|
||||
self.validate = validate
|
||||
|
||||
def validating(self):
|
||||
return _ScopedExt(self.context, validate=True)
|
||||
|
||||
def get_session(self):
|
||||
return self.context.registry()
|
||||
|
||||
def instrument_class(self, mapper, class_):
|
||||
class query(object):
|
||||
def __getattr__(self, key):
|
||||
return getattr(registry().query(class_), key)
|
||||
def __call__(self):
|
||||
return registry().query(class_)
|
||||
def __getattr__(s, key):
|
||||
return getattr(self.context.registry().query(class_), key)
|
||||
def __call__(s):
|
||||
return self.context.registry().query(class_)
|
||||
|
||||
if not hasattr(class_, 'query'):
|
||||
class_.query = query()
|
||||
@@ -87,9 +109,9 @@ class _ScopedExt(MapperExtension):
|
||||
session = kwargs.pop('_sa_session', self.context.registry())
|
||||
if not isinstance(oldinit, types.MethodType):
|
||||
for key, value in kwargs.items():
|
||||
#if validate:
|
||||
# if not self.mapper.get_property(key, resolve_synonyms=False, raiseerr=False):
|
||||
# raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
|
||||
if self.validate:
|
||||
if not mapper.get_property(key, resolve_synonyms=False, raiseerr=False):
|
||||
raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
|
||||
setattr(instance, key, value)
|
||||
session._save_impl(instance, entity_name=kwargs.pop('_sa_entity_name', None))
|
||||
return EXT_CONTINUE
|
||||
|
||||
+77
-9
@@ -4,7 +4,6 @@ from sqlalchemy.orm import *
|
||||
from testlib import *
|
||||
from testlib.tables import *
|
||||
import testlib.tables as tables
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
class SessionTest(AssertMixin):
|
||||
def setUpAll(self):
|
||||
@@ -98,7 +97,7 @@ class SessionTest(AssertMixin):
|
||||
conn1 = testbase.db.connect()
|
||||
conn2 = testbase.db.connect()
|
||||
|
||||
sess = Session(bind=conn1, transactional=True, autoflush=True)
|
||||
sess = create_session(bind=conn1, transactional=True, autoflush=True)
|
||||
u = User()
|
||||
u.user_name='ed'
|
||||
sess.save(u)
|
||||
@@ -116,7 +115,7 @@ class SessionTest(AssertMixin):
|
||||
mapper(User, users)
|
||||
|
||||
try:
|
||||
sess = Session(transactional=True, autoflush=True)
|
||||
sess = create_session(transactional=True, autoflush=True)
|
||||
u = User()
|
||||
u.user_name='ed'
|
||||
sess.save(u)
|
||||
@@ -137,7 +136,7 @@ class SessionTest(AssertMixin):
|
||||
conn1 = testbase.db.connect()
|
||||
conn2 = testbase.db.connect()
|
||||
|
||||
sess = Session(bind=conn1, transactional=True, autoflush=True)
|
||||
sess = create_session(bind=conn1, transactional=True, autoflush=True)
|
||||
u = User()
|
||||
u.user_name='ed'
|
||||
sess.save(u)
|
||||
@@ -153,7 +152,7 @@ class SessionTest(AssertMixin):
|
||||
'addresses':relation(Address)
|
||||
})
|
||||
|
||||
sess = Session(transactional=True, autoflush=True)
|
||||
sess = create_session(transactional=True, autoflush=True)
|
||||
u = sess.query(User).get(8)
|
||||
newad = Address()
|
||||
newad.email_address == 'something new'
|
||||
@@ -173,7 +172,7 @@ class SessionTest(AssertMixin):
|
||||
mapper(User, users)
|
||||
conn = testbase.db.connect()
|
||||
trans = conn.begin()
|
||||
sess = Session(conn, transactional=True, autoflush=True)
|
||||
sess = create_session(bind=conn, transactional=True, autoflush=True)
|
||||
sess.begin()
|
||||
u = User()
|
||||
sess.save(u)
|
||||
@@ -189,7 +188,7 @@ class SessionTest(AssertMixin):
|
||||
try:
|
||||
conn = testbase.db.connect()
|
||||
trans = conn.begin()
|
||||
sess = Session(conn, transactional=True, autoflush=True)
|
||||
sess = create_session(bind=conn, transactional=True, autoflush=True)
|
||||
u1 = User()
|
||||
sess.save(u1)
|
||||
sess.flush()
|
||||
@@ -217,7 +216,7 @@ class SessionTest(AssertMixin):
|
||||
mapper(Address, addresses)
|
||||
|
||||
engine2 = create_engine(testbase.db.url)
|
||||
sess = Session(transactional=False, autoflush=False, twophase=True)
|
||||
sess = create_session(transactional=False, autoflush=False, twophase=True)
|
||||
sess.bind_mapper(User, testbase.db)
|
||||
sess.bind_mapper(Address, engine2)
|
||||
sess.begin()
|
||||
@@ -234,7 +233,7 @@ class SessionTest(AssertMixin):
|
||||
def test_joined_transaction(self):
|
||||
class User(object):pass
|
||||
mapper(User, users)
|
||||
sess = Session(transactional=True, autoflush=True)
|
||||
sess = create_session(transactional=True, autoflush=True)
|
||||
sess.begin()
|
||||
u = User()
|
||||
sess.save(u)
|
||||
@@ -440,6 +439,75 @@ class SessionTest(AssertMixin):
|
||||
key = s.identity_key(User, row=row, entity_name="en")
|
||||
self._assert_key(key, (User, (1,), "en"))
|
||||
|
||||
class ScopedSessionTest(PersistTest):
|
||||
def setUpAll(self):
|
||||
global metadata, table, table2
|
||||
metadata = MetaData(testbase.db)
|
||||
table = Table('sometable', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('data', String(30)))
|
||||
table2 = Table('someothertable', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('someid', None, ForeignKey('sometable.id'))
|
||||
)
|
||||
metadata.create_all()
|
||||
|
||||
def setUp(self):
|
||||
global SomeObject, SomeOtherObject
|
||||
class SomeObject(object):pass
|
||||
class SomeOtherObject(object):pass
|
||||
|
||||
global Session
|
||||
|
||||
Session = scoped_session(create_session)
|
||||
Session.mapper(SomeObject, table, properties={
|
||||
'options':relation(SomeOtherObject)
|
||||
})
|
||||
Session.mapper(SomeOtherObject, table2)
|
||||
|
||||
s = SomeObject()
|
||||
s.id = 1
|
||||
s.data = 'hello'
|
||||
sso = SomeOtherObject()
|
||||
s.options.append(sso)
|
||||
Session.flush()
|
||||
Session.clear()
|
||||
|
||||
def tearDownAll(self):
|
||||
metadata.drop_all()
|
||||
|
||||
def tearDown(self):
|
||||
for table in metadata.table_iterator(reverse=True):
|
||||
table.delete().execute()
|
||||
clear_mappers()
|
||||
|
||||
def test_query(self):
|
||||
sso = SomeOtherObject.query().first()
|
||||
assert SomeObject.query.filter_by(id=1).one().options[0].id == sso.id
|
||||
|
||||
def test_validating_constructor(self):
|
||||
s2 = SomeObject(someid=12)
|
||||
s3 = SomeOtherObject(someid=123, bogus=345)
|
||||
|
||||
class ValidatedOtherObject(object):pass
|
||||
Session.mapper(ValidatedOtherObject, table2, validate=True)
|
||||
|
||||
v1 = ValidatedOtherObject(someid=12)
|
||||
try:
|
||||
v2 = ValidatedOtherObject(someid=12, bogus=345)
|
||||
assert False
|
||||
except exceptions.ArgumentError:
|
||||
pass
|
||||
|
||||
def test_dont_clobber_methods(self):
|
||||
class MyClass(object):
|
||||
def expunge(self):
|
||||
return "an expunge !"
|
||||
|
||||
Session.mapper(MyClass, table2)
|
||||
|
||||
assert MyClass().expunge() == "an expunge !"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
testbase.main()
|
||||
|
||||
@@ -12,8 +12,9 @@ from testlib import tables
|
||||
|
||||
class UnitOfWorkTest(AssertMixin):
|
||||
def setUpAll(self):
|
||||
global Session
|
||||
Session = scoped_session(sessionmaker(autoflush=True, transactional=True), enhance_classes=True)
|
||||
global Session, mapper
|
||||
Session = scoped_session(sessionmaker(autoflush=True, transactional=True))
|
||||
mapper = Session.mapper
|
||||
def tearDownAll(self):
|
||||
global_extensions[:] = []
|
||||
def tearDown(self):
|
||||
|
||||
Reference in New Issue
Block a user