Skip to main content

JAX NN library.

Project description

The ✨Magical✨ JAX Scientific ML Library.

*Serket is the goddess of magic in Egyptian mythology

Installation |Description |Quick Example |Freezing/Fine tuning |Filtering

Tests pyver codestyle Downloads codecov DOI PyPI

🛠️ Installation

pip install serket

Install development version

pip install git+https://github.com/ASEM000/serket

📖 Description

  • serket aims to be the most intuitive and easy-to-use physics-based Neural network library in JAX.
  • serket is built on top of pytreeclass
  • serket is fully transparent to jax transformation (e.g. vmap,grad,jit,...)

➖➕Finite difference package: serket.fd➕➖

Group Function/Layer
Finite difference layer - Difference: apply finite difference to input array to any derivative order and accuracy
Finite difference functions - difference: finite difference of array with any accuracy and derivative order
- generate_finitediff_coeffs : generate coeffs using sample points and derivative order
- fgrad: differentiate functions (similar to jax.grad) with custom accuracy and derivative order
Vector operator layers - Curl, Divergence, Gradient, Laplacian, Jacobian, Hessian
Vector operator function - curl, divergence, gradient, laplacian, jacobian, hessian

🧠 Neural network package: serket.nn 🧠

Group Layers
Linear - Linear, Bilinear, Multilinear, GeneralLinear, Identity
Densely connected - FNN (Fully connected network),
- PFNN (Parallel fully connected network)
Convolution - Conv1D, Conv2D, Conv3D,
- Conv1DTranspose , Conv2DTranspose, Conv3DTranspose,
- DepthwiseConv1D, DepthwiseConv2D, DepthwiseConv3D,
- SeparableConv1D, SeparableConv2D, SeparableConv3D,
- Conv1DLocal, Conv2DLocal, Conv3DLocal
Containers - Sequential, Lambda
Pooling
(kernex backend)
- MaxPool1D, MaxPool2D, MaxPool3D,
- AvgPool1D, AvgPool2D, AvgPool3D
- GlobalMaxPool1D, GlobalMaxPool2D, GlobalMaxPool3D,
- GlobalAvgPool1D, GlobalAvgPool2D, GlobalAvgPool3D
- LPPool1D, LPPool2D,LPPool3D ,
- AdaptivePool1D, AdaptivePool2D, AdaptivePool3D,
- AdaptiveConcatPool1D,AdaptiveConcatPool2D,AdaptiveConcatPool3D
Reshaping - Flatten, Unflatten,
- FlipLeftRight2D, FlipUpDown2D,
- Repeat1D, Repeat2D, Repeat3D,
- Resize1D, Resize2D, Resize3D,
- Upsample1D, Upsample2D, Upsample3D,
- Pad1D, Pad2D, Pad3D
Crop - Crop1D, Crop2D,
Normalization - LayerNorm, InstanceNorm, GroupNorm
Blurring - AvgBlur2D, GaussianBlur2D
Dropout - Dropout, ,
- Dropout1D, Dropout2D, Dropout3D,
Random transforms - RandomCrop1D, RandomCrop2D,
- RandomApply,
- RandomCutout1D, RandomCutout2D,
- RandomZoom2D,
- RandomContrast2D
Misc - HistogramEqualization2D, AdjustContrast2D, Filter2D, PixelShuffle2D
Activations - AdaptiveLeakyReLU,AdaptiveReLU,AdaptiveSigmoid,AdaptiveTanh,
- CeLU,ELU,GELU,GLU
- HardSILU,HardShrink,HardSigmoid,HardSwish,HardTanh,
- LeakyReLU,LogSigmoid,LogSoftmax,Mish,PReLU,
- ReLU,ReLU6,SILU,SeLU,Sigmoid,SoftPlus,SoftShrink,
- SoftSign,Swish,Tanh,TanhShrink, ThresholdedReLU, Snake
Recurrent - SimpleRNNCell, LSTMCell, GRUCell,
- ConvLSTM1D, ConvLSTM2D, ConvLSTM3D,
- SeparableConvLSTM1DCell, SeparableConvLSTM2DCell, SeparableConvLSTM3DCell
- ConvGRU1DCell,ConvGRU2DCell,ConvGRU3DCell,
- SeparableConvGRU1DCell,SeparableConvGRU2DCell,SeparableConvGRU3DCell
Blocks - VGG16Block, VGG19Block, UNetBlock

