JAX compatible dataclass.
Project description
๐ฒPytreeclass๐ฒ
Write pytorch-like layers with rich visualizations in JAX.
Installation |Description |Quick Example |More |Applications
๐ ๏ธ Installation
pip install pytreeclass
๐ Description
A JAX compatible dataclass
like datastructure with the following functionalities
- Create PyTorch like NN classes like equinox and Treex
- Provides rich visualizations for pytrees wrapped with
@treeclass
. - Boolean indexing on Pytrees in functional style similar to jax.numpy. e.g.
x.at[x<0].set(0)
- Apply math/numpy operations like tree-math
โฉ Quick Example
๐๏ธ Create simple MLP
import jax
from jax import numpy as jnp
from pytreeclass import treeclass,tree_viz
import matplotlib.pyplot as plt
@treeclass
class Linear :
# Any variable not wrapped with @treeclass
# should be declared as a dataclass field here
weight : jnp.ndarray
bias : jnp.ndarray
def __init__(self,key,in_dim,out_dim):
self.weight = jax.random.normal(key,shape=(in_dim, out_dim)) * jnp.sqrt(2/in_dim)
self.bias = jnp.ones((1,out_dim))
def __call__(self,x):
return x @ self.weight + self.bias
@treeclass
class StackedLinear:
def __init__(self,key,in_dim,out_dim,hidden_dim):
keys= jax.random.split(key,3)
# Declaring l1,l2,l3 as dataclass_fields is optional
# as they are already wrapped with @treeclass
self.l1 = Linear(key=keys[0],in_dim=in_dim,out_dim=hidden_dim)
self.l2 = Linear(key=keys[1],in_dim=hidden_dim,out_dim=hidden_dim)
self.l3 = Linear(key=keys[2],in_dim=hidden_dim,out_dim=out_dim)
def __call__(self,x):
x = self.l1(x)
x = jax.nn.tanh(x)
x = self.l2(x)
x = jax.nn.tanh(x)
x = self.l3(x)
return x
>>> model = StackedLinear(in_dim=1,out_dim=1,hidden_dim=10,key=jax.random.PRNGKey(0))
>>> x = jnp.linspace(0,1,100)[:,None]
>>> y = x**3 + jax.random.uniform(jax.random.PRNGKey(0),(100,1))*0.01
๐จ Visualize
summary | tree_box | tree_diagram |
>>> print(tree_viz.summary(model))
โโโโโโโโฌโโโโโโโโฌโโโโโโโโฌโโโโโโโโโโโโโโโโโโ
โType โParam #โSize โConfig โ
โโโโโโโโผโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโค
โLinearโ20 โ80.00B โweight=f32[1,10] โ
โ โ(0) โ(0.00B)โbias=f32[1,10] โ
โโโโโโโโผโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโค
โLinearโ110 โ440.00Bโweight=f32[10,10]โ
โ โ(0) โ(0.00B)โbias=f32[1,10] โ
โโโโโโโโผโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโค
โLinearโ11 โ44.00B โweight=f32[10,1] โ
โ โ(0) โ(0.00B)โbias=f32[1,1] โ
โโโโโโโโดโโโโโโโโดโโโโโโโโดโโโโโโโโโโโโโโโโโโ
Total # : 141(0)
Dynamic #: 141(0)
Static/Frozen #: 0(0)
------------------------------------------
Total size : 564.00B(0.00B)
Dynamic size: 564.00B(0.00B)
Static/Frozen size: 0.00B(0.00B)
==========================================
|
>>> print(tree_viz.tree_box(model,array=x))
# using jax.eval_shape (no-flops operation)
# ** note ** : the created modules
# in __init__ should be in the same order
# where they are called in __call__
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โStackedLinear(Parent) โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโ
โโ โ Input โ f32[100,1] โโ
โโ Linear(l1) โโโโโโโโโโผโโโโโโโโโโโโโโคโ
โโ โ Output โ f32[100,10] โโ
โโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโ
โโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโ
โโ โ Input โ f32[100,10] โโ
โโ Linear(l2) โโโโโโโโโโผโโโโโโโโโโโโโโคโ
โโ โ Output โ f32[100,10] โโ
โโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโ
โโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโ
โโ โ Input โ f32[100,10] โโ
โโ Linear(l3) โโโโโโโโโโผโโโโโโโโโโโโโโคโ
โโ โ Output โ f32[100,1] โโ
โโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
>>> print(tree_viz.tree_diagram(model))
StackedLinear
โโโ l1=Linear
โ โโโ weight=f32[1,10]
โ โโโ bias=f32[1,10]
โโโ l2=Linear
โ โโโ weight=f32[10,10]
โ โโโ bias=f32[1,10]
โโโl3=Linear
โโโ weight=f32[10,1]
โโโ bias=f32[1,1]
|
mermaid.io (Native support in Github/Notion) |
# generate mermaid diagrams
# print(tree_viz.tree_mermaid(model)) # generate core syntax
>>> tree_viz.save_viz(model,filename="test_mermaid",method="tree_mermaid_md")
# use `method="tree_mermaid_html"` to save as html
flowchart TD
id15696277213149321320[StackedLinear]
id15696277213149321320 --> id159132120600507116(l1\nLinear)
id159132120600507116 --- id7500441386962467209["weight\nf32[1,10]"]
id159132120600507116 --- id10793958738030044218["bias\nf32[1,10]"]
id15696277213149321320 --> id10009280772564895168(l2\nLinear)
id10009280772564895168 --- id11951215191344350637["weight\nf32[10,10]"]
id10009280772564895168 --- id1196345851686744158["bias\nf32[1,10]"]
id15696277213149321320 --> id7572222925824649475(l3\nLinear)
id7572222925824649475 --- id4749243995442935477["weight\nf32[10,1]"]
id7572222925824649475 --- id8042761346510512486["bias\nf32[1,1]"]
|
โ๏ธ Model surgery
# freeze l1
>>> model.l1 = model.l1.freeze()
# set non-negative values in l2 to 0
>>> model.l2 = model.l2.at[model.l2<0].set(0)
# frozen nodes are marked with #
>>> print(tree_viz.tree_diagram(model))
StackedLinear
โโโ l1=Linear
โ โ#โ weight=f32[1,10]
โ โ#โ bias=f32[1,10]
โโโ l2=Linear
โ โโโ weight=f32[10,10]
โ โโโ bias=f32[1,10]
โโโ l3=Linear
โโโ weight=f32[10,1]
โโโ bias=f32[1,1]
๐ข More
Train from scratch
>>> x = jnp.linspace(0,1,100)[:,None]
>>> y = x**3 + jax.random.uniform(jax.random.PRNGKey(0),(100,1))*0.01
def loss_func(model,x,y):
return jnp.mean((model(x)-y)**2 )
@jax.jit
def update(model,x,y):
value,grads = jax.value_and_grad(loss_func)(model,x,y)
# no need to use `jax.tree_map` to update the model
# as it model is wrapped by @treeclass
return value , model-1e-3*grads
for _ in range(1,20_001):
value,model = update(model,x,y)
plt.plot(x,model(x),'--r',label = 'Prediction',linewidth=3)
plt.plot(x,y,'--k',label='True',linewidth=3)
plt.legend()
Using out-of-place indexing `.at[].set()` and `.at[].get()` on Pytrees
Similar to JAX pytreeclass provides .at
property for out-of-place update.
# get layer1
layer1 = model.l1
# layer1 repr
>>> print(f"{layer1!r}")
Linear(
weight=f32[1,10],
bias=f32[1,10])
# layer1 str
>>> print(f"{layer1!s}")
Linear(
weight=
[[-2.5491788 1.674097 0.07813213 0.47670904 -1.8760327 -0.9941608
0.2808009 0.6522513 -0.53470623 1.0796958 ]],
bias=
[[1.0368661 0.98985153 1.0104426 0.9997676 1.2349331 0.9800282
0.9618377 0.99291945 0.9431369 1.0172408 ]])
# set negative values to 0
>>> print(layer1.at[layer1<0].set(0))
Linear(
weight=
[[0. 1.674097 0.07813213 0.47670904 0. 0.
0.2808009 0.6522513 0. 1.0796958 ]],
bias=
[[1.0368661 0.98985153 1.0104426 0.9997676 1.2349331 0.9800282
0.9618377 0.99291945 0.9431369 1.0172408 ]])
# get only positive values
>>> print(layer1.at[layer1>0].get())
Linear(
weight=
[1.674097 0.07813213 0.47670904 0.2808009 0.6522513 1.0796958 ],
bias=
[1.0368661 0.98985153 1.0104426 0.9997676 1.2349331 0.9800282
0.9618377 0.99291945 0.9431369 1.0172408 ])
Perform Math operations on Pytrees
@treeclass
class Test :
a : float
b : float
c : float
name : str
# basic operations
>>> A = Test(10,20,30,'A')
>>> (A + A) # Test(20,40,60,'A')
>>> (A - A) # Test(0,0,0,'A')
>>> (A*A).reduce_mean() # 1400
>>> (A + 1) # Test(11,21,31,'A')
# only add 1 to field `a`
# all other fields are set to None and returns the same class
>>> assert (A['a'] + 1) == Test(11,None,None,'A')
# use `|` to merge classes by performing ( left_node or right_node )
>>> Aa = A['a'] + 10 # Test(a=20,b=None,c=None,name=A)
>>> Ab = A['b'] + 10 # Test(a=None,b=30,c=None,name=A)
>>> assert (Aa | Ab | A ) == Test(20,30,30,'A')
# indexing by class
>>> A[A>10] # Test(a=None,b=20,c=30,name='A')
# Register custom operations
>>> B = Test([10,10],20,30,'B')
>>> B.register_op( func=lambda node:node+1,name='plus_one')
>>> B.plus_one() # Test(a=[11, 11],b=21,c=31,name='B')
# Register custom reduce operations ( similar to functools.reduce)
>>> C = Test(jnp.array([10,10]),20,30,'C')
>>> C.register_op(
func=jnp.prod, # function applied on each node
name='product', # name of the function
reduce_op=lambda x,y:x*y, # function applied between nodes (accumulated * current node)
init_val=1 # initializer for the reduce function
)
# product applies only on each node
# and returns an instance of the same class
>>> C.product() # Test(a=100,b=20,c=30,name='C')
# `reduce_` + name of the registered function (`product`)
# reduces the class and returns a value
>>> C.reduce_product() # 60000
๐ Applications
Description | Link |
---|---|
Physics informed neural network (PINN) | PINN |
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
pytreeclass-0.0.6.tar.gz
(25.3 kB
view hashes)
Built Distribution
Close
Hashes for pytreeclass-0.0.6-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | b7d01b4463a988e70207e17e12314701cb68fcf373a1259320d1b3fbfe2d9468 |
|
MD5 | e27d20a97146eb47886c9d9c193e3a29 |
|
BLAKE2b-256 | c589f72b57e8d9b072e0eeab5a41f26c658746f74c823150313482623bdc0c71 |