diff --git a/tests/model_sql.py b/tests/model_sql.py index cf401d4e..25b794ec 100644 --- a/tests/model_sql.py +++ b/tests/model_sql.py @@ -1258,6 +1258,149 @@ class TestModelAdvancedSQL(ModelDatabaseTestCase): 'WHERE ("t3"."author_id" = "t1"."id") ' 'ORDER BY "t3"."id" DESC LIMIT ?) AS "t2" ON ?'), [2, True]) + # -- orwhere -- + + def test_orwhere(self): + query = (User + .select() + .orwhere(User.username == 'huey') + .orwhere(User.username == 'zaizee')) + self.assertSQL(query, ( + 'SELECT "t1"."id", "t1"."username" ' + 'FROM "users" AS "t1" ' + 'WHERE (("t1"."username" = ?) OR ("t1"."username" = ?))'), + ['huey', 'zaizee']) + + def test_where_then_orwhere(self): + query = (User + .select() + .where(User.id > 0) + .orwhere(User.username == 'huey') + .orwhere(User.username == 'zaizee')) + self.assertSQL(query, ( + 'SELECT "t1"."id", "t1"."username" ' + 'FROM "users" AS "t1" ' + 'WHERE ((("t1"."id" > ?) OR ' + '("t1"."username" = ?)) OR ("t1"."username" = ?))'), + [0, 'huey', 'zaizee']) + + # -- ensure_join -- + + def test_ensure_join_noop(self): + """ensure_join is a no-op when the join already exists.""" + query = (User + .select(User, Tweet.content) + .join(Tweet) + .switch(User) + .ensure_join(User, Tweet)) + self.assertSQL(query, ( + 'SELECT "t1"."id", "t1"."username", "t2"."content" ' + 'FROM "users" AS "t1" ' + 'INNER JOIN "tweet" AS "t2" ON ("t2"."user_id" = "t1"."id")'), + []) + + def test_ensure_join_adds(self): + """ensure_join adds the join when it doesn't exist.""" + query = (User + .select(User, Tweet.content) + .ensure_join(User, Tweet)) + self.assertSQL(query, ( + 'SELECT "t1"."id", "t1"."username", "t2"."content" ' + 'FROM "users" AS "t1" ' + 'INNER JOIN "tweet" AS "t2" ON ("t2"."user_id" = "t1"."id")'), + []) + + # -- for_update SQL -- + + def test_for_update(self): + query = (User + .select() + .where(User.username == 'huey') + .for_update()) + self.assertSQL(query, ( + 'SELECT "t1"."id", "t1"."username" ' + 'FROM "users" AS "t1" ' + 'WHERE ("t1"."username" = ?) FOR UPDATE'), + ['huey'], for_update=True) + + def test_for_update_options(self): + query = User.select().for_update(for_update='FOR SHARE') + self.assertSQL(query, ( + 'SELECT "t1"."id", "t1"."username" ' + 'FROM "users" AS "t1" FOR SHARE'), [], for_update=True) + + query = User.select().for_update(nowait=True) + self.assertSQL(query, ( + 'SELECT "t1"."id", "t1"."username" ' + 'FROM "users" AS "t1" FOR UPDATE NOWAIT'), [], for_update=True) + + query = User.select().for_update(skip_locked=True) + self.assertSQL(query, ( + 'SELECT "t1"."id", "t1"."username" ' + 'FROM "users" AS "t1" FOR UPDATE SKIP LOCKED'), [], + for_update=True) + + def test_for_update_unsupported(self): + """FOR UPDATE on a database that doesn't support it raises.""" + query = User.select().for_update() + self.assertRaises(ValueError, self.assertSQL, query, + '', []) + + # -- having -- + + def test_having(self): + query = (User + .select(User.username, + fn.COUNT(Tweet.id).alias('ct')) + .join(Tweet, JOIN.LEFT_OUTER) + .group_by(User.username) + .having(fn.COUNT(Tweet.id) > 5)) + self.assertSQL(query, ( + 'SELECT "t1"."username", COUNT("t2"."id") AS "ct" ' + 'FROM "users" AS "t1" ' + 'LEFT OUTER JOIN "tweet" AS "t2" ' + 'ON ("t2"."user_id" = "t1"."id") ' + 'GROUP BY "t1"."username" ' + 'HAVING (COUNT("t2"."id") > ?)'), [5]) + + # -- distinct on -- + + def test_distinct_on(self): + query = User.select().distinct(User.username) + self.assertSQL(query, ( + 'SELECT DISTINCT ON ("t1"."username") ' + '"t1"."id", "t1"."username" ' + 'FROM "users" AS "t1"'), []) + + +# =========================================================================== +# ON CONFLICT shortcut SQL (on_conflict_ignore / on_conflict_replace) +# =========================================================================== + +class TestOnConflictShortcutSQL(ModelDatabaseTestCase): + database = get_in_memory_db() + requires = [User, Emp] + + def test_on_conflict_ignore(self): + query = User.insert(username='test').on_conflict_ignore() + self.assertSQL(query, ( + 'INSERT OR IGNORE INTO "users" ("username") VALUES (?)'), + ['test']) + + def test_on_conflict_replace(self): + query = Emp.insert(first='h', last='c', empno='1').on_conflict_replace() + self.assertSQL(query, ( + 'INSERT OR REPLACE INTO "emp" ' + '("first", "last", "empno") VALUES (?, ?, ?)'), + ['h', 'c', '1']) + + def test_on_conflict_both_target_and_constraint_raises(self): + self.assertRaises( + ValueError, + User.insert(username='test').on_conflict, + conflict_target=[User.username], + conflict_constraint='foo') + # =========================================================================== # ON CONFLICT / upsert SQL with Models @@ -1549,6 +1692,35 @@ class TestModelCompoundSelect(BaseTestCase): 'UNION ' 'SELECT "t2"."beta" FROM "beta" AS "t2"))'), []) + def test_union_method(self): + lhs = Alpha.select(Alpha.alpha).where(Alpha.alpha > 1) + rhs = Beta.select(Beta.beta).where(Beta.beta < 10) + query = lhs.union(rhs) + self.assertSQL(query, ( + 'SELECT "t1"."alpha" FROM "alpha" AS "t1" ' + 'WHERE ("t1"."alpha" > ?) ' + 'UNION ' + 'SELECT "t2"."beta" FROM "beta" AS "t2" ' + 'WHERE ("t2"."beta" < ?)'), [1, 10]) + + def test_intersect_method(self): + lhs = Alpha.select(Alpha.alpha) + rhs = Beta.select(Beta.beta) + query = lhs.intersect(rhs) + self.assertSQL(query, ( + 'SELECT "t1"."alpha" FROM "alpha" AS "t1" ' + 'INTERSECT ' + 'SELECT "t2"."beta" FROM "beta" AS "t2"'), []) + + def test_except_method(self): + lhs = Alpha.select(Alpha.alpha) + rhs = Beta.select(Beta.beta) + query = lhs.except_(rhs) + self.assertSQL(query, ( + 'SELECT "t1"."alpha" FROM "alpha" AS "t1" ' + 'EXCEPT ' + 'SELECT "t2"."beta" FROM "beta" AS "t2"'), []) + # =========================================================================== # Model index SQL and miscellaneous diff --git a/tests/models.py b/tests/models.py index 3375b774..314eb4a2 100644 --- a/tests/models.py +++ b/tests/models.py @@ -3229,6 +3229,158 @@ class TestSelectValueConversion(ModelTestCase): self.assertEqual(u1_id, str(u1.id)) +class TestOrWhere(ModelTestCase): + requires = [User] + + def test_orwhere(self): + User.insert_many([{'username': u} for u in + ('huey', 'mickey', 'zaizee')]).execute() + query = (User + .select() + .orwhere(User.username == 'huey') + .orwhere(User.username == 'zaizee') + .order_by(User.username)) + self.assertEqual([u.username for u in query], ['huey', 'zaizee']) + + def test_where_then_orwhere(self): + User.insert_many([{'username': u} for u in + ('huey', 'mickey', 'zaizee')]).execute() + # where + orwhere: the where is OR'd with subsequent orwhere calls. + query = (User + .select() + .where(User.username == 'huey') + .orwhere(User.username == 'zaizee') + .order_by(User.username)) + self.assertEqual([u.username for u in query], ['huey', 'zaizee']) + + +class TestEnsureJoin(ModelTestCase): + requires = [User, Tweet] + + def test_ensure_join_noop(self): + """If join already exists, ensure_join doesn't duplicate it.""" + u = User.create(username='huey') + Tweet.create(user=u, content='meow') + + query = (User + .select(User, Tweet.content) + .join(Tweet) + .switch(User) + .ensure_join(User, Tweet)) + result = [(row.username, row.tweet.content) for row in query] + self.assertEqual(result, [('huey', 'meow')]) + + def test_ensure_join_adds_when_missing(self): + """If join doesn't exist, ensure_join adds it.""" + u = User.create(username='huey') + Tweet.create(user=u, content='meow') + + query = (User + .select(User, Tweet.content) + .ensure_join(User, Tweet)) + result = [(row.username, row.tweet.content) for row in query] + self.assertEqual(result, [('huey', 'meow')]) + + +class TestScalarIntegration(ModelTestCase): + requires = [User, Sample] + + @requires_models(User) + def test_scalar(self): + for u in ('huey', 'mickey', 'zaizee'): + User.create(username=u) + count = User.select(fn.COUNT(User.id)).scalar() + self.assertEqual(count, 3) + + @requires_models(User) + def test_scalar_as_tuple(self): + for u in ('huey', 'mickey', 'zaizee'): + User.create(username=u) + count, mx = (User + .select(fn.COUNT(User.id), fn.MAX(User.id)) + .scalar(as_tuple=True)) + self.assertEqual(count, 3) + self.assertTrue(mx > 0) + + @requires_models(User) + def test_scalar_as_dict(self): + for u in ('huey', 'mickey', 'zaizee'): + User.create(username=u) + result = (User + .select(fn.COUNT(User.id).alias('ct')) + .scalar(as_dict=True)) + self.assertEqual(result, {'ct': 3}) + + @requires_models(Sample) + def test_scalar_empty_result(self): + val = (Sample + .select(fn.MAX(Sample.value)) + .scalar()) + self.assertTrue(val is None) + + +class TestExistsIntegration(ModelTestCase): + requires = [User] + + @requires_models(User) + def test_exists_true(self): + User.create(username='huey') + self.assertTrue( + User.select().where(User.username == 'huey').exists()) + + @requires_models(User) + def test_exists_false(self): + User.create(username='huey') + self.assertFalse( + User.select().where(User.username == 'nobody').exists()) + + +class TestObjectsIntegration(ModelTestCase): + requires = [User, Tweet] + + @requires_models(User, Tweet) + def test_objects_returns_target_model(self): + u = User.create(username='huey') + Tweet.create(user=u, content='meow') + Tweet.create(user=u, content='purr') + + query = (Tweet + .select(Tweet, User) + .join(User) + .order_by(Tweet.content) + .objects()) + results = list(query) + # .objects() maps all columns onto the Tweet model. + self.assertEqual(len(results), 2) + self.assertTrue(isinstance(results[0], Tweet)) + self.assertEqual(results[0].content, 'meow') + self.assertEqual(results[0].username, 'huey') + self.assertEqual(results[1].content, 'purr') + + +class TestRowTypeIntegration(ModelTestCase): + requires = [User] + + @requires_models(User) + def test_tuples(self): + User.create(username='huey') + result = list(User.select(User.username).tuples()) + self.assertEqual(result, [('huey',)]) + + @requires_models(User) + def test_dicts(self): + User.create(username='huey') + result = list(User.select(User.username).dicts()) + self.assertEqual(result, [{'username': 'huey'}]) + + @requires_models(User) + def test_namedtuples(self): + User.create(username='huey') + result = list(User.select(User.username).namedtuples()) + self.assertEqual(len(result), 1) + self.assertEqual(result[0].username, 'huey') + + class VL(TestModel): n = IntegerField() s = CharField()