mirror of
https://github.com/coleifer/peewee.git
synced 2026-05-06 07:56:41 -04:00
Refactoring alias generation.
This commit is contained in:
@@ -222,6 +222,7 @@ DJANGO_MAP = {
|
||||
|
||||
JOIN_INNER = 'inner'
|
||||
JOIN_LEFT_OUTER = 'left outer'
|
||||
JOIN_RIGHT_OUTER = 'right outer'
|
||||
JOIN_FULL = 'full'
|
||||
|
||||
# Helper functions that are used in various parts of the codebase.
|
||||
@@ -293,6 +294,7 @@ class _CDescriptor(object):
|
||||
class Node(object):
|
||||
"""Base-class for any part of a query which shall be composable."""
|
||||
c = _CDescriptor()
|
||||
_node_type = 'node'
|
||||
|
||||
def __init__(self):
|
||||
self._negated = False
|
||||
@@ -388,6 +390,8 @@ class Node(object):
|
||||
|
||||
class Expression(Node):
|
||||
"""A binary expression, e.g `foo + 1` or `bar < 7`."""
|
||||
_node_type = 'expression'
|
||||
|
||||
def __init__(self, lhs, op, rhs, flat=False):
|
||||
super(Expression, self).__init__()
|
||||
self.lhs = lhs
|
||||
@@ -413,6 +417,8 @@ class Param(Node):
|
||||
specifically treat this value as a parameter, useful for `list` which is
|
||||
special-cased for `IN` lookups.
|
||||
"""
|
||||
_node_type = 'param'
|
||||
|
||||
def __init__(self, value, conv=None):
|
||||
self.value = value
|
||||
self.conv = conv
|
||||
@@ -423,6 +429,8 @@ class Param(Node):
|
||||
|
||||
class SQL(Node):
|
||||
"""An unescaped SQL string, with optional parameters."""
|
||||
_node_type = 'sql'
|
||||
|
||||
def __init__(self, value, *params):
|
||||
self.value = value
|
||||
self.params = params
|
||||
@@ -434,6 +442,8 @@ R = SQL # backwards-compat.
|
||||
|
||||
class Func(Node):
|
||||
"""An arbitrary SQL function call."""
|
||||
_node_type = 'func'
|
||||
|
||||
def __init__(self, name, *arguments):
|
||||
self.name = name
|
||||
self.arguments = arguments
|
||||
@@ -492,6 +502,8 @@ class Window(Node):
|
||||
|
||||
class Clause(Node):
|
||||
"""A SQL clause, one or more Node objects joined by spaces."""
|
||||
_node_type = 'clause'
|
||||
|
||||
glue = ' '
|
||||
parens = False
|
||||
|
||||
@@ -515,6 +527,8 @@ class EnclosedClause(CommaClause):
|
||||
|
||||
class Entity(Node):
|
||||
"""A quoted-name or entity, e.g. "table"."column"."""
|
||||
_node_type = 'entity'
|
||||
|
||||
def __init__(self, *path):
|
||||
super(Entity, self).__init__()
|
||||
self.path = path
|
||||
@@ -552,6 +566,7 @@ class Field(Node):
|
||||
"""A column on a table."""
|
||||
_field_counter = 0
|
||||
_order = 0
|
||||
_node_type = 'field'
|
||||
db_field = 'unknown'
|
||||
|
||||
def __init__(self, null=False, index=False, unique=False,
|
||||
@@ -1034,12 +1049,36 @@ class CompositeKey(object):
|
||||
return reduce(operator.and_, expressions)
|
||||
|
||||
|
||||
class QueryContext(object):
|
||||
class AliasMap(object):
|
||||
"""
|
||||
Provide a "scope" for a query.
|
||||
"""
|
||||
prefix = 't'
|
||||
|
||||
def __init__(self):
|
||||
self._alias_map = {}
|
||||
self._counter = 0
|
||||
|
||||
def add(self, obj, alias=None):
|
||||
if obj in self._alias_map:
|
||||
return
|
||||
self._counter += 1
|
||||
self._alias_map[obj] = alias or '%s%s' % (self.prefix, self._counter)
|
||||
|
||||
def __getitem__(self, obj):
|
||||
if obj not in self._alias_map:
|
||||
self.add(obj)
|
||||
return self._alias_map[obj]
|
||||
|
||||
def __contains__(self, obj):
|
||||
return obj in self._alias_map
|
||||
|
||||
def update(self, alias_map):
|
||||
self._alias_map.update(alias_map._alias_map)
|
||||
return self
|
||||
|
||||
def clone(self):
|
||||
return AliasMap().update(self)
|
||||
|
||||
|
||||
class QueryCompiler(object):
|
||||
@@ -1093,8 +1132,10 @@ class QueryCompiler(object):
|
||||
join_map = {
|
||||
JOIN_INNER: 'INNER',
|
||||
JOIN_LEFT_OUTER: 'LEFT OUTER',
|
||||
JOIN_RIGHT_OUTER: 'RIGHT OUTER',
|
||||
JOIN_FULL: 'FULL',
|
||||
}
|
||||
alias_map_class = AliasMap
|
||||
|
||||
def __init__(self, quote_char='"', interpolation='?', field_overrides=None,
|
||||
op_overrides=None):
|
||||
@@ -1102,6 +1143,17 @@ class QueryCompiler(object):
|
||||
self.interpolation = interpolation
|
||||
self._field_map = merge_dict(self.field_map, field_overrides or {})
|
||||
self._op_map = merge_dict(self.op_map, op_overrides or {})
|
||||
self._parse_map = {
|
||||
'expression': self._parse_expression,
|
||||
'param': self._parse_param,
|
||||
'func': self._parse_func,
|
||||
'clause': self._parse_clause,
|
||||
'entity': self._parse_entity,
|
||||
'field': self._parse_field,
|
||||
'query': self._parse_query,
|
||||
'sql': self._parse_sql,
|
||||
}
|
||||
self._unknown_types = set(['param'])
|
||||
|
||||
def quote(self, s):
|
||||
return '%s%s%s' % (self.quote_char, s, self.quote_char)
|
||||
@@ -1115,70 +1167,63 @@ class QueryCompiler(object):
|
||||
def _sorted_fields(self, field_dict):
|
||||
return sorted(field_dict.items(), key=lambda i: i[0]._sort_key)
|
||||
|
||||
def _max_alias(self, alias_map):
|
||||
max_alias = 0
|
||||
def _parse_default(self, node, alias_map, conv):
|
||||
return self.interpolation, [node]
|
||||
|
||||
def _parse_expression(self, node, alias_map, conv):
|
||||
if isinstance(node.lhs, Field):
|
||||
conv = node.lhs
|
||||
lhs, lparams = self.parse_node(node.lhs, alias_map, conv)
|
||||
rhs, rparams = self.parse_node(node.rhs, alias_map, conv)
|
||||
template = '%s %s %s' if node.flat else '(%s %s %s)'
|
||||
sql = template % (lhs, self.get_op(node.op), rhs)
|
||||
return sql, lparams + rparams
|
||||
|
||||
def _parse_param(self, node, alias_map, conv):
|
||||
if node.conv:
|
||||
params = [node.conv(node.value)]
|
||||
else:
|
||||
params = [node.value]
|
||||
return self.interpolation, params
|
||||
|
||||
def _parse_func(self, node, alias_map, conv):
|
||||
conv = node._coerce and conv or None
|
||||
sql, params = self.parse_node_list(node.arguments, alias_map, conv)
|
||||
return '%s(%s)' % (node.name, sql), params
|
||||
|
||||
def _parse_clause(self, node, alias_map, conv):
|
||||
sql, params = self.parse_node_list(
|
||||
node.nodes, alias_map, conv, node.glue)
|
||||
if node.parens:
|
||||
sql = '(%s)' % sql
|
||||
return sql, params
|
||||
|
||||
def _parse_entity(self, node, alias_map, conv):
|
||||
return '.'.join(map(self.quote, node.path)), []
|
||||
|
||||
def _parse_sql(self, node, alias_map, conv):
|
||||
return node.value, list(node.params)
|
||||
|
||||
def _parse_field(self, node, alias_map, conv):
|
||||
if alias_map:
|
||||
for alias in alias_map.values():
|
||||
try:
|
||||
alias_number = int(alias.lstrip('t'))
|
||||
except ValueError:
|
||||
alias_number = 0
|
||||
if alias_number > max_alias:
|
||||
max_alias = alias_number
|
||||
return max_alias + 1
|
||||
entity = Entity(alias_map[node.model_class], node.db_column)
|
||||
else:
|
||||
entity = Entity(node.db_column)
|
||||
sql, params, _ = self._parse(entity, alias_map, conv)
|
||||
return sql, params
|
||||
|
||||
def _ensure_alias_set(self, model, alias_map):
|
||||
if model not in alias_map:
|
||||
max_alias = self._max_alias(alias_map)
|
||||
alias_map[model] = 't%d' % max_alias
|
||||
|
||||
def _parse(self, node, alias_map, conv):
|
||||
# By default treat the incoming node as a raw value that should be
|
||||
# parameterized.
|
||||
sql = self.interpolation
|
||||
params = [node]
|
||||
unknown = False
|
||||
if isinstance(node, Expression):
|
||||
if isinstance(node.lhs, Field):
|
||||
conv = node.lhs
|
||||
lhs, lparams = self.parse_node(node.lhs, alias_map, conv)
|
||||
rhs, rparams = self.parse_node(node.rhs, alias_map, conv)
|
||||
template = '%s %s %s' if node.flat else '(%s %s %s)'
|
||||
sql = template % (lhs, self.get_op(node.op), rhs)
|
||||
params = lparams + rparams
|
||||
elif isinstance(node, Field):
|
||||
sql = self.quote(node.db_column)
|
||||
if alias_map and node.model_class in alias_map:
|
||||
sql = '.'.join((alias_map[node.model_class], sql))
|
||||
params = []
|
||||
elif isinstance(node, Func):
|
||||
conv = node._coerce and conv or None
|
||||
sql, params = self.parse_node_list(node.arguments, alias_map, conv)
|
||||
sql = '%s(%s)' % (node.name, sql)
|
||||
elif isinstance(node, Clause):
|
||||
sql, params = self.parse_node_list(
|
||||
node.nodes, alias_map, conv, node.glue)
|
||||
if node.parens:
|
||||
sql = '(%s)' % sql
|
||||
elif isinstance(node, Param):
|
||||
if node.conv:
|
||||
params = [node.conv(node.value)]
|
||||
else:
|
||||
params = [node.value]
|
||||
unknown = True
|
||||
elif isinstance(node, SQL):
|
||||
sql = node.value
|
||||
params = list(node.params)
|
||||
elif isinstance(node, CompoundSelect):
|
||||
def _parse_query(self, node, alias_map, conv):
|
||||
if isinstance(node, CompoundSelect):
|
||||
l, lp = self.generate_select(
|
||||
node.lhs, self._max_alias(alias_map), alias_map)
|
||||
node.lhs, alias_map._counter, alias_map)
|
||||
r, rp = self.generate_select(
|
||||
node.rhs, self._max_alias(alias_map), alias_map)
|
||||
node.rhs, alias_map._counter, alias_map)
|
||||
sql = '%s %s %s' % (l, node.operator, r)
|
||||
params = lp + rp
|
||||
elif isinstance(node, SelectQuery):
|
||||
max_alias = self._max_alias(alias_map)
|
||||
alias_copy = alias_map and alias_map.copy() or None
|
||||
alias_map = alias_map or self.alias_map_class()
|
||||
max_alias = alias_map._counter
|
||||
alias_copy = alias_map and alias_map.clone() or None
|
||||
clone = node.clone()
|
||||
if not node._explicit_selection:
|
||||
if conv and isinstance(conv, ForeignKeyField):
|
||||
@@ -1188,14 +1233,21 @@ class QueryCompiler(object):
|
||||
clone._select = (select_field,)
|
||||
sub, params = self.generate_select(clone, max_alias, alias_copy)
|
||||
sql = '(%s)' % sub
|
||||
return sql, params
|
||||
|
||||
def _parse(self, node, alias_map, conv):
|
||||
# By default treat the incoming node as a raw value that should be
|
||||
# parameterized.
|
||||
node_type = getattr(node, '_node_type', None)
|
||||
unknown = False
|
||||
if node_type in self._parse_map:
|
||||
sql, params = self._parse_map[node_type](node, alias_map, conv)
|
||||
unknown = node_type in self._unknown_types
|
||||
elif isinstance(node, (list, tuple)):
|
||||
# If you're wondering how to pass a list into your query, simply
|
||||
# wrap it in Param().
|
||||
sql, params = self.parse_node_list(node, alias_map, conv)
|
||||
sql = '(%s)' % sql
|
||||
elif isinstance(node, Entity):
|
||||
sql = '.'.join(map(self.quote, node.path))
|
||||
params = []
|
||||
elif isinstance(node, Model):
|
||||
sql = self.interpolation
|
||||
if conv and isinstance(conv, ForeignKeyField):
|
||||
@@ -1203,11 +1255,12 @@ class QueryCompiler(object):
|
||||
else:
|
||||
params = [node.get_id()]
|
||||
elif isclass(node) and issubclass(node, Model):
|
||||
self._ensure_alias_set(node, alias_map)
|
||||
entity = node._as_entity().alias(alias_map[node])
|
||||
sql, params = self.parse_node(entity, alias_map, conv)
|
||||
else:
|
||||
sql, params = self._parse_default(node, alias_map, conv)
|
||||
unknown = True
|
||||
|
||||
return sql, params, unknown
|
||||
|
||||
def parse_node(self, node, alias_map=None, conv=None):
|
||||
@@ -1234,16 +1287,13 @@ class QueryCompiler(object):
|
||||
return glue.join(sql), params
|
||||
|
||||
def calculate_alias_map(self, query, start=1):
|
||||
make_alias = lambda model: model._meta.table_alias or 't%s' % start
|
||||
alias_map = {query.model_class: make_alias(query.model_class)}
|
||||
for dest, joins in query._joins.items():
|
||||
if dest not in alias_map:
|
||||
start += 1
|
||||
alias_map[dest] = make_alias(dest)
|
||||
for join in joins:
|
||||
if join.dest not in alias_map:
|
||||
start += 1
|
||||
alias_map[join.dest] = make_alias(join.dest)
|
||||
alias_map = self.alias_map_class()
|
||||
alias_map.add(query.model_class, query.model_class._meta.table_alias)
|
||||
for src_model, joined_models in query._joins.items():
|
||||
alias_map.add(src_model, src_model._meta.table_alias)
|
||||
for join_obj in joined_models:
|
||||
alias_map.add(join_obj.dest, join_obj.dest._meta.table_alias)
|
||||
|
||||
return alias_map
|
||||
|
||||
def build_query(self, clauses, alias_map=None):
|
||||
@@ -1293,8 +1343,7 @@ class QueryCompiler(object):
|
||||
model = query.model_class
|
||||
db = model._meta.database
|
||||
|
||||
alias_map = alias_map or {}
|
||||
alias_map.update(self.calculate_alias_map(query, start))
|
||||
alias_map = self.calculate_alias_map(query, start)
|
||||
|
||||
if isinstance(query, CompoundSelect):
|
||||
clauses = [query]
|
||||
@@ -1718,6 +1767,7 @@ class ModelQueryResultWrapper(QueryResultWrapper):
|
||||
|
||||
class Query(Node):
|
||||
"""Base class representing a database query on one or more tables."""
|
||||
_node_type = 'query'
|
||||
require_commit = True
|
||||
|
||||
def __init__(self, model_class):
|
||||
|
||||
@@ -16,6 +16,7 @@ except ImportError:
|
||||
from functools import wraps
|
||||
|
||||
from peewee import *
|
||||
from peewee import AliasMap
|
||||
from peewee import DeleteQuery
|
||||
from peewee import InsertQuery
|
||||
from peewee import logger
|
||||
@@ -104,23 +105,12 @@ else:
|
||||
# TEST-ONLY QUERY COMPILER USED TO CREATE "predictable" QUERIES
|
||||
#
|
||||
|
||||
class TestAliasMap(AliasMap):
|
||||
def add(self, obj, alias=None):
|
||||
self._alias_map[obj] = obj._meta.db_table
|
||||
|
||||
class TestQueryCompiler(QueryCompiler):
|
||||
def _max_alias(self, alias_map):
|
||||
return 't0'
|
||||
|
||||
def _ensure_alias_set(self, model, alias_map):
|
||||
if model not in alias_map:
|
||||
alias_map[model] = model._meta.db_table
|
||||
|
||||
def calculate_alias_map(self, query, start=1):
|
||||
alias_map = {query.model_class: query.model_class._meta.db_table}
|
||||
for model, joins in query._joins.items():
|
||||
if model not in alias_map:
|
||||
alias_map[model] = model._meta.db_table
|
||||
for join in joins:
|
||||
if join.dest not in alias_map:
|
||||
alias_map[join.dest] = join.dest._meta.db_table
|
||||
return alias_map
|
||||
alias_map_class = TestAliasMap
|
||||
|
||||
class TestDatabase(database_class):
|
||||
compiler_class = TestQueryCompiler
|
||||
|
||||
Reference in New Issue
Block a user