⏩ Examples:

Finite difference examples
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy.testing as npt

import serket as sk


# lets first define a vector valued function F: R^3 -> R^3
# F = F1, F2
# F1 = x^2 + y^3
# F2 = x^4 + y^3
# F3 = 0
# F = [x**2 + y**3, x**4 + y**3, 0]

x, y, z = [jnp.linspace(0, 1, 100)] * 3
dx, dy, dz = x[1] - x[0], y[1] - y[0], z[1] - z[0]
X, Y, Z = jnp.meshgrid(x, y, z, indexing="ij")
F1 = X**2 + Y**3
F2 = X**4 + Y**3
F3 = jnp.zeros_like(F1)
F = jnp.stack([F1, F2, F3], axis=0)

# ∂F1/∂x : differentiate F1 with respect to x (i.e axis=0)
dF1dx = sk.fd.difference(F1, axis=0, step_size=dx, accuracy=6)
dF1dx_exact = 2 * X
npt.assert_allclose(dF1dx, dF1dx_exact, atol=1e-7)

# ∂F2/∂y : differentiate F2 with respect to y (i.e axis=1)
dF2dy = sk.fd.difference(F2, axis=1, step_size=dy, accuracy=6)
dF2dy_exact = 3 * Y**2
npt.assert_allclose(dF2dy, dF2dy_exact, atol=1e-7)

# ∇.F : the divergence of F
divF = sk.fd.divergence(F, step_size=(dx, dy, dz), keepdims=False, accuracy=6)
divF_exact = 2 * X + 3 * Y**2
npt.assert_allclose(divF, divF_exact, atol=1e-7)

# ∇F1 : the gradient of F1
gradF1 = sk.fd.gradient(F1, step_size=(dx, dy, dz), accuracy=6)
gradF1_exact = jnp.stack([2 * X, 3 * Y**2, 0 * X], axis=0)
npt.assert_allclose(gradF1, gradF1_exact, atol=1e-7)

# ΔF1 : laplacian of F1
lapF1 = sk.fd.laplacian(F1, step_size=(dx, dy, dz), accuracy=6)
lapF1_exact = 2 + 6 * Y
npt.assert_allclose(lapF1, lapF1_exact, atol=1e-7)

# ∇xF : the curl of F
curlF = sk.fd.curl(F, step_size=(dx, dy, dz), accuracy=6)
curlF_exact = jnp.stack([F1 * 0, F1 * 0, 4 * X**3 - 3 * Y**2], axis=0)
npt.assert_allclose(curlF, curlF_exact, atol=1e-7)

# Jacobian of F
JF = sk.fd.jacobian(F, accuracy=4, step_size=(dx, dy, dz))
JF_exact = jnp.array(
    [
        [2 * X, 3 * Y**2, jnp.zeros_like(X)],
        [4 * X**3, 3 * Y**2, jnp.zeros_like(X)],
        [jnp.zeros_like(X), jnp.zeros_like(X), jnp.zeros_like(X)],
    ]
)
npt.assert_allclose(JF, JF_exact, atol=1e-7)

# Hessian of F1
HF1 = sk.fd.hessian(F1, accuracy=4, step_size=(dx, dy, dz))
HF1_exact = jnp.array(
    [
        [
            2 * jnp.ones_like(X),  # ∂2F1/∂x2
            0 * jnp.ones_like(X),  # ∂2F1/∂xy
            0 * jnp.ones_like(X),  # ∂2F1/∂xz
        ],
        [
            0 * jnp.ones_like(X),  # ∂2F1/∂yx
            6 * Y**2,              # ∂2F1/∂y2
            0 * jnp.ones_like(X),  # ∂2F1/∂yz
        ],
        [
            0 * jnp.ones_like(X),  # ∂2F1/∂zx
            0 * jnp.ones_like(X),  # ∂2F1/∂zy
            0 * jnp.ones_like(X),  # ∂2F1/∂z2
        ],
    ]
)
npt.assert_allclose(JF, JF_exact, atol=1e-7)
Train Bidirectional-LSTM
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import optax  # for gradient optimization

