Skip to main content

Dataclasses + JAX

Project description

jax_dataclasses

build mypy lint codecov

jax_dataclasses provides a wrapper around dataclasses.dataclass for use in JAX, which enables automatic support for:

  • Pytree registration. This allows dataclasses to be used at API boundaries in JAX. (necessary for function transformations, JIT, etc)
  • Serialization via flax.serialization.

Notably, jax_dataclasses is designed to work seamlessly with static analysis, including tools like mypy and jedi.

Heavily influenced by some great existing work; see Alternatives for comparisons.

Installation

pip install jax_dataclasses

Core interface

jax_dataclasses is meant to provide a drop-in replacement for dataclasses.dataclass:

  • jax_dataclasses.pytree_dataclass has the same interface as dataclasses.dataclass, but also registers the target class as a pytree container.
  • jax_dataclasses.static_field has the same interface as dataclasses.field, but will also mark the field as static. In a pytree node, static fields will be treated as part of the treedef instead of as a child of the node; all fields that are not explicitly marked static should contain arrays or child nodes.

We also provide several aliases: jax_dataclasses.[field, asdict, astuples, is_dataclass, replace] are all identical to their counterparts in the standard dataclasses library.

Mutations

All dataclasses are automatically marked as frozen and thus immutable (even when no frozen= parameter is passed in). To make changes to nested structures easier, we provide an interface that will (a) make a copy of a pytree and (b) return a context in which any of that copy's contained dataclasses are temporarily mutable:

from jax import numpy as jnp
import jax_dataclasses

@jax_dataclasses.pytree_dataclass
class Node:
  child: jnp.ndarray

obj = Node(child=jnp.zeros(3))

with jax_dataclasses.copy_and_mutate(obj) as obj_updated:
  # Make mutations to the dataclass. This is primarily useful for nested
  # dataclasses.
  #
  # Also does input validation: if the treedef, leaf shapes, or dtypes of `obj`
  # and `obj_updated` don't match, an AssertionError will be raised.
  # This can be disabled with a `validate=False` argument.
  obj_updated.child = jnp.ones(3)

print(obj)
print(obj_updated)

Alternatives

A few other solutions exist for automatically integrating dataclass-style objects into pytree structures. Great ones include: chex.dataclass, flax.struct, and tjax.dataclass. These all influenced this library.

The main differentiators of jax_dataclasses are:

  • Static analysis support. Libraries like dataclasses and attrs rely on tooling-specific custom plugins for static analysis, which don't exist for chex or flax. tjax has a custom mypy plugin to enable type checking, but isn't supported by other tools. Because @jax_dataclasses.pytree_dataclass has the same API as @dataclasses.dataclass, it can include pytree registration behavior at runtime while being treated as the standard decorator during static analysis. This means that all static checkers, language servers, and autocomplete engines that support the standard dataclasses library should work out of the box with jax_dataclasses.

  • Nested dataclasses. Making replacements/modifications in deeply nested dataclasses is generally very frustrating. The three alternatives all introduce a .replace(self, ...) method to dataclasses that's a bit more convenient than the traditional dataclasses.replace(obj, ...) API for shallow changes, but still becomes really cumbersome to use when dataclasses are nested. jax_dataclasses.copy_and_mutate() is introduced to address this.

  • Static field support. Parameters that should not be traced in JAX should be marked as static. This is supported in flax, tjax, and jax_dataclasses, but not chex.

  • Serialization. When working with flax, being able to serialize dataclasses is really handy. This is supported in flax.struct (naturally) and jax_dataclasses, but not chex or tjax.

Misc

This code was originally written for and factored out of jaxfg, where Nick Heppert provided valuable feedback!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_dataclasses-1.0.2.tar.gz (7.2 kB view details)

Uploaded Source

Built Distribution

jax_dataclasses-1.0.2-py3-none-any.whl (7.0 kB view details)

Uploaded Python 3

File details

Details for the file jax_dataclasses-1.0.2.tar.gz.

File metadata

  • Download URL: jax_dataclasses-1.0.2.tar.gz
  • Upload date:
  • Size: 7.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.6.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.61.2 CPython/3.9.6

File hashes

Hashes for jax_dataclasses-1.0.2.tar.gz
Algorithm Hash digest
SHA256 01900dfeb32152b1f6de280cd9a24820a6988dc2566879ab14e1ea9b5c064bb1
MD5 22723e96d3443499d6c346cdb19dbb05
BLAKE2b-256 9d4011cf288d581fb351552cccb52c82f3682df2e01b317539b03fc73da978d8

See more details on using hashes here.

File details

Details for the file jax_dataclasses-1.0.2-py3-none-any.whl.

File metadata

  • Download URL: jax_dataclasses-1.0.2-py3-none-any.whl
  • Upload date:
  • Size: 7.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.6.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.61.2 CPython/3.9.6

File hashes

Hashes for jax_dataclasses-1.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 9027ecc1fde1d020ebf8b81f3391248e36c7e9a8ab9e121918db204ea04261be
MD5 3f4a21f176189aada040c50e9c9b2d1c
BLAKE2b-256 6bec53c8dcb5600e277eb477e5b5d0f95754f8f169af5d95a07ccaa1f1a64261

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page