Skip to main content

Drinx: Dataclass Registry in JAX

Project description

title image

Documentation PyPI version codecov Continuous integration

Drinx: Dataclass Registry in JAX ๐Ÿฅ‚

Often it is useful to have structures in a program containing a mixture of JAX arrays and non-JAX types (e.g. strings, ...). But, this makes it difficult to pass these objects through JAX transformations. Drinx solves this by allowing dataclass fields to be declared as static. Moreover, drinx introduces numerous quality-of-life features when working with dataclasses in JAX.

Installation

You can install drinx simply via

pip install drinx

If you want to use the GPU-acceleration from JAX, you can install afterwards:

pip install jax[cuda]

Quickstart

Below you can find some examples to get you quickly started with drinx. But, beware, there are so much more features available, which are documented in detail in our Documentation

Decorator style

Use @drinx.dataclass as a drop-in replacement for @dataclasses.dataclass. The class is automatically frozen and registered as a JAX pytree:

import jax
import jax.numpy as jnp
import drinx

@drinx.dataclass
class Params:
    weights: jax.Array
    bias: jax.Array

params = Params(weights=jnp.ones((3,)), bias=jnp.zeros((3,)))

# Works transparently with JAX transforms
doubled = jax.tree_util.tree_map(lambda x: x * 2, params)

Static fields

Fields that should not be traced by JAX (e.g. shapes, dtypes, hyperparameters) are marked with static_field or field(static=True). Changing a static field triggers recompilation under jit:

@drinx.dataclass
class Model:
    weights: jax.Array
    hidden_size: int = drinx.static_field(default=128)

@jax.jit
def forward(model, x):
    # hidden_size is a compile-time constant; weights are traced
    return model.weights[:model.hidden_size] @ x

model = Model(weights=jnp.ones((128, 32)))

Inheritance style

Subclass DataClass instead of using the decorator. The transform is applied automatically โ€” no @dataclass needed:

class Model(drinx.DataClass):
    weights: jax.Array
    learning_rate: float = drinx.static_field(default=1e-3)

model = Model(weights=jnp.ones((10,)))

Dataclass options are forwarded via the class definition, or alternatively by using a combination of inheritance and decorator.

class Config(drinx.DataClass, kw_only=True, order=True):
    hidden_size: int = drinx.static_field(default=128)
    num_layers: int = drinx.static_field(default=4)

# This is the recommended way: Typechecker will recognize the kw_only argument correctly
@drinx.dataclass(kw_only=True, order=True)
class Config(drinx.DataClass):
    hidden_size: int = drinx.static_field(default=128)
    num_layers: int = drinx.static_field(default=4)

Functional updates with aset

Because drinx dataclasses are frozen, fields cannot be mutated in place. aset performs a functional update and returns a new instance. It supports nested paths using -> as a separator, integer indices [n], and string dictionary keys ['k']. Note that this function is only available when inheriting the drinx.Dataclass, but not from the decorator.

class Inner(drinx.DataClass):
    w: jax.Array

class Outer(drinx.DataClass):
    inner: Inner
    bias: jax.Array

outer = Outer(inner=Inner(w=jnp.ones((3,))), bias=jnp.zeros((1,)))

# Update a top-level field
outer2 = outer.aset("bias", jnp.ones((1,)))

# Update a nested field
outer3 = outer.aset("inner->w", jnp.zeros((3,)))

JAX transforms

Drinx dataclasses work with all JAX transforms out of the box:

class State(drinx.DataClass):
    x: jax.Array
    step_size: float = drinx.static_field(default=0.1)

# jit
@jax.jit
def update(state):
    # updated_copy is convenience wrapper for altering top-level attributes
    return state.updated_copy(x=state.x - state.step_size)

def loss(state):
    return jnp.sum(state.x ** 2)

grads = jax.grad(loss)(State(x=jnp.array([1.0, 2.0, 3.0])))

@jax.vmap
def scale(state):
    return state.x * 2

batched = State(x=jnp.array([[1.0, 2.0], [3.0, 4.0]]))
result = scale(batched)  # shape (2, 2)

Visualization

tree_diagram and tree_summary let you inspect any JAX pytree at a glance:

class Encoder(drinx.DataClass):
    w: jax.Array
    b: jax.Array

class Model(drinx.DataClass):
    encoder: Encoder
    head: jax.Array

model = Model(encoder=Encoder(w=jnp.ones((16, 32)), b=jnp.zeros((16,))), head=jnp.ones((4, 16)))

print(drinx.tree_diagram(model))
# Model
# โ”œโ”€โ”€ .encoder:Encoder
# โ”‚   โ”œโ”€โ”€ .w=f32[16,32] โˆˆ [1.0, 1.0], ฮผ=1.0, ฯƒ=0.0
# โ”‚   โ””โ”€โ”€ .b=f32[16] โˆˆ [0.0, 0.0], ฮผ=0.0, ฯƒ=0.0
# โ””โ”€โ”€ .head=f32[4,16] โˆˆ [1.0, 1.0], ฮผ=1.0, ฯƒ=0.0

print(drinx.tree_summary(model))
# โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
# โ”‚Name          โ”‚Type      โ”‚Count  โ”‚Size    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚.encoder.w    โ”‚f32[16,32]โ”‚512    โ”‚2.00KB  โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚.encoder.b    โ”‚f32[16]   โ”‚16     โ”‚64.00B  โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚.head         โ”‚f32[4,16] โ”‚64     โ”‚256.00B โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚ฮฃ             โ”‚Tree      โ”‚592    โ”‚2.31KB  โ”‚
# โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

Documentation

For more examples and a detailed documentation, check out the API here.

Citation

TODO: add citation once published

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

drinx-1.1.0.tar.gz (1.7 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

drinx-1.1.0-py3-none-any.whl (17.4 kB view details)

Uploaded Python 3

File details

Details for the file drinx-1.1.0.tar.gz.

File metadata

  • Download URL: drinx-1.1.0.tar.gz
  • Upload date:
  • Size: 1.7 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for drinx-1.1.0.tar.gz
Algorithm Hash digest
SHA256 10e3c5cc81a91503485dda54e999afadab491823abed667d75c0480315f9b12c
MD5 c5c16ff5fa9d22e067f1c273e8caa05b
BLAKE2b-256 f36a699aa508991a82a213e95f4e0ae36d8ff6a8e2602d0758e6daed26ec2ebc

See more details on using hashes here.

Provenance

The following attestation bundles were made for drinx-1.1.0.tar.gz:

Publisher: publish.yml on ymahlau/drinx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file drinx-1.1.0-py3-none-any.whl.

File metadata

  • Download URL: drinx-1.1.0-py3-none-any.whl
  • Upload date:
  • Size: 17.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for drinx-1.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 176e8c767eeca043854ccafcd91fd3cfb75f31d8e1d8e3f6480f204c135658f1
MD5 df2d71d36233e8afb1822a747ab99e4c
BLAKE2b-256 9e780d24a65c1e1aea788e270c68d25bcacacd8273701af9114ca58f35c7385d

See more details on using hashes here.

Provenance

The following attestation bundles were made for drinx-1.1.0-py3-none-any.whl:

Publisher: publish.yml on ymahlau/drinx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page