import serket as sk

x = jnp.linspace(0, 1, 101).reshape(-1, 1)  # 101 samples of 1D data
y = jnp.sin(2 * jnp.pi * x)
y += 0.1 * jr.normal(jr.PRNGKey(0), y.shape)

# we will use 2 time steps to predict the next time step
x_batched = jnp.stack([x[:-1], x[1:]], axis=1)
x_batched = jnp.reshape(x_batched, (100, 1, 2, 1))  # 100 minibatches x 1 sample x 2 time steps x 1D data
y_batched = jnp.reshape(y[1:], (100, 1, 1))  # 100 minibatches x 1 samples x 1D data

model = sk.nn.Sequential(
    [
        # first cell is the forward cell, second cell is the backward cell for bidirectional RNN
        # we return the full sequence of outputs for each cell by setting return_sequences=True
        # we use None in place of `in_features` to infer the input shape from the input
        sk.nn.ScanRNN(
            sk.nn.LSTMCell(None, 64),
            backward_cell=sk.nn.LSTMCell(None, 64),
            return_sequences=True,
        ),
        # here the in_features is inferred from the previous layer by setting it to None
        # or simply we can set it to 64*2 (64 for each cell from previous layer)
        # we set return_sequences=False to return only the last output of the sequence
        sk.nn.ScanRNN(sk.nn.LSTMCell(None, 1), return_sequences=False),
    ]
)


@jax.value_and_grad
def loss_func(NN, batched_x, batched_y):
    # use jax.vmap to apply the model to each minibatch
    # in our case single x minibatch has shape (1, 2, 1)
    # and single y minibatch has shape (1, 1)
    # then vmap will be applied to the leading axis
    batched_preds = jax.vmap(NN)(batched_x)
    return jnp.mean((batched_preds - batched_y) ** 2)


@jax.jit
def batch_step(NN, batched_x, batched_y, opt_state):
    loss, grads = loss_func(NN, batched_x, batched_y)
    updates, optim_state = optim.update(grads, opt_state)
    NN = optax.apply_updates(NN, updates)
    return NN, optim_state, loss


# dry run to infer the in_features (i.e. replace None with in_features)
# if you want restrict the model to a specific input shape or to avoid
# confusion you can manually specify the in_features as a consequence
# dry run is not necessary in this case
model(x_batched[0, 0])

optim = optax.adam(1e-3)
opt_state = optim.init(model)

epochs = 100

for i in range(1, epochs + 1):
    epoch_loss = []
    for x_b, y_b in zip(x_batched, y_batched):
        model, opt_state, loss = batch_step(model, x_b, y_b, opt_state)
        epoch_loss.append(loss)

    epoch_loss = jnp.mean(jnp.array(epoch_loss))

    if i % 10 == 0:
        print(f"Epoch {i:3d} Loss {epoch_loss:.4f}")

# Epoch  10 Loss 0.0880
# Epoch  20 Loss 0.0796
# Epoch  30 Loss 0.0620
# Epoch  40 Loss 0.0285
# Epoch  50 Loss 0.0205
# Epoch  60 Loss 0.0187
# Epoch  70 Loss 0.0182
# Epoch  80 Loss 0.0176
# Epoch  90 Loss 0.0171
# Epoch 100 Loss 0.0166

y_pred = jax.vmap(model)(x_batched.reshape(-1, 2, 1))
plt.plot(x[1:], y[1:], "--k", label="data")
plt.plot(x[1:], y_pred, label="prediction")
plt.legend()

image

Lazy initialization

In cases where in_features needs to be inferred from input, use None instead of in_features to infer the value at runtime. However, since the lazy module initialize it's state after the first call (i.e. mutate it's state) jax transformation ex: vmap, grad ... is not allowed before initialization. Using any jax transformation before initialization will throw a ValueError.

import serket as sk 
import jax
import jax.numpy as jnp 

