mirror of
https://github.com/coleifer/peewee.git
synced 2026-05-06 07:56:41 -04:00
Cleaning up alias map code.
This commit is contained in:
@@ -1050,15 +1050,15 @@ class CompositeKey(object):
|
||||
|
||||
|
||||
class AliasMap(object):
|
||||
"""
|
||||
Provide a "scope" for a query.
|
||||
"""
|
||||
prefix = 't'
|
||||
|
||||
def __init__(self):
|
||||
self._alias_map = {}
|
||||
self._counter = 0
|
||||
|
||||
def __repr__(self):
|
||||
return '<AliasMap: %s>' % self._alias_map
|
||||
|
||||
def add(self, obj, alias=None):
|
||||
if obj in self._alias_map:
|
||||
return
|
||||
@@ -1074,12 +1074,12 @@ class AliasMap(object):
|
||||
return obj in self._alias_map
|
||||
|
||||
def update(self, alias_map):
|
||||
self._alias_map.update(alias_map._alias_map)
|
||||
if alias_map:
|
||||
for obj, alias in alias_map._alias_map.items():
|
||||
if obj not in self:
|
||||
self._alias_map[obj] = alias
|
||||
return self
|
||||
|
||||
def clone(self):
|
||||
return AliasMap().update(self)
|
||||
|
||||
|
||||
class QueryCompiler(object):
|
||||
# Mapping of `db_type` to actual column type used by database driver.
|
||||
@@ -1150,8 +1150,9 @@ class QueryCompiler(object):
|
||||
'clause': self._parse_clause,
|
||||
'entity': self._parse_entity,
|
||||
'field': self._parse_field,
|
||||
'query': self._parse_query,
|
||||
'sql': self._parse_sql,
|
||||
'select_query': self._parse_select_query,
|
||||
'compound_select_query': self._parse_compound_select_query,
|
||||
}
|
||||
self._unknown_types = set(['param'])
|
||||
|
||||
@@ -1206,34 +1207,28 @@ class QueryCompiler(object):
|
||||
|
||||
def _parse_field(self, node, alias_map, conv):
|
||||
if alias_map:
|
||||
entity = Entity(alias_map[node.model_class], node.db_column)
|
||||
sql = '.'.join((
|
||||
alias_map[node.model_class], self.quote(node.db_column)))
|
||||
else:
|
||||
entity = Entity(node.db_column)
|
||||
sql, params, _ = self._parse(entity, alias_map, conv)
|
||||
return sql, params
|
||||
sql = self.quote(node.db_column)
|
||||
return sql, []
|
||||
|
||||
def _parse_query(self, node, alias_map, conv):
|
||||
if isinstance(node, CompoundSelect):
|
||||
l, lp = self.generate_select(
|
||||
node.lhs, alias_map._counter, alias_map)
|
||||
r, rp = self.generate_select(
|
||||
node.rhs, alias_map._counter, alias_map)
|
||||
sql = '%s %s %s' % (l, node.operator, r)
|
||||
params = lp + rp
|
||||
elif isinstance(node, SelectQuery):
|
||||
alias_map = alias_map or self.alias_map_class()
|
||||
max_alias = alias_map._counter
|
||||
alias_copy = alias_map and alias_map.clone() or None
|
||||
clone = node.clone()
|
||||
if not node._explicit_selection:
|
||||
if conv and isinstance(conv, ForeignKeyField):
|
||||
select_field = conv.to_field
|
||||
else:
|
||||
select_field = clone.model_class._meta.primary_key
|
||||
clone._select = (select_field,)
|
||||
sub, params = self.generate_select(clone, max_alias, alias_copy)
|
||||
sql = '(%s)' % sub
|
||||
return sql, params
|
||||
def _parse_compound_select_query(self, node, alias_map, conv):
|
||||
l, lp = self.generate_select(node.lhs, alias_map)
|
||||
r, rp = self.generate_select(node.rhs, alias_map)
|
||||
sql = '%s %s %s' % (l, node.operator, r)
|
||||
return sql, lp + rp
|
||||
|
||||
def _parse_select_query(self, node, alias_map, conv):
|
||||
clone = node.clone()
|
||||
if not node._explicit_selection:
|
||||
if conv and isinstance(conv, ForeignKeyField):
|
||||
select_field = conv.to_field
|
||||
else:
|
||||
select_field = clone.model_class._meta.primary_key
|
||||
clone._select = (select_field,)
|
||||
sub, params = self.generate_select(clone, alias_map)
|
||||
return '(%s)' % sub, params
|
||||
|
||||
def _parse(self, node, alias_map, conv):
|
||||
# By default treat the incoming node as a raw value that should be
|
||||
@@ -1286,15 +1281,18 @@ class QueryCompiler(object):
|
||||
params.extend(node_params)
|
||||
return glue.join(sql), params
|
||||
|
||||
def calculate_alias_map(self, query, start=1):
|
||||
alias_map = self.alias_map_class()
|
||||
alias_map.add(query.model_class, query.model_class._meta.table_alias)
|
||||
for src_model, joined_models in query._joins.items():
|
||||
alias_map.add(src_model, src_model._meta.table_alias)
|
||||
for join_obj in joined_models:
|
||||
alias_map.add(join_obj.dest, join_obj.dest._meta.table_alias)
|
||||
def calculate_alias_map(self, query, alias_map=None):
|
||||
new_map = self.alias_map_class()
|
||||
if alias_map is not None:
|
||||
new_map._counter = alias_map._counter
|
||||
|
||||
return alias_map
|
||||
new_map.add(query.model_class, query.model_class._meta.table_alias)
|
||||
for src_model, joined_models in query._joins.items():
|
||||
new_map.add(src_model, src_model._meta.table_alias)
|
||||
for join_obj in joined_models:
|
||||
new_map.add(join_obj.dest, join_obj.dest._meta.table_alias)
|
||||
|
||||
return new_map.update(alias_map)
|
||||
|
||||
def build_query(self, clauses, alias_map=None):
|
||||
return self.parse_node(Clause(*clauses), alias_map)
|
||||
@@ -1339,11 +1337,11 @@ class QueryCompiler(object):
|
||||
|
||||
return clauses
|
||||
|
||||
def generate_select(self, query, start=1, alias_map=None):
|
||||
def generate_select(self, query, alias_map=None):
|
||||
model = query.model_class
|
||||
db = model._meta.database
|
||||
|
||||
alias_map = self.calculate_alias_map(query, start)
|
||||
alias_map = self.calculate_alias_map(query, alias_map)
|
||||
|
||||
if isinstance(query, CompoundSelect):
|
||||
clauses = [query]
|
||||
@@ -1767,7 +1765,6 @@ class ModelQueryResultWrapper(QueryResultWrapper):
|
||||
|
||||
class Query(Node):
|
||||
"""Base class representing a database query on one or more tables."""
|
||||
_node_type = 'query'
|
||||
require_commit = True
|
||||
|
||||
def __init__(self, model_class):
|
||||
@@ -1965,6 +1962,8 @@ class RawQuery(Query):
|
||||
return iter(self.execute())
|
||||
|
||||
class SelectQuery(Query):
|
||||
_node_type = 'select_query'
|
||||
|
||||
def __init__(self, model_class, *selection):
|
||||
super(SelectQuery, self).__init__(model_class)
|
||||
self.require_commit = self.database.commit_select
|
||||
@@ -2219,6 +2218,8 @@ class SelectQuery(Query):
|
||||
return res._result_cache[value]
|
||||
|
||||
class CompoundSelect(SelectQuery):
|
||||
_node_type = 'compound_select_query'
|
||||
|
||||
def __init__(self, model_class, lhs=None, operator=None, rhs=None):
|
||||
self.lhs = lhs
|
||||
self.operator = operator
|
||||
|
||||
@@ -919,7 +919,7 @@ class SelectTestCase(BasePeeweeTestCase):
|
||||
# e.g. annotate the number of blogs per user, then annotate the number
|
||||
# of users with that number of blogs.
|
||||
inner = (Blog
|
||||
.select(fn.COUNT(Blog.id).alias('blog_ct'))
|
||||
.select(fn.COUNT(Blog.pk).alias('blog_ct'))
|
||||
.group_by(Blog.user))
|
||||
blog_ct = SQL('blog_ct')
|
||||
outer = (Blog
|
||||
@@ -930,7 +930,7 @@ class SelectTestCase(BasePeeweeTestCase):
|
||||
self.assertEqual(sql, (
|
||||
'SELECT blog_ct, COUNT(blog_ct) AS blog_ct_n '
|
||||
'FROM ('
|
||||
'SELECT COUNT("id") AS blog_ct FROM "blog" AS blog '
|
||||
'SELECT COUNT(blog."pk") AS blog_ct FROM "blog" AS blog '
|
||||
'GROUP BY blog."user_id") '
|
||||
'GROUP BY blog_ct'))
|
||||
|
||||
@@ -1030,7 +1030,9 @@ class SelectTestCase(BasePeeweeTestCase):
|
||||
self.assertEqual(normal_compiler.generate_select(query), expected)
|
||||
|
||||
def test_outer_inner_alias(self):
|
||||
expected = 'SELECT t1."id", t1."username", (SELECT Sum(t2."id") FROM "users" AS t2 WHERE (t2."id" = t1."id")) AS xxx FROM "users" AS t1'
|
||||
expected = ('SELECT t1."id", t1."username", '
|
||||
'(SELECT Sum(t2."id") FROM "users" AS t2 '
|
||||
'WHERE (t2."id" = t1."id")) AS xxx FROM "users" AS t1')
|
||||
UA = User.alias()
|
||||
inner = SelectQuery(UA, fn.Sum(UA.id)).where(UA.id == User.id)
|
||||
query = User.select(User, inner.alias('xxx'))
|
||||
@@ -2629,7 +2631,7 @@ class ModelAggregateTestCase(ModelTestCase):
|
||||
|
||||
def test_annotate_int(self):
|
||||
users = self.create_user_blogs()
|
||||
annotated = User.select().annotate(Blog, fn.Count(Blog.id).alias('ct'))
|
||||
annotated = User.select().annotate(Blog, fn.Count(Blog.pk).alias('ct'))
|
||||
for i, user in enumerate(annotated):
|
||||
self.assertEqual(user.ct, 2)
|
||||
self.assertEqual(user.username, 'u-%d' % i)
|
||||
|
||||
Reference in New Issue
Block a user