Skip to main content

Dataclasses + JAX

Project description

jax_dataclasses

build mypy lint codecov

Overview

jax_dataclasses provides a wrapper around dataclasses.dataclass for use in JAX, which enables automatic support for:

  • Pytree registration. This allows dataclasses to be used at API boundaries in JAX. (necessary for function transformations, JIT, etc)
  • Serialization via flax.serialization.
  • Static analysis with tools like mypy, jedi, pyright, etc. (including for constructors)
  • Optional shape and data-type annotations, which are checked at runtime.

Heavily influenced by some great existing work (the obvious one being flax.struct.dataclass); see Alternatives for comparisons.

Installation

The latest version of jax_dataclasses requires Python>=3.7. Python 3.6 will work as well, but is missing support for shape annotations.

pip install jax_dataclasses

We can then import:

import jax_dataclasses as jdc

Core interface

jax_dataclasses is meant to provide a drop-in replacement for dataclasses.dataclass:

  • jdc.pytree_dataclass has the same interface as dataclasses.dataclass, but also registers the target class as a pytree container.
  • jdc.static_field has the same interface as dataclasses.field, but will also mark the field as static. In a pytree node, static fields will be treated as part of the treedef instead of as a child of the node; all fields that are not explicitly marked static should contain arrays or child nodes.

We also provide several aliases: jdc.[field, asdict, astuples, is_dataclass, replace] are all identical to their counterparts in the standard dataclasses library.

Mutations

All dataclasses are automatically marked as frozen and thus immutable (even when no frozen= parameter is passed in). To make changes to nested structures easier, jdc.copy_and_mutate (a) makes a copy of a pytree and (b) returns a context in which any of that copy's contained dataclasses are temporarily mutable:

from jax import numpy as jnp
import jax_dataclasses as jdc

@jdc.pytree_dataclass
class Node:
  child: jnp.ndarray

obj = Node(child=jnp.zeros(3))

with jdc.copy_and_mutate(obj) as obj_updated:
  # Make mutations to the dataclass. This is primarily useful for nested
  # dataclasses.
  #
  # Also does input validation: if the treedef, leaf shapes, or dtypes of `obj`
  # and `obj_updated` don't match, an AssertionError will be raised.
  # This can be disabled with a `validate=False` argument.
  obj_updated.child = jnp.ones(3)

print(obj)
print(obj_updated)

Shape and data-type annotations

Subclassing from jdc.EnforcedAnnotationsMixin enables automatic shape and data-type validation. Arrays contained within dataclasses are validated on instantiation and a .get_batch_axes() method is exposed for grabbing any common batch axes to the shapes of contained arrays.

We can start by importing the standard Annotated type:

# Python >=3.9
from typing import Annotated

# Backport
from typing_extensions import Annotated

We can then add shape annotations:

@jdc.pytree_dataclass
class MnistStruct(jdc.EnforcedAnnotationsMixin):
    image: Annotated[
        jnp.ndarray,
        # Note that we can move the expected location of the batch axes by
        # shifting the ellipsis around.
        #
        # If the ellipsis is excluded, we assume batch axes at the start of the
        # shape.
        (..., 28, 28),
    ]
    label: Annotated[
        jnp.ndarray,
        (..., 10),
    ]

Or data-type annotations:

    image: Annotated[
        jnp.ndarray,
        jnp.float32,
    ]
    label: Annotated[
        jnp.ndarray,
        jnp.integer,
    ]

Or both (note that annotations are order-invariant):

    image: Annotated[
        jnp.ndarray,
        (..., 28, 28),
        jnp.float32,
    ]
    label: Annotated[
        jnp.ndarray,
        (..., 10),
        jnp.integer,
    ]

Then, assuming we've constrained both the shape and data-type:

# OK
struct = MnistStruct(
  image=onp.zeros((28, 28), dtype=onp.float32),
  label=onp.zeros((10,), dtype=onp.uint8),
)
print(struct.get_batch_axes()) # Prints ()

