Files
sqlalchemy/test/ext/test_serializer.py
Mike Bayer 239f629b9a update pickle tests
Since I want to get rid of util.portable_instancemethod, first
make sure we are testing pickle extensively including going through
all protocols for all metadata-oriented tests.

Change-Id: I0064bc16033939780e50c7a8a4ede60ef5835b38
2025-06-11 15:19:23 -04:00

374 lines
12 KiB
Python

import pickle
from sqlalchemy import desc
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Integer
from sqlalchemy import join
from sqlalchemy import literal_column
from sqlalchemy import MetaData
from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy import testing
from sqlalchemy.ext import serializer
from sqlalchemy.orm import aliased
from sqlalchemy.orm import class_mapper
from sqlalchemy.orm import column_property
from sqlalchemy.orm import configure_mappers
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import relationship
from sqlalchemy.orm import scoped_session
from sqlalchemy.orm import sessionmaker
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import combinations
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
def pickle_protocols():
return range(-2, pickle.HIGHEST_PROTOCOL)
class User(ComparableEntity):
pass
class Address(ComparableEntity):
pass
class SerializeTest(AssertsCompiledSQL, fixtures.MappedTest):
run_setup_mappers = "once"
run_inserts = "once"
run_deletes = None
@classmethod
def define_tables(cls, metadata):
global users, addresses
users = Table(
"users",
metadata,
Column("id", Integer, primary_key=True),
Column("name", String(50)),
)
addresses = Table(
"addresses",
metadata,
Column("id", Integer, primary_key=True),
Column("email", String(50)),
Column("user_id", Integer, ForeignKey("users.id")),
)
@classmethod
def setup_mappers(cls):
global Session
Session = scoped_session(sessionmaker(testing.db))
cls.mapper_registry.map_imperatively(
User,
users,
properties={
"addresses": relationship(
Address, backref="user", order_by=addresses.c.id
)
},
)
cls.mapper_registry.map_imperatively(Address, addresses)
configure_mappers()
@classmethod
def insert_data(cls, connection):
params = [
dict(list(zip(("id", "name"), column_values)))
for column_values in [
(7, "jack"),
(8, "ed"),
(9, "fred"),
(10, "chuck"),
]
]
connection.execute(users.insert(), params)
connection.execute(
addresses.insert(),
[
dict(list(zip(("id", "user_id", "email"), column_values)))
for column_values in [
(1, 7, "jack@bean.com"),
(2, 8, "ed@wood.com"),
(3, 8, "ed@bettyboop.com"),
(4, 8, "ed@lala.com"),
(5, 9, "fred@fred.com"),
]
],
)
def test_tables(self):
assert (
serializer.loads(
serializer.dumps(users, -1), users.metadata, Session
)
is users
)
def test_columns(self):
assert (
serializer.loads(
serializer.dumps(users.c.name, -1), users.metadata, Session
)
is users.c.name
)
def test_mapper(self):
user_mapper = class_mapper(User)
assert (
serializer.loads(serializer.dumps(user_mapper, -1), None, None)
is user_mapper
)
def test_attribute(self):
assert (
serializer.loads(serializer.dumps(User.name, -1), None, None)
is User.name
)
def test_expression(self):
expr = select(users).select_from(users.join(addresses)).limit(5)
re_expr = serializer.loads(
serializer.dumps(expr, -1), users.metadata, None
)
eq_(str(expr), str(re_expr))
eq_(
Session.connection().execute(re_expr).fetchall(),
[(7, "jack"), (8, "ed"), (8, "ed"), (8, "ed"), (9, "fred")],
)
def test_query_one(self):
q = (
Session.query(User)
.filter(User.name == "ed")
.options(joinedload(User.addresses))
)
q2 = serializer.loads(serializer.dumps(q, -1), users.metadata, Session)
def go():
eq_(
q2.all(),
[
User(
name="ed",
addresses=[
Address(id=2),
Address(id=3),
Address(id=4),
],
)
],
)
self.assert_sql_count(testing.db, go, 1)
eq_(
q2.join(User.addresses)
.filter(Address.email == "ed@bettyboop.com")
.enable_eagerloads(False)
.with_entities(func.count(literal_column("*")))
.scalar(),
1,
)
u1 = Session.get(User, 8)
q = (
Session.query(Address)
.filter(Address.user == u1)
.order_by(desc(Address.email))
)
q2 = serializer.loads(serializer.dumps(q, -1), users.metadata, Session)
eq_(
q2.all(),
[
Address(email="ed@wood.com"),
Address(email="ed@lala.com"),
Address(email="ed@bettyboop.com"),
],
)
def test_query_two(self):
q = (
Session.query(User)
.join(User.addresses)
.filter(Address.email.like("%fred%"))
)
q2 = serializer.loads(serializer.dumps(q, -1), users.metadata, Session)
eq_(q2.all(), [User(name="fred")])
eq_(list(q2.with_entities(User.id, User.name)), [(9, "fred")])
def test_query_three(self):
ua = aliased(User)
q = (
Session.query(ua)
.join(ua.addresses)
.filter(Address.email.like("%fred%"))
)
for prot in pickle_protocols():
q2 = serializer.loads(
serializer.dumps(q, prot), users.metadata, Session
)
eq_(q2.all(), [User(name="fred")])
# try to pull out the aliased entity here...
ua_2 = q2._compile_state()._entities[0].entity_zero.entity
eq_(list(q2.with_entities(ua_2.id, ua_2.name)), [(9, "fred")])
def test_annotated_one(self):
j = join(users, addresses)._annotate({"foo": "bar"})
query = select(addresses).select_from(j)
str(query)
for prot in pickle_protocols():
pickled_failing = serializer.dumps(j, prot)
serializer.loads(pickled_failing, users.metadata, None)
def test_orm_join(self):
from sqlalchemy.orm import join
j = join(User, Address, User.addresses)
j2 = serializer.loads(serializer.dumps(j, -1), users.metadata)
assert j2.left is j.left
assert j2.right is j.right
@testing.exclude(
"sqlite", "<=", (3, 5, 9), "id comparison failing on the buildbot"
)
def test_aliases(self):
u7, u8, u9, u10 = Session.query(User).order_by(User.id).all()
ualias = aliased(User)
q = (
Session.query(User, ualias)
.join(ualias, User.id < ualias.id)
.filter(User.id < 9)
.order_by(User.id, ualias.id)
)
eq_(
list(q.all()), [(u7, u8), (u7, u9), (u7, u10), (u8, u9), (u8, u10)]
)
q2 = serializer.loads(serializer.dumps(q, -1), users.metadata, Session)
eq_(
list(q2.all()),
[(u7, u8), (u7, u9), (u7, u10), (u8, u9), (u8, u10)],
)
def test_any(self):
r = User.addresses.any(Address.email == "x")
ser = serializer.dumps(r, -1)
x = serializer.loads(ser, users.metadata)
eq_(str(r), str(x))
def test_unicode(self):
m = MetaData()
t = Table("\u6e2c\u8a66", m, Column("\u6e2c\u8a66_id", Integer))
expr = select(t).where(t.c["\u6e2c\u8a66_id"] == 5)
expr2 = serializer.loads(serializer.dumps(expr, -1), m)
self.assert_compile(
expr2,
'SELECT "\u6e2c\u8a66"."\u6e2c\u8a66_id" FROM "\u6e2c\u8a66" '
'WHERE "\u6e2c\u8a66"."\u6e2c\u8a66_id" = :\u6e2c\u8a66_id_1',
dialect="default",
)
@combinations(
(
lambda: func.max(users.c.name).over(range_=(None, 0)),
"max(users.name) OVER (RANGE BETWEEN UNBOUNDED "
"PRECEDING AND CURRENT ROW)",
),
(
lambda: func.max(users.c.name).over(range_=(0, None)),
"max(users.name) OVER (RANGE BETWEEN CURRENT "
"ROW AND UNBOUNDED FOLLOWING)",
),
(
lambda: func.max(users.c.name).over(rows=(None, 0)),
"max(users.name) OVER (ROWS BETWEEN UNBOUNDED "
"PRECEDING AND CURRENT ROW)",
),
(
lambda: func.max(users.c.name).over(rows=(0, None)),
"max(users.name) OVER (ROWS BETWEEN CURRENT "
"ROW AND UNBOUNDED FOLLOWING)",
),
(
lambda: func.max(users.c.name).over(groups=(None, 0)),
"max(users.name) OVER (GROUPS BETWEEN UNBOUNDED "
"PRECEDING AND CURRENT ROW)",
),
(
lambda: func.max(users.c.name).over(groups=(0, None)),
"max(users.name) OVER (GROUPS BETWEEN CURRENT "
"ROW AND UNBOUNDED FOLLOWING)",
),
)
def test_over(self, over_fn, sql):
o = over_fn()
self.assert_compile(o, sql)
ol = serializer.loads(serializer.dumps(o), users.metadata)
self.assert_compile(ol, sql)
class ColumnPropertyWParamTest(
AssertsCompiledSQL, fixtures.DeclarativeMappedTest
):
__dialect__ = "default"
run_create_tables = None
@classmethod
def setup_classes(cls):
Base = cls.DeclarativeBasic
global TestTable
class TestTable(Base):
__tablename__ = "test"
id = Column(Integer, primary_key=True, autoincrement=True)
_some_id = Column("some_id", String)
some_primary_id = column_property(
func.left(_some_id, 6).cast(Integer)
)
def test_deserailize_colprop(self):
TestTable = self.classes.TestTable
s = scoped_session(sessionmaker())
expr = s.query(TestTable).filter(TestTable.some_primary_id == 123456)
expr2 = serializer.loads(serializer.dumps(expr), TestTable.metadata, s)
# note in the original, the same bound parameter is used twice
self.assert_compile(
expr,
"SELECT "
"CAST(left(test.some_id, :left_1) AS INTEGER) AS anon_1, "
"test.id AS test_id, test.some_id AS test_some_id FROM test WHERE "
"CAST(left(test.some_id, :left_1) AS INTEGER) = :param_1",
checkparams={"left_1": 6, "param_1": 123456},
)
# in the deserialized, it's two separate parameter objects which
# need to have different anonymous names. they still have
# the same value however
self.assert_compile(
expr2,
"SELECT CAST(left(test.some_id, :left_1) AS INTEGER) AS anon_1, "
"test.id AS test_id, test.some_id AS test_some_id FROM test WHERE "
"CAST(left(test.some_id, :left_2) AS INTEGER) = :param_1",
checkparams={"left_1": 6, "left_2": 6, "param_1": 123456},
)