sábado, 24 de novembro de 2012

Semi-automatic conversion of C++ exceptions to Lua errors

 0. Error handling in Lua and C++


Lua and C++ actually use a similar error handling scheme: an error is signalled at some point of the execution and the stack unrolls until a handler is found. If you ask google how to "simulate C++ exception in C" you will be shown several results all using setjmp/longjmp and this is exactly the mechanism used by the Lua interpreter. I'm not going to enter into details here but basically you call setjmp to mark an execution point and call longjmp to return to it, no matter how deep in the stack you are. If you think of setjmp as a try...catch and longjmp as throw, then you have a primitive exception handling scheme.

Back to Lua, when you call lua_pcall in C or pcall in Lua, the interpreter calls setjmp and the executes the function you passed it. If the interpreter detects an error, or the user code calls lua_error, longjmp is called and the execution resumes inside pcall.

The problem with longjmp is that it just discards the entire execution stack (not Lua's stack) without doing any cleanup. If you called malloc and then did a longjmp you just leaked a chunk of memory. So you have to be very careful to clean up everything. In C++ when you throw an exception, not only is the stack unrolled but the destructor of every object you allocated on the stack is called. So if you want to protect yourself you tie each allocation of resources to the lifetime of an object on the stack.  This practice is known as C++'s resource allocation is initialization (RAII) idiom. In the case of a memory buffer, you would use a smart pointer object to hold the raw pointer that calls delete when it is destructed.

Now that we have covered how the exception handling works in C++ and in Lua, let's see what happens if you mix them.

1. Signaling errors from C++ to Lua

Suppose you have a piece of C++ code that is called by lua. What happens if the C++ code throws an exception? Not really what we would like to happen: nobody catches the exceptions and the process terminates. Even if there is a pcall on the stack, Lua will not catch the exception.

Ok, I know that if you compile the PUC Lua interpreter with a C++ compiler it will use try...catch instead of setjmp...longjmp, but unless you are willing to recompile Lua and link it statically to you application, you can't count on that. Also PUC Lua is not the only interpreter out there.

The solution, as you might guess, is to convert the exceptions to Lua errors at the boundary between the two languages. Typically you have some nifty C++ code that you want to use in Lua and you add a layer to make it callable from Lua. Let's call this layer a binding. This binding usually consists of a bunch of functions that conform to the lua_CFunction signature.

Consider the following scenarios:
  1. Deep in the C++ code an exception is raised that can't be corrected by the binding
  2. A error is detected at the binding level that must must be reported to the calling Lua code
How should we handle these cases? A possible solution is:

"To protect ourselves from the first case, we surround every call to C++ with a try catch and call lua_error in the catch block. In the second case it doesn't make sense to throw an exception so we call lua_error directly."

It makes sense, doesn't it? But it actually has two severe problems.

The first problem is that if you call lua_error inside of the catch block you are guaranteed to leak at least one resource: the exception object itself. lua_error calls longjmp and the exception's destructor is never called. Usually exception objects contain an error message string so this is leaked as well. A way out would be to structure your exception handling code so that lua_error is called outside of the catch block. In the case where an error is detected in the binding and you call lua_error directly you also must be very careful to clean-up everything.

The second problem is that it is hard to organize your code around lua_error. The slightest mistake and you are leaking resources. Ok, you say, but I am a very disciplined programmer and I'll make sure that every object is destroyed before I call lua_error. But what if someone else who is unaware of all this subtleties edits your code? This solution is very hard to maintain. Also your C++ code doesn't feel very natural anymore.

2. Solution

The solution I propose is:
  1. Handle all the conversions of exceptions to lua_error in one place
  2. Never call lua_error except for this one place
  3. Always use exceptions, even in binding lua_CFunctions
The trick, of course, is implementing item 1. We want to wrap every call to a lua_CFunction inside a try...catch statement. As an example let's use a modified version of listing 26.1 of Programming in Lua. We implement a function to list directories but instead of returning error codes we throw an exception (this is a bad idea, but it's just to illustrate our case).

