Skip to main content

JAX NN library.

Project description

The โœจMagicalโœจ JAX Scientific ML Library.

*Serket is the goddess of magic in Egyptian mythology

Installation |Description |Quick Example |Freezing/Fine tuning |Filtering

Tests pyver codestyle Downloads codecov DOI PyPI

๐Ÿ› ๏ธ Installation

pip install serket

Install development version

pip install git+https://github.com/ASEM000/serket

๐Ÿ“– Description and motivation

  • serket aims to be the most intuitive and easy-to-use physics-based Neural network library in JAX.
  • serket is fully transparent to jax transformation (e.g. vmap,grad,jit,...)
  • serket current aim to facilitate the integration of numerical methods in a NN setting (see examples for more)

Layer structure

serket is built on top of PyTreeClass, this means that layers are represented as a PyTree whose leaves are the layer parameters.

๐Ÿง  Neural network package: serket.nn ๐Ÿง 

Group Layers
Linear - Linear, Bilinear, Multilinear, GeneralLinear, Identity, Embedding
Densely connected - FNN (Fully connected network),
Convolution - {Conv,FFTConv}{1D,2D,3D}
- {Conv,FFTConv}{1D,2D,3D}Transpose
- {Depthwise,Separable}{Conv,FFTConv}{1D,2D,3D}
- Conv{1D,2D,3D}Local
Containers - Sequential, Lambda
Pooling
(kernex backend)
- {Avg,Max,LP}Pool{1D,2D,3D}
- Global{Avg,Max}Pool{1D,2D,3D}
- Adaptive{Avg,Max}Pool{1D,2D,3D}
Reshaping - Flatten, Unflatten,
- FlipLeftRight2D, FlipUpDown2D
- Resize{1D,2D,3D}
- Upsample{1D,2D,3D}
- Pad{1D,2D,3D}
Crop - Crop{1D,2D}
Normalization - {Layer,Instance,Group}Norm
Blurring - {Avg,Gaussian}Blur2D
Dropout - Dropout
- Dropout{1D,2D,3D}
Random transforms - RandomCrop{1D,2D}
- RandomApply,
- RandomCutout{1D,2D}
- RandomZoom2D,
- RandomContrast2D
Misc - HistogramEqualization2D, AdjustContrast2D, Filter2D, PixelShuffle2D
Activations - Adaptive{LeakyReLU,ReLU,Sigmoid,Tanh},
- CeLU,ELU,GELU,GLU
- Hard{SILU,Shrink,Sigmoid,Swish,Tanh},
- Soft{Plus,Sign,Shrink}
- LeakyReLU,LogSigmoid,LogSoftmax,Mish,PReLU,
- ReLU,ReLU6,SILU,SeLU,Sigmoid
- Swish,Tanh,TanhShrink, ThresholdedReLU, Snake, Stan, SquarePlus
Recurrent cells - {SimpleRNN,LSTM,GRU}Cell
- Conv{LSTM,GRU}{1D,2D,3D}Cell
Blocks - VGG{16,19}Block, UNetBlock

โฉ Examples:

Linear layers examples
import jax.numpy as jnp

import serket as sk

# Linear
x = jnp.ones([1, 2, 3, 4])
l1 = sk.nn.Linear(4, 5)  # last dim is 4, output dim is 5
print(l1(x).shape)  # (1, 2, 3, 5)

# Bilinear
x1, x2 = jnp.ones([1, 2, 3, 4]), jnp.ones([1, 2, 3, 5])
l2 = sk.nn.Bilinear(4, 5, 6)  # last dim of the input x1,x2 are 4,5, output dim is 6
print(l2(x1, x2).shape)  # (1, 2, 3, 6)

# Multilinear
x1, x2, x3 = jnp.ones([1, 2, 3, 4]), jnp.ones([1, 2, 3, 5]), jnp.ones([1, 2, 3, 6])
l3 = sk.nn.Multilinear((4, 5, 6), 7)  # last dim for x1,x2,x3 = 4,5,6, output dim is 7
print(l3(x1, x2, x3).shape)  # (1, 2, 3, 7)

# GeneralLinear
x = jnp.ones([4, 5, 6, 7])
# apply a linear layer to axis 1,2,3, with dim = (5, 6, 7) and output dim is 5
l4 = sk.nn.GeneralLinear((5, 6, 7), 5, in_axes=(1, 2, 3))
print(l4(x).shape)  # (4, 5)
Train Bidirectional-LSTM
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import optax  # for gradient optimization

