Thursday 30 September 2010

Asserting Function Calls in Python

One of the nicest features of Python is “duck typing”, which means you don’t need to create interfaces to allow you to swap out implementations. Instead you simply create a different object that has the functions you need.

One really powerful use of this is in unit testing, allowing you to create lightweight replacements for dependencies without the need for a powerful mocking framework. Having said that, sometimes you need to be able to do things like checking that a function was called on an existing object. I asked about this on StackOverflow, and got a variety of different approaches to this problem.

Thanks to another feature of Python, sometimes called “monkey patching” you can take any object and replace an existing function with your own. This is obviously very powerful (and potentially dangerous) but it opens up all sorts of possibilities.

Here’s an example of monkey patching to replace the existing implementation of MyFunc with a lambda expression that simply counts how many times it was called.

def testMyFunc():
    obj = MyObject()
    calls = 0
    obj.MyFunc = lambda: calls += 1
    # DoSomething should call MyFunc
    DoSomething(obj)
    assert calls == 1

To take this one step further, we might wish to still call through to the original implementation of MyFunc. We can simply this by creating a helper class:

class MethodCallLogger(object):
    def __init__(self, meth):
        self.meth = meth
        self.CallCount = 0

    def __call__(self, *args):
        self.meth(*args)
        self.CallCount += 1

This class will call through to the original function, as well as count how many times it was called. The __call__ function is a way of allowing a class to be called as though it were a function. The *args syntax simply lets us support functions with multiple parameters. These could then be saved into a list and made available to the unit test if necessary. Here’s our first example again, using the MethodCallLogger class:

def testMyFunc():
    obj = MyObject()
    logger = MethodCallLogger(obj.MyFunc)
    obj.MyFunc = logger
    # DoSomething should call MyFunc
    DoSomething(obj)
    assert logger.CallCount == 1

No comments: