Skip to main content

Overload NumPy Functions

Project description

overload_numpy's PyPI Status overload_numpy's Coverage Status overload_numpy's Documentation Status Codestyle Black pre-commit

overload_numpy provides easy-to-use tools for working with NumPy’s __array_(u)func(tion)__. The library is fully typed and wheels are compiled with mypyc.

Implementing an Overload

First, some imports:

>>> from dataclasses import dataclass, fields
>>> from typing import ClassVar
>>> import numpy as np
>>> from overload_numpy import NumPyOverloader, NPArrayOverloadMixin

Now we can define a NumPyOverloader instance:

>>> W_FUNCS = NumPyOverloader()

The overloads apply to an array wrapping class. Let’s define one:

>>> @dataclass
... class Wrap1D(NPArrayOverloadMixin):
...     '''A simple array wrapper.'''
...     x: np.ndarray
...     NP_OVERLOADS: ClassVar[NumPyOverloader] = W_FUNCS
>>> w1d = Wrap1D(np.arange(3))

Now both numpy.ufunc (e.g. numpy.add) and numpy functions (e.g. numpy.concatenate) can be overloaded and registered for Wrap1D.

>>> @W_FUNCS.implements(np.add, Wrap1D)
... def add(w1, w2):
...     return Wrap1D(np.add(w1.x, w2.x))
>>> @W_FUNCS.implements(np.concatenate, Wrap1D)
... def concatenate(w1ds):
...     return Wrap1D(np.concatenate(tuple(w.x for w in w1ds)))

Time to check these work:

>>> np.add(w1d, w1d)
Wrap1D(x=array([0, 2, 4]))
>>> np.concatenate((w1d, w1d))
Wrap1D(x=array([0, 1, 2, 0, 1, 2]))

ufunc also have a number of methods: ‘at’, ‘accumulate’, etc. The function dispatch mechanism in NEP13 says that “If one of the input or output arguments implements __array_ufunc__, it is executed instead of the ufunc.” Currently the overloaded numpy.add does not work for any of the ufunc methods.

>>> try: np.add.accumulate(w1d)
... except Exception: print("failed")
failed

ufunc method overloads can be registered on the wrapped add implementation:

>>> @add.register('accumulate')
... def add_accumulate(w1):
...     return Wrap1D(np.add.accumulate(w1.x))
>>> np.add.accumulate(w1d)
Wrap1D(x=array([0, 1, 3]))

Dispatching Overloads for Subclasses

What if we defined a subclass of Wrap1D?

>>> @dataclass
... class Wrap2D(Wrap1D):
...     '''A simple 2-array wrapper.'''
...     y: np.ndarray

The overload for numpy.concatenate registered on Wrap1D will not work correctly for Wrap2D. However, NumPyOverloader supports single-dispatch on the calling type for the overload, so overloads can be customized for subclasses.

>>> @W_FUNCS.implements(np.add, Wrap2D)
... def add(w1, w2):
...     print("using Wrap2D implementation...")
...     return Wrap2D(np.add(w1.x, w2.x),
...                   np.add(w1.y, w2.y))
>>> @W_FUNCS.implements(np.concatenate, Wrap2D)
... def concatenate2(w2ds):
...     print("using Wrap2D implementation...")
...     return Wrap2D(np.concatenate(tuple(w.x for w in w2ds)),
...                   np.concatenate(tuple(w.y for w in w2ds)))

Checking these work:

>>> w2d = Wrap2D(np.arange(3), np.arange(3, 6))
>>> np.add(w2d, w2d)
using Wrap2D implementation...
Wrap2D(x=array([0, 2, 4]), y=array([ 6, 8, 10]))
>>> np.concatenate((w2d, w2d))
using Wrap2D implementation...
Wrap2D(x=array([0, 1, 2, 0, 1, 2]), y=array([3, 4, 5, 3, 4, 5]))

Great! But rather than defining a new implementation for each subclass, let’s see how we could write a more broadly applicable overload:

>>> @W_FUNCS.implements(np.add, Wrap1D)  # overriding both
... @W_FUNCS.implements(np.add, Wrap2D)  # overriding both
... def add_general(w1, w2):
...     WT = type(w1)
...     return WT(*(np.add(getattr(w1, f.name), getattr(w2, f.name))
...                 for f in fields(WT)))
>>> @W_FUNCS.implements(np.concatenate, Wrap1D)  # overriding both
... @W_FUNCS.implements(np.concatenate, Wrap2D)  # overriding both
... def concatenate_general(ws):
...     WT = type(ws[0])
...     return WT(*(np.concatenate(tuple(getattr(w, f.name) for w in ws))
...                 for f in fields(WT)))

Checking these work:

>>> np.add(w2d, w2d)
Wrap2D(x=array([0, 2, 4]), y=array([ 6, 8, 10]))
>>> np.concatenate((w2d, w2d))
Wrap2D(x=array([0, 1, 2, 0, 1, 2]), y=array([3, 4, 5, 3, 4, 5]))
>>> @dataclass
... class Wrap3D(Wrap2D):
...     '''A simple 3-array wrapper.'''
...     z: np.ndarray
>>> w3d = Wrap3D(np.arange(2), np.arange(3, 5), np.arange(6, 8))
>>> np.add(w3d, w3d)
Wrap3D(x=array([0, 2]), y=array([6, 8]), z=array([12, 14]))
>>> np.concatenate((w3d, w3d))
Wrap3D(x=array([0, 1, 0, 1]), y=array([3, 4, 3, 4]), z=array([6, 7, 6, 7]))

