JAX NN library.

The โœจMagicalโœจ JAX NN Library.

*Serket is the goddess of magic in Egyptian mythology

๐Ÿ› ๏ธ Installation

pip install serket

Install development version

pip install git+

๐Ÿ“– Description

  • serket aims to be the most intuitive and easy-to-use Neural network library in JAX.

  • serket is built on top of pytreeclass

  • serket currently implements

Group Layers
Linear Linear, Bilinear,Identity
Densely connected FNN (Fully connected network), PFNN (Parallel fully connected network)
Convolution Conv1D, Conv2D, Conv3D, Conv1DTranspose , Conv2DTranspose, Conv3DTranspose, DepthwiseConv1D, DepthwiseConv2D, DepthwiseConv3D, SeparableConv1D, SeparableConv2D, SeparableConv3D, Conv1DLocal, Conv2DLocal, Conv3DLocal
Containers Sequential, Lambda
Pooling MaxPool1D, MaxPool2D, MaxPool3D, AvgPool1D, AvgPool2D, AvgPool3D GlobalMaxPool1D, GlobalMaxPool2D, GlobalMaxPool3D, GlobalAvgPool1D, GlobalAvgPool2D, GlobalAvgPool3D
Reshaping Flatten, Unflatten, FlipLeftRight2D, FlipUpDown2D, Repeat1D, Repeat2D, Repeat3D, Resize1D, Resize2D, Resize3D, Upsampling1D, Upsampling2D, Upsampling3D, Padding1D, Padding2D, Padding3D
Crop Crop1D, Crop2D,
Normalization LayerNorm, InstanceNorm, GroupNorm
Blurring AvgBlur2D, GaussianBlur2D
Dropout Dropout, Dropout1D, Dropout2D, Dropout3D,
Physics Laplace2D
Random transforms RandomCrop1D, RandomCrop2D, RandomApply, RandomCutout1D, RandomCutout2D, RandomZoom2D, RandomContrast2D
Preprocessing HistogramEqualization2D, AdjustContrast2D
Activations AdaptiveLeakyReLU,AdaptiveReLU,AdaptiveSigmoid,AdaptiveTanh,
SoftSign,Swish,Tanh,TanhShrink, ThresholdedReLU
Blocks VGG16Block, VGG19Block, UNetBlock

โฉ Quick Example: Train MNIST

We will use tensorflow datasets for dataloading. for more on interface of jax/tensorflow dataset see here

# imports
import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type="GPU")
import tensorflow_datasets as tfds
import tensorflow.experimental.numpy as tnp
import jax
import jax.numpy as jnp
import jax.random as jr 
import optax  # for gradient optimization
import serket as sk
import matplotlib.pyplot as plt
import functools as ft
# Construct a
batch_size = 128

# convert the samples from integers to floating-point numbers
# and channel first format
def preprocess_data(x):
    # convert to channel first format
    image = tnp.moveaxis(x["image"], -1, 0)
    # normalize to [0, 1]
    image = tf.cast(image, tf.float32) / 255.0

    # one-hot encode the labels
    label = tf.one_hot(x["label"], 10) / 1.0
    return {"image": image, "label": label}

ds_train, ds_test = tfds.load("mnist", split=["train", "test"], shuffle_files=True)
# (batches, batch_size, 1, 28, 28)
ds_train = ds_train.shuffle(1024).map(preprocess_data).batch(batch_size).prefetch(

# (batches, 1, 28, 28)
ds_test =

๐Ÿ—๏ธ Model definition

We will use jax.vmap(model) to apply model on batches.

class CNN:
    def __init__(self):
        self.conv1 = sk.nn.Conv2D(1, 32, (3, 3), padding="valid")
        self.relu1 = sk.nn.ReLU()
        self.pool1 = sk.nn.MaxPool2D((2, 2), strides=(2, 2))
        self.conv2 = sk.nn.Conv2D(32, 64, (3, 3), padding="valid")
        self.relu2 = sk.nn.ReLU()
        self.pool2 = sk.nn.MaxPool2D((2, 2), strides=(2, 2))
        self.flatten = sk.nn.Flatten(start_dim=0)
        self.dropout = sk.nn.Dropout(0.5)
        self.linear = sk.nn.Linear(5*5*64, 10)

    def __call__(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.dropout(x)
        x = self.linear(x)
        return x

model = CNN()

๐ŸŽจ Visualize model

Model summary
print(model.summary(show_config=False, array=jnp.empty((1, 28, 28))))  
โ”‚Name   โ”‚Type     โ”‚Param #  โ”‚Size          โ”‚Input        โ”‚Output       โ”‚
โ”‚conv1  โ”‚Conv2D   โ”‚320(0)   โ”‚1.25KB(0.00B) โ”‚f32[1,28,28] โ”‚f32[32,26,26]โ”‚
โ”‚relu1  โ”‚ReLU     โ”‚0(0)     โ”‚0.00B(0.00B)  โ”‚f32[32,26,26]โ”‚f32[32,26,26]โ”‚
โ”‚pool1  โ”‚MaxPool2Dโ”‚0(0)     โ”‚0.00B(0.00B)  โ”‚f32[32,26,26]โ”‚f32[32,13,13]โ”‚
โ”‚conv2  โ”‚Conv2D   โ”‚18,496(0)โ”‚72.25KB(0.00B)โ”‚f32[32,13,13]โ”‚f32[64,11,11]โ”‚
โ”‚relu2  โ”‚ReLU     โ”‚0(0)     โ”‚0.00B(0.00B)  โ”‚f32[64,11,11]โ”‚f32[64,11,11]โ”‚
โ”‚pool2  โ”‚MaxPool2Dโ”‚0(0)     โ”‚0.00B(0.00B)  โ”‚f32[64,11,11]โ”‚f32[64,5,5]  โ”‚
โ”‚flattenโ”‚Flatten  โ”‚0(0)     โ”‚0.00B(0.00B)  โ”‚f32[64,5,5]  โ”‚f32[1600]    โ”‚
โ”‚dropoutโ”‚Dropout  โ”‚0(0)     โ”‚0.00B(0.00B)  โ”‚f32[1600]    โ”‚f32[1600]    โ”‚
โ”‚linear โ”‚Linear   โ”‚16,010(0)โ”‚62.54KB(0.00B)โ”‚f32[1600]    โ”‚f32[10]      โ”‚
Total count :	34,826(0)
Dynamic count :	34,826(0)
Frozen count :	0(0)
Total size :	136.04KB(0.00B)
Dynamic size :	136.04KB(0.00B)
Frozen size :	0.00B(0.00B)
tree diagram
    โ”œโ”€โ”€ conv1=Conv2D
    โ”‚   โ”œโ”€โ”€ weight=f32[32,1,3,3]
    โ”‚   โ”œโ”€โ”€ bias=f32[32,1,1]
    โ”‚   โ”œ*โ”€ in_features=1
    โ”‚   โ”œ*โ”€ out_features=32
    โ”‚   โ”œ*โ”€ kernel_size=(3, 3)
    โ”‚   โ”œ*โ”€ strides=(1, 1)
    โ”‚   โ”œ*โ”€ padding=((0, 0), (0, 0))
    โ”‚   โ”œ*โ”€ input_dilation=(1, 1)
    โ”‚   โ”œ*โ”€ kernel_dilation=(1, 1)
    โ”‚   โ”œโ”€โ”€ weight_init_func=Partial(init(key,shape,dtype))
    โ”‚   โ”œโ”€โ”€ bias_init_func=Partial(zeros(key,shape,dtype))
    โ”‚   โ””*โ”€ groups=1    
    โ”œโ”€โ”€ relu1=ReLU  
    โ”œ*โ”€ pool1=MaxPool2D
    โ”‚   โ”œ*โ”€ kernel_size=(2, 2)
    โ”‚   โ”œ*โ”€ strides=(2, 2)
    โ”‚   โ””*โ”€ padding='valid' 
    โ”œโ”€โ”€ conv2=Conv2D
    โ”‚   โ”œโ”€โ”€ weight=f32[64,32,3,3]
    โ”‚   โ”œโ”€โ”€ bias=f32[64,1,1]
    โ”‚   โ”œ*โ”€ in_features=32
    โ”‚   โ”œ*โ”€ out_features=64
    โ”‚   โ”œ*โ”€ kernel_size=(3, 3)
    โ”‚   โ”œ*โ”€ strides=(1, 1)
    โ”‚   โ”œ*โ”€ padding=((0, 0), (0, 0))
    โ”‚   โ”œ*โ”€ input_dilation=(1, 1)
    โ”‚   โ”œ*โ”€ kernel_dilation=(1, 1)
    โ”‚   โ”œโ”€โ”€ weight_init_func=Partial(init(key,shape,dtype))
    โ”‚   โ”œโ”€โ”€ bias_init_func=Partial(zeros(key,shape,dtype))
    โ”‚   โ””*โ”€ groups=1    
    โ”œโ”€โ”€ relu2=ReLU  
    โ”œ*โ”€ pool2=MaxPool2D
    โ”‚   โ”œ*โ”€ kernel_size=(2, 2)
    โ”‚   โ”œ*โ”€ strides=(2, 2)
    โ”‚   โ””*โ”€ padding='valid' 
    โ”œ*โ”€ flatten=Flatten
    โ”‚   โ”œ*โ”€ start_dim=0
    โ”‚   โ””*โ”€ end_dim=-1  
    โ”œโ”€โ”€ dropout=Dropout
    โ”‚   โ”œ*โ”€ p=0.5
    โ”‚   โ””โ”€โ”€ eval=None   
    โ””โ”€โ”€ linear=Linear
        โ”œโ”€โ”€ weight=f32[1600,10]
        โ”œโ”€โ”€ bias=f32[10]
        โ”œ*โ”€ in_features=1600
        โ””*โ”€ out_features=10  
Plot sample predictions before training
# set all dropout off
test_model =[model == "eval"].set(True, is_leaf=lambda x: x is None)

def show_images_with_predictions(model, images, one_hot_labels):
    logits = jax.vmap(model)(images)
    predictions = jnp.argmax(logits, axis=-1)
    fig, axes = plt.subplots(5, 5, figsize=(10, 10))
    for i, ax in enumerate(axes.flat):
        ax.imshow(images[i].reshape(28, 28), cmap="binary")
        ax.set(title=f"Prediction: {predictions[i]}\nLabel: {jnp.argmax(labels[i], axis=-1)}")

example = ds_test.take(25).as_numpy_iterator()
example = list(example)
sample_test_images = jnp.stack([x["image"] for x in example])
sample_test_labels = jnp.stack([x["label"] for x in example])

show_images_with_predictions(test_model, sample_test_images, sample_test_labels)


๐Ÿƒ Train the model

@ft.partial(jax.value_and_grad, has_aux=True)
def loss_func(model, batched_images, batched_one_hot_labels):
    logits = jax.vmap(model)(batched_images)
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=batched_one_hot_labels))
    return loss, logits