import serket as sk

x = jnp.linspace(0, 1, 101).reshape(-1, 1)  # 101 samples of 1D data
y = jnp.sin(2 * jnp.pi * x)
y += 0.1 * jr.normal(jr.PRNGKey(0), y.shape)

# we will use 2 time steps to predict the next time step
x_batched = jnp.stack([x[:-1], x[1:]], axis=1)
x_batched = jnp.reshape(x_batched, (100, 1, 2, 1))  # 100 minibatches x 1 sample x 2 time steps x 1D data
y_batched = jnp.reshape(y[1:], (100, 1, 1))  # 100 minibatches x 1 samples x 1D data

model = sk.nn.Sequential(
    [
        # first cell is the forward cell, second cell is the backward cell for bidirectional RNN
        # we return the full sequence of outputs for each cell by setting return_sequences=True
        # we use None in place of `in_features` to infer the input shape from the input
        sk.nn.ScanRNN(
            sk.nn.LSTMCell(None, 64),
            backward_cell=sk.nn.LSTMCell(None, 64),
            return_sequences=True,
        ),
        # here the in_features is inferred from the previous layer by setting it to None
        # or simply we can set it to 64*2 (64 for each cell from previous layer)
        # we set return_sequences=False to return only the last output of the sequence
        sk.nn.ScanRNN(sk.nn.LSTMCell(None, 1), return_sequences=False),
    ]
)


@jax.value_and_grad
def loss_func(NN, batched_x, batched_y):
    # use jax.vmap to apply the model to each minibatch
    # in our case single x minibatch has shape (1, 2, 1)
    # and single y minibatch has shape (1, 1)
    # then vmap will be applied to the leading axis
    batched_preds = jax.vmap(NN)(batched_x)
    return jnp.mean((batched_preds - batched_y) ** 2)


@jax.jit
def batch_step(NN, batched_x, batched_y, opt_state):
    loss, grads = loss_func(NN, batched_x, batched_y)
    updates, optim_state = optim.update(grads, opt_state)
    NN = optax.apply_updates(NN, updates)
    return NN, optim_state, loss


# dry run to infer the in_features (i.e. replace None with in_features)
# if you want restrict the model to a specific input shape or to avoid
# confusion you can manually specify the in_features as a consequence
# dry run is not necessary in this case
model(x_batched[0, 0])

optim = optax.adam(1e-3)
opt_state = optim.init(model)

epochs = 100

for i in range(1, epochs + 1):
    epoch_loss = []
    for x_b, y_b in zip(x_batched, y_batched):
        model, opt_state, loss = batch_step(model, x_b, y_b, opt_state)
        epoch_loss.append(loss)

    epoch_loss = jnp.mean(jnp.array(epoch_loss))

    if i % 10 == 0:
        print(f"Epoch {i:3d} Loss {epoch_loss:.4f}")

# Epoch  10 Loss 0.0880
# Epoch  20 Loss 0.0796
# Epoch  30 Loss 0.0620
# Epoch  40 Loss 0.0285
# Epoch  50 Loss 0.0205
# Epoch  60 Loss 0.0187
# Epoch  70 Loss 0.0182
# Epoch  80 Loss 0.0176
# Epoch  90 Loss 0.0171
# Epoch 100 Loss 0.0166

y_pred = jax.vmap(model)(x_batched.reshape(-1, 2, 1))
plt.plot(x[1:], y[1:], "--k", label="data")
plt.plot(x[1:], y_pred, label="prediction")
plt.legend()

image

Lazy initialization

In cases where in_features needs to be inferred from input, use None instead of in_features to infer the value at runtime. However, since the lazy module initialize it's state after the first call (i.e. mutate it's state) jax transformation ex: vmap, grad ... is not allowed before initialization. Using any jax transformation before initialization will throw a ValueError.

import serket as sk
import jax
import jax.numpy as jnp

model = sk.nn.Sequential(
    [
        sk.nn.Conv2D(None, 128, 3),
        sk.nn.ReLU(),
        sk.nn.MaxPool2D(2, 2),
        sk.nn.Conv2D(128, 64, 3),
        sk.nn.ReLU(),
        sk.nn.MaxPool2D(2, 2),
        sk.nn.Flatten(),
        sk.nn.Linear(None, 128),
        sk.nn.ReLU(),
        sk.nn.Linear(128, 1),
    ]
)