Assisting Groups of Overloads

In the previous examples we wrote implementations for a single NumPy function. Overloading the full set of NumPy functions this way would take a long time.

Wouldn’t it be better if we could write many fewer, based on groups of NumPy functions?

>>> add_funcs = {np.add, np.subtract}
>>> @W_FUNCS.assists(add_funcs, types=Wrap1D, dispatch_on=Wrap1D)
... def add_assists(cls, func, w1, w2, *args, **kwargs):
...     return cls(*(func(getattr(w1, f.name), getattr(w2, f.name), *args, **kwargs)
...                     for f in fields(cls)))
>>> stack_funcs = {np.vstack, np.hstack, np.dstack, np.column_stack, np.row_stack}
>>> @W_FUNCS.assists(stack_funcs, types=Wrap1D, dispatch_on=Wrap1D)
... def stack_assists(cls, func, ws, *args, **kwargs):
...     return cls(*(func(tuple(getattr(v, f.name) for v in ws), *args, **kwargs)
...                     for f in fields(cls)))

Checking these work:

>>> np.subtract(w2d, w2d)
Wrap2D(x=array([0, 0, 0]), y=array([0, 0, 0]))
>>> np.vstack((w1d, w1d))
Wrap1D(x=array([[0, 1, 2],
                    [0, 1, 2]]))
>>> np.hstack((w1d, w1d))
Wrap1D(x=array([0, 1, 2, 0, 1, 2]))

We would also like to implement the accumulate method for all the add_funcs overloads:

>>> @add_assists.register("accumulate")
... def add_accumulate_assists(cls, func, w1, *args, **kwargs):
...     return cls(*(func(getattr(w1, f.name), *args, **kwargs)
...                  for f in fields(cls)))
>>> np.subtract.accumulate(w2d)
Wrap2D(x=array([ 0, -1, -3]), y=array([ 3, -1, -6]))

Details

Want to see about type constraints and the API? Check out the docs!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

overload_numpy-0.1.0.tar.gz (52.3 kB view hashes)

Uploaded Source

Built Distributions

overload_numpy-0.1.0-py3-none-any.whl (28.6 kB view hashes)

Uploaded Python 3

overload_numpy-0.1.0-cp310-cp310-win_amd64.whl (126.4 kB view hashes)

Uploaded CPython 3.10 Windows x86-64

overload_numpy-0.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (237.0 kB view hashes)

Uploaded CPython 3.10 manylinux: glibc 2.17+ x86-64

overload_numpy-0.1.0-cp310-cp310-macosx_11_0_arm64.whl (136.2 kB view hashes)

Uploaded CPython 3.10 macOS 11.0+ ARM64

overload_numpy-0.1.0-cp310-cp310-macosx_10_9_x86_64.whl (140.0 kB view hashes)

Uploaded CPython 3.10 macOS 10.9+ x86-64

overload_numpy-0.1.0-cp310-cp310-macosx_10_9_universal2.whl (246.6 kB view hashes)

Uploaded CPython 3.10 macOS 10.9+ universal2 (ARM64, x86-64)

overload_numpy-0.1.0-cp39-cp39-win_amd64.whl (126.3 kB view hashes)

Uploaded CPython 3.9 Windows x86-64

overload_numpy-0.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (236.5 kB view hashes)

Uploaded CPython 3.9 manylinux: glibc 2.17+ x86-64

overload_numpy-0.1.0-cp39-cp39-macosx_11_0_arm64.whl (136.3 kB view hashes)

Uploaded CPython 3.9 macOS 11.0+ ARM64

overload_numpy-0.1.0-cp39-cp39-macosx_10_9_x86_64.whl (140.0 kB view hashes)

Uploaded CPython 3.9 macOS 10.9+ x86-64

overload_numpy-0.1.0-cp39-cp39-macosx_10_9_universal2.whl (246.6 kB view hashes)

Uploaded CPython 3.9 macOS 10.9+ universal2 (ARM64, x86-64)

overload_numpy-0.1.0-cp38-cp38-win_amd64.whl (126.1 kB view hashes)

Uploaded CPython 3.8 Windows x86-64

overload_numpy-0.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (234.4 kB view hashes)

Uploaded CPython 3.8 manylinux: glibc 2.17+ x86-64

overload_numpy-0.1.0-cp38-cp38-macosx_11_0_arm64.whl (134.8 kB view hashes)

Uploaded CPython 3.8 macOS 11.0+ ARM64

overload_numpy-0.1.0-cp38-cp38-macosx_10_9_x86_64.whl (138.2 kB view hashes)

Uploaded CPython 3.8 macOS 10.9+ x86-64

overload_numpy-0.1.0-cp38-cp38-macosx_10_9_universal2.whl (243.4 kB view hashes)

Uploaded CPython 3.8 macOS 10.9+ universal2 (ARM64, x86-64)

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page