Extending Numba - fosdem 2019

February 15th, 2020

In the first weekend of february, I faithfully visit fosdem. This year was not different. Every year it's always a bit bigger, and also in that regard, this year was no different, to the extent that I'm starting to feel that it's becoming all a bit too much. Still it's an amazing conference and over the years it has become really professional, while it remains volunteer led.

Last year I gave a talk at the Python devroom titled «Extending Numba». In my talk I explained how Numba's architecture (a python accelerator) allows you to extend it. At Luceda we've extended Numba to solve some of its shortcomings. Docs on this functionality were sparse though, so I thought that would be useful to share some of the things I've learned. The examples below are also available on the fosdem archive (see the link above). I've only added some inline comments to make them a bit clearer.

one of the extension points numba offers is rewriting its intermediate representation by registering a subclass of Rewrite

you can register a rewrite at two stages:

  • 'before-inference' (no type information)
  • 'after-inference' (has type information)

in the meaningless example below, we replace the value of the meaningful_var with 42

import numba as nb
from numba.rewrites import Rewrite, register_rewrite
from numba import ir

@register_rewrite('before-inference')
class MyRewrite(Rewrite):
    
    def match(self, func_ir, block, typemap, calltypes):
        assigns = [
            assign
            for assign in block.find_insts(ir.Assign)
            if assign.target.name == 'meaningful_var'
        ]

        if len(assigns) > 0 and not hasattr(self, 'assigns'):
            self.assigns = assigns
            self.block = block
            return True

    def apply(self):
        for assign in self.assigns:
            # replace with 42
            assign.value = ir.Const(42, assign.value.loc)

        return self.block

You can now execute the code below and the value of meaningful_var will be replaced by 42. Note that the integer type of 42 is correctly inferred by Numba because it's replaced before any type inference is done.

        
@nb.jit(nopython=True)
def test(x):
    meaningful_var = '#to be replaced#'
    y = meaningful_var + 1
    return y
    
print(test(1))

Another important way to extend Numba is by adding new types and data models. The examples below show how to make our custom 'MyPoint' object available for use in Numba. It's similar to the interval example on the Numba documentation.

# this is the type we want to use in Numba:

class MyPoint(object):
    def __init__(self, x, y):
        self.x = x
        self.y = y

The first step is to create an object to represent the type of MyPoint in Numba.

import numba
from numba.extending import type_callable, typeof_impl
from numba import types

class MyPointType(numba.types.Type):
    # A custom type to represent a point
    # used during inference
    def __init__(self):
        super(MyPointType, self).__init__(name='Point')
        
# use type_callable decorator to annotate the types of the MyPoint callable.
@type_callable(MyPoint)
def type_MyPoint(context):
    # MyPoint callable accepts two arguments
    def typer(x, y):
        # your_func returns a point
        return MyPointType()
    return typer

At this point Numba can interpret and infer the type of the MyPoint callable, at however has no idea about its behavior. In our case it creates a simple object with two attributes. A common way you would represent this natively is by using a struct. We'll do that next:


from numba.extending import register_model, models
@register_model(MyPointType)
class MyPointModel(models.StructModel):
    """ A Struct - like model to store the x,y 
    coordinates of a point. 
    """
    def __init__(self, dmm, fe_type):
        members = [
          ('x', types.int64),
          ('y', types.int64),
        ]
        assert isinstance(fe_type, MyPointType)
        models.StructModel.__init__(self, dmm, 
                                   fe_type, members)

With this Numba knows that the internal numba datastructure of MyPointType is a struct (-like) with 2 integer attributes. But it doesn't yet know how to create it. We can tell Numba how to do so by using the lower_builtin decorator.

from numba.extending import lower_builtin
from numba import cgutils # llvm codegen utils

@lower_builtin(MyPoint, types.Integer, types.Integer)
def impl_point(context, builder, sig, args):
    typ = sig.return_type
    assert isinstance(typ, MyPointType)
    x, y = args
    # create_struct_proxy is a helper to easily generate llvm code 
    # to create a struct 
    point = cgutils.create_struct_proxy(typ)(context, builder)
    point.x = x 
    point.y = y
    return point._getvalue()

At this point we're ready to execute our first code with Numba:

@numba.njit()
def test1():
    # you can create a point but that's about it
    # you can't really use it yet
    pt = MyPoint(1, 2)
    
    # you can make a list though:
    lst = [
        MyPoint(1, 2),
        MyPoint(3, 2)
    ]
    
    return len(lst)

assert test1()  == 2

In order to do something more useful, we at least have to implement the attribute retrieval:


from numba.typing.templates import AttributeTemplate
from numba.extending import infer_getattr, lower_getattr
from numba.targets.imputils import impl_ret_borrowed

@infer_getattr
class MyPointAttribute(AttributeTemplate):
    """ Templates are used for inference, though typically 
    you'll want to use the higher level constructs.
    """
    key = MyPointType

    def generic_resolve(self, typ, attr):
        if attr in ['x', 'y']:
            return numba.types.int64

@lower_getattr(MyPointType, 'x')
def struct_getattr_impl(context, builder, typ, val):
    val = cgutils.create_struct_proxy(typ)(context, builder, value=val)
    attrval = getattr(val, 'x')
    return impl_ret_borrowed(context, builder, numba.types.int64, attrval)

@lower_getattr(MyPointType, 'y')
def struct_getattr_impl(context, builder, typ, val):
    val = cgutils.create_struct_proxy(typ)(context, builder, value=val)
    attrval = getattr(val, 'y')
    return impl_ret_borrowed(context, builder, numba.types.int64, attrval)

First step is to tell Numba about the types of the getattr operation. In this case we know x and y are of type int64, as this is a quite common type Numba has builtin types to represent this. After the type inference, there's again a lowering step to generate the actual code. Again we see that we're using Numba's struct_proxy to make it easy to generate code for accessing the members of the struct.

Let's test that we can indeed use the attributes now:


@numba.njit()
def test2():
    pt = MyPoint(1, 2)
    return pt.x + pt.y

assert test2() == 1 + 2

Nice that works! But that was quite some code to just add attributes while most information was already available. Luckily Numba provides some helpers to implement this common case.

from numba.extending import make_attribute_wrapper

# takes care of the type inference and the lowering at 
# the same time. 
make_attribute_wrapper(MyPointType, 'x', 'x')
make_attribute_wrapper(MyPointType, 'y', 'y')

We're not quite done yet. There's still something that doesn't work yet, if you execute the following code Numba will raise an error:

@numba.njit()
def test3():
    return MyPoint(1, 2)

# returning a point raises a TypeError
try:
    point = test3()
except TypeError as e:
    print("Error returning MyPoint: {}".format(e))

We get this error because Numba doesn't know how to turn the internal struct to a python MyPoint object. This operation is called boxing, like before Numba provides a decorator to tell it how to box the MyPoint struct to the right python object:

from numba.extending import box

@box(MyPointType)
def box_point(typ, val, c):
    point = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)
    x_obj = c.pyapi.long_from_signed_int(point.x)
    y_obj = c.pyapi.long_from_signed_int(point.y)
    class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(MyPoint))
    res = c.pyapi.call_function_objargs(class_obj, (x_obj, y_obj))
    c.pyapi.decref(x_obj)
    c.pyapi.decref(y_obj)
    c.pyapi.decref(class_obj)
    return res

This allows us the return back to Python:

@numba.njit()
def test4():
    return MyPoint(1, 2)

pt = test4()
assert (pt.x, pt.y) == (1, 2)

Similarly you can implement an unboxing operation to pass python objects as arguments to numba functions.