Files
peewee/tests/manytomany.py
2026-03-23 08:31:54 -05:00

757 lines
26 KiB
Python

"""
ManyToManyField behavior tests: through models (auto and explicit), backrefs,
inheritance, FK-to-non-PK, FK-as-PK, and multiple M2M on same tables.
Test case ordering:
* Core M2M operations (User/Note — the largest test class)
* Backref behavior (Student/Course)
* Inheritance of M2M through models
* FK-to-non-PK M2M (Color/Logo with non-PK FK)
* FK-as-PK M2M (Person/Account/AccountList)
* Multiple M2M between same tables (Permission/Visitor)
* Errors / edge cases.
"""
from peewee import *
from .base import ModelTestCase
from .base import TestModel
from .base import get_in_memory_db
from .base import requires_models
# ---------------------------------------------------------------------------
# Module-local models for M2M tests.
# NOTE: User and Note here are local to this module (not base_models).
# ---------------------------------------------------------------------------
class User(TestModel):
username = TextField(unique=True)
class Note(TestModel):
text = TextField()
users = ManyToManyField(User)
NoteUserThrough = Note.users.get_through_model()
AltThroughDeferred = DeferredThroughModel()
class AltNote(TestModel):
text = TextField()
users = ManyToManyField(User, through_model=AltThroughDeferred)
class AltThroughModel(TestModel):
user = ForeignKeyField(User, backref='_xx_rel')
note = ForeignKeyField(AltNote, backref='_xx_rel')
class Meta:
primary_key = CompositeKey('user', 'note')
AltThroughDeferred.set_model(AltThroughModel)
class Student(TestModel):
name = TextField()
CourseStudentDeferred = DeferredThroughModel()
class Course(TestModel):
name = TextField()
students = ManyToManyField(Student, backref='+')
students2 = ManyToManyField(Student, through_model=CourseStudentDeferred)
CourseStudent = Course.students.get_through_model()
class CourseStudent2(TestModel):
course = ForeignKeyField(Course, backref='+')
student = ForeignKeyField(Student, backref='+')
CourseStudentDeferred.set_model(CourseStudent2)
class Color(TestModel):
name = TextField(unique=True)
LogoColorDeferred = DeferredThroughModel()
class Logo(TestModel):
name = TextField(unique=True)
colors = ManyToManyField(Color, through_model=LogoColorDeferred)
class LogoColor(TestModel):
logo = ForeignKeyField(Logo, field=Logo.name)
color = ForeignKeyField(Color, field=Color.name) # FK to non-PK column.
LogoColorDeferred.set_model(LogoColor)
# ===========================================================================
# Core M2M operations (add, remove, set, clear, prefetch)
# ===========================================================================
class TestManyToMany(ModelTestCase):
database = get_in_memory_db()
requires = [User, Note, NoteUserThrough, AltNote, AltThroughModel]
user_to_note = {
'gargie': [1, 2],
'huey': [2, 3],
'mickey': [3, 4],
'zaizee': [4, 5],
}
def setUp(self):
super(TestManyToMany, self).setUp()
for username in sorted(self.user_to_note):
User.create(username=username)
for i in range(5):
Note.create(text='note-%s' % (i + 1))
def test_through_model(self):
self.assertEqual(len(NoteUserThrough._meta.fields), 3)
fields = NoteUserThrough._meta.fields
self.assertEqual(sorted(fields), ['id', 'note', 'user'])
note_field = fields['note']
self.assertEqual(note_field.rel_model, Note)
self.assertFalse(note_field.null)
user_field = fields['user']
self.assertEqual(user_field.rel_model, User)
self.assertFalse(user_field.null)
def _set_data(self):
for username, notes in self.user_to_note.items():
user = User.get(User.username == username)
for note in notes:
NoteUserThrough.create(
note=Note.get(Note.text == 'note-%s' % note),
user=user)
def assertNotes(self, query, expected):
notes = [note.text for note in query]
self.assertEqual(sorted(notes),
['note-%s' % i for i in sorted(expected)])
def assertUsers(self, query, expected):
usernames = [user.username for user in query]
self.assertEqual(sorted(usernames), sorted(expected))
def test_accessor_query(self):
self._set_data()
gargie, huey, mickey, zaizee = User.select().order_by(User.username)
with self.assertQueryCount(1):
self.assertNotes(gargie.notes, [1, 2])
with self.assertQueryCount(1):
self.assertNotes(zaizee.notes, [4, 5])
with self.assertQueryCount(2):
self.assertNotes(User.create(username='x').notes, [])
n1, n2, n3, n4, n5 = Note.select().order_by(Note.text)
with self.assertQueryCount(1):
self.assertUsers(n1.users, ['gargie'])
with self.assertQueryCount(1):
self.assertUsers(n2.users, ['gargie', 'huey'])
with self.assertQueryCount(1):
self.assertUsers(n5.users, ['zaizee'])
with self.assertQueryCount(2):
self.assertUsers(Note.create(text='x').users, [])
def test_prefetch_notes(self):
self._set_data()
for pt in PREFETCH_TYPE.values():
with self.assertQueryCount(3):
gargie, huey, mickey, zaizee = prefetch(
User.select().order_by(User.username),
NoteUserThrough,
Note,
prefetch_type=pt)
with self.assertQueryCount(0):
self.assertNotes(gargie.notes, [1, 2])
with self.assertQueryCount(0):
self.assertNotes(zaizee.notes, [4, 5])
with self.assertQueryCount(2):
self.assertNotes(User.create(username='x').notes, [])
def test_prefetch_users(self):
self._set_data()
for pt in PREFETCH_TYPE.values():
with self.assertQueryCount(3):
n1, n2, n3, n4, n5 = prefetch(
Note.select().order_by(Note.text),
NoteUserThrough,
User,
prefetch_type=pt)
with self.assertQueryCount(0):
self.assertUsers(n1.users, ['gargie'])
with self.assertQueryCount(0):
self.assertUsers(n2.users, ['gargie', 'huey'])
with self.assertQueryCount(0):
self.assertUsers(n5.users, ['zaizee'])
with self.assertQueryCount(2):
self.assertUsers(Note.create(text='x').users, [])
def test_query_filtering(self):
self._set_data()
gargie, huey, mickey, zaizee = User.select().order_by(User.username)
with self.assertQueryCount(1):
notes = gargie.notes.where(Note.text != 'note-2')
self.assertNotes(notes, [1])
def test_set_value(self):
self._set_data()
gargie = User.get(User.username == 'gargie')
huey = User.get(User.username == 'huey')
n1, n2, n3, n4, n5 = Note.select().order_by(Note.text)
with self.assertQueryCount(2):
gargie.notes = n3
self.assertNotes(gargie.notes, [3])
self.assertUsers(n3.users, ['gargie', 'huey', 'mickey'])
self.assertUsers(n1.users, [])
gargie.notes = [n3, n4]
self.assertNotes(gargie.notes, [3, 4])
self.assertUsers(n3.users, ['gargie', 'huey', 'mickey'])
self.assertUsers(n4.users, ['gargie', 'mickey', 'zaizee'])
def test_set_query(self):
huey = User.get(User.username == 'huey')
with self.assertQueryCount(2):
huey.notes = Note.select().where(~Note.text.endswith('4'))
self.assertNotes(huey.notes, [1, 2, 3, 5])
def test_add(self):
gargie = User.get(User.username == 'gargie')
huey = User.get(User.username == 'huey')
n1, n2, n3, n4, n5 = Note.select().order_by(Note.text)
gargie.notes.add([n1, n2])
self.assertNotes(gargie.notes, [1, 2])
self.assertUsers(n1.users, ['gargie'])
self.assertUsers(n2.users, ['gargie'])
for note in [n3, n4, n5]:
self.assertUsers(note.users, [])
with self.assertQueryCount(1):
huey.notes.add(Note.select().where(
fn.substr(Note.text, 6, 1) << ['1', '3', '5']))
self.assertNotes(huey.notes, [1, 3, 5])
self.assertUsers(n1.users, ['gargie', 'huey'])
self.assertUsers(n2.users, ['gargie'])
self.assertUsers(n3.users, ['huey'])
self.assertUsers(n4.users, [])
self.assertUsers(n5.users, ['huey'])
with self.assertQueryCount(1):
gargie.notes.add(n4)
self.assertNotes(gargie.notes, [1, 2, 4])
with self.assertQueryCount(2):
n3.users.add(
User.select().where(User.username != 'gargie'),
clear_existing=True)
self.assertUsers(n3.users, ['huey', 'mickey', 'zaizee'])
def test_add_by_pk(self):
huey = User.get(User.username == 'huey')
n1, n2, n3 = Note.select().order_by(Note.text).limit(3)
huey.notes.add([n1.id, n2.id])
self.assertNotes(huey.notes, [1, 2])
self.assertUsers(n1.users, ['huey'])
self.assertUsers(n2.users, ['huey'])
self.assertUsers(n3.users, [])
def test_unique(self):
n1 = Note.get(Note.text == 'note-1')
huey = User.get(User.username == 'huey')
def add_user(note, user):
with self.assertQueryCount(1):
note.users.add(user)
add_user(n1, huey)
self.assertRaises(IntegrityError, add_user, n1, huey)
add_user(n1, User.get(User.username == 'zaizee'))
self.assertUsers(n1.users, ['huey', 'zaizee'])
def test_remove(self):
self._set_data()
gargie, huey, mickey, zaizee = User.select().order_by(User.username)
n1, n2, n3, n4, n5 = Note.select().order_by(Note.text)
with self.assertQueryCount(1):
gargie.notes.remove([n1, n2, n3])
self.assertNotes(gargie.notes, [])
self.assertNotes(huey.notes, [2, 3])
with self.assertQueryCount(1):
huey.notes.remove(Note.select().where(
Note.text << ['note-2', 'note-4', 'note-5']))
self.assertNotes(huey.notes, [3])
self.assertNotes(mickey.notes, [3, 4])
self.assertNotes(zaizee.notes, [4, 5])
with self.assertQueryCount(1):
n4.users.remove([gargie, mickey])
self.assertUsers(n4.users, ['zaizee'])
with self.assertQueryCount(1):
n5.users.remove(User.select())
self.assertUsers(n5.users, [])
def test_remove_by_id(self):
self._set_data()
gargie, huey = User.select().order_by(User.username).limit(2)
n1, n2, n3, n4 = Note.select().order_by(Note.text).limit(4)
gargie.notes.add([n3, n4])
with self.assertQueryCount(1):
gargie.notes.remove([n1.id, n3.id])
self.assertNotes(gargie.notes, [2, 4])
self.assertNotes(huey.notes, [2, 3])
def test_clear(self):
gargie = User.get(User.username == 'gargie')
huey = User.get(User.username == 'huey')
gargie.notes = Note.select()
huey.notes = Note.select()
self.assertEqual(gargie.notes.count(), 5)
self.assertEqual(huey.notes.count(), 5)
gargie.notes.clear()
self.assertEqual(gargie.notes.count(), 0)
self.assertEqual(huey.notes.count(), 5)
n1 = Note.get(Note.text == 'note-1')
n2 = Note.get(Note.text == 'note-2')
n1.users = User.select()
n2.users = User.select()
self.assertEqual(n1.users.count(), 4)
self.assertEqual(n2.users.count(), 4)
n1.users.clear()
self.assertEqual(n1.users.count(), 0)
self.assertEqual(n2.users.count(), 4)
def test_manual_through(self):
gargie, huey, mickey, zaizee = User.select().order_by(User.username)
alt_notes = []
for i in range(5):
alt_notes.append(AltNote.create(text='note-%s' % (i + 1)))
self.assertNotes(gargie.altnotes, [])
for alt_note in alt_notes:
self.assertUsers(alt_note.users, [])
n1, n2, n3, n4, n5 = alt_notes
# Test adding relationships by setting the descriptor.
gargie.altnotes = [n1, n2]
with self.assertQueryCount(2):
huey.altnotes = AltNote.select().where(
fn.substr(AltNote.text, 6, 1) << ['1', '3', '5'])
mickey.altnotes.add([n1, n4])
with self.assertQueryCount(2):
zaizee.altnotes = AltNote.select()
# Test that the notes were added correctly.
with self.assertQueryCount(1):
self.assertNotes(gargie.altnotes, [1, 2])
with self.assertQueryCount(1):
self.assertNotes(huey.altnotes, [1, 3, 5])
with self.assertQueryCount(1):
self.assertNotes(mickey.altnotes, [1, 4])
with self.assertQueryCount(1):
self.assertNotes(zaizee.altnotes, [1, 2, 3, 4, 5])
# Test removing notes.
with self.assertQueryCount(1):
gargie.altnotes.remove(n1)
self.assertNotes(gargie.altnotes, [2])
with self.assertQueryCount(1):
huey.altnotes.remove([n1, n2, n3])
self.assertNotes(huey.altnotes, [5])
with self.assertQueryCount(1):
sq = (AltNote
.select()
.where(fn.SUBSTR(AltNote.text, 6, 1) << ['1', '2', '4']))
zaizee.altnotes.remove(sq)
self.assertNotes(zaizee.altnotes, [3, 5])
# Test the backside of the relationship.
n1.users = User.select().where(User.username != 'gargie')
with self.assertQueryCount(1):
self.assertUsers(n1.users, ['huey', 'mickey', 'zaizee'])
with self.assertQueryCount(1):
self.assertUsers(n2.users, ['gargie'])
with self.assertQueryCount(1):
self.assertUsers(n3.users, ['zaizee'])
with self.assertQueryCount(1):
self.assertUsers(n4.users, ['mickey'])
with self.assertQueryCount(1):
self.assertUsers(n5.users, ['huey', 'zaizee'])
with self.assertQueryCount(1):
n1.users.remove(User.select())
with self.assertQueryCount(1):
n5.users.remove([gargie, huey])
with self.assertQueryCount(1):
self.assertUsers(n1.users, [])
with self.assertQueryCount(1):
self.assertUsers(n5.users, ['zaizee'])
# ===========================================================================
# Backref behavior, inheritance, and FK-to-non-PK
# ===========================================================================
class TestManyToManyBackrefBehavior(ModelTestCase):
database = get_in_memory_db()
requires = [Student, Course, CourseStudent, CourseStudent2]
def setUp(self):
super(TestManyToManyBackrefBehavior, self).setUp()
math = Course.create(name='math')
engl = Course.create(name='engl')
huey, mickey, zaizee = [Student.create(name=name)
for name in ('huey', 'mickey', 'zaizee')]
# Set up relationships.
math.students.add([huey, zaizee])
engl.students.add([mickey])
math.students2.add([mickey])
engl.students2.add([huey, zaizee])
def test_manytomanyfield_disabled_backref(self):
math = Course.get(name='math')
query = math.students.order_by(Student.name)
self.assertEqual([s.name for s in query], ['huey', 'zaizee'])
huey = Student.get(name='huey')
math.students.remove(huey)
self.assertEqual([s.name for s in math.students], ['zaizee'])
# The backref is via the CourseStudent2 through-model.
self.assertEqual([c.name for c in huey.courses], ['engl'])
def test_through_model_disabled_backrefs(self):
# Here we're testing the case where the many-to-many field does not
# explicitly disable back-references, but the foreign-keys on the
# through model have disabled back-references.
engl = Course.get(name='engl')
query = engl.students2.order_by(Student.name)
self.assertEqual([s.name for s in query], ['huey', 'zaizee'])
zaizee = Student.get(Student.name == 'zaizee')
engl.students2.remove(zaizee)
self.assertEqual([s.name for s in engl.students2], ['huey'])
math = Course.get(name='math')
self.assertEqual([s.name for s in math.students2], ['mickey'])
class TestManyToManyInheritance(ModelTestCase):
def test_manytomany_inheritance(self):
class BaseModel(TestModel):
class Meta:
database = self.database
class User(BaseModel):
username = TextField()
class Project(BaseModel):
name = TextField()
users = ManyToManyField(User, backref='projects')
def subclass_project():
class VProject(Project):
pass
# We cannot subclass Project, because the many-to-many field "users"
# will be inherited, but the through-model does not contain a
# foreign-key to VProject. The through-model in this case is
# ProjectUsers, which has foreign-keys to project and user.
self.assertRaises(ValueError, subclass_project)
PThrough = Project.users.through_model
self.assertTrue(PThrough.project.rel_model is Project)
self.assertTrue(PThrough.user.rel_model is User)
class TestManyToManyFKtoNonPK(ModelTestCase):
database = get_in_memory_db()
requires = [Color, Logo, LogoColor]
def test_manytomany_fk_to_non_pk(self):
red = Color.create(name='red')
green = Color.create(name='green')
blue = Color.create(name='blue')
lrg = Logo.create(name='logo-rg')
lrb = Logo.create(name='logo-rb')
lrgb = Logo.create(name='logo-rgb')
lrg.colors.add([red, green])
lrb.colors.add([red, blue])
lrgb.colors.add([red, green, blue])
def assertColors(logo, expected):
colors = [c.name for c in logo.colors.order_by(Color.name)]
self.assertEqual(colors, expected)
assertColors(lrg, ['green', 'red'])
assertColors(lrb, ['blue', 'red'])
assertColors(lrgb, ['blue', 'green', 'red'])
def assertLogos(color, expected):
logos = [l.name for l in color.logos.order_by(Logo.name)]
self.assertEqual(logos, expected)
assertLogos(red, ['logo-rb', 'logo-rg', 'logo-rgb'])
assertLogos(green, ['logo-rg', 'logo-rgb'])
assertLogos(blue, ['logo-rb', 'logo-rgb'])
# Verify we can delete data as well.
lrg.colors.remove(red)
self.assertEqual([c.name for c in lrg.colors], ['green'])
blue.logos.remove(lrb)
self.assertEqual([c.name for c in lrb.colors], ['red'])
# Verify we can insert using a SELECT query.
lrg.colors.add(Color.select().where(Color.name != 'blue'), True)
assertColors(lrg, ['green', 'red'])
lrb.colors.add(Color.select().where(Color.name == 'blue'))
assertColors(lrb, ['blue', 'red'])
# Verify we can insert logos using a SELECT query.
black = Color.create(name='black')
black.logos.add(Logo.select().where(Logo.name != 'logo-rgb'))
assertLogos(black, ['logo-rb', 'logo-rg'])
assertColors(lrb, ['black', 'blue', 'red'])
assertColors(lrg, ['black', 'green', 'red'])
assertColors(lrgb, ['blue', 'green', 'red'])
# Verify we can delete using a SELECT query.
lrg.colors.remove(Color.select().where(Color.name == 'red'))
assertColors(lrg, ['black', 'green'])
black.logos.remove(Logo.select().where(Logo.name == 'logo-rg'))
assertLogos(black, ['logo-rb'])
# Verify we can clear.
lrg.colors.clear()
assertColors(lrg, [])
assertColors(lrb, ['black', 'blue', 'red']) # Not affected.
black.logos.clear()
assertLogos(black, [])
assertLogos(red, ['logo-rb', 'logo-rgb'])
# ===========================================================================
# FK-as-PK M2M and multiple M2M between same tables
# ===========================================================================
class Person(TestModel):
name = CharField()
class Account(TestModel):
person = ForeignKeyField(Person, primary_key=True)
class AccountList(TestModel):
name = CharField()
accounts = ManyToManyField(Account, backref='lists')
AccountListThrough = AccountList.accounts.get_through_model()
class TestForeignKeyPrimaryKeyManyToMany(ModelTestCase):
database = get_in_memory_db()
requires = [Person, Account, AccountList, AccountListThrough]
test_data = (
('huey', ('cats', 'evil')),
('zaizee', ('cats', 'good')),
('mickey', ('dogs', 'good')),
('zombie', ()),
)
def setUp(self):
super(TestForeignKeyPrimaryKeyManyToMany, self).setUp()
name2list = {}
for name, lists in self.test_data:
p = Person.create(name=name)
a = Account.create(person=p)
for l in lists:
if l not in name2list:
name2list[l] = AccountList.create(name=l)
name2list[l].accounts.add(a)
def account_for(self, name):
return Account.select().join(Person).where(Person.name == name).get()
def assertLists(self, l1, l2):
self.assertEqual(sorted(list(l1)), sorted(list(l2)))
def test_pk_is_fk(self):
list2names = {}
for name, lists in self.test_data:
account = self.account_for(name)
self.assertLists([l.name for l in account.lists],
lists)
for l in lists:
list2names.setdefault(l, [])
list2names[l].append(name)
for list_name, names in list2names.items():
account_list = AccountList.get(AccountList.name == list_name)
self.assertLists([s.person.name for s in account_list.accounts],
names)
def test_empty(self):
al = AccountList.create(name='empty')
self.assertEqual(list(al.accounts), [])
class Permission(TestModel):
name = TextField()
DeniedThroughDeferred = DeferredThroughModel()
class Visitor(TestModel):
name = TextField()
allowed = ManyToManyField(Permission)
denied = ManyToManyField(Permission, through_model=DeniedThroughDeferred)
class DeniedThrough(TestModel):
permission = ForeignKeyField(Permission)
visitor = ForeignKeyField(Visitor)
DeniedThroughDeferred.set_model(DeniedThrough)
class TestMultipleManyToManySameTables(ModelTestCase):
database = get_in_memory_db()
requires = [Permission, Visitor, Visitor.allowed.through_model,
Visitor.denied.through_model]
def test_multiple_manytomany_same_tables(self):
p1, p2, p3 = [Permission.create(name=n) for n in ('p1', 'p2', 'p3')]
v1, v2, v3 = [Visitor.create(name=n) for n in ('v1', 'v2', 'v3')]
v1.allowed.add([p1, p2, p3])
v2.allowed.add(p2)
v2.denied.add([p1, p3])
v3.allowed.add(p3)
v3.denied.add(p1)
accum = []
for v in Visitor.select().order_by(Visitor.name):
allowed, denied = [], []
for p in v.allowed.order_by(Permission.name):
allowed.append(p.name)
for p in v.denied.order_by(Permission.name):
denied.append(p.name)
accum.append((v.name, allowed, denied))
self.assertEqual(accum, [
('v1', ['p1', 'p2', 'p3'], []),
('v2', ['p2'], ['p1', 'p3']),
('v3', ['p3'], ['p1'])])
# ===========================================================================
# Errors and edge-cases
# ===========================================================================
class TestManyToManyPreventUnsaved(ModelTestCase):
database = get_in_memory_db()
requires = [User, Note, NoteUserThrough]
def test_m2m_unsaved_raises(self):
n = Note(text='unsaved note')
# n has not been saved, so n.id is None.
with self.assertRaises(ValueError):
n.users # Triggers ManyToManyFieldAccessor.__get__
with self.assertRaises(ValueError):
n.users = [User(username='u')]
u = User.create(username='huey')
n = Note.create(text='note1')
result = list(n.users)
self.assertEqual(result, [])
with self.assertRaises(IntegrityError):
# Cannot set instance with no primary key.
with self.database.atomic():
n.users = [User()]
n.users = [u]
result = list(n.users)
self.assertEqual(result, [u])
class TestManyToManyInitErrors(ModelTestCase):
database = get_in_memory_db()
def test_invalid_through_model_type(self):
with self.assertRaises(TypeError):
ManyToManyField(User, through_model='not_a_model')
def test_on_delete_with_through_model_raises(self):
class DummyThrough(TestModel):
pass
with self.assertRaises(ValueError):
ManyToManyField(User, through_model=DummyThrough,
on_delete='CASCADE')
def test_on_update_with_through_model_raises(self):
class DummyThrough(TestModel):
pass
with self.assertRaises(ValueError):
ManyToManyField(User, through_model=DummyThrough,
on_update='CASCADE')
class TestManyToManyEmptyOperations(ModelTestCase):
database = get_in_memory_db()
requires = [User, Note, NoteUserThrough]
def test_add_empty_list(self):
u = User.create(username='huey')
n = Note.create(text='note1')
n.users.add([])
self.assertEqual(list(n.users), [])
def test_remove_empty_list(self):
u = User.create(username='huey')
n = Note.create(text='note1')
n.users.add([u])
result = n.users.remove([])
# remove with empty list returns None (early exit).
self.assertIsNone(result)
# The relationship should still exist.
self.assertEqual(len(list(n.users)), 1)