JAX NN library.
Project description
The โจMagicalโจ JAX NN Library.
*Serket is the goddess of magic in Egyptian mythology
Installation |Description |Quick Example |Freezing/Fine tuning |Filtering
๐ ๏ธ Installation
pip install serket
Install development version
pip install git+https://github.com/ASEM000/serket
๐ Description
-
serket
aims to be the most intuitive and easy-to-use Neural network library in JAX. -
serket
is built on top ofpytreeclass
-
serket
currently implements
Group | Layers |
---|---|
Linear | Linear , Bilinear , FNN (Fully connected network), PFNN (Parallel fully connected network) |
Convolution | Conv1D , Conv2D , Conv3D , Conv1DTranspose , Conv2DTranspose , Conv3DTranspose , DepthwiseConv1D , DepthwiseConv2D , DepthwiseConv3D |
Containers | Sequential , Lambda |
Activation | AdaptiveReLU , AdaptiveLeakyReLU , AdaptiveSigmoid , AdaptiveTanh |
Pooling | MaxPool1D , MaxPool2D , MaxPool3D , AvgPool1D , AvgPool2D , AvgPool3D GlobalMaxPool1D , GlobalMaxPool2D , GlobalMaxPool3D , GlobalAvgPool1D , GlobalAvgPool2D , GlobalAvgPool3D |
Reshaping | Flatten , Unflatten , FlipLeftRight2D , FlipUpDown2D , Repeat1D , Repeat2D , Repeat3D , Resize1D , Resize2D , Resize3D , Upsampling1D , Upsampling2D , Upsampling3D , Padding1D , Padding2D , Padding3D |
Recurrent | RNNCell , LSTMCell |
Normalization | LayerNorm |
Blurring | AvgBlur2D |
Physics | Laplace2D |
Blocks | VGG16Block , VGG19Block |
โฉ Quick Example
Simple Fully connected neural network.
๐๏ธ Model definition
import serket as sk
import jax
import jax.numpy as jnp
import jax.random as jr
@sk.treeclass
class NN:
def __init__(
self,
in_features:int,
out_features:int,
hidden_features: int, key:jr.PRNGKey = jr.PRNGKey(0)):
k1,k2,k3 = jr.split(key, 3)
self.l1 = sk.nn.Linear(in_features, hidden_features, key=k1)
self.l2 = sk.nn.Linear(hidden_features, hidden_features, key=k2)
self.l3 = sk.nn.Linear(hidden_features, out_features, key=k3)
def __call__(self, x):
x = self.l1(x)
x = jax.nn.relu(x)
x = self.l2(x)
x = jax.nn.relu(x)
x = self.l3(x)
return x
model = NN(
in_features=1,
out_features=1,
hidden_features=128,
key=jr.PRNGKey(0))
๐จ Visualize
Model representation `__repr__`
print(f"{model!r}")
# `*` represents untrainable(static) nodes.
NN(
l1=Linear(
weight=f32[1,128],
bias=f32[128],
*in_features=1,
*out_features=128,
*weight_init_func=init(key,shape,dtype),
*bias_init_func=Lambda(key,shape)
),
l2=Linear(
weight=f32[128,128],
bias=f32[128],
*in_features=128,
*out_features=128,
*weight_init_func=init(key,shape,dtype),
*bias_init_func=Lambda(key,shape)
),
l3=Linear(
weight=f32[128,1],
bias=f32[1],
*in_features=128,
*out_features=1,
*weight_init_func=init(key,shape,dtype),
*bias_init_func=Lambda(key,shape)
)
)
Model values `__str__`
print(f"{model!s}")
# `*` represents untrainable(static) nodes.
NN(
l1=Linear(
weight=
[[-0.556661 -0.6288703 1.28644 -2.9053314 -0.9808919 0.02763719
-1.5992663 0.3522784 -0.72343904 2.1087773 -1.184502 0.37314773
0.13440615 -1.1792887 2.646051 -0.31855923 1.2535691 -0.350722
0.24288356 0.8924919 1.8751624 -0.4494902 -0.6869111 2.4898252
1.0088646 2.3707743 -1.212474 -0.19152707 0.51991814 -0.801294
1.9568022 -0.05682194 0.7434735 0.24796781 -0.31967887 -0.6026076
0.02562018 -2.1735084 -0.7877185 1.1945596 -0.5776542 -0.08814432
0.01738743 0.85175467 -2.4330282 2.400132 -0.15812641 -2.2410994
1.8925649 -1.4573553 -1.5524752 0.2746206 0.99534875 -0.52039754
-1.6240916 0.57301414 1.2754964 0.39254263 1.5842631 -0.4408383
0.22060809 -0.11473875 1.2702179 0.14604266 -1.1393331 -0.20517357
2.8613555 -0.76657873 -2.7623959 1.4629859 1.7641917 1.4639573
0.90266997 -1.4661105 1.1719718 0.6656477 -0.6834308 1.0311401
-3.0281627 1.7895395 -1.248399 -0.13082643 2.1665883 2.8423917
0.24363454 0.20664148 1.7082529 2.129452 0.2974662 -0.8575109
-0.5970874 0.01702698 -0.18604587 0.7464636 0.83206064 0.6965974
0.7219791 0.8652629 1.3164111 -2.788336 -0.06530724 -0.7846771
-0.7344756 1.5899261 0.2623837 -0.01147135 -0.5437088 0.68380916
-1.5405492 1.1371891 -0.67851156 -0.37528485 -0.0336573 -2.0287845
0.3067764 -1.3464272 -0.6037441 -1.6209227 -2.3215613 -3.062661
0.5440992 -0.8735671 0.9094481 2.3398476 0.5821143 1.9373481
-0.36942863 2.5151203 ]],
bias=
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1.],
*in_features=1,
*out_features=128,
*weight_init_func=init(key,shape,dtype),
*bias_init_func=Lambda(key,shape)
),
l2=Linear(
weight=
[[ 0.01565691 -0.02781865 0.15829083 ... 0.00930642 0.03536453
0.01890953]
[-0.01510135 0.1975845 0.2470963 ... -0.13168702 0.01404842
-0.21973991]
[-0.07814246 -0.18890998 -0.26707044 ... -0.15391685 -0.16248046
-0.11042175]
...
[-0.01806537 0.01311939 0.00696071 ... -0.18970545 0.07411639
-0.04393121]
[ 0.07426595 0.19547018 -0.26033685 ... -0.01357261 -0.00193011
-0.00152987]
[-0.00897581 -0.0115421 0.08062097 ... -0.098473 0.1083767
0.12410464]],
bias=
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1.],
*in_features=128,
*out_features=128,
*weight_init_func=init(key,shape,dtype),
*bias_init_func=Lambda(key,shape)
),
l3=Linear(
weight=
[[-0.13613197]
[ 0.14116174]
[-0.06744987]
[-0.08091136]
[-0.27361065]
[-0.06548355]
[ 0.01022272]
[ 0.0252317 ]
[ 0.0237782 ]
[ 0.00614042]
[ 0.1812661 ]
[-0.06621032]
[ 0.16613998]
[-0.05014007]
[-0.21103479]
[-0.11941364]
[ 0.00036292]
[ 0.00039283]
[ 0.08278123]
[ 0.10028461]
[ 0.07373375]
[ 0.04089416]
[-0.00426106]
[-0.0247845 ]
[ 0.2804994 ]
[-0.11494187]
[ 0.26255226]
[-0.05349432]
[-0.16621305]
[ 0.0187737 ]
[ 0.11997257]
[ 0.24926668]
[ 0.12966438]
[ 0.02550141]
[ 0.18541676]
[-0.09129915]
[-0.22716352]
[-0.18755099]
[ 0.1665244 ]
[-0.10028487]
[ 0.09164064]
[-0.02597431]
[-0.15029983]
[-0.02553205]
[ 0.16129787]
[-0.07182706]
[-0.07004812]
[-0.03763127]
[-0.06973497]
[-0.0998554 ]
[ 0.00957549]
[ 0.0948947 ]
[-0.11812133]
[ 0.00408699]
[ 0.18451509]
[-0.2392044 ]
[ 0.1889591 ]
[ 0.20876819]
[ 0.16006592]
[ 0.11820399]
[ 0.13270618]
[-0.02642066]
[-0.03972287]
[ 0.0130475 ]
[ 0.12387222]
[-0.07360736]
[-0.07168346]
[ 0.26462224]
[-0.24544406]
[ 0.02614611]
[ 0.17016351]
[-0.10638441]
[-0.01891194]
[ 0.02476142]
[ 0.00474042]
[ 0.06326718]
[-0.10003307]
[ 0.03704525]
[-0.17377096]
[ 0.02369826]
[-0.09041592]
[ 0.06363823]
[-0.00131075]
[-0.19338304]
[ 0.2741859 ]
[-0.03178171]
[-0.0061704 ]
[ 0.01059608]
[ 0.17419283]
[ 0.08168265]
[ 0.08119942]
[ 0.07225287]
[-0.02761899]
[ 0.11468761]
[ 0.0180395 ]
[-0.04214213]
[-0.10949433]
[-0.03126818]
[ 0.14708327]
[-0.25051817]
[ 0.0431254 ]
[ 0.10890955]
[-0.00171187]
[-0.07619253]
[ 0.16909993]
[-0.11504915]
[ 0.02266672]
[ 0.22796142]
[ 0.05010169]
[-0.26961675]
[-0.02833704]
[-0.21504459]
[ 0.00469143]
[ 0.23426442]
[ 0.04301503]
[-0.13504943]
[-0.1914389 ]
[-0.1553146 ]
[ 0.00082878]
[-0.05092873]
[-0.13719554]
[-0.24856809]
[-0.05966872]
[-0.04416765]
[ 0.12827884]
[-0.06721988]
[ 0.05502734]
[ 0.03519182]],
bias=[1.],
*in_features=128,
*out_features=1,
*weight_init_func=init(key,shape,dtype),
*bias_init_func=Lambda(key,shape)
)
)
Tree diagram
# `*` represents untrainable(static) nodes.
print(model.tree_diagram())
NN
โโโ l1=Linear
โ โโโ weight=f32[1,128]
โ โโโ bias=f32[128]
โ โ*โ in_features=1
โ โ*โ out_features=128
โ โ*โ weight_init_func=init(key,shape,dtype)
โ โ*โ bias_init_func=Lambda(key,shape)
โโโ l2=Linear
โ โโโ weight=f32[128,128]
โ โโโ bias=f32[128]
โ โ*โ in_features=128
โ โ*โ out_features=128
โ โ*โ weight_init_func=init(key,shape,dtype)
โ โ*โ bias_init_func=Lambda(key,shape)
โโโ l3=Linear
โโโ weight=f32[128,1]
โโโ bias=f32[1]
โ*โ in_features=128
โ*โ out_features=1
โ*โ weight_init_func=init(key,shape,dtype)
โ*โ bias_init_func=Lambda(key,shape)
Tree summary
>>> print(model.summary())
โโโโโโฌโโโโโโโฌโโโโโโโโโโฌโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโ
โNameโType โParam # โSize โConfig โ
โโโโโโผโโโโโโโผโโโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโโโค
โl1 โLinearโ256(0) โ1.00KB โweight=f32[1,128] โ
โ โ โ โ(0.00B)โbias=f32[128] โ
โโโโโโผโโโโโโโผโโโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโโโค
โl2 โLinearโ16,512(0)โ64.50KBโweight=f32[128,128]โ
โ โ โ โ(0.00B)โbias=f32[128] โ
โโโโโโผโโโโโโโผโโโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโโโค
โl3 โLinearโ129(0) โ516.00Bโweight=f32[128,1] โ
โ โ โ โ(0.00B)โbias=f32[1] โ
โโโโโโดโโโโโโโดโโโโโโโโโโดโโโโโโโโดโโโโโโโโโโโโโโโโโโโโ
Total count : 16,897(0)
Dynamic count : 16,897(0)
Frozen count : 0(0)
---------------------------------------------------
Total size : 66.00KB(0.00B)
Dynamic size : 66.00KB(0.00B)
Frozen size : 0.00B(0.00B)
===================================================
Tree summary with shape inference
**Using `model.summary(array=input_array)` `serket` can evaluate the shape propagation without evaluating the model , by using `jax` no-flop shape inference operations.**print(model.summary(array=x))
โโโโโโฌโโโโโโโฌโโโโโโโโโโฌโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโ
โNameโType โParam # โSize โConfig โInput/Outputโ
โโโโโโผโโโโโโโผโโโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค
โl1 โLinearโ256(0) โ1.00KB โweight=f32[1,128] โf32[100,1] โ
โ โ โ โ(0.00B)โbias=f32[128] โf32[100,128]โ
โโโโโโผโโโโโโโผโโโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค
โl2 โLinearโ16,512(0)โ64.50KBโweight=f32[128,128]โf32[100,128]โ
โ โ โ โ(0.00B)โbias=f32[128] โf32[100,128]โ
โโโโโโผโโโโโโโผโโโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค
โl3 โLinearโ129(0) โ516.00Bโweight=f32[128,1] โf32[100,128]โ
โ โ โ โ(0.00B)โbias=f32[1] โf32[100,1] โ
โโโโโโดโโโโโโโดโโโโโโโโโโดโโโโโโโโดโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโ
Total count : 16,897(0)
Dynamic count : 16,897(0)
Frozen count : 0(0)
----------------------------------------------------------------
Total size : 66.00KB(0.00B)
Dynamic size : 66.00KB(0.00B)
Frozen size : 0.00B(0.00B)
================================================================
โ๐ง Train
import matplotlib.pyplot as plt
x = jnp.linspace(0,1,100)[:,None]
y = x**3 + jax.random.uniform(jax.random.PRNGKey(0),(100,1))*0.01
@jax.value_and_grad
def loss_func(model,x,y):
return jnp.mean((model(x)-y)**2)
@jax.jit
def update(model,x,y):
value,grad = loss_func(model,x,y)
return value , model - 1e-3*grad
plt.plot(x,y,'-k',label='True')
plt.plot(x,model(x),'-r',label='Prediction')
plt.title("Before training")
plt.legend()
plt.show()
for _ in range(20_000):
value,model = update(model,x,y)
plt.plot(x,y,'-k',label='True')
plt.plot(x,model(x),'-r',label='Prediction')
plt.title("After training")
plt.legend()
plt.show()
๐ฅถ Freezing parameters /Fine tuning
๐ Filtering by masking
Filter by value
# get model negative values
negative_model = model.at[model<0].get()
# Set negative values to 0
zeroed_model = model.at[model<0].set(0)
# Apply `jnp.cos` to negative values
cosined_model = model.at[model<0].apply(jnp.cos)
Filter by field name
# get model layer named `l1`
l1_model = model.at[model == "l1" ].get()
# Set `l1` values to 0
zeroed_model = model.at[model == "l1" ].set(0)
# Apply `jnp.cos` to `l1`
cosined_model = model.at[model == "l1" ].apply(jnp.cos)
Filter by field type
# get all model `Linear` layers
l1_model = model.at[model == sk.nn.Linear ].get()
# Set `Linear` layers to 0
zeroed_model = model.at[model == sk.nn.Linear ].set(0)
# Apply `jnp.cos` to all `Linear` layers
cosined_model = model.at[model == sk.nn.Linear ].apply(jnp.cos)
Filter by mixed masks
# Set all `Linear` bias to 0
mask = (model == sk.nn.Linear) & (model == "bias" )
zero_bias_model = model.at[mask].set(0.)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
serket-0.0.5.tar.gz
(35.2 kB
view hashes)
Built Distribution
serket-0.0.5-py3-none-any.whl
(39.0 kB
view hashes)