# print the first `Conv2D` layer before initialization
print(model[0].__repr__())
# Conv2D(
#   weight=None,
#   bias=None,
#   *in_features=None,
#   *out_features=None,
#   *kernel_size=None,
#   *strides=None,
#   *padding=None,
#   *input_dilation=None,
#   *kernel_dilation=None,
#   weight_init_func=None,
#   bias_init_func=None,
#   *groups=None
# )

try :
    jax.vmap(model)(jnp.ones((10, 1,28, 28)))
except ValueError:
    print("***** Not initialized *****")
# ***** Not initialized *****

# dry run to initialize the model
model(jnp.empty([3,128,128]))

print(model[0].__repr__())
# Conv2D(
#   weight=f32[128,3,3,3],
#   bias=f32[128,1,1],
#   *in_features=3,
#   *out_features=128,
#   *kernel_size=(3,3),
#   *strides=(1,1),
#   *padding=((1,1),(1,1)),
#   *input_dilation=(1,1),
#   *kernel_dilation=(1,1),
#   weight_init_func=Partial(glorot_uniform(key,shape,dtype)),
#   bias_init_func=Partial(zeros(key,shape,dtype)),
#   *groups=1
# )
Train MNIST

We will use tensorflow datasets for dataloading. for more on interface of jax/tensorflow dataset see here

# imports
import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type="GPU")
import tensorflow_datasets as tfds
import tensorflow.experimental.numpy as tnp
import jax
import jax.numpy as jnp
import jax.random as jr
import optax  # for gradient optimization
import serket as sk
import matplotlib.pyplot as plt
import functools as ft
# Construct a tf.data.Dataset
batch_size = 128

# convert the samples from integers to floating-point numbers
# and channel first format
def preprocess_data(x):
    # convert to channel first format
    image = tnp.moveaxis(x["image"], -1, 0)
    # normalize to [0, 1]
    image = tf.cast(image, tf.float32) / 255.0

    # one-hot encode the labels
    label = tf.one_hot(x["label"], 10) / 1.0
    return {"image": image, "label": label}


ds_train, ds_test = tfds.load("mnist", split=["train", "test"], shuffle_files=True)
# (batches, batch_size, 1, 28, 28)
ds_train = ds_train.shuffle(1024).map(preprocess_data).batch(batch_size).prefetch(tf.data.AUTOTUNE)

# (batches, 1, 28, 28)
ds_test = ds_test.map(preprocess_data).prefetch(tf.data.AUTOTUNE)

๐Ÿ—๏ธ Model definition

We will use jax.vmap(model) to apply model on batches.

