Files
sqlalchemy/test/ext/test_compiler.py
T
Khairi Hafsham 772374735d Make all tests to be PEP8 compliant
tested using pycodestyle version 2.2.0

Fixes: #3885
Change-Id: I5df43adc3aefe318f9eeab72a078247a548ec566
Pull-request: https://github.com/zzzeek/sqlalchemy/pull/343
2017-02-07 11:21:56 -05:00

427 lines
12 KiB
Python

from sqlalchemy import *
from sqlalchemy.types import TypeEngine
from sqlalchemy.sql.expression import ClauseElement, ColumnClause,\
FunctionElement, Select, \
BindParameter, ColumnElement
from sqlalchemy.schema import DDLElement, CreateColumn, CreateTable
from sqlalchemy.ext.compiler import compiles, deregister
from sqlalchemy import exc
from sqlalchemy.testing import eq_
from sqlalchemy.sql import table, column
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import fixtures, AssertsCompiledSQL
class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = 'default'
def test_column(self):
class MyThingy(ColumnClause):
def __init__(self, arg=None):
super(MyThingy, self).__init__(arg or 'MYTHINGY!')
@compiles(MyThingy)
def visit_thingy(thingy, compiler, **kw):
return ">>%s<<" % thingy.name
self.assert_compile(
select([column('foo'), MyThingy()]),
"SELECT foo, >>MYTHINGY!<<"
)
self.assert_compile(
select([MyThingy('x'), MyThingy('y')]).where(MyThingy() == 5),
"SELECT >>x<<, >>y<< WHERE >>MYTHINGY!<< = :MYTHINGY!_1"
)
def test_create_column_skip(self):
@compiles(CreateColumn)
def skip_xmin(element, compiler, **kw):
if element.element.name == 'xmin':
return None
else:
return compiler.visit_create_column(element, **kw)
t = Table('t', MetaData(), Column('a', Integer),
Column('xmin', Integer),
Column('c', Integer))
self.assert_compile(
CreateTable(t),
"CREATE TABLE t (a INTEGER, c INTEGER)"
)
def test_types(self):
class MyType(TypeEngine):
pass
@compiles(MyType, 'sqlite')
def visit_type(type, compiler, **kw):
return "SQLITE_FOO"
@compiles(MyType, 'postgresql')
def visit_type(type, compiler, **kw):
return "POSTGRES_FOO"
from sqlalchemy.dialects.sqlite import base as sqlite
from sqlalchemy.dialects.postgresql import base as postgresql
self.assert_compile(
MyType(),
"SQLITE_FOO",
dialect=sqlite.dialect()
)
self.assert_compile(
MyType(),
"POSTGRES_FOO",
dialect=postgresql.dialect()
)
def test_stateful(self):
class MyThingy(ColumnClause):
def __init__(self):
super(MyThingy, self).__init__('MYTHINGY!')
@compiles(MyThingy)
def visit_thingy(thingy, compiler, **kw):
if not hasattr(compiler, 'counter'):
compiler.counter = 0
compiler.counter += 1
return str(compiler.counter)
self.assert_compile(
select([column('foo'), MyThingy()]).order_by(desc(MyThingy())),
"SELECT foo, 1 ORDER BY 2 DESC"
)
self.assert_compile(
select([MyThingy(), MyThingy()]).where(MyThingy() == 5),
"SELECT 1, 2 WHERE 3 = :MYTHINGY!_1"
)
def test_callout_to_compiler(self):
class InsertFromSelect(ClauseElement):
def __init__(self, table, select):
self.table = table
self.select = select
@compiles(InsertFromSelect)
def visit_insert_from_select(element, compiler, **kw):
return "INSERT INTO %s (%s)" % (
compiler.process(element.table, asfrom=True),
compiler.process(element.select)
)
t1 = table("mytable", column('x'), column('y'), column('z'))
self.assert_compile(
InsertFromSelect(
t1,
select([t1]).where(t1.c.x > 5)
),
"INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z "
"FROM mytable WHERE mytable.x > :x_1)"
)
def test_no_default_but_has_a_visit(self):
class MyThingy(ColumnClause):
pass
@compiles(MyThingy, 'postgresql')
def visit_thingy(thingy, compiler, **kw):
return "mythingy"
eq_(str(MyThingy('x')), "x")
def test_no_default_has_no_visit(self):
class MyThingy(TypeEngine):
pass
@compiles(MyThingy, 'postgresql')
def visit_thingy(thingy, compiler, **kw):
return "mythingy"
assert_raises_message(
exc.CompileError,
"<class 'test.ext.test_compiler..*MyThingy'> "
"construct has no default compilation handler.",
str, MyThingy()
)
def test_no_default_message(self):
class MyThingy(ClauseElement):
pass
@compiles(MyThingy, 'postgresql')
def visit_thingy(thingy, compiler, **kw):
return "mythingy"
assert_raises_message(
exc.CompileError,
"<class 'test.ext.test_compiler..*MyThingy'> "
"construct has no default compilation handler.",
str, MyThingy()
)
def test_default_subclass(self):
from sqlalchemy.types import ARRAY
class MyArray(ARRAY):
pass
@compiles(MyArray, "sqlite")
def sl_array(elem, compiler, **kw):
return "array"
self.assert_compile(
MyArray(Integer),
"INTEGER[]",
dialect="postgresql"
)
def test_annotations(self):
"""test that annotated clause constructs use the
decorated class' compiler.
"""
t1 = table('t1', column('c1'), column('c2'))
dispatch = Select._compiler_dispatch
try:
@compiles(Select)
def compile(element, compiler, **kw):
return "OVERRIDE"
s1 = select([t1])
self.assert_compile(
s1, "OVERRIDE"
)
self.assert_compile(
s1._annotate({}),
"OVERRIDE"
)
finally:
Select._compiler_dispatch = dispatch
if hasattr(Select, '_compiler_dispatcher'):
del Select._compiler_dispatcher
def test_dialect_specific(self):
class AddThingy(DDLElement):
__visit_name__ = 'add_thingy'
class DropThingy(DDLElement):
__visit_name__ = 'drop_thingy'
@compiles(AddThingy, 'sqlite')
def visit_add_thingy(thingy, compiler, **kw):
return "ADD SPECIAL SL THINGY"
@compiles(AddThingy)
def visit_add_thingy(thingy, compiler, **kw):
return "ADD THINGY"
@compiles(DropThingy)
def visit_drop_thingy(thingy, compiler, **kw):
return "DROP THINGY"
self.assert_compile(AddThingy(), "ADD THINGY")
self.assert_compile(DropThingy(), "DROP THINGY")
from sqlalchemy.dialects.sqlite import base
self.assert_compile(AddThingy(),
"ADD SPECIAL SL THINGY",
dialect=base.dialect())
self.assert_compile(DropThingy(),
"DROP THINGY",
dialect=base.dialect())
@compiles(DropThingy, 'sqlite')
def visit_drop_thingy(thingy, compiler, **kw):
return "DROP SPECIAL SL THINGY"
self.assert_compile(DropThingy(),
"DROP SPECIAL SL THINGY",
dialect=base.dialect())
self.assert_compile(DropThingy(), "DROP THINGY")
def test_functions(self):
from sqlalchemy.dialects import postgresql
class MyUtcFunction(FunctionElement):
pass
@compiles(MyUtcFunction)
def visit_myfunc(element, compiler, **kw):
return "utcnow()"
@compiles(MyUtcFunction, 'postgresql')
def visit_myfunc(element, compiler, **kw):
return "timezone('utc', current_timestamp)"
self.assert_compile(
MyUtcFunction(),
"utcnow()",
use_default_dialect=True
)
self.assert_compile(
MyUtcFunction(),
"timezone('utc', current_timestamp)",
dialect=postgresql.dialect()
)
def test_function_calls_base(self):
from sqlalchemy.dialects import mssql
class greatest(FunctionElement):
type = Numeric()
name = 'greatest'
@compiles(greatest)
def default_greatest(element, compiler, **kw):
return compiler.visit_function(element)
@compiles(greatest, 'mssql')
def case_greatest(element, compiler, **kw):
arg1, arg2 = list(element.clauses)
return "CASE WHEN %s > %s THEN %s ELSE %s END" % (
compiler.process(arg1),
compiler.process(arg2),
compiler.process(arg1),
compiler.process(arg2),
)
self.assert_compile(
greatest('a', 'b'),
'greatest(:greatest_1, :greatest_2)',
use_default_dialect=True
)
self.assert_compile(
greatest('a', 'b'),
"CASE WHEN :greatest_1 > :greatest_2 "
"THEN :greatest_1 ELSE :greatest_2 END",
dialect=mssql.dialect()
)
def test_subclasses_one(self):
class Base(FunctionElement):
name = 'base'
class Sub1(Base):
name = 'sub1'
class Sub2(Base):
name = 'sub2'
@compiles(Base)
def visit_base(element, compiler, **kw):
return element.name
@compiles(Sub1)
def visit_base(element, compiler, **kw):
return "FOO" + element.name
self.assert_compile(
select([Sub1(), Sub2()]),
'SELECT FOOsub1, sub2',
use_default_dialect=True
)
def test_subclasses_two(self):
class Base(FunctionElement):
name = 'base'
class Sub1(Base):
name = 'sub1'
@compiles(Base)
def visit_base(element, compiler, **kw):
return element.name
class Sub2(Base):
name = 'sub2'
class SubSub1(Sub1):
name = 'subsub1'
self.assert_compile(
select([Sub1(), Sub2(), SubSub1()]),
'SELECT sub1, sub2, subsub1',
use_default_dialect=True
)
@compiles(Sub1)
def visit_base(element, compiler, **kw):
return "FOO" + element.name
self.assert_compile(
select([Sub1(), Sub2(), SubSub1()]),
'SELECT FOOsub1, sub2, FOOsubsub1',
use_default_dialect=True
)
class DefaultOnExistingTest(fixtures.TestBase, AssertsCompiledSQL):
"""Test replacement of default compilation on existing constructs."""
__dialect__ = 'default'
def teardown(self):
for cls in (Select, BindParameter):
deregister(cls)
def test_select(self):
t1 = table('t1', column('c1'), column('c2'))
@compiles(Select, 'sqlite')
def compile(element, compiler, **kw):
return "OVERRIDE"
s1 = select([t1])
self.assert_compile(
s1, "SELECT t1.c1, t1.c2 FROM t1",
)
from sqlalchemy.dialects.sqlite import base as sqlite
self.assert_compile(
s1, "OVERRIDE",
dialect=sqlite.dialect()
)
def test_binds_in_select(self):
t = table('t',
column('a'),
column('b'),
column('c'))
@compiles(BindParameter)
def gen_bind(element, compiler, **kw):
return "BIND(%s)" % compiler.visit_bindparam(element, **kw)
self.assert_compile(
t.select().where(t.c.c == 5),
"SELECT t.a, t.b, t.c FROM t WHERE t.c = BIND(:c_1)",
use_default_dialect=True
)
def test_binds_in_dml(self):
t = table('t',
column('a'),
column('b'),
column('c'))
@compiles(BindParameter)
def gen_bind(element, compiler, **kw):
return "BIND(%s)" % compiler.visit_bindparam(element, **kw)
self.assert_compile(
t.insert(),
"INSERT INTO t (a, b) VALUES (BIND(:a), BIND(:b))",
{'a': 1, 'b': 2},
use_default_dialect=True
)