Dataclasses + JAX
Project description
jax_dataclasses
Library for using dataclasses as JAX PyTrees.
Key features:
- PyTree registration; automatic generation of flatten/unflatten ops.
- Static analysis-friendly. Works out of the box with tools like
mypyandjedi. - Support for serialization via
flax.serialization.
Usage
Basic
jax_dataclasses is meant to be a drop-in replacement for
dataclasses.dataclass:
jax_dataclasses.dataclasshas the same interface asdataclasses.dataclass, but also register a class as a PyTree.jax_dataclasses.static_fieldhas the same interface asdataclasses.field, but will also mark the field as static. In a PyTree node, static fields are treated as part of the treedef instead of as a child of the node.
We also provide several aliases:
jax_dataclasses.[field, asdict, astuples, is_dataclass, replace] are all
identical to their counterparts in the standard dataclasses library.
Mutations
All dataclasses are automatically marked as frozen and thus immutable. We do, however, provide an interface that will (a) make a copy of a PyTree and (b) return a context in which any of that copy's contained dataclasses are temporarily mutable:
from jax import numpy as jnp
import jax_dataclasses
@jax_dataclasses.dataclass
class Node:
child: jnp.ndarray
obj = Node(child=jnp.zeros(3))
with jax_dataclasses.copy_and_mutate(obj) as obj_updated:
# Make mutations to the dataclass.
# Also does input validation: if the treedef of `obj` and `obj_updated` don't
# match, an AssertionError will be raised.
obj_updated.child = jnp.ones(3)
print(obj)
print(obj_updated)
Motivation
For compatibility with function transformations in JAX (jit, grad, vmap, etc), arguments and return values must all be PyTree containers. Dataclasses, by default, are not.
A few great solutions exist for automatically integrating dataclass-style
objects into PyTree structures, notably
chex.dataclass and
flax.struct. This library implements another
one.
Why not use chex.dataclass?
chex.dataclass is handy and lightweight, but currently lacks support for:
- Static fields: parameters that are either non-differentiable or simply not arrays.
- Serialization using
flax.serialization. This is really handy when parameters needed to be saved to disk!
Why not use flax.struct?
flax.struct addresses the two points above, but both it and chex.dataclass:
- Lack support for static analysis and type-checking. Static analysis for
libraries like
dataclassesandattrstends to rely on tooling-specific custom plugins, which doesn't exist for eitherchex.dataclassorflax.struct. - Make modifying deeply nested dataclasses fairly frustrating. Both introduce a
.replace(self, ...)method to dataclasses that's a bit more convenient than the traditionaldataclasses.replace(obj, ...)API, but this becomes really cumbersome to use when dataclasses are nested. Fixing this is the goal ofjax_dataclasses.copy_and_mutate().
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jax_dataclasses-0.0.1.tar.gz.
File metadata
- Download URL: jax_dataclasses-0.0.1.tar.gz
- Upload date:
- Size: 5.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.3.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.9.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d67f17cf6e8ec3271a0c86ffd5727bcd883e5002891d75d404342c84c243ed62
|
|
| MD5 |
de1a9cc601542a15f5b3e693c1884bcc
|
|
| BLAKE2b-256 |
5af2255911178d42e333e62fb038476b7c8455c48254afc2c0886702a486267e
|
File details
Details for the file jax_dataclasses-0.0.1-py3-none-any.whl.
File metadata
- Download URL: jax_dataclasses-0.0.1-py3-none-any.whl
- Upload date:
- Size: 6.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.3.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.9.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6a72ed945232587ff7c38c12411250bb716e09b9a5e4fb9c31cc3f8816a0f513
|
|
| MD5 |
e1d9d11f0fce44fa2c3e590cc5c4856c
|
|
| BLAKE2b-256 |
6d4b409328787a8eabd5219d5af67badb8de0e3b854091c8f36e312241a06f82
|