@sk.treeclass
class CNN:
    def __init__(self):
        self.conv1 = sk.nn.Conv2D(1, 32, (3, 3), padding="valid")
        self.relu1 = sk.nn.ReLU()
        self.pool1 = sk.nn.MaxPool2D((2, 2), strides=(2, 2))
        self.conv2 = sk.nn.Conv2D(32, 64, (3, 3), padding="valid")
        self.relu2 = sk.nn.ReLU()
        self.pool2 = sk.nn.MaxPool2D((2, 2), strides=(2, 2))
        self.flatten = sk.nn.Flatten(start_dim=0)
        self.dropout = sk.nn.Dropout(0.5)
        self.linear = sk.nn.Linear(5*5*64, 10)

    def __call__(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.dropout(x)
        x = self.linear(x)
        return x

model = CNN()

๐ŸŽจ Visualize model

Model summary
print(model.summary(show_config=False, array=jnp.empty((1, 28, 28))))  
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚Name   โ”‚Type     โ”‚Param #  โ”‚Size          โ”‚Input        โ”‚Output       โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚conv1  โ”‚Conv2D   โ”‚320(0)   โ”‚1.25KB(0.00B) โ”‚f32[1,28,28] โ”‚f32[32,26,26]โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚relu1  โ”‚ReLU     โ”‚0(0)     โ”‚0.00B(0.00B)  โ”‚f32[32,26,26]โ”‚f32[32,26,26]โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚pool1  โ”‚MaxPool2Dโ”‚0(0)     โ”‚0.00B(0.00B)  โ”‚f32[32,26,26]โ”‚f32[32,13,13]โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚conv2  โ”‚Conv2D   โ”‚18,496(0)โ”‚72.25KB(0.00B)โ”‚f32[32,13,13]โ”‚f32[64,11,11]โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚relu2  โ”‚ReLU     โ”‚0(0)     โ”‚0.00B(0.00B)  โ”‚f32[64,11,11]โ”‚f32[64,11,11]โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚pool2  โ”‚MaxPool2Dโ”‚0(0)     โ”‚0.00B(0.00B)  โ”‚f32[64,11,11]โ”‚f32[64,5,5]  โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚flattenโ”‚Flatten  โ”‚0(0)     โ”‚0.00B(0.00B)  โ”‚f32[64,5,5]  โ”‚f32[1600]    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚dropoutโ”‚Dropout  โ”‚0(0)     โ”‚0.00B(0.00B)  โ”‚f32[1600]    โ”‚f32[1600]    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚linear โ”‚Linear   โ”‚16,010(0)โ”‚62.54KB(0.00B)โ”‚f32[1600]    โ”‚f32[10]      โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
Total count :	34,826(0)
Dynamic count :	34,826(0)
Frozen count :	0(0)
------------------------------------------------------------------------
Total size :	136.04KB(0.00B)
Dynamic size :	136.04KB(0.00B)
Frozen size :	0.00B(0.00B)
========================================================================
tree diagram
print(model.tree_diagram())
CNN
    โ”œโ”€โ”€ conv1=Conv2D
    โ”‚   โ”œโ”€โ”€ weight=f32[32,1,3,3]
    โ”‚   โ”œโ”€โ”€ bias=f32[32,1,1]
    โ”‚   โ”œ*โ”€ in_features=1
    โ”‚   โ”œ*โ”€ out_features=32
    โ”‚   โ”œ*โ”€ kernel_size=(3, 3)
    โ”‚   โ”œ*โ”€ strides=(1, 1)
    โ”‚   โ”œ*โ”€ padding=((0, 0), (0, 0))
    โ”‚   โ”œ*โ”€ input_dilation=(1, 1)
    โ”‚   โ”œ*โ”€ kernel_dilation=(1, 1)
    โ”‚   โ”œโ”€โ”€ weight_init_func=Partial(init(key,shape,dtype))
    โ”‚   โ”œโ”€โ”€ bias_init_func=Partial(zeros(key,shape,dtype))
    โ”‚   โ””*โ”€ groups=1    
    โ”œโ”€โ”€ relu1=ReLU  
    โ”œ*โ”€ pool1=MaxPool2D
    โ”‚   โ”œ*โ”€ kernel_size=(2, 2)
    โ”‚   โ”œ*โ”€ strides=(2, 2)
    โ”‚   โ””*โ”€ padding='valid' 
    โ”œโ”€โ”€ conv2=Conv2D
    โ”‚   โ”œโ”€โ”€ weight=f32[64,32,3,3]
    โ”‚   โ”œโ”€โ”€ bias=f32[64,1,1]
    โ”‚   โ”œ*โ”€ in_features=32
    โ”‚   โ”œ*โ”€ out_features=64
    โ”‚   โ”œ*โ”€ kernel_size=(3, 3)
    โ”‚   โ”œ*โ”€ strides=(1, 1)
    โ”‚   โ”œ*โ”€ padding=((0, 0), (0, 0))
    โ”‚   โ”œ*โ”€ input_dilation=(1, 1)
    โ”‚   โ”œ*โ”€ kernel_dilation=(1, 1)
    โ”‚   โ”œโ”€โ”€ weight_init_func=Partial(init(key,shape,dtype))
    โ”‚   โ”œโ”€โ”€ bias_init_func=Partial(zeros(key,shape,dtype))
    โ”‚   โ””*โ”€ groups=1    
    โ”œโ”€โ”€ relu2=ReLU  
    โ”œ*โ”€ pool2=MaxPool2D
    โ”‚   โ”œ*โ”€ kernel_size=(2, 2)
    โ”‚   โ”œ*โ”€ strides=(2, 2)
    โ”‚   โ””*โ”€ padding='valid' 
    โ”œ*โ”€ flatten=Flatten
    โ”‚   โ”œ*โ”€ start_dim=0
    โ”‚   โ””*โ”€ end_dim=-1  
    โ”œโ”€โ”€ dropout=Dropout
    โ”‚   โ”œ*โ”€ p=0.5
    โ”‚   โ””โ”€โ”€ eval=None   
    โ””โ”€โ”€ linear=Linear
        โ”œโ”€โ”€ weight=f32[1600,10]
        โ”œโ”€โ”€ bias=f32[10]
        โ”œ*โ”€ in_features=1600
        โ””*โ”€ out_features=10  
    
Plot sample predictions before training
 
# set all dropout off
test_model = model.at[model == "eval"].set(True, is_leaf=lambda x: x is None)

def show_images_with_predictions(model, images, one_hot_labels):
logits = jax.vmap(model)(images)
predictions = jnp.argmax(logits, axis=-1)
fig, axes = plt.subplots(5, 5, figsize=(10, 10))
for i, ax in enumerate(axes.flat):
ax.imshow(images[i].reshape(28, 28), cmap="binary")
ax.set(title=f"Prediction: {predictions[i]}\nLabel: {jnp.argmax(one_hot_labels[i], axis=-1)}")
ax.set_xticks([])
ax.set_yticks([])
plt.show()

example = ds_test.take(25).as_numpy_iterator()
example = list(example)
sample_test_images = jnp.stack([x["image"] for x in example])
sample_test_labels = jnp.stack([x["label"] for x in example])

show_images_with_predictions(test_model, sample_test_images, sample_test_labels)

image

๐Ÿƒ Train the model

@ft.partial(jax.value_and_grad, has_aux=True)
def loss_func(model, batched_images, batched_one_hot_labels):
    logits = jax.vmap(model)(batched_images)
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=batched_one_hot_labels))
    return loss, logits


