diff --git a/doc/build/changelog/unreleased_13/4653.rst b/doc/build/changelog/unreleased_13/4653.rst new file mode 100644 index 0000000000..67e198ce63 --- /dev/null +++ b/doc/build/changelog/unreleased_13/4653.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, sql + :tickets: 4653 + + Fixed that the :class:`.GenericFunction` class was inadvertently + registering itself as one of the named functions. Pull request courtesy + Adrien Berchet. \ No newline at end of file diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index cb503892c8..3340451440 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -635,7 +635,17 @@ class _GenericMeta(VisitableType): # legacy if "__return_type__" in clsdict: cls.type = clsdict["__return_type__"] - register_function(identifier, cls, package) + + # Check _register attribute status + cls._register = getattr(cls, '_register', True) + + # Register the function if required + if cls._register: + register_function(identifier, cls, package) + else: + # Set _register to True to register child classes by default + cls._register = True + super(_GenericMeta, cls).__init__(clsname, bases, clsdict) @@ -703,6 +713,7 @@ class GenericFunction(util.with_metaclass(_GenericMeta, Function)): """ coerce_arguments = True + _register = False def __init__(self, *args, **kwargs): parsed_args = kwargs.pop("_parsed_args", None) diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index b03b156bc1..7f7ba14e23 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -1057,3 +1057,43 @@ def exec_sorted(statement, *args, **kw): return sorted( [tuple(row) for row in statement.execute(*args, **kw).fetchall()] ) + + +class RegisterTest(fixtures.TestBase, AssertsCompiledSQL): + __dialect__ = "default" + + def setup(self): + self._registry = deepcopy(functions._registry) + + def teardown(self): + functions._registry = self._registry + + def test_GenericFunction_is_registered(self): + assert 'GenericFunction' not in functions._registry['_default'] + + def test_register_function(self): + + # test generic function registering + class registered_func(GenericFunction): + _register = True + + def __init__(self, *args, **kwargs): + GenericFunction.__init__(self, *args, **kwargs) + + class registered_func_child(registered_func): + type = sqltypes.Integer + + assert 'registered_func' in functions._registry['_default'] + assert isinstance(func.registered_func_child().type, Integer) + + class not_registered_func(GenericFunction): + _register = False + + def __init__(self, *args, **kwargs): + GenericFunction.__init__(self, *args, **kwargs) + + class not_registered_func_child(not_registered_func): + type = sqltypes.Integer + + assert 'not_registered_func' not in functions._registry['_default'] + assert isinstance(func.not_registered_func_child().type, Integer)