Corrected behavior of get_cls_kwargs and friends

This commit is contained in:
Jason Kirtland
2008-01-24 00:08:40 +00:00
parent 29f7a38ee0
commit f6439ffa2c
2 changed files with 95 additions and 10 deletions
+28 -10
View File
@@ -4,8 +4,9 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import itertools, sys, warnings, sets, weakref
import inspect, itertools, sets, sys, warnings, weakref
import __builtin__
types = __import__('types')
from sqlalchemy import exceptions
@@ -181,20 +182,37 @@ class ArgSingleton(type):
return instance
def get_cls_kwargs(cls):
"""Return the full set of legal kwargs for the given `cls`."""
"""Return the full set of inherited kwargs for the given `cls`.
Probes a class's __init__ method, collecting all named arguments. If the
__init__ defines a **kwargs catch-all, then the constructor is presumed to
pass along unrecognized keywords to it's base classes, and the collection
process is repeated recursively on each of the bases.
"""
kw = []
for c in cls.__mro__:
cons = c.__init__
if hasattr(cons, 'func_code'):
for vn in cons.func_code.co_varnames:
if vn != 'self':
kw.append(vn)
return kw
if '__init__' in c.__dict__:
stack = [c]
break
else:
return []
args = Set()
while stack:
class_ = stack.pop()
ctr = class_.__dict__.get('__init__', False)
if not ctr or not isinstance(ctr, types.FunctionType):
continue
names, _, has_kw, _ = inspect.getargspec(ctr)
args |= Set(names)
if has_kw:
stack.extend(class_.__bases__)
args.discard('self')
return list(args)
def get_func_kwargs(func):
"""Return the full set of legal kwargs for the given `func`."""
return [vn for vn in func.func_code.co_varnames]
return inspect.getargspec(func)[0]
# from paste.deploy.converters
def asbool(obj):
+67
View File
@@ -305,5 +305,72 @@ class DictlikeIteritemsTest(unittest.TestCase):
self._notok(duck6())
class ArgInspectionTest(PersistTest):
def test_get_cls_kwargs(self):
class A(object):
def __init__(self, a):
pass
class A1(A):
def __init__(self, a1):
pass
class A11(A1):
def __init__(self, a11, **kw):
pass
class B(object):
def __init__(self, b, **kw):
pass
class B1(B):
def __init__(self, b1, **kw):
pass
class AB(A, B):
def __init__(self, ab):
pass
class BA(B, A):
def __init__(self, ba, **kwargs):
pass
class BA1(BA):
pass
class CAB(A, B):
pass
class CBA(B, A):
pass
class CAB1(A, B1):
pass
class CB1A(B1, A):
pass
class D(object):
pass
def test(cls, *expected):
self.assertEquals(set(util.get_cls_kwargs(cls)), set(expected))
test(A, 'a')
test(A1, 'a1')
test(A11, 'a11', 'a1')
test(B, 'b')
test(B1, 'b1', 'b')
test(AB, 'ab')
test(BA, 'ba', 'b', 'a')
test(BA1, 'ba', 'b', 'a')
test(CAB, 'a')
test(CBA, 'b')
test(CAB1, 'a')
test(CB1A, 'b1', 'b')
test(D)
def test_get_func_kwargs(self):
def f1(): pass
def f2(foo): pass
def f3(*foo): pass
def f4(**foo): pass
def test(fn, *expected):
self.assertEquals(set(util.get_func_kwargs(fn)), set(expected))
test(f1)
test(f2, 'foo')
test(f3)
test(f4)
if __name__ == "__main__":
testenv.main()