mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-28 03:26:01 -04:00
1fcbc17b7d
The :meth:`.Query.update` method can now accommodate both hybrid attributes as well as composite attributes as a source of the key to be placed in the SET clause. For hybrids, an additional decorator :meth:`.hybrid_property.update_expression` is supplied for which the user supplies a tuple-returning function. Change-Id: I15e97b02381d553f30b3301308155e19128d2cfb Fixes: #3229
915 lines
27 KiB
Python
915 lines
27 KiB
Python
from sqlalchemy import func, Integer, Numeric, String, ForeignKey
|
|
from sqlalchemy.orm import relationship, Session, aliased, persistence
|
|
from sqlalchemy.testing.schema import Column
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.ext import hybrid
|
|
from sqlalchemy.testing import eq_, is_, AssertsCompiledSQL, \
|
|
assert_raises_message
|
|
from sqlalchemy.testing import fixtures
|
|
from sqlalchemy import inspect
|
|
from decimal import Decimal
|
|
|
|
|
|
class PropertyComparatorTest(fixtures.TestBase, AssertsCompiledSQL):
|
|
__dialect__ = 'default'
|
|
|
|
def _fixture(self):
|
|
Base = declarative_base()
|
|
|
|
class UCComparator(hybrid.Comparator):
|
|
|
|
def __eq__(self, other):
|
|
if other is None:
|
|
return self.expression is None
|
|
else:
|
|
return func.upper(self.expression) == func.upper(other)
|
|
|
|
class A(Base):
|
|
__tablename__ = 'a'
|
|
id = Column(Integer, primary_key=True)
|
|
|
|
_value = Column("value", String)
|
|
|
|
@hybrid.hybrid_property
|
|
def value(self):
|
|
"This is a docstring"
|
|
return self._value - 5
|
|
|
|
@value.comparator
|
|
def value(cls):
|
|
return UCComparator(cls._value)
|
|
|
|
@value.setter
|
|
def value(self, v):
|
|
self._value = v + 5
|
|
|
|
return A
|
|
|
|
def test_set_get(self):
|
|
A = self._fixture()
|
|
a1 = A(value=5)
|
|
eq_(a1._value, 10)
|
|
eq_(a1.value, 5)
|
|
|
|
def test_value(self):
|
|
A = self._fixture()
|
|
eq_(str(A.value == 5), "upper(a.value) = upper(:upper_1)")
|
|
|
|
def test_aliased_value(self):
|
|
A = self._fixture()
|
|
eq_(str(aliased(A).value == 5), "upper(a_1.value) = upper(:upper_1)")
|
|
|
|
def test_query(self):
|
|
A = self._fixture()
|
|
sess = Session()
|
|
self.assert_compile(
|
|
sess.query(A.value),
|
|
"SELECT a.value AS a_value FROM a"
|
|
)
|
|
|
|
def test_aliased_query(self):
|
|
A = self._fixture()
|
|
sess = Session()
|
|
self.assert_compile(
|
|
sess.query(aliased(A).value),
|
|
"SELECT a_1.value AS a_1_value FROM a AS a_1"
|
|
)
|
|
|
|
def test_aliased_filter(self):
|
|
A = self._fixture()
|
|
sess = Session()
|
|
self.assert_compile(
|
|
sess.query(aliased(A)).filter_by(value="foo"),
|
|
"SELECT a_1.value AS a_1_value, a_1.id AS a_1_id "
|
|
"FROM a AS a_1 WHERE upper(a_1.value) = upper(:upper_1)"
|
|
)
|
|
|
|
def test_docstring(self):
|
|
A = self._fixture()
|
|
eq_(A.value.__doc__, "This is a docstring")
|
|
|
|
|
|
class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL):
|
|
__dialect__ = 'default'
|
|
|
|
def _fixture(self):
|
|
Base = declarative_base()
|
|
|
|
class A(Base):
|
|
__tablename__ = 'a'
|
|
id = Column(Integer, primary_key=True)
|
|
_value = Column("value", String)
|
|
|
|
@hybrid.hybrid_property
|
|
def value(self):
|
|
"This is an instance-level docstring"
|
|
return int(self._value) - 5
|
|
|
|
@value.expression
|
|
def value(cls):
|
|
"This is a class-level docstring"
|
|
return func.foo(cls._value) + cls.bar_value
|
|
|
|
@value.setter
|
|
def value(self, v):
|
|
self._value = v + 5
|
|
|
|
@hybrid.hybrid_property
|
|
def bar_value(cls):
|
|
return func.bar(cls._value)
|
|
|
|
return A
|
|
|
|
|
|
def _relationship_fixture(self):
|
|
Base = declarative_base()
|
|
|
|
class A(Base):
|
|
__tablename__ = 'a'
|
|
id = Column(Integer, primary_key=True)
|
|
b_id = Column('bid', Integer, ForeignKey('b.id'))
|
|
_value = Column("value", String)
|
|
|
|
@hybrid.hybrid_property
|
|
def value(self):
|
|
return int(self._value) - 5
|
|
|
|
@value.expression
|
|
def value(cls):
|
|
return func.foo(cls._value) + cls.bar_value
|
|
|
|
@value.setter
|
|
def value(self, v):
|
|
self._value = v + 5
|
|
|
|
@hybrid.hybrid_property
|
|
def bar_value(cls):
|
|
return func.bar(cls._value)
|
|
|
|
class B(Base):
|
|
__tablename__ = 'b'
|
|
id = Column(Integer, primary_key=True)
|
|
|
|
as_ = relationship("A")
|
|
|
|
return A, B
|
|
|
|
def test_info(self):
|
|
A = self._fixture()
|
|
inspect(A).all_orm_descriptors.value.info["some key"] = "some value"
|
|
eq_(
|
|
inspect(A).all_orm_descriptors.value.info,
|
|
{"some key": "some value"}
|
|
)
|
|
|
|
def test_set_get(self):
|
|
A = self._fixture()
|
|
a1 = A(value=5)
|
|
eq_(a1._value, 10)
|
|
eq_(a1.value, 5)
|
|
|
|
def test_expression(self):
|
|
A = self._fixture()
|
|
self.assert_compile(
|
|
A.value.__clause_element__(),
|
|
"foo(a.value) + bar(a.value)"
|
|
)
|
|
|
|
def test_any(self):
|
|
A, B = self._relationship_fixture()
|
|
sess = Session()
|
|
self.assert_compile(
|
|
sess.query(B).filter(B.as_.any(value=5)),
|
|
"SELECT b.id AS b_id FROM b WHERE EXISTS "
|
|
"(SELECT 1 FROM a WHERE b.id = a.bid "
|
|
"AND foo(a.value) + bar(a.value) = :param_1)"
|
|
)
|
|
|
|
def test_aliased_expression(self):
|
|
A = self._fixture()
|
|
self.assert_compile(
|
|
aliased(A).value.__clause_element__(),
|
|
"foo(a_1.value) + bar(a_1.value)"
|
|
)
|
|
|
|
def test_query(self):
|
|
A = self._fixture()
|
|
sess = Session()
|
|
self.assert_compile(
|
|
sess.query(A).filter_by(value="foo"),
|
|
"SELECT a.value AS a_value, a.id AS a_id "
|
|
"FROM a WHERE foo(a.value) + bar(a.value) = :param_1"
|
|
)
|
|
|
|
def test_aliased_query(self):
|
|
A = self._fixture()
|
|
sess = Session()
|
|
self.assert_compile(
|
|
sess.query(aliased(A)).filter_by(value="foo"),
|
|
"SELECT a_1.value AS a_1_value, a_1.id AS a_1_id "
|
|
"FROM a AS a_1 WHERE foo(a_1.value) + bar(a_1.value) = :param_1"
|
|
)
|
|
|
|
def test_docstring(self):
|
|
A = self._fixture()
|
|
eq_(A.value.__doc__, "This is a class-level docstring")
|
|
|
|
# no docstring here since we get a literal
|
|
a1 = A(_value=10)
|
|
eq_(a1.value, 5)
|
|
|
|
|
|
class PropertyValueTest(fixtures.TestBase, AssertsCompiledSQL):
|
|
__dialect__ = 'default'
|
|
|
|
def _fixture(self, assignable):
|
|
Base = declarative_base()
|
|
|
|
class A(Base):
|
|
__tablename__ = 'a'
|
|
id = Column(Integer, primary_key=True)
|
|
_value = Column("value", String)
|
|
|
|
@hybrid.hybrid_property
|
|
def value(self):
|
|
return self._value - 5
|
|
|
|
if assignable:
|
|
@value.setter
|
|
def value(self, v):
|
|
self._value = v + 5
|
|
|
|
return A
|
|
|
|
def test_nonassignable(self):
|
|
A = self._fixture(False)
|
|
a1 = A(_value=5)
|
|
assert_raises_message(
|
|
AttributeError,
|
|
"can't set attribute",
|
|
setattr, a1, 'value', 10
|
|
)
|
|
|
|
def test_nondeletable(self):
|
|
A = self._fixture(False)
|
|
a1 = A(_value=5)
|
|
assert_raises_message(
|
|
AttributeError,
|
|
"can't delete attribute",
|
|
delattr, a1, 'value'
|
|
)
|
|
|
|
def test_set_get(self):
|
|
A = self._fixture(True)
|
|
a1 = A(value=5)
|
|
eq_(a1.value, 5)
|
|
eq_(a1._value, 10)
|
|
|
|
|
|
class PropertyOverrideTest(fixtures.TestBase, AssertsCompiledSQL):
|
|
__dialect__ = 'default'
|
|
|
|
def _fixture(self):
|
|
Base = declarative_base()
|
|
|
|
class Person(Base):
|
|
__tablename__ = 'person'
|
|
id = Column(Integer, primary_key=True)
|
|
_name = Column(String)
|
|
|
|
@hybrid.hybrid_property
|
|
def name(self):
|
|
return self._name
|
|
|
|
@name.setter
|
|
def name(self, value):
|
|
self._name = value.title()
|
|
|
|
class OverrideSetter(Person):
|
|
__tablename__ = 'override_setter'
|
|
id = Column(Integer, ForeignKey('person.id'), primary_key=True)
|
|
other = Column(String)
|
|
|
|
@Person.name.setter
|
|
def name(self, value):
|
|
self._name = value.upper()
|
|
|
|
class OverrideGetter(Person):
|
|
__tablename__ = 'override_getter'
|
|
id = Column(Integer, ForeignKey('person.id'), primary_key=True)
|
|
other = Column(String)
|
|
|
|
@Person.name.getter
|
|
def name(self):
|
|
return "Hello " + self._name
|
|
|
|
class OverrideExpr(Person):
|
|
__tablename__ = 'override_expr'
|
|
id = Column(Integer, ForeignKey('person.id'), primary_key=True)
|
|
other = Column(String)
|
|
|
|
@Person.name.overrides.expression
|
|
def name(self):
|
|
return func.concat("Hello", self._name)
|
|
|
|
class FooComparator(hybrid.Comparator):
|
|
def __clause_element__(self):
|
|
return func.concat("Hello", self.expression._name)
|
|
|
|
class OverrideComparator(Person):
|
|
__tablename__ = 'override_comp'
|
|
id = Column(Integer, ForeignKey('person.id'), primary_key=True)
|
|
other = Column(String)
|
|
|
|
@Person.name.overrides.comparator
|
|
def name(self):
|
|
return FooComparator(self)
|
|
|
|
return (
|
|
Person, OverrideSetter, OverrideGetter,
|
|
OverrideExpr, OverrideComparator
|
|
)
|
|
|
|
def test_property(self):
|
|
Person, _, _, _, _ = self._fixture()
|
|
p1 = Person()
|
|
p1.name = 'mike'
|
|
eq_(p1._name, 'Mike')
|
|
eq_(p1.name, 'Mike')
|
|
|
|
def test_override_setter(self):
|
|
_, OverrideSetter, _, _, _ = self._fixture()
|
|
p1 = OverrideSetter()
|
|
p1.name = 'mike'
|
|
eq_(p1._name, 'MIKE')
|
|
eq_(p1.name, 'MIKE')
|
|
|
|
def test_override_getter(self):
|
|
_, _, OverrideGetter, _, _ = self._fixture()
|
|
p1 = OverrideGetter()
|
|
p1.name = 'mike'
|
|
eq_(p1._name, 'Mike')
|
|
eq_(p1.name, 'Hello Mike')
|
|
|
|
def test_override_expr(self):
|
|
Person, _, _, OverrideExpr, _ = self._fixture()
|
|
|
|
self.assert_compile(
|
|
Person.name.__clause_element__(),
|
|
"person._name"
|
|
)
|
|
|
|
self.assert_compile(
|
|
OverrideExpr.name.__clause_element__(),
|
|
"concat(:concat_1, person._name)"
|
|
)
|
|
|
|
def test_override_comparator(self):
|
|
Person, _, _, _, OverrideComparator = self._fixture()
|
|
|
|
self.assert_compile(
|
|
Person.name.__clause_element__(),
|
|
"person._name"
|
|
)
|
|
|
|
self.assert_compile(
|
|
OverrideComparator.name.__clause_element__(),
|
|
"concat(:concat_1, person._name)"
|
|
)
|
|
|
|
|
|
class PropertyMirrorTest(fixtures.TestBase, AssertsCompiledSQL):
|
|
__dialect__ = 'default'
|
|
|
|
def _fixture(self):
|
|
Base = declarative_base()
|
|
|
|
class A(Base):
|
|
__tablename__ = 'a'
|
|
id = Column(Integer, primary_key=True)
|
|
_value = Column("value", String)
|
|
|
|
@hybrid.hybrid_property
|
|
def value(self):
|
|
"This is an instance-level docstring"
|
|
return self._value
|
|
return A
|
|
|
|
def test_property(self):
|
|
A = self._fixture()
|
|
|
|
is_(A.value.property, A._value.property)
|
|
|
|
def test_key(self):
|
|
A = self._fixture()
|
|
eq_(A.value.key, "value")
|
|
eq_(A._value.key, "_value")
|
|
|
|
def test_class(self):
|
|
A = self._fixture()
|
|
is_(A.value.class_, A._value.class_)
|
|
|
|
def test_get_history(self):
|
|
A = self._fixture()
|
|
inst = A(_value=5)
|
|
eq_(A.value.get_history(inst), A._value.get_history(inst))
|
|
|
|
def test_info_not_mirrored(self):
|
|
A = self._fixture()
|
|
A._value.info['foo'] = 'bar'
|
|
A.value.info['bar'] = 'hoho'
|
|
|
|
eq_(A._value.info, {'foo': 'bar'})
|
|
eq_(A.value.info, {'bar': 'hoho'})
|
|
|
|
def test_info_from_hybrid(self):
|
|
A = self._fixture()
|
|
A._value.info['foo'] = 'bar'
|
|
A.value.info['bar'] = 'hoho'
|
|
|
|
insp = inspect(A)
|
|
is_(insp.all_orm_descriptors['value'].info, A.value.info)
|
|
|
|
|
|
class MethodExpressionTest(fixtures.TestBase, AssertsCompiledSQL):
|
|
__dialect__ = 'default'
|
|
|
|
def _fixture(self):
|
|
Base = declarative_base()
|
|
|
|
class A(Base):
|
|
__tablename__ = 'a'
|
|
id = Column(Integer, primary_key=True)
|
|
_value = Column("value", String)
|
|
|
|
@hybrid.hybrid_method
|
|
def value(self, x):
|
|
"This is an instance-level docstring"
|
|
return int(self._value) + x
|
|
|
|
@value.expression
|
|
def value(cls, value):
|
|
"This is a class-level docstring"
|
|
return func.foo(cls._value, value) + value
|
|
|
|
@hybrid.hybrid_method
|
|
def other_value(self, x):
|
|
"This is an instance-level docstring"
|
|
return int(self._value) + x
|
|
|
|
@other_value.expression
|
|
def other_value(cls, value):
|
|
return func.foo(cls._value, value) + value
|
|
|
|
return A
|
|
|
|
def test_call(self):
|
|
A = self._fixture()
|
|
a1 = A(_value=10)
|
|
eq_(a1.value(7), 17)
|
|
|
|
def test_expression(self):
|
|
A = self._fixture()
|
|
self.assert_compile(
|
|
A.value(5),
|
|
"foo(a.value, :foo_1) + :foo_2"
|
|
)
|
|
|
|
def test_info(self):
|
|
A = self._fixture()
|
|
inspect(A).all_orm_descriptors.value.info["some key"] = "some value"
|
|
eq_(
|
|
inspect(A).all_orm_descriptors.value.info,
|
|
{"some key": "some value"}
|
|
)
|
|
|
|
def test_aliased_expression(self):
|
|
A = self._fixture()
|
|
self.assert_compile(
|
|
aliased(A).value(5),
|
|
"foo(a_1.value, :foo_1) + :foo_2"
|
|
)
|
|
|
|
def test_query(self):
|
|
A = self._fixture()
|
|
sess = Session()
|
|
self.assert_compile(
|
|
sess.query(A).filter(A.value(5) == "foo"),
|
|
"SELECT a.value AS a_value, a.id AS a_id "
|
|
"FROM a WHERE foo(a.value, :foo_1) + :foo_2 = :param_1"
|
|
)
|
|
|
|
def test_aliased_query(self):
|
|
A = self._fixture()
|
|
sess = Session()
|
|
a1 = aliased(A)
|
|
self.assert_compile(
|
|
sess.query(a1).filter(a1.value(5) == "foo"),
|
|
"SELECT a_1.value AS a_1_value, a_1.id AS a_1_id "
|
|
"FROM a AS a_1 WHERE foo(a_1.value, :foo_1) + :foo_2 = :param_1"
|
|
)
|
|
|
|
def test_query_col(self):
|
|
A = self._fixture()
|
|
sess = Session()
|
|
self.assert_compile(
|
|
sess.query(A.value(5)),
|
|
"SELECT foo(a.value, :foo_1) + :foo_2 AS anon_1 FROM a"
|
|
)
|
|
|
|
def test_aliased_query_col(self):
|
|
A = self._fixture()
|
|
sess = Session()
|
|
self.assert_compile(
|
|
sess.query(aliased(A).value(5)),
|
|
"SELECT foo(a_1.value, :foo_1) + :foo_2 AS anon_1 FROM a AS a_1"
|
|
)
|
|
|
|
def test_docstring(self):
|
|
A = self._fixture()
|
|
eq_(A.value.__doc__, "This is a class-level docstring")
|
|
eq_(A.other_value.__doc__, "This is an instance-level docstring")
|
|
a1 = A(_value=10)
|
|
|
|
# a1.value is still a method, so it has a
|
|
# docstring
|
|
eq_(a1.value.__doc__, "This is an instance-level docstring")
|
|
|
|
eq_(a1.other_value.__doc__, "This is an instance-level docstring")
|
|
|
|
|
|
class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL):
|
|
__dialect__ = 'default'
|
|
|
|
@classmethod
|
|
def setup_classes(cls):
|
|
Base = cls.DeclarativeBasic
|
|
|
|
class Person(Base):
|
|
__tablename__ = 'person'
|
|
|
|
id = Column(Integer, primary_key=True)
|
|
first_name = Column(String(10))
|
|
last_name = Column(String(10))
|
|
|
|
@hybrid.hybrid_property
|
|
def name(self):
|
|
return self.first_name + ' ' + self.last_name
|
|
|
|
@name.setter
|
|
def name(self, value):
|
|
self.first_name, self.last_name = value.split(' ', 1)
|
|
|
|
@name.expression
|
|
def name(cls):
|
|
return func.concat(cls.first_name, ' ', cls.last_name)
|
|
|
|
@name.update_expression
|
|
def name(cls, value):
|
|
f, l = value.split(' ', 1)
|
|
return [(cls.first_name, f), (cls.last_name, l)]
|
|
|
|
@hybrid.hybrid_property
|
|
def uname(self):
|
|
return self.name
|
|
|
|
@hybrid.hybrid_property
|
|
def fname(self):
|
|
return self.first_name
|
|
|
|
@hybrid.hybrid_property
|
|
def fname2(self):
|
|
return self.fname
|
|
|
|
@classmethod
|
|
def insert_data(cls):
|
|
s = Session()
|
|
jill = cls.classes.Person(id=3, first_name='jill')
|
|
s.add(jill)
|
|
s.commit()
|
|
|
|
def test_update_plain(self):
|
|
Person = self.classes.Person
|
|
|
|
s = Session()
|
|
q = s.query(Person)
|
|
|
|
bulk_ud = persistence.BulkUpdate.factory(
|
|
q, False, {Person.fname: "Dr."}, {})
|
|
|
|
self.assert_compile(
|
|
bulk_ud,
|
|
"UPDATE person SET first_name=:first_name",
|
|
params={'first_name': 'Dr.'}
|
|
)
|
|
|
|
def test_update_expr(self):
|
|
Person = self.classes.Person
|
|
|
|
s = Session()
|
|
q = s.query(Person)
|
|
|
|
bulk_ud = persistence.BulkUpdate.factory(
|
|
q, False, {Person.name: "Dr. No"}, {})
|
|
|
|
self.assert_compile(
|
|
bulk_ud,
|
|
"UPDATE person SET first_name=:first_name, last_name=:last_name",
|
|
params={'first_name': 'Dr.', 'last_name': 'No'}
|
|
)
|
|
|
|
def test_evaluate_hybrid_attr_indirect(self):
|
|
Person = self.classes.Person
|
|
|
|
s = Session()
|
|
jill = s.query(Person).get(3)
|
|
|
|
s.query(Person).update(
|
|
{Person.fname2: 'moonbeam'},
|
|
synchronize_session='evaluate')
|
|
eq_(jill.fname2, 'moonbeam')
|
|
|
|
def test_evaluate_hybrid_attr_plain(self):
|
|
Person = self.classes.Person
|
|
|
|
s = Session()
|
|
jill = s.query(Person).get(3)
|
|
|
|
s.query(Person).update(
|
|
{Person.fname: 'moonbeam'},
|
|
synchronize_session='evaluate')
|
|
eq_(jill.fname, 'moonbeam')
|
|
|
|
def test_fetch_hybrid_attr_indirect(self):
|
|
Person = self.classes.Person
|
|
|
|
s = Session()
|
|
jill = s.query(Person).get(3)
|
|
|
|
s.query(Person).update(
|
|
{Person.fname2: 'moonbeam'},
|
|
synchronize_session='fetch')
|
|
eq_(jill.fname2, 'moonbeam')
|
|
|
|
def test_fetch_hybrid_attr_plain(self):
|
|
Person = self.classes.Person
|
|
|
|
s = Session()
|
|
jill = s.query(Person).get(3)
|
|
|
|
s.query(Person).update(
|
|
{Person.fname: 'moonbeam'},
|
|
synchronize_session='fetch')
|
|
eq_(jill.fname, 'moonbeam')
|
|
|
|
def test_evaluate_hybrid_attr_w_update_expr(self):
|
|
Person = self.classes.Person
|
|
|
|
s = Session()
|
|
jill = s.query(Person).get(3)
|
|
|
|
s.query(Person).update(
|
|
{Person.name: 'moonbeam sunshine'},
|
|
synchronize_session='evaluate')
|
|
eq_(jill.name, 'moonbeam sunshine')
|
|
|
|
def test_fetch_hybrid_attr_w_update_expr(self):
|
|
Person = self.classes.Person
|
|
|
|
s = Session()
|
|
jill = s.query(Person).get(3)
|
|
|
|
s.query(Person).update(
|
|
{Person.name: 'moonbeam sunshine'},
|
|
synchronize_session='fetch')
|
|
eq_(jill.name, 'moonbeam sunshine')
|
|
|
|
def test_evaluate_hybrid_attr_indirect_w_update_expr(self):
|
|
Person = self.classes.Person
|
|
|
|
s = Session()
|
|
jill = s.query(Person).get(3)
|
|
|
|
s.query(Person).update(
|
|
{Person.uname: 'moonbeam sunshine'},
|
|
synchronize_session='evaluate')
|
|
eq_(jill.uname, 'moonbeam sunshine')
|
|
|
|
|
|
class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL):
|
|
"""tests against hybrids that return a non-ClauseElement.
|
|
|
|
use cases derived from the example at
|
|
http://techspot.zzzeek.org/2011/10/21/hybrids-and-value-agnostic-types/
|
|
|
|
"""
|
|
__dialect__ = 'default'
|
|
|
|
@classmethod
|
|
def setup_class(cls):
|
|
from sqlalchemy import literal
|
|
|
|
symbols = ('usd', 'gbp', 'cad', 'eur', 'aud')
|
|
currency_lookup = dict(
|
|
((currency_from, currency_to), Decimal(str(rate)))
|
|
for currency_to, values in zip(
|
|
symbols,
|
|
[
|
|
(1, 1.59009, 0.988611, 1.37979, 1.02962),
|
|
(0.628895, 1, 0.621732, 0.867748, 0.647525),
|
|
(1.01152, 1.6084, 1, 1.39569, 1.04148),
|
|
(0.724743, 1.1524, 0.716489, 1, 0.746213),
|
|
(0.971228, 1.54434, 0.960166, 1.34009, 1),
|
|
])
|
|
for currency_from, rate in zip(symbols, values)
|
|
)
|
|
|
|
class Amount(object):
|
|
def __init__(self, amount, currency):
|
|
self.currency = currency
|
|
self.amount = amount
|
|
|
|
def __add__(self, other):
|
|
return Amount(
|
|
self.amount +
|
|
other.as_currency(self.currency).amount,
|
|
self.currency
|
|
)
|
|
|
|
def __sub__(self, other):
|
|
return Amount(
|
|
self.amount -
|
|
other.as_currency(self.currency).amount,
|
|
self.currency
|
|
)
|
|
|
|
def __lt__(self, other):
|
|
return self.amount < other.as_currency(self.currency).amount
|
|
|
|
def __gt__(self, other):
|
|
return self.amount > other.as_currency(self.currency).amount
|
|
|
|
def __eq__(self, other):
|
|
return self.amount == other.as_currency(self.currency).amount
|
|
|
|
def as_currency(self, other_currency):
|
|
return Amount(
|
|
currency_lookup[(self.currency, other_currency)] *
|
|
self.amount,
|
|
other_currency
|
|
)
|
|
|
|
def __clause_element__(self):
|
|
# helper method for SQLAlchemy to interpret
|
|
# the Amount object as a SQL element
|
|
if isinstance(self.amount, (float, int, Decimal)):
|
|
return literal(self.amount)
|
|
else:
|
|
return self.amount
|
|
|
|
def __str__(self):
|
|
return "%2.4f %s" % (self.amount, self.currency)
|
|
|
|
def __repr__(self):
|
|
return "Amount(%r, %r)" % (self.amount, self.currency)
|
|
|
|
Base = declarative_base()
|
|
|
|
class BankAccount(Base):
|
|
__tablename__ = 'bank_account'
|
|
id = Column(Integer, primary_key=True)
|
|
|
|
_balance = Column('balance', Numeric)
|
|
|
|
@hybrid.hybrid_property
|
|
def balance(self):
|
|
"""Return an Amount view of the current balance."""
|
|
return Amount(self._balance, "usd")
|
|
|
|
@balance.setter
|
|
def balance(self, value):
|
|
self._balance = value.as_currency("usd").amount
|
|
|
|
cls.Amount = Amount
|
|
cls.BankAccount = BankAccount
|
|
|
|
def test_instance_one(self):
|
|
BankAccount, Amount = self.BankAccount, self.Amount
|
|
account = BankAccount(balance=Amount(4000, "usd"))
|
|
|
|
# 3b. print balance in usd
|
|
eq_(account.balance.amount, 4000)
|
|
|
|
def test_instance_two(self):
|
|
BankAccount, Amount = self.BankAccount, self.Amount
|
|
account = BankAccount(balance=Amount(4000, "usd"))
|
|
|
|
# 3c. print balance in gbp
|
|
eq_(account.balance.as_currency("gbp").amount, Decimal('2515.58'))
|
|
|
|
def test_instance_three(self):
|
|
BankAccount, Amount = self.BankAccount, self.Amount
|
|
account = BankAccount(balance=Amount(4000, "usd"))
|
|
|
|
# 3d. perform currency-agnostic comparisons, math
|
|
is_(account.balance > Amount(500, "cad"), True)
|
|
|
|
def test_instance_four(self):
|
|
BankAccount, Amount = self.BankAccount, self.Amount
|
|
account = BankAccount(balance=Amount(4000, "usd"))
|
|
eq_(
|
|
account.balance + Amount(500, "cad") - Amount(50, "eur"),
|
|
Amount(Decimal("4425.316"), "usd")
|
|
)
|
|
|
|
def test_query_one(self):
|
|
BankAccount, Amount = self.BankAccount, self.Amount
|
|
session = Session()
|
|
|
|
query = session.query(BankAccount).\
|
|
filter(BankAccount.balance == Amount(10000, "cad"))
|
|
|
|
self.assert_compile(
|
|
query,
|
|
"SELECT bank_account.balance AS bank_account_balance, "
|
|
"bank_account.id AS bank_account_id FROM bank_account "
|
|
"WHERE bank_account.balance = :balance_1",
|
|
checkparams={'balance_1': Decimal('9886.110000')}
|
|
)
|
|
|
|
def test_query_two(self):
|
|
BankAccount, Amount = self.BankAccount, self.Amount
|
|
session = Session()
|
|
|
|
# alternatively we can do the calc on the DB side.
|
|
query = session.query(BankAccount).\
|
|
filter(
|
|
BankAccount.balance.as_currency("cad") > Amount(9999, "cad")).\
|
|
filter(
|
|
BankAccount.balance.as_currency("cad") < Amount(10001, "cad"))
|
|
self.assert_compile(
|
|
query,
|
|
"SELECT bank_account.balance AS bank_account_balance, "
|
|
"bank_account.id AS bank_account_id "
|
|
"FROM bank_account "
|
|
"WHERE :balance_1 * bank_account.balance > :param_1 "
|
|
"AND :balance_2 * bank_account.balance < :param_2",
|
|
checkparams={
|
|
'balance_1': Decimal('1.01152'),
|
|
'balance_2': Decimal('1.01152'),
|
|
'param_1': Decimal('9999'),
|
|
'param_2': Decimal('10001')}
|
|
)
|
|
|
|
def test_query_three(self):
|
|
BankAccount = self.BankAccount
|
|
session = Session()
|
|
|
|
query = session.query(BankAccount).\
|
|
filter(
|
|
BankAccount.balance.as_currency("cad") >
|
|
BankAccount.balance.as_currency("eur"))
|
|
self.assert_compile(
|
|
query,
|
|
"SELECT bank_account.balance AS bank_account_balance, "
|
|
"bank_account.id AS bank_account_id FROM bank_account "
|
|
"WHERE :balance_1 * bank_account.balance > "
|
|
":param_1 * :balance_2 * bank_account.balance",
|
|
checkparams={
|
|
'balance_1': Decimal('1.01152'),
|
|
'balance_2': Decimal('0.724743'),
|
|
'param_1': Decimal('1.39569')}
|
|
)
|
|
|
|
def test_query_four(self):
|
|
BankAccount = self.BankAccount
|
|
session = Session()
|
|
|
|
# 4c. query all amounts, converting to "CAD" on the DB side
|
|
query = session.query(BankAccount.balance.as_currency("cad").amount)
|
|
self.assert_compile(
|
|
query,
|
|
"SELECT :balance_1 * bank_account.balance AS anon_1 "
|
|
"FROM bank_account",
|
|
checkparams={'balance_1': Decimal('1.01152')}
|
|
)
|
|
|
|
def test_query_five(self):
|
|
BankAccount = self.BankAccount
|
|
session = Session()
|
|
|
|
# 4d. average balance in EUR
|
|
query = session.query(func.avg(BankAccount.balance.as_currency("eur")))
|
|
self.assert_compile(
|
|
query,
|
|
"SELECT avg(:balance_1 * bank_account.balance) AS avg_1 "
|
|
"FROM bank_account",
|
|
checkparams={'balance_1': Decimal('0.724743')}
|
|
)
|
|
|
|
def test_docstring(self):
|
|
BankAccount = self.BankAccount
|
|
eq_(
|
|
BankAccount.balance.__doc__,
|
|
"Return an Amount view of the current balance.")
|