JAX compatible dataclass.
Project description
Installation |Description |Quick Example |Filtering |StatefulComputation |Applications| Acknowledgements
๐ ๏ธ Installation
pip install pytreeclass
Install development version
pip install git+https://github.com/ASEM000/PyTreeClass
๐ Description
PyTreeClass
is a JAX-compatible dataclass
-like decorator to create and operate on stateful JAX PyTrees.
The package aims to achieve two goals:
- ๐ To maintain safe and correct behaviour by using immutable modules with functional API.
- To achieve the most intuitive user experience in the
JAX
ecosystem by :- ๐๏ธ Defining layers similar to
PyTorch
orTensorFlow
subclassing style. - โ๏ธ Filtering\Indexing layer values by using boolean masking similar to
jax.numpy.at[].{get,set,apply,...}
- ๐จ Visualize defined layers in plethora of ways for better debugging and sharing of information.
- ๐๏ธ Defining layers similar to
โฉ Quick Example
๐๏ธ Create simple MLP
import jax
from jax import numpy as jnp
import pytreeclass as pytc
import matplotlib.pyplot as plt
@pytc.treeclass
class Linear :
# Any variable not wrapped with @pytc.treeclass
# should be declared as a dataclass field here
weight : jnp.ndarray
bias : jnp.ndarray
def __init__(self,key,in_dim,out_dim):
self.weight = jax.random.normal(key,shape=(in_dim, out_dim)) * jnp.sqrt(2/in_dim)
self.bias = jnp.ones((1,out_dim))
def __call__(self,x):
return x @ self.weight + self.bias
@pytc.treeclass
class StackedLinear:
def __init__(self,key,in_dim,out_dim,hidden_dim):
keys= jax.random.split(key,3)
# Declaring l1,l2,l3 as dataclass_fields is optional
# as l1,l2,l3 are Linear class that is wrapped with @pytc.treeclass
# To strictly include nodes defined in dataclass fields
# use `@pytc.treeclass(field_only=True)`
self.l1 = Linear(key=keys[0],in_dim=in_dim,out_dim=hidden_dim)
self.l2 = Linear(key=keys[1],in_dim=hidden_dim,out_dim=hidden_dim)
self.l3 = Linear(key=keys[2],in_dim=hidden_dim,out_dim=out_dim)
def __call__(self,x):
x = self.l1(x)
x = jax.nn.tanh(x)
x = self.l2(x)
x = jax.nn.tanh(x)
x = self.l3(x)
return x
model = StackedLinear(in_dim=1,out_dim=1,hidden_dim=10,key=jax.random.PRNGKey(0))
x = jnp.linspace(0,1,100)[:,None]
y = x**3 + jax.random.uniform(jax.random.PRNGKey(0),(100,1))*0.01
๐จ Visualize
summary | tree_box | tree_diagram |
print(model.summary())
โโโโโโฌโโโโโโโฌโโโโโโโโฌโโโโโโโโฌโโโโโโโโโโโโโโโโโโ
โNameโType โParam #โSize โConfig โ
โโโโโโผโโโโโโโผโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโค
โl1 โLinearโ20(0) โ80.00B โweight=f32[1,10] โ
โ โ โ โ(0.00B)โbias=f32[1,10] โ
โโโโโโผโโโโโโโผโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโค
โl2 โLinearโ110(0) โ440.00Bโweight=f32[10,10]โ
โ โ โ โ(0.00B)โbias=f32[1,10] โ
โโโโโโผโโโโโโโผโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโค
โl3 โLinearโ11(0) โ44.00B โweight=f32[10,1] โ
โ โ โ โ(0.00B)โbias=f32[1,1] โ
โโโโโโดโโโโโโโดโโโโโโโโดโโโโโโโโดโโโโโโโโโโโโโโโโโโ
Total count : 141(0)
Dynamic count : 141(0)
Frozen count : 0(0)
-----------------------------------------------
Total size : 564.00B(0.00B)
Dynamic size : 564.00B(0.00B)
Frozen size : 0.00B(0.00B)
===============================================
|
using jax.eval_shape (no-flops operation) note : the created modules in print(model.tree_box(array=x))
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โStackedLinear[Parent] โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโโ
โโ โ Input โ f32[100,1] โโ
โโ Linear[l1] โโโโโโโโโโผโโโโโโโโโโโโโโโคโ
โโ โ Output โ f32[100,128] โโ
โโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโโ
โโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโโ
โโ โ Input โ f32[100,128] โโ
โโ Linear[l2] โโโโโโโโโโผโโโโโโโโโโโโโโโคโ
โโ โ Output โ f32[100,128] โโ
โโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโโ
โโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโโ
โโ โ Input โ f32[100,128] โโ
โโ Linear[l3] โโโโโโโโโโผโโโโโโโโโโโโโโโคโ
โโ โ Output โ f32[100,1] โโ
โโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโโ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
print(model.tree_diagram())
StackedLinear
โโโ l1=Linear
โ โโโ weight=f32[1,10]
โ โโโ bias=f32[1,10]
โโโ l2=Linear
โ โโโ weight=f32[10,10]
โ โโโ bias=f32[1,10]
โโโl3=Linear
โโโ weight=f32[10,1]
โโโ bias=f32[1,1]
|
mermaid.io (Native support in Github/Notion) |
โจ Generate shareable vizualization links โจ
# generate mermaid diagrams
# print(pytc.tree_viz.tree_mermaid(model)) # generate core syntax
>>> pytc.tree_viz.tree_mermaid(model,link=True)
# 'Open URL in browser: https://pytreeclass.herokuapp.com/temp/?id=*********'
flowchart LR
id15696277213149321320[StackedLinear]
id15696277213149321320 --> id159132120600507116(l1\nLinear)
id159132120600507116 --- id7500441386962467209["weight\nf32[1,10]"]
id159132120600507116 --- id10793958738030044218["bias\nf32[1,10]"]
id15696277213149321320 --> id10009280772564895168(l2\nLinear)
id10009280772564895168 --- id11951215191344350637["weight\nf32[10,10]"]
id10009280772564895168 --- id1196345851686744158["bias\nf32[1,10]"]
id15696277213149321320 --> id7572222925824649475(l3\nLinear)
id7572222925824649475 --- id4749243995442935477["weight\nf32[10,1]"]
id7572222925824649475 --- id8042761346510512486["bias\nf32[1,1]"]
|
โ๏ธ Model surgery
# freeze l1
from pytreeclass.tree_util import tree_freeze
model = model.at["l1"].set(tree_freeze(model.l1))
# Set negative_values in l2 to 0
filtered_l2 = model.l2.at[model.l2<0].set(0)
model = model.at["l2"].set( filtered_l2 )
# apply sin(x) to all values in l3
filtered_l3 = model.l3.at[...].apply(jnp.sin)
model = model.at["l3"].set(filtered_l3)
# frozen nodes are marked with #
print(model.tree_diagram())
StackedLinear
โ#โ l1=Linear
โ โ#โ weight=f32[1,10]
โ โ#โ bias=f32[1,10]
โโโ l2=Linear
โ โโโ weight=f32[10,10]
โ โโโ bias=f32[1,10]
โโโ l3=Linear
โโโ weight=f32[10,1]
โโโ bias=f32[1,1]
โ๏ธ Filtering with .at[]
PyTreeClass
offers four means of filtering:
- Filter by value
- Filter by field name
- Filter by field type
- Filter by field metadata.
The following example demonstrates the usage the filtering. Suppose you have the following (Multilayer perceptron) MLP class
- Note in
StackedLinear
l1
andl2
has a description infield
metadata.
Model definition
import jax
from jax import numpy as jnp
import pytreeclass as pytc
import matplotlib.pyplot as plt
from dataclasses import field
@pytc.treeclass
class Linear :
weight : jnp.ndarray
bias : jnp.ndarray
def __init__(self,key,in_dim,out_dim):
self.weight = jax.random.normal(key,shape=(in_dim, out_dim)) * jnp.sqrt(2/in_dim)
self.bias = jnp.ones((1,out_dim))
def __call__(self,x):
return x @ self.weight + self.bias
@pytc.treeclass
class StackedLinear:
l1 : Linear = field(metadata={"description": "First layer"})
l2 : Linear = field(metadata={"description": "Second layer"})
def __init__(self,key,in_dim,out_dim,hidden_dim):
keys= jax.random.split(key,3)
self.l1 = Linear(key=keys[0],in_dim=in_dim,out_dim=hidden_dim)
self.l2 = Linear(key=keys[2],in_dim=hidden_dim,out_dim=out_dim)
def __call__(self,x):
x = self.l1(x)
x = jax.nn.tanh(x)
x = self.l2(x)
return x
model = StackedLinear(in_dim=1,out_dim=1,hidden_dim=5,key=jax.random.PRNGKey(0))
- Raw model values before any filtering.
print(model)
StackedLinear(
l1=Linear(
weight=[[-1.6248673 -2.8383057 1.3969219 1.3169124 -0.40784812]],
bias=[[1. 1. 1. 1. 1.]]
),
l2=Linear(
weight=
[[ 0.98507565]
[ 0.99815285]
[-1.0687716 ]
[-0.19255024]
[-1.2108876 ]],
bias=[[1.]]
)
)
Filter by value
- Get all negative values
print(model.at[model<0].get())
StackedLinear(
l1=Linear(
weight=[-1.6248673 -2.8383057 -0.40784812],
bias=[]
),
l2=Linear(
weight=[-1.0687716 -0.19255024 -1.2108876 ],
bias=[]
)
)
- Set negative values to 0
print(model.at[model<0].set(0))
StackedLinear(
l1=Linear(
weight=[[0. 0. 1.3969219 1.3169124 0. ]],
bias=[[1. 1. 1. 1. 1.]]
),
l2=Linear(
weight=
[[0.98507565]
[0.99815285]
[0. ]
[0. ]
[0. ]],
bias=[[1.]]
)
)
- Apply f(x)=x^2 to negative values
print(model.at[model<0].apply(lambda x:x**2))
StackedLinear(
l1=Linear(
weight=[[2.6401937 8.05598 1.3969219 1.3169124 0.16634008]],
bias=[[1. 1. 1. 1. 1.]]
),
l2=Linear(
weight=
[[0.98507565]
[0.99815285]
[1.1422727 ]
[0.03707559]
[1.4662486 ]],
bias=[[1.]]
)
)
- Sum all negative values
print(model.at[model<0].reduce(lambda acc,cur: acc+jnp.sum(cur)))
-7.3432307
Filter by field name
- Get all fields named
l1
print(model.at[model == "l1"].get())
StackedLinear(
l1=Linear(
weight=[-1.6248673 -2.8383057 1.3969219 1.3169124 -0.40784812],
bias=[1. 1. 1. 1. 1.]
),
l2=Linear(weight=[],bias=[])
)
Filter by field type
- Get all fields of
Linear
type
print(model.at[model == Linear].get())
StackedLinear(
l1=Linear(
weight=[-1.6248673 -2.8383057 1.3969219 1.3169124 -0.40784812],
bias=[1. 1. 1. 1. 1.]
),
l2=Linear(
weight=[ 0.98507565 0.99815285 -1.0687716 -0.19255024 -1.2108876 ],
bias=[1.]
)
)
Filter by field metadata
- Get all fields of with their metadata equal to
{"description": "First layer"}
print(model.at[model == {"description": "First layer"}].get())
StackedLinear(
l1=Linear(
weight=[-1.6248673 -2.8383057 1.3969219 1.3169124 -0.40784812],
bias=[1. 1. 1. 1. 1.]
),
l2=Linear(weight=[],bias=[])
)
Mix and match different filtering methods.
- Get only fields named
weight
of positive values.
mask = (model == "weight") & (model>0)
print(model.at[mask].get())
StackedLinear(
l1=Linear(weight=[1.3969219 1.3169124],bias=[]),
l2=Linear(weight=[0.98507565 0.99815285],bias=[])
)
Marking fields non-differentiable โจ NEW โจ
Automatically marking fields non-differentiable
In the following code example, we train a model with differentiable and non-differentiable fields.
Using jax.grad
will throw an error, however to circumvent this we use pytc.filter_nondiff
to filter out any non-differentiable field.
import pytreeclass as pytc
import jax.numpy as jnp
import jax
from typing import Callable
@pytc.treeclass
class Linear:
weight: jnp.ndarray # โ
differentiable
bias: jnp.ndarray # โ
differentiable
other: tuple[int,...] = (1,2,3,4) # โ non-differentiable
a: int = 1 # โ non-differentiable
b: float = 1.0 # โ
differentiable
c: int = 1 # โ non-differentiable
d: float = 2.0 # โ
differentiable
act : Callable = jax.nn.tanh # โ non-differentiable
def __init__(self,in_dim,out_dim):
self.weight = jnp.ones((in_dim,out_dim))
self.bias = jnp.ones((1,out_dim))
def __call__(self,x):
return self.act(self.b+x)
@jax.value_and_grad
def loss_func(model):
# lets optimize a differentiable field `b`
# inside a non-differentiable field `act`
return jnp.mean((model(1.)-0.5)**2)
@jax.jit
def update(model):
value,grad = loss_func(model)
return value,model-1e-3*grad
def train(model,epochs=10_000):
# here we use the filter_nondiff function
# to filter out the non-differentiable fields
# otherwise we would get an error
model = pytc.filter_nondiff(model)
for _ in range(epochs):
value,model = update(model)
return model
# before any filtering or training
model = Linear(1,1)
print(model)
# Linear(
# weight=[[1.]],
# bias=[[1.]],
# other=(1,2,3,4),
# a=1,
# b=1.0,
# c=1,
# d=2.0,
# act=tanh(x)
# )
model = train(model)
# after filtering and training
# note that the non-differentiable fields are not updated
# and the differentiable fields are updated
# the non-differentiable fields are marked with a `*`
print(model)
# Linear(
# weight=[[1.]],
# bias=[[1.]],
# *other=(1,2,3,4),
# *a=1,
# b=-0.36423424,
# *c=1,
# d=2.0,
# *act=tanh(x)
# )
Marking fields non-differentiable with a mask
In the following example, let's say we want to train only the field `b` and mark all other fields non-differentiable, we can simply do this in the following codenew_model = pytc.filter_nondiff(model, model != "b")
# we can see all fields except `b` are marked with
# `*` to mark non-differentiable.
print(new_model)
# Linear(
# *weight=f32[1,1],
# *bias=f32[1,1],
# *other=(1,2,3,4),
# *a=1,
# b=f32[],
# *c=1,
# *d=f32[],
# *act=tanh(x)
# )
# undo the filtering
# note the removal of `*` that marks non-diff fields
unfiltered_model = pytc.unfilter_nondiff(new_model)
print(unfiltered_model)
# Linear(
# weight=f32[1,1],
# bias=f32[1,1],
# other=(1,2,3,4),
# a=1,
# b=f32[],
# c=1,
# d=f32[],
# act=tanh(x)
# )
๐ Stateful computations
First, Under jax.jit jax requires states to be explicit, this means that for any class instance; variables needs to be separated from the class and be passed explictly. However when using @pytc.treeclass no need to separate the instance variables ; instead the whole instance is passed as a state.
Using the following pattern,Updating state functionally can be achieved under jax.jit
import jax
import pytreeclass as pytc
@pytc.treeclass
class Counter:
calls : int = 0.
def increment(self):
self.calls += 1
counter = Counter() # Counter(calls=0.0)
Here, we define the update function. Since the increment method mutate the internal state, thus we need to use the functional approach to update the state by using .at
. To achieve this we can use .at[method_name].__call__(*args,**kwargs)
, this functional call will return the value of this call and a new model instance with the update state.
@jax.jit
def update(counter):
value, new_counter = counter.at["increment"]()
return new_counter
for i in range(10):
counter = update(counter)
print(counter.calls) # 10.0
๐ Applications
Check other packages built on top of PyTreeClass
๐ Acknowledgements
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for pytreeclass-0.1.6-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | d9ceab664253e46559ad36fcbb88013a2563de086615301fa1d9971b654998ef |
|
MD5 | 88cfafee52e069e31baed9faa779b597 |
|
BLAKE2b-256 | b3a689f7429e323d37300bf26466219a760aea3852d2b7786af06762e96af1ff |