diff --git a/peewee.py b/peewee.py index e9251209..46c9d214 100644 --- a/peewee.py +++ b/peewee.py @@ -2633,11 +2633,15 @@ class Insert(_WriteQuery): if col not in seen: columns.append(col) + nullable_columns = set() value_lookups = {} for column in columns: lookups = [column, column.name] - if isinstance(column, Field) and column.name != column.column_name: - lookups.append(column.column_name) + if isinstance(column, Field): + if column.name != column.column_name: + lookups.append(column.column_name) + if column.null: + nullable_columns.add(column) value_lookups[column] = lookups ctx.sql(EnclosedNodeList(columns)).literal(' VALUES ') @@ -2671,6 +2675,8 @@ class Insert(_WriteQuery): val = defaults[column] if callable_(val): val = val() + elif column in nullable_columns: + val = None else: raise ValueError('Missing value for %s.' % column.name) diff --git a/tests/base_models.py b/tests/base_models.py index 22be93a8..3fe75b1f 100644 --- a/tests/base_models.py +++ b/tests/base_models.py @@ -104,3 +104,10 @@ class UKVP(TestModel): indexes = [ SQL('CREATE UNIQUE INDEX "ukvp_kve" ON "ukvp" ("key", "value") ' 'WHERE "extra" > 1')] + + +class DfltM(TestModel): + name = CharField() + dflt1 = IntegerField(default=1) + dflt2 = IntegerField(default=lambda: 2) + dfltn = IntegerField(null=True) diff --git a/tests/model_sql.py b/tests/model_sql.py index e8b35f37..c3fdb20d 100644 --- a/tests/model_sql.py +++ b/tests/model_sql.py @@ -23,7 +23,7 @@ class CKM(TestModel): class TestModelSQL(ModelDatabaseTestCase): database = get_in_memory_db() - requires = [Category, CKM, Note, Person, Relationship, Sample, User] + requires = [Category, CKM, Note, Person, Relationship, Sample, User, DfltM] def test_select(self): query = (Person @@ -440,6 +440,21 @@ class TestModelSQL(ModelDatabaseTestCase): 'INSERT INTO "sample" ("counter", "value") VALUES (?, ?), (?, ?)'), [3, 1., 2, 2.]) + def test_insert_many_defaults_nulls(self): + data = [ + {'name': 'd1'}, + {'name': 'd2', 'dflt1': 10}, + {'name': 'd3', 'dflt2': 30}, + {'name': 'd4', 'dfltn': 40}] + fields = [DfltM.name, DfltM.dflt1, DfltM.dflt2, DfltM.dfltn] + self.assertSQL(DfltM.insert_many(data, fields=fields), ( + 'INSERT INTO "dflt_m" ("name", "dflt1", "dflt2", "dfltn") VALUES ' + '(?, ?, ?, ?), (?, ?, ?, ?), (?, ?, ?, ?), (?, ?, ?, ?)'), + ['d1', 1, 2, None, + 'd2', 10, 2, None, + 'd3', 1, 30, None, + 'd4', 1, 2, 40]) + def test_insert_many_list_with_fields(self): data = [(i,) for i in ('charlie', 'huey', 'zaizee')] query = User.insert_many(data, fields=[User.username]) diff --git a/tests/models.py b/tests/models.py index 641aa184..fbf22531 100644 --- a/tests/models.py +++ b/tests/models.py @@ -170,6 +170,25 @@ class TestModelAPIs(ModelTestCase): names = [u.username for u in User.select().order_by(User.username)] self.assertEqual(names, ['u%02d' % i for i in range(100)]) + @requires_models(DfltM) + def test_insert_many_defaults_nullable(self): + data = [ + {'name': 'd1'}, + {'name': 'd2', 'dflt1': 10}, + {'name': 'd3', 'dflt2': 30}, + {'name': 'd4', 'dfltn': 40}] + fields = [DfltM.name, DfltM.dflt1, DfltM.dflt2, DfltM.dfltn] + DfltM.insert_many(data, fields).execute() + + expected = [ + ('d1', 1, 2, None), + ('d2', 10, 2, None), + ('d3', 1, 30, None), + ('d4', 1, 2, 40)] + query = DfltM.select().order_by(DfltM.name) + actual = [(d.name, d.dflt1, d.dflt2, d.dfltn) for d in query] + self.assertEqual(actual, expected) + @requires_models(User, Tweet) def test_create(self): with self.assertQueryCount(1):