JAX NN library.
Project description
The โจMagicalโจ JAX NN Library.
*Serket is the goddess of magic in Egyptian mythology
Installation |Description |Quick Example |Freezing/Fine tuning |Filtering
๐ ๏ธ Installation
pip install serket
Install development version
pip install git+https://github.com/ASEM000/serket
๐ Description
-
serket
aims to be the most intuitive and easy-to-use physics-based Neural network library in JAX. -
serket
is built on top ofpytreeclass
-
serket
currently implements
๐ง Neural network package: serket.nn
๐ง
Group | Layers |
---|---|
Linear | Linear , Bilinear ,Identity |
Densely connected | FNN (Fully connected network), PFNN (Parallel fully connected network) |
Convolution | Conv1D , Conv2D , Conv3D , Conv1DTranspose , Conv2DTranspose , Conv3DTranspose , DepthwiseConv1D , DepthwiseConv2D , DepthwiseConv3D , SeparableConv1D , SeparableConv2D , SeparableConv3D , Conv1DLocal , Conv2DLocal , Conv3DLocal , Conv1DSemiLocal , Conv2DSemiLocal , Conv3DSemiLocal * |
Containers | Sequential , Lambda |
Pooling | MaxPool1D , MaxPool2D , MaxPool3D , AvgPool1D , AvgPool2D , AvgPool3D GlobalMaxPool1D , GlobalMaxPool2D , GlobalMaxPool3D , GlobalAvgPool1D , GlobalAvgPool2D , GlobalAvgPool3D (kernex backend) |
Reshaping | Flatten , Unflatten , FlipLeftRight2D , FlipUpDown2D , Repeat1D , Repeat2D , Repeat3D , Resize1D , Resize2D , Resize3D , Upsample1D , Upsample2D , Upsample3D , Pad1D , Pad2D , Pad3D |
Crop | Crop1D , Crop2D , |
Normalization | LayerNorm , InstanceNorm , GroupNorm |
Blurring | AvgBlur2D , GaussianBlur2D |
Dropout | Dropout , Dropout1D , Dropout2D , Dropout3D , |
Random transforms | RandomCrop1D , RandomCrop2D , RandomApply , RandomCutout1D , RandomCutout2D , RandomZoom2D , RandomContrast2D |
Preprocessing | HistogramEqualization2D , AdjustContrast2D |
Activations | AdaptiveLeakyReLU ,AdaptiveReLU ,AdaptiveSigmoid ,AdaptiveTanh ,CeLU ,ELU ,GELU ,GLU , HardSILU ,HardShrink ,HardSigmoid ,HardSwish ,HardTanh ,LeakyReLU ,LogSigmoid ,LogSoftmax ,Mish ,PReLU ,ReLU ,ReLU6 ,SILU ,SeLU ,Sigmoid ,SoftPlus ,SoftShrink ,SoftSign ,Swish ,Tanh ,TanhShrink , ThresholdedReLU |
Blocks | VGG16Block , VGG19Block , UNetBlock |
*
Apply set of different shared kernel weights to each spatial group, where spatial groups<= Total patches of the input.
โโFinite difference package: serket.fd
โโ
Group | Function/Layer |
---|---|
Finite difference layer | Difference : apply finite difference to input array to any derivative order and accuracy |
Finite difference functions | - difference : finite difference of array with any accuracy and derivative order - generate_finitediff_coeffs : generate coeffs using sample points and derivative order - fgrad : differentiate functions (similar to jax.grad ) with custom accuracy and derivative order |
Vector operator layers | Curl , Divergence , Gradient , Laplacian |
Vector operator function | curl , divergence , gradient , laplacian |
โฉ Quick Example:
Lazy initialization
In cases where in_features
needs to be inferred from input, use None
instead of in_features
to infer the value at runtime.
However, since the lazy module initialize it's state after the first call (i.e. mutate it's state) jax
transformation ex: vmap, grad ...
is not allowed before initialization. Using any jax
transformation before initialization will throw a ValueError
.
import serket as sk
import jax
import jax.numpy as jnp
model = sk.nn.Sequential(
[
sk.nn.Conv2D(None, 128, 3),
sk.nn.ReLU(),
sk.nn.MaxPool2D(2, 2),
sk.nn.Conv2D(128, 64, 3),
sk.nn.ReLU(),
sk.nn.MaxPool2D(2, 2),
sk.nn.Flatten(),
sk.nn.Linear(None, 128),
sk.nn.ReLU(),
sk.nn.Linear(128, 1),
]
)
# print the first `Conv2D` layer before initialization
print(model[0].__repr__())
# Conv2D(
# weight=None,
# bias=None,
# *in_features=None,
# *out_features=None,
# *kernel_size=None,
# *strides=None,
# *padding=None,
# *input_dilation=None,
# *kernel_dilation=None,
# weight_init_func=None,
# bias_init_func=None,
# *groups=None
# )
try :
jax.vmap(model)(jnp.ones((10, 1,28, 28)))
except ValueError:
print("***** Not initialized *****")
# ***** Not initialized *****
# dry run to initialize the model
model(jnp.empty([3,128,128]))
print(model[0].__repr__())
# Conv2D(
# weight=f32[128,3,3,3],
# bias=f32[128,1,1],
# *in_features=3,
# *out_features=128,
# *kernel_size=(3,3),
# *strides=(1,1),
# *padding=((1,1),(1,1)),
# *input_dilation=(1,1),
# *kernel_dilation=(1,1),
# weight_init_func=Partial(glorot_uniform(key,shape,dtype)),
# bias_init_func=Partial(zeros(key,shape,dtype)),
# *groups=1
# )
Train MNIST
We will use tensorflow
datasets for dataloading. for more on interface of jax/tensorflow dataset see here
# imports
import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type="GPU")
import tensorflow_datasets as tfds
import tensorflow.experimental.numpy as tnp
import jax
import jax.numpy as jnp
import jax.random as jr
import optax # for gradient optimization
import serket as sk
import matplotlib.pyplot as plt
import functools as ft
# Construct a tf.data.Dataset
batch_size = 128
# convert the samples from integers to floating-point numbers
# and channel first format
def preprocess_data(x):
# convert to channel first format
image = tnp.moveaxis(x["image"], -1, 0)
# normalize to [0, 1]
image = tf.cast(image, tf.float32) / 255.0
# one-hot encode the labels
label = tf.one_hot(x["label"], 10) / 1.0
return {"image": image, "label": label}
ds_train, ds_test = tfds.load("mnist", split=["train", "test"], shuffle_files=True)
# (batches, batch_size, 1, 28, 28)
ds_train = ds_train.shuffle(1024).map(preprocess_data).batch(batch_size).prefetch(tf.data.AUTOTUNE)
# (batches, 1, 28, 28)
ds_test = ds_test.map(preprocess_data).prefetch(tf.data.AUTOTUNE)
๐๏ธ Model definition
We will use jax.vmap(model)
to apply model
on batches.
@sk.treeclass
class CNN:
def __init__(self):
self.conv1 = sk.nn.Conv2D(1, 32, (3, 3), padding="valid")
self.relu1 = sk.nn.ReLU()
self.pool1 = sk.nn.MaxPool2D((2, 2), strides=(2, 2))
self.conv2 = sk.nn.Conv2D(32, 64, (3, 3), padding="valid")
self.relu2 = sk.nn.ReLU()
self.pool2 = sk.nn.MaxPool2D((2, 2), strides=(2, 2))
self.flatten = sk.nn.Flatten(start_dim=0)
self.dropout = sk.nn.Dropout(0.5)
self.linear = sk.nn.Linear(5*5*64, 10)
def __call__(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.pool2(x)
x = self.flatten(x)
x = self.dropout(x)
x = self.linear(x)
return x
model = CNN()
๐จ Visualize model
Model summary
print(model.summary(show_config=False, array=jnp.empty((1, 28, 28))))
โโโโโโโโโฌโโโโโโโโโโฌโโโโโโโโโโฌโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโ
โName โType โParam # โSize โInput โOutput โ
โโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโผโโโโโโโโโโโโโโค
โconv1 โConv2D โ320(0) โ1.25KB(0.00B) โf32[1,28,28] โf32[32,26,26]โ
โโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโผโโโโโโโโโโโโโโค
โrelu1 โReLU โ0(0) โ0.00B(0.00B) โf32[32,26,26]โf32[32,26,26]โ
โโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโผโโโโโโโโโโโโโโค
โpool1 โMaxPool2Dโ0(0) โ0.00B(0.00B) โf32[32,26,26]โf32[32,13,13]โ
โโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโผโโโโโโโโโโโโโโค
โconv2 โConv2D โ18,496(0)โ72.25KB(0.00B)โf32[32,13,13]โf32[64,11,11]โ
โโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโผโโโโโโโโโโโโโโค
โrelu2 โReLU โ0(0) โ0.00B(0.00B) โf32[64,11,11]โf32[64,11,11]โ
โโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโผโโโโโโโโโโโโโโค
โpool2 โMaxPool2Dโ0(0) โ0.00B(0.00B) โf32[64,11,11]โf32[64,5,5] โ
โโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโผโโโโโโโโโโโโโโค
โflattenโFlatten โ0(0) โ0.00B(0.00B) โf32[64,5,5] โf32[1600] โ
โโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโผโโโโโโโโโโโโโโค
โdropoutโDropout โ0(0) โ0.00B(0.00B) โf32[1600] โf32[1600] โ
โโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโผโโโโโโโโโโโโโโค
โlinear โLinear โ16,010(0)โ62.54KB(0.00B)โf32[1600] โf32[10] โ
โโโโโโโโโดโโโโโโโโโโดโโโโโโโโโโดโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโดโโโโโโโโโโโโโโ
Total count : 34,826(0)
Dynamic count : 34,826(0)
Frozen count : 0(0)
------------------------------------------------------------------------
Total size : 136.04KB(0.00B)
Dynamic size : 136.04KB(0.00B)
Frozen size : 0.00B(0.00B)
========================================================================
tree diagram
print(model.tree_diagram())
CNN
โโโ conv1=Conv2D
โ โโโ weight=f32[32,1,3,3]
โ โโโ bias=f32[32,1,1]
โ โ*โ in_features=1
โ โ*โ out_features=32
โ โ*โ kernel_size=(3, 3)
โ โ*โ strides=(1, 1)
โ โ*โ padding=((0, 0), (0, 0))
โ โ*โ input_dilation=(1, 1)
โ โ*โ kernel_dilation=(1, 1)
โ โโโ weight_init_func=Partial(init(key,shape,dtype))
โ โโโ bias_init_func=Partial(zeros(key,shape,dtype))
โ โ*โ groups=1
โโโ relu1=ReLU
โ*โ pool1=MaxPool2D
โ โ*โ kernel_size=(2, 2)
โ โ*โ strides=(2, 2)
โ โ*โ padding='valid'
โโโ conv2=Conv2D
โ โโโ weight=f32[64,32,3,3]
โ โโโ bias=f32[64,1,1]
โ โ*โ in_features=32
โ โ*โ out_features=64
โ โ*โ kernel_size=(3, 3)
โ โ*โ strides=(1, 1)
โ โ*โ padding=((0, 0), (0, 0))
โ โ*โ input_dilation=(1, 1)
โ โ*โ kernel_dilation=(1, 1)
โ โโโ weight_init_func=Partial(init(key,shape,dtype))
โ โโโ bias_init_func=Partial(zeros(key,shape,dtype))
โ โ*โ groups=1
โโโ relu2=ReLU
โ*โ pool2=MaxPool2D
โ โ*โ kernel_size=(2, 2)
โ โ*โ strides=(2, 2)
โ โ*โ padding='valid'
โ*โ flatten=Flatten
โ โ*โ start_dim=0
โ โ*โ end_dim=-1
โโโ dropout=Dropout
โ โ*โ p=0.5
โ โโโ eval=None
โโโ linear=Linear
โโโ weight=f32[1600,10]
โโโ bias=f32[10]
โ*โ in_features=1600
โ*โ out_features=10
Plot sample predictions before training
# set all dropout off
test_model = model.at[model == "eval"].set(True, is_leaf=lambda x: x is None)
def show_images_with_predictions(model, images, one_hot_labels):
logits = jax.vmap(model)(images)
predictions = jnp.argmax(logits, axis=-1)
fig, axes = plt.subplots(5, 5, figsize=(10, 10))
for i, ax in enumerate(axes.flat):
ax.imshow(images[i].reshape(28, 28), cmap="binary")
ax.set(title=f"Prediction: {predictions[i]}\nLabel: {jnp.argmax(one_hot_labels[i], axis=-1)}")
ax.set_xticks([])
ax.set_yticks([])
plt.show()
example = ds_test.take(25).as_numpy_iterator()
example = list(example)
sample_test_images = jnp.stack([x["image"] for x in example])
sample_test_labels = jnp.stack([x["label"] for x in example])
show_images_with_predictions(test_model, sample_test_images, sample_test_labels)
๐ Train the model
@ft.partial(jax.value_and_grad, has_aux=True)
def loss_func(model, batched_images, batched_one_hot_labels):
logits = jax.vmap(model)(batched_images)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=batched_one_hot_labels))
return loss, logits
# using optax for gradient updates
optim = optax.adam(1e-3)
optim_state = optim.init(model)
@jax.jit
def batch_step(model, batched_images, batched_one_hot_labels, optim_state):
(loss, logits), grads = loss_func(model, batched_images, batched_one_hot_labels)
updates, optim_state = optim.update(grads, optim_state)
model = optax.apply_updates(model, updates)
accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == jnp.argmax(batched_one_hot_labels, axis=-1))
return model, optim_state, loss, accuracy
epochs = 5
for i in range(epochs):
epoch_accuracy = []
epoch_loss = []
for example in ds_train.as_numpy_iterator():
image, label = example["image"], example["label"]
model, optim_state, loss, accuracy = batch_step(model, image, label, optim_state)
epoch_accuracy.append(accuracy)
epoch_loss.append(loss)
epoch_loss = jnp.mean(jnp.array(epoch_loss))
epoch_accuracy = jnp.mean(jnp.array(epoch_accuracy))
print(f"epoch:{i+1:00d}\tloss:{epoch_loss:.4f}\taccuracy:{epoch_accuracy:.4f}")
# epoch:1 loss:0.2706 accuracy:0.9268
# epoch:2 loss:0.0725 accuracy:0.9784
# epoch:3 loss:0.0533 accuracy:0.9836
# epoch:4 loss:0.0442 accuracy:0.9868
# epoch:5 loss:0.0368 accuracy:0.9889
๐จ Visualize After training
test_model = model.at[model == "eval"].set(True, is_leaf=lambda x: x is None)
show_images_with_predictions(test_model, sample_test_images, sample_test_labels)
๐ฅถ Freezing parameters /Fine tuning
โจSee here for more about freezingโจ
๐ Filtering by masking
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.