mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-18 06:32:08 -04:00
Corrected behavior of get_cls_kwargs and friends
This commit is contained in:
+28
-10
@@ -4,8 +4,9 @@
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
import itertools, sys, warnings, sets, weakref
|
||||
import inspect, itertools, sets, sys, warnings, weakref
|
||||
import __builtin__
|
||||
types = __import__('types')
|
||||
|
||||
from sqlalchemy import exceptions
|
||||
|
||||
@@ -181,20 +182,37 @@ class ArgSingleton(type):
|
||||
return instance
|
||||
|
||||
def get_cls_kwargs(cls):
|
||||
"""Return the full set of legal kwargs for the given `cls`."""
|
||||
"""Return the full set of inherited kwargs for the given `cls`.
|
||||
|
||||
Probes a class's __init__ method, collecting all named arguments. If the
|
||||
__init__ defines a **kwargs catch-all, then the constructor is presumed to
|
||||
pass along unrecognized keywords to it's base classes, and the collection
|
||||
process is repeated recursively on each of the bases.
|
||||
"""
|
||||
|
||||
kw = []
|
||||
for c in cls.__mro__:
|
||||
cons = c.__init__
|
||||
if hasattr(cons, 'func_code'):
|
||||
for vn in cons.func_code.co_varnames:
|
||||
if vn != 'self':
|
||||
kw.append(vn)
|
||||
return kw
|
||||
if '__init__' in c.__dict__:
|
||||
stack = [c]
|
||||
break
|
||||
else:
|
||||
return []
|
||||
|
||||
args = Set()
|
||||
while stack:
|
||||
class_ = stack.pop()
|
||||
ctr = class_.__dict__.get('__init__', False)
|
||||
if not ctr or not isinstance(ctr, types.FunctionType):
|
||||
continue
|
||||
names, _, has_kw, _ = inspect.getargspec(ctr)
|
||||
args |= Set(names)
|
||||
if has_kw:
|
||||
stack.extend(class_.__bases__)
|
||||
args.discard('self')
|
||||
return list(args)
|
||||
|
||||
def get_func_kwargs(func):
|
||||
"""Return the full set of legal kwargs for the given `func`."""
|
||||
return [vn for vn in func.func_code.co_varnames]
|
||||
return inspect.getargspec(func)[0]
|
||||
|
||||
# from paste.deploy.converters
|
||||
def asbool(obj):
|
||||
|
||||
@@ -305,5 +305,72 @@ class DictlikeIteritemsTest(unittest.TestCase):
|
||||
self._notok(duck6())
|
||||
|
||||
|
||||
class ArgInspectionTest(PersistTest):
|
||||
def test_get_cls_kwargs(self):
|
||||
class A(object):
|
||||
def __init__(self, a):
|
||||
pass
|
||||
class A1(A):
|
||||
def __init__(self, a1):
|
||||
pass
|
||||
class A11(A1):
|
||||
def __init__(self, a11, **kw):
|
||||
pass
|
||||
class B(object):
|
||||
def __init__(self, b, **kw):
|
||||
pass
|
||||
class B1(B):
|
||||
def __init__(self, b1, **kw):
|
||||
pass
|
||||
class AB(A, B):
|
||||
def __init__(self, ab):
|
||||
pass
|
||||
class BA(B, A):
|
||||
def __init__(self, ba, **kwargs):
|
||||
pass
|
||||
class BA1(BA):
|
||||
pass
|
||||
class CAB(A, B):
|
||||
pass
|
||||
class CBA(B, A):
|
||||
pass
|
||||
class CAB1(A, B1):
|
||||
pass
|
||||
class CB1A(B1, A):
|
||||
pass
|
||||
class D(object):
|
||||
pass
|
||||
|
||||
def test(cls, *expected):
|
||||
self.assertEquals(set(util.get_cls_kwargs(cls)), set(expected))
|
||||
|
||||
test(A, 'a')
|
||||
test(A1, 'a1')
|
||||
test(A11, 'a11', 'a1')
|
||||
test(B, 'b')
|
||||
test(B1, 'b1', 'b')
|
||||
test(AB, 'ab')
|
||||
test(BA, 'ba', 'b', 'a')
|
||||
test(BA1, 'ba', 'b', 'a')
|
||||
test(CAB, 'a')
|
||||
test(CBA, 'b')
|
||||
test(CAB1, 'a')
|
||||
test(CB1A, 'b1', 'b')
|
||||
test(D)
|
||||
|
||||
def test_get_func_kwargs(self):
|
||||
def f1(): pass
|
||||
def f2(foo): pass
|
||||
def f3(*foo): pass
|
||||
def f4(**foo): pass
|
||||
|
||||
def test(fn, *expected):
|
||||
self.assertEquals(set(util.get_func_kwargs(fn)), set(expected))
|
||||
|
||||
test(f1)
|
||||
test(f2, 'foo')
|
||||
test(f3)
|
||||
test(f4)
|
||||
|
||||
if __name__ == "__main__":
|
||||
testenv.main()
|
||||
|
||||
Reference in New Issue
Block a user