JAX compatible dataclass.
Project description
Installation |Description |Quick Example |Filtering |StatefulComputation |Applications| Acknowledgements
๐ ๏ธ Installation
pip install pytreeclass
Install development version
pip install git+https://github.com/ASEM000/PyTreeClass
๐ Description
PyTreeClass is a JAX-compatible dataclass-like decorator to create and operate on stateful JAX PyTrees.
The package aims to achieve two goals:
- ๐ To maintain safe and correct behaviour by using immutable modules with functional API.
- To achieve the most intuitive user experience in the
JAXecosystem by :- ๐๏ธ Defining layers similar to
PyTorchorTensorFlowsubclassing style. - โ๏ธ Filtering\Indexing layer values by using boolean masking similar to
jax.numpy.at[].{get,set,apply,...} - ๐จ Visualize defined layers in plethora of ways for better debugging and sharing of information.
- ๐๏ธ Defining layers similar to
โฉ Quick Example
๐๏ธ Create simple MLP
import jax
from jax import numpy as jnp
import pytreeclass as pytc
import matplotlib.pyplot as plt
@pytc.treeclass
class Linear :
# Any variable not wrapped with @pytc.treeclass
# should be declared as a dataclass field here
weight : jnp.ndarray
bias : jnp.ndarray
def __init__(self,key,in_dim,out_dim):
self.weight = jax.random.normal(key,shape=(in_dim, out_dim)) * jnp.sqrt(2/in_dim)
self.bias = jnp.ones((1,out_dim))
def __call__(self,x):
return x @ self.weight + self.bias
@pytc.treeclass
class StackedLinear:
def __init__(self,key,in_dim,out_dim,hidden_dim):
keys= jax.random.split(key,3)
# Declaring l1,l2,l3 as dataclass_fields is optional
# as l1,l2,l3 are Linear class that is wrapped with @pytc.treeclass
# To strictly include nodes defined in dataclass fields
# use `@pytc.treeclass(field_only=True)`
self.l1 = Linear(key=keys[0],in_dim=in_dim,out_dim=hidden_dim)
self.l2 = Linear(key=keys[1],in_dim=hidden_dim,out_dim=hidden_dim)
self.l3 = Linear(key=keys[2],in_dim=hidden_dim,out_dim=out_dim)
def __call__(self,x):
x = self.l1(x)
x = jax.nn.tanh(x)
x = self.l2(x)
x = jax.nn.tanh(x)
x = self.l3(x)
return x
model = StackedLinear(in_dim=1,out_dim=1,hidden_dim=10,key=jax.random.PRNGKey(0))
x = jnp.linspace(0,1,100)[:,None]
y = x**3 + jax.random.uniform(jax.random.PRNGKey(0),(100,1))*0.01
๐จ Visualize
| summary | tree_box | tree_diagram |
print(model.summary())
โโโโโโฌโโโโโโโฌโโโโโโโโฌโโโโโโโโฌโโโโโโโโโโโโโโโโโโ
โNameโType โParam #โSize โConfig โ
โโโโโโผโโโโโโโผโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโค
โl1 โLinearโ20(0) โ80.00B โweight=f32[1,10] โ
โ โ โ โ(0.00B)โbias=f32[1,10] โ
โโโโโโผโโโโโโโผโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโค
โl2 โLinearโ110(0) โ440.00Bโweight=f32[10,10]โ
โ โ โ โ(0.00B)โbias=f32[1,10] โ
โโโโโโผโโโโโโโผโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโค
โl3 โLinearโ11(0) โ44.00B โweight=f32[10,1] โ
โ โ โ โ(0.00B)โbias=f32[1,1] โ
โโโโโโดโโโโโโโดโโโโโโโโดโโโโโโโโดโโโโโโโโโโโโโโโโโโ
Total count : 141(0)
Dynamic count : 141(0)
Frozen count : 0(0)
-----------------------------------------------
Total size : 564.00B(0.00B)
Dynamic size : 564.00B(0.00B)
Frozen size : 0.00B(0.00B)
===============================================
|
using jax.eval_shape (no-flops operation) note : the created modules in print(model.tree_box(array=x))
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โStackedLinear[Parent] โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโโ
โโ โ Input โ f32[100,1] โโ
โโ Linear[l1] โโโโโโโโโโผโโโโโโโโโโโโโโโคโ
โโ โ Output โ f32[100,128] โโ
โโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโโ
โโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโโ
โโ โ Input โ f32[100,128] โโ
โโ Linear[l2] โโโโโโโโโโผโโโโโโโโโโโโโโโคโ
โโ โ Output โ f32[100,128] โโ
โโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโโ
โโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโโ
โโ โ Input โ f32[100,128] โโ
โโ Linear[l3] โโโโโโโโโโผโโโโโโโโโโโโโโโคโ
โโ โ Output โ f32[100,1] โโ
โโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโโ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
print(model.tree_diagram())
StackedLinear
โโโ l1=Linear
โ โโโ weight=f32[1,10]
โ โโโ bias=f32[1,10]
โโโ l2=Linear
โ โโโ weight=f32[10,10]
โ โโโ bias=f32[1,10]
โโโl3=Linear
โโโ weight=f32[10,1]
โโโ bias=f32[1,1]
|
| mermaid.io (Native support in Github/Notion) |
|
โจ Generate shareable vizualization links โจ
# generate mermaid diagrams
# print(pytc.tree_viz.tree_mermaid(model)) # generate core syntax
>>> pytc.tree_viz.tree_mermaid(model,link=True)
# 'Open URL in browser: https://pytreeclass.herokuapp.com/temp/?id=*********'
flowchart LR
id15696277213149321320[StackedLinear]
id15696277213149321320 --> id159132120600507116(l1\nLinear)
id159132120600507116 --- id7500441386962467209["weight\nf32[1,10]"]
id159132120600507116 --- id10793958738030044218["bias\nf32[1,10]"]
id15696277213149321320 --> id10009280772564895168(l2\nLinear)
id10009280772564895168 --- id11951215191344350637["weight\nf32[10,10]"]
id10009280772564895168 --- id1196345851686744158["bias\nf32[1,10]"]
id15696277213149321320 --> id7572222925824649475(l3\nLinear)
id7572222925824649475 --- id4749243995442935477["weight\nf32[10,1]"]
id7572222925824649475 --- id8042761346510512486["bias\nf32[1,1]"]
|
โ๏ธ Model surgery
# freeze l1
from pytreeclass.tree_util import tree_freeze
model = model.at["l1"].set(tree_freeze(model.l1))
# Set negative_values in l2 to 0
filtered_l2 = model.l2.at[model.l2<0].set(0)
model = model.at["l2"].set( filtered_l2 )
# apply sin(x) to all values in l3
filtered_l3 = model.l3.at[...].apply(jnp.sin)
model = model.at["l3"].set(filtered_l3)
# frozen nodes are marked with #
print(model.tree_diagram())
StackedLinear
โ#โ l1=Linear
โ โ#โ weight=f32[1,10]
โ โ#โ bias=f32[1,10]
โโโ l2=Linear
โ โโโ weight=f32[10,10]
โ โโโ bias=f32[1,10]
โโโ l3=Linear
โโโ weight=f32[10,1]
โโโ bias=f32[1,1]
โ๏ธ Filtering with .at[]
PyTreeClass offers four means of filtering:
- Filter by value
- Filter by field name
- Filter by field type
- Filter by field metadata.
The following example demonstrates the usage the filtering. Suppose you have the following (Multilayer perceptron) MLP class
- Note in
StackedLinearl1andl2has a description infieldmetadata.
Model definition
import jax
from jax import numpy as jnp
import pytreeclass as pytc
import matplotlib.pyplot as plt
from dataclasses import field
@pytc.treeclass
class Linear :
weight : jnp.ndarray
bias : jnp.ndarray
def __init__(self,key,in_dim,out_dim):
self.weight = jax.random.normal(key,shape=(in_dim, out_dim)) * jnp.sqrt(2/in_dim)
self.bias = jnp.ones((1,out_dim))
def __call__(self,x):
return x @ self.weight + self.bias
@pytc.treeclass
class StackedLinear:
l1 : Linear = field(metadata={"description": "First layer"})
l2 : Linear = field(metadata={"description": "Second layer"})
def __init__(self,key,in_dim,out_dim,hidden_dim):
keys= jax.random.split(key,3)
self.l1 = Linear(key=keys[0],in_dim=in_dim,out_dim=hidden_dim)
self.l2 = Linear(key=keys[2],in_dim=hidden_dim,out_dim=out_dim)
def __call__(self,x):
x = self.l1(x)
x = jax.nn.tanh(x)
x = self.l2(x)
return x
model = StackedLinear(in_dim=1,out_dim=1,hidden_dim=5,key=jax.random.PRNGKey(0))
- Raw model values before any filtering.
print(model)
StackedLinear(
l1=Linear(
weight=[[-1.6248673 -2.8383057 1.3969219 1.3169124 -0.40784812]],
bias=[[1. 1. 1. 1. 1.]]
),
l2=Linear(
weight=
[[ 0.98507565]
[ 0.99815285]
[-1.0687716 ]
[-0.19255024]
[-1.2108876 ]],
bias=[[1.]]
)
)
Filter by value
- Get all negative values
print(model.at[model<0].get())
StackedLinear(
l1=Linear(
weight=[-1.6248673 -2.8383057 -0.40784812],
bias=[]
),
l2=Linear(
weight=[-1.0687716 -0.19255024 -1.2108876 ],
bias=[]
)
)
- Set negative values to 0
print(model.at[model<0].set(0))
StackedLinear(
l1=Linear(
weight=[[0. 0. 1.3969219 1.3169124 0. ]],
bias=[[1. 1. 1. 1. 1.]]
),
l2=Linear(
weight=
[[0.98507565]
[0.99815285]
[0. ]
[0. ]
[0. ]],
bias=[[1.]]
)
)
- Apply f(x)=x^2 to negative values
print(model.at[model<0].apply(lambda x:x**2))
StackedLinear(
l1=Linear(
weight=[[2.6401937 8.05598 1.3969219 1.3169124 0.16634008]],
bias=[[1. 1. 1. 1. 1.]]
),
l2=Linear(
weight=
[[0.98507565]
[0.99815285]
[1.1422727 ]
[0.03707559]
[1.4662486 ]],
bias=[[1.]]
)
)
- Sum all negative values
print(model.at[model<0].reduce(lambda acc,cur: acc+jnp.sum(cur)))
-7.3432307
Filter by field name
- Get all fields named
l1
print(model.at[model == "l1"].get())
StackedLinear(
l1=Linear(
weight=[-1.6248673 -2.8383057 1.3969219 1.3169124 -0.40784812],
bias=[1. 1. 1. 1. 1.]
),
l2=Linear(weight=[],bias=[])
)
Filter by field type
- Get all fields of
Lineartype
print(model.at[model == Linear].get())
StackedLinear(
l1=Linear(
weight=[-1.6248673 -2.8383057 1.3969219 1.3169124 -0.40784812],
bias=[1. 1. 1. 1. 1.]
),
l2=Linear(
weight=[ 0.98507565 0.99815285 -1.0687716 -0.19255024 -1.2108876 ],
bias=[1.]
)
)
Filter by field metadata
- Get all fields of with their metadata equal to
{"description": "First layer"}
print(model.at[model == {"description": "First layer"}].get())
StackedLinear(
l1=Linear(
weight=[-1.6248673 -2.8383057 1.3969219 1.3169124 -0.40784812],
bias=[1. 1. 1. 1. 1.]
),
l2=Linear(weight=[],bias=[])
)
Mix and match different filtering methods.
- Get only fields named
weightof positive values.
mask = (model == "weight") & (model>0)
print(model.at[mask].get())
StackedLinear(
l1=Linear(weight=[1.3969219 1.3169124],bias=[]),
l2=Linear(weight=[0.98507565 0.99815285],bias=[])
)
Marking fields non-differentiable โจ NEW โจ
Automatically marking fields non-differentiable
In the following code example, we train a model with differentiable and non-differentiable fields.
Using jax.grad will throw an error, however to circumvent this we use pytc.filter_nondiff to filter out any non-differentiable field.
import pytreeclass as pytc
import jax.numpy as jnp
import jax
from typing import Callable
@pytc.treeclass
class Linear:
weight: jnp.ndarray # โ
differentiable
bias: jnp.ndarray # โ
differentiable
other: tuple[int,...] = (1,2,3,4) # โ non-differentiable
a: int = 1 # โ non-differentiable
b: float = 1.0 # โ
differentiable
c: int = 1 # โ non-differentiable
d: float = 2.0 # โ
differentiable
act : Callable = jax.nn.tanh # โ non-differentiable
def __init__(self,in_dim,out_dim):
self.weight = jnp.ones((in_dim,out_dim))
self.bias = jnp.ones((1,out_dim))
def __call__(self,x):
return self.act(self.b+x)
@jax.value_and_grad
def loss_func(model):
# lets optimize a differentiable field `b`
# inside a non-differentiable field `act`
return jnp.mean((model(1.)-0.5)**2)
@jax.jit
def update(model):
value,grad = loss_func(model)
return value,model-1e-3*grad
def train(model,epochs=10_000):
# here we use the filter_nondiff function
# to filter out the non-differentiable fields
# otherwise we would get an error
model = pytc.filter_nondiff(model)
for _ in range(epochs):
value,model = update(model)
return model
# before any filtering or training
model = Linear(1,1)
print(model)
# Linear(
# weight=[[1.]],
# bias=[[1.]],
# other=(1,2,3,4),
# a=1,
# b=1.0,
# c=1,
# d=2.0,
# act=tanh(x)
# )
model = train(model)
# after filtering and training
# note that the non-differentiable fields are not updated
# and the differentiable fields are updated
# the non-differentiable fields are marked with a `*`
print(model)
# Linear(
# weight=[[1.]],
# bias=[[1.]],
# *other=(1,2,3,4),
# *a=1,
# b=-0.36423424,
# *c=1,
# d=2.0,
# *act=tanh(x)
# )
Marking fields non-differentiable with a mask
In the following example, let's say we want to train only the field `b` and mark all other fields non-differentiable, we can simply do this in the following codenew_model = pytc.filter_nondiff(model, model != "b")
# we can see all fields except `b` are marked with
# `*` to mark non-differentiable.
print(new_model)
# Linear(
# *weight=f32[1,1],
# *bias=f32[1,1],
# *other=(1,2,3,4),
# *a=1,
# b=f32[],
# *c=1,
# *d=f32[],
# *act=tanh(x)
# )
# undo the filtering
# note the removal of `*` that marks non-diff fields
unfiltered_model = pytc.unfilter_nondiff(new_model)
print(unfiltered_model)
# Linear(
# weight=f32[1,1],
# bias=f32[1,1],
# other=(1,2,3,4),
# a=1,
# b=f32[],
# c=1,
# d=f32[],
# act=tanh(x)
# )
๐ Stateful computations
First, Under jax.jit jax requires states to be explicit, this means that for any class instance; variables needs to be separated from the class and be passed explictly. However when using @pytc.treeclass no need to separate the instance variables ; instead the whole instance is passed as a state.
Using the following pattern,Updating state functionally can be achieved under jax.jit
import jax
import pytreeclass as pytc
@pytc.treeclass
class Counter:
calls : int = 0.
def increment(self):
self.calls += 1
counter = Counter() # Counter(calls=0.0)
Here, we define the update function. Since the increment method mutate the internal state, thus we need to use the functional approach to update the state by using .at. To achieve this we can use .at[method_name].__call__(*args,**kwargs), this functional call will return the value of this call and a new model instance with the update state.
@jax.jit
def update(counter):
value, new_counter = counter.at["increment"]()
return new_counter
for i in range(10):
counter = update(counter)
print(counter.calls) # 10.0
๐ Applications
Check other packages built on top of PyTreeClass
๐ Acknowledgements
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pytreeclass-0.1.6.tar.gz.
File metadata
- Download URL: pytreeclass-0.1.6.tar.gz
- Upload date:
- Size: 51.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.10.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8b3f70591f6a222b9298acc22d8039a0ff14b7cc4a16bff23ef6417f77ce0027
|
|
| MD5 |
bccc336c3097e43102d39bc09ca37ff9
|
|
| BLAKE2b-256 |
58c2b66b992042d2eb6fee7f7490164aa04d9628a772e1344073aa8dbfe015a7
|
File details
Details for the file pytreeclass-0.1.6-py3-none-any.whl.
File metadata
- Download URL: pytreeclass-0.1.6-py3-none-any.whl
- Upload date:
- Size: 57.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.10.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d9ceab664253e46559ad36fcbb88013a2563de086615301fa1d9971b654998ef
|
|
| MD5 |
88cfafee52e069e31baed9faa779b597
|
|
| BLAKE2b-256 |
b3a689f7429e323d37300bf26466219a760aea3852d2b7786af06762e96af1ff
|