JAX NN library.
Project description
The โจMagicalโจ JAX NN Library.
*Serket is the goddess of magic in Egyptian mythology
Installation |Description |Quick Example |Freezing/Fine tuning |Filtering
๐ ๏ธ Installation
pip install serket
Install development version
pip install git+https://github.com/ASEM000/serket
๐ Description
serket
aims to be the most intuitive and easy-to-use Neural network library in JAX.serket
is built on top ofpytreeclass
serket
currently implementsLinear
- Fully connected network
FNN
, Parallel fully connected networkPFNN
Dropout
Sequential
,Lambda
AdaptiveReLU
,AdaptiveLeakyReLU
,AdaptiveSigmoid
,AdaptiveTanh
โฉ Quick Example
Simple Fully connected neural network.
๐๏ธ Model definition
import serket as sk
import jax
import jax.numpy as jnp
import jax.random as jr
@sk.treeclass
class NN:
def __init__(
self,
in_features:int,
out_features:int,
hidden_features: int, key:jr.PRNGKey = jr.PRNGKey(0)):
k1,k2,k3 = jr.split(key, 3)
self.l1 = sk.nn.Linear(in_features, hidden_features, key=k1)
self.l2 = sk.nn.Linear(hidden_features, hidden_features, key=k2)
self.l3 = sk.nn.Linear(hidden_features, out_features, key=k3)
def __call__(self, x):
x = self.l1(x)
x = jax.nn.relu(x)
x = self.l2(x)
x = jax.nn.relu(x)
x = self.l3(x)
return x
model = NN(
in_features=1,
out_features=1,
hidden_features=128,
key=jr.PRNGKey(0))
๐จ Visualize
Model representation `__repr__`
print(f"{model!r}")
# `*` represents untrainable(static) nodes.
NN(
l1=Linear(
weight=f32[1,128],
bias=f32[128],
*in_features=1,
*out_features=128,
*weight_init_func=init(key,shape,dtype),
*bias_init_func=Lambda(key,shape)),
l2=Linear(
weight=f32[128,128],
bias=f32[128],
*in_features=128,
*out_features=128,
*weight_init_func=init(key,shape,dtype),
*bias_init_func=Lambda(key,shape)),
l3=Linear(
weight=f32[128,1],
bias=f32[1],
*in_features=128,
*out_features=1,
*weight_init_func=init(key,shape,dtype),
*bias_init_func=Lambda(key,shape)))
Model values `__str__`
print(f"{model!s}")
# `*` represents untrainable(static) nodes.
NN(
l1=Linear(
weight=
[[-0.556661 -0.6288703 1.28644 -2.9053314 -0.9808919 0.02763719
-1.5992663 0.3522784 -0.72343904 2.1087773 -1.184502 0.37314773
0.13440615 -1.1792887 2.646051 -0.31855923 1.2535691 -0.350722
0.24288356 0.8924919 1.8751624 -0.4494902 -0.6869111 2.4898252
1.0088646 2.3707743 -1.212474 -0.19152707 0.51991814 -0.801294
1.9568022 -0.05682194 0.7434735 0.24796781 -0.31967887 -0.6026076
0.02562018 -2.1735084 -0.7877185 1.1945596 -0.5776542 -0.08814432
0.01738743 0.85175467 -2.4330282 2.400132 -0.15812641 -2.2410994
1.8925649 -1.4573553 -1.5524752 0.2746206 0.99534875 -0.52039754
-1.6240916 0.57301414 1.2754964 0.39254263 1.5842631 -0.4408383
0.22060809 -0.11473875 1.2702179 0.14604266 -1.1393331 -0.20517357
2.8613555 -0.76657873 -2.7623959 1.4629859 1.7641917 1.4639573
0.90266997 -1.4661105 1.1719718 0.6656477 -0.6834308 1.0311401
-3.0281627 1.7895395 -1.248399 -0.13082643 2.1665883 2.8423917
0.24363454 0.20664148 1.7082529 2.129452 0.2974662 -0.8575109
-0.5970874 0.01702698 -0.18604587 0.7464636 0.83206064 0.6965974
0.7219791 0.8652629 1.3164111 -2.788336 -0.06530724 -0.7846771
-0.7344756 1.5899261 0.2623837 -0.01147135 -0.5437088 0.68380916
-1.5405492 1.1371891 -0.67851156 -0.37528485 -0.0336573 -2.0287845
0.3067764 -1.3464272 -0.6037441 -1.6209227 -2.3215613 -3.062661
0.5440992 -0.8735671 0.9094481 2.3398476 0.5821143 1.9373481
-0.36942863 2.5151203 ]],
bias=
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1.],
*in_features=1,
*out_features=128,
*weight_init_func=init(key,shape,dtype),
*bias_init_func=Lambda(key,shape)),
l2=Linear(
weight=
[[ 0.01565691 -0.02781865 0.15829083 ... 0.00930642 0.03536453
0.01890953]
[-0.01510135 0.1975845 0.2470963 ... -0.13168702 0.01404842
-0.21973991]
[-0.07814246 -0.18890998 -0.26707044 ... -0.15391685 -0.16248046
-0.11042175]
...
[-0.01806537 0.01311939 0.00696071 ... -0.18970545 0.07411639
-0.04393121]
[ 0.07426595 0.19547018 -0.26033685 ... -0.01357261 -0.00193011
-0.00152987]
[-0.00897581 -0.0115421 0.08062097 ... -0.098473 0.1083767
0.12410464]],
bias=
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1.],
*in_features=128,
*out_features=128,
*weight_init_func=init(key,shape,dtype),
*bias_init_func=Lambda(key,shape)),
l3=Linear(
weight=
[[-0.13613197]
[ 0.14116174]
[-0.06744987]
[-0.08091136]
[-0.27361065]
[-0.06548355]
[ 0.01022272]
[ 0.0252317 ]
[ 0.0237782 ]
[ 0.00614042]
[ 0.1812661 ]
[-0.06621032]
[ 0.16613998]
[-0.05014007]
[-0.21103479]
[-0.11941364]
[ 0.00036292]
[ 0.00039283]
[ 0.08278123]
[ 0.10028461]
[ 0.07373375]
[ 0.04089416]
[-0.00426106]
[-0.0247845 ]
[ 0.2804994 ]
[-0.11494187]
[ 0.26255226]
[-0.05349432]
[-0.16621305]
[ 0.0187737 ]
[ 0.11997257]
[ 0.24926668]
[ 0.12966438]
[ 0.02550141]
[ 0.18541676]
[-0.09129915]
[-0.22716352]
[-0.18755099]
[ 0.1665244 ]
[-0.10028487]
[ 0.09164064]
[-0.02597431]
[-0.15029983]
[-0.02553205]
[ 0.16129787]
[-0.07182706]
[-0.07004812]
[-0.03763127]
[-0.06973497]
[-0.0998554 ]
[ 0.00957549]
[ 0.0948947 ]
[-0.11812133]
[ 0.00408699]
[ 0.18451509]
[-0.2392044 ]
[ 0.1889591 ]
[ 0.20876819]
[ 0.16006592]
[ 0.11820399]
[ 0.13270618]
[-0.02642066]
[-0.03972287]
[ 0.0130475 ]
[ 0.12387222]
[-0.07360736]
[-0.07168346]
[ 0.26462224]
[-0.24544406]
[ 0.02614611]
[ 0.17016351]
[-0.10638441]
[-0.01891194]
[ 0.02476142]
[ 0.00474042]
[ 0.06326718]
[-0.10003307]
[ 0.03704525]
[-0.17377096]
[ 0.02369826]
[-0.09041592]
[ 0.06363823]
[-0.00131075]
[-0.19338304]
[ 0.2741859 ]
[-0.03178171]
[-0.0061704 ]
[ 0.01059608]
[ 0.17419283]
[ 0.08168265]
[ 0.08119942]
[ 0.07225287]
[-0.02761899]
[ 0.11468761]
[ 0.0180395 ]
[-0.04214213]
[-0.10949433]
[-0.03126818]
[ 0.14708327]
[-0.25051817]
[ 0.0431254 ]
[ 0.10890955]
[-0.00171187]
[-0.07619253]
[ 0.16909993]
[-0.11504915]
[ 0.02266672]
[ 0.22796142]
[ 0.05010169]
[-0.26961675]
[-0.02833704]
[-0.21504459]
[ 0.00469143]
[ 0.23426442]
[ 0.04301503]
[-0.13504943]
[-0.1914389 ]
[-0.1553146 ]
[ 0.00082878]
[-0.05092873]
[-0.13719554]
[-0.24856809]
[-0.05966872]
[-0.04416765]
[ 0.12827884]
[-0.06721988]
[ 0.05502734]
[ 0.03519182]],
bias=[1.],
*in_features=128,
*out_features=1,
*weight_init_func=init(key,shape,dtype),
*bias_init_func=Lambda(key,shape)))
Tree diagram
# `*` represents untrainable(static) nodes.
print(model.tree_diagram())
NN
โโโ l1=Linear
โ โโโ weight=f32[1,128]
โ โโโ bias=f32[128]
โ โ*โ in_features=1
โ โ*โ out_features=128
โ โ*โ weight_init_func=init(key,shape,dtype)
โ โ*โ bias_init_func=Lambda(key,shape)
โโโ l2=Linear
โ โโโ weight=f32[128,128]
โ โโโ bias=f32[128]
โ โ*โ in_features=128
โ โ*โ out_features=128
โ โ*โ weight_init_func=init(key,shape,dtype)
โ โ*โ bias_init_func=Lambda(key,shape)
โโโ l3=Linear
โโโ weight=f32[128,1]
โโโ bias=f32[1]
โ*โ in_features=128
โ*โ out_features=1
โ*โ weight_init_func=init(key,shape,dtype)
โ*โ bias_init_func=Lambda(key,shape)
Tree summary
>>> print(model.summary())
โโโโโโฌโโโโโโโฌโโโโโโโโโโฌโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโ
โNameโType โParam # โSize โConfig โ
โโโโโโผโโโโโโโผโโโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโโโค
โl1 โLinearโ256(0) โ1.00KB โweight=f32[1,128] โ
โ โ โ โ(0.00B)โbias=f32[128] โ
โโโโโโผโโโโโโโผโโโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโโโค
โl2 โLinearโ16,512(0)โ64.50KBโweight=f32[128,128]โ
โ โ โ โ(0.00B)โbias=f32[128] โ
โโโโโโผโโโโโโโผโโโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโโโค
โl3 โLinearโ129(0) โ516.00Bโweight=f32[128,1] โ
โ โ โ โ(0.00B)โbias=f32[1] โ
โโโโโโดโโโโโโโดโโโโโโโโโโดโโโโโโโโดโโโโโโโโโโโโโโโโโโโโ
Total count : 16,897(0)
Dynamic count : 16,897(0)
Frozen count : 0(0)
---------------------------------------------------
Total size : 66.00KB(0.00B)
Dynamic size : 66.00KB(0.00B)
Frozen size : 0.00B(0.00B)
===================================================
Tree summary with shape inference
**Using `model.summary(array=input_array)` `serket` can evaluate the shape propagation without evaluating the model , by using `jax` no-flop shape inference operations.**print(model.summary(array=x))
โโโโโโฌโโโโโโโฌโโโโโโโโโโฌโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโ
โNameโType โParam # โSize โConfig โInput/Outputโ
โโโโโโผโโโโโโโผโโโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค
โl1 โLinearโ256(0) โ1.00KB โweight=f32[1,128] โf32[100,1] โ
โ โ โ โ(0.00B)โbias=f32[128] โf32[100,128]โ
โโโโโโผโโโโโโโผโโโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค
โl2 โLinearโ16,512(0)โ64.50KBโweight=f32[128,128]โf32[100,128]โ
โ โ โ โ(0.00B)โbias=f32[128] โf32[100,128]โ
โโโโโโผโโโโโโโผโโโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค
โl3 โLinearโ129(0) โ516.00Bโweight=f32[128,1] โf32[100,128]โ
โ โ โ โ(0.00B)โbias=f32[1] โf32[100,1] โ
โโโโโโดโโโโโโโดโโโโโโโโโโดโโโโโโโโดโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโ
Total count : 16,897(0)
Dynamic count : 16,897(0)
Frozen count : 0(0)
----------------------------------------------------------------
Total size : 66.00KB(0.00B)
Dynamic size : 66.00KB(0.00B)
Frozen size : 0.00B(0.00B)
================================================================
โ๐ง Train
import matplotlib.pyplot as plt
x = jnp.linspace(0,1,100)[:,None]
y = x**3 + jax.random.uniform(jax.random.PRNGKey(0),(100,1))*0.01
@jax.value_and_grad
def loss_func(model,x,y):
return jnp.mean((model(x)-y)**2)
@jax.jit
def update(model,x,y):
value,grad = loss_func(model,x,y)
return value , model - 1e-3*grad
plt.plot(x,y,'-k',label='True')
plt.plot(x,model(x),'-r',label='Prediction')
plt.title("Before training")
plt.legend()
plt.show()
for _ in range(20_000):
value,model = update(model,x,y)
plt.plot(x,y,'-k',label='True')
plt.plot(x,model(x),'-r',label='Prediction')
plt.title("After training")
plt.legend()
plt.show()
๐ฅถ Freezing parameters /Fine tuning
In serket
simply use .at[...].freeze()
/.at[...].unfreeze()
on treeclass
instance to freeze/unfreeze it is parameters.
# Freeze the entire model
frozen_model = model.at[...].freeze()
# To freeze a certain sub module
# use model = model.at["sub_module name"].freeze()
@jax.value_and_grad
def loss_func(model,x,y):
return jnp.mean((model(x)-y)**2)
@jax.jit
def update(model,x,y):
value,grad = loss_func(model,x,y)
return value , model - 1e-3*grad
plt.plot(x,y,'-k',label='True')
plt.plot(x,frozen_model(x),'-r',label='Prediction')
plt.title("Before training")
plt.legend()
plt.show()
for _ in range(20_000):
value,frozen_model = update(frozen_model,x,y)
plt.plot(x,y,'-k',label='True')
plt.plot(x,frozen_model(x),'-r',label='Prediction')
plt.title("After training")
plt.legend()
plt.show()
๐ Filtering by masking
Filter by value
# get model negative values
negative_model = model.at[model<0].get()
# Set negative values to 0
zeroed_model = model.at[model<0].set(0)
# Apply `jnp.cos` to negative values
cosined_model = model.at[model<0].apply(jnp.cos)
Filter by field name
# get model layer named `l1`
l1_model = model.at[model == "l1" ].get()
# Set `l1` values to 0
zeroed_model = model.at[model == "l1" ].set(0)
# Apply `jnp.cos` to `l1`
cosined_model = model.at[model == "l1" ].apply(jnp.cos)
Filter by field type
# get all model `Linear` layers
l1_model = model.at[model == sk.nn.Linear ].get()
# Set `Linear` layers to 0
zeroed_model = model.at[model == sk.nn.Linear ].set(0)
# Apply `jnp.cos` to all `Linear` layers
cosined_model = model.at[model == sk.nn.Linear ].apply(jnp.cos)
Filter by mixed masks
# Set all `Linear` bias to 0
mask = (model == sk.nn.Linear) & (model == "bias" )
zero_bias_model = model.at[mask].set(0.)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
serket-0.0.2.tar.gz
(16.5 kB
view hashes)
Built Distribution
serket-0.0.2-py3-none-any.whl
(13.7 kB
view hashes)