some fixes to IN clauses, literal text clauses displaying text/numeric properly including

longs
This commit is contained in:
Mike Bayer
2005-11-27 05:31:22 +00:00
parent c67359157e
commit cdcd74cb39
2 changed files with 23 additions and 13 deletions
+5 -5
View File
@@ -151,10 +151,7 @@ class ANSICompiler(sql.Compiled):
if compound.operator is None:
sep = " "
else:
if compound.spaces:
sep = compound.operator
else:
sep = " " + compound.operator + " "
sep = " " + compound.operator + " "
if compound.parens:
self.strings[compound] = "(" + string.join([self.get_str(c) for c in compound.clauses], sep) + ")"
@@ -162,7 +159,10 @@ class ANSICompiler(sql.Compiled):
self.strings[compound] = string.join([self.get_str(c) for c in compound.clauses], sep)
def visit_clauselist(self, list):
self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ')
if list.parens:
self.strings[list] = "(" + string.join([self.get_str(c) for c in list.clauses], ', ') + ")"
else:
self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ')
def visit_binary(self, binary):
result = self.get_str(binary.left)
+18 -8
View File
@@ -21,7 +21,7 @@
import sqlalchemy.schema as schema
import sqlalchemy.util as util
import sqlalchemy.types as types
import string
import string, re
__ALL__ = ['textclause', 'select', 'join', 'and_', 'or_', 'not_', 'union', 'unionall', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'bindparam', 'sequence']
@@ -328,8 +328,11 @@ class CompareMixin(object):
elif len(other) == 1 and not isinstance(other[0], Selectable):
return self.__eq__(other[0])
elif _is_literal(other[0]):
return self._compare('IN', CompoundClause(',', spaces=False, parens=True, *other))
return self._compare('IN', ClauseList(parens=True, *[TextClause(o, isliteral=True) for o in other]))
else:
# assume *other is a list of selects.
# so put them in a UNION. if theres only one, you just get one SELECT
# statement out of it.
return self._compare('IN', union(*other))
def startswith(self, other):
@@ -421,12 +424,19 @@ class BindParamClause(ClauseElement):
return self.type.convert_bind_param(value)
class TextClause(ClauseElement):
"""represents any plain text WHERE clause or full SQL statement"""
"""represents literal text, including SQL fragments as well
as literal (non bind-param) values."""
def __init__(self, text = "", engine=None):
def __init__(self, text = "", engine=None, isliteral=False):
self.text = text
self.parens = False
self.engine = engine
if isliteral:
if isinstance(text, int) or isinstance(text, long):
self.text = str(text)
else:
text = re.sub(r"'", r"''", text)
self.text = "'" + text + "'"
def accept_visitor(self, visitor):
visitor.visit_textclause(self)
def hash_key(self):
@@ -447,8 +457,7 @@ class CompoundClause(ClauseElement):
def __init__(self, operator, *clauses, **kwargs):
self.operator = operator
self.clauses = []
self.parens = kwargs.pop('parens', False)
self.spaces = kwargs.pop('spaces', False)
self.parens = False
for c in clauses:
if c is None: continue
self.append(c)
@@ -459,7 +468,7 @@ class CompoundClause(ClauseElement):
def append(self, clause):
if _is_literal(clause):
clause = TextClause(repr(clause))
clause = TextClause(str(clause))
elif isinstance(clause, CompoundClause):
clause.parens = True
self.clauses.append(clause)
@@ -479,8 +488,9 @@ class CompoundClause(ClauseElement):
return string.join([c.hash_key() for c in self.clauses], self.operator)
class ClauseList(ClauseElement):
def __init__(self, *clauses):
def __init__(self, *clauses, **kwargs):
self.clauses = clauses
self.parens = kwargs.get('parens', False)
def accept_visitor(self, visitor):
for c in self.clauses: