Skip to main content

JAX NN library.

Project description

The โœจMagicalโœจ JAX NN Library.

*Serket is the goddess of magic in Egyptian mythology

Installation |Description |Quick Example |Freezing/Fine tuning |Filtering

Tests pyver codestyle Downloads codecov DOI PyPI

๐Ÿ› ๏ธ Installation

pip install serket

Install development version

pip install git+https://github.com/ASEM000/serket

๐Ÿ“– Description

  • serket aims to be the most intuitive and easy-to-use Neural network library in JAX.

  • serket is built on top of pytreeclass

  • serket currently implements

Group Layers
Linear Linear, Bilinear,Identity
Densely connected FNN (Fully connected network), PFNN (Parallel fully connected network)
Convolution Conv1D, Conv2D, Conv3D, Conv1DTranspose , Conv2DTranspose, Conv3DTranspose, DepthwiseConv1D, DepthwiseConv2D, DepthwiseConv3D, SeparableConv1D, SeparableConv2D, SeparableConv3D, Conv1DLocal, Conv2DLocal, Conv3DLocal
Containers Sequential, Lambda
Pooling MaxPool1D, MaxPool2D, MaxPool3D, AvgPool1D, AvgPool2D, AvgPool3D GlobalMaxPool1D, GlobalMaxPool2D, GlobalMaxPool3D, GlobalAvgPool1D, GlobalAvgPool2D, GlobalAvgPool3D
Reshaping Flatten, Unflatten, FlipLeftRight2D, FlipUpDown2D, Repeat1D, Repeat2D, Repeat3D, Resize1D, Resize2D, Resize3D, Upsampling1D, Upsampling2D, Upsampling3D, Padding1D, Padding2D, Padding3D
Crop Crop1D, Crop2D,
Normalization LayerNorm, InstanceNorm, GroupNorm
Blurring AvgBlur2D, GaussianBlur2D
Dropout Dropout, Dropout1D, Dropout2D, Dropout3D,
Physics Laplace2D
Random transforms RandomCrop1D, RandomCrop2D, RandomApply, RandomCutout1D, RandomCutout2D, RandomZoom2D, RandomContrast2D
Preprocessing HistogramEqualization2D, AdjustContrast2D
Activations AdaptiveLeakyReLU,AdaptiveReLU,AdaptiveSigmoid,AdaptiveTanh,
CeLU,ELU,GELU,GLU
,HardSILU,HardShrink,HardSigmoid,HardSwish,HardTanh,
LeakyReLU,LogSigmoid,LogSoftmax,Mish,PReLU,
ReLU,ReLU6,SILU,SeLU,Sigmoid,SoftPlus,SoftShrink,
SoftSign,Swish,Tanh,TanhShrink, ThresholdedReLU
Blocks VGG16Block, VGG19Block, UNetBlock

โฉ Quick Example: Train MNIST

We will use tensorflow datasets for dataloading. for more on interface of jax/tensorflow dataset see here

# imports
import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type="GPU")
import tensorflow_datasets as tfds
import tensorflow.experimental.numpy as tnp
import jax
import jax.numpy as jnp
import jax.random as jr 
import optax  # for gradient optimization
import serket as sk
import matplotlib.pyplot as plt
import functools as ft
# Construct a tf.data.Dataset
batch_size = 128

# convert the samples from integers to floating-point numbers
# and channel first format
def preprocess_data(x):
    # convert to channel first format
    image = tnp.moveaxis(x["image"], -1, 0)
    # normalize to [0, 1]
    image = tf.cast(image, tf.float32) / 255.0

    # one-hot encode the labels
    label = tf.one_hot(x["label"], 10) / 1.0
    return {"image": image, "label": label}


ds_train, ds_test = tfds.load("mnist", split=["train", "test"], shuffle_files=True)
# (batches, batch_size, 1, 28, 28)
ds_train = ds_train.shuffle(1024).map(preprocess_data).batch(batch_size).prefetch(tf.data.AUTOTUNE)

# (batches, 1, 28, 28)
ds_test = ds_test.map(preprocess_data).prefetch(tf.data.AUTOTUNE)

๐Ÿ—๏ธ Model definition

We will use jax.vmap(model) to apply model on batches.

@sk.treeclass
class CNN:
    def __init__(self):
        self.conv1 = sk.nn.Conv2D(1, 32, (3, 3), padding="valid")
        self.relu1 = sk.nn.ReLU()
        self.pool1 = sk.nn.MaxPool2D((2, 2), strides=(2, 2))
        self.conv2 = sk.nn.Conv2D(32, 64, (3, 3), padding="valid")
        self.relu2 = sk.nn.ReLU()
        self.pool2 = sk.nn.MaxPool2D((2, 2), strides=(2, 2))
        self.flatten = sk.nn.Flatten(start_dim=0)
        self.dropout = sk.nn.Dropout(0.5)
        self.linear = sk.nn.Linear(5*5*64, 10)

    def __call__(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.dropout(x)
        x = self.linear(x)
        return x

model = CNN()

๐ŸŽจ Visualize model

Model summary
print(model.summary(show_config=False, array=jnp.empty((1, 28, 28))))  
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚Name   โ”‚Type     โ”‚Param #  โ”‚Size          โ”‚Input        โ”‚Output       โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚conv1  โ”‚Conv2D   โ”‚320(0)   โ”‚1.25KB(0.00B) โ”‚f32[1,28,28] โ”‚f32[32,26,26]โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚relu1  โ”‚ReLU     โ”‚0(0)     โ”‚0.00B(0.00B)  โ”‚f32[32,26,26]โ”‚f32[32,26,26]โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚pool1  โ”‚MaxPool2Dโ”‚0(0)     โ”‚0.00B(0.00B)  โ”‚f32[32,26,26]โ”‚f32[32,13,13]โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚conv2  โ”‚Conv2D   โ”‚18,496(0)โ”‚72.25KB(0.00B)โ”‚f32[32,13,13]โ”‚f32[64,11,11]โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚relu2  โ”‚ReLU     โ”‚0(0)     โ”‚0.00B(0.00B)  โ”‚f32[64,11,11]โ”‚f32[64,11,11]โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚pool2  โ”‚MaxPool2Dโ”‚0(0)     โ”‚0.00B(0.00B)  โ”‚f32[64,11,11]โ”‚f32[64,5,5]  โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚flattenโ”‚Flatten  โ”‚0(0)     โ”‚0.00B(0.00B)  โ”‚f32[64,5,5]  โ”‚f32[1600]    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚dropoutโ”‚Dropout  โ”‚0(0)     โ”‚0.00B(0.00B)  โ”‚f32[1600]    โ”‚f32[1600]    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚linear โ”‚Linear   โ”‚16,010(0)โ”‚62.54KB(0.00B)โ”‚f32[1600]    โ”‚f32[10]      โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
Total count :	34,826(0)
Dynamic count :	34,826(0)
Frozen count :	0(0)
------------------------------------------------------------------------
Total size :	136.04KB(0.00B)
Dynamic size :	136.04KB(0.00B)
Frozen size :	0.00B(0.00B)
========================================================================
tree diagram
print(model.tree_diagram())
CNN
    โ”œโ”€โ”€ conv1=Conv2D
    โ”‚   โ”œโ”€โ”€ weight=f32[32,1,3,3]
    โ”‚   โ”œโ”€โ”€ bias=f32[32,1,1]
    โ”‚   โ”œ*โ”€ in_features=1
    โ”‚   โ”œ*โ”€ out_features=32
    โ”‚   โ”œ*โ”€ kernel_size=(3, 3)
    โ”‚   โ”œ*โ”€ strides=(1, 1)
    โ”‚   โ”œ*โ”€ padding=((0, 0), (0, 0))
    โ”‚   โ”œ*โ”€ input_dilation=(1, 1)
    โ”‚   โ”œ*โ”€ kernel_dilation=(1, 1)
    โ”‚   โ”œโ”€โ”€ weight_init_func=Partial(init(key,shape,dtype))
    โ”‚   โ”œโ”€โ”€ bias_init_func=Partial(zeros(key,shape,dtype))
    โ”‚   โ””*โ”€ groups=1    
    โ”œโ”€โ”€ relu1=ReLU  
    โ”œ*โ”€ pool1=MaxPool2D
    โ”‚   โ”œ*โ”€ kernel_size=(2, 2)
    โ”‚   โ”œ*โ”€ strides=(2, 2)
    โ”‚   โ””*โ”€ padding='valid' 
    โ”œโ”€โ”€ conv2=Conv2D
    โ”‚   โ”œโ”€โ”€ weight=f32[64,32,3,3]
    โ”‚   โ”œโ”€โ”€ bias=f32[64,1,1]
    โ”‚   โ”œ*โ”€ in_features=32
    โ”‚   โ”œ*โ”€ out_features=64
    โ”‚   โ”œ*โ”€ kernel_size=(3, 3)
    โ”‚   โ”œ*โ”€ strides=(1, 1)
    โ”‚   โ”œ*โ”€ padding=((0, 0), (0, 0))
    โ”‚   โ”œ*โ”€ input_dilation=(1, 1)
    โ”‚   โ”œ*โ”€ kernel_dilation=(1, 1)
    โ”‚   โ”œโ”€โ”€ weight_init_func=Partial(init(key,shape,dtype))
    โ”‚   โ”œโ”€โ”€ bias_init_func=Partial(zeros(key,shape,dtype))
    โ”‚   โ””*โ”€ groups=1    
    โ”œโ”€โ”€ relu2=ReLU  
    โ”œ*โ”€ pool2=MaxPool2D
    โ”‚   โ”œ*โ”€ kernel_size=(2, 2)
    โ”‚   โ”œ*โ”€ strides=(2, 2)
    โ”‚   โ””*โ”€ padding='valid' 
    โ”œ*โ”€ flatten=Flatten
    โ”‚   โ”œ*โ”€ start_dim=0
    โ”‚   โ””*โ”€ end_dim=-1  
    โ”œโ”€โ”€ dropout=Dropout
    โ”‚   โ”œ*โ”€ p=0.5
    โ”‚   โ””โ”€โ”€ eval=None   
    โ””โ”€โ”€ linear=Linear
        โ”œโ”€โ”€ weight=f32[1600,10]
        โ”œโ”€โ”€ bias=f32[10]
        โ”œ*โ”€ in_features=1600
        โ””*โ”€ out_features=10  
    
Plot sample predictions before training
 
# set all dropout off
test_model = model.at[model == "eval"].set(True, is_leaf=lambda x: x is None)

def show_images_with_predictions(model, images, one_hot_labels):
    logits = jax.vmap(model)(images)
    predictions = jnp.argmax(logits, axis=-1)
    fig, axes = plt.subplots(5, 5, figsize=(10, 10))
    for i, ax in enumerate(axes.flat):
        ax.imshow(images[i].reshape(28, 28), cmap="binary")
        ax.set(title=f"Prediction: {predictions[i]}\nLabel: {jnp.argmax(labels[i], axis=-1)}")
        ax.set_xticks([])
        ax.set_yticks([])
    plt.show()

example = ds_test.take(25).as_numpy_iterator()
example = list(example)
sample_test_images = jnp.stack([x["image"] for x in example])
sample_test_labels = jnp.stack([x["label"] for x in example])

show_images_with_predictions(test_model, sample_test_images, sample_test_labels)

image

๐Ÿƒ Train the model

@ft.partial(jax.value_and_grad, has_aux=True)
def loss_func(model, batched_images, batched_one_hot_labels):
    logits = jax.vmap(model)(batched_images)
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=batched_one_hot_labels))
    return loss, logits


