- case() interprets the "THEN" expressions

as values by default, meaning case([(x==y, "foo")]) will
interpret "foo" as a bound value, not a SQL expression.
use text(expr) for literal SQL expressions in this case.
For the criterion itself, these may be literal strings
only if the "value" keyword is present, otherwise SA
will force explicit usage of either text() or literal().
This commit is contained in:
Mike Bayer
2008-04-03 16:34:03 +00:00
parent a27d6be28a
commit abb10856dc
3 changed files with 68 additions and 23 deletions
+7 -2
View File
@@ -196,8 +196,13 @@ CHANGES
symptom.
- The case() function now also takes a dictionary as its whens
parameter. But beware that it doesn't escape literals, use
the literal construct for that.
parameter. It also interprets the "THEN" expressions
as values by default, meaning case([(x==y, "foo")]) will
interpret "foo" as a bound value, not a SQL expression.
use text(expr) for literal SQL expressions in this case.
For the criterion itself, these may be literal strings
only if the "value" keyword is present, otherwise SA
will force explicit usage of either text() or literal().
- declarative extension
- The "synonym" function is now directly usable with
+38 -9
View File
@@ -392,7 +392,7 @@ def not_(clause):
result.
"""
return operators.inv(clause)
return operators.inv(_literal_as_binds(clause))
def distinct(expr):
"""Return a ``DISTINCT`` clause."""
@@ -416,24 +416,45 @@ def case(whens, value=None, else_=None):
"""Produce a ``CASE`` statement.
whens
A sequence of pairs or a dict to be translated into "when / then" clauses.
A sequence of pairs, or alternatively a dict,
to be translated into "WHEN / THEN" clauses.
value
Optional for simple case statements.
Optional for simple case statements, produces
a column expression as in "CASE <expr> WHEN ..."
else\_
Optional as well, for case defaults.
Optional as well, for case defaults produces
the "ELSE" portion of the "CASE" statement.
The expressions used for THEN and ELSE,
when specified as strings, will be interpreted
as bound values. To specify textual SQL expressions
for these, use the text(<string>) construct.
The expressions used for the WHEN criterion
may only be literal strings when "value" is
present, i.e. CASE table.somecol WHEN "x" THEN "y".
Otherwise, literal strings are not accepted
in this position, and either the text(<string>)
or literal(<string>) constructs must be used to
interpret raw string values.
"""
try:
whens = util.dictlike_iteritems(whens)
except TypeError:
pass
whenlist = [ClauseList('WHEN', c, 'THEN', r, operator=None)
if value:
crit_filter = _literal_as_binds
else:
crit_filter = _no_literals
whenlist = [ClauseList('WHEN', crit_filter(c), 'THEN', _literal_as_binds(r), operator=None)
for (c,r) in whens]
if not else_ is None:
whenlist.append(ClauseList('ELSE', else_, operator=None))
if else_ is not None:
whenlist.append(ClauseList('ELSE', _literal_as_binds(else_), operator=None))
if whenlist:
type = list(whenlist[-1])[-1].type
else:
@@ -842,6 +863,14 @@ def _literal_as_binds(element, name=None, type_=None):
else:
return element
def _no_literals(element):
if isinstance(element, Operators):
return element.expression_element()
elif _is_literal(element):
raise exceptions.ArgumentError("Ambiguous literal: %r. Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element)
else:
return element
def _corresponding_column_or_error(fromclause, column, require_embedded=False):
c = fromclause.corresponding_column(column, require_embedded=require_embedded)
if not c:
+23 -12
View File
@@ -2,10 +2,11 @@ import testenv; testenv.configure_for_tests()
import sys
from sqlalchemy import *
from testlib import *
from sqlalchemy import util
from sqlalchemy import util, exceptions
from sqlalchemy.sql import table, column
class CaseTest(TestBase):
class CaseTest(TestBase, AssertsCompiledSQL):
def setUpAll(self):
metadata = MetaData(testing.db)
@@ -30,9 +31,9 @@ class CaseTest(TestBase):
def testcase(self):
inner = select([case([
[info_table.c.pk < 3,
literal('lessthan3', type_=String)],
'lessthan3'],
[and_(info_table.c.pk >= 3, info_table.c.pk < 7),
literal('gt3', type_=String)]]).label('x'),
'gt3']]).label('x'),
info_table.c.pk, info_table.c.info],
from_obj=[info_table]).alias('q_inner')
@@ -69,9 +70,9 @@ class CaseTest(TestBase):
w_else = select([case([
[info_table.c.pk < 3,
literal(3, type_=Integer)],
3],
[and_(info_table.c.pk >= 3, info_table.c.pk < 6),
literal(6, type_=Integer)]],
6]],
else_ = 0).label('x'),
info_table.c.pk, info_table.c.info],
from_obj=[info_table]).alias('q_inner')
@@ -87,12 +88,21 @@ class CaseTest(TestBase):
(0, 6, 'pk_6_data')
]
def test_literal_interpretation(self):
t = table('test', column('col1'))
self.assertRaises(exceptions.ArgumentError, case, [("x", "y")])
self.assert_compile(case([("x", "y")], value=t.c.col1), "CASE test.col1 WHEN :param_1 THEN :param_2 END")
self.assert_compile(case([(t.c.col1==7, "y")], else_="z"), "CASE WHEN (test.col1 = :test_col1_1) THEN :param_1 ELSE :param_2 END")
@testing.fails_on('maxdb')
def testcase_with_dict(self):
query = select([case({
info_table.c.pk < 3: literal('lessthan3'),
info_table.c.pk >= 3: literal('gt3'),
}, else_=literal('other')),
info_table.c.pk < 3: 'lessthan3',
info_table.c.pk >= 3: 'gt3',
}, else_='other'),
info_table.c.pk, info_table.c.info
],
from_obj=[info_table])
@@ -106,13 +116,14 @@ class CaseTest(TestBase):
]
simple_query = select([case({
1: literal('one'),
2: literal('two'),
}, value=info_table.c.pk, else_=literal('other')),
1: 'one',
2: 'two',
}, value=info_table.c.pk, else_='other'),
info_table.c.pk
],
whereclause=info_table.c.pk < 4,
from_obj=[info_table])
assert simple_query.execute().fetchall() == [
('one', 1),
('two', 2),