Overload NumPy Functions
Project description
overload_numpy provides easy-to-use tools for working with NumPy’s __array_(u)func(tion)__. The library is fully typed and wheels are compiled with mypyc.
Implementing an Overload
First, some imports:
>>> from dataclasses import dataclass, fields >>> from typing import ClassVar >>> import numpy as np >>> from overload_numpy import NumPyOverloader, NPArrayOverloadMixin
Now we can define a NumPyOverloader instance:
>>> W_FUNCS = NumPyOverloader()
The overloads apply to an array wrapping class. Let’s define one:
>>> @dataclass ... class Wrap1D(NPArrayOverloadMixin): ... '''A simple array wrapper.''' ... x: np.ndarray ... NP_OVERLOADS: ClassVar[NumPyOverloader] = W_FUNCS>>> w1d = Wrap1D(np.arange(3))
Now both numpy.ufunc (e.g. numpy.add) and numpy functions (e.g. numpy.concatenate) can be overloaded and registered for Wrap1D.
>>> @W_FUNCS.implements(np.add, Wrap1D) ... def add(w1, w2): ... return Wrap1D(np.add(w1.x, w2.x))>>> @W_FUNCS.implements(np.concatenate, Wrap1D) ... def concatenate(w1ds): ... return Wrap1D(np.concatenate(tuple(w.x for w in w1ds)))
Time to check these work:
>>> np.add(w1d, w1d) Wrap1D(x=array([0, 2, 4]))>>> np.concatenate((w1d, w1d)) Wrap1D(x=array([0, 1, 2, 0, 1, 2]))
ufunc also have a number of methods: ‘at’, ‘accumulate’, etc. The function dispatch mechanism in NEP13 says that “If one of the input or output arguments implements __array_ufunc__, it is executed instead of the ufunc.” Currently the overloaded numpy.add does not work for any of the ufunc methods.
>>> try: np.add.accumulate(w1d) ... except Exception: print("failed") failed
ufunc method overloads can be registered on the wrapped add implementation:
>>> @add.register('accumulate') ... def add_accumulate(w1): ... return Wrap1D(np.add.accumulate(w1.x))>>> np.add.accumulate(w1d) Wrap1D(x=array([0, 1, 3]))
Dispatching Overloads for Subclasses
What if we defined a subclass of Wrap1D?
>>> @dataclass ... class Wrap2D(Wrap1D): ... '''A simple 2-array wrapper.''' ... y: np.ndarray
The overload for numpy.concatenate registered on Wrap1D will not work correctly for Wrap2D. However, NumPyOverloader supports single-dispatch on the calling type for the overload, so overloads can be customized for subclasses.
>>> @W_FUNCS.implements(np.add, Wrap2D) ... def add(w1, w2): ... print("using Wrap2D implementation...") ... return Wrap2D(np.add(w1.x, w2.x), ... np.add(w1.y, w2.y))>>> @W_FUNCS.implements(np.concatenate, Wrap2D) ... def concatenate2(w2ds): ... print("using Wrap2D implementation...") ... return Wrap2D(np.concatenate(tuple(w.x for w in w2ds)), ... np.concatenate(tuple(w.y for w in w2ds)))
Checking these work:
>>> w2d = Wrap2D(np.arange(3), np.arange(3, 6)) >>> np.add(w2d, w2d) using Wrap2D implementation... Wrap2D(x=array([0, 2, 4]), y=array([ 6, 8, 10]))>>> np.concatenate((w2d, w2d)) using Wrap2D implementation... Wrap2D(x=array([0, 1, 2, 0, 1, 2]), y=array([3, 4, 5, 3, 4, 5]))
Great! But rather than defining a new implementation for each subclass, let’s see how we could write a more broadly applicable overload:
>>> @W_FUNCS.implements(np.add, Wrap1D) # overriding both ... @W_FUNCS.implements(np.add, Wrap2D) # overriding both ... def add_general(w1, w2): ... WT = type(w1) ... return WT(*(np.add(getattr(w1, f.name), getattr(w2, f.name)) ... for f in fields(WT)))>>> @W_FUNCS.implements(np.concatenate, Wrap1D) # overriding both ... @W_FUNCS.implements(np.concatenate, Wrap2D) # overriding both ... def concatenate_general(ws): ... WT = type(ws[0]) ... return WT(*(np.concatenate(tuple(getattr(w, f.name) for w in ws)) ... for f in fields(WT)))
Checking these work:
>>> np.add(w2d, w2d) Wrap2D(x=array([0, 2, 4]), y=array([ 6, 8, 10]))>>> np.concatenate((w2d, w2d)) Wrap2D(x=array([0, 1, 2, 0, 1, 2]), y=array([3, 4, 5, 3, 4, 5]))>>> @dataclass ... class Wrap3D(Wrap2D): ... '''A simple 3-array wrapper.''' ... z: np.ndarray>>> w3d = Wrap3D(np.arange(2), np.arange(3, 5), np.arange(6, 8)) >>> np.add(w3d, w3d) Wrap3D(x=array([0, 2]), y=array([6, 8]), z=array([12, 14])) >>> np.concatenate((w3d, w3d)) Wrap3D(x=array([0, 1, 0, 1]), y=array([3, 4, 3, 4]), z=array([6, 7, 6, 7]))
Assisting Groups of Overloads
In the previous examples we wrote implementations for a single NumPy function. Overloading the full set of NumPy functions this way would take a long time.
Wouldn’t it be better if we could write many fewer, based on groups of NumPy functions?
>>> add_funcs = {np.add, np.subtract} >>> @W_FUNCS.assists(add_funcs, types=Wrap1D, dispatch_on=Wrap1D) ... def add_assists(cls, func, w1, w2, *args, **kwargs): ... return cls(*(func(getattr(w1, f.name), getattr(w2, f.name), *args, **kwargs) ... for f in fields(cls)))>>> stack_funcs = {np.vstack, np.hstack, np.dstack, np.column_stack, np.row_stack} >>> @W_FUNCS.assists(stack_funcs, types=Wrap1D, dispatch_on=Wrap1D) ... def stack_assists(cls, func, ws, *args, **kwargs): ... return cls(*(func(tuple(getattr(v, f.name) for v in ws), *args, **kwargs) ... for f in fields(cls)))
Checking these work:
>>> np.subtract(w2d, w2d) Wrap2D(x=array([0, 0, 0]), y=array([0, 0, 0]))>>> np.vstack((w1d, w1d)) Wrap1D(x=array([[0, 1, 2], [0, 1, 2]]))>>> np.hstack((w1d, w1d)) Wrap1D(x=array([0, 1, 2, 0, 1, 2]))
We would also like to implement the accumulate method for all the add_funcs overloads:
>>> @add_assists.register("accumulate") ... def add_accumulate_assists(cls, func, w1, *args, **kwargs): ... return cls(*(func(getattr(w1, f.name), *args, **kwargs) ... for f in fields(cls)))>>> np.subtract.accumulate(w2d) Wrap2D(x=array([ 0, -1, -3]), y=array([ 3, -1, -6]))
Details
Want to see about type constraints and the API? Check out the docs!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
Hashes for overload_numpy-0.1.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | a407d991ca734d8cb586d4f15b1a6f8384418a6a8f6bc793b0f2fe53c9b95ca9 |
|
MD5 | 391891e9f58f8f712789b398d642670c |
|
BLAKE2b-256 | 1ad9bad8797c33d936f3f18e16622a687d31073a0ed77b97ca52c0065c04b004 |
Hashes for overload_numpy-0.1.0-cp310-cp310-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0ebabe5326ef2abc221679732df7f6449b56fb37e6d629687126830dfb001421 |
|
MD5 | 55bfe09d2de46094e1cdafe40c90955b |
|
BLAKE2b-256 | 94e1f66893799e6d21ad67879cb244713f59d85305747a6cf243b5b257494871 |
Hashes for overload_numpy-0.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 161167cba5fda5761a23d12ec95e98cee65480c27960d74001423e548de7790e |
|
MD5 | e61ef26856a5f97c922e6738f467ee47 |
|
BLAKE2b-256 | 3aa4bcb4efca97551c400e6b5c7b85cc602cbdf107ffb872487950d480305108 |
Hashes for overload_numpy-0.1.0-cp310-cp310-macosx_11_0_arm64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 87bad0051ba795165002736e9d76eacd9995f4cea54161c05f835597fa59dfab |
|
MD5 | abd2d849c4035d44ba718a7c532dd074 |
|
BLAKE2b-256 | 3253a397b543ad2847089043a7839a09e0cb495179c7f45c325d0a3c89880283 |
Hashes for overload_numpy-0.1.0-cp310-cp310-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7aa431d79ef86ce7a03424009791bce61ec214861152868b7fa2dc3302b70298 |
|
MD5 | 48549ba2cd7ea753825bab318cb7809f |
|
BLAKE2b-256 | 7620c7d1b1b912d3485c67cf13a1480807a7eb1f3b4bb8c8ec29c92f3874d45f |
Hashes for overload_numpy-0.1.0-cp310-cp310-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 09109158ac55bff365d51e034444e703e86f19caf5b0eaf6c145071ca61c7903 |
|
MD5 | 307ae0729ab928dbcee27e443321deca |
|
BLAKE2b-256 | 461ee7d25327152bbca028f6f3bb5d315f49a74f6d24ef632444c588f919896d |
Hashes for overload_numpy-0.1.0-cp39-cp39-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e29582ddfd17b79037bdbd58e0f3441c427d10c095b8a2b63c802deff3bf15c5 |
|
MD5 | 20c8c0d2e456b91c787f35bca54f2478 |
|
BLAKE2b-256 | d09c89f7708b72e2600005947537b2c4afc1550fb3eee9d03f5379e2afdbcf6a |
Hashes for overload_numpy-0.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | d6fb93f973a1c96228e7c457daeb5d663e929f278449428d1000adfa80a7f07e |
|
MD5 | 1730fa3cfcf51b953204f30e8709ca08 |
|
BLAKE2b-256 | 2eed4618900bbe85e149c972f581d1792b1c9c640c28f8d78ae852641566f2d8 |
Hashes for overload_numpy-0.1.0-cp39-cp39-macosx_11_0_arm64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5347ffce79ec43d93db57d336218b200209fff51d945113edb2a5f8111682cc8 |
|
MD5 | bde17cd181e2b2239341bf07c50909dd |
|
BLAKE2b-256 | c0ec9f973301fd5bec73fc344c9040d4ed65702aab5e18925403aaedf911aab4 |
Hashes for overload_numpy-0.1.0-cp39-cp39-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3b312e0db4e196acb477dc2f978498a184f6daa0e5fce14acedc578ef06c37a3 |
|
MD5 | a9e341c841e61d96ecc476f7b5507d83 |
|
BLAKE2b-256 | e362beb1de96d61e9aeffae00330e6f5f2e237ea8965205574f3a6c68e2aa9a4 |
Hashes for overload_numpy-0.1.0-cp39-cp39-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f928b727741b3e649c064b5a51fbe665bc4e971710a5b23f3494d4efa9da1f82 |
|
MD5 | 7f4681df601d23342d98a80e845975bb |
|
BLAKE2b-256 | 4b4c8072aaddb59080ea5fa4568b6e8d93a28b251d55a2f7f7d31761064e94ff |
Hashes for overload_numpy-0.1.0-cp38-cp38-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c088a817539f34c52947b599b6f9dbc72f6e377fd90c3888114e675b59440a6b |
|
MD5 | 05f2f1c44359d9cd951290342fee2e75 |
|
BLAKE2b-256 | 293b38e6bce9c0ec7f6d56d23545e5fa0449e7f5872efd9578cbddaf46ab00e6 |
Hashes for overload_numpy-0.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4c30b2ff0a78c9104b2199cf5cc8b054cc3e9c9913c89da2f27bcf796762bad6 |
|
MD5 | 92b36e013ceed8f17894d8d899ffe959 |
|
BLAKE2b-256 | fc083287cd105495278c0b0802fb6d4cfbb9d31bd029c810ca4cce645ef7a593 |
Hashes for overload_numpy-0.1.0-cp38-cp38-macosx_11_0_arm64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | ff1ec301954cc845ad06498afa052bb087f4d749b15b7fad963c29ec4f07db2a |
|
MD5 | 1e16e849ba70cb5bee02f177d11651db |
|
BLAKE2b-256 | 6aed3bb40ec70b39a3d6d4b8840ccbef66a8a0a3fc0cc9c8da868870ff85e352 |
Hashes for overload_numpy-0.1.0-cp38-cp38-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4193cd0dc3561682394134a2ddd00c9b360a0423da2bbc3b587ea684c999a21c |
|
MD5 | 33b2909489ca039f1e656912306f247f |
|
BLAKE2b-256 | b2563f8d84c7253b6949fbb2ba504d3437ae231fe1d337f151df8d002d9237a4 |
Hashes for overload_numpy-0.1.0-cp38-cp38-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4e9e200cf2bb1bd318073a00c0f32afa1f8a6a021ad71f04d905f35028e9aab4 |
|
MD5 | 99ff8ea1911e652020c34accda895352 |
|
BLAKE2b-256 | 65527ff371444d81b7bbd8951ad1cb36fe0afe80ae63928ab476bee90357e941 |