Domain type

Added a new Postgresql :class:`_postgresql.DOMAIN` datatype, which follows
the same CREATE TYPE / DROP TYPE behaviors as that of PostgreSQL
:class:`_postgresql.ENUM`. Much thanks to David Baumgold for the efforts on
this.

Fixes: #7316
Closes: #7317
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/7317
Pull-request-sha: bc9a82f010

Change-Id: Id8d7e48843a896de17d20cc466b115b3cc065132
This commit is contained in:
David Baumgold
2022-02-11 12:30:24 -05:00
committed by Mike Bayer
parent 4e2a89c41b
commit 017fd9ae06
12 changed files with 1264 additions and 493 deletions
+21
View File
@@ -0,0 +1,21 @@
.. change::
:tags: feature, postgresql
:tickets: 7316
Added a new PostgreSQL :class:`_postgresql.DOMAIN` datatype, which follows
the same CREATE TYPE / DROP TYPE behaviors as that of PostgreSQL
:class:`_postgresql.ENUM`. Much thanks to David Baumgold for the efforts on
this.
.. seealso::
:class:`_postgresql.DOMAIN`
.. change::
:tags: change, postgresql
The :paramref:`_postgresql.ENUM.name` parameter for the PostgreSQL-specific
:class:`_postgresql.ENUM` datatype is now a required keyword argument. The
"name" is necessary in any case in order for the :class:`_postgresql.ENUM`
to be usable as an error would be raised at SQL/DDL render time if "name"
were not present.
+3
View File
@@ -49,6 +49,9 @@ construction arguments, are as follows:
.. autoclass:: CIDR
.. autoclass:: DOMAIN
:members: __init__, create, drop
.. autoclass:: DOUBLE_PRECISION
:members: __init__
:noindex:
+11 -3
View File
@@ -22,6 +22,7 @@ from .base import BIGINT
from .base import BOOLEAN
from .base import CHAR
from .base import DATE
from .base import DOMAIN
from .base import DOUBLE_PRECISION
from .base import FLOAT
from .base import INTEGER
@@ -40,6 +41,12 @@ from .hstore import HSTORE
from .hstore import hstore
from .json import JSON
from .json import JSONB
from .named_types import CreateDomainType
from .named_types import CreateEnumType
from .named_types import DropDomainType
from .named_types import DropEnumType
from .named_types import ENUM
from .named_types import NamedType
from .ranges import DATERANGE
from .ranges import INT4RANGE
from .ranges import INT8RANGE
@@ -49,9 +56,6 @@ from .ranges import TSTZRANGE
from .types import BIT
from .types import BYTEA
from .types import CIDR
from .types import CreateEnumType
from .types import DropEnumType
from .types import ENUM
from .types import INET
from .types import INTERVAL
from .types import MACADDR
@@ -97,6 +101,7 @@ __all__ = (
"INTERVAL",
"ARRAY",
"ENUM",
"DOMAIN",
"dialect",
"array",
"HSTORE",
@@ -113,6 +118,9 @@ __all__ = (
"Any",
"All",
"DropEnumType",
"DropDomainType",
"CreateDomainType",
"NamedType",
"CreateEnumType",
"ExcludeConstraint",
"aggregate_order_by",
+237 -64
View File
@@ -1450,6 +1450,9 @@ from __future__ import annotations
from collections import defaultdict
from functools import lru_cache
import re
from typing import Any
from typing import List
from typing import Optional
from . import array as _array
from . import dml
@@ -1457,30 +1460,34 @@ from . import hstore as _hstore
from . import json as _json
from . import pg_catalog
from . import ranges as _ranges
from .types import _DECIMAL_TYPES # noqa
from .types import _FLOAT_TYPES # noqa
from .types import _INT_TYPES # noqa
from .types import BIT
from .types import BYTEA
from .types import CIDR
from .types import CreateEnumType # noqa
from .types import DropEnumType # noqa
from .types import ENUM
from .types import INET
from .types import INTERVAL
from .types import MACADDR
from .types import MONEY
from .types import OID
from .types import PGBit # noqa
from .types import PGCidr # noqa
from .types import PGInet # noqa
from .types import PGInterval # noqa
from .types import PGMacAddr # noqa
from .types import PGUuid
from .types import REGCLASS
from .types import TIME
from .types import TIMESTAMP
from .types import TSVECTOR
from .named_types import CreateDomainType as CreateDomainType # noqa: F401
from .named_types import CreateEnumType as CreateEnumType # noqa: F401
from .named_types import DOMAIN as DOMAIN # noqa: F401
from .named_types import DropDomainType as DropDomainType # noqa: F401
from .named_types import DropEnumType as DropEnumType # noqa: F401
from .named_types import ENUM as ENUM # noqa: F401
from .named_types import NamedType as NamedType # noqa: F401
from .types import _DECIMAL_TYPES # noqa: F401
from .types import _FLOAT_TYPES # noqa: F401
from .types import _INT_TYPES # noqa: F401
from .types import BIT as BIT
from .types import BYTEA as BYTEA
from .types import CIDR as CIDR
from .types import INET as INET
from .types import INTERVAL as INTERVAL
from .types import MACADDR as MACADDR
from .types import MONEY as MONEY
from .types import OID as OID
from .types import PGBit as PGBit # noqa: F401
from .types import PGCidr as PGCidr # noqa: F401
from .types import PGInet as PGInet # noqa: F401
from .types import PGInterval as PGInterval # noqa: F401
from .types import PGMacAddr as PGMacAddr # noqa: F401
from .types import PGUuid as PGUuid
from .types import REGCLASS as REGCLASS
from .types import TIME as TIME
from .types import TIMESTAMP as TIMESTAMP
from .types import TSVECTOR as TSVECTOR
from ... import exc
from ... import schema
from ... import select
@@ -1515,6 +1522,7 @@ from ...types import SMALLINT
from ...types import TEXT
from ...types import UUID as UUID
from ...types import VARCHAR
from ...util.typing import TypedDict
IDX_USING = re.compile(r"^(?:btree|hash|gist|gin|[\w_]+)$", re.I)
@@ -2198,6 +2206,38 @@ class PGDDLCompiler(compiler.DDLCompiler):
return "DROP TYPE %s" % (self.preparer.format_type(type_))
def visit_create_domain_type(self, create):
domain: DOMAIN = create.element
options = []
if domain.collation is not None:
options.append(f"COLLATE {self.preparer.quote(domain.collation)}")
if domain.default is not None:
default = self.render_default_string(domain.default)
options.append(f"DEFAULT {default}")
if domain.constraint_name is not None:
name = self.preparer.truncate_and_render_constraint_name(
domain.constraint_name
)
options.append(f"CONSTRAINT {name}")
if domain.not_null:
options.append("NOT NULL")
if domain.check is not None:
check = self.sql_compiler.process(
domain.check, include_table=False, literal_binds=True
)
options.append(f"CHECK ({check})")
return (
f"CREATE DOMAIN {self.preparer.format_type(domain)} AS "
f"{self.type_compiler.process(domain.data_type)} "
f"{' '.join(options)}"
)
def visit_drop_domain_type(self, drop):
domain = drop.element
return f"DROP DOMAIN {self.preparer.format_type(domain)}"
def visit_create_index(self, create):
preparer = self.preparer
index = create.element
@@ -2470,6 +2510,11 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
identifier_preparer = self.dialect.identifier_preparer
return identifier_preparer.format_type(type_)
def visit_DOMAIN(self, type_, identifier_preparer=None, **kw):
if identifier_preparer is None:
identifier_preparer = self.dialect.identifier_preparer
return identifier_preparer.format_type(type_)
def visit_TIMESTAMP(self, type_, **kw):
return "TIMESTAMP%s %s" % (
"(%d)" % type_.precision
@@ -2548,7 +2593,9 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer):
def format_type(self, type_, use_schema=True):
if not type_.name:
raise exc.CompileError("PostgreSQL ENUM type requires a name.")
raise exc.CompileError(
f"PostgreSQL {type_.__class__.__name__} type requires a name."
)
name = self.quote(type_.name)
effective_schema = self.schema_for_object(type_)
@@ -2558,14 +2605,60 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer):
and use_schema
and effective_schema is not None
):
name = self.quote_schema(effective_schema) + "." + name
name = f"{self.quote_schema(effective_schema)}.{name}"
return name
class ReflectedNamedType(TypedDict):
"""Represents a reflected named type."""
name: str
"""Name of the type."""
schema: str
"""The schema of the type."""
visible: bool
"""Indicates if this type is in the current search path."""
class ReflectedDomainConstraint(TypedDict):
"""Represents a reflect check constraint of a domain."""
name: str
"""Name of the constraint."""
check: str
"""The check constraint text."""
class ReflectedDomain(ReflectedNamedType):
"""Represents a reflected enum."""
type: str
"""The string name of the underlying data type of the domain."""
nullable: bool
"""Indicates if the domain allows null or not."""
default: Optional[str]
"""The string representation of the default value of this domain
or ``None`` if none present.
"""
constraints: List[ReflectedDomainConstraint]
"""The constraints defined in the domain, if any.
The constraint are in order of evaluation by postgresql.
"""
class ReflectedEnum(ReflectedNamedType):
"""Represents a reflected enum."""
labels: List[str]
"""The labels that compose the enum."""
class PGInspector(reflection.Inspector):
dialect: PGDialect
def get_table_oid(self, table_name, schema=None):
def get_table_oid(
self, table_name: str, schema: Optional[str] = None
) -> int:
"""Return the OID for the given table name.
:param table_name: string name of the table. For special quoting,
@@ -2582,7 +2675,38 @@ class PGInspector(reflection.Inspector):
conn, table_name, schema, info_cache=self.info_cache
)
def get_enums(self, schema=None):
def get_domains(
self, schema: Optional[str] = None
) -> List[ReflectedDomain]:
"""Return a list of DOMAIN objects.
Each member is a dictionary containing these fields:
* name - name of the domain
* schema - the schema name for the domain.
* visible - boolean, whether or not this domain is visible
in the default search path.
* type - the type defined by this domain.
* nullable - Indicates if this domain can be ``NULL``.
* default - The default value of the domain or ``None`` if the
domain has no default.
* constraints - A list of dict wit the constraint defined by this
domain. Each element constaints two keys: ``name`` of the
constraint and ``check`` with the constraint text.
:param schema: schema name. If None, the default schema
(typically 'public') is used. May also be set to ``'*'`` to
indicate load domains for all schemas.
.. versionadded:: 2.0
"""
with self._operation_context() as conn:
return self.dialect._load_domains(
conn, schema, info_cache=self.info_cache
)
def get_enums(self, schema: Optional[str] = None) -> List[ReflectedEnum]:
"""Return a list of ENUM objects.
Each member is a dictionary containing these fields:
@@ -2594,7 +2718,7 @@ class PGInspector(reflection.Inspector):
* labels - a list of string labels that apply to the enum.
:param schema: schema name. If None, the default schema
(typically 'public') is used. May also be set to '*' to
(typically 'public') is used. May also be set to ``'*'`` to
indicate load enums for all schemas.
.. versionadded:: 1.0.0
@@ -2605,7 +2729,9 @@ class PGInspector(reflection.Inspector):
conn, schema, info_cache=self.info_cache
)
def get_foreign_table_names(self, schema=None):
def get_foreign_table_names(
self, schema: Optional[str] = None
) -> List[str]:
"""Return a list of FOREIGN TABLE names.
Behavior is similar to that of
@@ -2621,13 +2747,15 @@ class PGInspector(reflection.Inspector):
conn, schema, info_cache=self.info_cache
)
def has_type(self, type_name, schema=None, **kw):
def has_type(
self, type_name: str, schema: Optional[str] = None, **kw: Any
) -> bool:
"""Return if the database has the specified type in the provided
schema.
:param type_name: the type to check.
:param schema: schema name. If None, the default schema
(typically 'public') is used. May also be set to '*' to
(typically 'public') is used. May also be set to ``'*'`` to
check in all schemas.
.. versionadded:: 2.0
@@ -2941,10 +3069,12 @@ class PGDialect(default.DefaultDialect):
pg_catalog.pg_namespace,
pg_catalog.pg_namespace.c.oid == pg_class_table.c.relnamespace,
)
if scope is ObjectScope.DEFAULT:
query = query.where(pg_class_table.c.relpersistence != "t")
elif scope is ObjectScope.TEMPORARY:
query = query.where(pg_class_table.c.relpersistence == "t")
if schema is None:
query = query.where(
pg_catalog.pg_table_is_visible(pg_class_table.c.oid),
@@ -3319,9 +3449,12 @@ class PGDialect(default.DefaultDialect):
# dictionary with (name, ) if default search path or (schema, name)
# as keys
domains = self._load_domains(
connection, info_cache=kw.get("info_cache")
)
domains = {
((d["schema"], d["name"]) if not d["visible"] else (d["name"],)): d
for d in self._load_domains(
connection, schema="*", info_cache=kw.get("info_cache")
)
}
# dictionary with (name, ) if default search path or (schema, name)
# as keys
@@ -3446,7 +3579,7 @@ class PGDialect(default.DefaultDialect):
break
elif enum_or_domain_key in domains:
domain = domains[enum_or_domain_key]
attype = domain["attype"]
attype = domain["type"]
attype, is_array = _handle_array_type(attype)
# strip quotes from case sensitive enum or domain names
enum_or_domain_key = tuple(
@@ -3736,7 +3869,7 @@ class PGDialect(default.DefaultDialect):
@util.memoized_property
def _fk_regex_pattern(self):
# https://www.postgresql.org/docs/14.0/static/sql-createtable.html
# https://www.postgresql.org/docs/current/static/sql-createtable.html
return re.compile(
r"FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)"
r"[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?"
@@ -4201,7 +4334,7 @@ class PGDialect(default.DefaultDialect):
(
pg_catalog.pg_constraint.c.oid.is_not(None),
pg_catalog.pg_get_constraintdef(
pg_catalog.pg_constraint.c.oid
pg_catalog.pg_constraint.c.oid, True
),
),
else_=None,
@@ -4265,6 +4398,17 @@ class PGDialect(default.DefaultDialect):
check_constraints[(schema, table_name)].append(entry)
return check_constraints.items()
def _pg_type_filter_schema(self, query, schema):
if schema is None:
query = query.where(
pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid),
# ignore pg_catalog schema
pg_catalog.pg_namespace.c.nspname != "pg_catalog",
)
elif schema != "*":
query = query.where(pg_catalog.pg_namespace.c.nspname == schema)
return query
@lru_cache()
def _enum_query(self, schema):
lbl_sq = (
@@ -4310,15 +4454,7 @@ class PGDialect(default.DefaultDialect):
)
)
if schema is None:
query = query.where(
pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid),
# ignore pg_catalog schema
pg_catalog.pg_namespace.c.nspname != "pg_catalog",
)
elif schema != "*":
query = query.where(pg_catalog.pg_namespace.c.nspname == schema)
return query
return self._pg_type_filter_schema(query, schema)
@reflection.cache
def _load_enums(self, connection, schema=None, **kw):
@@ -4339,9 +4475,27 @@ class PGDialect(default.DefaultDialect):
)
return enums
@util.memoized_property
def _domain_query(self):
return (
@lru_cache()
def _domain_query(self, schema):
con_sq = (
select(
pg_catalog.pg_constraint.c.contypid,
sql.func.array_agg(
pg_catalog.pg_get_constraintdef(
pg_catalog.pg_constraint.c.oid, True
)
).label("condefs"),
sql.func.array_agg(pg_catalog.pg_constraint.c.conname).label(
"connames"
),
)
# The domain this constraint is on; zero if not a domain constraint
.where(pg_catalog.pg_constraint.c.contypid != 0)
.group_by(pg_catalog.pg_constraint.c.contypid)
.subquery("domain_constraints")
)
query = (
select(
pg_catalog.pg_type.c.typname.label("name"),
pg_catalog.format_type(
@@ -4354,38 +4508,57 @@ class PGDialect(default.DefaultDialect):
"visible"
),
pg_catalog.pg_namespace.c.nspname.label("schema"),
con_sq.c.condefs,
con_sq.c.connames,
)
.join(
pg_catalog.pg_namespace,
pg_catalog.pg_namespace.c.oid
== pg_catalog.pg_type.c.typnamespace,
)
.outerjoin(
con_sq,
pg_catalog.pg_type.c.oid == con_sq.c.contypid,
)
.where(pg_catalog.pg_type.c.typtype == "d")
.order_by(
pg_catalog.pg_namespace.c.nspname, pg_catalog.pg_type.c.typname
)
)
return self._pg_type_filter_schema(query, schema)
@reflection.cache
def _load_domains(self, connection, **kw):
def _load_domains(self, connection, schema=None, **kw):
# Load data types for domains:
result = connection.execute(self._domain_query)
result = connection.execute(self._domain_query(schema))
domains = {}
domains = []
for domain in result.mappings():
domain = domain
# strip (30) from character varying(30)
attype = re.search(r"([^\(]+)", domain["attype"]).group(1)
# 'visible' just means whether or not the domain is in a
# schema that's on the search path -- or not overridden by
# a schema with higher precedence. If it's not visible,
# it will be prefixed with the schema-name when it's used.
if domain["visible"]:
key = (domain["name"],)
else:
key = (domain["schema"], domain["name"])
constraints = []
if domain["connames"]:
# When a domain has multiple CHECK constraints, they will
# be tested in alphabetical order by name.
sorted_constraints = sorted(
zip(domain["connames"], domain["condefs"]),
key=lambda t: t[0],
)
for name, def_ in sorted_constraints:
# constraint is in the form "CHECK (expression)".
# remove "CHECK (" and the tailing ")".
check = def_[7:-1]
constraints.append({"name": name, "check": check})
domains[key] = {
"attype": attype,
domain_rec = {
"name": domain["name"],
"schema": domain["schema"],
"visible": domain["visible"],
"type": attype,
"nullable": domain["nullable"],
"default": domain["default"],
"constraints": constraints,
}
domains.append(domain_rec)
return domains
@@ -0,0 +1,476 @@
# postgresql/named_types.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
from typing import Any
from typing import Optional
from typing import Type
from typing import TYPE_CHECKING
from typing import Union
from ... import schema
from ... import util
from ...sql import coercions
from ...sql import elements
from ...sql import roles
from ...sql import sqltypes
from ...sql import type_api
from ...sql.ddl import InvokeDDLBase
if TYPE_CHECKING:
from ...sql._typing import _TypeEngineArgument
class NamedType(sqltypes.TypeEngine):
"""Base for named types."""
__abstract__ = True
DDLGenerator: Type["NamedTypeGenerator"]
DDLDropper: Type["NamedTypeDropper"]
create_type: bool
def create(self, bind, checkfirst=True, **kw):
"""Emit ``CREATE`` DDL for this type.
:param bind: a connectable :class:`_engine.Engine`,
:class:`_engine.Connection`, or similar object to emit
SQL.
:param checkfirst: if ``True``, a query against
the PG catalog will be first performed to see
if the type does not exist already before
creating.
"""
bind._run_ddl_visitor(self.DDLGenerator, self, checkfirst=checkfirst)
def drop(self, bind, checkfirst=True, **kw):
"""Emit ``DROP`` DDL for this type.
:param bind: a connectable :class:`_engine.Engine`,
:class:`_engine.Connection`, or similar object to emit
SQL.
:param checkfirst: if ``True``, a query against
the PG catalog will be first performed to see
if the type actually exists before dropping.
"""
bind._run_ddl_visitor(self.DDLDropper, self, checkfirst=checkfirst)
def _check_for_name_in_memos(self, checkfirst, kw):
"""Look in the 'ddl runner' for 'memos', then
note our name in that collection.
This to ensure a particular named type is operated
upon only once within any kind of create/drop
sequence without relying upon "checkfirst".
"""
if not self.create_type:
return True
if "_ddl_runner" in kw:
ddl_runner = kw["_ddl_runner"]
type_name = f"pg_{self.__visit_name__}"
if type_name in ddl_runner.memo:
existing = ddl_runner.memo[type_name]
else:
existing = ddl_runner.memo[type_name] = set()
present = (self.schema, self.name) in existing
existing.add((self.schema, self.name))
return present
else:
return False
def _on_table_create(self, target, bind, checkfirst=False, **kw):
if (
checkfirst
or (
not self.metadata
and not kw.get("_is_metadata_operation", False)
)
) and not self._check_for_name_in_memos(checkfirst, kw):
self.create(bind=bind, checkfirst=checkfirst)
def _on_table_drop(self, target, bind, checkfirst=False, **kw):
if (
not self.metadata
and not kw.get("_is_metadata_operation", False)
and not self._check_for_name_in_memos(checkfirst, kw)
):
self.drop(bind=bind, checkfirst=checkfirst)
def _on_metadata_create(self, target, bind, checkfirst=False, **kw):
if not self._check_for_name_in_memos(checkfirst, kw):
self.create(bind=bind, checkfirst=checkfirst)
def _on_metadata_drop(self, target, bind, checkfirst=False, **kw):
if not self._check_for_name_in_memos(checkfirst, kw):
self.drop(bind=bind, checkfirst=checkfirst)
class NamedTypeGenerator(InvokeDDLBase):
def __init__(self, dialect, connection, checkfirst=False, **kwargs):
super().__init__(connection, **kwargs)
self.checkfirst = checkfirst
def _can_create_type(self, type_):
if not self.checkfirst:
return True
effective_schema = self.connection.schema_for_object(type_)
return not self.connection.dialect.has_type(
self.connection, type_.name, schema=effective_schema
)
class NamedTypeDropper(InvokeDDLBase):
def __init__(self, dialect, connection, checkfirst=False, **kwargs):
super().__init__(connection, **kwargs)
self.checkfirst = checkfirst
def _can_drop_type(self, type_):
if not self.checkfirst:
return True
effective_schema = self.connection.schema_for_object(type_)
return self.connection.dialect.has_type(
self.connection, type_.name, schema=effective_schema
)
class EnumGenerator(NamedTypeGenerator):
def visit_enum(self, enum):
if not self._can_create_type(enum):
return
self.connection.execute(CreateEnumType(enum))
class EnumDropper(NamedTypeDropper):
def visit_enum(self, enum):
if not self._can_drop_type(enum):
return
self.connection.execute(DropEnumType(enum))
class ENUM(NamedType, sqltypes.NativeForEmulated, sqltypes.Enum):
"""PostgreSQL ENUM type.
This is a subclass of :class:`_types.Enum` which includes
support for PG's ``CREATE TYPE`` and ``DROP TYPE``.
When the builtin type :class:`_types.Enum` is used and the
:paramref:`.Enum.native_enum` flag is left at its default of
True, the PostgreSQL backend will use a :class:`_postgresql.ENUM`
type as the implementation, so the special create/drop rules
will be used.
The create/drop behavior of ENUM is necessarily intricate, due to the
awkward relationship the ENUM type has in relationship to the
parent table, in that it may be "owned" by just a single table, or
may be shared among many tables.
When using :class:`_types.Enum` or :class:`_postgresql.ENUM`
in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted
corresponding to when the :meth:`_schema.Table.create` and
:meth:`_schema.Table.drop`
methods are called::
table = Table('sometable', metadata,
Column('some_enum', ENUM('a', 'b', 'c', name='myenum'))
)
table.create(engine) # will emit CREATE ENUM and CREATE TABLE
table.drop(engine) # will emit DROP TABLE and DROP ENUM
To use a common enumerated type between multiple tables, the best
practice is to declare the :class:`_types.Enum` or
:class:`_postgresql.ENUM` independently, and associate it with the
:class:`_schema.MetaData` object itself::
my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata)
t1 = Table('sometable_one', metadata,
Column('some_enum', myenum)
)
t2 = Table('sometable_two', metadata,
Column('some_enum', myenum)
)
When this pattern is used, care must still be taken at the level
of individual table creates. Emitting CREATE TABLE without also
specifying ``checkfirst=True`` will still cause issues::
t1.create(engine) # will fail: no such type 'myenum'
If we specify ``checkfirst=True``, the individual table-level create
operation will check for the ``ENUM`` and create if not exists::
# will check if enum exists, and emit CREATE TYPE if not
t1.create(engine, checkfirst=True)
When using a metadata-level ENUM type, the type will always be created
and dropped if either the metadata-wide create/drop is called::
metadata.create_all(engine) # will emit CREATE TYPE
metadata.drop_all(engine) # will emit DROP TYPE
The type can also be created and dropped directly::
my_enum.create(engine)
my_enum.drop(engine)
.. versionchanged:: 1.0.0 The PostgreSQL :class:`_postgresql.ENUM` type
now behaves more strictly with regards to CREATE/DROP. A metadata-level
ENUM type will only be created and dropped at the metadata level,
not the table level, with the exception of
``table.create(checkfirst=True)``.
The ``table.drop()`` call will now emit a DROP TYPE for a table-level
enumerated type.
"""
native_enum = True
DDLGenerator = EnumGenerator
DDLDropper = EnumDropper
def __init__(self, *enums, name: str, create_type: bool = True, **kw):
"""Construct an :class:`_postgresql.ENUM`.
Arguments are the same as that of
:class:`_types.Enum`, but also including
the following parameters.
:param create_type: Defaults to True.
Indicates that ``CREATE TYPE`` should be
emitted, after optionally checking for the
presence of the type, when the parent
table is being created; and additionally
that ``DROP TYPE`` is called when the table
is dropped. When ``False``, no check
will be performed and no ``CREATE TYPE``
or ``DROP TYPE`` is emitted, unless
:meth:`~.postgresql.ENUM.create`
or :meth:`~.postgresql.ENUM.drop`
are called directly.
Setting to ``False`` is helpful
when invoking a creation scheme to a SQL file
without access to the actual database -
the :meth:`~.postgresql.ENUM.create` and
:meth:`~.postgresql.ENUM.drop` methods can
be used to emit SQL to a target bind.
"""
native_enum = kw.pop("native_enum", None)
if native_enum is False:
util.warn(
"the native_enum flag does not apply to the "
"sqlalchemy.dialects.postgresql.ENUM datatype; this type "
"always refers to ENUM. Use sqlalchemy.types.Enum for "
"non-native enum."
)
self.create_type = create_type
super().__init__(*enums, name=name, **kw)
@classmethod
def __test_init__(cls):
return cls(name="name")
@classmethod
def adapt_emulated_to_native(cls, impl, **kw):
"""Produce a PostgreSQL native :class:`_postgresql.ENUM` from plain
:class:`.Enum`.
"""
kw.setdefault("validate_strings", impl.validate_strings)
kw.setdefault("name", impl.name)
kw.setdefault("schema", impl.schema)
kw.setdefault("inherit_schema", impl.inherit_schema)
kw.setdefault("metadata", impl.metadata)
kw.setdefault("_create_events", False)
kw.setdefault("values_callable", impl.values_callable)
kw.setdefault("omit_aliases", impl._omit_aliases)
return cls(**kw)
def create(self, bind=None, checkfirst=True):
"""Emit ``CREATE TYPE`` for this
:class:`_postgresql.ENUM`.
If the underlying dialect does not support
PostgreSQL CREATE TYPE, no action is taken.
:param bind: a connectable :class:`_engine.Engine`,
:class:`_engine.Connection`, or similar object to emit
SQL.
:param checkfirst: if ``True``, a query against
the PG catalog will be first performed to see
if the type does not exist already before
creating.
"""
if not bind.dialect.supports_native_enum:
return
super().create(bind, checkfirst=checkfirst)
def drop(self, bind=None, checkfirst=True):
"""Emit ``DROP TYPE`` for this
:class:`_postgresql.ENUM`.
If the underlying dialect does not support
PostgreSQL DROP TYPE, no action is taken.
:param bind: a connectable :class:`_engine.Engine`,
:class:`_engine.Connection`, or similar object to emit
SQL.
:param checkfirst: if ``True``, a query against
the PG catalog will be first performed to see
if the type actually exists before dropping.
"""
if not bind.dialect.supports_native_enum:
return
super().drop(bind, checkfirst=checkfirst)
def get_dbapi_type(self, dbapi):
"""dont return dbapi.STRING for ENUM in PostgreSQL, since that's
a different type"""
return None
class DomainGenerator(NamedTypeGenerator):
def visit_DOMAIN(self, domain):
if not self._can_create_type(domain):
return
self.connection.execute(CreateDomainType(domain))
class DomainDropper(NamedTypeDropper):
def visit_DOMAIN(self, domain):
if not self._can_drop_type(domain):
return
self.connection.execute(DropDomainType(domain))
class DOMAIN(NamedType, sqltypes.SchemaType):
r"""Represent the DOMAIN PostgreSQL type.
A domain is essentially a data type with optional constraints
that restrict the allowed set of values. E.g.::
PositiveInt = Domain(
"pos_int", Integer, check="VALUE > 0", not_null=True
)
UsPostalCode = Domain(
"us_postal_code",
Text,
check="VALUE ~ '^\d{5}$' OR VALUE ~ '^\d{5}-\d{4}$'"
)
See the `PostgreSQL documentation`__ for additional details
__ https://www.postgresql.org/docs/current/sql-createdomain.html
.. versionadded:: 2.0
"""
DDLGenerator = DomainGenerator
DDLDropper = DomainDropper
__visit_name__ = "DOMAIN"
def __init__(
self,
name: str,
data_type: _TypeEngineArgument[Any],
*,
collation: Optional[str] = None,
default: Optional[Union[str, elements.TextClause]] = None,
constraint_name: Optional[str] = None,
not_null: Optional[bool] = None,
check: Optional[str] = None,
create_type: bool = True,
**kw: Any,
):
"""
Construct a DOMAIN.
:param name: the name of the domain
:param data_type: The underlying data type of the domain.
This can include array specifiers.
:param collation: An optional collation for the domain.
If no collation is specified, the underlying data type's default
collation is used. The underlying type must be collatable if
``collation`` is specified.
:param default: The DEFAULT clause specifies a default value for
columns of the domain data type. The default should be a string
or a :func:`_expression.text` value.
If no default value is specified, then the default value is
the null value.
:param constraint_name: An optional name for a constraint.
If not specified, the backend generates a name.
:param not_null: Values of this domain are prevented from being null.
By default domain are allowed to be null. If not specified
no nullability clause will be emitted.
:param check: CHECK clause specify integrity constraint or test
which values of the domain must satisfy. A constraint must be
an expression producing a Boolean result that can use the key
word VALUE to refer to the value being tested.
Differently from PostgreSQL, only a single check clause is
currently allowed in SQLAlchemy.
:param schema: optional schema name
:param metadata: optional :class:`_schema.MetaData` object which
this :class:`_postgresql.DOMAIN` will be directly associated
:param create_type: Defaults to True.
Indicates that ``CREATE TYPE`` should be emitted, after optionally
checking for the presence of the type, when the parent table is
being created; and additionally that ``DROP TYPE`` is called
when the table is dropped.
"""
self.data_type = type_api.to_instance(data_type)
self.default = default
self.collation = collation
self.constraint_name = constraint_name
self.not_null = not_null
if check is not None:
check = coercions.expect(roles.DDLExpressionRole, check)
self.check = check
self.create_type = create_type
super().__init__(name=name, **kw)
@classmethod
def __test_init__(cls):
return cls("name", sqltypes.Integer)
class CreateEnumType(schema._CreateDropBase):
__visit_name__ = "create_enum_type"
class DropEnumType(schema._CreateDropBase):
__visit_name__ = "drop_enum_type"
class CreateDomainType(schema._CreateDropBase):
"""Represent a CREATE DOMAIN statement."""
__visit_name__ = "create_domain_type"
class DropDomainType(schema._CreateDropBase):
"""Represent a DROP DOMAIN statement."""
__visit_name__ = "drop_domain_type"
-285
View File
@@ -8,10 +8,7 @@
import datetime as dt
from typing import Any
from ... import schema
from ... import util
from ...sql import sqltypes
from ...sql.ddl import InvokeDDLBase
_DECIMAL_TYPES = (1231, 1700)
@@ -201,285 +198,3 @@ class TSVECTOR(sqltypes.TypeEngine[Any]):
"""
__visit_name__ = "TSVECTOR"
class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
"""PostgreSQL ENUM type.
This is a subclass of :class:`_types.Enum` which includes
support for PG's ``CREATE TYPE`` and ``DROP TYPE``.
When the builtin type :class:`_types.Enum` is used and the
:paramref:`.Enum.native_enum` flag is left at its default of
True, the PostgreSQL backend will use a :class:`_postgresql.ENUM`
type as the implementation, so the special create/drop rules
will be used.
The create/drop behavior of ENUM is necessarily intricate, due to the
awkward relationship the ENUM type has in relationship to the
parent table, in that it may be "owned" by just a single table, or
may be shared among many tables.
When using :class:`_types.Enum` or :class:`_postgresql.ENUM`
in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted
corresponding to when the :meth:`_schema.Table.create` and
:meth:`_schema.Table.drop`
methods are called::
table = Table('sometable', metadata,
Column('some_enum', ENUM('a', 'b', 'c', name='myenum'))
)
table.create(engine) # will emit CREATE ENUM and CREATE TABLE
table.drop(engine) # will emit DROP TABLE and DROP ENUM
To use a common enumerated type between multiple tables, the best
practice is to declare the :class:`_types.Enum` or
:class:`_postgresql.ENUM` independently, and associate it with the
:class:`_schema.MetaData` object itself::
my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata)
t1 = Table('sometable_one', metadata,
Column('some_enum', myenum)
)
t2 = Table('sometable_two', metadata,
Column('some_enum', myenum)
)
When this pattern is used, care must still be taken at the level
of individual table creates. Emitting CREATE TABLE without also
specifying ``checkfirst=True`` will still cause issues::
t1.create(engine) # will fail: no such type 'myenum'
If we specify ``checkfirst=True``, the individual table-level create
operation will check for the ``ENUM`` and create if not exists::
# will check if enum exists, and emit CREATE TYPE if not
t1.create(engine, checkfirst=True)
When using a metadata-level ENUM type, the type will always be created
and dropped if either the metadata-wide create/drop is called::
metadata.create_all(engine) # will emit CREATE TYPE
metadata.drop_all(engine) # will emit DROP TYPE
The type can also be created and dropped directly::
my_enum.create(engine)
my_enum.drop(engine)
.. versionchanged:: 1.0.0 The PostgreSQL :class:`_postgresql.ENUM` type
now behaves more strictly with regards to CREATE/DROP. A metadata-level
ENUM type will only be created and dropped at the metadata level,
not the table level, with the exception of
``table.create(checkfirst=True)``.
The ``table.drop()`` call will now emit a DROP TYPE for a table-level
enumerated type.
"""
native_enum = True
def __init__(self, *enums, **kw):
"""Construct an :class:`_postgresql.ENUM`.
Arguments are the same as that of
:class:`_types.Enum`, but also including
the following parameters.
:param create_type: Defaults to True.
Indicates that ``CREATE TYPE`` should be
emitted, after optionally checking for the
presence of the type, when the parent
table is being created; and additionally
that ``DROP TYPE`` is called when the table
is dropped. When ``False``, no check
will be performed and no ``CREATE TYPE``
or ``DROP TYPE`` is emitted, unless
:meth:`~.postgresql.ENUM.create`
or :meth:`~.postgresql.ENUM.drop`
are called directly.
Setting to ``False`` is helpful
when invoking a creation scheme to a SQL file
without access to the actual database -
the :meth:`~.postgresql.ENUM.create` and
:meth:`~.postgresql.ENUM.drop` methods can
be used to emit SQL to a target bind.
"""
native_enum = kw.pop("native_enum", None)
if native_enum is False:
util.warn(
"the native_enum flag does not apply to the "
"sqlalchemy.dialects.postgresql.ENUM datatype; this type "
"always refers to ENUM. Use sqlalchemy.types.Enum for "
"non-native enum."
)
self.create_type = kw.pop("create_type", True)
super(ENUM, self).__init__(*enums, **kw)
@classmethod
def adapt_emulated_to_native(cls, impl, **kw):
"""Produce a PostgreSQL native :class:`_postgresql.ENUM` from plain
:class:`.Enum`.
"""
kw.setdefault("validate_strings", impl.validate_strings)
kw.setdefault("name", impl.name)
kw.setdefault("schema", impl.schema)
kw.setdefault("inherit_schema", impl.inherit_schema)
kw.setdefault("metadata", impl.metadata)
kw.setdefault("_create_events", False)
kw.setdefault("values_callable", impl.values_callable)
kw.setdefault("omit_aliases", impl._omit_aliases)
return cls(**kw)
def create(self, bind=None, checkfirst=True):
"""Emit ``CREATE TYPE`` for this
:class:`_postgresql.ENUM`.
If the underlying dialect does not support
PostgreSQL CREATE TYPE, no action is taken.
:param bind: a connectable :class:`_engine.Engine`,
:class:`_engine.Connection`, or similar object to emit
SQL.
:param checkfirst: if ``True``, a query against
the PG catalog will be first performed to see
if the type does not exist already before
creating.
"""
if not bind.dialect.supports_native_enum:
return
bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst)
def drop(self, bind=None, checkfirst=True):
"""Emit ``DROP TYPE`` for this
:class:`_postgresql.ENUM`.
If the underlying dialect does not support
PostgreSQL DROP TYPE, no action is taken.
:param bind: a connectable :class:`_engine.Engine`,
:class:`_engine.Connection`, or similar object to emit
SQL.
:param checkfirst: if ``True``, a query against
the PG catalog will be first performed to see
if the type actually exists before dropping.
"""
if not bind.dialect.supports_native_enum:
return
bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst)
class EnumGenerator(InvokeDDLBase):
def __init__(self, dialect, connection, checkfirst=False, **kwargs):
super(ENUM.EnumGenerator, self).__init__(connection, **kwargs)
self.checkfirst = checkfirst
def _can_create_enum(self, enum):
if not self.checkfirst:
return True
effective_schema = self.connection.schema_for_object(enum)
return not self.connection.dialect.has_type(
self.connection, enum.name, schema=effective_schema
)
def visit_enum(self, enum):
if not self._can_create_enum(enum):
return
self.connection.execute(CreateEnumType(enum))
class EnumDropper(InvokeDDLBase):
def __init__(self, dialect, connection, checkfirst=False, **kwargs):
super(ENUM.EnumDropper, self).__init__(connection, **kwargs)
self.checkfirst = checkfirst
def _can_drop_enum(self, enum):
if not self.checkfirst:
return True
effective_schema = self.connection.schema_for_object(enum)
return self.connection.dialect.has_type(
self.connection, enum.name, schema=effective_schema
)
def visit_enum(self, enum):
if not self._can_drop_enum(enum):
return
self.connection.execute(DropEnumType(enum))
def get_dbapi_type(self, dbapi):
"""dont return dbapi.STRING for ENUM in PostgreSQL, since that's
a different type"""
return None
def _check_for_name_in_memos(self, checkfirst, kw):
"""Look in the 'ddl runner' for 'memos', then
note our name in that collection.
This to ensure a particular named enum is operated
upon only once within any kind of create/drop
sequence without relying upon "checkfirst".
"""
if not self.create_type:
return True
if "_ddl_runner" in kw:
ddl_runner = kw["_ddl_runner"]
if "_pg_enums" in ddl_runner.memo:
pg_enums = ddl_runner.memo["_pg_enums"]
else:
pg_enums = ddl_runner.memo["_pg_enums"] = set()
present = (self.schema, self.name) in pg_enums
pg_enums.add((self.schema, self.name))
return present
else:
return False
def _on_table_create(self, target, bind, checkfirst=False, **kw):
if (
checkfirst
or (
not self.metadata
and not kw.get("_is_metadata_operation", False)
)
) and not self._check_for_name_in_memos(checkfirst, kw):
self.create(bind=bind, checkfirst=checkfirst)
def _on_table_drop(self, target, bind, checkfirst=False, **kw):
if (
not self.metadata
and not kw.get("_is_metadata_operation", False)
and not self._check_for_name_in_memos(checkfirst, kw)
):
self.drop(bind=bind, checkfirst=checkfirst)
def _on_metadata_create(self, target, bind, checkfirst=False, **kw):
if not self._check_for_name_in_memos(checkfirst, kw):
self.create(bind=bind, checkfirst=checkfirst)
def _on_metadata_drop(self, target, bind, checkfirst=False, **kw):
if not self._check_for_name_in_memos(checkfirst, kw):
self.drop(bind=bind, checkfirst=checkfirst)
class CreateEnumType(schema._CreateDropBase):
__visit_name__ = "create_enum_type"
class DropEnumType(schema._CreateDropBase):
__visit_name__ = "drop_enum_type"
+9 -8
View File
@@ -5251,17 +5251,18 @@ class DDLCompiler(Compiled):
def get_column_default_string(self, column):
if isinstance(column.server_default, schema.DefaultClause):
if isinstance(column.server_default.arg, str):
return self.sql_compiler.render_literal_value(
column.server_default.arg, sqltypes.STRINGTYPE
)
else:
return self.sql_compiler.process(
column.server_default.arg, literal_binds=True
)
return self.render_default_string(column.server_default.arg)
else:
return None
def render_default_string(self, default):
if isinstance(default, str):
return self.sql_compiler.render_literal_value(
default, sqltypes.STRINGTYPE
)
else:
return self.sql_compiler.process(default, literal_binds=True)
def visit_table_or_column_check_constraint(self, constraint, **kw):
if constraint.is_column_level:
return self.visit_column_check_constraint(constraint)
+12 -10
View File
@@ -312,20 +312,22 @@ class MemUsageTest(EnsureZeroed):
eng = engines.testing_engine()
for args in (
(types.Integer,),
(types.String,),
(types.PickleType,),
(types.Enum, "a", "b", "c"),
(sqlite.DATETIME,),
(postgresql.ENUM, "a", "b", "c"),
(types.Interval,),
(postgresql.INTERVAL,),
(mysql.VARCHAR,),
(types.Integer, {}),
(types.String, {}),
(types.PickleType, {}),
(types.Enum, "a", "b", "c", {}),
(sqlite.DATETIME, {}),
(postgresql.ENUM, "a", "b", "c", {"name": "pgenum"}),
(types.Interval, {}),
(postgresql.INTERVAL, {}),
(mysql.VARCHAR, {}),
):
@profile_memory()
def go():
type_ = args[0](*args[1:])
kwargs = args[-1]
posargs = args[1:-1]
type_ = args[0](*posargs, **kwargs)
bp = type_._cached_bind_processor(eng.dialect)
rp = type_._cached_result_processor(eng.dialect, 0)
bp, rp # strong reference
+75 -1
View File
@@ -38,6 +38,7 @@ from sqlalchemy.dialects.postgresql import aggregate_order_by
from sqlalchemy.dialects.postgresql import ARRAY as PG_ARRAY
from sqlalchemy.dialects.postgresql import array
from sqlalchemy.dialects.postgresql import array_agg as pg_array_agg
from sqlalchemy.dialects.postgresql import DOMAIN
from sqlalchemy.dialects.postgresql import ExcludeConstraint
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.dialects.postgresql import TSRANGE
@@ -270,7 +271,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
render_schema_translate=True,
)
def test_create_type_schema_translate(self):
def test_create_enum_schema_translate(self):
e1 = Enum("x", "y", "z", name="somename")
e2 = Enum("x", "y", "z", name="somename", schema="someschema")
schema_translate_map = {None: "foo", "someschema": "bar"}
@@ -289,6 +290,79 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
render_schema_translate=True,
)
def test_domain(self):
self.assert_compile(
postgresql.CreateDomainType(
DOMAIN(
"x",
Integer,
default=text("11"),
not_null=True,
check="VALUE < 0",
)
),
"CREATE DOMAIN x AS INTEGER DEFAULT 11 NOT NULL CHECK (VALUE < 0)",
)
self.assert_compile(
postgresql.CreateDomainType(
DOMAIN(
"sOmEnAmE",
Text,
collation="utf8",
constraint_name="a constraint",
not_null=True,
)
),
'CREATE DOMAIN "sOmEnAmE" AS TEXT COLLATE utf8 CONSTRAINT '
'"a constraint" NOT NULL',
)
self.assert_compile(
postgresql.CreateDomainType(
DOMAIN(
"foo",
Text,
collation="utf8",
default="foobar",
constraint_name="no_bar",
not_null=True,
check="VALUE != 'bar'",
)
),
"CREATE DOMAIN foo AS TEXT COLLATE utf8 DEFAULT 'foobar' "
"CONSTRAINT no_bar NOT NULL CHECK (VALUE != 'bar')",
)
def test_cast_domain_schema(self):
"""test #6739"""
d1 = DOMAIN("somename", Integer)
d2 = DOMAIN("somename", Integer, schema="someschema")
stmt = select(cast(column("foo"), d1), cast(column("bar"), d2))
self.assert_compile(
stmt,
"SELECT CAST(foo AS somename) AS foo, "
"CAST(bar AS someschema.somename) AS bar",
)
def test_create_domain_schema_translate(self):
d1 = DOMAIN("somename", Integer)
d2 = DOMAIN("somename", Integer, schema="someschema")
schema_translate_map = {None: "foo", "someschema": "bar"}
self.assert_compile(
postgresql.CreateDomainType(d1),
"CREATE DOMAIN foo.somename AS INTEGER ",
schema_translate_map=schema_translate_map,
render_schema_translate=True,
)
self.assert_compile(
postgresql.CreateDomainType(d2),
"CREATE DOMAIN bar.somename AS INTEGER ",
schema_translate_map=schema_translate_map,
render_schema_translate=True,
)
def test_create_table_with_schema_type_schema_translate(self):
e1 = Enum("x", "y", "z", name="somename")
e2 = Enum("x", "y", "z", name="somename", schema="someschema")
+123 -3
View File
@@ -410,6 +410,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
"CREATE DOMAIN nullable_domain AS TEXT CHECK "
"(VALUE IN('FOO', 'BAR'))",
"CREATE DOMAIN not_nullable_domain AS TEXT NOT NULL",
"CREATE DOMAIN my_int AS int CONSTRAINT b_my_int_one CHECK "
"(VALUE > 1) CONSTRAINT a_my_int_two CHECK (VALUE < 42) "
"CHECK(VALUE != 22)",
]:
try:
con.exec_driver_sql(ddl)
@@ -468,6 +471,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
con.exec_driver_sql("DROP TABLE nullable_domain_test")
con.exec_driver_sql("DROP DOMAIN nullable_domain")
con.exec_driver_sql("DROP DOMAIN not_nullable_domain")
con.exec_driver_sql("DROP DOMAIN my_int")
def test_table_is_reflected(self, connection):
metadata = MetaData()
@@ -579,6 +583,122 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
finally:
base.PGDialect.ischema_names = ischema_names
@property
def all_domains(self):
return {
"public": [
{
"visible": True,
"name": "arraydomain",
"schema": "public",
"nullable": True,
"type": "integer[]",
"default": None,
"constraints": [],
},
{
"visible": True,
"name": "enumdomain",
"schema": "public",
"nullable": True,
"type": "testtype",
"default": None,
"constraints": [],
},
{
"visible": True,
"name": "my_int",
"schema": "public",
"nullable": True,
"type": "integer",
"default": None,
"constraints": [
{"check": "VALUE < 42", "name": "a_my_int_two"},
{"check": "VALUE > 1", "name": "b_my_int_one"},
# autogenerated name by pg
{"check": "VALUE <> 22", "name": "my_int_check"},
],
},
{
"visible": True,
"name": "not_nullable_domain",
"schema": "public",
"nullable": False,
"type": "text",
"default": None,
"constraints": [],
},
{
"visible": True,
"name": "nullable_domain",
"schema": "public",
"nullable": True,
"type": "text",
"default": None,
"constraints": [
{
"check": "VALUE = ANY (ARRAY['FOO'::text, "
"'BAR'::text])",
# autogenerated name by pg
"name": "nullable_domain_check",
}
],
},
{
"visible": True,
"name": "testdomain",
"schema": "public",
"nullable": False,
"type": "integer",
"default": "42",
"constraints": [],
},
],
"test_schema": [
{
"visible": False,
"name": "testdomain",
"schema": "test_schema",
"nullable": True,
"type": "integer",
"default": "0",
"constraints": [],
}
],
"SomeSchema": [
{
"visible": False,
"name": "Quoted.Domain",
"schema": "SomeSchema",
"nullable": True,
"type": "integer",
"default": "0",
"constraints": [],
}
],
}
def test_inspect_domains(self, connection):
inspector = inspect(connection)
eq_(inspector.get_domains(), self.all_domains["public"])
def test_inspect_domains_schema(self, connection):
inspector = inspect(connection)
eq_(
inspector.get_domains("test_schema"),
self.all_domains["test_schema"],
)
eq_(
inspector.get_domains("SomeSchema"), self.all_domains["SomeSchema"]
)
def test_inspect_domains_star(self, connection):
inspector = inspect(connection)
all_ = [d for dl in self.all_domains.values() for d in dl]
all_ += inspector.get_domains("information_schema")
exp = sorted(all_, key=lambda d: (d["schema"], d["name"]))
eq_(inspector.get_domains("*"), exp)
class ReflectionTest(
ReflectionFixtures, AssertsCompiledSQL, fixtures.TestBase
@@ -1800,10 +1920,10 @@ class ReflectionTest(
eq_(
check_constraints,
{
"cc1": "(a > 1) AND (a < 5)",
"cc2": "(a = 1) OR ((a > 2) AND (a < 5))",
"cc1": "a > 1 AND a < 5",
"cc2": "a = 1 OR a > 2 AND a < 5",
"cc3": "is_positive(a)",
"cc4": "(b)::text <> 'hi\nim a name \nyup\n'::text",
"cc4": "b::text <> 'hi\nim a name \nyup\n'::text",
},
)
+275 -96
View File
@@ -38,12 +38,15 @@ from sqlalchemy import util
from sqlalchemy.dialects import postgresql
from sqlalchemy.dialects.postgresql import array
from sqlalchemy.dialects.postgresql import DATERANGE
from sqlalchemy.dialects.postgresql import DOMAIN
from sqlalchemy.dialects.postgresql import ENUM
from sqlalchemy.dialects.postgresql import HSTORE
from sqlalchemy.dialects.postgresql import hstore
from sqlalchemy.dialects.postgresql import INT4RANGE
from sqlalchemy.dialects.postgresql import INT8RANGE
from sqlalchemy.dialects.postgresql import JSON
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.dialects.postgresql import NamedType
from sqlalchemy.dialects.postgresql import NUMRANGE
from sqlalchemy.dialects.postgresql import TSRANGE
from sqlalchemy.dialects.postgresql import TSTZRANGE
@@ -161,7 +164,7 @@ class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults):
eq_(row, ([5], [5], [6], [7], [decimal.Decimal("6.4")]))
class EnumTest(fixtures.TestBase, AssertsExecutionResults):
class NamedTypeTest(fixtures.TestBase, AssertsExecutionResults):
__backend__ = True
__only_on__ = "postgresql > 8.3"
@@ -173,16 +176,18 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
"the native_enum flag does not apply to the "
"sqlalchemy.dialects.postgresql.ENUM datatype;"
):
e1 = postgresql.ENUM("a", "b", "c", native_enum=False)
e1 = postgresql.ENUM(
"a", "b", "c", name="pgenum", native_enum=False
)
e2 = postgresql.ENUM("a", "b", "c", native_enum=True)
e3 = postgresql.ENUM("a", "b", "c")
e2 = postgresql.ENUM("a", "b", "c", name="pgenum", native_enum=True)
e3 = postgresql.ENUM("a", "b", "c", name="pgenum")
is_(e1.native_enum, True)
is_(e2.native_enum, True)
is_(e3.native_enum, True)
def test_create_table(self, metadata, connection):
def test_enum_create_table(self, metadata, connection):
metadata = self.metadata
t1 = Table(
"table",
@@ -202,50 +207,147 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
[(1, "two"), (2, "three"), (3, "three")],
)
def test_domain_create_table(self, metadata, connection):
metadata = self.metadata
Email = DOMAIN(
name="email",
data_type=Text,
check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
)
PosInt = DOMAIN(
name="pos_int",
data_type=Integer,
not_null=True,
check=r"VALUE > 0",
)
t1 = Table(
"table",
metadata,
Column("id", Integer, primary_key=True),
Column("email", Email),
Column("number", PosInt),
)
t1.create(connection)
t1.create(connection, checkfirst=True) # check the create
connection.execute(
t1.insert(), {"email": "test@example.com", "number": 42}
)
connection.execute(t1.insert(), {"email": "a@b.c", "number": 1})
connection.execute(
t1.insert(), {"email": "example@gmail.co.uk", "number": 99}
)
eq_(
connection.execute(t1.select().order_by(t1.c.id)).fetchall(),
[
(1, "test@example.com", 42),
(2, "a@b.c", 1),
(3, "example@gmail.co.uk", 99),
],
)
@testing.combinations(
(ENUM("one", "two", "three", name="mytype"), "get_enums"),
(
DOMAIN(
name="mytype",
data_type=Text,
check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
),
"get_domains",
),
argnames="datatype, method",
)
def test_drops_on_table(
self, connection, metadata, datatype: "NamedType", method
):
table = Table("e1", metadata, Column("e1", datatype))
table.create(connection)
table.drop(connection)
assert "mytype" not in [
e["name"] for e in getattr(inspect(connection), method)()
]
table.create(connection)
assert "mytype" in [
e["name"] for e in getattr(inspect(connection), method)()
]
table.drop(connection)
assert "mytype" not in [
e["name"] for e in getattr(inspect(connection), method)()
]
@testing.combinations(
(
lambda symbol_name: ENUM(
"one", "two", "three", name="schema_mytype", schema=symbol_name
),
["two", "three", "three"],
"get_enums",
),
(
lambda symbol_name: DOMAIN(
name="schema_mytype",
data_type=Text,
check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
schema=symbol_name,
),
["test@example.com", "a@b.c", "example@gmail.co.uk"],
"get_domains",
),
argnames="datatype,data,method",
)
@testing.combinations(None, "foo", argnames="symbol_name")
def test_create_table_schema_translate_map(self, connection, symbol_name):
def test_create_table_schema_translate_map(
self, connection, symbol_name, datatype, data, method
):
# note we can't use the fixture here because it will not drop
# from the correct schema
metadata = MetaData()
dt = datatype(symbol_name)
t1 = Table(
"table",
metadata,
Column("id", Integer, primary_key=True),
Column(
"value",
Enum(
"one",
"two",
"three",
name="schema_enum",
schema=symbol_name,
),
),
Column("value", dt),
schema=symbol_name,
)
conn = connection.execution_options(
schema_translate_map={symbol_name: testing.config.test_schema}
)
t1.create(conn)
assert "schema_enum" in [
assert "schema_mytype" in [
e["name"]
for e in inspect(conn).get_enums(schema=testing.config.test_schema)
for e in getattr(inspect(conn), method)(
schema=testing.config.test_schema
)
]
t1.create(conn, checkfirst=True)
conn.execute(t1.insert(), dict(value="two"))
conn.execute(t1.insert(), dict(value="three"))
conn.execute(t1.insert(), dict(value="three"))
conn.execute(
t1.insert(),
dict(value=data[0]),
)
conn.execute(t1.insert(), dict(value=data[1]))
conn.execute(t1.insert(), dict(value=data[2]))
eq_(
conn.execute(t1.select().order_by(t1.c.id)).fetchall(),
[(1, "two"), (2, "three"), (3, "three")],
[
(1, data[0]),
(2, data[1]),
(3, data[2]),
],
)
t1.drop(conn)
assert "schema_enum" not in [
assert "schema_mytype" not in [
e["name"]
for e in inspect(conn).get_enums(schema=testing.config.test_schema)
for e in getattr(inspect(conn), method)(
schema=testing.config.test_schema
)
]
t1.drop(conn, checkfirst=True)
@@ -256,40 +358,48 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
("override_metadata_schema",),
argnames="test_case",
)
@testing.combinations("enum", "domain", argnames="datatype")
@testing.requires.schemas
def test_schema_inheritance(self, test_case, metadata, connection):
def test_schema_inheritance(
self, test_case, metadata, connection, datatype
):
"""test #6373"""
metadata.schema = testing.config.test_schema
def make_type(**kw):
if datatype == "enum":
return Enum("four", "five", "six", name="mytype", **kw)
elif datatype == "domain":
return DOMAIN(
name="mytype",
data_type=Text,
check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
**kw,
)
else:
assert False
if test_case == "metadata_schema_only":
enum = Enum(
"four", "five", "six", metadata=metadata, name="myenum"
)
enum = make_type(metadata=metadata)
assert_schema = testing.config.test_schema
elif test_case == "override_metadata_schema":
enum = Enum(
"four",
"five",
"six",
enum = make_type(
metadata=metadata,
schema=testing.config.test_schema_2,
name="myenum",
)
assert_schema = testing.config.test_schema_2
elif test_case == "inherit_table_schema":
enum = Enum(
"four",
"five",
"six",
enum = make_type(
metadata=metadata,
inherit_schema=True,
name="myenum",
)
assert_schema = testing.config.test_schema_2
elif test_case == "local_schema":
enum = Enum("four", "five", "six", name="myenum")
enum = make_type()
assert_schema = testing.config.db.dialect.default_schema_name
else:
assert False
Table(
"t",
@@ -300,27 +410,62 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
metadata.create_all(connection)
eq_(
inspect(connection).get_enums(schema=assert_schema),
[
{
"labels": ["four", "five", "six"],
"name": "myenum",
"schema": assert_schema,
"visible": assert_schema
== testing.config.db.dialect.default_schema_name,
}
],
)
if datatype == "enum":
eq_(
inspect(connection).get_enums(schema=assert_schema),
[
{
"labels": ["four", "five", "six"],
"name": "mytype",
"schema": assert_schema,
"visible": assert_schema
== testing.config.db.dialect.default_schema_name,
}
],
)
elif datatype == "domain":
def test_name_required(self, metadata, connection):
etype = Enum("four", "five", "six", metadata=metadata)
assert_raises(exc.CompileError, etype.create, connection)
def_schame = testing.config.db.dialect.default_schema_name
eq_(
inspect(connection).get_domains(schema=assert_schema),
[
{
"name": "mytype",
"type": "text",
"nullable": True,
"default": None,
"schema": assert_schema,
"visible": assert_schema == def_schame,
"constraints": [
{
"name": "mytype_check",
"check": r"VALUE ~ '[^@]+@[^@]+\.[^@]+'::text",
}
],
}
],
)
else:
assert False
@testing.combinations(
(ENUM("one", "two", "three", name=None)),
(
DOMAIN(
name=None,
data_type=Text,
check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
),
),
argnames="datatype",
)
def test_name_required(self, metadata, connection, datatype):
assert_raises(exc.CompileError, datatype.create, connection)
assert_raises(
exc.CompileError, etype.compile, dialect=connection.dialect
exc.CompileError, datatype.compile, dialect=connection.dialect
)
def test_unicode_labels(self, connection, metadata):
def test_enum_unicode_labels(self, connection, metadata):
t1 = Table(
"table",
metadata,
@@ -426,22 +571,30 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
connection.execute(t1.insert(), {"bar": "Ü"})
eq_(connection.scalar(select(t1.c.bar)), "Ü")
def test_disable_create(self, metadata, connection):
@testing.combinations(
(ENUM("one", "two", "three", name="mytype", create_type=False),),
(
DOMAIN(
name="mytype",
data_type=Text,
check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
create_type=False,
),
),
argnames="datatype",
)
def test_disable_create(self, metadata, connection, datatype):
metadata = self.metadata
e1 = postgresql.ENUM(
"one", "two", "three", name="myenum", create_type=False
)
t1 = Table("e1", metadata, Column("c1", e1))
t1 = Table("e1", metadata, Column("c1", datatype))
# table can be created separately
# without conflict
e1.create(bind=connection)
datatype.create(bind=connection)
t1.create(connection)
t1.drop(connection)
e1.drop(bind=connection)
datatype.drop(bind=connection)
def test_dont_keep_checking(self, metadata, connection):
def test_enum_dont_keep_checking(self, metadata, connection):
metadata = self.metadata
e1 = postgresql.ENUM("one", "two", "three", name="myenum")
@@ -486,7 +639,36 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
RegexSQL("DROP TYPE myenum", dialect="postgresql"),
)
def test_generate_multiple(self, metadata, connection):
@testing.combinations(
(
Enum(
"one",
"two",
"three",
name="mytype",
),
"get_enums",
),
(
ENUM(
"one",
"two",
"three",
name="mytype",
),
"get_enums",
),
(
DOMAIN(
name="mytype",
data_type=Text,
check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
),
"get_domains",
),
argnames="datatype, method",
)
def test_generate_multiple(self, metadata, connection, datatype, method):
"""Test that the same enum twice only generates once
for the create_all() call, without using checkfirst.
@@ -494,15 +676,20 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
now handles this.
"""
e1 = Enum("one", "two", "three", name="myenum")
Table("e1", metadata, Column("c1", e1))
Table("e1", metadata, Column("c1", datatype))
Table("e2", metadata, Column("c1", e1))
Table("e2", metadata, Column("c1", datatype))
metadata.create_all(connection, checkfirst=False)
assert "mytype" in [
e["name"] for e in getattr(inspect(connection), method)()
]
metadata.drop_all(connection, checkfirst=False)
assert "myenum" not in [
e["name"] for e in inspect(connection).get_enums()
assert "mytype" not in [
e["name"] for e in getattr(inspect(connection), method)()
]
def test_generate_alone_on_metadata(self, connection, metadata):
@@ -571,23 +758,6 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
for e in inspect(connection).get_enums(schema="test_schema")
]
def test_drops_on_table(self, connection, metadata):
e1 = Enum("one", "two", "three", name="myenum")
table = Table("e1", metadata, Column("c1", e1))
table.create(connection)
table.drop(connection)
assert "myenum" not in [
e["name"] for e in inspect(connection).get_enums()
]
table.create(connection)
assert "myenum" in [e["name"] for e in inspect(connection).get_enums()]
table.drop(connection)
assert "myenum" not in [
e["name"] for e in inspect(connection).get_enums()
]
def test_create_drop_schema_translate_map(self, connection):
conn = connection.execution_options(
@@ -1445,15 +1615,16 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase):
array_agg,
)
element_type = ENUM if with_enum else Integer
element = ENUM(name="pgenum") if with_enum else Integer()
element_type = type(element)
expr = (
array_agg(
aggregate_order_by(
column("q", element_type), column("idx", Integer)
column("q", element), column("idx", Integer)
)
)
if using_aggregate_order_by
else array_agg(column("q", element_type))
else array_agg(column("q", element))
)
is_(expr.type.__class__, postgresql.ARRAY)
is_(expr.type.item_type.__class__, element_type)
@@ -2081,10 +2252,13 @@ class ArrayRoundTripTest:
],
testing.requires.hstore,
),
(postgresql.ENUM(AnEnum), enum_values),
(postgresql.ENUM(AnEnum, name="pgenum"), enum_values),
(sqltypes.Enum(AnEnum, native_enum=True), enum_values),
(sqltypes.Enum(AnEnum, native_enum=False), enum_values),
(postgresql.ENUM(AnEnum, native_enum=True), enum_values),
(
postgresql.ENUM(AnEnum, name="pgenum", native_enum=True),
enum_values,
),
(
make_difficult_enum(sqltypes.Enum, native=True),
difficult_enum_values,
@@ -2102,10 +2276,15 @@ class ArrayRoundTripTest:
if not exclude_empty_lists:
elements.extend(
[
(postgresql.ENUM(AnEnum), empty_list),
(postgresql.ENUM(AnEnum, name="pgenum"), empty_list),
(sqltypes.Enum(AnEnum, native_enum=True), empty_list),
(sqltypes.Enum(AnEnum, native_enum=False), empty_list),
(postgresql.ENUM(AnEnum, native_enum=True), empty_list),
(
postgresql.ENUM(
AnEnum, name="pgenum", native_enum=True
),
empty_list,
),
]
)
if not exclude_json:
@@ -2410,7 +2589,7 @@ class ArrayEnum(fixtures.TestBase):
),
Column(
"pyenum_col",
array_cls(enum_cls(MyEnum)),
array_cls(enum_cls(MyEnum, name="pgenum")),
),
)
+22 -23
View File
@@ -111,7 +111,11 @@ def _all_dialects():
def _types_for_mod(mod):
for key in dir(mod):
typ = getattr(mod, key)
if not isinstance(typ, type) or not issubclass(typ, types.TypeEngine):
if (
not isinstance(typ, type)
or not issubclass(typ, types.TypeEngine)
or typ.__dict__.get("__abstract__")
):
continue
yield typ
@@ -143,6 +147,17 @@ def _all_types(omit_special_types=False):
yield typ
def _get_instance(type_):
if issubclass(type_, ARRAY):
return type_(String)
elif hasattr(type_, "__test_init__"):
t1 = type_.__test_init__()
is_(isinstance(t1, type_), True)
return t1
else:
return type_()
class AdaptTest(fixtures.TestBase):
@testing.combinations(((t,) for t in _types_for_mod(types)), id_="n")
def test_uppercase_importable(self, typ):
@@ -240,11 +255,8 @@ class AdaptTest(fixtures.TestBase):
adapt() beyond their defaults.
"""
t1 = _get_instance(typ)
if issubclass(typ, ARRAY):
t1 = typ(String)
else:
t1 = typ()
for cls in target_adaptions:
if (is_down_adaption and issubclass(typ, sqltypes.Emulated)) or (
not is_down_adaption and issubclass(cls, sqltypes.Emulated)
@@ -301,19 +313,13 @@ class AdaptTest(fixtures.TestBase):
@testing.uses_deprecated()
@testing.combinations(*[(t,) for t in _all_types(omit_special_types=True)])
def test_repr(self, typ):
if issubclass(typ, ARRAY):
t1 = typ(String)
else:
t1 = typ()
t1 = _get_instance(typ)
repr(t1)
@testing.uses_deprecated()
@testing.combinations(*[(t,) for t in _all_types(omit_special_types=True)])
def test_str(self, typ):
if issubclass(typ, ARRAY):
t1 = typ(String)
else:
t1 = typ()
t1 = _get_instance(typ)
str(t1)
def test_str_third_party(self):
@@ -400,7 +406,7 @@ class AsGenericTest(fixtures.TestBase):
(pg.JSON(), sa.JSON()),
(pg.ARRAY(sa.String), sa.ARRAY(sa.String)),
(Enum("a", "b", "c"), Enum("a", "b", "c")),
(pg.ENUM("a", "b", "c"), Enum("a", "b", "c")),
(pg.ENUM("a", "b", "c", name="pgenum"), Enum("a", "b", "c")),
(mysql.ENUM("a", "b", "c"), Enum("a", "b", "c")),
(pg.INTERVAL(precision=5), Interval(native=True, second_precision=5)),
(
@@ -419,11 +425,7 @@ class AsGenericTest(fixtures.TestBase):
]
)
def test_as_generic_all_types_heuristic(self, type_):
if issubclass(type_, ARRAY):
t1 = type_(String)
else:
t1 = type_()
t1 = _get_instance(type_)
try:
gentype = t1.as_generic()
except NotImplementedError:
@@ -445,10 +447,7 @@ class AsGenericTest(fixtures.TestBase):
]
)
def test_as_generic_all_types_custom(self, type_):
if issubclass(type_, ARRAY):
t1 = type_(String)
else:
t1 = type_()
t1 = _get_instance(type_)
gentype = t1.as_generic(allow_nulltype=False)
assert isinstance(gentype, TypeEngine)