diff --git a/CHANGES b/CHANGES index 53eb7683ee..a83db182c1 100644 --- a/CHANGES +++ b/CHANGES @@ -196,8 +196,13 @@ CHANGES symptom. - The case() function now also takes a dictionary as its whens - parameter. But beware that it doesn't escape literals, use - the literal construct for that. + parameter. It also interprets the "THEN" expressions + as values by default, meaning case([(x==y, "foo")]) will + interpret "foo" as a bound value, not a SQL expression. + use text(expr) for literal SQL expressions in this case. + For the criterion itself, these may be literal strings + only if the "value" keyword is present, otherwise SA + will force explicit usage of either text() or literal(). - declarative extension - The "synonym" function is now directly usable with diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index cc97227a70..39a2ae3eb9 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -392,7 +392,7 @@ def not_(clause): result. """ - return operators.inv(clause) + return operators.inv(_literal_as_binds(clause)) def distinct(expr): """Return a ``DISTINCT`` clause.""" @@ -416,24 +416,45 @@ def case(whens, value=None, else_=None): """Produce a ``CASE`` statement. whens - A sequence of pairs or a dict to be translated into "when / then" clauses. + A sequence of pairs, or alternatively a dict, + to be translated into "WHEN / THEN" clauses. value - Optional for simple case statements. + Optional for simple case statements, produces + a column expression as in "CASE WHEN ..." else\_ - Optional as well, for case defaults. + Optional as well, for case defaults produces + the "ELSE" portion of the "CASE" statement. + + The expressions used for THEN and ELSE, + when specified as strings, will be interpreted + as bound values. To specify textual SQL expressions + for these, use the text() construct. + + The expressions used for the WHEN criterion + may only be literal strings when "value" is + present, i.e. CASE table.somecol WHEN "x" THEN "y". + Otherwise, literal strings are not accepted + in this position, and either the text() + or literal() constructs must be used to + interpret raw string values. + """ - try: whens = util.dictlike_iteritems(whens) except TypeError: pass - - whenlist = [ClauseList('WHEN', c, 'THEN', r, operator=None) + + if value: + crit_filter = _literal_as_binds + else: + crit_filter = _no_literals + + whenlist = [ClauseList('WHEN', crit_filter(c), 'THEN', _literal_as_binds(r), operator=None) for (c,r) in whens] - if not else_ is None: - whenlist.append(ClauseList('ELSE', else_, operator=None)) + if else_ is not None: + whenlist.append(ClauseList('ELSE', _literal_as_binds(else_), operator=None)) if whenlist: type = list(whenlist[-1])[-1].type else: @@ -842,6 +863,14 @@ def _literal_as_binds(element, name=None, type_=None): else: return element +def _no_literals(element): + if isinstance(element, Operators): + return element.expression_element() + elif _is_literal(element): + raise exceptions.ArgumentError("Ambiguous literal: %r. Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element) + else: + return element + def _corresponding_column_or_error(fromclause, column, require_embedded=False): c = fromclause.corresponding_column(column, require_embedded=require_embedded) if not c: diff --git a/test/sql/case_statement.py b/test/sql/case_statement.py index 730517b210..257298c8e5 100644 --- a/test/sql/case_statement.py +++ b/test/sql/case_statement.py @@ -2,10 +2,11 @@ import testenv; testenv.configure_for_tests() import sys from sqlalchemy import * from testlib import * -from sqlalchemy import util +from sqlalchemy import util, exceptions +from sqlalchemy.sql import table, column -class CaseTest(TestBase): +class CaseTest(TestBase, AssertsCompiledSQL): def setUpAll(self): metadata = MetaData(testing.db) @@ -30,9 +31,9 @@ class CaseTest(TestBase): def testcase(self): inner = select([case([ [info_table.c.pk < 3, - literal('lessthan3', type_=String)], + 'lessthan3'], [and_(info_table.c.pk >= 3, info_table.c.pk < 7), - literal('gt3', type_=String)]]).label('x'), + 'gt3']]).label('x'), info_table.c.pk, info_table.c.info], from_obj=[info_table]).alias('q_inner') @@ -69,9 +70,9 @@ class CaseTest(TestBase): w_else = select([case([ [info_table.c.pk < 3, - literal(3, type_=Integer)], + 3], [and_(info_table.c.pk >= 3, info_table.c.pk < 6), - literal(6, type_=Integer)]], + 6]], else_ = 0).label('x'), info_table.c.pk, info_table.c.info], from_obj=[info_table]).alias('q_inner') @@ -87,12 +88,21 @@ class CaseTest(TestBase): (0, 6, 'pk_6_data') ] + def test_literal_interpretation(self): + t = table('test', column('col1')) + + self.assertRaises(exceptions.ArgumentError, case, [("x", "y")]) + + self.assert_compile(case([("x", "y")], value=t.c.col1), "CASE test.col1 WHEN :param_1 THEN :param_2 END") + self.assert_compile(case([(t.c.col1==7, "y")], else_="z"), "CASE WHEN (test.col1 = :test_col1_1) THEN :param_1 ELSE :param_2 END") + + @testing.fails_on('maxdb') def testcase_with_dict(self): query = select([case({ - info_table.c.pk < 3: literal('lessthan3'), - info_table.c.pk >= 3: literal('gt3'), - }, else_=literal('other')), + info_table.c.pk < 3: 'lessthan3', + info_table.c.pk >= 3: 'gt3', + }, else_='other'), info_table.c.pk, info_table.c.info ], from_obj=[info_table]) @@ -106,13 +116,14 @@ class CaseTest(TestBase): ] simple_query = select([case({ - 1: literal('one'), - 2: literal('two'), - }, value=info_table.c.pk, else_=literal('other')), + 1: 'one', + 2: 'two', + }, value=info_table.c.pk, else_='other'), info_table.c.pk ], whereclause=info_table.c.pk < 4, from_obj=[info_table]) + assert simple_query.execute().fetchall() == [ ('one', 1), ('two', 2),