#include <dirent.h>
#include <errno.h>
#include <stdexcept>

static int l_dir(lua_State* l) {

  const char* path = luaL_checkstring(L, 1);

  DIR* dir = opendir(path);
  if (dir == nullptr) {
    throw std::logic_error(strerror(errno));
  }

  lua_newtable(L);
  int i = 1;
  dirent* entry;

  while((entry = readdir(dir)) != NULL) {
    lua_pushnumber(L, i++);
    lua_pushstring(L, entry->d_name);
    lua_settable(L, -3);
  }
  closedir(dir);
  return 1;
}

Now we must register the function to call it from lua:

static const struct luaL_Reg dirlib[] = {
  { "dir", l_dir },
  { nullptr, nullptr }
};

int luaopen_dirlib(lua_State* L) {
  luaL_register(L, "dirlib", dirlib);
  return 1;
}

Our first try is:

static int wrap_dir(lua_State* L)
{
  try {
    return l_dir(L);
  } catch(std::exception& ex) {
    lua_pushstring(L, "caught C++ exception: ");
    lua_pushstring(L, ex.what());
    lua_concat(L,2);
  } catch (...) {
    lua_pushstring(L, "caught unknown C++ exception");
  }

  // call lua_error out of the catch block to make sure
  // that the exception's destructor is called
  // this is because lua_error calls longjmp and discards the error
  lua_error(L);

  return 0; //just to silence warnings
}

And we register this wrapper function instead of l_dir

static const struct luaL_Reg dirlib[] = {
  { "dir", wrap_dir },
  { nullptr, nullptr }
};

Notice that we don't call lua_error in the catch blocks to avoid leaking the exceptions. This solution works but it's not practical to require a wrapper function for every lua_CFunction, programmers are lazy people!

What we would like to do is to create a generic wrapper function that takes a lua_CFunction pointer as argument, calls it and handles potential exceptions. The problem is that with this new signature our function wouldn't conform to the lua_CFunction signature and therefore could not be registered with luaL_register. We need another way to pass the function pointer to our wrapper function.

As usual, in my posts, C++ templates come to our rescue when the situation looks most desperate. What few people know is that function pointers can be used as template arguments in C++. We can write a function template that takes as template argument a pointer to lua_CFuntion. Once instantied, it is a lua_CFunction like any other, that can be registered with luaL_register.

template<lua_CFunction func>
static int exception_translator(lua_State* L)
{
  try {
    return func(L);
  } catch(std::exception& ex) {
    lua_pushstring(L, "caught C++ exception: ");
    lua_pushstring(L, ex.what());
    lua_concat(L,2);
  } catch (...) {
    lua_pushstring(L, "caught unknown C++ exception");
  }

  // call lua_error out of the catch block to make sure
  // that the exception's destructor is called
  // this is because lua_error calls longjmp and discards the error
  lua_error(L);
  return 0; //just to silence warnings
}

static const struct luaL_Reg dirlib[] = {
  { "dir", exception_translator<l_dir> },
  { nullptr, nullptr }
};


3. Conclusion


If you are careful to wrap every lua_CFunction with this function you can use exceptions as much as you like in your C++ code knowing that they will be converted to Lua errors if they aren't caught before. Because you are able to use exceptions instead of lua_error, you know that everything that you carefully allocated using RAII will get disallocated in the case of errors. Also you don't have to rely on the compilation settings of your Lua interpreter, this works with any conforming interpreter.

sábado, 20 de outubro de 2012

Fun with variable typing in C++


Warning. The following listing contains disturbing code. If you are the type of viewer who is easily disturbed by the sight of templates, lambda expressions and other diabolic concoctions do not read this post.

VariantValue factorial(vector<VariantValue> args)
{
    return expr_switch(args, {
        {pattern(0), []{ return 1; }},
        {pattern(Type<int>()), [&](){ return args[0].value<int>()
            * factorial({args[0].value<int>()-1}).value<int>(); }}
   });
}


