JAX compatible dataclass.
Project description
๐ฒPytreeclass๐ฒ
Write pytorch-like layers with keras-like visualizations in JAX.
Installation |Description |Examples
๐ ๏ธ Installation
pip install pytreeclass
๐ Description
A JAX compatible dataclass
like datastructure with the following functionalities
- Create PyTorch like NN classes like equinox and Treex
- Provides Keras-like
model.summary()
andplot_model
visualizations for pytrees wrapped with@treeclass
. - Apply math/numpy operations like tree-math
- Registering user-defined reduce operations on each class.
- Some fancy indexing syntax functionalities like
x[x>0]
on pytrees
๐ข Examples
Create PyTorch like NN classes
import jax
from jax import numpy as jnp
from pytreeclass import treeclass,static_field,tree_viz
@treeclass
class Linear :
weight : jnp.ndarray
bias : jnp.ndarray
def __init__(self,key,in_dim,out_dim):
self.weight = jax.random.normal(key,shape=(in_dim, out_dim)) * jnp.sqrt(2/in_dim)
self.bias = jnp.ones((1,out_dim))
def __call__(self,x):
return x @ self.weight + self.bias
@treeclass
class StackedLinear:
l1 : Linear
l2 : Linear
l3 : Linear
def __init__(self,key,in_dim,out_dim,hidden_dim):
keys= jax.random.split(key,3)
self.l1 = Linear(key=keys[0],in_dim=in_dim,out_dim=hidden_dim)
self.l2 = Linear(key=keys[1],in_dim=hidden_dim,out_dim=hidden_dim)
self.l3 = Linear(key=keys[2],in_dim=hidden_dim,out_dim=out_dim)
def __call__(self,x):
x = self.l1(x)
x = jax.nn.tanh(x)
x = self.l2(x)
x = jax.nn.tanh(x)
x = self.l3(x)
return x
x = jnp.linspace(0,1,100)[:,None]
y = x**3 + jax.random.uniform(jax.random.PRNGKey(0),(100,1))*0.01
model = StackedLinear(
in_dim=1,
out_dim=1,
hidden_dim=10,
key=jax.random.PRNGKey(0))
def loss_func(model,x,y):
return jnp.mean((model(x)-y)**2 )
@jax.jit
def update(model,x,y):
value,grads = jax.value_and_grad(loss_func)(model,x,y)
# no need to use `jax.tree_map` to update the model
# as it model is wrapped by @treeclass
return value , model-1e-3*grads
for _ in range(1,10_001):
value,model = update(model,x,y)
plt.plot(x,model(x),'--r',label = 'Prediction',linewidth=3)
plt.plot(x,y,'--k',label='True',linewidth=3)
plt.legend()
Visualize Pytrees
>>> print(f"{model!r}")
StackedLinear(
l1=Linear(weight=f32[1,10],bias=f32[1,10])
l2=Linear(weight=f32[10,10],bias=f32[1,10])
l3=Linear(weight=f32[10,1],bias=f32[1,1]))
>>> print(tree_viz.summary(model))
โโโโโโโโฌโโโโโโโโฌโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโ
โType โParam #โSize โConfig โ
โโโโโโโโผโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโโโโค
โLinearโ20 โ80.000 B โbias=f32[1,10] โ
โ โ โ โweight=f32[1,10] โ
โโโโโโโโผโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโโโโค
โLinearโ110 โ440.000 Bโbias=f32[1,10] โ
โ โ โ โweight=f32[10,10]โ
โโโโโโโโผโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโโโโค
โLinearโ11 โ44.000 B โbias=f32[1,1] โ
โ โ โ โweight=f32[10,1] โ
โโโโโโโโดโโโโโโโโดโโโโโโโโโโดโโโโโโโโโโโโโโโโโโ
Total params : 141
Inexact params: 141
Other params: 0
--------------------------------------------
Total size : 564.000 B
Inexact size: 564.000 B
Other size: 0.000 B
============================================
>>> print(tree_viz.tree_box(model,array=x))
# using jax.eval_shape (no-flops operation)
# ** note ** : the created modules in __init__ should be in the same order
# where they are called in __call__
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โStackedLinear(Parent) โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโ
โโ โ Input โ f32[100,1] โโ
โโ Linear(l1) โโโโโโโโโโผโโโโโโโโโโโโโโคโ
โโ โ Output โ f32[100,10] โโ
โโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโ
โโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโ
โโ โ Input โ f32[100,10] โโ
โโ Linear(l2) โโโโโโโโโโผโโโโโโโโโโโโโโคโ
โโ โ Output โ f32[100,10] โโ
โโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโ
โโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโ
โโ โ Input โ f32[100,10] โโ
โโ Linear(l3) โโโโโโโโโโผโโโโโโโโโโโโโโคโ
โโ โ Output โ f32[100,1] โโ
โโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
>>> print(tree_viz.tree_diagram(model))
StackedLinear
โโโ l1=Linear
โ โโโ weight=f32[1,10]
โ โโโ bias=f32[1,10]
โโโ l2=Linear
โ โโโ weight=f32[10,10]
โ โโโ bias=f32[1,10]
โโโl3=Linear
โโโ weight=f32[10,1]
โโโ bias=f32[1,1]
Using out-of-place indexing `.at[].set()` and `.at[].get()` on Pytrees
Similar to JAX pytreeclass provides .at
property for out-of-place update.
# get layer1
layer1 = model.l1
# layer1 repr
>>> print(f"{layer1!r}")
Linear(weight=f32[1,10],bias=f32[1,10])
# layer1 str
>>> print(f"{layer1!s}")
Linear(
weight=
[[-2.55655 1.674097 0.07847876 0.48010758 -1.9021134 -0.95792925
0.27486905 0.6492373 -0.51447827 1.077894 ]],
bias=
[[1.0345106 0.9914236 0.9971589 0.9965508 1.1548151 0.99296653
0.9731281 0.9994397 0.9537985 1.0100957 ]])
# get only positive values
>>> print(layer1.at[layer1>0].get())
Linear(
weight=
[1.674097 0.07847876 0.48010758 0.27486905 0.6492373 1.077894 ],
bias=
[1.0345106 0.9914236 0.9971589 0.9965508 1.1548151 0.99296653
0.9731281 0.9994397 0.9537985 1.0100957 ])
# set negative values to 0
>>> print(layer1.at[layer1<0].set(0))
Linear(
weight=
[[0. 1.674097 0.07847876 0.48010758 0. 0.
0.27486905 0.6492373 0. 1.077894 ]],
bias=
[[1.0345106 0.9914236 0.9971589 0.9965508 1.1548151 0.99296653
0.9731281 0.9994397 0.9537985 1.0100957 ]])
Perform Math operations on JAX pytrees
@treeclass
class Test :
a : float
b : float
c : float
name : str = static_field() # ignore from jax computations
# basic operations
A = Test(10,20,30,'A')
assert (A + A) == Test(20,40,60,'A')
assert (A - A) == Test(0,0,0,'A')
assert (A*A).reduce_mean() == 1400
assert (A + 1) == Test(11,21,31,'A')
# selective operations
# only add 1 to field `a`
# all other fields are set to None and returns the same class
assert (A['a'] + 1) == Test(11,None,None,'A')
# use `|` to merge classes by performing ( left_node or right_node )
Aa = A['a'] + 10 # Test(a=20,b=None,c=None,name=A)
Ab = A['b'] + 10 # Test(a=None,b=30,c=None,name=A)
assert (Aa | Ab | A ) == Test(20,30,30,'A')
# indexing by class
assert A[A>10] == Test(a=None,b=20,c=30,name='A')
# Register custom operations
B = Test([10,10],20,30,'B')
B.register_op( func=lambda node:node+1,name='plus_one')
assert B.plus_one() == Test(a=[11, 11],b=21,c=31,name='B')
# Register custom reduce operations ( similar to functools.reduce)
C = Test(jnp.array([10,10]),20,30,'C')
C.register_op(
func=jnp.prod, # function applied on each node
name='product', # name of the function
reduce_op=lambda x,y:x*y, # function applied between nodes (accumulated * current node)
init_val=1 # initializer for the reduce function
)
# product applies only on each node
# and returns an instance of the same class
assert C.product() == Test(a=100,b=20,c=30,name='C')
# `reduce_` + name of the registered function (`product`)
# reduces the class and returns a value
assert C.reduce_product() == 60000
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
pytreeclass-0.0.5.tar.gz
(20.6 kB
view hashes)
Built Distribution
Close
Hashes for pytreeclass-0.0.5-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 99becf55035b44472b89024ccb5ef354366a037f13cdfd1df093dde6a60aa281 |
|
MD5 | 7d9538963fce905feff76d277c27bf92 |
|
BLAKE2b-256 | f179b7cbcab87480aa1b67e6fb31cbecbfeff40eea54f11926ab0be5e8ac48fc |