mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-08 18:00:23 -04:00
394 lines
12 KiB
Python
394 lines
12 KiB
Python
from sqlalchemy.testing import eq_
|
|
from sqlalchemy import testing
|
|
from sqlalchemy.testing.schema import Table, Column
|
|
from sqlalchemy.types import TypeDecorator
|
|
from sqlalchemy.testing import fixtures, AssertsExecutionResults, engines, \
|
|
assert_raises_message
|
|
from sqlalchemy import exc as sa_exc
|
|
from sqlalchemy import MetaData, String, Integer, Boolean, func, select, \
|
|
Sequence
|
|
import itertools
|
|
|
|
table = GoofyType = seq = None
|
|
|
|
|
|
class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
|
|
__requires__ = 'returning',
|
|
__backend__ = True
|
|
|
|
def setup(self):
|
|
meta = MetaData(testing.db)
|
|
global table, GoofyType
|
|
|
|
class GoofyType(TypeDecorator):
|
|
impl = String
|
|
|
|
def process_bind_param(self, value, dialect):
|
|
if value is None:
|
|
return None
|
|
return "FOO" + value
|
|
|
|
def process_result_value(self, value, dialect):
|
|
if value is None:
|
|
return None
|
|
return value + "BAR"
|
|
|
|
table = Table(
|
|
'tables', meta,
|
|
Column(
|
|
'id', Integer, primary_key=True,
|
|
test_needs_autoincrement=True),
|
|
Column('persons', Integer),
|
|
Column('full', Boolean),
|
|
Column('goofy', GoofyType(50)))
|
|
table.create(checkfirst=True)
|
|
|
|
def teardown(self):
|
|
table.drop()
|
|
|
|
def test_column_targeting(self):
|
|
result = table.insert().returning(
|
|
table.c.id, table.c.full).execute({'persons': 1, 'full': False})
|
|
|
|
row = result.first()
|
|
assert row[table.c.id] == row['id'] == 1
|
|
assert row[table.c.full] == row['full']
|
|
assert row['full'] is False
|
|
|
|
result = table.insert().values(
|
|
persons=5, full=True, goofy="somegoofy").\
|
|
returning(table.c.persons, table.c.full, table.c.goofy).execute()
|
|
row = result.first()
|
|
assert row[table.c.persons] == row['persons'] == 5
|
|
assert row[table.c.full] == row['full']
|
|
|
|
eq_(row[table.c.goofy], row['goofy'])
|
|
eq_(row['goofy'], "FOOsomegoofyBAR")
|
|
|
|
@testing.fails_on('firebird', "fb can't handle returning x AS y")
|
|
def test_labeling(self):
|
|
result = table.insert().values(persons=6).\
|
|
returning(table.c.persons.label('lala')).execute()
|
|
row = result.first()
|
|
assert row['lala'] == 6
|
|
|
|
@testing.fails_on(
|
|
'firebird',
|
|
"fb/kintersbasdb can't handle the bind params")
|
|
@testing.fails_on('oracle+zxjdbc', "JDBC driver bug")
|
|
def test_anon_expressions(self):
|
|
result = table.insert().values(goofy="someOTHERgoofy").\
|
|
returning(func.lower(table.c.goofy, type_=GoofyType)).execute()
|
|
row = result.first()
|
|
eq_(row[0], "foosomeothergoofyBAR")
|
|
|
|
result = table.insert().values(persons=12).\
|
|
returning(table.c.persons + 18).execute()
|
|
row = result.first()
|
|
eq_(row[0], 30)
|
|
|
|
def test_update_returning(self):
|
|
table.insert().execute(
|
|
[{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
|
|
|
|
result = table.update(
|
|
table.c.persons > 4, dict(
|
|
full=True)).returning(
|
|
table.c.id).execute()
|
|
eq_(result.fetchall(), [(1,)])
|
|
|
|
result2 = select([table.c.id, table.c.full]).order_by(
|
|
table.c.id).execute()
|
|
eq_(result2.fetchall(), [(1, True), (2, False)])
|
|
|
|
def test_insert_returning(self):
|
|
result = table.insert().returning(
|
|
table.c.id).execute({'persons': 1, 'full': False})
|
|
|
|
eq_(result.fetchall(), [(1,)])
|
|
|
|
@testing.requires.multivalues_inserts
|
|
def test_multirow_returning(self):
|
|
ins = table.insert().returning(table.c.id, table.c.persons).values(
|
|
[
|
|
{'persons': 1, 'full': False},
|
|
{'persons': 2, 'full': True},
|
|
{'persons': 3, 'full': False},
|
|
]
|
|
)
|
|
result = testing.db.execute(ins)
|
|
eq_(
|
|
result.fetchall(),
|
|
[(1, 1), (2, 2), (3, 3)]
|
|
)
|
|
|
|
def test_no_ipk_on_returning(self):
|
|
result = testing.db.execute(
|
|
table.insert().returning(table.c.id),
|
|
{'persons': 1, 'full': False}
|
|
)
|
|
assert_raises_message(
|
|
sa_exc.InvalidRequestError,
|
|
"Can't call inserted_primary_key when returning\(\) is used.",
|
|
getattr, result, "inserted_primary_key"
|
|
)
|
|
|
|
@testing.fails_on_everything_except('postgresql', 'firebird')
|
|
def test_literal_returning(self):
|
|
if testing.against("postgresql"):
|
|
literal_true = "true"
|
|
else:
|
|
literal_true = "1"
|
|
|
|
result4 = testing.db.execute(
|
|
'insert into tables (id, persons, "full") '
|
|
'values (5, 10, %s) returning persons' %
|
|
literal_true)
|
|
eq_([dict(row) for row in result4], [{'persons': 10}])
|
|
|
|
def test_delete_returning(self):
|
|
table.insert().execute(
|
|
[{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
|
|
|
|
result = table.delete(
|
|
table.c.persons > 4).returning(
|
|
table.c.id).execute()
|
|
eq_(result.fetchall(), [(1,)])
|
|
|
|
result2 = select([table.c.id, table.c.full]).order_by(
|
|
table.c.id).execute()
|
|
eq_(result2.fetchall(), [(2, False), ])
|
|
|
|
|
|
class SequenceReturningTest(fixtures.TestBase):
|
|
__requires__ = 'returning', 'sequences'
|
|
__backend__ = True
|
|
|
|
def setup(self):
|
|
meta = MetaData(testing.db)
|
|
global table, seq
|
|
seq = Sequence('tid_seq')
|
|
table = Table('tables', meta,
|
|
Column('id', Integer, seq, primary_key=True),
|
|
Column('data', String(50))
|
|
)
|
|
table.create(checkfirst=True)
|
|
|
|
def teardown(self):
|
|
table.drop()
|
|
|
|
def test_insert(self):
|
|
r = table.insert().values(data='hi').returning(table.c.id).execute()
|
|
assert r.first() == (1, )
|
|
assert seq.execute() == 2
|
|
|
|
|
|
class KeyReturningTest(fixtures.TestBase, AssertsExecutionResults):
|
|
|
|
"""test returning() works with columns that define 'key'."""
|
|
|
|
__requires__ = 'returning',
|
|
__backend__ = True
|
|
|
|
def setup(self):
|
|
meta = MetaData(testing.db)
|
|
global table
|
|
|
|
table = Table(
|
|
'tables',
|
|
meta,
|
|
Column(
|
|
'id',
|
|
Integer,
|
|
primary_key=True,
|
|
key='foo_id',
|
|
test_needs_autoincrement=True),
|
|
Column(
|
|
'data',
|
|
String(20)),
|
|
)
|
|
table.create(checkfirst=True)
|
|
|
|
def teardown(self):
|
|
table.drop()
|
|
|
|
@testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
|
|
@testing.exclude('postgresql', '<', (8, 2), '8.2+ feature')
|
|
def test_insert(self):
|
|
result = table.insert().returning(
|
|
table.c.foo_id).execute(
|
|
data='somedata')
|
|
row = result.first()
|
|
assert row[table.c.foo_id] == row['id'] == 1
|
|
|
|
result = table.select().execute().first()
|
|
assert row[table.c.foo_id] == row['id'] == 1
|
|
|
|
|
|
class ReturnDefaultsTest(fixtures.TablesTest):
|
|
__requires__ = ('returning', )
|
|
run_define_tables = 'each'
|
|
__backend__ = True
|
|
|
|
@classmethod
|
|
def define_tables(cls, metadata):
|
|
from sqlalchemy.sql import ColumnElement
|
|
from sqlalchemy.ext.compiler import compiles
|
|
|
|
counter = itertools.count()
|
|
|
|
class IncDefault(ColumnElement):
|
|
pass
|
|
|
|
@compiles(IncDefault)
|
|
def compile(element, compiler, **kw):
|
|
return str(next(counter))
|
|
|
|
Table(
|
|
"t1", metadata,
|
|
Column(
|
|
"id", Integer, primary_key=True,
|
|
test_needs_autoincrement=True),
|
|
Column("data", String(50)),
|
|
Column("insdef", Integer, default=IncDefault()),
|
|
Column("upddef", Integer, onupdate=IncDefault()))
|
|
|
|
def test_chained_insert_pk(self):
|
|
t1 = self.tables.t1
|
|
result = testing.db.execute(
|
|
t1.insert().values(upddef=1).return_defaults(t1.c.insdef)
|
|
)
|
|
eq_(
|
|
[result.returned_defaults[k] for k in (t1.c.id, t1.c.insdef)],
|
|
[1, 0]
|
|
)
|
|
|
|
def test_arg_insert_pk(self):
|
|
t1 = self.tables.t1
|
|
result = testing.db.execute(
|
|
t1.insert(return_defaults=[t1.c.insdef]).values(upddef=1)
|
|
)
|
|
eq_(
|
|
[result.returned_defaults[k] for k in (t1.c.id, t1.c.insdef)],
|
|
[1, 0]
|
|
)
|
|
|
|
def test_chained_update_pk(self):
|
|
t1 = self.tables.t1
|
|
testing.db.execute(
|
|
t1.insert().values(upddef=1)
|
|
)
|
|
result = testing.db.execute(t1.update().values(data='d1').
|
|
return_defaults(t1.c.upddef))
|
|
eq_(
|
|
[result.returned_defaults[k] for k in (t1.c.upddef,)],
|
|
[1]
|
|
)
|
|
|
|
def test_arg_update_pk(self):
|
|
t1 = self.tables.t1
|
|
testing.db.execute(
|
|
t1.insert().values(upddef=1)
|
|
)
|
|
result = testing.db.execute(t1.update(return_defaults=[t1.c.upddef]).
|
|
values(data='d1'))
|
|
eq_(
|
|
[result.returned_defaults[k] for k in (t1.c.upddef,)],
|
|
[1]
|
|
)
|
|
|
|
def test_insert_non_default(self):
|
|
"""test that a column not marked at all as a
|
|
default works with this feature."""
|
|
|
|
t1 = self.tables.t1
|
|
result = testing.db.execute(
|
|
t1.insert().values(upddef=1).return_defaults(t1.c.data)
|
|
)
|
|
eq_(
|
|
[result.returned_defaults[k] for k in (t1.c.id, t1.c.data,)],
|
|
[1, None]
|
|
)
|
|
|
|
def test_update_non_default(self):
|
|
"""test that a column not marked at all as a
|
|
default works with this feature."""
|
|
|
|
t1 = self.tables.t1
|
|
testing.db.execute(
|
|
t1.insert().values(upddef=1)
|
|
)
|
|
result = testing.db.execute(
|
|
t1.update(). values(
|
|
upddef=2).return_defaults(
|
|
t1.c.data))
|
|
eq_(
|
|
[result.returned_defaults[k] for k in (t1.c.data,)],
|
|
[None]
|
|
)
|
|
|
|
@testing.fails_on("oracle+cx_oracle", "seems like a cx_oracle bug")
|
|
def test_insert_non_default_plus_default(self):
|
|
t1 = self.tables.t1
|
|
result = testing.db.execute(
|
|
t1.insert().values(upddef=1).return_defaults(
|
|
t1.c.data, t1.c.insdef)
|
|
)
|
|
eq_(
|
|
dict(result.returned_defaults),
|
|
{"id": 1, "data": None, "insdef": 0}
|
|
)
|
|
|
|
@testing.fails_on("oracle+cx_oracle", "seems like a cx_oracle bug")
|
|
def test_update_non_default_plus_default(self):
|
|
t1 = self.tables.t1
|
|
testing.db.execute(
|
|
t1.insert().values(upddef=1)
|
|
)
|
|
result = testing.db.execute(
|
|
t1.update().
|
|
values(insdef=2).return_defaults(
|
|
t1.c.data, t1.c.upddef))
|
|
eq_(
|
|
dict(result.returned_defaults),
|
|
{"data": None, 'upddef': 1}
|
|
)
|
|
|
|
|
|
class ImplicitReturningFlag(fixtures.TestBase):
|
|
__backend__ = True
|
|
|
|
def test_flag_turned_off(self):
|
|
e = engines.testing_engine(options={'implicit_returning': False})
|
|
assert e.dialect.implicit_returning is False
|
|
c = e.connect()
|
|
c.close()
|
|
assert e.dialect.implicit_returning is False
|
|
|
|
def test_flag_turned_on(self):
|
|
e = engines.testing_engine(options={'implicit_returning': True})
|
|
assert e.dialect.implicit_returning is True
|
|
c = e.connect()
|
|
c.close()
|
|
assert e.dialect.implicit_returning is True
|
|
|
|
def test_flag_turned_default(self):
|
|
supports = [False]
|
|
|
|
def go():
|
|
supports[0] = True
|
|
testing.requires.returning(go)()
|
|
e = engines.testing_engine()
|
|
|
|
# starts as False. This is because all of Firebird,
|
|
# Postgresql, Oracle, SQL Server started supporting RETURNING
|
|
# as of a certain version, and the flag is not set until
|
|
# version detection occurs. If some DB comes along that has
|
|
# RETURNING in all cases, this test can be adjusted.
|
|
assert e.dialect.implicit_returning is False
|
|
|
|
# version detection on connect sets it
|
|
c = e.connect()
|
|
c.close()
|
|
assert e.dialect.implicit_returning is supports[0]
|