Visualize, create, and operate on JAX PyTree in the most intuitive way possible.
Project description
Installation |Description |Quick Example |StatefulComputation |Benchamrks |Acknowledgements
๐ ๏ธ Installation
pip install pytreeclass
Install development version
pip install git+https://github.com/ASEM000/PyTreeClass
๐ Description
PyTreeClass
is a JAX-compatible class builder to create and operate on stateful JAX PyTrees in a performant and intuitive way, by building on familiar concepts found in numpy
, dataclasses
, and others.
See documentation and ๐ณ Common recipes to check if this library is a good fit for your work. If you find the package useful consider giving it a ๐.
โฉ Quick Example
import jax
import jax.numpy as jnp
import pytreeclass as pytc
@pytc.autoinit
class Tree(pytc.TreeClass):
a: float = 1.0
b: tuple[float, float] = (2.0, 3.0)
c: jax.Array = jnp.array([4.0, 5.0, 6.0])
def __call__(self, x):
return self.a + self.b[0] + self.c + x
tree = Tree()
mask = jax.tree_map(lambda x: x > 5, tree)
tree = tree\
.at["a"].set(100.0)\
.at["b"].at[0].set(10.0)\
.at[mask].set(100.0)
print(tree)
# Tree(a=100.0, b=(10.0, 3.0), c=[ 4. 5. 100.])
print(pytc.tree_diagram(tree))
# Tree
# โโโ .a=100.0
# โโโ .b:tuple
# โ โโโ [0]=10.0
# โ โโโ [1]=3.0
# โโโ .c=f32[3](ฮผ=36.33, ฯ=45.02, โ[4.00,100.00])
print(pytc.tree_summary(tree))
# โโโโโโโฌโโโโโโโฌโโโโโโฌโโโโโโโ
# โName โType โCountโSize โ
# โโโโโโโผโโโโโโโผโโโโโโผโโโโโโโค
# โ.a โfloat โ1 โ โ
# โโโโโโโผโโโโโโโผโโโโโโผโโโโโโโค
# โ.b[0]โfloat โ1 โ โ
# โโโโโโโผโโโโโโโผโโโโโโผโโโโโโโค
# โ.b[1]โfloat โ1 โ โ
# โโโโโโโผโโโโโโโผโโโโโโผโโโโโโโค
# โ.c โf32[3]โ3 โ12.00Bโ
# โโโโโโโผโโโโโโโผโโโโโโผโโโโโโโค
# โฮฃ โTree โ6 โ12.00Bโ
# โโโโโโโดโโโโโโโดโโโโโโดโโโโโโโ
# ** pass it to jax transformations **
# works with jit, grad, vmap, etc.
@jax.jit
@jax.grad
def sum_tree(tree: Tree, x):
return sum(tree(x))
print(sum_tree(tree, 1.0))
# Tree(a=3.0, b=(3.0, 0.0), c=[1. 1. 1.])
|
๐ Stateful computations
Under jax.jit jax requires states to be explicit, this means that for any class instance; variables needs to be separated from the class and be passed explictly. However when using TreeClass
no need to separate the instance variables ; instead the whole instance is passed as a state.
Using the following pattern,Updating state functionally can be achieved under jax.jit
import jax
import pytreeclass as pytc
class Counter(pytc.TreeClass):
def __init__(self, calls: int = 0):
self.calls = calls
def increment(self):
self.calls += 1
counter = Counter() # Counter(calls=0)
|
Here, we define the update function. Since the increment method mutate the internal state, thus we need to use the functional approach to update the state by using .at
. To achieve this we can use .at[method_name].__call__(*args,**kwargs)
, this functional call will return the value of this call and a new model instance with the update state.
@jax.jit
def update(counter):
value, new_counter = counter.at["increment"]()
return new_counter
for i in range(10):
counter = update(counter)
print(counter.calls) # 10
|
โ Benchmarks
Benchmark simple training against `flax` and `equinox`
Training simple sequential linear benchmark against flax
and equinox
Num of layers | Flax/PyTC time |
Equinox/PyTC time |
10 | 1.427 | 6.671 |
100 | 1.1130 | 2.714 |
๐ Acknowledgements
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for pytreeclass-0.6.0.post0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 498ef3be6b5c0181faee9753c645af5c1bc9363253bd2253a23cffb72b4695f7 |
|
MD5 | 30a905c59db3971fa56d72f84d12bcef |
|
BLAKE2b-256 | 818a8080bbfd30669b9b2c22ab36bc41d85cba23f9e3ff20156a68dddfcdba49 |