Named dimensions of nested arrays, independent of array implementation (Numpy, JAX, etc)
Project description
spekk is a tool for working with named dimensions for arrays
spekk
lets you declare specifications of the shapes of your arrays.
A common problem with multidimensional arrays is that it can be hard to keep track of the dimensions of the data over time. Additionally, dimensions may be "shared" across different arrays. For example, two different arguments to a function may share a dimension, like the positions of receiving elements in an ultrasound array and their corresponding recorded signals.
spekk
attempts to solve this by providing a way to declare the dimensions of arrays using a class called Spec
:
import numpy as np
from spekk import Spec
# Anything that is a collection of key-value pairs can be used as a container of arrays:
spec = Spec({
"receiving_element": { # <- You could also use your own custom class instead of dict
"position": ["receivers", "xyz"],
"weight": ["receivers"],
},
"signal": ["transmits", "receivers", "samples"],
})
# my_fn = lambda signal, receiving_element: ...
spekk
exists independently of the underlying arrays and can thus be used to specify the dimensions of both NumPy and JAX arrays (or anything else that has a shape
property). This is useful when working with code that need to support multiple array backends.
See an overview of spekk below. Read the documentation for more details.
Installation
python3 -m pip install spekk
Overview
spekk
lets you name the dimensions of an array as a sequence of strings:
import numpy as np
from spekk import Spec
data = np.ones([4, 5, 6]) # A 3D array with shape [4, 5, 6].
spec = Spec(["transmits", "receivers", "samples"]) # <- The names of the dimensions,
# each name corresponding to an
# axis in the array ([4, 5, 6]).
It also lets you specify the dimensions of nested data structures of arrays:
data = {
"receiving_element": {
"position": np.ones([5, 3]),
"weight": np.ones([5]),
},
"signal": np.ones([4, 5, 6]),
}
# Note that the structure is the same as the data:
spec = Spec({
"receiving_element": {
"position": ["receivers", "xyz"],
"weight": ["receivers"],
},
"signal": ["transmits", "receivers", "samples"],
})
You can spec what happens to data when you apply some function to it:
from spekk.transformations import Specced
def f(x, y, c):
"""Return a dictionary of a circle and two hyperbolas (one for each axis) evaluated
at point (x, y) with radius/axis-width c."""
return {
"circle": x**2 + y**2 - c**2,
"hyperbola": [
x**2 - y**2 - c**2,
x**2 - y**2 + c**2
],
}
# Ignore input_spec and just return the output_spec:
specced_f = Specced(f, lambda input_spec: {"circle": [], "hyperbola": ["axes"]})
specced_f = specced_f.build(spec) # <- Let the function know about the spec
assert f(x=1, y=2, c=3) == specced_f(x=1, y=2, c=3)
assert specced_f.output_spec == Spec({"circle": [], "hyperbola": ["axes"]})
You can describe what happens to the spec of a function when you transform the function, for example when transforming it to loop over the arguments:
from spekk.transformations import ForAll, compose
from spekk.util import shape
# The following spec represent the input kwargs to the function f:
spec = Spec({"x": ["x-values"], "y": ["y-values"], "c": ["c-values"]})
tf = compose(
specced_f,
ForAll("y-values"), # Run it for all the y-values (all rows)
ForAll("x-values"), # Run that for all the x-values (all columns)
ForAll("c-values"), # And then run that for all values of c
).build(spec) # <- Building the transformed function lets it know the spec of the data
# so that it also knows how to loop over it.
result = tf(x=np.linspace(-5, 5, 10), y=np.linspace(-5, 5, 11), c=np.arange(1, 6))
assert shape(result["circle"]) == (5, 10, 11)
assert tf.output_spec == Spec({
"circle": ["c-values", "x-values", "y-values"],
"hyperbola": ["c-values", "x-values", "y-values", "axes"],
})
You may use more powerful frameworks when transforming functions using ForAll
:
from functools import partial
import jax
# Use JAX's vmap to vectorize the function in order to run in parallel on GPUs:
ForAll_jax = partial(ForAll, vmap_impl=jax.vmap)
In most cases, Numpy broadcasting will be enough to get the desired result when working with multidimensional data. However, broadcasting can sometimes be difficult or inefficient to get right, and it can be hard to keep track of the dimensions of the arrays over time. ForAll
makes it easier to write code that loops over arbitrary dimensions and — if used in conjunction with for example JAX and jax.vmap
— it can be very efficient as well.
Read the documentation for more details.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file spekk-1.0.8.tar.gz
.
File metadata
- Download URL: spekk-1.0.8.tar.gz
- Upload date:
- Size: 45.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.20
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 25107d43552befce01217081a39b09fc3888476242c507577529b78d2eec4f45 |
|
MD5 | 286cf213f8f4ee87132a170246baeb11 |
|
BLAKE2b-256 | 015f5b05a564bd662a7d1ad49396ef6dd6f4e1df9fa25ec79c460db347d0a199 |
File details
Details for the file spekk-1.0.8-py3-none-any.whl
.
File metadata
- Download URL: spekk-1.0.8-py3-none-any.whl
- Upload date:
- Size: 40.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.20
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4de70251e33d223ed28e9eec3f7072e6f19fb44e5e3ca8632bd5f8fdefa58ff1 |
|
MD5 | fc8f2aa39b9de8cf42fbdf354700dc49 |
|
BLAKE2b-256 | de1ca74e40e5faaaa56445fbf8175620086db2eeaf0fc61f27ece4505beabbb0 |