Dataclasses + JAX
- Core interface
- Static fields
- Shape and data-type annotations
jax_dataclasses provides a wrapper around
dataclasses.dataclass for use in
JAX, which enables automatic support for:
- Pytree registration. This allows dataclasses to be used at API boundaries in JAX. (necessary for function transformations, JIT, etc)
- Serialization via
- Static analysis with tools like
pyright, etc. (including for constructors)
- Optional shape and data-type annotations, which are checked at runtime.
Heavily influenced by some great existing work (the obvious one being
flax.struct.dataclass); see Alternatives for comparisons.
In Python >=3.7:
pip install jax_dataclasses
We can then import:
import jax_dataclasses as jdc
jax_dataclasses is meant to provide a drop-in replacement for
the same interface as
dataclasses.dataclass, but also registers the target
class as a pytree node.
We also provide several aliases:
jdc.[field, asdict, astuples, is_dataclass, replace] are all identical to
their counterparts in the standard dataclasses library.
To mark a field as static (in this context: constant at compile-time), we can
wrap its type with
@jdc.pytree_dataclass class A: a: jnp.ndarray b: jdc.Static[bool]
In a pytree node, static fields will be treated as part of the treedef instead of as a child of the node; all fields that are not explicitly marked static should contain arrays or child nodes.
Experimental: in combination with
jdc.Static can also be used in function signatures.
All dataclasses are automatically marked as frozen and thus immutable (even when
frozen= parameter is passed in). To make changes to nested structures
jdc.copy_and_mutate (a) makes a copy of a
pytree and (b) returns a context in which any of that copy's contained
dataclasses are temporarily mutable:
from jax import numpy as jnp import jax_dataclasses as jdc @jdc.pytree_dataclass class Node: child: jnp.ndarray obj = Node(child=jnp.zeros(3)) with jdc.copy_and_mutate(obj) as obj_updated: # Make mutations to the dataclass. This is primarily useful for nested # dataclasses. # # Also does input validation: if the treedef, leaf shapes, or dtypes of `obj` # and `obj_updated` don't match, an AssertionError will be raised. # This can be disabled with a `validate=False` argument. obj_updated.child = jnp.ones(3) print(obj) print(obj_updated)
Shape and data-type annotations
enables automatic shape and data-type validation. Arrays contained within
dataclasses are validated on instantiation and a
is exposed for grabbing any common batch axes to the shapes of contained arrays.
We can start by importing the standard
# Python >=3.9 from typing import Annotated # Backport from typing_extensions import Annotated
We can then add shape annotations:
@jdc.pytree_dataclass class MnistStruct(jdc.EnforcedAnnotationsMixin): image: Annotated[ jnp.ndarray, # Note that we can move the expected location of the batch axes by # shifting the ellipsis around. # # If the ellipsis is excluded, we assume batch axes at the start of the # shape. (..., 28, 28), ] label: Annotated[ jnp.ndarray, (..., 10), ]
Or data-type annotations:
image: Annotated[ jnp.ndarray, jnp.float32, ] label: Annotated[ jnp.ndarray, jnp.integer, ]
Or both (note that annotations are order-invariant):
image: Annotated[ jnp.ndarray, (..., 28, 28), jnp.float32, ] label: Annotated[ jnp.ndarray, (..., 10), jnp.integer, ]
Then, assuming we've constrained both the shape and data-type:
# OK struct = MnistStruct( image=onp.zeros((28, 28), dtype=onp.float32), label=onp.zeros((10,), dtype=onp.uint8), ) print(struct.get_batch_axes()) # Prints () # OK struct = MnistStruct( image=onp.zeros((32, 28, 28), dtype=onp.float32), label=onp.zeros((32, 10), dtype=onp.uint8), ) print(struct.get_batch_axes()) # Prints (32,) # AssertionError on instantiation because of type mismatch MnistStruct( image=onp.zeros((28, 28), dtype=onp.float32), label=onp.zeros((10,), dtype=onp.float32), # Not an integer type! ) # AssertionError on instantiation because of shape mismatch MnistStruct( image=onp.zeros((28, 28), dtype=onp.float32), label=onp.zeros((5,), dtype=onp.uint8), ) # AssertionError on instantiation because of batch axis mismatch struct = MnistStruct( image=onp.zeros((64, 28, 28), dtype=onp.float32), label=onp.zeros((32, 10), dtype=onp.uint8), )
A few other solutions exist for automatically integrating dataclass-style
objects into pytree structures. Great ones include:
tjax.dataclass. These all influenced
The main differentiators of
Static analysis support.
tjaxhas a custom mypy plugin to enable type checking, but isn't supported by other tools.
dataclass_transformspec proposed by pyright, but isn't supported by other tools. Because
@jdc.pytree_dataclasshas the same API as
@dataclasses.dataclass, it can include pytree registration behavior at runtime while being treated as the standard decorator during static analysis. This means that all static checkers, language servers, and autocomplete engines that support the standard
dataclasseslibrary should work out of the box with
Nested dataclasses. Making replacements/modifications in deeply nested dataclasses can be really frustrating. The three alternatives all introduce a
.replace(self, ...)method to dataclasses that's a bit more convenient than the traditional
dataclasses.replace(obj, ...)API for shallow changes, but still becomes really cumbersome to use when dataclasses are nested.
jdc.copy_and_mutate()is introduced to address this.
Static field support. Parameters that should not be traced in JAX should be marked as static. This is supported in
jax_dataclasses, but not
Serialization. When working with
flax, being able to serialize dataclasses is really handy. This is supported in
jax_dataclasses, but not
Shape and type annotations. See above.
You can also eschew the dataclass-style interface entirely;
see how brax registers pytrees.
This is a reasonable thing to prefer: it requires some floating strings and
breaks things that I care about but you may not (like immutability and
__post_init__), but gives more flexibility with custom
Release history Release notifications | RSS feed
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Hashes for jax_dataclasses-1.5.1-py3-none-any.whl