This commit is contained in:
Mike Bayer
2005-10-02 19:50:11 +00:00
parent bb274e941e
commit 24fe959df6
8 changed files with 187 additions and 91 deletions
+67 -17
View File
@@ -54,6 +54,10 @@ class PropHistory(object):
self.obj = obj
self.key = key
self.orig = PropHistory.NONE
def gethistory(self, *args, **kwargs):
return self
def __call__(self, *args, **kwargs):
return self.obj.__dict__[self.key]
def history_contains(self, obj):
return self.orig is obj or self.obj.__dict__[self.key] is obj
def setattr_clean(self, value):
@@ -90,15 +94,30 @@ class PropHistory(object):
class ListElement(util.HistoryArraySet):
"""manages the value of a particular list-based attribute on a particular object instance."""
def __init__(self, obj, key, items = None):
def __init__(self, obj, key, data=None):
self.obj = obj
self.key = key
util.HistoryArraySet.__init__(self, items)
obj.__dict__[key] = self.data
try:
list_ = obj.__dict__[key]
if data is not None:
list_.clear()
for d in data:
list_.append(d)
except KeyError:
if data is not None:
list_ = data
else:
list_ = []
obj.__dict__[key] = []
util.HistoryArraySet.__init__(self, list_)
def gethistory(self, *args, **kwargs):
return self
def __call__(self, *args, **kwargs):
return self
def list_value_changed(self, obj, key, listval):
pass
def setattr(self, value):
self.obj.__dict__[self.key] = value
self.set_data(value)
@@ -115,9 +134,37 @@ class ListElement(util.HistoryArraySet):
self.list_value_changed(self.obj, self.key, self)
return res
class CallableProp(object):
"""allows the attaching of a callable item, representing the future value
of a particular attribute on a particular object instance, to
the AttributeManager. When the attributemanager
accesses the object attribute, either to get its history or its real value, the __call__ method
is invoked which runs the underlying callable_ and sets the new value to the object attribute
via the manager."""
def __init__(self, callable_, obj, key, uselist = False):
self.callable_ = callable_
self.obj = obj
self.key = key
self.uselist = uselist
def gethistory(self, manager, *args, **kwargs):
self.__call__(manager, *args, **kwargs)
return manager.attribute_history[self.obj][self.key]
def __call__(self, manager, passive=False):
if passive:
return None
value = self.callable_()
if self.uselist:
p = manager.create_list(self.obj, self.key, value)
manager.attribute_history[self.obj][self.key] = p
return p
else:
self.obj.__dict__[self.key] = value
p = PropHistory(self.obj, self.key)
manager.attribute_history[self.obj][self.key] = p
return p
class AttributeManager(object):
"""maintains a set of per-attribute history objects for a set of objects."""
"""maintains a set of per-attribute callable/history manager objects for a set of objects."""
def __init__(self):
self.attribute_history = {}
@@ -130,13 +177,13 @@ class AttributeManager(object):
def get_attribute(self, obj, key):
try:
v = obj.__dict__[key]
return self.get_history(obj, key)(self)
except KeyError:
pass
try:
return obj.__dict__[key]
except KeyError:
raise AttributeError(key)
if (callable(v)):
v = v()
obj.__dict__[key] = v
return v
def get_list_attribute(self, obj, key):
return self.get_list_history(obj, key)
@@ -152,6 +199,13 @@ class AttributeManager(object):
self.get_history(obj, key).delattr()
self.value_changed(obj, key, value)
def set_callable(self, obj, key, func, uselist):
try:
d = self.attribute_history[obj]
except KeyError, e:
d = {}
self.attribute_history[obj] = d
d[key] = CallableProp(func, obj, key, uselist)
def delete_list_attribute(self, obj, key):
pass
@@ -190,7 +244,7 @@ class AttributeManager(object):
def get_history(self, obj, key):
try:
return self.attribute_history[obj][key]
return self.attribute_history[obj][key].gethistory(self)
except KeyError, e:
if e.args[0] is obj:
d = {}
@@ -205,14 +259,10 @@ class AttributeManager(object):
def get_list_history(self, obj, key, passive = False):
try:
return self.attribute_history[obj][key]
return self.attribute_history[obj][key].gethistory(self, passive)
except KeyError, e:
# TODO: when an callable is re-set on an existing list element
list_ = obj.__dict__.get(key, None)
if callable(list_):
if passive:
return None
list_ = list_()
if e.args[0] is obj:
d = {}
self.attribute_history[obj] = d
+3 -2
View File
@@ -22,7 +22,7 @@ import sqlalchemy.schema as schema
import sqlalchemy.pool
import sqlalchemy.util as util
import sqlalchemy.sql as sql
import StringIO
import StringIO, sys
import sqlalchemy.types as types
def create_engine(name, *args ,**kwargs):
@@ -61,6 +61,7 @@ class SQLEngine(schema.SchemaEngine):
self.context = util.ThreadLocal()
self.tables = {}
self.notes = {}
self.logger = sys.stdout
def type_descriptor(self, typeobj):
@@ -206,7 +207,7 @@ class SQLEngine(schema.SchemaEngine):
return ResultProxy(c, self.echo, typemap = typemap)
def log(self, msg):
print msg
self.logger.write(msg + "\n")
class ResultProxy:
+19 -36
View File
@@ -694,12 +694,15 @@ class PropertyLoader(MapperProperty):
return (obj2, obj1)
def process_dependencies(self, deplist, uowcommit, delete = False):
#print self.mapper.table.name + " " + repr(deplist.map.values()) + " process_dep isdelete " + repr(delete)
print self.mapper.table.name + " " + repr([str(v) for v in deplist.map.values()]) + " process_dep isdelete " + repr(delete)
# fucntion to set properties across a parent/child object plus an "association row",
# based on a join condition
def sync_foreign_keys(binary):
self._sync_foreign_keys(binary, obj, child, associationrow, clearkeys)
if self.direction == PropertyLoader.RIGHT:
self._sync_foreign_keys(binary, child, obj, associationrow, clearkeys)
else:
self._sync_foreign_keys(binary, obj, child, associationrow, clearkeys)
setter = BinaryVisitor(sync_foreign_keys)
def getlist(obj, passive=True):
@@ -744,8 +747,8 @@ class PropertyLoader(MapperProperty):
if len(secondary_insert):
statement = self.secondary.insert()
statement.execute(*secondary_insert)
elif self.direction == PropertyLoader.LEFT:
if delete and not self.private:
elif self.direction == PropertyLoader.LEFT and delete:
if not self.private:
updates = []
clearkeys = True
for obj in deplist:
@@ -763,33 +766,18 @@ class PropertyLoader(MapperProperty):
values[bind.shortname] = None
statement = self.target.update(self.lazywhere, values = values)
statement.execute(*updates)
else:
for obj in deplist:
childlist = getlist(obj)
if childlist is None: return
uowcommit.register_saved_list(childlist)
clearkeys = False
for child in childlist.added_items():
self.primaryjoin.accept_visitor(setter)
clearkeys = True
for child in childlist.deleted_items():
self.primaryjoin.accept_visitor(setter)
elif self.direction == PropertyLoader.RIGHT:
for child in deplist:
childlist = getlist(child)
else:
for obj in deplist:
childlist = getlist(obj)
if childlist is None: return
uowcommit.register_saved_list(childlist)
clearkeys = False
added = childlist.added_items()
if len(added):
for obj in added:
self.primaryjoin.accept_visitor(setter)
else:
for child in childlist.added_items():
self.primaryjoin.accept_visitor(setter)
if self.direction != PropertyLoader.RIGHT or len(childlist.added_items()) == 0:
clearkeys = True
for obj in childlist.deleted_items():
for child in childlist.deleted_items():
self.primaryjoin.accept_visitor(setter)
else:
raise " no foreign key ?"
#print self.mapper.table.name + " postdep " + repr([str(v) for v in deplist.map.values()]) + " process_dep isdelete " + repr(delete)
@@ -797,6 +785,8 @@ class PropertyLoader(MapperProperty):
"""given a binary clause with an = operator joining two table columns, synchronizes the values
of the corresponding attributes within a parent object and a child object, or the attributes within an
an "association row" that represents an association link between the 'parent' and 'child' object."""
if obj is child:
raise "wha?"
if binary.operator == '=':
if binary.left.table == binary.right.table:
if binary.right is self.foreignkey:
@@ -805,8 +795,9 @@ class PropertyLoader(MapperProperty):
source = binary.right
else:
raise "Cant determine direction for relationship %s = %s" % (binary.left.fullname, binary.right.fullname)
#print "set " + repr(child) + ":" + self.foreignkey.key + " to " + repr(obj) + ":" + source.key
self.mapper._setattrbycolumn(child, self.foreignkey, self.parent._getattrbycolumn(obj, source))
print "set " + repr(id(child)) + child.__dict__['name'] + ":" + self.foreignkey.key + " to " + repr(id(obj)) + obj.__dict__['name'] + ":" + source.key
#+ "\n" + repr(child.__dict__)
else:
colmap = {binary.left.table : binary.left, binary.right.table : binary.right}
if colmap.has_key(self.parent.primarytable) and colmap.has_key(self.target):
@@ -820,18 +811,10 @@ class PropertyLoader(MapperProperty):
elif colmap.has_key(self.target) and colmap.has_key(self.secondary):
associationrow[colmap[self.secondary].key] = self.mapper._getattrbycolumn(child, colmap[self.target])
# TODO: break out the lazywhere capability so that the main PropertyLoader can use it
# to do child deletes
class LazyLoader(PropertyLoader):
def execute(self, instance, row, identitykey, imap, isnew):
if isnew:
# TODO: get lazy callables to be stored within the unit of work?
# allows serializable ? still need lazyload state to exist in the application
# when u deserialize tho
objectstore.uow().attribute_set_callable(instance, self.key, LazyLoadInstance(self, row))
objectstore.uow().register_callable(instance, self.key, LazyLoadInstance(self, row), uselist=self.uselist)
def create_lazy_clause(table, primaryjoin, secondaryjoin, thiscol):
binds = {}
+54 -23
View File
@@ -107,7 +107,9 @@ class UnitOfWork(object):
self.new = util.HashSet(ordered = True)
self.dirty = util.HashSet()
self.modified_lists = util.HashSet()
self.deleted = util.HashSet()
# the delete list is ordered mostly so the unit tests can predict the argument list ordering.
# TODO: need stronger unit test fixtures....
self.deleted = util.HashSet(ordered = True)
self.parent = parent
def get(self, class_, *id):
@@ -136,17 +138,10 @@ class UnitOfWork(object):
def register_attribute(self, class_, key, uselist):
self.attributes.register_attribute(class_, key, uselist)
def attribute_set_callable(self, obj, key, func):
# TODO: gotta work this out when a list element is already there,
# etc.
obj.__dict__[key] = func
try:
del self.attributes.attribute_history[obj][key]
except KeyError:
pass
def register_callable(self, obj, key, func, uselist):
self.attributes.set_callable(obj, key, func, uselist)
def register_clean(self, obj):
try:
del self.dirty[obj]
@@ -405,7 +400,32 @@ class UOWTask(object):
def sort_circular_dependencies(self, trans):
allobjects = self.objects
tuples = []
d = {}
def get_task(obj):
try:
return d[obj]
except KeyError:
t = UOWTask(self.mapper, self.isdelete, self.listonly)
t.taskhash = d
d[obj] = t
return t
dependencies = {}
def get_dependency_task(obj, processor):
try:
dp = dependencies[obj]
except KeyError:
dp = {}
dependencies[obj] = dp
try:
l = dp[processor]
except KeyError:
l = UOWTask(None, None, None)
dp[processor] = l
return l
for obj in self.objects:
parenttask = get_task(obj)
for dep in self.dependencies:
(processor, targettask) = dep
if targettask is self:
@@ -414,29 +434,40 @@ class UOWTask(object):
whosdep = processor.whose_dependent_on_who(obj, o, trans)
if whosdep is not None:
tuples.append(whosdep)
if whosdep[0] is obj:
get_dependency_task(whosdep[0], processor).objects.append(whosdep[0])
else:
get_dependency_task(whosdep[0], processor).objects.append(whosdep[1])
head = TupleSorter(tuples, allobjects).sort()
if head is None:
return None
d = {}
def make_task():
t = UOWTask(self.mapper, self.isdelete, self.listonly)
t.dependencies = self.dependencies
t.taskhash = d
return t
def make_task_tree(node, parenttask):
if node is None:
return
parenttask.objects.append(node.item)
t = make_task()
d[node.item] = t
if dependencies.has_key(node.item):
for processor, deptask in dependencies[node.item].iteritems():
parenttask.dependencies.append((processor, deptask))
t = d[node.item]
for n in node.children:
make_task_tree(n, t)
t = make_task()
t2 = make_task_tree(n, t)
return t
t = UOWTask(self.mapper, self.isdelete, self.listonly)
t.taskhash = d
make_task_tree(head, t)
t._print_circular()
return t
def _print_circular(t):
print "-----------------------------"
print "task objects: " + repr([str(v) for v in t.objects])
print "task depends: " + repr([(dt[0].key, [str(o) for o in dt[1].objects]) for dt in t.dependencies])
for o in t.objects:
t.taskhash[o]._print_circular()
def __str__(self):
if self.isdelete:
+6 -6
View File
@@ -220,12 +220,9 @@ class ClauseElement(object):
def compile(self, engine = None, bindparams = None):
"""compiles this SQL expression using its underlying SQLEngine to produce
a Compiled object. The actual SQL statement is the Compiled object's string representation.
bindparams is an optional dictionary representing the bind parameters to be used with
the statement. Currently, only the compilations of INSERT and UPDATE statements
use the bind parameters, in order to determine which
table columns should be used in the statement."""
a Compiled object. If no engine can be found, an ansisql engine is used.
bindparams is a dictionary representing the default bind parameters to be used with
the statement. """
if engine is None:
for f in self._get_from_objects():
engine = f.engine
@@ -237,6 +234,9 @@ class ClauseElement(object):
return engine.compile(self, bindparams = bindparams)
def __str__(self):
return str(self.compile())
def execute(self, *multiparams, **params):
"""compiles and executes this SQL expression using its underlying SQLEngine.
the given **params are used as bind parameters when compiling and executing the expression.
+19 -3
View File
@@ -1,6 +1,7 @@
from testbase import PersistTest, AssertMixin
import unittest, sys, os
from sqlalchemy.mapper import *
import StringIO
import sqlalchemy.objectstore as objectstore
from tables import *
@@ -207,7 +208,18 @@ class SaveTest(AssertMixin):
objectstore.uow().register_deleted(l[0])
objectstore.uow().register_deleted(l[2])
objectstore.uow().commit()
res = self.capture_exec(db, lambda: objectstore.uow().commit())
state = None
for line in res.split('\n'):
if line == "DELETE FROM items WHERE items.item_id = :item_id":
self.assert_(state is None or state == 'addresses')
elif line == "DELETE FROM orders WHERE orders.order_id = :order_id":
state = 'orders'
elif line == "DELETE FROM email_addresses WHERE email_addresses.address_id = :address_id":
if state is None:
state = 'addresses'
elif line == "DELETE FROM users WHERE users.user_id = :user_id":
self.assert_(state is not None)
def testbackwardsonetoone(self):
# test 'backwards'
@@ -238,8 +250,12 @@ class SaveTest(AssertMixin):
objects[3].user = User()
objects[3].user.user_name = 'imnewlyadded'
objectstore.uow().commit()
return
self.assert_enginesql(db, lambda: objectstore.uow().commit(),
"""INSERT INTO users (user_id, user_name) VALUES (:user_id, :user_name)
{'user_id': None, 'user_name': 'imnewlyadded'}
UPDATE email_addresses SET address_id=:address_id, user_id=:user_id, email_address=:email_address WHERE email_addresses.address_id = :address_id
[{'email_address': 'imnew@foo.bar', 'address_id': 3, 'user_id': 3}, {'email_address': 'adsd5@llala.net', 'address_id': 4, 'user_id': None}]
""")
l = sql.select([users, addresses], sql.and_(users.c.user_id==addresses.c.address_id, addresses.c.address_id==a.address_id)).execute()
self.echo( repr(l.fetchone().row))
+1 -2
View File
@@ -57,7 +57,6 @@ User.mapper = assignmapper(users, properties = dict(
# select
user = User.mapper.select(User.c.user_name == 'fred jones')[0]
print repr(user.__dict__['addresses'])
address = user.addresses[0]
# modify
@@ -129,4 +128,4 @@ user.preferences.stylename = 'bluesteel'
user.addresses.append(Address('freddy@hi.org'))
# commit
objectstore.commit()
objectstore.commit()
+18 -2
View File
@@ -1,4 +1,5 @@
import unittest
import StringIO
echo = True
@@ -8,7 +9,20 @@ class PersistTest(unittest.TestCase):
def echo(self, text):
if echo:
print text
def capture_exec(self, db, callable_):
e = db.echo
b = db.logger
buffer = StringIO.StringIO()
db.logger = buffer
db.echo = True
try:
callable_()
if echo:
print buffer.getvalue()
return buffer.getvalue()
finally:
db.logger = b
db.echo = e
class AssertMixin(PersistTest):
def assert_result(self, result, class_, *objects):
@@ -29,7 +43,9 @@ class AssertMixin(PersistTest):
self.assert_row(value[0], getattr(rowobj, key), value[1])
else:
self.assert_(getattr(rowobj, key) == value, "attribute %s value %s does not match %s" % (key, getattr(rowobj, key), value))
def assert_enginesql(self, db, callable_, result):
self.assert_(self.capture_exec(db, callable_) == result, result)
def runTests(suite):
runner = unittest.TextTestRunner(verbosity = 2, descriptions =1)
runner.run(suite)