# using optax for gradient updates
optim = optax.adam(1e-3)
optim_state = optim.init(model)

def batch_step(model, batched_images, batched_one_hot_labels, optim_state):
    (loss, logits), grads = loss_func(model, batched_images, batched_one_hot_labels)
    updates, optim_state = optim.update(grads, optim_state)
    model = optax.apply_updates(model, updates)
    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == jnp.argmax(batched_one_hot_labels, axis=-1))
    return model, optim_state, loss, accuracy

epochs = 5

for i in range(epochs):
    epoch_accuracy = []
    epoch_loss = []

    for example in ds_train.as_numpy_iterator():
        image, label = example["image"], example["label"]
        model, optim_state, loss, accuracy = batch_step(model, image, label, optim_state)

    epoch_loss = jnp.mean(jnp.array(epoch_loss))
    epoch_accuracy = jnp.mean(jnp.array(epoch_accuracy))

# epoch:1	loss:0.2706	accuracy:0.9268
# epoch:2	loss:0.0725	accuracy:0.9784
# epoch:3	loss:0.0533	accuracy:0.9836
# epoch:4	loss:0.0442	accuracy:0.9868
# epoch:5	loss:0.0368	accuracy:0.9889

๐ŸŽจ Visualize After training

test_model =[model == "eval"].set(True, is_leaf=lambda x: x is None)
show_images_with_predictions(test_model, sample_test_images, sample_test_labels)


๐Ÿฅถ Freezing parameters /Fine tuning

โœจSee here for more about freezingโœจ

๐Ÿ”˜ Filtering by masking

โœจSee here for more about filterning โœจ

