NPArrayUFuncOverloadMixin#

class overload_numpy.NPArrayUFuncOverloadMixin

Bases: object

Mixin for adding __array_ufunc__ to a class.

This mixin adds the method __array_ufunc__. Subclasses must define a class variable NP_OVERLOADS.

Notes

When compiled this class is a mypyc trait() and permits interpreted subclasses (see https://mypyc.readthedocs.io/en/latest/native_classes.html#inheritance).

Examples

First, some imports:

>>> from dataclasses import dataclass, fields
>>> from typing import ClassVar
>>> import numpy as np
>>> from overload_numpy import NumPyOverloader, NPArrayUFuncOverloadMixin

Now we can define a NumPyOverloader instance:

>>> W_FUNCS = NumPyOverloader()

The overloads apply to an array wrapping class. Let’s define one:

>>> @dataclass
... class Wrap1D(NPArrayUFuncOverloadMixin):
...     '''A simple array wrapper.'''
...     x: np.ndarray
...     NP_OVERLOADS: ClassVar[NumPyOverloader] = W_FUNCS
>>> w1d = Wrap1D(np.arange(3))

Now numpy.ufunc can be overloaded and registered for Wrap1D.

>>> @W_FUNCS.implements(np.add, Wrap1D)
... def add(w1, w2):
...     return Wrap1D(np.add(w1.x, w2.x))

Time to check this works:

>>> np.add(w1d, w1d)
Wrap1D(x=array([0, 2, 4]))

ufunc also have a number of methods: ‘at’, ‘accumulate’, etc. The function dispatch mechanism in NEP13 says that “If one of the input or output arguments implements __array_ufunc__, it is executed instead of the ufunc.” Currently the overloaded numpy.add does not work for any of the ufunc methods.

>>> try: np.add.accumulate(w1d)
... except Exception: print("failed")
failed

ufunc method overloads can be registered on the wrapped add implementation:

>>> @add.register('accumulate')
... def add_accumulate(w1):
...     return Wrap1D(np.add.accumulate(w1.x))
>>> np.add.accumulate(w1d)
Wrap1D(x=array([0, 1, 3]))

What if we defined a subclass of Wrap1D?

>>> @dataclass
... class Wrap2D(Wrap1D):
...     '''A simple 2-array wrapper.'''
...     y: np.ndarray

The overload for numpy.concatenate() registered on Wrap1D will not work correctly for Wrap2D. However, NumPyOverloader supports single-dispatch on the calling type for the overload, so overloads can be customized for subclasses.

>>> @W_FUNCS.implements(np.add, Wrap2D)
... def add(w1, w2):
...     print("using Wrap2D implementation...")
...     return Wrap2D(np.add(w1.x, w2.x), np.add(w1.y, w2.y))

Checking this works:

>>> w2d = Wrap2D(np.arange(3), np.arange(3, 6))
>>> np.add(w2d, w2d)
using Wrap2D implementation...
Wrap2D(x=array([0, 2, 4]), y=array([ 6, 8, 10]))

Great! But rather than defining a new implementation for each subclass, let’s see how we could write a more broadly applicable overload:

>>> @W_FUNCS.implements(np.add, Wrap1D)  # overriding both
... @W_FUNCS.implements(np.add, Wrap2D)  # overriding both
... def add_general(w1, w2):
...     WT = type(w1)
...     return WT(*(np.add(getattr(w1, f.name), getattr(w2, f.name))
...                 for f in fields(WT)))

Checking this works:

>>> np.add(w2d, w2d)
Wrap2D(x=array([0, 2, 4]), y=array([ 6, 8, 10]))
>>> @dataclass
... class Wrap3D(Wrap2D):
...     '''A simple 3-array wrapper.'''
...     z: np.ndarray
>>> w3d = Wrap3D(np.arange(2), np.arange(3, 5), np.arange(6, 8))
>>> np.add(w3d, w3d)
Wrap3D(x=array([0, 2]), y=array([6, 8]), z=array([12, 14]))

In the previous examples we wrote implementations for a single NumPy function. Overloading the full set of NumPy functions this way would take a long time.

Wouldn’t it be better if we could write many fewer, based on groups of NumPy functions?

>>> add_funcs = {np.add, np.subtract}
>>> @W_FUNCS.assists(add_funcs, types=Wrap1D, dispatch_on=Wrap1D)
... def add_assists(cls, func, w1, w2, *args, **kwargs):
...     return cls(*(func(getattr(w1, f.name), getattr(w2, f.name), *args, **kwargs)
...                     for f in fields(cls)))

Checking this works:

>>> np.subtract(w2d, w2d)
Wrap2D(x=array([0, 0, 0]), y=array([0, 0, 0]))

We would also like to implement the accumulate method for all the add_funcs overloads:

>>> @add_assists.register("accumulate")
... def add_accumulate_assists(cls, func, w1, *args, **kwargs):
...     return cls(*(func(getattr(w1, f.name), *args, **kwargs)
...                  for f in fields(cls)))
>>> np.subtract.accumulate(w2d)
Wrap2D(x=array([ 0, -1, -3]), y=array([ 3, -1, -6]))
Attributes:
NP_OVERLOADSNumPyOverloader

A class-attribute of an instance of NumPyOverloader.