# using optax for gradient updates
optim = optax.adam(1e-3)
optim_state = optim.init(model)


@jax.jit
def batch_step(model, batched_images, batched_one_hot_labels, optim_state):
    (loss, logits), grads = loss_func(model, batched_images, batched_one_hot_labels)
    updates, optim_state = optim.update(grads, optim_state)
    model = optax.apply_updates(model, updates)
    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == jnp.argmax(batched_one_hot_labels, axis=-1))
    return model, optim_state, loss, accuracy


epochs = 5

for i in range(epochs):
    epoch_accuracy = []
    epoch_loss = []

    for example in ds_train.as_numpy_iterator():
        image, label = example["image"], example["label"]
        model, optim_state, loss, accuracy = batch_step(model, image, label, optim_state)
        epoch_accuracy.append(accuracy)
        epoch_loss.append(loss)

    epoch_loss = jnp.mean(jnp.array(epoch_loss))
    epoch_accuracy = jnp.mean(jnp.array(epoch_accuracy))

    print(f"epoch:{i+1:00d}\tloss:{epoch_loss:.4f}\taccuracy:{epoch_accuracy:.4f}")

# epoch:1	loss:0.2706	accuracy:0.9268
# epoch:2	loss:0.0725	accuracy:0.9784
# epoch:3	loss:0.0533	accuracy:0.9836
# epoch:4	loss:0.0442	accuracy:0.9868
# epoch:5	loss:0.0368	accuracy:0.9889

๐ŸŽจ Visualize After training

test_model = model.at[model == "eval"].set(True, is_leaf=lambda x: x is None)
show_images_with_predictions(test_model, sample_test_images, sample_test_labels)

image


image

Finite difference examples
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy.testing as npt

import serket as sk


# lets first define a vector valued function F: R^3 -> R^3
# F = F1, F2
# F1 = x^2 + y^3
# F2 = x^4 + y^3
# F3 = 0
# F = [x**2 + y**3, x**4 + y**3, 0]

x, y, z = [jnp.linspace(0, 1, 100)] * 3
dx, dy, dz = x[1] - x[0], y[1] - y[0], z[1] - z[0]
X, Y, Z = jnp.meshgrid(x, y, z, indexing="ij")
F1 = X**2 + Y**3
F2 = X**4 + Y**3
F3 = jnp.zeros_like(F1)
F = jnp.stack([F1, F2, F3], axis=0)

