Visualize, create, and operate on JAX PyTree in the most intuitive way possible.
Project description
Installation |Description |Quick Example |StatefulComputation |Benchamrks |Acknowledgements
๐ ๏ธ Installation
pip install pytreeclass
Install development version
pip install git+https://github.com/ASEM000/pytreeclass
๐ Description
pytreeclass
is a JAX-compatible class builder to create and operate on stateful JAX PyTrees in a performant and intuitive way, by building on familiar concepts found in numpy
, dataclasses
, and others.
See documentation and ๐ณ Common recipes to check if this library is a good fit for your work. If you find the package useful consider giving it a ๐.
โฉ Quick Example
import jax
import jax.numpy as jnp
import pytreeclass as tc
@tc.autoinit
class Tree(tc.TreeClass):
a: float = 1.0
b: tuple[float, float] = (2.0, 3.0)
c: jax.Array = jnp.array([4.0, 5.0, 6.0])
def __call__(self, x):
return self.a + self.b[0] + self.c + x
tree = Tree()
mask = jax.tree_map(lambda x: x > 5, tree)
tree = tree\
.at["a"].set(100.0)\
.at["b"][0].set(10.0)\
.at[mask].set(100.0)
print(tree)
# Tree(a=100.0, b=(10.0, 3.0), c=[ 4. 5. 100.])
print(tc.tree_diagram(tree))
# Tree
# โโโ .a=100.0
# โโโ .b:tuple
# โ โโโ [0]=10.0
# โ โโโ [1]=3.0
# โโโ .c=f32[3](ฮผ=36.33, ฯ=45.02, โ[4.00,100.00])
print(tc.tree_summary(tree))
# โโโโโโโฌโโโโโโโฌโโโโโโฌโโโโโโโ
# โName โType โCountโSize โ
# โโโโโโโผโโโโโโโผโโโโโโผโโโโโโโค
# โ.a โfloat โ1 โ โ
# โโโโโโโผโโโโโโโผโโโโโโผโโโโโโโค
# โ.b[0]โfloat โ1 โ โ
# โโโโโโโผโโโโโโโผโโโโโโผโโโโโโโค
# โ.b[1]โfloat โ1 โ โ
# โโโโโโโผโโโโโโโผโโโโโโผโโโโโโโค
# โ.c โf32[3]โ3 โ12.00Bโ
# โโโโโโโผโโโโโโโผโโโโโโผโโโโโโโค
# โฮฃ โTree โ6 โ12.00Bโ
# โโโโโโโดโโโโโโโดโโโโโโดโโโโโโโ
# ** pass it to jax transformations **
# works with jit, grad, vmap, etc.
@jax.jit
@jax.grad
def sum_tree(tree: Tree, x):
return sum(tree(x))
print(sum_tree(tree, 1.0))
# Tree(a=3.0, b=(3.0, 0.0), c=[1. 1. 1.])
|
๐ Stateful computations
Under jax.jit jax requires states to be explicit, this means that for any class instance; variables needs to be separated from the class and be passed explictly. However when using TreeClass
no need to separate the instance variables ; instead the whole instance is passed as a state.
Using the following pattern,Updating state functionally can be achieved under jax.jit
import jax
import pytreeclass as tc
class Counter(tc.TreeClass):
def __init__(self, calls: int = 0):
self.calls = calls
def increment(self):
self.calls += 1
counter = Counter() # Counter(calls=0)
|
Here, we define the update function. Since the increment method mutate the internal state, thus we need to use the functional approach to update the state by using .at
. To achieve this we can use .at[method_name].__call__(*args,**kwargs)
, this functional call will return the value of this call and a new model instance with the update state.
@jax.jit
def update(counter):
value, new_counter = counter.at["increment"]()
return new_counter
for i in range(10):
counter = update(counter)
print(counter.calls) # 10
|
โ Benchmarks
Benchmark simple training against `flax` and `equinox`
Training simple sequential linear benchmark against flax
and equinox
Num of layers | Flax/tc time |
Equinox/tc time |
10 | 1.427 | 6.671 |
100 | 1.1130 | 2.714 |
๐ Acknowledgements
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file pytreeclass-0.9.2.tar.gz
.
File metadata
- Download URL: pytreeclass-0.9.2.tar.gz
- Upload date:
- Size: 55.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | dbc2b13cead4c4ab3bb2cb026fce44ae673938d38d3e091e7f7c87ccaa1b7db8 |
|
MD5 | 2815430d5d64eeccbc527d25f11a588f |
|
BLAKE2b-256 | dd4f4970819ca7424d551ac4c02efb1cb2e0d20fb320a41998a9898d36af4c2e |
File details
Details for the file pytreeclass-0.9.2-py3-none-any.whl
.
File metadata
- Download URL: pytreeclass-0.9.2-py3-none-any.whl
- Upload date:
- Size: 46.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | de379df40a58bce323e4be1a55453fd8adcaa482f27c4024fa5bc05768589163 |
|
MD5 | 6b9f32658af5f559898dc95959c60fcb |
|
BLAKE2b-256 | bb63dda8f04586d299e895ddbbeaa0d18b994be047017142fd791c20cc5840e5 |