Remove Python 2.x compatibility code. So long, old friend.

This commit is contained in:
Charles Leifer
2026-03-09 10:39:07 -05:00
parent 1da0bc72e3
commit ac959db5b4
24 changed files with 121 additions and 2612 deletions
+2
View File
@@ -7,6 +7,8 @@ https://github.com/coleifer/peewee/releases
## master
* Remove all Python 2.x compatibility code.
[View commits](https://github.com/coleifer/peewee/compare/4.0.1...master)
## 4.0.1
+55 -92
View File
@@ -1,7 +1,10 @@
from bisect import bisect_left
from bisect import bisect_right
from collections.abc import Callable
from collections.abc import Mapping
from contextlib import contextmanager
from copy import deepcopy
from functools import reduce
from functools import wraps
from inspect import isclass
import calendar
@@ -20,18 +23,11 @@ import threading
import time
import uuid
import warnings
try:
from collections.abc import Mapping
except ImportError:
from collections import Mapping
try:
from pysqlite3 import dbapi2 as pysq3
except ImportError:
try:
from pysqlite2 import dbapi2 as pysq3
except ImportError:
pysq3 = None
pysq3 = None
try:
import sqlite3
except ImportError:
@@ -154,48 +150,17 @@ __all__ = [
'Window',
]
try: # Python 2.7+
from logging import NullHandler
except ImportError:
class NullHandler(logging.Handler):
def emit(self, record):
pass
logger = logging.getLogger('peewee')
logger.addHandler(NullHandler())
logger.addHandler(logging.NullHandler())
if sys.version_info[0] == 2:
text_type = unicode
bytes_type = str
buffer_type = buffer
izip_longest = itertools.izip_longest
callable_ = callable
multi_types = (list, tuple, frozenset, set)
exec('def reraise(tp, value, tb=None): raise tp, value, tb')
def print_(s):
sys.stdout.write(s)
sys.stdout.write('\n')
else:
import builtins
try:
from collections.abc import Callable
except ImportError:
from collections import Callable
from functools import reduce
callable_ = lambda c: isinstance(c, Callable)
text_type = str
bytes_type = bytes
buffer_type = memoryview
basestring = str
long = int
multi_types = (list, tuple, frozenset, set, range)
print_ = getattr(builtins, 'print')
izip_longest = itertools.zip_longest
def reraise(tp, value, tb=None):
if value.__traceback__ is not tb:
raise value.with_traceback(tb)
raise value
callable_ = lambda c: isinstance(c, Callable)
multi_types = (list, tuple, frozenset, set, range)
def reraise(tp, value, tb=None):
if value.__traceback__ is not tb:
raise value.with_traceback(tb)
raise value
# Other compat issues.
if sys.version_info < (3, 12):
@@ -444,8 +409,8 @@ def make_snake_case(s):
def chunked(it, n):
marker = object()
for group in (list(g) for g in izip_longest(*[iter(it)] * n,
fillvalue=marker)):
groups = itertools.zip_longest(*[iter(it)] * n, fillvalue=marker)
for group in (list(g) for g in groups):
if group[-1] is marker:
del group[group.index(marker):]
yield group
@@ -735,10 +700,10 @@ def query_to_string(query):
def _query_val_transform(v):
# Interpolate parameters.
if isinstance(v, (text_type, datetime.datetime, datetime.date,
if isinstance(v, (str, datetime.datetime, datetime.date,
datetime.time)):
v = "'%s'" % v
elif isinstance(v, bytes_type):
elif isinstance(v, bytes):
try:
v = v.decode('utf8')
except UnicodeDecodeError:
@@ -1133,7 +1098,7 @@ class CTE(_HashableSource, Source):
self._recursive = recursive
self._materialized = materialized
if columns is not None:
columns = [Entity(c) if isinstance(c, basestring) else c
columns = [Entity(c) if isinstance(c, str) else c
for c in columns]
self._columns = columns
query._cte_list = ()
@@ -1769,7 +1734,7 @@ class Window(Node):
@Node.copy
def exclude(self, frame_exclusion=None):
if isinstance(frame_exclusion, basestring):
if isinstance(frame_exclusion, str):
frame_exclusion = SQL(frame_exclusion)
self._exclude = frame_exclusion
@@ -1796,7 +1761,7 @@ class Window(Node):
ext = self._extends
if isinstance(ext, Window):
ext = SQL(ext._alias)
elif isinstance(ext, basestring):
elif isinstance(ext, str):
ext = SQL(ext)
parts.append(ext)
if self.partition_by:
@@ -2778,7 +2743,7 @@ class Insert(_WriteQuery):
# Infer column names from the dict of data being inserted.
accum = []
for column in row:
if isinstance(column, basestring):
if isinstance(column, str):
column = getattr(self.table, column)
accum.append(column)
@@ -2794,7 +2759,7 @@ class Insert(_WriteQuery):
clean_columns = []
seen = set()
for column in columns:
if isinstance(column, basestring):
if isinstance(column, str):
column_obj = getattr(self.table, column)
else:
column_obj = column
@@ -3010,7 +2975,7 @@ class Index(Node):
ctx.literal('USING %s ' % self._using) # Postgres/default.
ctx.sql(EnclosedNodeList([
SQL(expr) if isinstance(expr, basestring) else expr
SQL(expr) if isinstance(expr, str) else expr
for expr in self._expressions]))
if self._where is not None:
ctx.literal(' WHERE ').sql(self._where)
@@ -3045,7 +3010,7 @@ class ModelIndex(Index):
def _generate_name_from_fields(self, model, fields):
accum = []
for field in fields:
if isinstance(field, basestring):
if isinstance(field, str):
accum.append(field.split()[0])
else:
if isinstance(field, Node) and not isinstance(field, Field):
@@ -3378,7 +3343,7 @@ class Database(_callable_context_manager):
if on_conflict._conflict_target:
stmt = SQL('ON CONFLICT')
target = EnclosedNodeList([
Entity(col) if isinstance(col, basestring) else col
Entity(col) if isinstance(col, str) else col
for col in on_conflict._conflict_target])
if on_conflict._conflict_where is not None:
target = NodeList([target, SQL('WHERE'),
@@ -3386,7 +3351,7 @@ class Database(_callable_context_manager):
else:
stmt = SQL('ON CONFLICT ON CONSTRAINT')
target = on_conflict._conflict_constraint
if isinstance(target, basestring):
if isinstance(target, str):
target = Entity(target)
updates = []
@@ -3403,7 +3368,7 @@ class Database(_callable_context_manager):
if not isinstance(v, Node):
# Attempt to resolve string field-names to their respective
# field object, to apply data-type conversions.
if isinstance(k, basestring):
if isinstance(k, str):
k = getattr(query.table, k)
if isinstance(k, Field):
v = k.to_value(v)
@@ -4294,7 +4259,7 @@ class PostgresqlDatabase(Database):
parts = [SQL('ON CONFLICT')]
if oc._conflict_target:
parts.append(EnclosedNodeList([
Entity(col) if isinstance(col, basestring) else col
Entity(col) if isinstance(col, str) else col
for col in oc._conflict_target]))
parts.append(SQL('DO NOTHING'))
return NodeList(parts)
@@ -4534,7 +4499,7 @@ class MySQLDatabase(Database):
if not isinstance(v, Node):
# Attempt to resolve string field-names to their respective
# field object, to apply data-type conversions.
if isinstance(k, basestring):
if isinstance(k, str):
k = getattr(query.table, k)
if isinstance(k, Field):
v = k.to_value(v)
@@ -5137,7 +5102,7 @@ class DecimalField(Field):
if not value:
return value if value is None else D(0)
if self.auto_round:
decimal_value = D(text_type(value))
decimal_value = D(str(value))
return decimal_value.quantize(self._exp, rounding=self.rounding)
return value
@@ -5145,16 +5110,16 @@ class DecimalField(Field):
if value is not None:
if isinstance(value, decimal.Decimal):
return value
return decimal.Decimal(text_type(value))
return decimal.Decimal(str(value))
class _StringField(Field):
def adapt(self, value):
if isinstance(value, text_type):
if isinstance(value, str):
return value
elif isinstance(value, bytes_type):
elif isinstance(value, bytes):
return value.decode('utf-8')
return text_type(value)
return str(value)
def __add__(self, other): return StringExpression(self, OP.CONCAT, other)
def __radd__(self, other): return StringExpression(other, OP.CONCAT, self)
@@ -5216,9 +5181,9 @@ class BlobField(FieldDatabaseHook, Field):
self._constructor = database.get_binary_type()
def db_value(self, value):
if isinstance(value, text_type):
if isinstance(value, str):
value = value.encode('raw_unicode_escape')
if isinstance(value, bytes_type):
if isinstance(value, bytes):
return self._constructor(value)
return value
@@ -5353,10 +5318,10 @@ class BigBitFieldData(object):
return repr(self._buffer)
if sys.version_info[0] < 3:
def __str__(self):
return bytes_type(self._buffer)
return bytes(self._buffer)
else:
def __bytes__(self):
return bytes_type(self._buffer)
return bytes(self._buffer)
class BigBitFieldAccessor(FieldAccessor):
@@ -5367,15 +5332,13 @@ class BigBitFieldAccessor(FieldAccessor):
def __set__(self, instance, value):
if isinstance(value, memoryview):
value = value.tobytes()
elif isinstance(value, buffer_type):
value = bytes(value)
elif isinstance(value, bytearray):
value = bytes_type(value)
value = bytes(value)
elif isinstance(value, BigBitFieldData):
value = bytes_type(value._buffer)
elif isinstance(value, text_type):
value = bytes(value._buffer)
elif isinstance(value, str):
value = value.encode('utf-8')
elif not isinstance(value, bytes_type):
elif not isinstance(value, bytes):
raise ValueError('Value must be either a bytes, memoryview or '
'BigBitFieldData instance.')
super(BigBitFieldAccessor, self).__set__(instance, value)
@@ -5385,18 +5348,18 @@ class BigBitField(BlobField):
accessor_class = BigBitFieldAccessor
def __init__(self, *args, **kwargs):
kwargs.setdefault('default', bytes_type)
kwargs.setdefault('default', bytes)
super(BigBitField, self).__init__(*args, **kwargs)
def db_value(self, value):
return bytes_type(value) if value is not None else value
return bytes(value) if value is not None else value
class UUIDField(Field):
field_type = 'UUID'
def db_value(self, value):
if isinstance(value, basestring) and len(value) == 32:
if isinstance(value, str) and len(value) == 32:
# Hex string. No transformation is necessary.
return value
elif isinstance(value, bytes) and len(value) == 16:
@@ -5422,7 +5385,7 @@ class BinaryUUIDField(BlobField):
if isinstance(value, bytes) and len(value) == 16:
# Raw binary value. No transformation is necessary.
return self._constructor(value)
elif isinstance(value, basestring) and len(value) == 32:
elif isinstance(value, str) and len(value) == 32:
# Allow hex string representation.
value = uuid.UUID(hex=value)
if isinstance(value, uuid.UUID):
@@ -5482,7 +5445,7 @@ class DateTimeField(_BaseFormattedField):
]
def adapt(self, value):
if value and isinstance(value, basestring):
if value and isinstance(value, str):
return format_date_time(value, self.formats)
return value
@@ -5509,7 +5472,7 @@ class DateField(_BaseFormattedField):
]
def adapt(self, value):
if value and isinstance(value, basestring):
if value and isinstance(value, str):
pp = lambda x: x.date()
return format_date_time(value, self.formats, pp)
elif value and isinstance(value, datetime.datetime):
@@ -5539,7 +5502,7 @@ class TimeField(_BaseFormattedField):
def adapt(self, value):
if value:
if isinstance(value, basestring):
if isinstance(value, str):
pp = lambda x: x.time()
return format_date_time(value, self.formats, pp)
elif isinstance(value, datetime.datetime):
@@ -5623,7 +5586,7 @@ class TimestampField(BigIntegerField):
return int(round(timestamp))
def python_value(self, value):
if value is not None and isinstance(value, (int, float, long)):
if value is not None and isinstance(value, (int, float)):
if self.resolution > 1:
value, ticks = divmod(value, self.resolution)
microseconds = int(ticks * self.ticks_to_microsecond)
@@ -5762,7 +5725,7 @@ class ForeignKeyField(Field):
'name.' % (model._meta.name, name))
if self._is_self_reference:
self.rel_model = model
if isinstance(self.rel_field, basestring):
if isinstance(self.rel_field, str):
self.rel_field = getattr(self.rel_model, self.rel_field)
elif self.rel_field is None:
self.rel_field = self.rel_model._meta.primary_key
@@ -6168,7 +6131,7 @@ class SchemaManager(object):
if meta.table_settings is not None:
table_settings = ensure_tuple(meta.table_settings)
for setting in table_settings:
if not isinstance(setting, basestring):
if not isinstance(setting, str):
raise ValueError('table_settings must be strings')
ctx.literal(' ').literal(setting)
@@ -6619,7 +6582,7 @@ class Metadata(object):
index_parts, unique = index_obj
fields = []
for part in index_parts:
if isinstance(part, basestring):
if isinstance(part, str):
fields.append(self.combined[part])
elif isinstance(part, Node):
fields.append(part)
@@ -6884,7 +6847,7 @@ class Model(with_metaclass(ModelBase, Node)):
@classmethod
def insert_from(cls, query, fields):
columns = [getattr(cls, field) if isinstance(field, basestring)
columns = [getattr(cls, field) if isinstance(field, str)
else field for field in fields]
return ModelInsert(cls, insert=query, columns=columns)
@@ -6954,7 +6917,7 @@ class Model(with_metaclass(ModelBase, Node)):
'a composite primary key.')
# First normalize list of fields so all are field instances.
fields = [cls._meta.fields[f] if isinstance(f, basestring) else f
fields = [cls._meta.fields[f] if isinstance(f, str) else f
for f in fields]
# Now collect list of attribute names to use for values.
attrs = [field.object_id_name if isinstance(field, ForeignKeyField)
@@ -7071,7 +7034,7 @@ class Model(with_metaclass(ModelBase, Node)):
def _prune_fields(self, field_dict, only):
new_data = {}
for field in only:
if isinstance(field, basestring):
if isinstance(field, str):
field = self._meta.combined[field]
if field.name in field_dict:
new_data[field.name] = field_dict[field.name]
+2 -11
View File
@@ -2,13 +2,10 @@
from libc.stdlib cimport free, malloc
from libc.math cimport log, sqrt
import sys
from difflib import SequenceMatcher
from random import randint
IS_PY3K = sys.version_info[0] == 3
# FTS ranking functions.
cdef double *get_weights(int ncol, tuple raw_weights):
@@ -249,10 +246,7 @@ def damerau_levenshtein_dist(s1, s2):
list one_ago, two_ago, current_row
list zeroes = [0] * (s2_len + 1)
if IS_PY3K:
current_row = list(range(1, s2_len + 2))
else:
current_row = range(1, s2_len + 2)
current_row = list(range(1, s2_len + 2))
current_row[-1] = 0
one_ago = None
@@ -290,10 +284,7 @@ def levenshtein_dist(a, b):
zeroes = [0] * (m + 1)
if IS_PY3K:
current = list(range(n + 1))
else:
current = range(n + 1)
current = list(range(n + 1))
for i in range(1, m + 1):
previous = current
+1 -4
View File
@@ -20,9 +20,6 @@ try:
except ImportError: # psycopg2 not installed, ignore.
ArrayField = BinaryJSONField = IntervalField = JSONField = None
if sys.version_info[0] > 2:
basestring = str
NESTED_TX_MIN_VERSION = 200100
@@ -119,7 +116,7 @@ class CockroachDatabase(PostgresqlDatabase):
parts = [SQL('ON CONFLICT')]
if oc._conflict_target:
parts.append(EnclosedNodeList([
Entity(col) if isinstance(col, basestring) else col
Entity(col) if isinstance(col, str) else col
for col in oc._conflict_target]))
parts.append(SQL('DO NOTHING'))
return NodeList(parts)
+6 -18
View File
@@ -1,15 +1,12 @@
import base64
import csv
import datetime
from decimal import Decimal
import json
import operator
try:
from urlparse import urlparse
except ImportError:
from urllib.parse import urlparse
import sys
import uuid
from decimal import Decimal
from functools import reduce
from urllib.parse import urlparse
from peewee import *
from playhouse.db_url import connect
@@ -17,15 +14,6 @@ from playhouse.migrate import migrate
from playhouse.migrate import SchemaMigrator
from playhouse.reflection import Introspector
if sys.version_info[0] == 3:
basestring = str
from functools import reduce
def open_file(f, mode, encoding='utf8'):
return open(f, mode, encoding=encoding)
else:
def open_file(f, mode, encoding='utf8'):
return open(f, mode)
class DataSet(object):
def __init__(self, url, include_views=False, **kwargs):
@@ -164,7 +152,7 @@ class DataSet(object):
encoding='utf8', **kwargs):
self._check_arguments(filename, file_obj, format, self._export_formats)
if filename:
file_obj = open_file(filename, 'w', encoding)
file_obj = open(filename, 'w', encoding=encoding)
exporter = self._export_formats[format](query)
exporter.export(file_obj, **kwargs)
@@ -176,7 +164,7 @@ class DataSet(object):
strict=False, encoding='utf8', **kwargs):
self._check_arguments(filename, file_obj, format, self._export_formats)
if filename:
file_obj = open_file(filename, 'r', encoding)
file_obj = open(filename, 'r', encoding=encoding)
importer = self._import_formats[format](self[table], strict)
count = importer.load(file_obj, **kwargs)
@@ -223,7 +211,7 @@ class Table(object):
self.dataset._database.execute(index)
def _guess_field_type(self, value):
if isinstance(value, basestring):
if isinstance(value, str):
return TextField
if isinstance(value, (datetime.date, datetime.datetime)):
return DateTimeField
+1 -4
View File
@@ -1,7 +1,4 @@
try:
from urlparse import parse_qsl, unquote, urlparse
except ImportError:
from urllib.parse import parse_qsl, unquote, urlparse
from urllib.parse import parse_qsl, unquote, urlparse
from peewee import *
from playhouse.pool import PooledMySQLDatabase
+1 -7
View File
@@ -1,3 +1,4 @@
import pickle
try:
import bz2
except ImportError:
@@ -6,13 +7,8 @@ try:
import zlib
except ImportError:
zlib = None
try:
import cPickle as pickle
except ImportError:
import pickle
from peewee import BlobField
from peewee import buffer_type
class CompressedField(BlobField):
@@ -50,8 +46,6 @@ class CompressedField(BlobField):
class PickleField(BlobField):
def python_value(self, value):
if value is not None:
if isinstance(value, buffer_type):
value = bytes(value)
return pickle.loads(value)
def db_value(self, value):
+4 -8
View File
@@ -1,11 +1,8 @@
try:
from collections import OrderedDict
except ImportError:
OrderedDict = dict
from collections import namedtuple
from inspect import isclass
import re
import warnings
from collections import OrderedDict
from collections import namedtuple
from inspect import isclass
from peewee import *
from peewee import _StringField
@@ -13,7 +10,6 @@ from peewee import _query_val_transform
from peewee import CommaNodeList
from peewee import SCOPE_VALUES
from peewee import make_snake_case
from peewee import text_type
try:
from pymysql.constants import FIELD_TYPE
except ImportError:
@@ -207,7 +203,7 @@ class Metadata(object):
default.lower() == 'null':
return
if issubclass(field_class, _StringField) and \
isinstance(default, text_type) and not default.startswith("'"):
isinstance(default, str) and not default.startswith("'"):
default = "'%s'" % default
return default or "''"
+4 -7
View File
@@ -49,13 +49,10 @@ import decimal
import sys
from peewee import *
if sys.version_info[0] != 3:
from pysqlcipher import dbapi2 as sqlcipher
else:
try:
from sqlcipher3 import dbapi2 as sqlcipher
except ImportError:
from pysqlcipher3 import dbapi2 as sqlcipher
try:
from sqlcipher3 import dbapi2 as sqlcipher
except ImportError:
from pysqlcipher3 import dbapi2 as sqlcipher
sqlcipher.register_adapter(decimal.Decimal, str)
sqlcipher.register_adapter(datetime.date, str)
+1 -4
View File
@@ -20,9 +20,6 @@ from playhouse.sqlite_udf import RANK
from playhouse.sqlite_udf import register_udf_groups
if sys.version_info[0] == 3:
basestring = str
FTS3_MATCHINFO = 'pcx'
FTS4_MATCHINFO = 'pcnalx'
@@ -389,7 +386,7 @@ class BaseFTSModel(VirtualModel):
tokenize = options.get('tokenize')
content_rowid = options.get('content_rowid')
if isinstance(content, basestring) and content == '':
if isinstance(content, str) and content == '':
# Special-case content-less full-text search tables.
options['content'] = "''"
elif isinstance(content, Field):
+11 -28
View File
@@ -174,24 +174,15 @@ def file_read(filename):
except:
pass
if sys.version_info[0] == 2:
@udf(HELPER)
def gzip(data, compression=9):
return buffer(zlib.compress(data, compression))
@udf(HELPER)
def gzip(data, compression=9):
if isinstance(data, str):
data = bytes(data.encode('raw_unicode_escape'))
return zlib.compress(data, compression)
@udf(HELPER)
def gunzip(data):
return zlib.decompress(data)
else:
@udf(HELPER)
def gzip(data, compression=9):
if isinstance(data, str):
data = bytes(data.encode('raw_unicode_escape'))
return zlib.compress(data, compression)
@udf(HELPER)
def gunzip(data):
return zlib.decompress(data)
@udf(HELPER)
def gunzip(data):
return zlib.decompress(data)
@udf(HELPER)
def hostname(url):
@@ -319,14 +310,6 @@ class _datetime_heap_agg(_heap_agg):
def process(self, value):
return format_date_time_sqlite(value)
if sys.version_info[:2] == (2, 6):
def total_seconds(td):
return (td.seconds +
(td.days * 86400) +
(td.microseconds / (10.**6)))
else:
total_seconds = lambda td: td.total_seconds()
@aggregate(DATE)
class mintdiff(_datetime_heap_agg):
def finalize(self):
@@ -342,7 +325,7 @@ class mintdiff(_datetime_heap_agg):
min_diff = diff
dtp = dt
if min_diff is not None:
return total_seconds(min_diff)
return min_diff.total_seconds()
@aggregate(DATE)
class avgtdiff(_datetime_heap_agg):
@@ -363,7 +346,7 @@ class avgtdiff(_datetime_heap_agg):
dt = heapq.heappop(self.heap)
diff = dt - dtp
ct += 1
total += total_seconds(diff)
total += diff.total_seconds()
dtp = dt
return float(total) / ct
@@ -383,7 +366,7 @@ class duration(object):
def finalize(self):
if self._min and self._max:
td = (self._max - self._min)
return total_seconds(td)
return td.total_seconds()
return None
@aggregate(MATH)
+1 -4
View File
@@ -1,13 +1,10 @@
import logging
import weakref
from queue import Queue
from threading import local as thread_local
from threading import Event
from threading import Lock
from threading import Thread
try:
from Queue import Queue
except ImportError:
from queue import Queue
try:
import gevent
+23 -24
View File
@@ -7,7 +7,6 @@ from getpass import getpass
from optparse import OptionParser
from peewee import *
from peewee import print_
from peewee import __version__ as peewee_version
from playhouse.cockroachdb import CockroachDatabase
from playhouse.reflection import *
@@ -63,12 +62,12 @@ def print_models(introspector, tables=None, preserve_order=False,
introspector.get_database_class().__name__,
introspector.get_database_name().replace('\\', '\\\\'),
', **%s' % repr(db_kwargs) if db_kwargs else '')
print_(header)
print(header)
if not ignore_unknown:
print_(UNKNOWN_FIELD)
print(UNKNOWN_FIELD)
print_(BASE_MODEL)
print(BASE_MODEL)
def _print_table(table, seen, accum=None):
accum = accum or []
@@ -79,7 +78,7 @@ def print_models(introspector, tables=None, preserve_order=False,
# In the event the destination table has already been pushed
# for printing, then we have a reference cycle.
if dest in accum and table not in accum:
print_('# Possible reference cycle: %s' % dest)
print('# Possible reference cycle: %s' % dest)
# If this is not a self-referential foreign key, and we have
# not already processed the destination table, do so now.
@@ -88,7 +87,7 @@ def print_models(introspector, tables=None, preserve_order=False,
if dest != table:
_print_table(dest, seen, accum + [table])
print_('class %s(BaseModel):' % database.model_names[table])
print('class %s(BaseModel):' % database.model_names[table])
columns = database.columns[table].items()
if not preserve_order:
columns = sorted(columns)
@@ -109,34 +108,34 @@ def print_models(introspector, tables=None, preserve_order=False,
is_unknown = column.field_class is UnknownField
if is_unknown and ignore_unknown:
disp = '%s - %s' % (column.name, column.raw_column_type or '?')
print_(' # %s' % disp)
print(' # %s' % disp)
else:
print_(' %s' % column.get_field())
print(' %s' % column.get_field())
print_('')
print_(' class Meta:')
print_(' table_name = \'%s\'' % table)
print('')
print(' class Meta:')
print(' table_name = \'%s\'' % table)
multi_column_indexes = database.multi_column_indexes(table)
if multi_column_indexes:
print_(' indexes = (')
print(' indexes = (')
for fields, unique in sorted(multi_column_indexes):
print_(' ((%s), %s),' % (
print(' ((%s), %s),' % (
', '.join("'%s'" % field for field in fields),
unique,
))
print_(' )')
print(' )')
if introspector.schema:
print_(' schema = \'%s\'' % introspector.schema)
print(' schema = \'%s\'' % introspector.schema)
if len(primary_keys) > 1:
pk_field_names = sorted([
field.name for col, field in columns
if col in primary_keys])
pk_list = ', '.join("'%s'" % pk for pk in pk_field_names)
print_(' primary_key = CompositeKey(%s)' % pk_list)
print(' primary_key = CompositeKey(%s)' % pk_list)
elif not primary_keys:
print_(' primary_key = False')
print_('')
print(' primary_key = False')
print('')
seen.add(table)
@@ -148,12 +147,12 @@ def print_models(introspector, tables=None, preserve_order=False,
def print_header(cmd_line, introspector):
timestamp = datetime.datetime.now()
print_('# Code generated by:')
print_('# python -m pwiz %s' % cmd_line)
print_('# Date: %s' % timestamp.strftime('%B %d, %Y %I:%M%p'))
print_('# Database: %s' % introspector.get_database_name())
print_('# Peewee version: %s' % peewee_version)
print_('')
print('# Code generated by:')
print('# python -m pwiz %s' % cmd_line)
print('# Date: %s' % timestamp.strftime('%B %d, %Y %I:%M%p'))
print('# Database: %s' % introspector.get_database_name())
print('# Peewee version: %s' % peewee_version)
print('')
def err(msg):
+1 -2
View File
@@ -80,8 +80,7 @@ except (ImportError, SyntaxError):
if __name__ == '__main__':
from peewee import print_
print_(r"""\x1b[1;31m
print(r"""\x1b[1;31m
______ ______ ______ __ __ ______ ______
/\ == \ /\ ___\ /\ ___\ /\ \ _ \ \ /\ ___\ /\ ___\\
\ \ _-/ \ \ __\ \ \ __\ \ \ \/ ".\ \ \ \ __\ \ \ __\\
-4
View File
@@ -5,10 +5,6 @@ import logging
import os
import re
import unittest
try:
from unittest import mock
except ImportError:
from .libs import mock
from peewee import *
from peewee import sqlite3
+1 -6
View File
@@ -3,12 +3,8 @@ import datetime
import json
import operator
import os
import sys
import tempfile
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
from io import StringIO
from peewee import *
from playhouse.dataset import DataSet
@@ -402,7 +398,6 @@ class TestDataSet(ModelTestCase):
'charlie',
'huey'])
@skip_if(sys.version_info[0] < 3, 'requires python 3.x')
def test_freeze_thaw_csv_utf8(self):
self._test_freeze_thaw_utf8('csv')
+1 -4
View File
@@ -1,8 +1,5 @@
from itertools import permutations
try:
from Queue import Queue
except ImportError:
from queue import Queue
from queue import Queue
import platform
import re
import threading
+2 -3
View File
@@ -6,7 +6,6 @@ import uuid
from decimal import Decimal as D
from decimal import ROUND_UP
from peewee import bytes_type
from peewee import NodeList
from peewee import *
@@ -666,10 +665,10 @@ class TestBitFields(ModelTestCase):
for i in range(128):
b.data.clear_bit(i)
buf = bytes_type(b.data._buffer)
buf = bytes(b.data._buffer)
self.assertEqual(len(buf), 16)
self.assertEqual(bytes_type(buf), b'\x00' * 16)
self.assertEqual(bytes(buf), b'\x00' * 16)
def test_bigbit_zero_idx(self):
b = Bits()
View File
-2367
View File
File diff suppressed because it is too large Load Diff
+2 -7
View File
@@ -1,8 +1,8 @@
import datetime
import sys
import threading
import time
import unittest
from unittest import mock
from peewee import *
from peewee import Entity
@@ -12,7 +12,6 @@ from peewee import sort_models
from .base import db
from .base import get_in_memory_db
from .base import mock
from .base import new_connection
from .base import requires_models
from .base import requires_mysql
@@ -38,10 +37,6 @@ from .base import TestModel
from .base_models import *
if sys.version_info[0] >= 3:
long = int
class Color(TestModel):
name = CharField(primary_key=True)
is_neutral = BooleanField(default=False)
@@ -195,7 +190,7 @@ class TestModelAPIs(ModelTestCase):
with self.assertQueryCount(1):
huey = self.add_user('huey')
self.assertEqual(huey.username, 'huey')
self.assertTrue(isinstance(huey.id, (int, long)))
self.assertTrue(isinstance(huey.id, int))
self.assertTrue(huey.id > 0)
with self.assertQueryCount(1):
+2 -5
View File
@@ -1,11 +1,9 @@
import datetime
import os
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
import textwrap
import sys
from io import StringIO
from unittest import mock
from peewee import *
from pwiz import *
@@ -13,7 +11,6 @@ from pwiz import *
from .base import ModelTestCase
from .base import TestModel
from .base import db_loader
from .base import mock
from .base import skip_if
-2
View File
@@ -1,7 +1,6 @@
import datetime
import json
import random
import sys
import threading
import time
import uuid
@@ -1267,7 +1266,6 @@ class CharPKKV(TestModel):
class TestBulkUpdateNonIntegerPK(ModelTestCase):
@skip_if(sys.version_info[0] == 2)
@requires_models(UUIDReg)
def test_bulk_update_uuid_pk(self):
r1 = UUIDReg.create(key='k1')
-1
View File
@@ -1647,7 +1647,6 @@ class TestCollatedFieldDefinitions(ModelTestCase):
class TestReadOnly(ModelTestCase):
database = get_sqlite_db()
@skip_if(sys.version_info < (3, 4, 0), 'requres python >= 3.4.0')
@requires_models(User)
def test_read_only(self):
User.create(username='foo')