# โˆ‚F1/โˆ‚x : differentiate F1 with respect to x (i.e axis=0)
dF1dx = sk.fd.difference(F1, axis=0, step_size=dx, accuracy=6)
dF1dx_exact = 2 * X
npt.assert_allclose(dF1dx, dF1dx_exact, atol=1e-7)

# โˆ‚F2/โˆ‚y : differentiate F2 with respect to y (i.e axis=1)
dF2dy = sk.fd.difference(F2, axis=1, step_size=dy, accuracy=6)
dF2dy_exact = 3 * Y**2
npt.assert_allclose(dF2dy, dF2dy_exact, atol=1e-7)

# โˆ‡.F : the divergence of F
divF = sk.fd.divergence(F, step_size=(dx, dy, dz), keepdims=False, accuracy=6)
divF_exact = 2 * X + 3 * Y**2
npt.assert_allclose(divF, divF_exact, atol=1e-7)

# โˆ‡F1 : the gradient of F1
gradF1 = sk.fd.gradient(F1, step_size=(dx, dy, dz), accuracy=6)
gradF1_exact = jnp.stack([2 * X, 3 * Y**2, 0 * X], axis=0)
npt.assert_allclose(gradF1, gradF1_exact, atol=1e-7)

# ฮ”F1 : laplacian of F1
lapF1 = sk.fd.laplacian(F1, step_size=(dx, dy, dz), accuracy=6)
lapF1_exact = 2 + 6 * Y
npt.assert_allclose(lapF1, lapF1_exact, atol=1e-7)

# โˆ‡xF : the curl of F
curlF = sk.fd.curl(F, step_size=(dx, dy, dz), accuracy=6)
curlF_exact = jnp.stack([F1 * 0, F1 * 0, 4 * X**3 - 3 * Y**2], axis=0)
npt.assert_allclose(curlF, curlF_exact, atol=1e-7)

# Jacobian of F
JF = sk.fd.jacobian(F, accuracy=4, step_size=(dx, dy, dz))
JF_exact = jnp.array(
    [
        [2 * X, 3 * Y**2, jnp.zeros_like(X)],
        [4 * X**3, 3 * Y**2, jnp.zeros_like(X)],
        [jnp.zeros_like(X), jnp.zeros_like(X), jnp.zeros_like(X)],
    ]
)
npt.assert_allclose(JF, JF_exact, atol=1e-7)

# Hessian of F1
HF1 = sk.fd.hessian(F1, accuracy=4, step_size=(dx, dy, dz))
HF1_exact = jnp.array(
    [
        [
            2 * jnp.ones_like(X),  # โˆ‚2F1/โˆ‚x2
            0 * jnp.ones_like(X),  # โˆ‚2F1/โˆ‚xy
            0 * jnp.ones_like(X),  # โˆ‚2F1/โˆ‚xz
        ],
        [
            0 * jnp.ones_like(X),  # โˆ‚2F1/โˆ‚yx
            6 * Y**2,              # โˆ‚2F1/โˆ‚y2
            0 * jnp.ones_like(X),  # โˆ‚2F1/โˆ‚yz
        ],
        [
            0 * jnp.ones_like(X),  # โˆ‚2F1/โˆ‚zx
            0 * jnp.ones_like(X),  # โˆ‚2F1/โˆ‚zy
            0 * jnp.ones_like(X),  # โˆ‚2F1/โˆ‚z2
        ],
    ]
)
npt.assert_allclose(JF, JF_exact, atol=1e-7)
PINN with Finite difference

We will try to estimate $NN(x)~f(x)$, where $df(x)/dx = cos(x)$ and $df(x)/dx$ will be represented with finite difference scheme. The following code compares between finite difference fd based implementation and automatic differentation ad based implementation.

import copy

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax

import serket as sk

x = jnp.linspace(-jnp.pi, jnp.pi, 1000)[:, None]
y = jnp.sin(x)
dx = x[1] - x[0]
dydx = jnp.cos(x)

NN_fd = sk.nn.Sequential(
    [
        sk.nn.Linear(1, 128),
        sk.nn.ReLU(),
        sk.nn.Linear(128, 128),
        sk.nn.ReLU(),
        sk.nn.Linear(128, 1),
    ]
)

NN_ad = copy.copy(NN_fd)
optim = optax.adam(1e-3)