model = sk.nn.Sequential(
    [
        sk.nn.Conv2D(None, 128, 3),
        sk.nn.ReLU(),
        sk.nn.MaxPool2D(2, 2),
        sk.nn.Conv2D(128, 64, 3),
        sk.nn.ReLU(),
        sk.nn.MaxPool2D(2, 2),
        sk.nn.Flatten(),
        sk.nn.Linear(None, 128),
        sk.nn.ReLU(),
        sk.nn.Linear(128, 1),
    ]
)

# print the first `Conv2D` layer before initialization
print(model[0].__repr__())
# Conv2D(
#   weight=None,
#   bias=None,
#   *in_features=None,
#   *out_features=None,
#   *kernel_size=None,
#   *strides=None,
#   *padding=None,
#   *input_dilation=None,
#   *kernel_dilation=None,
#   weight_init_func=None,
#   bias_init_func=None,
#   *groups=None
# )

try :
    jax.vmap(model)(jnp.ones((10, 1,28, 28)))
except ValueError:
    print("***** Not initialized *****")
# ***** Not initialized *****

# dry run to initialize the model
model(jnp.empty([3,128,128]))

print(model[0].__repr__())
# Conv2D(
#   weight=f32[128,3,3,3],
#   bias=f32[128,1,1],
#   *in_features=3,
#   *out_features=128,
#   *kernel_size=(3,3),
#   *strides=(1,1),
#   *padding=((1,1),(1,1)),
#   *input_dilation=(1,1),
#   *kernel_dilation=(1,1),
#   weight_init_func=Partial(glorot_uniform(key,shape,dtype)),
#   bias_init_func=Partial(zeros(key,shape,dtype)),
#   *groups=1
# )
Train MNIST

We will use tensorflow datasets for dataloading. for more on interface of jax/tensorflow dataset see here

# imports
import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type="GPU")
import tensorflow_datasets as tfds
import tensorflow.experimental.numpy as tnp
import jax
import jax.numpy as jnp
import jax.random as jr 
import optax  # for gradient optimization
import serket as sk
import matplotlib.pyplot as plt
import functools as ft
# Construct a tf.data.Dataset
batch_size = 128

# convert the samples from integers to floating-point numbers
# and channel first format
def preprocess_data(x):
    # convert to channel first format
    image = tnp.moveaxis(x["image"], -1, 0)
    # normalize to [0, 1]
    image = tf.cast(image, tf.float32) / 255.0

    # one-hot encode the labels
    label = tf.one_hot(x["label"], 10) / 1.0
    return {"image": image, "label": label}


ds_train, ds_test = tfds.load("mnist", split=["train", "test"], shuffle_files=True)
# (batches, batch_size, 1, 28, 28)
ds_train = ds_train.shuffle(1024).map(preprocess_data).batch(batch_size).prefetch(tf.data.AUTOTUNE)

# (batches, 1, 28, 28)
ds_test = ds_test.map(preprocess_data).prefetch(tf.data.AUTOTUNE)

🏗️ Model definition

We will use jax.vmap(model) to apply model on batches.

