domingo, 10 de junho de 2012

Integrating Lua unit tests in CxxTest

I'm currently working on a project that uses both C++ and Lua. In the beginning I looked at several options to do unit tests in C++ and ended up choosing CxxTest. I've managed to integrate it well in my cmake build script and all the tests run with a simple "make test".
Because the Lua code calls the C++ code I started coding it later and only now I'm writing unit tests for it (yeah, I'm not a TDD fan). There are several good options for unit testing in Lua (see lua wiki) but there are two reasons why I didn't use them. The first one is that I wanted to use only one testing framework to have unified testing statistics. The other reason is that they lack a certain features CxxTest has. The thing I like about CxxTest is that its test macros also capture the source text. For example if I write:
int i = 3;
TS_ASSERT_EQUALS(i, 5);
TS_ASSERT(i == 5);
I get:
test.cpp:29: Error: Expected (i == 5), found (3 != 5)
test.cpp:30: Error: Assertion failed: i == 5
When you look at that output you already know what variables are wrong even before opening the code in your text editor. It would be nice to have this feature in Lua as well.
Of course this only work in CxxTest because it uses the argument to string conversion feature of C macros. Now, Lua has no macros but it has a much more powerful feature which is the dynamic loading of code from strings (Actually there is a macro extension for Lua). Its impossible to write
i = 3
TS_ASSERT(i == 5)
But it should be possible to write:
i = 3
TS_ASSERT[[i == 5]]
Our first try is
function TS_ASSERT(code)
    local func = assert(loadstring("return "..code))
    if not func() then
       local info = debug.getinfo(2, "Sl");
       local msg = {"Failed assertion: ", code}
       TS_FAIL(info.short_src, info.currentline, table.concat(msg))
    end
