Drinx: Dataclass Registry in JAX
Project description
Drinx: Dataclass Registry in JAX ๐ฅ
Often it is useful to have structures in a program containing a mixture of JAX arrays and non-JAX types (e.g. strings, ...). But, this makes it difficult to pass these objects through JAX transformations. Drinx solves this by allowing dataclass fields to be declared as static. Moreover, drinx introduces numerous quality-of-life features when working with dataclasses in JAX.
Installation
You can install drinx simply via
pip install drinx
If you want to use the GPU-acceleration from JAX, you can install afterwards:
pip install jax[cuda]
Quickstart
Below you can find some examples to get you quickly started with drinx. But, beware, there are so much more features available, which are documented in detail in our Documentation
Decorator style
Use @drinx.dataclass as a drop-in replacement for @dataclasses.dataclass.
The class is automatically frozen and registered as a JAX pytree:
import jax
import jax.numpy as jnp
import drinx
@drinx.dataclass
class Params:
weights: jax.Array
bias: jax.Array
params = Params(weights=jnp.ones((3,)), bias=jnp.zeros((3,)))
# Works transparently with JAX transforms
doubled = jax.tree_util.tree_map(lambda x: x * 2, params)
Static fields
Fields that should not be traced by JAX (e.g. shapes, dtypes, hyperparameters)
are marked with static_field or field(static=True). Changing a static
field triggers recompilation under jit:
@drinx.dataclass
class Model:
weights: jax.Array
hidden_size: int = drinx.static_field(default=128)
@jax.jit
def forward(model, x):
# hidden_size is a compile-time constant; weights are traced
return model.weights[:model.hidden_size] @ x
model = Model(weights=jnp.ones((128, 32)))
Inheritance style
Subclass DataClass instead of using the decorator. The transform is applied
automatically โ no @dataclass needed:
class Model(drinx.DataClass):
weights: jax.Array
learning_rate: float = drinx.static_field(default=1e-3)
model = Model(weights=jnp.ones((10,)))
Dataclass options are forwarded via the class definition, or alternatively by using a combination of inheritance and decorator.
class Config(drinx.DataClass, kw_only=True, order=True):
hidden_size: int = drinx.static_field(default=128)
num_layers: int = drinx.static_field(default=4)
# This is the recommended way: Typechecker will recognize the kw_only argument correctly
@drinx.dataclass(kw_only=True, order=True)
class Config(drinx.DataClass):
hidden_size: int = drinx.static_field(default=128)
num_layers: int = drinx.static_field(default=4)
Functional updates with aset
Because drinx dataclasses are frozen, fields cannot be mutated in place.
aset performs a functional update and returns a new instance. It supports
nested paths using -> as a separator, integer indices [n], and string
dictionary keys ['k'].
Note that this function is only available when inheriting the drinx.Dataclass, but not from the decorator.
class Inner(drinx.DataClass):
w: jax.Array
class Outer(drinx.DataClass):
inner: Inner
bias: jax.Array
outer = Outer(inner=Inner(w=jnp.ones((3,))), bias=jnp.zeros((1,)))
# Update a top-level field
outer2 = outer.aset("bias", jnp.ones((1,)))
# Update a nested field
outer3 = outer.aset("inner->w", jnp.zeros((3,)))
JAX transforms
Drinx dataclasses work with all JAX transforms out of the box:
class State(drinx.DataClass):
x: jax.Array
step_size: float = drinx.static_field(default=0.1)
# jit
@jax.jit
def update(state):
# updated_copy is convenience wrapper for altering top-level attributes
return state.updated_copy(x=state.x - state.step_size)
def loss(state):
return jnp.sum(state.x ** 2)
grads = jax.grad(loss)(State(x=jnp.array([1.0, 2.0, 3.0])))
@jax.vmap
def scale(state):
return state.x * 2
batched = State(x=jnp.array([[1.0, 2.0], [3.0, 4.0]]))
result = scale(batched) # shape (2, 2)
Visualization
tree_diagram and tree_summary let you inspect any JAX pytree at a glance:
class Encoder(drinx.DataClass):
w: jax.Array
b: jax.Array
class Model(drinx.DataClass):
encoder: Encoder
head: jax.Array
model = Model(encoder=Encoder(w=jnp.ones((16, 32)), b=jnp.zeros((16,))), head=jnp.ones((4, 16)))
print(drinx.tree_diagram(model))
# Model
# โโโ .encoder:Encoder
# โ โโโ .w=f32[16,32] โ [1.0, 1.0], ฮผ=1.0, ฯ=0.0
# โ โโโ .b=f32[16] โ [0.0, 0.0], ฮผ=0.0, ฯ=0.0
# โโโ .head=f32[4,16] โ [1.0, 1.0], ฮผ=1.0, ฯ=0.0
print(drinx.tree_summary(model))
# โโโโโโโโโโโโโโโโฌโโโโโโโโโโโฌโโโโโโโโฌโโโโโโโโโ
# โName โType โCount โSize โ
# โโโโโโโโโโโโโโโโผโโโโโโโโโโโผโโโโโโโโผโโโโโโโโโค
# โ.encoder.w โf32[16,32]โ512 โ2.00KB โ
# โโโโโโโโโโโโโโโโผโโโโโโโโโโโผโโโโโโโโผโโโโโโโโโค
# โ.encoder.b โf32[16] โ16 โ64.00B โ
# โโโโโโโโโโโโโโโโผโโโโโโโโโโโผโโโโโโโโผโโโโโโโโโค
# โ.head โf32[4,16] โ64 โ256.00B โ
# โโโโโโโโโโโโโโโโผโโโโโโโโโโโผโโโโโโโโผโโโโโโโโโค
# โฮฃ โTree โ592 โ2.31KB โ
# โโโโโโโโโโโโโโโโดโโโโโโโโโโโดโโโโโโโโดโโโโโโโโโ
Documentation
For more examples and a detailed documentation, check out the API here.
Citation
TODO: add citation once published
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file drinx-1.1.0.tar.gz.
File metadata
- Download URL: drinx-1.1.0.tar.gz
- Upload date:
- Size: 1.7 MB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
10e3c5cc81a91503485dda54e999afadab491823abed667d75c0480315f9b12c
|
|
| MD5 |
c5c16ff5fa9d22e067f1c273e8caa05b
|
|
| BLAKE2b-256 |
f36a699aa508991a82a213e95f4e0ae36d8ff6a8e2602d0758e6daed26ec2ebc
|
Provenance
The following attestation bundles were made for drinx-1.1.0.tar.gz:
Publisher:
publish.yml on ymahlau/drinx
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
drinx-1.1.0.tar.gz -
Subject digest:
10e3c5cc81a91503485dda54e999afadab491823abed667d75c0480315f9b12c - Sigstore transparency entry: 1096286102
- Sigstore integration time:
-
Permalink:
ymahlau/drinx@8dbf3249b65a55ac9bb88f76133435799e7538e3 -
Branch / Tag:
refs/tags/v1.1.0 - Owner: https://github.com/ymahlau
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@8dbf3249b65a55ac9bb88f76133435799e7538e3 -
Trigger Event:
release
-
Statement type:
File details
Details for the file drinx-1.1.0-py3-none-any.whl.
File metadata
- Download URL: drinx-1.1.0-py3-none-any.whl
- Upload date:
- Size: 17.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
176e8c767eeca043854ccafcd91fd3cfb75f31d8e1d8e3f6480f204c135658f1
|
|
| MD5 |
df2d71d36233e8afb1822a747ab99e4c
|
|
| BLAKE2b-256 |
9e780d24a65c1e1aea788e270c68d25bcacacd8273701af9114ca58f35c7385d
|
Provenance
The following attestation bundles were made for drinx-1.1.0-py3-none-any.whl:
Publisher:
publish.yml on ymahlau/drinx
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
drinx-1.1.0-py3-none-any.whl -
Subject digest:
176e8c767eeca043854ccafcd91fd3cfb75f31d8e1d8e3f6480f204c135658f1 - Sigstore transparency entry: 1096286104
- Sigstore integration time:
-
Permalink:
ymahlau/drinx@8dbf3249b65a55ac9bb88f76133435799e7538e3 -
Branch / Tag:
refs/tags/v1.1.0 - Owner: https://github.com/ymahlau
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@8dbf3249b65a55ac9bb88f76133435799e7538e3 -
Trigger Event:
release
-
Statement type: