JAX NN library.
Project description
The โจMagicalโจ JAX Scientific ML Library.
*Serket is the goddess of magic in Egyptian mythology
Installation |Description |Quick Example |Freezing/Fine tuning |Filtering
๐ ๏ธ Installation
pip install serket
Install development version
pip install git+https://github.com/ASEM000/serket
๐ Description and motivation
serket
aims to be the most intuitive and easy-to-use physics-based Neural network library in JAX.serket
is fully transparent tojax
transformation (e.g.vmap
,grad
,jit
,...)serket
current aim to facilitate the integration of numerical methods in a NN setting (see examples for more)
Layer structure
serket
is built on top of PyTreeClass
, this means that layers are represented as a PyTree whose leaves are the layer parameters.
๐ง Neural network package: serket.nn
๐ง
Group | Layers |
---|---|
Linear | - Linear , Bilinear , Multilinear , GeneralLinear , Identity , Embedding |
Densely connected | - FNN (Fully connected network), |
Convolution | - {Conv,FFTConv}{1D,2D,3D} - {Conv,FFTConv}{1D,2D,3D}Transpose - {Depthwise,Separable}{Conv,FFTConv}{1D,2D,3D} - Conv{1D,2D,3D}Local |
Containers | - Sequential , Lambda |
Pooling ( kernex backend) |
- {Avg,Max,LP}Pool{1D,2D,3D} - Global{Avg,Max}Pool{1D,2D,3D} - Adaptive{Avg,Max}Pool{1D,2D,3D} |
Reshaping | - Flatten , Unflatten , - FlipLeftRight2D , FlipUpDown2D - Resize{1D,2D,3D} - Upsample{1D,2D,3D} - Pad{1D,2D,3D} |
Crop | - Crop{1D,2D} |
Normalization | - {Layer,Instance,Group}Norm |
Blurring | - {Avg,Gaussian}Blur2D |
Dropout | - Dropout - Dropout{1D,2D,3D} |
Random transforms | - RandomCrop{1D,2D} - RandomApply , - RandomCutout{1D,2D} - RandomZoom2D , - RandomContrast2D |
Misc | - HistogramEqualization2D , AdjustContrast2D , Filter2D , PixelShuffle2D |
Activations | - Adaptive{LeakyReLU,ReLU,Sigmoid,Tanh} ,- CeLU ,ELU ,GELU ,GLU - Hard{SILU,Shrink,Sigmoid,Swish,Tanh} , - Soft{Plus,Sign,Shrink} - LeakyReLU ,LogSigmoid ,LogSoftmax ,Mish ,PReLU ,- ReLU ,ReLU6 ,SILU ,SeLU ,Sigmoid - Swish ,Tanh ,TanhShrink , ThresholdedReLU , Snake , Stan , SquarePlus |
Recurrent cells | - {SimpleRNN,LSTM,GRU}Cell - Conv{LSTM,GRU}{1D,2D,3D}Cell |
Blocks | - VGG{16,19}Block , UNetBlock |
โฉ Examples:
Linear layers examples
import jax.numpy as jnp
import serket as sk
# Linear
x = jnp.ones([1, 2, 3, 4])
l1 = sk.nn.Linear(4, 5) # last dim is 4, output dim is 5
print(l1(x).shape) # (1, 2, 3, 5)
# Bilinear
x1, x2 = jnp.ones([1, 2, 3, 4]), jnp.ones([1, 2, 3, 5])
l2 = sk.nn.Bilinear(4, 5, 6) # last dim of the input x1,x2 are 4,5, output dim is 6
print(l2(x1, x2).shape) # (1, 2, 3, 6)
# Multilinear
x1, x2, x3 = jnp.ones([1, 2, 3, 4]), jnp.ones([1, 2, 3, 5]), jnp.ones([1, 2, 3, 6])
l3 = sk.nn.Multilinear((4, 5, 6), 7) # last dim for x1,x2,x3 = 4,5,6, output dim is 7
print(l3(x1, x2, x3).shape) # (1, 2, 3, 7)
# GeneralLinear
x = jnp.ones([4, 5, 6, 7])
# apply a linear layer to axis 1,2,3, with dim = (5, 6, 7) and output dim is 5
l4 = sk.nn.GeneralLinear((5, 6, 7), 5, in_axes=(1, 2, 3))
print(l4(x).shape) # (4, 5)
Train Bidirectional-LSTM
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import optax # for gradient optimization
import serket as sk
x = jnp.linspace(0, 1, 101).reshape(-1, 1) # 101 samples of 1D data
y = jnp.sin(2 * jnp.pi * x)
y += 0.1 * jr.normal(jr.PRNGKey(0), y.shape)
# we will use 2 time steps to predict the next time step
x_batched = jnp.stack([x[:-1], x[1:]], axis=1)
x_batched = jnp.reshape(x_batched, (100, 1, 2, 1)) # 100 minibatches x 1 sample x 2 time steps x 1D data
y_batched = jnp.reshape(y[1:], (100, 1, 1)) # 100 minibatches x 1 samples x 1D data
model = sk.nn.Sequential(
[
# first cell is the forward cell, second cell is the backward cell for bidirectional RNN
# we return the full sequence of outputs for each cell by setting return_sequences=True
# we use None in place of `in_features` to infer the input shape from the input
sk.nn.ScanRNN(
sk.nn.LSTMCell(None, 64),
backward_cell=sk.nn.LSTMCell(None, 64),
return_sequences=True,
),
# here the in_features is inferred from the previous layer by setting it to None
# or simply we can set it to 64*2 (64 for each cell from previous layer)
# we set return_sequences=False to return only the last output of the sequence
sk.nn.ScanRNN(sk.nn.LSTMCell(None, 1), return_sequences=False),
]
)
@jax.value_and_grad
def loss_func(NN, batched_x, batched_y):
# use jax.vmap to apply the model to each minibatch
# in our case single x minibatch has shape (1, 2, 1)
# and single y minibatch has shape (1, 1)
# then vmap will be applied to the leading axis
batched_preds = jax.vmap(NN)(batched_x)
return jnp.mean((batched_preds - batched_y) ** 2)
@jax.jit
def batch_step(NN, batched_x, batched_y, opt_state):
loss, grads = loss_func(NN, batched_x, batched_y)
updates, optim_state = optim.update(grads, opt_state)
NN = optax.apply_updates(NN, updates)
return NN, optim_state, loss
# dry run to infer the in_features (i.e. replace None with in_features)
# if you want restrict the model to a specific input shape or to avoid
# confusion you can manually specify the in_features as a consequence
# dry run is not necessary in this case
model(x_batched[0, 0])
optim = optax.adam(1e-3)
opt_state = optim.init(model)
epochs = 100
for i in range(1, epochs + 1):
epoch_loss = []
for x_b, y_b in zip(x_batched, y_batched):
model, opt_state, loss = batch_step(model, x_b, y_b, opt_state)
epoch_loss.append(loss)
epoch_loss = jnp.mean(jnp.array(epoch_loss))
if i % 10 == 0:
print(f"Epoch {i:3d} Loss {epoch_loss:.4f}")
# Epoch 10 Loss 0.0880
# Epoch 20 Loss 0.0796
# Epoch 30 Loss 0.0620
# Epoch 40 Loss 0.0285
# Epoch 50 Loss 0.0205
# Epoch 60 Loss 0.0187
# Epoch 70 Loss 0.0182
# Epoch 80 Loss 0.0176
# Epoch 90 Loss 0.0171
# Epoch 100 Loss 0.0166
y_pred = jax.vmap(model)(x_batched.reshape(-1, 2, 1))
plt.plot(x[1:], y[1:], "--k", label="data")
plt.plot(x[1:], y_pred, label="prediction")
plt.legend()
Lazy initialization
In cases where in_features
needs to be inferred from input, use None
instead of in_features
to infer the value at runtime.
However, since the lazy module initialize it's state after the first call (i.e. mutate it's state) jax
transformation ex: vmap, grad ...
is not allowed before initialization. Using any jax
transformation before initialization will throw a ValueError
.
import serket as sk
import jax
import jax.numpy as jnp
model = sk.nn.Sequential(
[
sk.nn.Conv2D(None, 128, 3),
sk.nn.ReLU(),
sk.nn.MaxPool2D(2, 2),
sk.nn.Conv2D(128, 64, 3),
sk.nn.ReLU(),
sk.nn.MaxPool2D(2, 2),
sk.nn.Flatten(),
sk.nn.Linear(None, 128),
sk.nn.ReLU(),
sk.nn.Linear(128, 1),
]
)
# print the first `Conv2D` layer before initialization
print(model[0].__repr__())
# Conv2D(
# weight=None,
# bias=None,
# *in_features=None,
# *out_features=None,
# *kernel_size=None,
# *strides=None,
# *padding=None,
# *input_dilation=None,
# *kernel_dilation=None,
# weight_init_func=None,
# bias_init_func=None,
# *groups=None
# )
try :
jax.vmap(model)(jnp.ones((10, 1,28, 28)))
except ValueError:
print("***** Not initialized *****")
# ***** Not initialized *****
# dry run to initialize the model
model(jnp.empty([3,128,128]))
print(model[0].__repr__())
# Conv2D(
# weight=f32[128,3,3,3],
# bias=f32[128,1,1],
# *in_features=3,
# *out_features=128,
# *kernel_size=(3,3),
# *strides=(1,1),
# *padding=((1,1),(1,1)),
# *input_dilation=(1,1),
# *kernel_dilation=(1,1),
# weight_init_func=Partial(glorot_uniform(key,shape,dtype)),
# bias_init_func=Partial(zeros(key,shape,dtype)),
# *groups=1
# )
Train MNIST
We will use tensorflow
datasets for dataloading. for more on interface of jax/tensorflow dataset see here
# imports
import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type="GPU")
import tensorflow_datasets as tfds
import tensorflow.experimental.numpy as tnp
import jax
import jax.numpy as jnp
import jax.random as jr
import optax # for gradient optimization
import serket as sk
import matplotlib.pyplot as plt
import functools as ft
# Construct a tf.data.Dataset
batch_size = 128
# convert the samples from integers to floating-point numbers
# and channel first format
def preprocess_data(x):
# convert to channel first format
image = tnp.moveaxis(x["image"], -1, 0)
# normalize to [0, 1]
image = tf.cast(image, tf.float32) / 255.0
# one-hot encode the labels
label = tf.one_hot(x["label"], 10) / 1.0
return {"image": image, "label": label}
ds_train, ds_test = tfds.load("mnist", split=["train", "test"], shuffle_files=True)
# (batches, batch_size, 1, 28, 28)
ds_train = ds_train.shuffle(1024).map(preprocess_data).batch(batch_size).prefetch(tf.data.AUTOTUNE)
# (batches, 1, 28, 28)
ds_test = ds_test.map(preprocess_data).prefetch(tf.data.AUTOTUNE)
๐๏ธ Model definition
We will use jax.vmap(model)
to apply model
on batches.
@sk.treeclass
class CNN:
def __init__(self):
self.conv1 = sk.nn.Conv2D(1, 32, (3, 3), padding="valid")
self.relu1 = sk.nn.ReLU()
self.pool1 = sk.nn.MaxPool2D((2, 2), strides=(2, 2))
self.conv2 = sk.nn.Conv2D(32, 64, (3, 3), padding="valid")
self.relu2 = sk.nn.ReLU()
self.pool2 = sk.nn.MaxPool2D((2, 2), strides=(2, 2))
self.flatten = sk.nn.Flatten(start_dim=0)
self.dropout = sk.nn.Dropout(0.5)
self.linear = sk.nn.Linear(5*5*64, 10)
def __call__(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.pool2(x)
x = self.flatten(x)
x = self.dropout(x)
x = self.linear(x)
return x
model = CNN()
๐จ Visualize model
Model summary
print(model.summary(show_config=False, array=jnp.empty((1, 28, 28))))
โโโโโโโโโฌโโโโโโโโโโฌโโโโโโโโโโฌโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโ
โName โType โParam # โSize โInput โOutput โ
โโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโผโโโโโโโโโโโโโโค
โconv1 โConv2D โ320(0) โ1.25KB(0.00B) โf32[1,28,28] โf32[32,26,26]โ
โโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโผโโโโโโโโโโโโโโค
โrelu1 โReLU โ0(0) โ0.00B(0.00B) โf32[32,26,26]โf32[32,26,26]โ
โโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโผโโโโโโโโโโโโโโค
โpool1 โMaxPool2Dโ0(0) โ0.00B(0.00B) โf32[32,26,26]โf32[32,13,13]โ
โโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโผโโโโโโโโโโโโโโค
โconv2 โConv2D โ18,496(0)โ72.25KB(0.00B)โf32[32,13,13]โf32[64,11,11]โ
โโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโผโโโโโโโโโโโโโโค
โrelu2 โReLU โ0(0) โ0.00B(0.00B) โf32[64,11,11]โf32[64,11,11]โ
โโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโผโโโโโโโโโโโโโโค
โpool2 โMaxPool2Dโ0(0) โ0.00B(0.00B) โf32[64,11,11]โf32[64,5,5] โ
โโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโผโโโโโโโโโโโโโโค
โflattenโFlatten โ0(0) โ0.00B(0.00B) โf32[64,5,5] โf32[1600] โ
โโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโผโโโโโโโโโโโโโโค
โdropoutโDropout โ0(0) โ0.00B(0.00B) โf32[1600] โf32[1600] โ
โโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโผโโโโโโโโโโโโโโค
โlinear โLinear โ16,010(0)โ62.54KB(0.00B)โf32[1600] โf32[10] โ
โโโโโโโโโดโโโโโโโโโโดโโโโโโโโโโดโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโดโโโโโโโโโโโโโโ
Total count : 34,826(0)
Dynamic count : 34,826(0)
Frozen count : 0(0)
------------------------------------------------------------------------
Total size : 136.04KB(0.00B)
Dynamic size : 136.04KB(0.00B)
Frozen size : 0.00B(0.00B)
========================================================================
tree diagram
print(model.tree_diagram())
CNN
โโโ conv1=Conv2D
โ โโโ weight=f32[32,1,3,3]
โ โโโ bias=f32[32,1,1]
โ โ*โ in_features=1
โ โ*โ out_features=32
โ โ*โ kernel_size=(3, 3)
โ โ*โ strides=(1, 1)
โ โ*โ padding=((0, 0), (0, 0))
โ โ*โ input_dilation=(1, 1)
โ โ*โ kernel_dilation=(1, 1)
โ โโโ weight_init_func=Partial(init(key,shape,dtype))
โ โโโ bias_init_func=Partial(zeros(key,shape,dtype))
โ โ*โ groups=1
โโโ relu1=ReLU
โ*โ pool1=MaxPool2D
โ โ*โ kernel_size=(2, 2)
โ โ*โ strides=(2, 2)
โ โ*โ padding='valid'
โโโ conv2=Conv2D
โ โโโ weight=f32[64,32,3,3]
โ โโโ bias=f32[64,1,1]
โ โ*โ in_features=32
โ โ*โ out_features=64
โ โ*โ kernel_size=(3, 3)
โ โ*โ strides=(1, 1)
โ โ*โ padding=((0, 0), (0, 0))
โ โ*โ input_dilation=(1, 1)
โ โ*โ kernel_dilation=(1, 1)
โ โโโ weight_init_func=Partial(init(key,shape,dtype))
โ โโโ bias_init_func=Partial(zeros(key,shape,dtype))
โ โ*โ groups=1
โโโ relu2=ReLU
โ*โ pool2=MaxPool2D
โ โ*โ kernel_size=(2, 2)
โ โ*โ strides=(2, 2)
โ โ*โ padding='valid'
โ*โ flatten=Flatten
โ โ*โ start_dim=0
โ โ*โ end_dim=-1
โโโ dropout=Dropout
โ โ*โ p=0.5
โ โโโ eval=None
โโโ linear=Linear
โโโ weight=f32[1600,10]
โโโ bias=f32[10]
โ*โ in_features=1600
โ*โ out_features=10
Plot sample predictions before training
# set all dropout off
test_model = model.at[model == "eval"].set(True, is_leaf=lambda x: x is None)
def show_images_with_predictions(model, images, one_hot_labels):
logits = jax.vmap(model)(images)
predictions = jnp.argmax(logits, axis=-1)
fig, axes = plt.subplots(5, 5, figsize=(10, 10))
for i, ax in enumerate(axes.flat):
ax.imshow(images[i].reshape(28, 28), cmap="binary")
ax.set(title=f"Prediction: {predictions[i]}\nLabel: {jnp.argmax(one_hot_labels[i], axis=-1)}")
ax.set_xticks([])
ax.set_yticks([])
plt.show()
example = ds_test.take(25).as_numpy_iterator()
example = list(example)
sample_test_images = jnp.stack([x["image"] for x in example])
sample_test_labels = jnp.stack([x["label"] for x in example])
show_images_with_predictions(test_model, sample_test_images, sample_test_labels)
๐ Train the model
@ft.partial(jax.value_and_grad, has_aux=True)
def loss_func(model, batched_images, batched_one_hot_labels):
logits = jax.vmap(model)(batched_images)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=batched_one_hot_labels))
return loss, logits
# using optax for gradient updates
optim = optax.adam(1e-3)
optim_state = optim.init(model)
@jax.jit
def batch_step(model, batched_images, batched_one_hot_labels, optim_state):
(loss, logits), grads = loss_func(model, batched_images, batched_one_hot_labels)
updates, optim_state = optim.update(grads, optim_state)
model = optax.apply_updates(model, updates)
accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == jnp.argmax(batched_one_hot_labels, axis=-1))
return model, optim_state, loss, accuracy
epochs = 5
for i in range(epochs):
epoch_accuracy = []
epoch_loss = []
for example in ds_train.as_numpy_iterator():
image, label = example["image"], example["label"]
model, optim_state, loss, accuracy = batch_step(model, image, label, optim_state)
epoch_accuracy.append(accuracy)
epoch_loss.append(loss)
epoch_loss = jnp.mean(jnp.array(epoch_loss))
epoch_accuracy = jnp.mean(jnp.array(epoch_accuracy))
print(f"epoch:{i+1:00d}\tloss:{epoch_loss:.4f}\taccuracy:{epoch_accuracy:.4f}")
# epoch:1 loss:0.2706 accuracy:0.9268
# epoch:2 loss:0.0725 accuracy:0.9784
# epoch:3 loss:0.0533 accuracy:0.9836
# epoch:4 loss:0.0442 accuracy:0.9868
# epoch:5 loss:0.0368 accuracy:0.9889
๐จ Visualize After training
test_model = model.at[model == "eval"].set(True, is_leaf=lambda x: x is None)
show_images_with_predictions(test_model, sample_test_images, sample_test_labels)
Finite difference examples
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy.testing as npt
import serket as sk
# lets first define a vector valued function F: R^3 -> R^3
# F = F1, F2
# F1 = x^2 + y^3
# F2 = x^4 + y^3
# F3 = 0
# F = [x**2 + y**3, x**4 + y**3, 0]
x, y, z = [jnp.linspace(0, 1, 100)] * 3
dx, dy, dz = x[1] - x[0], y[1] - y[0], z[1] - z[0]
X, Y, Z = jnp.meshgrid(x, y, z, indexing="ij")
F1 = X**2 + Y**3
F2 = X**4 + Y**3
F3 = jnp.zeros_like(F1)
F = jnp.stack([F1, F2, F3], axis=0)
# โF1/โx : differentiate F1 with respect to x (i.e axis=0)
dF1dx = sk.fd.difference(F1, axis=0, step_size=dx, accuracy=6)
dF1dx_exact = 2 * X
npt.assert_allclose(dF1dx, dF1dx_exact, atol=1e-7)
# โF2/โy : differentiate F2 with respect to y (i.e axis=1)
dF2dy = sk.fd.difference(F2, axis=1, step_size=dy, accuracy=6)
dF2dy_exact = 3 * Y**2
npt.assert_allclose(dF2dy, dF2dy_exact, atol=1e-7)
# โ.F : the divergence of F
divF = sk.fd.divergence(F, step_size=(dx, dy, dz), keepdims=False, accuracy=6)
divF_exact = 2 * X + 3 * Y**2
npt.assert_allclose(divF, divF_exact, atol=1e-7)
# โF1 : the gradient of F1
gradF1 = sk.fd.gradient(F1, step_size=(dx, dy, dz), accuracy=6)
gradF1_exact = jnp.stack([2 * X, 3 * Y**2, 0 * X], axis=0)
npt.assert_allclose(gradF1, gradF1_exact, atol=1e-7)
# ฮF1 : laplacian of F1
lapF1 = sk.fd.laplacian(F1, step_size=(dx, dy, dz), accuracy=6)
lapF1_exact = 2 + 6 * Y
npt.assert_allclose(lapF1, lapF1_exact, atol=1e-7)
# โxF : the curl of F
curlF = sk.fd.curl(F, step_size=(dx, dy, dz), accuracy=6)
curlF_exact = jnp.stack([F1 * 0, F1 * 0, 4 * X**3 - 3 * Y**2], axis=0)
npt.assert_allclose(curlF, curlF_exact, atol=1e-7)
# Jacobian of F
JF = sk.fd.jacobian(F, accuracy=4, step_size=(dx, dy, dz))
JF_exact = jnp.array(
[
[2 * X, 3 * Y**2, jnp.zeros_like(X)],
[4 * X**3, 3 * Y**2, jnp.zeros_like(X)],
[jnp.zeros_like(X), jnp.zeros_like(X), jnp.zeros_like(X)],
]
)
npt.assert_allclose(JF, JF_exact, atol=1e-7)
# Hessian of F1
HF1 = sk.fd.hessian(F1, accuracy=4, step_size=(dx, dy, dz))
HF1_exact = jnp.array(
[
[
2 * jnp.ones_like(X), # โ2F1/โx2
0 * jnp.ones_like(X), # โ2F1/โxy
0 * jnp.ones_like(X), # โ2F1/โxz
],
[
0 * jnp.ones_like(X), # โ2F1/โyx
6 * Y**2, # โ2F1/โy2
0 * jnp.ones_like(X), # โ2F1/โyz
],
[
0 * jnp.ones_like(X), # โ2F1/โzx
0 * jnp.ones_like(X), # โ2F1/โzy
0 * jnp.ones_like(X), # โ2F1/โz2
],
]
)
npt.assert_allclose(JF, JF_exact, atol=1e-7)
PINN with Finite difference
We will try to estimate $NN(x)~f(x)$, where $df(x)/dx = cos(x)$ and $df(x)/dx$ will be represented with finite difference scheme. The following code compares between finite difference fd
based implementation and automatic differentation ad
based implementation.
import copy
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
import serket as sk
x = jnp.linspace(-jnp.pi, jnp.pi, 1000)[:, None]
y = jnp.sin(x)
dx = x[1] - x[0]
dydx = jnp.cos(x)
NN_fd = sk.nn.Sequential(
[
sk.nn.Linear(1, 128),
sk.nn.ReLU(),
sk.nn.Linear(128, 128),
sk.nn.ReLU(),
sk.nn.Linear(128, 1),
]
)
NN_ad = copy.copy(NN_fd)
optim = optax.adam(1e-3)
@jax.value_and_grad
def loss_func_fd(NN, x):
y = NN(x)
dydx = sk.fd.difference(y, axis=0, accuracy=5, step_size=dx)
loss = jnp.mean((dydx - jnp.cos(x)) ** 2)
loss += jnp.mean((NN(jnp.zeros_like(x))) ** 2) # initial condition
return loss
@jax.value_and_grad
def loss_func_ad(NN, x):
loss = jnp.mean((sk.diff(NN)(x) - jnp.cos(x)) ** 2)
loss += jnp.mean(NN(jnp.zeros_like(x)) ** 2) # initial condition
return loss
@jax.jit
def step_fd(NN, x, optim_state):
loss, grads = loss_func_fd(NN, x)
updates, optim_state = optim.update(grads, optim_state)
NN = optax.apply_updates(NN, updates)
return NN, optim_state, loss
def train_fd(NN_fd, optim_state_fd, epochs):
for i in range(1, epochs + 1):
NN_fd, optim_state_fd, loss_fd = step_fd(NN_fd, x, optim_state_fd)
return NN_fd, optim_state_fd, loss_fd
@jax.jit
def step_ad(NN, x, optim_state):
loss, grads = loss_func_ad(NN, x)
updates, optim_state = optim.update(grads, optim_state)
NN = optax.apply_updates(NN, updates)
return NN, optim_state, loss
def train_ad(NN_ad, optim_state_ad, epochs):
for i in range(1, epochs + 1):
NN_ad, optim_state_ad, loss_ad = step_ad(NN_ad, x, optim_state_ad)
return NN_ad, optim_state_ad, loss_ad
epochs = 1000
optim_state_fd = optim.init(NN_fd)
optim_state_ad = optim.init(NN_ad)
NN_fd, optim_state_fd, loss_fd = train_fd(NN_fd, optim_state_fd, epochs)
NN_ad, optim_state_ad, loss_ad = train_ad(NN_ad, optim_state_ad, epochs)
print(f"Loss_fd {loss_fd:.4f} \nLoss_ad {loss_ad:.4f}")
y_fd = NN_fd(x)
y_ad = NN_ad(x)
plt.plot(x, y, "--k", label="true")
plt.plot(x, y_fd, label="fd pred")
plt.plot(x, y_ad, label="ad pred")
plt.legend()
# Loss_fd 0.0012
# Loss_ad 0.0235
Reconstructing a vector field F using โ.F = 0 and โxF=2k condition
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
import serket as sk
x, y = [jnp.linspace(-1, 1,50)] * 2
dx, dy = [x[1] - x[0]] * 2
X, Y = jnp.meshgrid(x, y, indexing="ij")
F1 = -Y
F2 = +X
F = jnp.stack([F1, F2], axis=0)
NN = sk.nn.Sequential(
[
sk.nn.Conv2D(2, 32, kernel_size=3, padding="same"),
sk.nn.ReLU(),
sk.nn.Conv2D(32, 32, kernel_size=3, padding="same"),
sk.nn.ReLU(),
sk.nn.Conv2D(32, 2, kernel_size=3, padding="same"),
]
)
optim = optax.adam(1e-3)
@jax.value_and_grad
def loss_func(NN, F):
F_pred = NN(F)
div = sk.fd.divergence(F_pred, accuracy=5, step_size=(dx, dy))
loss = jnp.mean(div**2) # divergence free condition
curl = sk.fd.curl(F_pred, accuracy=2, step_size=(dx, dy))
loss += jnp.mean((curl-jnp.ones_like(curl)*2)**2) # curl condition
return loss
@jax.jit
def step(NN, F, optim_state):
loss, grads = loss_func(NN, F)
updates, optim_state = optim.update(grads, optim_state)
NN = optax.apply_updates(NN, updates)
return NN, optim_state, loss
def train(NN, Z, optim_state, epochs):
for i in range(1, epochs + 1):
NN, optim_state, loss = step(NN, Z, optim_state)
return NN, optim_state, loss
Z = jnp.stack([X, Y], axis=0) # collocation points
optim_state = optim.init(NN) # initialise optimiser
epochs = 1_000
NN, _, loss = train(NN, Z, optim_state, epochs)
Fpred = NN(Z) # predicted field
plt.figure(figsize=(10, 10))
plt.quiver(X, Y, Fpred[0], Fpred[1], color="r", label="pred")
plt.quiver(X, Y, F1, F2, color="k", alpha=0.5, label="true")
plt.legend()
Vectorized differentiable stencil computation with `serket.kmap`
Serket uses kernex.kmap
decorator that applies a user-defined stencil kernel. kmap
uses jax.vmap
as it's backend to vectorized the operation, this means that the decorator is transparent to jax
transformation.
Example
@sk.kmap(
# a kernel size applied to 2D input with size =3x3
kernel_size = (3,3),
# a strides = 1
strides= (1,1) ,
# padding can be among the following options
# 1) a single integer for each dimension -> ex: (1,) pads zeros before and after axis=0
# 2) a tuple of two integer for each dimension -> ex: ((1,2),) pads one zero on left and 2 zeros on right of axis=0
# 3) "same"/"valid"
# 4) "same"/"valid" tuple for each dimension -> ex: ("same",) same padding for axis=0
padding = "valid",
# relative means if the indexing should be row-col wise or center wise.
# for example in a 3x3 1..9 kernel , x[0,0] yields
# 1 if relative = False
# 5 if relative = True (i.e. the center value)
relative= True,
)
def avg_blur(x):
return (x[-1, -1] + x[-1, 0] + x[-1, 1] +
x[ 0, -1] + x[ 0, 0] + x[ 0, 1] +
x[ 1, -1] + x[ 1, 0] + x[ 1, 1]) // 9
avg_blur(jnp.arange(1,26).reshape(5,5))
# [[ 7 8 9]
# [12 13 14]
# [17 18 19]]
Scan a stencil kernel to solve linear convection using `serket.kscan`
$\Large {\partial u \over \partial t} + c {\partial u \over \partial x} = 0$
$\Large u_i^{n} = u_i^{n-1} - c \frac{\Delta t}{\Delta x}(u_i^{n-1}-u_{i-1}^{n-1})$
Problem setup | Stencil view |
|
|
By using serket.kscan
, the stencil kernel can be scanned carrying along state, in a way similar to how RNN works. This enables BPTT algorithm that is useful for some problems (ex. time-dependent PDEs) .
import jax
import jax.numpy as jnp
import serket as sk
import matplotlib.pyplot as plt
# see https://nbviewer.org/github/barbagroup/CFDPython/blob/master/lessons/01_Step_1.ipynb
tmax,xmax = 0.5,2.0
nt,nx = 151,51
dt,dx = tmax/(nt-1) , xmax/(nx-1)
u = jnp.ones([nt,nx])
c = 0.5
# kscan moves sequentially in row-major order and updates in-place using lax.scan.
F = sk.kscan(
kernel_size = (3,3),
padding = ((1,1),(1,1)),
named_axis={0:'n',1:'i'}, # n for time axis , i for spatial axis (optional naming)
relative=True
)
# boundary condtion as a function
def bc(u):
return 1
# initial condtion as a function
def ic1(u):
return 1
def ic2(u):
return 2
def linear_convection(u):
return ( u['i','n-1'] - (c*dt/dx) * (u['i','n-1'] - u['i-1','n-1']) )
F[:,0] = F[:,-1] = bc # assign 1 for left and right boundary for all t
# square wave initial condition
F[:,:int((nx-1)/4)+1] = F[:,int((nx-1)/2):] = ic1
F[0:1, int((nx-1)/4)+1 : int((nx-1)/2)] = ic2
# assign linear convection function for
# interior spatial location [1:-1]
# and start from t>0 [1:]
F[1:,1:-1] = linear_convection
kx_solution = F(jnp.array(u))
plt.figure(figsize=(20,7))
for line in kx_solution[::20]:
plt.plot(jnp.linspace(0,xmax,nx),line)
๐ฅถ Freezing parameters /Fine tuning
โจSee here for more about freezingโจ
๐ Filtering by masking
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.