NPArrayUFuncOverloadMixin#
- class overload_numpy.NPArrayUFuncOverloadMixin
Bases:
object
Mixin for adding
__array_ufunc__
to a class.This mixin adds the method
__array_ufunc__
. Subclasses must define a class variableNP_OVERLOADS
.Notes
When compiled this class is a
mypyc
trait()
and permits interpreted subclasses (see https://mypyc.readthedocs.io/en/latest/native_classes.html#inheritance).Examples
First, some imports:
>>> from dataclasses import dataclass, fields >>> from typing import ClassVar >>> import numpy as np >>> from overload_numpy import NumPyOverloader, NPArrayUFuncOverloadMixin
Now we can define a
NumPyOverloader
instance:>>> W_FUNCS = NumPyOverloader()
The overloads apply to an array wrapping class. Let’s define one:
>>> @dataclass ... class Wrap1D(NPArrayUFuncOverloadMixin): ... '''A simple array wrapper.''' ... x: np.ndarray ... NP_OVERLOADS: ClassVar[NumPyOverloader] = W_FUNCS
>>> w1d = Wrap1D(np.arange(3))
Now
numpy.ufunc
can be overloaded and registered forWrap1D
.>>> @W_FUNCS.implements(np.add, Wrap1D) ... def add(w1, w2): ... return Wrap1D(np.add(w1.x, w2.x))
Time to check this works:
>>> np.add(w1d, w1d) Wrap1D(x=array([0, 2, 4]))
ufunc
also have a number of methods: ‘at’, ‘accumulate’, etc. The function dispatch mechanism in NEP13 says that “If one of the input or output arguments implements __array_ufunc__, it is executed instead of the ufunc.” Currently the overloadednumpy.add
does not work for any of theufunc
methods.>>> try: np.add.accumulate(w1d) ... except Exception: print("failed") failed
ufunc
method overloads can be registered on the wrappedadd
implementation:>>> @add.register('accumulate') ... def add_accumulate(w1): ... return Wrap1D(np.add.accumulate(w1.x))
>>> np.add.accumulate(w1d) Wrap1D(x=array([0, 1, 3]))
What if we defined a subclass of
Wrap1D
?>>> @dataclass ... class Wrap2D(Wrap1D): ... '''A simple 2-array wrapper.''' ... y: np.ndarray
The overload for
numpy.concatenate()
registered onWrap1D
will not work correctly forWrap2D
. However,NumPyOverloader
supports single-dispatch on the calling type for the overload, so overloads can be customized for subclasses.>>> @W_FUNCS.implements(np.add, Wrap2D) ... def add(w1, w2): ... print("using Wrap2D implementation...") ... return Wrap2D(np.add(w1.x, w2.x), np.add(w1.y, w2.y))
Checking this works:
>>> w2d = Wrap2D(np.arange(3), np.arange(3, 6)) >>> np.add(w2d, w2d) using Wrap2D implementation... Wrap2D(x=array([0, 2, 4]), y=array([ 6, 8, 10]))
Great! But rather than defining a new implementation for each subclass, let’s see how we could write a more broadly applicable overload:
>>> @W_FUNCS.implements(np.add, Wrap1D) # overriding both ... @W_FUNCS.implements(np.add, Wrap2D) # overriding both ... def add_general(w1, w2): ... WT = type(w1) ... return WT(*(np.add(getattr(w1, f.name), getattr(w2, f.name)) ... for f in fields(WT)))
Checking this works:
>>> np.add(w2d, w2d) Wrap2D(x=array([0, 2, 4]), y=array([ 6, 8, 10]))
>>> @dataclass ... class Wrap3D(Wrap2D): ... '''A simple 3-array wrapper.''' ... z: np.ndarray
>>> w3d = Wrap3D(np.arange(2), np.arange(3, 5), np.arange(6, 8)) >>> np.add(w3d, w3d) Wrap3D(x=array([0, 2]), y=array([6, 8]), z=array([12, 14]))
In the previous examples we wrote implementations for a single NumPy function. Overloading the full set of NumPy functions this way would take a long time.
Wouldn’t it be better if we could write many fewer, based on groups of NumPy functions?
>>> add_funcs = {np.add, np.subtract} >>> @W_FUNCS.assists(add_funcs, types=Wrap1D, dispatch_on=Wrap1D) ... def add_assists(cls, func, w1, w2, *args, **kwargs): ... return cls(*(func(getattr(w1, f.name), getattr(w2, f.name), *args, **kwargs) ... for f in fields(cls)))
Checking this works:
>>> np.subtract(w2d, w2d) Wrap2D(x=array([0, 0, 0]), y=array([0, 0, 0]))
We would also like to implement the
accumulate
method for all theadd_funcs
overloads:>>> @add_assists.register("accumulate") ... def add_accumulate_assists(cls, func, w1, *args, **kwargs): ... return cls(*(func(getattr(w1, f.name), *args, **kwargs) ... for f in fields(cls)))
>>> np.subtract.accumulate(w2d) Wrap2D(x=array([ 0, -1, -3]), y=array([ 3, -1, -6]))
- Attributes:
- NP_OVERLOADS
NumPyOverloader
A class-attribute of an instance of
NumPyOverloader
.
- NP_OVERLOADS