Skip to main content

Dataclasses + JAX

Project description

jax_dataclasses

build mypy lint codecov

Library for using dataclasses as JAX PyTrees.

Key features:

  • PyTree registration; automatic generation of flatten/unflatten ops.
  • Static analysis-friendly. Works out of the box with tools like mypy and jedi.
  • Support for serialization via flax.serialization.

Usage

Basic

jax_dataclasses is meant to be a drop-in replacement for dataclasses.dataclass:

  • jax_dataclasses.dataclass has the same interface as dataclasses.dataclass, but also register a class as a PyTree.
  • jax_dataclasses.static_field has the same interface as dataclasses.field, but will also mark the field as static. In a PyTree node, static fields are treated as part of the treedef instead of as a child of the node.

We also provide several aliases: jax_dataclasses.[field, asdict, astuples, is_dataclass, replace] are all identical to their counterparts in the standard dataclasses library.

Mutations

All dataclasses are automatically marked as frozen and thus immutable. We do, however, provide an interface that will (a) make a copy of a PyTree and (b) return a context in which any of that copy's contained dataclasses are temporarily mutable:

from jax import numpy as jnp
import jax_dataclasses

@jax_dataclasses.dataclass
class Node:
  child: jnp.ndarray

obj = Node(child=jnp.zeros(3))

with jax_dataclasses.copy_and_mutate(obj) as obj_updated:
  # Make mutations to the dataclass.
  # Also does input validation: if the treedef of `obj` and `obj_updated` don't
  # match, an AssertionError will be raised.
  obj_updated.child = jnp.ones(3)

print(obj)
print(obj_updated)

Motivation

For compatibility with function transformations in JAX (jit, grad, vmap, etc), arguments and return values must all be PyTree containers. Dataclasses, by default, are not.

A few great solutions exist for automatically integrating dataclass-style objects into PyTree structures, notably chex.dataclass and flax.struct. This library implements another one.

Why not use chex.dataclass?

chex.dataclass is handy and lightweight, but currently lacks support for:

  • Static fields: parameters that are either non-differentiable or simply not arrays.
  • Serialization using flax.serialization. This is really handy when parameters needed to be saved to disk!

Why not use flax.struct?

flax.struct addresses the two points above, but both it and chex.dataclass:

  • Lack support for static analysis and type-checking. Static analysis for libraries like dataclasses and attrs tends to rely on tooling-specific custom plugins, which doesn't exist for either chex.dataclass or flax.struct.
  • Make modifying deeply nested dataclasses fairly frustrating. Both introduce a .replace(self, ...) method to dataclasses that's a bit more convenient than the traditional dataclasses.replace(obj, ...) API, but this becomes really cumbersome to use when dataclasses are nested. Fixing this is the goal of jax_dataclasses.copy_and_mutate().

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_dataclasses-0.0.1.tar.gz (5.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_dataclasses-0.0.1-py3-none-any.whl (6.2 kB view details)

Uploaded Python 3

File details

Details for the file jax_dataclasses-0.0.1.tar.gz.

File metadata

  • Download URL: jax_dataclasses-0.0.1.tar.gz
  • Upload date:
  • Size: 5.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.3.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.9.5

File hashes

Hashes for jax_dataclasses-0.0.1.tar.gz
Algorithm Hash digest
SHA256 d67f17cf6e8ec3271a0c86ffd5727bcd883e5002891d75d404342c84c243ed62
MD5 de1a9cc601542a15f5b3e693c1884bcc
BLAKE2b-256 5af2255911178d42e333e62fb038476b7c8455c48254afc2c0886702a486267e

See more details on using hashes here.

File details

Details for the file jax_dataclasses-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: jax_dataclasses-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 6.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.3.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.9.5

File hashes

Hashes for jax_dataclasses-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 6a72ed945232587ff7c38c12411250bb716e09b9a5e4fb9c31cc3f8816a0f513
MD5 e1d9d11f0fce44fa2c3e590cc5c4856c
BLAKE2b-256 6d4b409328787a8eabd5219d5af67badb8de0e3b854091c8f36e312241a06f82

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page