Refactoring alias generation.

This commit is contained in:
Charles Leifer
2014-07-09 23:41:35 -05:00
parent 052b58b365
commit 51d82fcd99
2 changed files with 130 additions and 90 deletions
+124 -74
View File
@@ -222,6 +222,7 @@ DJANGO_MAP = {
JOIN_INNER = 'inner'
JOIN_LEFT_OUTER = 'left outer'
JOIN_RIGHT_OUTER = 'right outer'
JOIN_FULL = 'full'
# Helper functions that are used in various parts of the codebase.
@@ -293,6 +294,7 @@ class _CDescriptor(object):
class Node(object):
"""Base-class for any part of a query which shall be composable."""
c = _CDescriptor()
_node_type = 'node'
def __init__(self):
self._negated = False
@@ -388,6 +390,8 @@ class Node(object):
class Expression(Node):
"""A binary expression, e.g `foo + 1` or `bar < 7`."""
_node_type = 'expression'
def __init__(self, lhs, op, rhs, flat=False):
super(Expression, self).__init__()
self.lhs = lhs
@@ -413,6 +417,8 @@ class Param(Node):
specifically treat this value as a parameter, useful for `list` which is
special-cased for `IN` lookups.
"""
_node_type = 'param'
def __init__(self, value, conv=None):
self.value = value
self.conv = conv
@@ -423,6 +429,8 @@ class Param(Node):
class SQL(Node):
"""An unescaped SQL string, with optional parameters."""
_node_type = 'sql'
def __init__(self, value, *params):
self.value = value
self.params = params
@@ -434,6 +442,8 @@ R = SQL # backwards-compat.
class Func(Node):
"""An arbitrary SQL function call."""
_node_type = 'func'
def __init__(self, name, *arguments):
self.name = name
self.arguments = arguments
@@ -492,6 +502,8 @@ class Window(Node):
class Clause(Node):
"""A SQL clause, one or more Node objects joined by spaces."""
_node_type = 'clause'
glue = ' '
parens = False
@@ -515,6 +527,8 @@ class EnclosedClause(CommaClause):
class Entity(Node):
"""A quoted-name or entity, e.g. "table"."column"."""
_node_type = 'entity'
def __init__(self, *path):
super(Entity, self).__init__()
self.path = path
@@ -552,6 +566,7 @@ class Field(Node):
"""A column on a table."""
_field_counter = 0
_order = 0
_node_type = 'field'
db_field = 'unknown'
def __init__(self, null=False, index=False, unique=False,
@@ -1034,12 +1049,36 @@ class CompositeKey(object):
return reduce(operator.and_, expressions)
class QueryContext(object):
class AliasMap(object):
"""
Provide a "scope" for a query.
"""
prefix = 't'
def __init__(self):
self._alias_map = {}
self._counter = 0
def add(self, obj, alias=None):
if obj in self._alias_map:
return
self._counter += 1
self._alias_map[obj] = alias or '%s%s' % (self.prefix, self._counter)
def __getitem__(self, obj):
if obj not in self._alias_map:
self.add(obj)
return self._alias_map[obj]
def __contains__(self, obj):
return obj in self._alias_map
def update(self, alias_map):
self._alias_map.update(alias_map._alias_map)
return self
def clone(self):
return AliasMap().update(self)
class QueryCompiler(object):
@@ -1093,8 +1132,10 @@ class QueryCompiler(object):
join_map = {
JOIN_INNER: 'INNER',
JOIN_LEFT_OUTER: 'LEFT OUTER',
JOIN_RIGHT_OUTER: 'RIGHT OUTER',
JOIN_FULL: 'FULL',
}
alias_map_class = AliasMap
def __init__(self, quote_char='"', interpolation='?', field_overrides=None,
op_overrides=None):
@@ -1102,6 +1143,17 @@ class QueryCompiler(object):
self.interpolation = interpolation
self._field_map = merge_dict(self.field_map, field_overrides or {})
self._op_map = merge_dict(self.op_map, op_overrides or {})
self._parse_map = {
'expression': self._parse_expression,
'param': self._parse_param,
'func': self._parse_func,
'clause': self._parse_clause,
'entity': self._parse_entity,
'field': self._parse_field,
'query': self._parse_query,
'sql': self._parse_sql,
}
self._unknown_types = set(['param'])
def quote(self, s):
return '%s%s%s' % (self.quote_char, s, self.quote_char)
@@ -1115,70 +1167,63 @@ class QueryCompiler(object):
def _sorted_fields(self, field_dict):
return sorted(field_dict.items(), key=lambda i: i[0]._sort_key)
def _max_alias(self, alias_map):
max_alias = 0
def _parse_default(self, node, alias_map, conv):
return self.interpolation, [node]
def _parse_expression(self, node, alias_map, conv):
if isinstance(node.lhs, Field):
conv = node.lhs
lhs, lparams = self.parse_node(node.lhs, alias_map, conv)
rhs, rparams = self.parse_node(node.rhs, alias_map, conv)
template = '%s %s %s' if node.flat else '(%s %s %s)'
sql = template % (lhs, self.get_op(node.op), rhs)
return sql, lparams + rparams
def _parse_param(self, node, alias_map, conv):
if node.conv:
params = [node.conv(node.value)]
else:
params = [node.value]
return self.interpolation, params
def _parse_func(self, node, alias_map, conv):
conv = node._coerce and conv or None
sql, params = self.parse_node_list(node.arguments, alias_map, conv)
return '%s(%s)' % (node.name, sql), params
def _parse_clause(self, node, alias_map, conv):
sql, params = self.parse_node_list(
node.nodes, alias_map, conv, node.glue)
if node.parens:
sql = '(%s)' % sql
return sql, params
def _parse_entity(self, node, alias_map, conv):
return '.'.join(map(self.quote, node.path)), []
def _parse_sql(self, node, alias_map, conv):
return node.value, list(node.params)
def _parse_field(self, node, alias_map, conv):
if alias_map:
for alias in alias_map.values():
try:
alias_number = int(alias.lstrip('t'))
except ValueError:
alias_number = 0
if alias_number > max_alias:
max_alias = alias_number
return max_alias + 1
entity = Entity(alias_map[node.model_class], node.db_column)
else:
entity = Entity(node.db_column)
sql, params, _ = self._parse(entity, alias_map, conv)
return sql, params
def _ensure_alias_set(self, model, alias_map):
if model not in alias_map:
max_alias = self._max_alias(alias_map)
alias_map[model] = 't%d' % max_alias
def _parse(self, node, alias_map, conv):
# By default treat the incoming node as a raw value that should be
# parameterized.
sql = self.interpolation
params = [node]
unknown = False
if isinstance(node, Expression):
if isinstance(node.lhs, Field):
conv = node.lhs
lhs, lparams = self.parse_node(node.lhs, alias_map, conv)
rhs, rparams = self.parse_node(node.rhs, alias_map, conv)
template = '%s %s %s' if node.flat else '(%s %s %s)'
sql = template % (lhs, self.get_op(node.op), rhs)
params = lparams + rparams
elif isinstance(node, Field):
sql = self.quote(node.db_column)
if alias_map and node.model_class in alias_map:
sql = '.'.join((alias_map[node.model_class], sql))
params = []
elif isinstance(node, Func):
conv = node._coerce and conv or None
sql, params = self.parse_node_list(node.arguments, alias_map, conv)
sql = '%s(%s)' % (node.name, sql)
elif isinstance(node, Clause):
sql, params = self.parse_node_list(
node.nodes, alias_map, conv, node.glue)
if node.parens:
sql = '(%s)' % sql
elif isinstance(node, Param):
if node.conv:
params = [node.conv(node.value)]
else:
params = [node.value]
unknown = True
elif isinstance(node, SQL):
sql = node.value
params = list(node.params)
elif isinstance(node, CompoundSelect):
def _parse_query(self, node, alias_map, conv):
if isinstance(node, CompoundSelect):
l, lp = self.generate_select(
node.lhs, self._max_alias(alias_map), alias_map)
node.lhs, alias_map._counter, alias_map)
r, rp = self.generate_select(
node.rhs, self._max_alias(alias_map), alias_map)
node.rhs, alias_map._counter, alias_map)
sql = '%s %s %s' % (l, node.operator, r)
params = lp + rp
elif isinstance(node, SelectQuery):
max_alias = self._max_alias(alias_map)
alias_copy = alias_map and alias_map.copy() or None
alias_map = alias_map or self.alias_map_class()
max_alias = alias_map._counter
alias_copy = alias_map and alias_map.clone() or None
clone = node.clone()
if not node._explicit_selection:
if conv and isinstance(conv, ForeignKeyField):
@@ -1188,14 +1233,21 @@ class QueryCompiler(object):
clone._select = (select_field,)
sub, params = self.generate_select(clone, max_alias, alias_copy)
sql = '(%s)' % sub
return sql, params
def _parse(self, node, alias_map, conv):
# By default treat the incoming node as a raw value that should be
# parameterized.
node_type = getattr(node, '_node_type', None)
unknown = False
if node_type in self._parse_map:
sql, params = self._parse_map[node_type](node, alias_map, conv)
unknown = node_type in self._unknown_types
elif isinstance(node, (list, tuple)):
# If you're wondering how to pass a list into your query, simply
# wrap it in Param().
sql, params = self.parse_node_list(node, alias_map, conv)
sql = '(%s)' % sql
elif isinstance(node, Entity):
sql = '.'.join(map(self.quote, node.path))
params = []
elif isinstance(node, Model):
sql = self.interpolation
if conv and isinstance(conv, ForeignKeyField):
@@ -1203,11 +1255,12 @@ class QueryCompiler(object):
else:
params = [node.get_id()]
elif isclass(node) and issubclass(node, Model):
self._ensure_alias_set(node, alias_map)
entity = node._as_entity().alias(alias_map[node])
sql, params = self.parse_node(entity, alias_map, conv)
else:
sql, params = self._parse_default(node, alias_map, conv)
unknown = True
return sql, params, unknown
def parse_node(self, node, alias_map=None, conv=None):
@@ -1234,16 +1287,13 @@ class QueryCompiler(object):
return glue.join(sql), params
def calculate_alias_map(self, query, start=1):
make_alias = lambda model: model._meta.table_alias or 't%s' % start
alias_map = {query.model_class: make_alias(query.model_class)}
for dest, joins in query._joins.items():
if dest not in alias_map:
start += 1
alias_map[dest] = make_alias(dest)
for join in joins:
if join.dest not in alias_map:
start += 1
alias_map[join.dest] = make_alias(join.dest)
alias_map = self.alias_map_class()
alias_map.add(query.model_class, query.model_class._meta.table_alias)
for src_model, joined_models in query._joins.items():
alias_map.add(src_model, src_model._meta.table_alias)
for join_obj in joined_models:
alias_map.add(join_obj.dest, join_obj.dest._meta.table_alias)
return alias_map
def build_query(self, clauses, alias_map=None):
@@ -1293,8 +1343,7 @@ class QueryCompiler(object):
model = query.model_class
db = model._meta.database
alias_map = alias_map or {}
alias_map.update(self.calculate_alias_map(query, start))
alias_map = self.calculate_alias_map(query, start)
if isinstance(query, CompoundSelect):
clauses = [query]
@@ -1718,6 +1767,7 @@ class ModelQueryResultWrapper(QueryResultWrapper):
class Query(Node):
"""Base class representing a database query on one or more tables."""
_node_type = 'query'
require_commit = True
def __init__(self, model_class):
+6 -16
View File
@@ -16,6 +16,7 @@ except ImportError:
from functools import wraps
from peewee import *
from peewee import AliasMap
from peewee import DeleteQuery
from peewee import InsertQuery
from peewee import logger
@@ -104,23 +105,12 @@ else:
# TEST-ONLY QUERY COMPILER USED TO CREATE "predictable" QUERIES
#
class TestAliasMap(AliasMap):
def add(self, obj, alias=None):
self._alias_map[obj] = obj._meta.db_table
class TestQueryCompiler(QueryCompiler):
def _max_alias(self, alias_map):
return 't0'
def _ensure_alias_set(self, model, alias_map):
if model not in alias_map:
alias_map[model] = model._meta.db_table
def calculate_alias_map(self, query, start=1):
alias_map = {query.model_class: query.model_class._meta.db_table}
for model, joins in query._joins.items():
if model not in alias_map:
alias_map[model] = model._meta.db_table
for join in joins:
if join.dest not in alias_map:
alias_map[join.dest] = join.dest._meta.db_table
return alias_map
alias_map_class = TestAliasMap
class TestDatabase(database_class):
compiler_class = TestQueryCompiler