Dataclasses that behave like numpy arrays (with indexing, slicing, vectorization).
Project description
Dataclass Array
DataclassArray
are dataclasses which behave like numpy-like arrays (can be
batched, reshaped, sliced,...), but are compatible with Jax, TensorFlow, and
numpy (with torch support planned).
Documentation
To create a dca.DataclassArray
, take a frozen dataclass and:
- Inherit from
dca.DataclassArray
- Annotate the fields with
etils.array_types
to specify the inner shape and dtype of the array (see below for static or nested dataclass fields).
import dataclass_array as dca
from etils.array_types import FloatArray
@dataclasses.dataclass(frozen=True)
class Ray(dca.DataclassArray):
pos: FloatArray['*batch_shape 3']
dir: FloatArray['*batch_shape 3']
Afterwards, the dataclass can be used as a numpy array:
ray = Ray(pos=jnp.zeros((3, 3)), dir=jnp.eye(3))
ray.shape == (3,) # 3 rays batched together
ray.pos.shape == (3, 3) # Individual fields still available
# Numpy slicing/indexing/masking
ray = ray[..., 1:2]
ray = ray[norm(ray.dir) > 1e-7]
# Shape transformation
ray = ray.reshape((1, 3))
ray = ray.reshape('h w -> w h') # Native einops support
ray = ray.flatten()
# Stack multiple dataclass arrays together
ray = dca.stack([ray0, ray1, ...])
# Supports TF, Jax, Numpy (torch planned) and can be easily converted
ray = ray.as_jax() # as_np(), as_tf()
ray.xnp == jax.numpy # `numpy`, `jax.numpy`, `tf.experimental.numpy`
# Compatibility `with jax.tree_util`, `jax.vmap`,..
ray = jax.tree_util.tree_map(lambda x: x+1, ray)
A DataclassArray
has 2 types of fields:
- Array fields: Fields batched like numpy arrays, with reshape, slicing,...
Can be
xnp.ndarray
or nesteddca.DataclassArray
. - Static fields: Other non-numpy field. Are not modified by reshaping,...
Static fields are also ignored in
jax.tree_map
.
@dataclasses.dataclass(frozen=True)
class MyArray(dca.DataclassArray):
# Array fields
a: FloatArray['*batch_shape 3'] # Defined by `etils.array_types`
b: Ray # Nested DataclassArray (inner shape == `()`)
# Array fields explicitly defined
c: Any = dca.field(shape=(3,), dtype=np.float32)
d: Ray = dca.field(shape=(3,), dtype=Ray) # Nested DataclassArray
# Static field (everything not defined as above)
e: float
f: np.array
Installation
pip install dataclass_array
This is not an official Google product
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
dataclass_array-1.0.0.tar.gz
(33.3 kB
view hashes)
Built Distribution
Close
Hashes for dataclass_array-1.0.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | da24fd78d1b03ffb28feeece78aec52b7f22c5fd30f08670add1bb83d6742c3b |
|
MD5 | 3c6e8f0d209b703ab792e6a1c1d4e08e |
|
BLAKE2b-256 | 3ba831d0da3084bf33b7ae90ca491bcfe4613c28f7da6195bc55db2e7934c397 |