@jax.value_and_grad
def loss_func_fd(NN, x):
    y = NN(x)
    dydx = sk.fd.difference(y, axis=0, accuracy=5, step_size=dx)
    loss = jnp.mean((dydx - jnp.cos(x)) ** 2)
    loss += jnp.mean((NN(jnp.zeros_like(x))) ** 2)  # initial condition
    return loss


@jax.value_and_grad
def loss_func_ad(NN, x):
    loss = jnp.mean((sk.diff(NN)(x) - jnp.cos(x)) ** 2)
    loss += jnp.mean(NN(jnp.zeros_like(x)) ** 2)  # initial condition
    return loss


@jax.jit
def step_fd(NN, x, optim_state):
    loss, grads = loss_func_fd(NN, x)
    updates, optim_state = optim.update(grads, optim_state)
    NN = optax.apply_updates(NN, updates)
    return NN, optim_state, loss


def train_fd(NN_fd, optim_state_fd, epochs):
    for i in range(1, epochs + 1):
        NN_fd, optim_state_fd, loss_fd = step_fd(NN_fd, x, optim_state_fd)
    return NN_fd, optim_state_fd, loss_fd


@jax.jit
def step_ad(NN, x, optim_state):
    loss, grads = loss_func_ad(NN, x)
    updates, optim_state = optim.update(grads, optim_state)
    NN = optax.apply_updates(NN, updates)
    return NN, optim_state, loss


def train_ad(NN_ad, optim_state_ad, epochs):
    for i in range(1, epochs + 1):
        NN_ad, optim_state_ad, loss_ad = step_ad(NN_ad, x, optim_state_ad)
    return NN_ad, optim_state_ad, loss_ad


epochs = 1000


optim_state_fd = optim.init(NN_fd)
optim_state_ad = optim.init(NN_ad)


NN_fd, optim_state_fd, loss_fd = train_fd(NN_fd, optim_state_fd, epochs)
NN_ad, optim_state_ad, loss_ad = train_ad(NN_ad, optim_state_ad, epochs)
print(f"Loss_fd {loss_fd:.4f} \nLoss_ad {loss_ad:.4f}")
y_fd = NN_fd(x)
y_ad = NN_ad(x)
plt.plot(x, y, "--k", label="true")
plt.plot(x, y_fd, label="fd pred")
plt.plot(x, y_ad, label="ad pred")
plt.legend()

# Loss_fd 0.0012
# Loss_ad 0.0235

image

Reconstructing a vector field F using โˆ‡.F = 0 and โˆ‡xF=2k condition
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax

import serket as sk

x, y = [jnp.linspace(-1, 1,50)] * 2
dx, dy = [x[1] - x[0]] * 2
X, Y = jnp.meshgrid(x, y, indexing="ij")

F1 = -Y
F2 = +X
F = jnp.stack([F1, F2], axis=0)

NN = sk.nn.Sequential(
    [
        sk.nn.Conv2D(2, 32, kernel_size=3, padding="same"),
        sk.nn.ReLU(),
        sk.nn.Conv2D(32, 32, kernel_size=3, padding="same"),
        sk.nn.ReLU(),
        sk.nn.Conv2D(32, 2, kernel_size=3, padding="same"),
    ]
)

optim = optax.adam(1e-3)


@jax.value_and_grad
def loss_func(NN, F):
    F_pred = NN(F)
    div = sk.fd.divergence(F_pred, accuracy=5, step_size=(dx, dy))
    loss = jnp.mean(div**2)  # divergence free condition
    curl = sk.fd.curl(F_pred, accuracy=2, step_size=(dx, dy))
    loss += jnp.mean((curl-jnp.ones_like(curl)*2)**2)  # curl condition
    return loss


@jax.jit
def step(NN, F, optim_state):
    loss, grads = loss_func(NN, F)
    updates, optim_state = optim.update(grads, optim_state)
    NN = optax.apply_updates(NN, updates)
    return NN, optim_state, loss


def train(NN, Z, optim_state, epochs):
    for i in range(1, epochs + 1):
        NN, optim_state, loss = step(NN, Z, optim_state)
    return NN, optim_state, loss


Z = jnp.stack([X, Y], axis=0)  # collocation points
optim_state = optim.init(NN)  # initialise optimiser
epochs = 1_000
NN, _, loss = train(NN, Z, optim_state, epochs)