# OK
struct = MnistStruct(
  image=onp.zeros((32, 28, 28), dtype=onp.float32),
  label=onp.zeros((32, 10), dtype=onp.uint8),
)
print(struct.get_batch_axes()) # Prints (32,)

# AssertionError on instantiation because of type mismatch
MnistStruct(
  image=onp.zeros((28, 28), dtype=onp.float32),
  label=onp.zeros((10,), dtype=onp.float32), # Not an integer type!
)

# AssertionError on instantiation because of shape mismatch
MnistStruct(
  image=onp.zeros((28, 28), dtype=onp.float32),
  label=onp.zeros((5,), dtype=onp.uint8),
)

# AssertionError on instantiation because of batch axis mismatch
struct = MnistStruct(
  image=onp.zeros((64, 28, 28), dtype=onp.float32),
  label=onp.zeros((32, 10), dtype=onp.uint8),
)

Alternatives

A few other solutions exist for automatically integrating dataclass-style objects into pytree structures. Great ones include: chex.dataclass, flax.struct, and tjax.dataclass. These all influenced this library.

The main differentiators of jax_dataclasses are:

  • Static analysis support. tjax has a custom mypy plugin to enable type checking, but isn't supported by other tools. flax.struct implements the dataclass_transform spec proposed by pyright, but isn't supported by other tools. Because @jdc.pytree_dataclass has the same API as @dataclasses.dataclass, it can include pytree registration behavior at runtime while being treated as the standard decorator during static analysis. This means that all static checkers, language servers, and autocomplete engines that support the standard dataclasses library should work out of the box with jax_dataclasses.

  • Nested dataclasses. Making replacements/modifications in deeply nested dataclasses can be really frustrating. The three alternatives all introduce a .replace(self, ...) method to dataclasses that's a bit more convenient than the traditional dataclasses.replace(obj, ...) API for shallow changes, but still becomes really cumbersome to use when dataclasses are nested. jdc.copy_and_mutate() is introduced to address this.

  • Static field support. Parameters that should not be traced in JAX should be marked as static. This is supported in flax, tjax, and jax_dataclasses, but not chex.

  • Serialization. When working with flax, being able to serialize dataclasses is really handy. This is supported in flax.struct (naturally) and jax_dataclasses, but not chex or tjax.

  • Shape and type annotations. See above.

You can also eschew the dataclass-style interface entirely; see how brax registers pytrees. This is a reasonable thing to prefer: it requires some floating strings and breaks things that I care about but you may not (like immutability and __post_init__), but gives more flexibility with custom __init__ methods.

Misc

This code was originally written for and factored out of jaxfg, where Nick Heppert provided valuable feedback.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_dataclasses-1.3.0.tar.gz (12.9 kB view details)

Uploaded Source

Built Distribution

jax_dataclasses-1.3.0-py3-none-any.whl (11.6 kB view details)

Uploaded Python 3

File details

Details for the file jax_dataclasses-1.3.0.tar.gz.

File metadata

  • Download URL: jax_dataclasses-1.3.0.tar.gz
  • Upload date:
  • Size: 12.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.5

File hashes

Hashes for jax_dataclasses-1.3.0.tar.gz
Algorithm Hash digest
SHA256 cdbafd1e9acae808d3b4016848e0ad8a56c11e0beb5349ca1a9f690bdc026056
MD5 299cb5cca997e3250b0439e44e10baca
BLAKE2b-256 e60c0dff5058db1477cba05f84652e415ff23c9c0de7ceeb22f186e386f5cfad

See more details on using hashes here.

File details

Details for the file jax_dataclasses-1.3.0-py3-none-any.whl.

File metadata

File hashes

Hashes for jax_dataclasses-1.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 85c1d69ec077dee7ab0d5bed7d5eba9236578257f9a516f1ed9346f98f4b3c43
MD5 4d301d02fe91d9984d6b1cf0e8632576
BLAKE2b-256 4d544b10c8b941a39e583d7d1f250f4edf0f55a04278f5b6f4b33c6006d2031b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page