JAX NN library.
Project description
The ✨Magical✨ JAX Scientific ML Library.
*Serket is the goddess of magic in Egyptian mythology
Installation |Description |Quick Example |Freezing/Fine tuning |Filtering
🛠️ Installation
pip install serket
Install development version
pip install git+https://github.com/ASEM000/serket
📖 Description
serket
aims to be the most intuitive and easy-to-use physics-based Neural network library in JAX.serket
is built on top ofpytreeclass
serket
is fully transparent tojax
transformation (e.g.vmap
,grad
,jit
,...)
➖➕Finite difference package: serket.fd
➕➖
Group | Function/Layer |
---|---|
Finite difference layer | - Difference : apply finite difference to input array to any derivative order and accuracy |
Finite difference functions | - difference : finite difference of array with any accuracy and derivative order - generate_finitediff_coeffs : generate coeffs using sample points and derivative order - fgrad : differentiate functions (similar to jax.grad ) with custom accuracy and derivative order |
Vector operator layers | - Curl , Divergence , Gradient , Laplacian , Jacobian , Hessian |
Vector operator function | - curl , divergence , gradient , laplacian , jacobian , hessian |
🧠 Neural network package: serket.nn
🧠
Group | Layers |
---|---|
Linear | - Linear , Bilinear , Multilinear , GeneralLinear , Identity |
Densely connected | - FNN (Fully connected network), - PFNN (Parallel fully connected network) |
Convolution | - Conv1D , Conv2D , Conv3D , - Conv1DTranspose , Conv2DTranspose , Conv3DTranspose ,- DepthwiseConv1D , DepthwiseConv2D , DepthwiseConv3D , - SeparableConv1D , SeparableConv2D , SeparableConv3D , - Conv1DLocal , Conv2DLocal , Conv3DLocal - FFTConv1D , FFTConv2D , FFTConv3D , - FFTConv1DTranspose , FFTConv2DTranspose , FFTConv3DTranspose ,- DepthwiseFFTConv1D , DepthwiseFFTConv2D , DepthwiseFFTConv3D , - SeparableFFTConv1D , SeparableFFTConv2D , SeparableFFTConv3D |
Containers | - Sequential , Lambda |
Pooling ( kernex backend) |
- MaxPool1D , MaxPool2D , MaxPool3D , - AvgPool1D , AvgPool2D , AvgPool3D - GlobalMaxPool1D , GlobalMaxPool2D , GlobalMaxPool3D , - GlobalAvgPool1D , GlobalAvgPool2D , GlobalAvgPool3D - LPPool1D , LPPool2D ,LPPool3D , - AdaptiveAvgPool1D , AdaptiveAvgPool2D , AdaptiveAvgPool3D ,- AdaptiveMaxPool1D , AdaptiveMaxPool2D , AdaptiveMaxPool3D - AdaptiveConcatPool1D ,AdaptiveConcatPool2D ,AdaptiveConcatPool3D |
Reshaping | - Flatten , Unflatten , - FlipLeftRight2D , FlipUpDown2D , - Repeat1D , Repeat2D , Repeat3D , - Resize1D , Resize2D , Resize3D , - Upsample1D , Upsample2D , Upsample3D ,- Pad1D , Pad2D , Pad3D |
Crop | - Crop1D , Crop2D , |
Normalization | - LayerNorm , InstanceNorm , GroupNorm |
Blurring | - AvgBlur2D , GaussianBlur2D |
Dropout | - Dropout , ,- Dropout1D , Dropout2D , Dropout3D , |
Random transforms | - RandomCrop1D , RandomCrop2D , - RandomApply , - RandomCutout1D , RandomCutout2D ,- RandomZoom2D , - RandomContrast2D |
Misc | - HistogramEqualization2D , AdjustContrast2D , Filter2D , PixelShuffle2D |
Activations | - AdaptiveLeakyReLU ,AdaptiveReLU ,AdaptiveSigmoid ,AdaptiveTanh ,- CeLU ,ELU ,GELU ,GLU - HardSILU ,HardShrink ,HardSigmoid ,HardSwish ,HardTanh ,- LeakyReLU ,LogSigmoid ,LogSoftmax ,Mish ,PReLU ,- ReLU ,ReLU6 ,SILU ,SeLU ,Sigmoid ,SoftPlus ,SoftShrink ,- SoftSign ,Swish ,Tanh ,TanhShrink , ThresholdedReLU , Snake |
Recurrent cells | - SimpleRNNCell , LSTMCell , GRUCell , - ConvLSTM1D , ConvLSTM2D , ConvLSTM3D , - SeparableConvLSTM1DCell , SeparableConvLSTM2DCell , SeparableConvLSTM3DCell - ConvGRU1DCell ,ConvGRU2DCell ,ConvGRU3DCell ,- SeparableConvGRU1DCell ,SeparableConvGRU2DCell ,SeparableConvGRU3DCell - FFTConvLSTM1D , FFTConvLSTM2D , FFTConvLSTM3D , - SeparableFFTConvLSTM1DCell , SeparableFFTConvLSTM2DCell , SeparableFFTConvLSTM3DCell - FFTConvGRU1DCell ,FFTConvGRU2DCell ,FFTConvGRU3DCell ,- SeparableFFTConvGRU1DCell ,SeparableFFTConvGRU2DCell ,SeparableFFTConvGRU3DCell |
Blocks | - VGG16Block , VGG19Block , UNetBlock |
⏩ Examples:
Linear layers examples
import jax.numpy as jnp
import serket as sk
# Linear
x = jnp.ones([1, 2, 3, 4])
l1 = sk.nn.Linear(4, 5) # last dim is 4, output dim is 5
print(l1(x).shape) # (1, 2, 3, 5)
# Bilinear
x1, x2 = jnp.ones([1, 2, 3, 4]), jnp.ones([1, 2, 3, 5])
l2 = sk.nn.Bilinear(4, 5, 6) # last dim is 4,5, output dim is 6
print(l2(x1, x2).shape) # (1, 2, 3, 6)
# Multilinear
x1, x2, x3 = jnp.ones([1, 2, 3, 4]), jnp.ones([1, 2, 3, 5]), jnp.ones([1, 2, 3, 6])
l3 = sk.nn.Multilinear((4, 5, 6), 7) # last dim is 4,5,6, output dim is 7
print(l3(x1, x2, x3).shape) # (1, 2, 3, 7)
# GeneralLinear
x = jnp.ones([4, 5, 6, 7])
# apply a linear layer to axis 1,2,3, and output dim is 5
# in_features = (4, 5, 6)
l4 = sk.nn.GeneralLinear((5, 6, 7), 5, in_axes=(1, 2, 3))
print(l4(x).shape) # (4, 5)
Finite difference examples
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy.testing as npt
import serket as sk
# lets first define a vector valued function F: R^3 -> R^3
# F = F1, F2
# F1 = x^2 + y^3
# F2 = x^4 + y^3
# F3 = 0
# F = [x**2 + y**3, x**4 + y**3, 0]
x, y, z = [jnp.linspace(0, 1, 100)] * 3
dx, dy, dz = x[1] - x[0], y[1] - y[0], z[1] - z[0]
X, Y, Z = jnp.meshgrid(x, y, z, indexing="ij")
F1 = X**2 + Y**3
F2 = X**4 + Y**3
F3 = jnp.zeros_like(F1)
F = jnp.stack([F1, F2, F3], axis=0)
# ∂F1/∂x : differentiate F1 with respect to x (i.e axis=0)
dF1dx = sk.fd.difference(F1, axis=0, step_size=dx, accuracy=6)
dF1dx_exact = 2 * X
npt.assert_allclose(dF1dx, dF1dx_exact, atol=1e-7)
# ∂F2/∂y : differentiate F2 with respect to y (i.e axis=1)
dF2dy = sk.fd.difference(F2, axis=1, step_size=dy, accuracy=6)
dF2dy_exact = 3 * Y**2
npt.assert_allclose(dF2dy, dF2dy_exact, atol=1e-7)
# ∇.F : the divergence of F
divF = sk.fd.divergence(F, step_size=(dx, dy, dz), keepdims=False, accuracy=6)
divF_exact = 2 * X + 3 * Y**2
npt.assert_allclose(divF, divF_exact, atol=1e-7)
# ∇F1 : the gradient of F1
gradF1 = sk.fd.gradient(F1, step_size=(dx, dy, dz), accuracy=6)
gradF1_exact = jnp.stack([2 * X, 3 * Y**2, 0 * X], axis=0)
npt.assert_allclose(gradF1, gradF1_exact, atol=1e-7)
# ΔF1 : laplacian of F1
lapF1 = sk.fd.laplacian(F1, step_size=(dx, dy, dz), accuracy=6)
lapF1_exact = 2 + 6 * Y
npt.assert_allclose(lapF1, lapF1_exact, atol=1e-7)
# ∇xF : the curl of F
curlF = sk.fd.curl(F, step_size=(dx, dy, dz), accuracy=6)
curlF_exact = jnp.stack([F1 * 0, F1 * 0, 4 * X**3 - 3 * Y**2], axis=0)
npt.assert_allclose(curlF, curlF_exact, atol=1e-7)
# Jacobian of F
JF = sk.fd.jacobian(F, accuracy=4, step_size=(dx, dy, dz))
JF_exact = jnp.array(
[
[2 * X, 3 * Y**2, jnp.zeros_like(X)],
[4 * X**3, 3 * Y**2, jnp.zeros_like(X)],
[jnp.zeros_like(X), jnp.zeros_like(X), jnp.zeros_like(X)],
]
)
npt.assert_allclose(JF, JF_exact, atol=1e-7)
# Hessian of F1
HF1 = sk.fd.hessian(F1, accuracy=4, step_size=(dx, dy, dz))
HF1_exact = jnp.array(
[
[
2 * jnp.ones_like(X), # ∂2F1/∂x2
0 * jnp.ones_like(X), # ∂2F1/∂xy
0 * jnp.ones_like(X), # ∂2F1/∂xz
],
[
0 * jnp.ones_like(X), # ∂2F1/∂yx
6 * Y**2, # ∂2F1/∂y2
0 * jnp.ones_like(X), # ∂2F1/∂yz
],
[
0 * jnp.ones_like(X), # ∂2F1/∂zx
0 * jnp.ones_like(X), # ∂2F1/∂zy
0 * jnp.ones_like(X), # ∂2F1/∂z2
],
]
)
npt.assert_allclose(JF, JF_exact, atol=1e-7)
Train Bidirectional-LSTM
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import optax # for gradient optimization
import serket as sk
x = jnp.linspace(0, 1, 101).reshape(-1, 1) # 101 samples of 1D data
y = jnp.sin(2 * jnp.pi * x)
y += 0.1 * jr.normal(jr.PRNGKey(0), y.shape)
# we will use 2 time steps to predict the next time step
x_batched = jnp.stack([x[:-1], x[1:]], axis=1)
x_batched = jnp.reshape(x_batched, (100, 1, 2, 1)) # 100 minibatches x 1 sample x 2 time steps x 1D data
y_batched = jnp.reshape(y[1:], (100, 1, 1)) # 100 minibatches x 1 samples x 1D data
model = sk.nn.Sequential(
[
# first cell is the forward cell, second cell is the backward cell for bidirectional RNN
# we return the full sequence of outputs for each cell by setting return_sequences=True
# we use None in place of `in_features` to infer the input shape from the input
sk.nn.ScanRNN(
sk.nn.LSTMCell(None, 64),
backward_cell=sk.nn.LSTMCell(None, 64),
return_sequences=True,
),
# here the in_features is inferred from the previous layer by setting it to None
# or simply we can set it to 64*2 (64 for each cell from previous layer)
# we set return_sequences=False to return only the last output of the sequence
sk.nn.ScanRNN(sk.nn.LSTMCell(None, 1), return_sequences=False),
]
)
@jax.value_and_grad
def loss_func(NN, batched_x, batched_y):
# use jax.vmap to apply the model to each minibatch
# in our case single x minibatch has shape (1, 2, 1)
# and single y minibatch has shape (1, 1)
# then vmap will be applied to the leading axis
batched_preds = jax.vmap(NN)(batched_x)
return jnp.mean((batched_preds - batched_y) ** 2)
@jax.jit
def batch_step(NN, batched_x, batched_y, opt_state):
loss, grads = loss_func(NN, batched_x, batched_y)
updates, optim_state = optim.update(grads, opt_state)
NN = optax.apply_updates(NN, updates)
return NN, optim_state, loss
# dry run to infer the in_features (i.e. replace None with in_features)
# if you want restrict the model to a specific input shape or to avoid
# confusion you can manually specify the in_features as a consequence
# dry run is not necessary in this case
model(x_batched[0, 0])
optim = optax.adam(1e-3)
opt_state = optim.init(model)
epochs = 100
for i in range(1, epochs + 1):
epoch_loss = []
for x_b, y_b in zip(x_batched, y_batched):
model, opt_state, loss = batch_step(model, x_b, y_b, opt_state)
epoch_loss.append(loss)
epoch_loss = jnp.mean(jnp.array(epoch_loss))
if i % 10 == 0:
print(f"Epoch {i:3d} Loss {epoch_loss:.4f}")
# Epoch 10 Loss 0.0880
# Epoch 20 Loss 0.0796
# Epoch 30 Loss 0.0620
# Epoch 40 Loss 0.0285
# Epoch 50 Loss 0.0205
# Epoch 60 Loss 0.0187
# Epoch 70 Loss 0.0182
# Epoch 80 Loss 0.0176
# Epoch 90 Loss 0.0171
# Epoch 100 Loss 0.0166
y_pred = jax.vmap(model)(x_batched.reshape(-1, 2, 1))
plt.plot(x[1:], y[1:], "--k", label="data")
plt.plot(x[1:], y_pred, label="prediction")
plt.legend()
Lazy initialization
In cases where in_features
needs to be inferred from input, use None
instead of in_features
to infer the value at runtime.
However, since the lazy module initialize it's state after the first call (i.e. mutate it's state) jax
transformation ex: vmap, grad ...
is not allowed before initialization. Using any jax
transformation before initialization will throw a ValueError
.
import serket as sk
import jax
import jax.numpy as jnp
model = sk.nn.Sequential(
[
sk.nn.Conv2D(None, 128, 3),
sk.nn.ReLU(),
sk.nn.MaxPool2D(2, 2),
sk.nn.Conv2D(128, 64, 3),
sk.nn.ReLU(),
sk.nn.MaxPool2D(2, 2),
sk.nn.Flatten(),
sk.nn.Linear(None, 128),
sk.nn.ReLU(),
sk.nn.Linear(128, 1),
]
)
# print the first `Conv2D` layer before initialization
print(model[0].__repr__())
# Conv2D(
# weight=None,
# bias=None,
# *in_features=None,
# *out_features=None,
# *kernel_size=None,
# *strides=None,
# *padding=None,
# *input_dilation=None,
# *kernel_dilation=None,
# weight_init_func=None,
# bias_init_func=None,
# *groups=None
# )
try :
jax.vmap(model)(jnp.ones((10, 1,28, 28)))
except ValueError:
print("***** Not initialized *****")
# ***** Not initialized *****
# dry run to initialize the model
model(jnp.empty([3,128,128]))
print(model[0].__repr__())
# Conv2D(
# weight=f32[128,3,3,3],
# bias=f32[128,1,1],
# *in_features=3,
# *out_features=128,
# *kernel_size=(3,3),
# *strides=(1,1),
# *padding=((1,1),(1,1)),
# *input_dilation=(1,1),
# *kernel_dilation=(1,1),
# weight_init_func=Partial(glorot_uniform(key,shape,dtype)),
# bias_init_func=Partial(zeros(key,shape,dtype)),
# *groups=1
# )
Train MNIST
We will use tensorflow
datasets for dataloading. for more on interface of jax/tensorflow dataset see here
# imports
import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type="GPU")
import tensorflow_datasets as tfds
import tensorflow.experimental.numpy as tnp
import jax
import jax.numpy as jnp
import jax.random as jr
import optax # for gradient optimization
import serket as sk
import matplotlib.pyplot as plt
import functools as ft
# Construct a tf.data.Dataset
batch_size = 128
# convert the samples from integers to floating-point numbers
# and channel first format
def preprocess_data(x):
# convert to channel first format
image = tnp.moveaxis(x["image"], -1, 0)
# normalize to [0, 1]
image = tf.cast(image, tf.float32) / 255.0
# one-hot encode the labels
label = tf.one_hot(x["label"], 10) / 1.0
return {"image": image, "label": label}
ds_train, ds_test = tfds.load("mnist", split=["train", "test"], shuffle_files=True)
# (batches, batch_size, 1, 28, 28)
ds_train = ds_train.shuffle(1024).map(preprocess_data).batch(batch_size).prefetch(tf.data.AUTOTUNE)
# (batches, 1, 28, 28)
ds_test = ds_test.map(preprocess_data).prefetch(tf.data.AUTOTUNE)
🏗️ Model definition
We will use jax.vmap(model)
to apply model
on batches.
@sk.treeclass
class CNN:
def __init__(self):
self.conv1 = sk.nn.Conv2D(1, 32, (3, 3), padding="valid")
self.relu1 = sk.nn.ReLU()
self.pool1 = sk.nn.MaxPool2D((2, 2), strides=(2, 2))
self.conv2 = sk.nn.Conv2D(32, 64, (3, 3), padding="valid")
self.relu2 = sk.nn.ReLU()
self.pool2 = sk.nn.MaxPool2D((2, 2), strides=(2, 2))
self.flatten = sk.nn.Flatten(start_dim=0)
self.dropout = sk.nn.Dropout(0.5)
self.linear = sk.nn.Linear(5*5*64, 10)
def __call__(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.pool2(x)
x = self.flatten(x)
x = self.dropout(x)
x = self.linear(x)
return x
model = CNN()
🎨 Visualize model
Model summary
print(model.summary(show_config=False, array=jnp.empty((1, 28, 28))))
┌───────┬─────────┬─────────┬──────────────┬─────────────┬─────────────┐
│Name │Type │Param # │Size │Input │Output │
├───────┼─────────┼─────────┼──────────────┼─────────────┼─────────────┤
│conv1 │Conv2D │320(0) │1.25KB(0.00B) │f32[1,28,28] │f32[32,26,26]│
├───────┼─────────┼─────────┼──────────────┼─────────────┼─────────────┤
│relu1 │ReLU │0(0) │0.00B(0.00B) │f32[32,26,26]│f32[32,26,26]│
├───────┼─────────┼─────────┼──────────────┼─────────────┼─────────────┤
│pool1 │MaxPool2D│0(0) │0.00B(0.00B) │f32[32,26,26]│f32[32,13,13]│
├───────┼─────────┼─────────┼──────────────┼─────────────┼─────────────┤
│conv2 │Conv2D │18,496(0)│72.25KB(0.00B)│f32[32,13,13]│f32[64,11,11]│
├───────┼─────────┼─────────┼──────────────┼─────────────┼─────────────┤
│relu2 │ReLU │0(0) │0.00B(0.00B) │f32[64,11,11]│f32[64,11,11]│
├───────┼─────────┼─────────┼──────────────┼─────────────┼─────────────┤
│pool2 │MaxPool2D│0(0) │0.00B(0.00B) │f32[64,11,11]│f32[64,5,5] │
├───────┼─────────┼─────────┼──────────────┼─────────────┼─────────────┤
│flatten│Flatten │0(0) │0.00B(0.00B) │f32[64,5,5] │f32[1600] │
├───────┼─────────┼─────────┼──────────────┼─────────────┼─────────────┤
│dropout│Dropout │0(0) │0.00B(0.00B) │f32[1600] │f32[1600] │
├───────┼─────────┼─────────┼──────────────┼─────────────┼─────────────┤
│linear │Linear │16,010(0)│62.54KB(0.00B)│f32[1600] │f32[10] │
└───────┴─────────┴─────────┴──────────────┴─────────────┴─────────────┘
Total count : 34,826(0)
Dynamic count : 34,826(0)
Frozen count : 0(0)
------------------------------------------------------------------------
Total size : 136.04KB(0.00B)
Dynamic size : 136.04KB(0.00B)
Frozen size : 0.00B(0.00B)
========================================================================
tree diagram
print(model.tree_diagram())
CNN
├── conv1=Conv2D
│ ├── weight=f32[32,1,3,3]
│ ├── bias=f32[32,1,1]
│ ├*─ in_features=1
│ ├*─ out_features=32
│ ├*─ kernel_size=(3, 3)
│ ├*─ strides=(1, 1)
│ ├*─ padding=((0, 0), (0, 0))
│ ├*─ input_dilation=(1, 1)
│ ├*─ kernel_dilation=(1, 1)
│ ├── weight_init_func=Partial(init(key,shape,dtype))
│ ├── bias_init_func=Partial(zeros(key,shape,dtype))
│ └*─ groups=1
├── relu1=ReLU
├*─ pool1=MaxPool2D
│ ├*─ kernel_size=(2, 2)
│ ├*─ strides=(2, 2)
│ └*─ padding='valid'
├── conv2=Conv2D
│ ├── weight=f32[64,32,3,3]
│ ├── bias=f32[64,1,1]
│ ├*─ in_features=32
│ ├*─ out_features=64
│ ├*─ kernel_size=(3, 3)
│ ├*─ strides=(1, 1)
│ ├*─ padding=((0, 0), (0, 0))
│ ├*─ input_dilation=(1, 1)
│ ├*─ kernel_dilation=(1, 1)
│ ├── weight_init_func=Partial(init(key,shape,dtype))
│ ├── bias_init_func=Partial(zeros(key,shape,dtype))
│ └*─ groups=1
├── relu2=ReLU
├*─ pool2=MaxPool2D
│ ├*─ kernel_size=(2, 2)
│ ├*─ strides=(2, 2)
│ └*─ padding='valid'
├*─ flatten=Flatten
│ ├*─ start_dim=0
│ └*─ end_dim=-1
├── dropout=Dropout
│ ├*─ p=0.5
│ └── eval=None
└── linear=Linear
├── weight=f32[1600,10]
├── bias=f32[10]
├*─ in_features=1600
└*─ out_features=10
Plot sample predictions before training
# set all dropout off
test_model = model.at[model == "eval"].set(True, is_leaf=lambda x: x is None)
def show_images_with_predictions(model, images, one_hot_labels):
logits = jax.vmap(model)(images)
predictions = jnp.argmax(logits, axis=-1)
fig, axes = plt.subplots(5, 5, figsize=(10, 10))
for i, ax in enumerate(axes.flat):
ax.imshow(images[i].reshape(28, 28), cmap="binary")
ax.set(title=f"Prediction: {predictions[i]}\nLabel: {jnp.argmax(one_hot_labels[i], axis=-1)}")
ax.set_xticks([])
ax.set_yticks([])
plt.show()
example = ds_test.take(25).as_numpy_iterator()
example = list(example)
sample_test_images = jnp.stack([x["image"] for x in example])
sample_test_labels = jnp.stack([x["label"] for x in example])
show_images_with_predictions(test_model, sample_test_images, sample_test_labels)
🏃 Train the model
@ft.partial(jax.value_and_grad, has_aux=True)
def loss_func(model, batched_images, batched_one_hot_labels):
logits = jax.vmap(model)(batched_images)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=batched_one_hot_labels))
return loss, logits
# using optax for gradient updates
optim = optax.adam(1e-3)
optim_state = optim.init(model)
@jax.jit
def batch_step(model, batched_images, batched_one_hot_labels, optim_state):
(loss, logits), grads = loss_func(model, batched_images, batched_one_hot_labels)
updates, optim_state = optim.update(grads, optim_state)
model = optax.apply_updates(model, updates)
accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == jnp.argmax(batched_one_hot_labels, axis=-1))
return model, optim_state, loss, accuracy
epochs = 5
for i in range(epochs):
epoch_accuracy = []
epoch_loss = []
for example in ds_train.as_numpy_iterator():
image, label = example["image"], example["label"]
model, optim_state, loss, accuracy = batch_step(model, image, label, optim_state)
epoch_accuracy.append(accuracy)
epoch_loss.append(loss)
epoch_loss = jnp.mean(jnp.array(epoch_loss))
epoch_accuracy = jnp.mean(jnp.array(epoch_accuracy))
print(f"epoch:{i+1:00d}\tloss:{epoch_loss:.4f}\taccuracy:{epoch_accuracy:.4f}")
# epoch:1 loss:0.2706 accuracy:0.9268
# epoch:2 loss:0.0725 accuracy:0.9784
# epoch:3 loss:0.0533 accuracy:0.9836
# epoch:4 loss:0.0442 accuracy:0.9868
# epoch:5 loss:0.0368 accuracy:0.9889
🎨 Visualize After training
test_model = model.at[model == "eval"].set(True, is_leaf=lambda x: x is None)
show_images_with_predictions(test_model, sample_test_images, sample_test_labels)
PINN with Finite difference
We will try to estimate NN(x)~f(x), where df(x)/dx = cos(x) and df(x)/dx will be represented with finite difference scheme
import copy
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
import serket as sk
x = jnp.linspace(-jnp.pi, jnp.pi, 1000)[:, None]
y = jnp.sin(x)
dx = x[1] - x[0]
dydx = jnp.cos(x)
NN_fd = sk.nn.Sequential(
[
sk.nn.Linear(1, 128),
sk.nn.ReLU(),
sk.nn.Linear(128, 128),
sk.nn.ReLU(),
sk.nn.Linear(128, 1),
]
)
NN_ad = copy.copy(NN_fd)
optim = optax.adam(1e-3)
@jax.value_and_grad
def loss_func_fd(NN, x):
y = NN(x)
dydx = sk.fd.difference(y, axis=0, accuracy=5, step_size=dx)
loss = jnp.mean((dydx - jnp.cos(x)) ** 2)
loss += jnp.mean((NN(jnp.zeros_like(x))) ** 2) # initial condition
return loss
@jax.value_and_grad
def loss_func_ad(NN, x):
loss = jnp.mean((sk.diff(NN)(x) - jnp.cos(x)) ** 2)
loss += jnp.mean(NN(jnp.zeros_like(x)) ** 2) # initial condition
return loss
@jax.jit
def step_fd(NN, x, optim_state):
loss, grads = loss_func_fd(NN, x)
updates, optim_state = optim.update(grads, optim_state)
NN = optax.apply_updates(NN, updates)
return NN, optim_state, loss
def train_fd(NN_fd, optim_state_fd, epochs):
for i in range(1, epochs + 1):
NN_fd, optim_state_fd, loss_fd = step_fd(NN_fd, x, optim_state_fd)
return NN_fd, optim_state_fd, loss_fd
@jax.jit
def step_ad(NN, x, optim_state):
loss, grads = loss_func_ad(NN, x)
updates, optim_state = optim.update(grads, optim_state)
NN = optax.apply_updates(NN, updates)
return NN, optim_state, loss
def train_ad(NN_ad, optim_state_ad, epochs):
for i in range(1, epochs + 1):
NN_ad, optim_state_ad, loss_ad = step_ad(NN_ad, x, optim_state_ad)
return NN_ad, optim_state_ad, loss_ad
epochs = 1000
optim_state_fd = optim.init(NN_fd)
optim_state_ad = optim.init(NN_ad)
NN_fd, optim_state_fd, loss_fd = train_fd(NN_fd, optim_state_fd, epochs)
NN_ad, optim_state_ad, loss_ad = train_ad(NN_ad, optim_state_ad, epochs)
print(f"Loss_fd {loss_fd:.4f} \nLoss_ad {loss_ad:.4f}")
y_fd = NN_fd(x)
y_ad = NN_ad(x)
plt.plot(x, y, "--k", label="true")
plt.plot(x, y_fd, label="fd pred")
plt.plot(x, y_ad, label="ad pred")
plt.legend()
# Loss_fd 0.0012
# Loss_ad 0.0235
Reconstructing a vector field F using ∇.F = 0 and ∇xF=2k condition
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
import serket as sk
x, y = [jnp.linspace(-1, 1,50)] * 2
dx, dy = [x[1] - x[0]] * 2
X, Y = jnp.meshgrid(x, y, indexing="ij")
F1 = -Y
F2 = +X
F = jnp.stack([F1, F2], axis=0)
NN = sk.nn.Sequential(
[
sk.nn.Conv2D(2, 32, kernel_size=3, padding="same"),
sk.nn.ReLU(),
sk.nn.Conv2D(32, 32, kernel_size=3, padding="same"),
sk.nn.ReLU(),
sk.nn.Conv2D(32, 2, kernel_size=3, padding="same"),
]
)
optim = optax.adam(1e-3)
@jax.value_and_grad
def loss_func(NN, F):
F_pred = NN(F)
div = sk.fd.divergence(F_pred, accuracy=5, step_size=(dx, dy))
loss = jnp.mean(div**2) # divergence free condition
curl = sk.fd.curl(F_pred, accuracy=2, step_size=(dx, dy))
loss += jnp.mean((curl-jnp.ones_like(curl)*2)**2) # curl condition
return loss
@jax.jit
def step(NN, F, optim_state):
loss, grads = loss_func(NN, F)
updates, optim_state = optim.update(grads, optim_state)
NN = optax.apply_updates(NN, updates)
return NN, optim_state, loss
def train(NN, Z, optim_state, epochs):
for i in range(1, epochs + 1):
NN, optim_state, loss = step(NN, Z, optim_state)
return NN, optim_state, loss
Z = jnp.stack([X, Y], axis=0) # collocation points
optim_state = optim.init(NN) # initialise optimiser
epochs = 1_000
NN, _, loss = train(NN, Z, optim_state, epochs)
Fpred = NN(Z) # predicted field
plt.figure(figsize=(10, 10))
plt.quiver(X, Y, Fpred[0], Fpred[1], color="r", label="pred")
plt.quiver(X, Y, F1, F2, color="k", alpha=0.5, label="true")
plt.legend()
🥶 Freezing parameters /Fine tuning
✨See here for more about freezing✨
🔘 Filtering by masking
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.