Fpred = NN(Z)  # predicted field

plt.figure(figsize=(10, 10))
plt.quiver(X, Y, Fpred[0], Fpred[1], color="r", label="pred")
plt.quiver(X, Y, F1, F2, color="k", alpha=0.5, label="true")
plt.legend()

image

Vectorized differentiable stencil computation with `serket.kmap`

Serket uses kernex.kmap decorator that applies a user-defined stencil kernel. kmap uses jax.vmap as it's backend to vectorized the operation, this means that the decorator is transparent to jax transformation.

Example

@sk.kmap(
     # a kernel size applied to 2D input with size =3x3
    kernel_size = (3,3),

    # a strides = 1
    strides= (1,1) ,

    # padding can be among the following options
    # 1) a single integer for each dimension -> ex: (1,) pads zeros before and after axis=0
    # 2) a tuple of two integer for each dimension -> ex: ((1,2),) pads one zero on left and 2 zeros on right of axis=0
    # 3) "same"/"valid"
    # 4) "same"/"valid" tuple for each dimension -> ex: ("same",) same padding for axis=0
    padding = "valid",

    # relative means if the indexing should be row-col wise or center wise.
    # for example in a 3x3  1..9 kernel , x[0,0] yields
    # 1 if relative = False
    # 5 if relative = True (i.e. the center value)
    relative= True,

)
def avg_blur(x):
    return (x[-1, -1] + x[-1, 0] + x[-1, 1] +
            x[ 0, -1] + x[ 0, 0] + x[ 0, 1] +
            x[ 1, -1] + x[ 1, 0] + x[ 1, 1]) // 9

avg_blur(jnp.arange(1,26).reshape(5,5))
# [[ 7  8  9]
#  [12 13 14]
#  [17 18 19]]
Scan a stencil kernel to solve linear convection using `serket.kscan`

$\Large {\partial u \over \partial t} + c {\partial u \over \partial x} = 0$

$\Large u_i^{n} = u_i^{n-1} - c \frac{\Delta t}{\Delta x}(u_i^{n-1}-u_{i-1}^{n-1})$

Problem setup Stencil view

By using serket.kscan, the stencil kernel can be scanned carrying along state, in a way similar to how RNN works. This enables BPTT algorithm that is useful for some problems (ex. time-dependent PDEs) .

import jax
import jax.numpy as jnp
import serket as sk
import matplotlib.pyplot as plt

# see https://nbviewer.org/github/barbagroup/CFDPython/blob/master/lessons/01_Step_1.ipynb

tmax,xmax = 0.5,2.0
nt,nx = 151,51
dt,dx = tmax/(nt-1) , xmax/(nx-1)
u = jnp.ones([nt,nx])
c = 0.5

# kscan moves sequentially in row-major order and updates in-place using lax.scan.

F = sk.kscan(
        kernel_size = (3,3),
        padding = ((1,1),(1,1)),
        named_axis={0:'n',1:'i'},  # n for time axis , i for spatial axis (optional naming)
        relative=True
    )


# boundary condtion as a function
def bc(u):
    return 1

# initial condtion as a function
def ic1(u):
    return 1

def ic2(u):
    return 2

def linear_convection(u):
    return ( u['i','n-1'] - (c*dt/dx) * (u['i','n-1'] - u['i-1','n-1']) )


F[:,0]  = F[:,-1] = bc # assign 1 for left and right boundary for all t

# square wave initial condition
F[:,:int((nx-1)/4)+1] = F[:,int((nx-1)/2):] = ic1
F[0:1, int((nx-1)/4)+1 : int((nx-1)/2)] = ic2

# assign linear convection function for
# interior spatial location [1:-1]
# and start from t>0  [1:]
F[1:,1:-1] = linear_convection

kx_solution = F(jnp.array(u))

plt.figure(figsize=(20,7))
for line in kx_solution[::20]:
    plt.plot(jnp.linspace(0,xmax,nx),line)

image

๐Ÿฅถ Freezing parameters /Fine tuning

โœจSee here for more about freezingโœจ

๐Ÿ”˜ Filtering by masking

โœจSee here for more about filterning โœจ

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

serket-0.2.0b2.tar.gz (99.8 kB view hashes)

Uploaded Source

Built Distribution

serket-0.2.0b2-py3-none-any.whl (99.1 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page