end
This function tries to compile a string, containing an expression, passed as argument and checks if the result evaluates to true or false according to Lua's rules. If the result is true the test passed and life goes on. If not, we find out the source location where our function was called and pass it on to TS_FAIL together with the expression that failed. For now, all we need to know about TS_FAIL is that it calls CxxTest to account for this failure.
The output for the sample above would be:
test.lua:29: Error: Failed assertion: i == 5
This is very nice but we have no clue about the actual value of i. And there is another problem: it only works with global variables because the code compiled with loadstring has no lexical scoping. If we had written
local i = 3
TS_ASSERT[[i == 5]]
the test would always fail because the undefined global variable i would be nil.
Let't turn our attention to the first problem first: we need to find out the values of variables inside a string of code. We could print all the values of _G, the global variable table, but this would clutter the output confusing the programmer instead of helping him. An extreme solution would be to parse the code ourselves to find out the names of the variables and lookup their values in _G to print them. A more practical solution is to use environments and meta-tables. In Lua we can set a special environment for functions. At the same time we can use the meta-method __index to track the access to keys that are not present. With a specially crafted environment table we can track all the variables acesses of a funtion. Adding these ideas our second attempt looks like this:
function TS_ASSERT(code)

    local func = assert(loadstring("return "..code))

    local calling_globals = getfenv(2)
    local referenced = {}
    local env = {}

    setmetatable(env, {__index=
        function(table, name)
            local value = calling_globals[name] end
            referenced[name] = value
            return value
        end
    })

    setfenv(func, env)
    if not func() then
       local info = debug.getinfo(2, "Sl");
       local msg = {"Failed assertion: ", code}

       local first = true
       for k, v in pairs(referenced) do
           if first then
               msg[#msg+1]= ". Variables: "
           else
               msg[#msg+1] = ", "
           end
           msg[#msg+1]= tostring(k).." = "..tostring(v)
           first = false
       end
       TS_FAIL(info.short_src, info.currentline, table.concat(msg))
    end
end
First we get the environment of the calling function. Then we give our newly compiled function an empty environment with an __index meta-method that will forward all variable lookups to the environment of the calling function. So now the output for the sample above would be:
test.lua:29: Error: Failed assertion: i == 5. Variables: i = 3
Much better, but it still won't work for local variables. Unlike globals, local variables are not stored in a table, possibly for performance reasons. For this we must, again, resort to the debug library. The function debug.getlocal an index as argument and returns the name of the corresponding local variable, followed by its value. All we have to do is to copy these values to our fake envirnonment:
...
    local calling_globals = getfenv(2)
    local calling_locals  = {}
    local i = 1
    while true do
        local name, value = debug.getlocal(2, i)
        if not name then break end
        calling_locals[name] = value
        i = i + 1
    end

    local notpresent = {}
    setmetatable(calling_locals, { __index=function() return notpresent end})

    local referenced = {}
    local env = {}

    setmetatable(env, {__index=
        function(table, name)
            local value = calling_locals[name]
            if value == notpresent then value = calling_globals[name] end
            referenced[name] = value
            return value
        end
    })
...
We collect the local variables in a table and we make the intercepting function look in this table before searching in the global environment, respecting the Lua's rules for variable lookup. The only complication is that a table lookup returns nil both if the entry doesn't exists or its value is really nil. To make the distinction we use the empty table "notpresent". The call to getlocal might be slow but we don't want to make the fastest unit test in the world but the most helpful.
As a last addition we can make the following enhancement to TS_ASSERT: if the argument is not a string we simply see if it evaluates to false.
...
    if (type(code) ~= "string") then
        if not code then
            local info = debug.getinfo(2, "Sl");
            local msg = {"Failed assertion"}
            TS_FAIL(info.short_src, info.currentline, table.concat(msg))
        end
        return
    end
...
Now we can also write:
local file = io.open("/dev/null", "w")
TS_ASSERT(file)
This assertion will fail if the io.open call is unsuccessful.
The last thing left to explain is the TS_FAIL function. We have used it passing a custom source location but it would also be nice to use it like the C++ macro to fail a test:
local file, msg = io.open("/dev/null", "w")
if not file then TS_FAIL(msg) end
All we have to do is to write a small binding for the CxxTest API function that the macro calls internally:
int ts_fail(lua_State* L) {
    const int nargs = lua_gettop(L);

    if (nargs == 3) {
        CxxTest::TestTracker::tracker().failedTest(
            luaL_checkstring(L, -3),
            luaL_checkint(L, -2),
            luaL_checkstring(L, -1) );
    } else if (nargs >= 1) {
        lua_Debug debug;
        lua_getstack(L, 1, &debug);
        lua_getinfo(L, "Sl", &debug);
        CxxTest::TestTracker::tracker().failedTest(
            debug.short_src,
            debug.currentline,
            luaL_checkstring(L, -1) );
    } else {
        luaL_error(L, "TS_FAIL called with an illegal number of arguments");
    }
    return 0;
}
If this function is called with three arguments we assume that the first two are the location and the second the message. If only one argument is supplied, we assume that it is the error message and find out the calling location using the C API of the debug library. In both cases the location and the message are passed to the failedTest method of CxxTest's TestTracker singleton.
It's arguable if the name of the variables in the testing output will help you to fix you errors faster, but I think it is interesting to see that it can be done. There are many more features we can add to this framework. One, for example, is measuring the code coverage of the tested code with Lua's debug library. But for now this will already help me a lot to write unit tests in Lua.

sábado, 9 de junho de 2012

Lua: Generic call with variadic templates

If you want to call a Lua function from C/C++ you have to follow a simple protocol: push the function on the stack, push the arguments, make the call and, finally pop the return values. The exact sequence of steps varies according to the number and types
of arguments and return values but in essence is always the same. In other words, it begs to be automated.

Programming in Lua, in listing 25.4 presents a solution a la printf/scanf. While it works well for numbers and strings it has the same drawbacks. It has the same potential for memory access disasters and is not easily extensible for new types.

Unfortunately this was the only way to write a generic call funtion in C++ without writing overloads for 0 to N arguments. However with the variadic templates feature of the new C++11 standard we can write a new generic call function that is both type-safe and user-extensible.
function add(a, b)
    return a+b
end
For example, given the Lua function above, we would like to call it like this:
int result = callFunc<int>(L, "add", 1, 2);
Our first attempt of implementation looks like this:
// type-specific helper
template<typename T>
struct LuaValue;

// recursive function template to push a variable number of arguments
template<class H, class... T>
int pushArgs(lua_State* L, const H& h, const T&... t) {
    int size = LuaValue<H>::size();
    LuaValue<H>::pushValue(L, h);
    return size + pushArgs(L, t...);
}

// base case
int pushArgs(lua_State* L) {
    return 0;
}

// the main function
template<class R, class... Args>
R callFunc(lua_State* L, const std::string& name, const Args... args) {
    lua_getglobal(L, name.c_str());
    luaL_checktype(L, -1, LUA_TFUNCTION);

    const int size = pushArgs(L, args...);

    if (lua_pcall(L, size, LuaValue<R>::size(), 0) != 0) {
        lua_pushfstring(L, "Error running function %s:\n", name.c_str());
        lua_insert(L, -2);
        lua_concat(L, 2);
        lua_error(L);
    }
    R ret = LuaValue<R>::getStackValue(L, -1);
    lua_pop(L, LuaValue<R>::size());
    return ret;
}
The procedure is simple: first we load the function on the stack and verify that it is really a function. Then we use recursion on the pack of variable argument to push
all arguments on the stack.  The pushArgs function returns the number of values that where pushed and we pass this value to lua_pcall. At the end we must restore the stack to its previous state and return the the value returned by Lua. For each type we must define a specialization of the helper struct LuaValue. To call the function "add" as above we must add a specialization for int:
template<> struct LuaValue<int> {
    static int getStackValue(lua_State* L, int pos) {
        return luaL_checkinteger(L, pos);
    }
    static void pushValue(lua_State* L, int value) {
        lua_pushinteger(L, value);
    }
    static int size() { return 1; }
};
A specialization of LuaValue must define three functions: a function to push the
value on the stack, a function to retrieve a value on the stack and a function that returns the number of stack positions that this value takes.

So far so good, but let's say that we want to call a function with no return values:
// Lua
function log(severity, message)
    --do something with the arguments
    return --return nothing
end
// C++
callFunc<void>(L, "log", WARNING, "corrupting your filesytem...");
What happens is that this code fails to compile because we allocated a temporary value of type R on the stack which, in this case is void. The reason we had to create this temporary was that we had to pop the returned values from the stack. This must be done after we retrieve the return values but before actually returning.
The solution that comes to mind is to use RAII to clear the stack:
struct popper {
    popper(lua_State* L, int size = 1)
        : m_L(L)
        , m_size(size) {}
    ~popper() {
        if (m_size > 0) lua_pop(m_L, m_size);
    }
private:
    lua_State* m_L;
    int m_size;
};
//  ... end of callFunc:
    popper p(L, LuaValue<R>::size());
    return LuaValue<R>::getStackValue(L, -1);
}
With this little trick we will be able to return a "void value". As you have surely guessed, one piece is still missing: the LuaValue for void:
template<> struct LuaValue<void> {
   static void getStackValue(lua_State* L, int pos) {}
   static int size() { return 0; }
};
The implementation for pushValue was intentionally left out, so if you try to use void as argument type your code won't compile.
Before we finish, lets consider one more case:
function div(a , b)
    return a/b, a%b
end
In Lua, a function can return more than one value so how can our generic C++ call function handle it? We return the values as a std::tuple! With the new variadic templates C++11 introduced a this new container in the standard library. The change required is only to add an appropriate overload for std::tuple:
 
template<class... T> struct LuaValue<std::tuple<T...> > {
    template<class Tuple, int I, int N>
    struct helper {
        static void pushValue(lua_State* L, const Tuple& tuple) {
            LuaValue< typename std::tuple_element< I, Tuple >::type >
                ::pushValue(L, std::get<I>(tuple));
            helper<Tuple, I+1, N>::pushValue(L, tuple);
        }
        static void getValue(lua_State* L, Tuple& tuple, int pos) {
            std::get<I>(tuple) =
                LuaValue< typename std::tuple_element<I, Tuple>::type>
                    ::getStackValue(L, pos);
            helper<Tuple, I+1, N>::getValue(L, tuple, ++pos);
        }
    };

    template<class Tuple, int N>
    struct helper<Tuple, N, N> {
        static void pushValue(lua_State*, const Tuple&) {}
        static void getValue(lua_State*, Tuple&, int) {}
    };

    typedef std::tuple<T...> TType;
    static TType getStackValue(lua_State* L, int pos) {
        TType ret;
        helper<TType, 0, std::tuple_size<TType>::value>
            ::getValue(L, ret, pos-size()+1);
        return ret;
    }
    static void pushValue(lua_State* L, const TType& tuple) {
        helper<TType, 0, std::tuple_size<TType>::value>
            ::pushValue(L, tuple);
    }
    static int size() { return std::tuple_size<TType>::value; }
};
Ok, I admit, this was quite a bit of template trickery but in the end it all boils down to this: the pushValue function pushes every value in the tuple to the stack and the getStackValue function retrieves a value from the stack for each position in the tuple. So now we can call the div funtion like this:
int result, remainder;
std::tie(result, remainder) = callFunc(L, "div", 42, 3);
With this scaffolding in place it is easy to add new types to pass as parameters or return values. The source can be found here. The are specializations of LuaValue for collections like std::vector and std::map.