@sk.treeclass
class CNN:
    def __init__(self):
        self.conv1 = sk.nn.Conv2D(1, 32, (3, 3), padding="valid")
        self.relu1 = sk.nn.ReLU()
        self.pool1 = sk.nn.MaxPool2D((2, 2), strides=(2, 2))
        self.conv2 = sk.nn.Conv2D(32, 64, (3, 3), padding="valid")
        self.relu2 = sk.nn.ReLU()
        self.pool2 = sk.nn.MaxPool2D((2, 2), strides=(2, 2))
        self.flatten = sk.nn.Flatten(start_dim=0)
        self.dropout = sk.nn.Dropout(0.5)
        self.linear = sk.nn.Linear(5*5*64, 10)

    def __call__(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.dropout(x)
        x = self.linear(x)
        return x

model = CNN()

🎨 Visualize model

Model summary
print(model.summary(show_config=False, array=jnp.empty((1, 28, 28))))  
┌───────┬─────────┬─────────┬──────────────┬─────────────┬─────────────┐
Name   Type     Param #  │Size          │Input        │Output       │
├───────┼─────────┼─────────┼──────────────┼─────────────┼─────────────┤
conv1  Conv2D   320(0)   1.25KB(0.00B) f32[1,28,28] f32[32,26,26]
├───────┼─────────┼─────────┼──────────────┼─────────────┼─────────────┤
relu1  ReLU     0(0)     0.00B(0.00B)  f32[32,26,26]f32[32,26,26]
├───────┼─────────┼─────────┼──────────────┼─────────────┼─────────────┤
pool1  MaxPool2D0(0)     0.00B(0.00B)  f32[32,26,26]f32[32,13,13]
├───────┼─────────┼─────────┼──────────────┼─────────────┼─────────────┤
conv2  Conv2D   18,496(0)72.25KB(0.00B)f32[32,13,13]f32[64,11,11]
├───────┼─────────┼─────────┼──────────────┼─────────────┼─────────────┤
relu2  ReLU     0(0)     0.00B(0.00B)  f32[64,11,11]f32[64,11,11]
├───────┼─────────┼─────────┼──────────────┼─────────────┼─────────────┤
pool2  MaxPool2D0(0)     0.00B(0.00B)  f32[64,11,11]f32[64,5,5]  
├───────┼─────────┼─────────┼──────────────┼─────────────┼─────────────┤
flattenFlatten  0(0)     0.00B(0.00B)  f32[64,5,5]  f32[1600]    
├───────┼─────────┼─────────┼──────────────┼─────────────┼─────────────┤
dropoutDropout  0(0)     0.00B(0.00B)  f32[1600]    f32[1600]    
├───────┼─────────┼─────────┼──────────────┼─────────────┼─────────────┤
linear Linear   16,010(0)62.54KB(0.00B)f32[1600]    f32[10]      
└───────┴─────────┴─────────┴──────────────┴─────────────┴─────────────┘
Total count :	34,826(0)
Dynamic count :	34,826(0)
Frozen count :	0(0)
------------------------------------------------------------------------
Total size :	136.04KB(0.00B)
Dynamic size :	136.04KB(0.00B)
Frozen size :	0.00B(0.00B)
========================================================================
tree diagram
print(model.tree_diagram())
CNN
    ├── conv1=Conv2D
       ├── weight=f32[32,1,3,3]
       ├── bias=f32[32,1,1]
       * in_features=1
       * out_features=32
       * kernel_size=(3, 3)
       * strides=(1, 1)
       * padding=((0, 0), (0, 0))
       * input_dilation=(1, 1)
       * kernel_dilation=(1, 1)
       ├── weight_init_func=Partial(init(key,shape,dtype))
       ├── bias_init_func=Partial(zeros(key,shape,dtype))
       * groups=1    
    ├── relu1=ReLU  
    * pool1=MaxPool2D
       * kernel_size=(2, 2)
       * strides=(2, 2)
       * padding='valid' 
    ├── conv2=Conv2D
       ├── weight=f32[64,32,3,3]
       ├── bias=f32[64,1,1]
       * in_features=32
       * out_features=64
       * kernel_size=(3, 3)
       * strides=(1, 1)
       * padding=((0, 0), (0, 0))
       * input_dilation=(1, 1)
       * kernel_dilation=(1, 1)
       ├── weight_init_func=Partial(init(key,shape,dtype))
       ├── bias_init_func=Partial(zeros(key,shape,dtype))
       * groups=1    
    ├── relu2=ReLU  
    * pool2=MaxPool2D
       * kernel_size=(2, 2)
       * strides=(2, 2)
       * padding='valid' 
    * flatten=Flatten
       * start_dim=0
       * end_dim=-1  
    ├── dropout=Dropout
       * p=0.5
       └── eval=None   
    └── linear=Linear
        ├── weight=f32[1600,10]
        ├── bias=f32[10]
        * in_features=1600
        * out_features=10  
    
Plot sample predictions before training
 
# set all dropout off
test_model = model.at[model == "eval"].set(True, is_leaf=lambda x: x is None)

def show_images_with_predictions(model, images, one_hot_labels):
    logits = jax.vmap(model)(images)
    predictions = jnp.argmax(logits, axis=-1)
    fig, axes = plt.subplots(5, 5, figsize=(10, 10))
    for i, ax in enumerate(axes.flat):
        ax.imshow(images[i].reshape(28, 28), cmap="binary")
        ax.set(title=f"Prediction: {predictions[i]}\nLabel: {jnp.argmax(one_hot_labels[i], axis=-1)}")
        ax.set_xticks([])
        ax.set_yticks([])
    plt.show()

example = ds_test.take(25).as_numpy_iterator()
example = list(example)
sample_test_images = jnp.stack([x["image"] for x in example])
sample_test_labels = jnp.stack([x["label"] for x in example])

show_images_with_predictions(test_model, sample_test_images, sample_test_labels)

image

🏃 Train the model

@ft.partial(jax.value_and_grad, has_aux=True)
def loss_func(model, batched_images, batched_one_hot_labels):
    logits = jax.vmap(model)(batched_images)
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=batched_one_hot_labels))
    return loss, logits