Ok, I have to admit that this is the ugliest factorial function I've ever written. But that is not the point of this example. Let me first tell you how I came to this.  I was reading the PhD thesis of Joe Armstrong [1], the author of the Erlang programmming language. As you may know, Erlang is a functional language and is dynamically typed. In Erlang you define functions as a sequence of head-body clauses where the head consists of a patterns that must be matched for the corresponding body to be executed. So, for example, you could write the factorial funtion like this:

factorial(0) -> 1;
factorial(N) -> N * factorial(N-1).


In this example the first clause is executed if the argument matches the value "0". If it doesn't the second clause is executed and the argument is bound to the identifier "N". If the arguments don't match any clause, an exception is raised. This mechanism has the following advantages:
  1. When a body is executed, most preconditions have been checked, so there is no need to perform further tests.
  2. If no clause is matched, processing stops immediately.
In other words, no defensive programming is necessary and errors are detected early. In his thesis Armstrong argues that this intentional programming together with fast failure detection is fundamental for building reliable and fault-tolerant systems. So this made me think how we could do the same thing in C++11.

As it happens I've been building a reflection system for C++11 (check out my
github page [2]) and it's most important building block is a class called VariantValue that enables dynamic typing. This class is similar to boost::any,
but it has some enhancements that make it much more usable (but this is for
another post). Anyway, I decided to see if I could emulate this mechanism using this VariantValue and C++11 lambdas. It turns out I can... sort of.

But first things first. Let me introduce you to the basic API of this VariantValue:

class VariantValue {
public:
    // set the content to type T passing a
    // variable number of arguments to its constructor
    template<class T, class... Args>
    void construct(Args&& as);

    // return true if the contained type is T
    template<class T>
    bool isA() const;

    // get the value as T, throws exception if type is wrong
    template<class T>
    T value() const;

    // same as above, but can try a several conversions
    template<class T>
    T convert() const;

};


These variants have value semantics and can be put inside containers. So a
function with dynamic arguments can have, for example, a  vector<VariantValue> as argument. Writing code to test if each position contains a certain type and maybe a specific value is easy, but we want a generic solution. What we want is to pass a list of types and values to the matching function and let it figure out if a sequences of variants matches the pattern.

The problem is that values and types, like oil and water, don't mix in C++, at
least not without an emulgent. Because encoding a value as a type is much more difficult, we will represent a type as a value.

template<class T>
struct Type {
    typedef T type;
};


We can use temporaries of our "emulgent" class as placeholders for values of a specific type. Next on our list is the matching engine. As you may have
guessed, this involves templates, and worse, of the variadic kind. Because there is no specialization for fucntions, we will define a little helper template. First we declare the generic template:

template<class... A>
struct matcher_impl;


Then we need to handle the case of the empty argument list:

template<>
struct matcher_impl<> {
    static bool match(vector<VariantValue>& v,
                      vector<VariantValue>::iterator it)
    {
        return it == v.end();
    }
};

If there is no argument left, we expect the iterator to have reached the end of the argument vector. This is because we want to match the exact number of arguments.

Next we define the case for the argument list beginning with a value:

template<class A1, class... As>
struct matcher_impl<A1, As...> {
    static bool match(A1&& a1,
                      As&&... as, vector<VariantValue>& v,
                      vector<VariantValue>::iterator it)
    {
        if (it == v.end()) {
            return false;
        }
        if (!it->isA<A1>()) {
            return false;
        }
        if (it->value<A1>() != a1) {
            return false;
        }
        return matcher_impl<As...>::match(
            std::forward<As&&>(as)..., v, ++it);
    }
};


As you can see this is fairly straight-forward. We check if the iterator is
valid, check the type, check the value and incremente the iterator before we
"recurse". The case for the type placeholder is very similar, except that we
omit the value check.

template<class A1, class... As>
struct matcher_impl<Type<A1>, As...> {
    static bool match(Type<A1> t1, As&&... as,              
                      vector<VariantValue>& v,
                      vector<VariantValue>::iterator it)
    {
        if (it == v.end()) {
            return false;
        }
        if (!it->isA<A1>()) {
            return false;
        }
        return matcher_impl<As...>::match(
            std::forward<As&&>(as)..., v, ++it);
    }
};