# using optax for gradient updates
optim = optax.adam(1e-3)
optim_state = optim.init(model)


@jax.jit
def batch_step(model, batched_images, batched_one_hot_labels, optim_state):
    (loss, logits), grads = loss_func(model, batched_images, batched_one_hot_labels)
    updates, optim_state = optim.update(grads, optim_state)
    model = optax.apply_updates(model, updates)
    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == jnp.argmax(batched_one_hot_labels, axis=-1))
    return model, optim_state, loss, accuracy


epochs = 5

for i in range(epochs):
    epoch_accuracy = []
    epoch_loss = []

    for example in ds_train.as_numpy_iterator():
        image, label = example["image"], example["label"]
        model, optim_state, loss, accuracy = batch_step(model, image, label, optim_state)
        epoch_accuracy.append(accuracy)
        epoch_loss.append(loss)

    epoch_loss = jnp.mean(jnp.array(epoch_loss))
    epoch_accuracy = jnp.mean(jnp.array(epoch_accuracy))

    print(f"epoch:{i+1:00d}\tloss:{epoch_loss:.4f}\taccuracy:{epoch_accuracy:.4f}")
    
# epoch:1	loss:0.2706	accuracy:0.9268
# epoch:2	loss:0.0725	accuracy:0.9784
# epoch:3	loss:0.0533	accuracy:0.9836
# epoch:4	loss:0.0442	accuracy:0.9868
# epoch:5	loss:0.0368	accuracy:0.9889

๐ŸŽจ Visualize After training

test_model = model.at[model == "eval"].set(True, is_leaf=lambda x: x is None)
show_images_with_predictions(test_model, sample_test_images, sample_test_labels)

image

๐Ÿฅถ Freezing parameters /Fine tuning

โœจSee here for more about freezingโœจ

๐Ÿ”˜ Filtering by masking

โœจSee here for more about filterning โœจ

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

serket-0.0.6.tar.gz (50.2 kB view hashes)

Uploaded Source

Built Distribution

serket-0.0.6-py3-none-any.whl (59.6 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page