JAX compatible dataclass.
Project description
Installation |Description |Quick Example |Filtering |StatefulComputation |Applications |More |Acknowledgements
๐ ๏ธ Installation
pip install pytreeclass
Install development version
pip install git+https://github.com/ASEM000/PyTreeClass
๐ Description
PyTreeClass
is a JAX-compatible dataclass
-like decorator to create and operate on stateful JAX PyTrees.
The package aims to achieve two goals:
- ๐ To maintain safe and correct behaviour by using immutable modules with functional API.
- To achieve the most intuitive user experience in the
JAX
ecosystem by :- ๐๏ธ Defining layers similar to
PyTorch
orTensorFlow
sublcassing style. - โ๏ธ Filtering\Indexing layer values by using boolean masking similar to
jax.numpy.at[].{get,set,apply,...}
- ๐จ Visualize defined layers in plethora of ways for better debugging and sharing of information
- ๐๏ธ Defining layers similar to
โฉ Quick Example
๐๏ธ Create simple MLP
import jax
from jax import numpy as jnp
import pytreeclass as pytc
import matplotlib.pyplot as plt
@pytc.treeclass
class Linear :
# Any variable not wrapped with @pytc.treeclass
# should be declared as a dataclass field here
weight : jnp.ndarray
bias : jnp.ndarray
def __init__(self,key,in_dim,out_dim):
self.weight = jax.random.normal(key,shape=(in_dim, out_dim)) * jnp.sqrt(2/in_dim)
self.bias = jnp.ones((1,out_dim))
def __call__(self,x):
return x @ self.weight + self.bias
@pytc.treeclass
class StackedLinear:
def __init__(self,key,in_dim,out_dim,hidden_dim):
keys= jax.random.split(key,3)
# Declaring l1,l2,l3 as dataclass_fields is optional
# as l1,l2,l3 are Linear class that is wrapped with @pytc.treeclass
# To strictly include nodes defined in dataclass fields
# use `@pytc.treeclass(field_only=True)`
self.l1 = Linear(key=keys[0],in_dim=in_dim,out_dim=hidden_dim)
self.l2 = Linear(key=keys[1],in_dim=hidden_dim,out_dim=hidden_dim)
self.l3 = Linear(key=keys[2],in_dim=hidden_dim,out_dim=out_dim)
def __call__(self,x):
x = self.l1(x)
x = jax.nn.tanh(x)
x = self.l2(x)
x = jax.nn.tanh(x)
x = self.l3(x)
return x
model = StackedLinear(in_dim=1,out_dim=1,hidden_dim=10,key=jax.random.PRNGKey(0))
x = jnp.linspace(0,1,100)[:,None]
y = x**3 + jax.random.uniform(jax.random.PRNGKey(0),(100,1))*0.01
๐จ Visualize
summary | tree_box | tree_diagram |
print(model.summary())
โโโโโโฌโโโโโโโฌโโโโโโโโฌโโโโโโโโฌโโโโโโโโโโโโโโโโโโ
โNameโType โParam #โSize โConfig โ
โโโโโโผโโโโโโโผโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโค
โl1 โLinearโ20(0) โ80.00B โweight=f32[1,10] โ
โ โ โ โ(0.00B)โbias=f32[1,10] โ
โโโโโโผโโโโโโโผโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโค
โl2 โLinearโ110(0) โ440.00Bโweight=f32[10,10]โ
โ โ โ โ(0.00B)โbias=f32[1,10] โ
โโโโโโผโโโโโโโผโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโค
โl3 โLinearโ11(0) โ44.00B โweight=f32[10,1] โ
โ โ โ โ(0.00B)โbias=f32[1,1] โ
โโโโโโดโโโโโโโดโโโโโโโโดโโโโโโโโดโโโโโโโโโโโโโโโโโโ
Total count : 141(0)
Dynamic count : 141(0)
Frozen count : 0(0)
-----------------------------------------------
Total size : 564.00B(0.00B)
Dynamic size : 564.00B(0.00B)
Frozen size : 0.00B(0.00B)
===============================================
|
using jax.eval_shape (no-flops operation) note : the created modules in print(model.tree_box(array=x))
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โStackedLinear[Parent] โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโโ
โโ โ Input โ f32[100,1] โโ
โโ Linear[l1] โโโโโโโโโโผโโโโโโโโโโโโโโโคโ
โโ โ Output โ f32[100,128] โโ
โโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโโ
โโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโโ
โโ โ Input โ f32[100,128] โโ
โโ Linear[l2] โโโโโโโโโโผโโโโโโโโโโโโโโโคโ
โโ โ Output โ f32[100,128] โโ
โโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโโ
โโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโโ
โโ โ Input โ f32[100,128] โโ
โโ Linear[l3] โโโโโโโโโโผโโโโโโโโโโโโโโโคโ
โโ โ Output โ f32[100,1] โโ
โโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโโ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
print(model.tree_diagram())
StackedLinear
โโโ l1=Linear
โ โโโ weight=f32[1,10]
โ โโโ bias=f32[1,10]
โโโ l2=Linear
โ โโโ weight=f32[10,10]
โ โโโ bias=f32[1,10]
โโโl3=Linear
โโโ weight=f32[10,1]
โโโ bias=f32[1,1]
|
mermaid.io (Native support in Github/Notion) |
# generate mermaid diagrams
# print(pytc.tree_viz.tree_mermaid(model)) # generate core syntax
pytc.tree_viz.save_viz(model,filename="test_mermaid",method="tree_mermaid_md")
# use `method="tree_mermaid_html"` to save as html
flowchart LR
id15696277213149321320[StackedLinear]
id15696277213149321320 --> id159132120600507116(l1\nLinear)
id159132120600507116 --- id7500441386962467209["weight\nf32[1,10]"]
id159132120600507116 --- id10793958738030044218["bias\nf32[1,10]"]
id15696277213149321320 --> id10009280772564895168(l2\nLinear)
id10009280772564895168 --- id11951215191344350637["weight\nf32[10,10]"]
id10009280772564895168 --- id1196345851686744158["bias\nf32[1,10]"]
id15696277213149321320 --> id7572222925824649475(l3\nLinear)
id7572222925824649475 --- id4749243995442935477["weight\nf32[10,1]"]
id7572222925824649475 --- id8042761346510512486["bias\nf32[1,1]"]
โจ Generate shareable vizualization links โจ
>>> pytc.tree_viz.tree_mermaid(model,link=True)
'Open URL in browser: https://pytreeclass.herokuapp.com/temp/?id=*********'
|
โ๏ธ Model surgery
# freeze l1
model = model.at["l1"].freeze()
# Set negative_values in l2 to 0
filtered_l2 = model.l2.at[model.l2<0].set(0)
model = model.at["l2"].set( filtered_l2 )
# apply sin(x) to all values in l3
filtered_l3 = model.l3.at[...].apply(jnp.sin)
model = model.at["l3"].set(filtered_l3)
# frozen nodes are marked with #
print(model.tree_diagram())
StackedLinear
โ#โ l1=Linear
โ โ#โ weight=f32[1,10]
โ โ#โ bias=f32[1,10]
โโโ l2=Linear
โ โโโ weight=f32[10,10]
โ โโโ bias=f32[1,10]
โโโ l3=Linear
โโโ weight=f32[10,1]
โโโ bias=f32[1,1]
โ๏ธ Filtering with .at[]
PyTreeClass
offers four means of filtering:
- Filter by value
- Filter by field name
- Filter by field type
- Filter by field metadata.
The following example demonstrates the usage the filtering. Suppose you have the following (Multilayer perceptron) MLP class
- Note in
StackedLinear
l1
andl2
has a description infield
metadata.
Model definition
import jax
from jax import numpy as jnp
import pytreeclass as pytc
import matplotlib.pyplot as plt
from dataclasses import field
@pytc.treeclass
class Linear :
weight : jnp.ndarray
bias : jnp.ndarray
def __init__(self,key,in_dim,out_dim):
self.weight = jax.random.normal(key,shape=(in_dim, out_dim)) * jnp.sqrt(2/in_dim)
self.bias = jnp.ones((1,out_dim))
def __call__(self,x):
return x @ self.weight + self.bias
@pytc.treeclass
class StackedLinear:
l1 : Linear = field(metadata={"description": "First layer"})
l2 : Linear = field(metadata={"description": "Second layer"})
def __init__(self,key,in_dim,out_dim,hidden_dim):
keys= jax.random.split(key,3)
self.l1 = Linear(key=keys[0],in_dim=in_dim,out_dim=hidden_dim)
self.l2 = Linear(key=keys[2],in_dim=hidden_dim,out_dim=out_dim)
def __call__(self,x):
x = self.l1(x)
x = jax.nn.tanh(x)
x = self.l2(x)
return x
model = StackedLinear(in_dim=1,out_dim=1,hidden_dim=5,key=jax.random.PRNGKey(0))
- Raw model values before any filtering.
print(model)
StackedLinear(
l1=Linear(
weight=[[-1.6248673 -2.8383057 1.3969219 1.3169124 -0.40784812]],
bias=[[1. 1. 1. 1. 1.]]
),
l2=Linear(
weight=
[[ 0.98507565]
[ 0.99815285]
[-1.0687716 ]
[-0.19255024]
[-1.2108876 ]],
bias=[[1.]]
)
)
Filter by value
- Get all negative values
print(model.at[model<0].get())
StackedLinear(
l1=Linear(
weight=[-1.6248673 -2.8383057 -0.40784812],
bias=[]
),
l2=Linear(
weight=[-1.0687716 -0.19255024 -1.2108876 ],
bias=[]
)
)
- Set negative values to 0
print(model.at[model<0].set(0))
StackedLinear(
l1=Linear(
weight=[[0. 0. 1.3969219 1.3169124 0. ]],
bias=[[1. 1. 1. 1. 1.]]
),
l2=Linear(
weight=
[[0.98507565]
[0.99815285]
[0. ]
[0. ]
[0. ]],
bias=[[1.]]
)
)
- Apply f(x)=x^2 to negative values
print(model.at[model<0].apply(lambda x:x**2))
StackedLinear(
l1=Linear(
weight=[[2.6401937 8.05598 1.3969219 1.3169124 0.16634008]],
bias=[[1. 1. 1. 1. 1.]]
),
l2=Linear(
weight=
[[0.98507565]
[0.99815285]
[1.1422727 ]
[0.03707559]
[1.4662486 ]],
bias=[[1.]]
)
)
- Sum all negative values
print(model.at[model<0].reduce(lambda acc,cur: acc+jnp.sum(cur)))
-7.3432307
Filter by field name
- Get all fields named
l1
print(model.at[model == "l1"].get())
StackedLinear(
l1=Linear(
weight=[-1.6248673 -2.8383057 1.3969219 1.3169124 -0.40784812],
bias=[1. 1. 1. 1. 1.]
),
l2=Linear(weight=[],bias=[])
)
Filter by field type
- Get all fields of
Linear
type
print(model.at[model == Linear].get())
StackedLinear(
l1=Linear(
weight=[-1.6248673 -2.8383057 1.3969219 1.3169124 -0.40784812],
bias=[1. 1. 1. 1. 1.]
),
l2=Linear(
weight=[ 0.98507565 0.99815285 -1.0687716 -0.19255024 -1.2108876 ],
bias=[1.]
)
)
Filter by field metadata
- Get all fields of with
{"description": "First layer"}
in their metadata
print(model.at[model == {"description": "First layer"}].get())
StackedLinear(
l1=Linear(
weight=[-1.6248673 -2.8383057 1.3969219 1.3169124 -0.40784812],
bias=[1. 1. 1. 1. 1.]
),
l2=Linear(weight=[],bias=[])
)
Mix and match different filtering methods.
- Get only fields named
weight
positive values.
mask = (model == "weight") & (model>0)
print(model.at[mask].get())
StackedLinear(
l1=Linear(weight=[1.3969219 1.3169124],bias=[]),
l2=Linear(weight=[0.98507565 0.99815285],bias=[])
)
๐ Applications
๐ข More
More compact boilerplate
Using param
:
- More compact definition can be done with node defined at runtime call.
- The Linear layers are defined on the first call and retrieved on the subsequent calls
- This pattern is useful if the module definition depends on runtime data.
@pytc.treeclass
class StackedLinear:
keys: Any
def __init__(self,key):
self.keys = jax.random.split(key,3)
def __call__(self,x):
x = self.param(Linear(self.keys[0],x.shape[-1],10),name="l1")(x)
x = jax.nn.tanh(x)
x = self.param(Linear(self.keys[1],10,10),name="l2")(x)
x = jax.nn.tanh(x)
x = self.param(Linear(self.keys[2],10,x.shape[-1]),name="l3")(x)
return x
# Upon defining the layer the modules are not instantiated
# However, after first call , the nodes are defined.
model = StackedLinear(jax.random.PRNGKey(0))
print(model)
StackedLinear(keys=ui32[3,2])
model((jnp.ones((10,10)))) # first call
print(f"{model!r}")
StackedLinear(
keys=ui32[3,2],
l1=Linear(weight=f32[10,10],bias=f32[1,10]),
l2=Linear(weight=f32[10,10],bias=f32[1,10]),
l3=Linear(weight=f32[10,10],bias=f32[1,10])
)
๐ Acknowledgements
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
pytreeclass-0.1.0.tar.gz
(44.4 kB
view hashes)
Built Distribution
Close
Hashes for pytreeclass-0.1.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f89a0561bbde753cc06a70517dcafd2470df61cb3f3f818aca6acd7eed0c25ca |
|
MD5 | 0adcc6bee6ee3b15737f6e7d12c76dad |
|
BLAKE2b-256 | f1a0d41c9a69c045d5f59dd9989256c4e8deec17ef28027e719a0b8a60c376a8 |