import unittest
from signature import Signature
from inspect import getargspec
from re import compile as re_compile
import linecache

signature_regex = re_compile(r"^def [^\(]+\((?P<args>.*)\):$")

# When adding example functions, be careful to make sure that the formatting of
# the arguments match what would be returned by repr or obj.__name__ .  Also
# make sure to not have extraneous whitespace and use ``*args`` and
# ``**kwargs``.
def ex_empty():
    pass
def ex_required_args(arg1, arg2):
    pass
def ex_default_args(default1=1, default2=2):
    pass
def ex_excess_pos_args(*args):
    pass
def ex_excess_kw_args(**kwargs):
    pass
def ex_everything(arg, default=1, *args, **kwargs):
    pass
def ex_funky_defaults(arg=object, arg2=14, arg3='blah', arg4=None):
    pass
"""
def ex_tuples((a,b)):
    pass
"""

# When adding classes with __call__ methods, only have 'self' as an argument
class HasCall(object):
    def __call__(self):
        pass
class OldHasCall:
    def __call__(self):
        pass

class NoCall(object):
    def __init__(self):
        pass
class OldNoCall:
    def __init__(self):
        pass

class Nothing(object):
    pass
class OldNothing:
    pass

# A genexp passed to tuple() is used instead of a listcomp because the loop
# variant from the listcomp "leaks" into the global namespace and causes the
# dict returned by globals() to be resized during iteration
test_fxns = tuple(fxn_object for name, fxn_object in globals().iteritems()
                            if name.startswith("ex_"))

class ConstructionTests(unittest.TestCase):

    def inspect_compare(self, obj):
        inspect_result = getargspec(obj)
        sig_result = Signature(obj)
        all_args = (list(sig_result.required_args) +
                   [arg for arg, ignore in sig_result.default_args])
        self.failUnlessEqual(inspect_result[0], all_args)
        self.failUnlessEqual(bool(inspect_result[1]),
                             sig_result.excess_pos_args)
        self.failUnlessEqual(bool(inspect_result[2]),
                             sig_result.excess_kw_args)
        inspect_defaults = inspect_result[3]
        if inspect_defaults:
            default_args = all_args[len(all_args)-len(inspect_defaults):]
            inspect_pairs = tuple(zip(default_args, inspect_defaults))
            self.failUnlessEqual(inspect_pairs, sig_result.default_args)
        else:
            self.failUnless(not sig_result.default_args)

    def simple_method_compare(self, obj):
        sig = Signature(obj)
        self.failUnlessEqual(sig.required_args, ('self',))
        self.failUnless(not sig.default_args)
        self.failUnless(not sig.excess_pos_args)
        self.failUnless(not sig.excess_kw_args)


    def test_init_TypeErrors(self):
        self.failUnlessRaises(TypeError, Signature, HasCall)
        self.failUnlessRaises(TypeError, Signature, OldHasCall)
        self.failUnlessRaises(TypeError, Signature, NoCall())
        self.failUnlessRaises(TypeError, Signature, OldNoCall())
        self.failUnlessRaises(TypeError, Signature, Nothing())
        self.failUnlessRaises(TypeError, Signature, OldNothing())

    def test_init_instances(self):
        for ins in (HasCall(), OldHasCall()):
            self.simple_method_compare(ins)
            
    def test_init_classes(self):
        for cls in (NoCall, OldNoCall):
            self.simple_method_compare(cls)

    def test_functions(self):
        for fxn in test_fxns:
            self.inspect_compare(fxn)

class StringTests(unittest.TestCase):

    def tearDown(self):
        linecache.clearcache()

    def test_str(self):
        for fxn in test_fxns:
            def_line = linecache.getline(__file__, fxn.func_code.co_firstlineno)
            signature_str = signature_regex.match(def_line).group("args")
            sig_result = str(Signature(fxn))
            self.failUnlessEqual(signature_str, sig_result)


if __name__ == "__main__":
    unittest.main()








