Cleaning up alias map code.

This commit is contained in:
Charles Leifer
2014-07-11 22:47:35 -05:00
parent 51d82fcd99
commit d8d55df046
2 changed files with 52 additions and 49 deletions
+46 -45
View File
@@ -1050,15 +1050,15 @@ class CompositeKey(object):
class AliasMap(object):
"""
Provide a "scope" for a query.
"""
prefix = 't'
def __init__(self):
self._alias_map = {}
self._counter = 0
def __repr__(self):
return '<AliasMap: %s>' % self._alias_map
def add(self, obj, alias=None):
if obj in self._alias_map:
return
@@ -1074,12 +1074,12 @@ class AliasMap(object):
return obj in self._alias_map
def update(self, alias_map):
self._alias_map.update(alias_map._alias_map)
if alias_map:
for obj, alias in alias_map._alias_map.items():
if obj not in self:
self._alias_map[obj] = alias
return self
def clone(self):
return AliasMap().update(self)
class QueryCompiler(object):
# Mapping of `db_type` to actual column type used by database driver.
@@ -1150,8 +1150,9 @@ class QueryCompiler(object):
'clause': self._parse_clause,
'entity': self._parse_entity,
'field': self._parse_field,
'query': self._parse_query,
'sql': self._parse_sql,
'select_query': self._parse_select_query,
'compound_select_query': self._parse_compound_select_query,
}
self._unknown_types = set(['param'])
@@ -1206,34 +1207,28 @@ class QueryCompiler(object):
def _parse_field(self, node, alias_map, conv):
if alias_map:
entity = Entity(alias_map[node.model_class], node.db_column)
sql = '.'.join((
alias_map[node.model_class], self.quote(node.db_column)))
else:
entity = Entity(node.db_column)
sql, params, _ = self._parse(entity, alias_map, conv)
return sql, params
sql = self.quote(node.db_column)
return sql, []
def _parse_query(self, node, alias_map, conv):
if isinstance(node, CompoundSelect):
l, lp = self.generate_select(
node.lhs, alias_map._counter, alias_map)
r, rp = self.generate_select(
node.rhs, alias_map._counter, alias_map)
sql = '%s %s %s' % (l, node.operator, r)
params = lp + rp
elif isinstance(node, SelectQuery):
alias_map = alias_map or self.alias_map_class()
max_alias = alias_map._counter
alias_copy = alias_map and alias_map.clone() or None
clone = node.clone()
if not node._explicit_selection:
if conv and isinstance(conv, ForeignKeyField):
select_field = conv.to_field
else:
select_field = clone.model_class._meta.primary_key
clone._select = (select_field,)
sub, params = self.generate_select(clone, max_alias, alias_copy)
sql = '(%s)' % sub
return sql, params
def _parse_compound_select_query(self, node, alias_map, conv):
l, lp = self.generate_select(node.lhs, alias_map)
r, rp = self.generate_select(node.rhs, alias_map)
sql = '%s %s %s' % (l, node.operator, r)
return sql, lp + rp
def _parse_select_query(self, node, alias_map, conv):
clone = node.clone()
if not node._explicit_selection:
if conv and isinstance(conv, ForeignKeyField):
select_field = conv.to_field
else:
select_field = clone.model_class._meta.primary_key
clone._select = (select_field,)
sub, params = self.generate_select(clone, alias_map)
return '(%s)' % sub, params
def _parse(self, node, alias_map, conv):
# By default treat the incoming node as a raw value that should be
@@ -1286,15 +1281,18 @@ class QueryCompiler(object):
params.extend(node_params)
return glue.join(sql), params
def calculate_alias_map(self, query, start=1):
alias_map = self.alias_map_class()
alias_map.add(query.model_class, query.model_class._meta.table_alias)
for src_model, joined_models in query._joins.items():
alias_map.add(src_model, src_model._meta.table_alias)
for join_obj in joined_models:
alias_map.add(join_obj.dest, join_obj.dest._meta.table_alias)
def calculate_alias_map(self, query, alias_map=None):
new_map = self.alias_map_class()
if alias_map is not None:
new_map._counter = alias_map._counter
return alias_map
new_map.add(query.model_class, query.model_class._meta.table_alias)
for src_model, joined_models in query._joins.items():
new_map.add(src_model, src_model._meta.table_alias)
for join_obj in joined_models:
new_map.add(join_obj.dest, join_obj.dest._meta.table_alias)
return new_map.update(alias_map)
def build_query(self, clauses, alias_map=None):
return self.parse_node(Clause(*clauses), alias_map)
@@ -1339,11 +1337,11 @@ class QueryCompiler(object):
return clauses
def generate_select(self, query, start=1, alias_map=None):
def generate_select(self, query, alias_map=None):
model = query.model_class
db = model._meta.database
alias_map = self.calculate_alias_map(query, start)
alias_map = self.calculate_alias_map(query, alias_map)
if isinstance(query, CompoundSelect):
clauses = [query]
@@ -1767,7 +1765,6 @@ class ModelQueryResultWrapper(QueryResultWrapper):
class Query(Node):
"""Base class representing a database query on one or more tables."""
_node_type = 'query'
require_commit = True
def __init__(self, model_class):
@@ -1965,6 +1962,8 @@ class RawQuery(Query):
return iter(self.execute())
class SelectQuery(Query):
_node_type = 'select_query'
def __init__(self, model_class, *selection):
super(SelectQuery, self).__init__(model_class)
self.require_commit = self.database.commit_select
@@ -2219,6 +2218,8 @@ class SelectQuery(Query):
return res._result_cache[value]
class CompoundSelect(SelectQuery):
_node_type = 'compound_select_query'
def __init__(self, model_class, lhs=None, operator=None, rhs=None):
self.lhs = lhs
self.operator = operator
+6 -4
View File
@@ -919,7 +919,7 @@ class SelectTestCase(BasePeeweeTestCase):
# e.g. annotate the number of blogs per user, then annotate the number
# of users with that number of blogs.
inner = (Blog
.select(fn.COUNT(Blog.id).alias('blog_ct'))
.select(fn.COUNT(Blog.pk).alias('blog_ct'))
.group_by(Blog.user))
blog_ct = SQL('blog_ct')
outer = (Blog
@@ -930,7 +930,7 @@ class SelectTestCase(BasePeeweeTestCase):
self.assertEqual(sql, (
'SELECT blog_ct, COUNT(blog_ct) AS blog_ct_n '
'FROM ('
'SELECT COUNT("id") AS blog_ct FROM "blog" AS blog '
'SELECT COUNT(blog."pk") AS blog_ct FROM "blog" AS blog '
'GROUP BY blog."user_id") '
'GROUP BY blog_ct'))
@@ -1030,7 +1030,9 @@ class SelectTestCase(BasePeeweeTestCase):
self.assertEqual(normal_compiler.generate_select(query), expected)
def test_outer_inner_alias(self):
expected = 'SELECT t1."id", t1."username", (SELECT Sum(t2."id") FROM "users" AS t2 WHERE (t2."id" = t1."id")) AS xxx FROM "users" AS t1'
expected = ('SELECT t1."id", t1."username", '
'(SELECT Sum(t2."id") FROM "users" AS t2 '
'WHERE (t2."id" = t1."id")) AS xxx FROM "users" AS t1')
UA = User.alias()
inner = SelectQuery(UA, fn.Sum(UA.id)).where(UA.id == User.id)
query = User.select(User, inner.alias('xxx'))
@@ -2629,7 +2631,7 @@ class ModelAggregateTestCase(ModelTestCase):
def test_annotate_int(self):
users = self.create_user_blogs()
annotated = User.select().annotate(Blog, fn.Count(Blog.id).alias('ct'))
annotated = User.select().annotate(Blog, fn.Count(Blog.pk).alias('ct'))
for i, user in enumerate(annotated):
self.assertEqual(user.ct, 2)
self.assertEqual(user.username, 'u-%d' % i)