# using optax for gradient updates
optim = optax.adam(1e-3)
optim_state = optim.init(model)


@jax.jit
def batch_step(model, batched_images, batched_one_hot_labels, optim_state):
    (loss, logits), grads = loss_func(model, batched_images, batched_one_hot_labels)
    updates, optim_state = optim.update(grads, optim_state)
    model = optax.apply_updates(model, updates)
    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == jnp.argmax(batched_one_hot_labels, axis=-1))
    return model, optim_state, loss, accuracy


epochs = 5

for i in range(epochs):
    epoch_accuracy = []
    epoch_loss = []

    for example in ds_train.as_numpy_iterator():
        image, label = example["image"], example["label"]
        model, optim_state, loss, accuracy = batch_step(model, image, label, optim_state)
        epoch_accuracy.append(accuracy)
        epoch_loss.append(loss)

    epoch_loss = jnp.mean(jnp.array(epoch_loss))
    epoch_accuracy = jnp.mean(jnp.array(epoch_accuracy))

    print(f"epoch:{i+1:00d}\tloss:{epoch_loss:.4f}\taccuracy:{epoch_accuracy:.4f}")
    
# epoch:1	loss:0.2706	accuracy:0.9268
# epoch:2	loss:0.0725	accuracy:0.9784
# epoch:3	loss:0.0533	accuracy:0.9836
# epoch:4	loss:0.0442	accuracy:0.9868
# epoch:5	loss:0.0368	accuracy:0.9889

🎨 Visualize After training

test_model = model.at[model == "eval"].set(True, is_leaf=lambda x: x is None)
show_images_with_predictions(test_model, sample_test_images, sample_test_labels)

image

PINN with Finite difference

We will try to estimate NN(x)~f(x), where df(x)/dx = cos(x) and df(x)/dx will be represented with finite difference scheme

import copy

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax

import serket as sk

x = jnp.linspace(-jnp.pi, jnp.pi, 1000)[:, None]
y = jnp.sin(x)
dx = x[1] - x[0]
dydx = jnp.cos(x)

NN_fd = sk.nn.Sequential(
    [
        sk.nn.Linear(1, 128),
        sk.nn.ReLU(),
        sk.nn.Linear(128, 128),
        sk.nn.ReLU(),
        sk.nn.Linear(128, 1),
    ]
)

NN_ad = copy.copy(NN_fd)
optim = optax.adam(1e-3)


@jax.value_and_grad
def loss_func_fd(NN, x):
    y = NN(x)
    dydx = sk.fd.difference(y, axis=0, accuracy=5, step_size=dx)
    loss = jnp.mean((dydx - jnp.cos(x)) ** 2)
    loss += jnp.mean((NN(jnp.zeros_like(x))) ** 2)  # initial condition
    return loss


@jax.value_and_grad
def loss_func_ad(NN, x):
    loss = jnp.mean((sk.diff(NN)(x) - jnp.cos(x)) ** 2)
    loss += jnp.mean(NN(jnp.zeros_like(x)) ** 2)  # initial condition
    return loss


@jax.jit
def step_fd(NN, x, optim_state):
    loss, grads = loss_func_fd(NN, x)
    updates, optim_state = optim.update(grads, optim_state)
    NN = optax.apply_updates(NN, updates)
    return NN, optim_state, loss