And finally we can write a front-end function to hide all this template ugliness. For a reason that will be apparent soon, instead of returning
immediately a bool, we return a function object that takes a vector as
argument and checks if it matches the pattern:

template<class... As>
function<bool(vector<VariantValue>&)> pattern(As&&... as)
{
    return [&](std::vector<VariantValue>& v) {
        return pattern_impl<As...>::match(
            std::forward<As&&>(as)...,
            v, v.begin());
    };
}

And finally we want an expression that selects the first match in a list of
patterns and calls the corresponding body or fails. But we don't want to write it by hand every time we define a function. Fortunately with initializer_lists we can write a function that does all this without being too ugly:

VariantValue expr_switch(
    vector<VariantValue>& values, 
    vector<pair<function<bool(vector<VariantValue>&)>, 
    function<VariantValue()>>> pairs)
{
    for (auto pair: pairs) {
        if (pair.first(values)) {
            return pair.second();
        }
    }
    throw logic_error("unmatched arguments");
}


Voilà! We take a list of pairs where the first element is the pattern matching
lambda and the second is the function body. We call the pattern matching
lambdas in sequence and return on the first match. If no pattern is matched we raise an exception.

Now that you saw the mechanims behind the factorial function lets see how to call it:

std::cout <<
    "factorial of 4 = " <<
    factorial({4}).value<int>() <<
    std::endl;


It could be worse, right? Another interesting example is to define a function
to treat the dreaded Square and Rectangle example (is a square a special case of a rectangle or the other way around? [3]):

struct Square {
    int w;
};

struct Rectangle {
    int w;
    int h;
};

VariantValue area(vector<VariantValue> args)
{
    return expr_switch(args, {
       {pattern(Type<Square>()), [&]() -> VariantValue {
            return args[0].value<Square&>().w
                  *args[0].value<Square&>().w;
        }},
        {pattern(Type<Rectangle>()), [&]() -> VariantValue {
            return args[0].value<Rectangle&>().w
                  *args[0].value<Rectangle&>().h;
        }}
    });
}

Not bad. The only thing that bothers me is that in the function body I'm forced to refer to the generic vector of arguments. It would be nice if I could give a name to the type placeholder in the pattern. Then in the function body I could refer directly to this name and no further type conversions would be needed. I haven't figured out how to do this and I think it's not possible.

I have only written patterns for types and values, but it would be easy to
extend this basic framework to do more advanced pattern matching. Using the same idea of the type placeholder, we could have a checking type placeholder. For example, Between<int>(1, 10) would match an integer between 1 and 10.

References:
  1. Joe Armstrong. Phd Thesis
  2. Reflection for C++11
  3. Robert C. Martin, The Liskov Substition Principle

segunda-feira, 23 de julho de 2012

C++ name demangling

C++ name demangling


As you probably know, C++ compilers employ a scheme called name mangling to produce unique names that can be handled by traditional C linkers. This is because language features such as function overloading and templates allow you create different functions with the same name. Because overloaded functions must have different signatures, compilers take the function signature and generate a name that contains the name and the parameter type information encoded as a string.

For example, g++ encodes
int foo(double d, float a);
as
_Z3foodf

While the C linker is perfectly happy with the mangled names, for most programmers they are not easily readable. Fortunately all mangling schemes I'm aware of allow you to decode the mangled name back into C++ names.
Even better, some of them provide libraries to do this for you. Of course, these libraries are not part of the C++ standard and therefore are not portable.

GNU mangling

The g++ compiler adopts the mangling scheme of the Itanium ABI standard. This scheme is recognisable by the "_Z" preceding the function name. G++ provides a demangling function in the header cxxabi.h:

namespace abi {
  char* __cxa_demangle(const char* mangled_name, char* output_buffer, size_t* length, int* status);
}

