Dataclasses + JAX
Project description
jax_dataclasses
Overview
jax_dataclasses
provides a wrapper around dataclasses.dataclass
for use in
JAX, which enables automatic support for:
- Pytree registration. This allows dataclasses to be used at API boundaries in JAX. (necessary for function transformations, JIT, etc)
- Serialization via
flax.serialization
. - Static analysis with tools like
mypy
,jedi
,pyright
, etc. (including for constructors) - Optional shape and data-type annotations, which are checked at runtime.
Heavily influenced by some great existing work (the obvious one being
flax.struct.dataclass
); see Alternatives for comparisons.
Installation
The latest version of jax_dataclasses
requires Python>=3.7. Python 3.6 will
work as well, but is missing support for shape annotations.
pip install jax_dataclasses
Core interface
jax_dataclasses
is meant to provide a drop-in replacement for
dataclasses.dataclass
:
jax_dataclasses.pytree_dataclass
has the same interface asdataclasses.dataclass
, but also registers the target class as a pytree container.jax_dataclasses.static_field
has the same interface asdataclasses.field
, but will also mark the field as static. In a pytree node, static fields will be treated as part of the treedef instead of as a child of the node; all fields that are not explicitly marked static should contain arrays or child nodes.
We also provide several aliases:
jax_dataclasses.[field, asdict, astuples, is_dataclass, replace]
are all
identical to their counterparts in the standard dataclasses library.
Mutations
All dataclasses are automatically marked as frozen and thus immutable (even when
no frozen=
parameter is passed in). To make changes to nested structures
easier, jax_dataclasses.copy_and_mutate
(a) makes
a copy of a pytree and (b) returns a context in which any of that copy's
contained dataclasses are temporarily mutable:
from jax import numpy as jnp
import jax_dataclasses
@jax_dataclasses.pytree_dataclass
class Node:
child: jnp.ndarray
obj = Node(child=jnp.zeros(3))
with jax_dataclasses.copy_and_mutate(obj) as obj_updated:
# Make mutations to the dataclass. This is primarily useful for nested
# dataclasses.
#
# Also does input validation: if the treedef, leaf shapes, or dtypes of `obj`
# and `obj_updated` don't match, an AssertionError will be raised.
# This can be disabled with a `validate=False` argument.
obj_updated.child = jnp.ones(3)
print(obj)
print(obj_updated)
Shape and data-type annotations
Subclassing from
jax_dataclasses.EnforcedAnnotationsMixin
enables
automatic shape and data-type validation. Arrays contained within dataclasses
are validated on instantiation and a .get_batch_axes()
method is exposed for
grabbing any common prefixes to the shapes of contained arrays.
We can start by importing the standard Annotated
type:
# Python >=3.9
from typing import Annotated
# Backport
from typing_extensions import Annotated
We can then add shape annotations:
@jax_dataclasses.pytree_dataclass
class MnistStruct(jax_dataclasses.EnforcedAnnotationsMixin):
image: Annotated[
jnp.ndarray,
(28, 28),
]
label: Annotated[
jnp.ndarray,
(10,),
]
Or data-type annotations:
image: Annotated[
jnp.ndarray,
jnp.float32,
]
label: Annotated[
jnp.ndarray,
jnp.integer,
]
Or both (note that annotations are order-invariant):
image: Annotated[
jnp.ndarray,
(28, 28),
jnp.float32,
]
label: Annotated[
jnp.ndarray,
(10,),
jnp.integer,
]
Then, assuming we've constrained both the shape and data-type:
# OK
struct = MnistStruct(
image=onp.zeros((28, 28), dtype=onp.float32),
label=onp.zeros((10,), dtype=onp.uint8),
)
print(struct.get_batch_axes()) # Prints ()
# OK
struct = MnistStruct(
image=onp.zeros((32, 28, 28), dtype=onp.float32),
label=onp.zeros((32, 10), dtype=onp.uint8),
)
print(struct.get_batch_axes()) # Prints (32,)
# AssertionError on instantiation because of type mismatch
MnistStruct(
image=onp.zeros((28, 28), dtype=onp.float32),
label=onp.zeros((10,), dtype=onp.float32), # Not an integer type!
)
# AssertionError on instantiation because of shape mismatch
MnistStruct(
image=onp.zeros((28, 28), dtype=onp.float32),
label=onp.zeros((5,), dtype=onp.uint8),
)
# AssertionError on instantiation because of batch axis mismatch
struct = MnistStruct(
image=onp.zeros((64, 28, 28), dtype=onp.float32),
label=onp.zeros((32, 10), dtype=onp.uint8),
)
Alternatives
A few other solutions exist for automatically integrating dataclass-style
objects into pytree structures. Great ones include:
chex.dataclass
,
flax.struct
, and
tjax.dataclass
. These all influenced
this library.
The main differentiators of jax_dataclasses
are:
-
Static analysis support.
tjax
has a custom mypy plugin to enable type checking, but isn't supported by other tools.flax.struct
implements thedataclass_transform
spec proposed by pyright, but isn't supported by other tools. Because@jax_dataclasses.pytree_dataclass
has the same API as@dataclasses.dataclass
, it can include pytree registration behavior at runtime while being treated as the standard decorator during static analysis. This means that all static checkers, language servers, and autocomplete engines that support the standarddataclasses
library should work out of the box withjax_dataclasses
. -
Nested dataclasses. Making replacements/modifications in deeply nested dataclasses can be really frustrating. The three alternatives all introduce a
.replace(self, ...)
method to dataclasses that's a bit more convenient than the traditionaldataclasses.replace(obj, ...)
API for shallow changes, but still becomes really cumbersome to use when dataclasses are nested.jax_dataclasses.copy_and_mutate()
is introduced to address this. -
Static field support. Parameters that should not be traced in JAX should be marked as static. This is supported in
flax
,tjax
, andjax_dataclasses
, but notchex
. -
Serialization. When working with
flax
, being able to serialize dataclasses is really handy. This is supported inflax.struct
(naturally) andjax_dataclasses
, but notchex
ortjax
. -
Shape and type annotations. See above.
Misc
This code was originally written for and factored out of jaxfg, where Nick Heppert provided valuable feedback.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file jax_dataclasses-1.2.1.tar.gz
.
File metadata
- Download URL: jax_dataclasses-1.2.1.tar.gz
- Upload date:
- Size: 11.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.1 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9637dc71bfc7be4e08ca8e9945ff3d4c3d6a1471eab3da2352b5efffbf3de725 |
|
MD5 | 336ec0a7dbe0e0bc5f288808b4ee1bb3 |
|
BLAKE2b-256 | cc8a7c3803099bbed6cc6bbf4dea1e0a8711c910062c3c8531b42ab9fe2077de |
File details
Details for the file jax_dataclasses-1.2.1-py3-none-any.whl
.
File metadata
- Download URL: jax_dataclasses-1.2.1-py3-none-any.whl
- Upload date:
- Size: 10.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.1 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | cadf9583c864cb1d45a28551a6857e5b5ff9190bc407b961f99390e685c8c54a |
|
MD5 | f316d205d32496193bc8dd92cdc2b067 |
|
BLAKE2b-256 | 0c59c682294ba1c593dfae7fb6f6d2976cf9715a67c629e20e4717310fa931f7 |