def train_fd(NN_fd, optim_state_fd, epochs):
    for i in range(1, epochs + 1):
        NN_fd, optim_state_fd, loss_fd = step_fd(NN_fd, x, optim_state_fd)
    return NN_fd, optim_state_fd, loss_fd


@jax.jit
def step_ad(NN, x, optim_state):
    loss, grads = loss_func_ad(NN, x)
    updates, optim_state = optim.update(grads, optim_state)
    NN = optax.apply_updates(NN, updates)
    return NN, optim_state, loss


def train_ad(NN_ad, optim_state_ad, epochs):
    for i in range(1, epochs + 1):
        NN_ad, optim_state_ad, loss_ad = step_ad(NN_ad, x, optim_state_ad)
    return NN_ad, optim_state_ad, loss_ad


epochs = 1000


optim_state_fd = optim.init(NN_fd)
optim_state_ad = optim.init(NN_ad)


NN_fd, optim_state_fd, loss_fd = train_fd(NN_fd, optim_state_fd, epochs)
NN_ad, optim_state_ad, loss_ad = train_ad(NN_ad, optim_state_ad, epochs)
print(f"Loss_fd {loss_fd:.4f} \nLoss_ad {loss_ad:.4f}")
y_fd = NN_fd(x)
y_ad = NN_ad(x)
plt.plot(x, y, "--k", label="true")
plt.plot(x, y_fd, label="fd pred")
plt.plot(x, y_ad, label="ad pred")
plt.legend()

# Loss_fd 0.0012 
# Loss_ad 0.0235

image

Reconstructing a vector field F using ∇.F = 0 and ∇xF=2k condition
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax

import serket as sk

x, y = [jnp.linspace(-1, 1,50)] * 2
dx, dy = [x[1] - x[0]] * 2
X, Y = jnp.meshgrid(x, y, indexing="ij")

F1 = -Y
F2 = +X
F = jnp.stack([F1, F2], axis=0)

NN = sk.nn.Sequential(
    [
        sk.nn.Conv2D(2, 32, kernel_size=3, padding="same"),
        sk.nn.ReLU(),
        sk.nn.Conv2D(32, 32, kernel_size=3, padding="same"),
        sk.nn.ReLU(),
        sk.nn.Conv2D(32, 2, kernel_size=3, padding="same"),
    ]
)

optim = optax.adam(1e-3)


@jax.value_and_grad
def loss_func(NN, F):
    F_pred = NN(F)
    div = sk.fd.divergence(F_pred, accuracy=5, step_size=(dx, dy))  
    loss = jnp.mean(div**2)  # divergence free condition
    curl = sk.fd.curl(F_pred, accuracy=2, step_size=(dx, dy))
    loss += jnp.mean((curl-jnp.ones_like(curl)*2)**2)  # curl condition 
    return loss


@jax.jit
def step(NN, F, optim_state):
    loss, grads = loss_func(NN, F)
    updates, optim_state = optim.update(grads, optim_state)
    NN = optax.apply_updates(NN, updates)
    return NN, optim_state, loss


def train(NN, Z, optim_state, epochs):
    for i in range(1, epochs + 1):
        NN, optim_state, loss = step(NN, Z, optim_state)
    return NN, optim_state, loss


Z = jnp.stack([X, Y], axis=0)  # collocation points
optim_state = optim.init(NN)  # initialise optimiser
epochs = 1_000
NN, _, loss = train(NN, Z, optim_state, epochs)

Fpred = NN(Z)  # predicted field

plt.figure(figsize=(10, 10))
plt.quiver(X, Y, Fpred[0], Fpred[1], color="r", label="pred")
plt.quiver(X, Y, F1, F2, color="k", alpha=0.5, label="true")
plt.legend()

image

🥶 Freezing parameters /Fine tuning

See here for more about freezing

🔘 Filtering by masking

See here for more about filterning

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

serket-0.0.10.tar.gz (82.5 kB view hashes)

Uploaded Source

Built Distribution

serket-0.0.10-py3-none-any.whl (92.0 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page