This function takes a string containing the mangled name and returns a the mangled name in a buffer that must be preallocated with malloc. If the buffer is not big enough for the demangled name, the function tries to obtain more space calling realloc. You can also pass NULL for the second and third arguments. In this case the function will allocate a buffer for you. The status argument is used to give more information in case of failure.

With this function in hand we can write a little command line utility to demangle names passed as command line arguments:

#include <iostream>
#include <cxxabi.h>
#include <stdlib.h>
using namespace std;

int main(int argc, char* argv[])
{
    int status = 0;
    char* demangled = abi::__cxa_demangle(argv[1], nullptr, nullptr, &status);
    switch (status) {
        case 0:
            cout << demangled << endl;
            break;
        case -1:
            cout << "Error: memory allocation failure" << endl;
            break;
        case -2:
            cout << "Error: " << argv[1] << " is not a valid mangled name"  << endl;
            break;
        case -3:
            cout << "Error: invalid argument"  << endl;
            break;
    }
    if (demangled) {
        free(demangled);
    }
    return 0;
}

For our previous example, the output will be:

> ./demangler _Z3foodf
foo(double, float)

This function is useful even for windows because MinGW compiled binaries use the same mangling scheme. This function is also available for programs compiled with clang, at least in the case where clang is using g++'s standard library

Microsoft mangling


Microsoft's Visual C++ compiler uses it's own proprietary mangling scheme. The names it generates always start with a "?". Microsoft API provides the unmangling function UnDecorateSymbolName in the imagehlp.h system header. Like __cxa_demangle, it takes input and output buffer but you are responsible for providing a buffer that is long enough to contain the unmagled name.  The last argument to this function is a set of flags that allow you to extract only the information of the you are interested in. A version for windows of our command line utility could look like this:

#include <windows.h>
#include <imagehlp.h>
#include <iostream>
using namespace std;

int main(int argc, char* argv[])
{
    char buffer[1024];
    DWORD result = UnDecorateSymbolName(argv[1], buffer, sizeof(buffer), UNDNAME_COMPLETE);

    if (result) {
        cout << buffer << endl;
    } else {
        cout << "Error: demangling failed. error code = " << GetLastError()  << endl;
    }
    return 0;
}

References:

http://en.wikipedia.org/wiki/Name_mangling
http://www.kegel.com/mangle.html
MSDN UnDecorateSymbolName

Epilogue: What if I don't want a function name to get mangled?

If you need a symbol to be compiled unmangled just declare it as extern "C":
extern "C" int foo(double d, float a);
or
extern "C" {
    int foo(double d, float a);
}

But remember, if you do this you're back to C, you cannot overload names declared as extern "C".

domingo, 10 de junho de 2012

Integrating Lua unit tests in CxxTest

I'm currently working on a project that uses both C++ and Lua. In the beginning I looked at several options to do unit tests in C++ and ended up choosing CxxTest. I've managed to integrate it well in my cmake build script and all the tests run with a simple "make test".
Because the Lua code calls the C++ code I started coding it later and only now I'm writing unit tests for it (yeah, I'm not a TDD fan). There are several good options for unit testing in Lua (see lua wiki) but there are two reasons why I didn't use them. The first one is that I wanted to use only one testing framework to have unified testing statistics. The other reason is that they lack a certain features CxxTest has. The thing I like about CxxTest is that its test macros also capture the source text. For example if I write:
int i = 3;
TS_ASSERT_EQUALS(i, 5);
TS_ASSERT(i == 5);
I get:
test.cpp:29: Error: Expected (i == 5), found (3 != 5)
test.cpp:30: Error: Assertion failed: i == 5
When you look at that output you already know what variables are wrong even before opening the code in your text editor. It would be nice to have this feature in Lua as well.
Of course this only work in CxxTest because it uses the argument to string conversion feature of C macros. Now, Lua has no macros but it has a much more powerful feature which is the dynamic loading of code from strings (Actually there is a macro extension for Lua). Its impossible to write
i = 3
TS_ASSERT(i == 5)
But it should be possible to write:
i = 3
TS_ASSERT[[i == 5]]
Our first try is
function TS_ASSERT(code)
    local func = assert(loadstring("return "..code))
    if not func() then
       local info = debug.getinfo(2, "Sl");
       local msg = {"Failed assertion: ", code}
       TS_FAIL(info.short_src, info.currentline, table.concat(msg))
    end
