Files
sqlalchemy/test/perf/many_table_reflection.py
Federico Caselli 9b153ff18f Update black to 24.1.1
Change-Id: Iadaea7b798d8e99302e1acb430dc7b758ca61137
2024-01-31 21:54:59 +01:00

650 lines
18 KiB
Python

from argparse import ArgumentDefaultsHelpFormatter
from argparse import ArgumentParser
from collections import defaultdict
from contextlib import contextmanager
from functools import wraps
from pprint import pprint
import random
import time
import sqlalchemy as sa
from sqlalchemy.engine import Inspector
types = (sa.Integer, sa.BigInteger, sa.String(200), sa.DateTime)
USE_CONNECTION = False
def generate_table(meta: sa.MetaData, min_cols, max_cols, dialect_name):
col_number = random.randint(min_cols, max_cols)
table_num = len(meta.tables)
add_identity = random.random() > 0.90
identity = sa.Identity(
always=random.randint(0, 1),
start=random.randint(1, 100),
increment=random.randint(1, 7),
)
is_mssql = dialect_name == "mssql"
cols = []
for i in range(col_number - (0 if is_mssql else add_identity)):
args = []
if random.random() < 0.99 or table_num == 0:
if is_mssql and add_identity and i == 0:
args.append(sa.Integer)
args.append(identity)
else:
args.append(random.choice(types))
else:
target = random.randint(0, table_num - 1)
args.append(sa.ForeignKey(f"table_{target}.table_{target}_col_1"))
cols.append(
sa.Column(
f"table_{table_num}_col_{i + 1}",
*args,
primary_key=i == 0,
comment=(
f"primary key of table_{table_num}" if i == 0 else None
),
index=random.random() > 0.97 and i > 0,
unique=random.random() > 0.97 and i > 0,
)
)
if add_identity and not is_mssql:
cols.append(
sa.Column(
f"table_{table_num}_col_{col_number}",
sa.Integer,
identity,
)
)
args = ()
if table_num % 3 == 0:
# mysql can't do check constraint on PK col
args = (sa.CheckConstraint(cols[1].is_not(None)),)
return sa.Table(
f"table_{table_num}",
meta,
*cols,
*args,
comment=f"comment for table_{table_num}" if table_num % 2 else None,
)
def generate_meta(schema_name, table_number, min_cols, max_cols, dialect_name):
meta = sa.MetaData(schema=schema_name)
log = defaultdict(int)
for _ in range(table_number):
t = generate_table(meta, min_cols, max_cols, dialect_name)
log["tables"] += 1
log["columns"] += len(t.columns)
log["index"] += len(t.indexes)
log["check_con"] += len(
[c for c in t.constraints if isinstance(c, sa.CheckConstraint)]
)
log["foreign_keys_con"] += len(
[
c
for c in t.constraints
if isinstance(c, sa.ForeignKeyConstraint)
]
)
log["unique_con"] += len(
[c for c in t.constraints if isinstance(c, sa.UniqueConstraint)]
)
log["identity"] += len([c for c in t.columns if c.identity])
print("Meta info", dict(log))
return meta
def log(fn):
@wraps(fn)
def wrap(*a, **kw):
print("Running", fn.__name__, "...", flush=True, end="")
try:
r = fn(*a, **kw)
except NotImplementedError:
print(" [not implemented]", flush=True)
r = None
else:
print("... done", flush=True)
return r
return wrap
tests = {}
def define_test(fn):
name: str = fn.__name__
if name.startswith("reflect_"):
name = name[8:]
tests[name] = wfn = log(fn)
return wfn
@log
def create_tables(engine, meta):
tables = list(meta.tables.values())
for i in range(0, len(tables), 500):
meta.create_all(engine, tables[i : i + 500])
def _drop_ddl(name, schema_name, dialect_name):
if dialect_name.startswith("postgres"):
suffix = "CASCADE"
elif dialect_name.startswith("oracle"):
suffix = "CASCADE CONSTRAINTS PURGE"
else:
suffix = ""
if schema_name:
return sa.schema.DDL(f"DROP TABLE {schema_name}.{name} {suffix}")
else:
return sa.schema.DDL(f"DROP TABLE {name} {suffix}")
@log
def drop_tables(engine, meta, schema_name, table_names: list):
tables = list(meta.tables.values())[::-1]
for i in range(0, len(tables), 500):
meta.drop_all(engine, tables[i : i + 500])
remaining = sa.inspect(engine).get_table_names(schema=schema_name)
remaining = sorted(
remaining, key=lambda tn: int(tn.partition("_")[2]), reverse=True
)
with engine.connect() as conn:
for i, tn in enumerate(remaining):
if engine.dialect.requires_name_normalize:
name = engine.dialect.denormalize_name(tn)
else:
name = tn
conn.execute(_drop_ddl(name, schema_name, engine.dialect.name))
if i % 500 == 0:
conn.commit()
conn.commit()
@log
def reflect_tables(engine, schema_name):
ref_meta = sa.MetaData(schema=schema_name)
ref_meta.reflect(engine)
def verify_dict(multi, single, str_compare=False):
if single is None or multi is None:
return
if single != multi:
keys = set(single) | set(multi)
diff = []
for key in sorted(keys):
se, me = single.get(key), multi.get(key)
if str(se) != str(me) if str_compare else se != me:
diff.append((key, single.get(key), multi.get(key)))
if diff:
print("\nfound different result:")
pprint(diff)
def _single_test(
singe_fn_name,
multi_fn_name,
engine,
schema_name,
table_names,
timing,
mode,
):
single = None
if "single" in mode:
singe_fn = getattr(Inspector, singe_fn_name)
def go(bind):
insp = sa.inspect(bind)
single = {}
with timing(singe_fn.__name__):
for t in table_names:
single[(schema_name, t)] = singe_fn(
insp, t, schema=schema_name
)
return single
if USE_CONNECTION:
with engine.connect() as c:
single = go(c)
else:
single = go(engine)
multi = None
if "multi" in mode:
insp = sa.inspect(engine)
multi_fn = getattr(Inspector, multi_fn_name)
with timing(multi_fn.__name__):
multi = multi_fn(insp, schema=schema_name)
return (multi, single)
@define_test
def reflect_columns(
engine, schema_name, table_names, timing, mode, ignore_diff
):
multi, single = _single_test(
"get_columns",
"get_multi_columns",
engine,
schema_name,
table_names,
timing,
mode,
)
if not ignore_diff:
verify_dict(multi, single, str_compare=True)
@define_test
def reflect_table_options(
engine, schema_name, table_names, timing, mode, ignore_diff
):
multi, single = _single_test(
"get_table_options",
"get_multi_table_options",
engine,
schema_name,
table_names,
timing,
mode,
)
if not ignore_diff:
verify_dict(multi, single)
@define_test
def reflect_pk(engine, schema_name, table_names, timing, mode, ignore_diff):
multi, single = _single_test(
"get_pk_constraint",
"get_multi_pk_constraint",
engine,
schema_name,
table_names,
timing,
mode,
)
if not ignore_diff:
verify_dict(multi, single)
@define_test
def reflect_comment(
engine, schema_name, table_names, timing, mode, ignore_diff
):
multi, single = _single_test(
"get_table_comment",
"get_multi_table_comment",
engine,
schema_name,
table_names,
timing,
mode,
)
if not ignore_diff:
verify_dict(multi, single)
@define_test
def reflect_whole_tables(
engine, schema_name, table_names, timing, mode, ignore_diff
):
single = None
meta = sa.MetaData(schema=schema_name)
if "single" in mode:
def go(bind):
single = {}
with timing("Table_autoload_with"):
for name in table_names:
single[(None, name)] = sa.Table(
name, meta, autoload_with=bind
)
return single
if USE_CONNECTION:
with engine.connect() as c:
single = go(c)
else:
single = go(engine)
multi_meta = sa.MetaData(schema=schema_name)
if "multi" in mode:
with timing("MetaData_reflect"):
multi_meta.reflect(engine, only=table_names)
return (multi_meta, single)
@define_test
def reflect_check_constraints(
engine, schema_name, table_names, timing, mode, ignore_diff
):
multi, single = _single_test(
"get_check_constraints",
"get_multi_check_constraints",
engine,
schema_name,
table_names,
timing,
mode,
)
if not ignore_diff:
verify_dict(multi, single)
@define_test
def reflect_indexes(
engine, schema_name, table_names, timing, mode, ignore_diff
):
multi, single = _single_test(
"get_indexes",
"get_multi_indexes",
engine,
schema_name,
table_names,
timing,
mode,
)
if not ignore_diff:
verify_dict(multi, single)
@define_test
def reflect_foreign_keys(
engine, schema_name, table_names, timing, mode, ignore_diff
):
multi, single = _single_test(
"get_foreign_keys",
"get_multi_foreign_keys",
engine,
schema_name,
table_names,
timing,
mode,
)
if not ignore_diff:
verify_dict(multi, single)
@define_test
def reflect_unique_constraints(
engine, schema_name, table_names, timing, mode, ignore_diff
):
multi, single = _single_test(
"get_unique_constraints",
"get_multi_unique_constraints",
engine,
schema_name,
table_names,
timing,
mode,
)
if not ignore_diff:
verify_dict(multi, single)
def _apply_events(engine):
queries = defaultdict(list)
now = 0
@sa.event.listens_for(engine, "before_cursor_execute")
def before_cursor_execute(
conn, cursor, statement, parameters, context, executemany
):
nonlocal now
now = time.time()
@sa.event.listens_for(engine, "after_cursor_execute")
def after_cursor_execute(
conn, cursor, statement, parameters, context, executemany
):
total = time.time() - now
if context and context.compiled:
statement_str = context.compiled.string
else:
statement_str = statement
queries[statement_str].append(total)
return queries
def _print_query_stats(queries):
number_of_queries = sum(
len(query_times) for query_times in queries.values()
)
print("-" * 50)
q_list = list(queries.items())
q_list.sort(key=lambda rec: -sum(rec[1]))
total = sum([sum(t) for _, t in q_list])
print(f"total number of queries: {number_of_queries}. Total time {total}")
print("-" * 50)
for stmt, times in q_list:
total_t = sum(times)
max_t = max(times)
min_t = min(times)
avg_t = total_t / len(times)
times.sort()
median_t = times[len(times) // 2]
print(
f"Query times: {total_t=}, {max_t=}, {min_t=}, {avg_t=}, "
f"{median_t=} Number of calls: {len(times)}"
)
print(stmt.strip(), "\n")
def main(db, schema_name, table_number, min_cols, max_cols, args):
timing = timer()
if args.pool_class:
engine = sa.create_engine(
db,
echo=args.echo,
poolclass=getattr(sa.pool, args.pool_class),
future=True,
)
else:
engine = sa.create_engine(db, echo=args.echo, future=True)
if args.drop_all:
return drop_all(engine, schema_name)
if engine.name == "oracle":
# clear out oracle caches so that we get the real-world time the
# queries would normally take for scripts that aren't run repeatedly
with engine.connect() as conn:
# https://stackoverflow.com/questions/2147456/how-to-clear-all-cached-items-in-oracle
conn.exec_driver_sql("alter system flush buffer_cache")
conn.exec_driver_sql("alter system flush shared_pool")
if not args.no_create:
print(
f"Generating {table_number} using engine {engine} in "
f"schema {schema_name or 'default'}",
)
meta = sa.MetaData()
table_names = []
stats = {}
try:
if not args.no_create:
with timing("populate-meta"):
meta = generate_meta(
schema_name, table_number, min_cols, max_cols, engine.name
)
with timing("create-tables"):
create_tables(engine, meta)
with timing("get_table_names"):
with engine.connect() as conn:
table_names = engine.dialect.get_table_names(
conn, schema=schema_name
)
print(
f"Reflected table number {len(table_names)} in "
f"schema {schema_name or 'default'}"
)
mode = {"single", "multi"}
if args.multi_only:
mode.discard("single")
if args.single_only:
mode.discard("multi")
if args.sqlstats:
print("starting stats for subsequent tests")
stats = _apply_events(engine)
for test_name, test_fn in tests.items():
if test_name in args.test or "all" in args.test:
test_fn(
engine,
schema_name,
table_names,
timing,
mode,
args.ignore_diff,
)
if args.reflect:
with timing("reflect-tables"):
reflect_tables(engine, schema_name)
finally:
# copy stats to new dict
if args.sqlstats:
stats = dict(stats)
try:
if not args.no_drop:
with timing("drop-tables"):
drop_tables(engine, meta, schema_name, table_names)
finally:
pprint(timing.timing, sort_dicts=False)
if args.sqlstats:
_print_query_stats(stats)
def timer():
timing = {}
@contextmanager
def track_time(name):
s = time.time()
yield
timing[name] = time.time() - s
track_time.timing = timing
return track_time
def drop_all(engine, schema_name):
with engine.connect() as conn:
table_names = engine.dialect.get_table_names(conn, schema=schema_name)
print(f"Dropping {len(table_names)} tables")
dn = engine.dialect.name
i = 0
while table_names:
name = table_names.pop()
try:
conn.execute(_drop_ddl(name, schema_name, dn))
conn.commit()
except Exception:
conn.rollback()
table_names.insert(0, name)
i += 1
if i % 25 == 0:
print(f"Still running. Tables left {len(table_names)}")
if __name__ == "__main__":
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--db", help="Database url", default="sqlite:///many-table.db"
)
parser.add_argument(
"--schema-name",
help="optional schema name",
type=str,
default=None,
)
parser.add_argument(
"--table-number",
help="Number of table to generate.",
type=int,
default=250,
)
parser.add_argument(
"--min-cols",
help="Min number of column per table.",
type=int,
default=15,
)
parser.add_argument(
"--max-cols",
help="Max number of column per table.",
type=int,
default=250,
)
parser.add_argument(
"--no-create", help="Do not run create tables", action="store_true"
)
parser.add_argument(
"--no-drop", help="Do not run drop tables", action="store_true"
)
parser.add_argument(
"--reflect", help="Run metadata reflect", action="store_true"
)
parser.add_argument(
"--test",
help="Run these tests. 'all' runs all tests",
nargs="+",
choices=tuple(tests) + ("all", "none"),
default=["all"],
)
parser.add_argument(
"--sqlstats",
help="count and time individual queries",
action="store_true",
)
parser.add_argument(
"--multi-only", help="Only run multi table tests", action="store_true"
)
parser.add_argument(
"--single-only",
help="Only run single table tests",
action="store_true",
)
parser.add_argument(
"--echo", action="store_true", help="Enable echo on the engine"
)
parser.add_argument(
"--ignore-diff",
action="store_true",
help="Ignores differences in the single/multi reflections",
)
parser.add_argument(
"--single-inspect-conn",
action="store_true",
help="Uses inspect on a connection instead of on the engine when "
"using single reflections. Mainly for sqlite.",
)
parser.add_argument("--pool-class", help="The pool class to use")
parser.add_argument(
"--drop-all",
action="store_true",
help="Drop all tables, do nothing else",
)
args = parser.parse_args()
min_cols = args.min_cols
max_cols = args.max_cols
USE_CONNECTION = args.single_inspect_conn
assert min_cols <= max_cols and min_cols >= 1
assert not (args.multi_only and args.single_only)
main(
args.db, args.schema_name, args.table_number, min_cols, max_cols, args
)