end
This function tries to compile a string, containing an expression, passed as argument and checks if the result evaluates to true or false according to Lua's rules. If the result is true the test passed and life goes on. If not, we find out the source location where our function was called and pass it on to TS_FAIL together with the expression that failed. For now, all we need to know about TS_FAIL is that it calls CxxTest to account for this failure.
The output for the sample above would be:
test.lua:29: Error: Failed assertion: i == 5
This is very nice but we have no clue about the actual value of i. And there is another problem: it only works with global variables because the code compiled with loadstring has no lexical scoping. If we had written
local i = 3
TS_ASSERT[[i == 5]]
the test would always fail because the undefined global variable i would be nil.
Let't turn our attention to the first problem first: we need to find out the values of variables inside a string of code. We could print all the values of _G, the global variable table, but this would clutter the output confusing the programmer instead of helping him. An extreme solution would be to parse the code ourselves to find out the names of the variables and lookup their values in _G to print them. A more practical solution is to use environments and meta-tables. In Lua we can set a special environment for functions. At the same time we can use the meta-method __index to track the access to keys that are not present. With a specially crafted environment table we can track all the variables acesses of a funtion. Adding these ideas our second attempt looks like this:
function TS_ASSERT(code)

    local func = assert(loadstring("return "..code))

    local calling_globals = getfenv(2)
    local referenced = {}
    local env = {}

    setmetatable(env, {__index=
        function(table, name)
            local value = calling_globals[name] end
            referenced[name] = value
            return value
        end
    })

    setfenv(func, env)
    if not func() then
       local info = debug.getinfo(2, "Sl");
       local msg = {"Failed assertion: ", code}

       local first = true
       for k, v in pairs(referenced) do
           if first then
               msg[#msg+1]= ". Variables: "
           else
               msg[#msg+1] = ", "
           end
           msg[#msg+1]= tostring(k).." = "..tostring(v)
           first = false
       end
       TS_FAIL(info.short_src, info.currentline, table.concat(msg))
    end
end
First we get the environment of the calling function. Then we give our newly compiled function an empty environment with an __index meta-method that will forward all variable lookups to the environment of the calling function. So now the output for the sample above would be:
test.lua:29: Error: Failed assertion: i == 5. Variables: i = 3
Much better, but it still won't work for local variables. Unlike globals, local variables are not stored in a table, possibly for performance reasons. For this we must, again, resort to the debug library. The function debug.getlocal an index as argument and returns the name of the corresponding local variable, followed by its value. All we have to do is to copy these values to our fake envirnonment:
...
    local calling_globals = getfenv(2)
    local calling_locals  = {}
    local i = 1
    while true do
        local name, value = debug.getlocal(2, i)
        if not name then break end
        calling_locals[name] = value
        i = i + 1
    end

    local notpresent = {}
    setmetatable(calling_locals, { __index=function() return notpresent end})

    local referenced = {}
    local env = {}

    setmetatable(env, {__index=
        function(table, name)
            local value = calling_locals[name]
            if value == notpresent then value = calling_globals[name] end
            referenced[name] = value
            return value
        end
    })
...
We collect the local variables in a table and we make the intercepting function look in this table before searching in the global environment, respecting the Lua's rules for variable lookup. The only complication is that a table lookup returns nil both if the entry doesn't exists or its value is really nil. To make the distinction we use the empty table "notpresent". The call to getlocal might be slow but we don't want to make the fastest unit test in the world but the most helpful.
As a last addition we can make the following enhancement to TS_ASSERT: if the argument is not a string we simply see if it evaluates to false.
...
    if (type(code) ~= "string") then
        if not code then
            local info = debug.getinfo(2, "Sl");
            local msg = {"Failed assertion"}
            TS_FAIL(info.short_src, info.currentline, table.concat(msg))
        end
        return
    end
...
Now we can also write:
local file = io.open("/dev/null", "w")
TS_ASSERT(file)
This assertion will fail if the io.open call is unsuccessful.
The last thing left to explain is the TS_FAIL function. We have used it passing a custom source location but it would also be nice to use it like the C++ macro to fail a test:
local file, msg = io.open("/dev/null", "w")
if not file then TS_FAIL(msg) end
All we have to do is to write a small binding for the CxxTest API function that the macro calls internally:
int ts_fail(lua_State* L) {
    const int nargs = lua_gettop(L);

    if (nargs == 3) {
        CxxTest::TestTracker::tracker().failedTest(
            luaL_checkstring(L, -3),
            luaL_checkint(L, -2),
            luaL_checkstring(L, -1) );
    } else if (nargs >= 1) {
        lua_Debug debug;
        lua_getstack(L, 1, &debug);
        lua_getinfo(L, "Sl", &debug);
        CxxTest::TestTracker::tracker().failedTest(
            debug.short_src,
            debug.currentline,
            luaL_checkstring(L, -1) );
    } else {
        luaL_error(L, "TS_FAIL called with an illegal number of arguments");
    }
    return 0;
}
If this function is called with three arguments we assume that the first two are the location and the second the message. If only one argument is supplied, we assume that it is the error message and find out the calling location using the C API of the debug library. In both cases the location and the message are passed to the failedTest method of CxxTest's TestTracker singleton.
It's arguable if the name of the variables in the testing output will help you to fix you errors faster, but I think it is interesting to see that it can be done. There are many more features we can add to this framework. One, for example, is measuring the code coverage of the tested code with Lua's debug library. But for now this will already help me a lot to write unit tests in Lua.

sábado, 9 de junho de 2012

Lua: Generic call with variadic templates

If you want to call a Lua function from C/C++ you have to follow a simple protocol: push the function on the stack, push the arguments, make the call and, finally pop the return values. The exact sequence of steps varies according to the number and types
of arguments and return values but in essence is always the same. In other words, it begs to be automated.

Programming in Lua, in listing 25.4 presents a solution a la printf/scanf. While it works well for numbers and strings it has the same drawbacks. It has the same potential for memory access disasters and is not easily extensible for new types.

Unfortunately this was the only way to write a generic call funtion in C++ without writing overloads for 0 to N arguments. However with the variadic templates feature of the new C++11 standard we can write a new generic call function that is both type-safe and user-extensible.
function add(a, b)
    return a+b
end
For example, given the Lua function above, we would like to call it like this:
int result = callFunc<int>(L, "add", 1, 2);
Our first attempt of implementation looks like this:
// type-specific helper
template<typename T>
struct LuaValue;

// recursive function template to push a variable number of arguments
template<class H, class... T>
int pushArgs(lua_State* L, const H& h, const T&... t) {
    int size = LuaValue<H>::size();
    LuaValue<H>::pushValue(L, h);
    return size + pushArgs(L, t...);
}

// base case
int pushArgs(lua_State* L) {
    return 0;
}

// the main function
template<class R, class... Args>
R callFunc(lua_State* L, const std::string& name, const Args... args) {
    lua_getglobal(L, name.c_str());
    luaL_checktype(L, -1, LUA_TFUNCTION);

    const int size = pushArgs(L, args...);

    if (lua_pcall(L, size, LuaValue<R>::size(), 0) != 0) {
        lua_pushfstring(L, "Error running function %s:\n", name.c_str());
        lua_insert(L, -2);
        lua_concat(L, 2);
        lua_error(L);
    }
    R ret = LuaValue<R>::getStackValue(L, -1);
    lua_pop(L, LuaValue<R>::size());
    return ret;
}
The procedure is simple: first we load the function on the stack and verify that it is really a function. Then we use recursion on the pack of variable argument to push
all arguments on the stack.  The pushArgs function returns the number of values that where pushed and we pass this value to lua_pcall. At the end we must restore the stack to its previous state and return the the value returned by Lua. For each type we must define a specialization of the helper struct LuaValue. To call the function "add" as above we must add a specialization for int:
template<> struct LuaValue<int> {
    static int getStackValue(lua_State* L, int pos) {
        return luaL_checkinteger(L, pos);
    }
    static void pushValue(lua_State* L, int value) {
        lua_pushinteger(L, value);
    }
    static int size() { return 1; }
};
A specialization of LuaValue must define three functions: a function to push the
value on the stack, a function to retrieve a value on the stack and a function that returns the number of stack positions that this value takes.

So far so good, but let's say that we want to call a function with no return values:
// Lua
function log(severity, message)
    --do something with the arguments
    return --return nothing
end
// C++
callFunc<void>(L, "log", WARNING, "corrupting your filesytem...");
What happens is that this code fails to compile because we allocated a temporary value of type R on the stack which, in this case is void. The reason we had to create this temporary was that we had to pop the returned values from the stack. This must be done after we retrieve the return values but before actually returning.
The solution that comes to mind is to use RAII to clear the stack:
struct popper {
    popper(lua_State* L, int size = 1)
        : m_L(L)
        , m_size(size) {}
    ~popper() {
        if (m_size > 0) lua_pop(m_L, m_size);
    }
private:
    lua_State* m_L;
    int m_size;
};
//  ... end of callFunc:
    popper p(L, LuaValue<R>::size());
    return LuaValue<R>::getStackValue(L, -1);
}
With this little trick we will be able to return a "void value". As you have surely guessed, one piece is still missing: the LuaValue for void:
template<> struct LuaValue<void> {
   static void getStackValue(lua_State* L, int pos) {}
   static int size() { return 0; }
};
The implementation for pushValue was intentionally left out, so if you try to use void as argument type your code won't compile.
Before we finish, lets consider one more case:
function div(a , b)
    return a/b, a%b
end
In Lua, a function can return more than one value so how can our generic C++ call function handle it? We return the values as a std::tuple! With the new variadic templates C++11 introduced a this new container in the standard library. The change required is only to add an appropriate overload for std::tuple:
 
template<class... T> struct LuaValue<std::tuple<T...> > {
    template<class Tuple, int I, int N>
    struct helper {
        static void pushValue(lua_State* L, const Tuple& tuple) {
            LuaValue< typename std::tuple_element< I, Tuple >::type >
                ::pushValue(L, std::get<I>(tuple));
            helper<Tuple, I+1, N>::pushValue(L, tuple);
        }
        static void getValue(lua_State* L, Tuple& tuple, int pos) {
            std::get<I>(tuple) =
                LuaValue< typename std::tuple_element<I, Tuple>::type>
                    ::getStackValue(L, pos);
            helper<Tuple, I+1, N>::getValue(L, tuple, ++pos);
        }
    };

    template<class Tuple, int N>
    struct helper<Tuple, N, N> {
        static void pushValue(lua_State*, const Tuple&) {}
        static void getValue(lua_State*, Tuple&, int) {}
    };

    typedef std::tuple<T...> TType;
    static TType getStackValue(lua_State* L, int pos) {
        TType ret;
        helper<TType, 0, std::tuple_size<TType>::value>
            ::getValue(L, ret, pos-size()+1);
        return ret;
    }
    static void pushValue(lua_State* L, const TType& tuple) {
        helper<TType, 0, std::tuple_size<TType>::value>
            ::pushValue(L, tuple);
    }
    static int size() { return std::tuple_size<TType>::value; }
};
Ok, I admit, this was quite a bit of template trickery but in the end it all boils down to this: the pushValue function pushes every value in the tuple to the stack and the getStackValue function retrieves a value from the stack for each position in the tuple. So now we can call the div funtion like this:
int result, remainder;
std::tie(result, remainder) = callFunc(L, "div", 42, 3);
With this scaffolding in place it is easy to add new types to pass as parameters or return values. The source can be found here. The are specializations of LuaValue for collections like std::vector and std::map.