Stencil computations in JAX.
Project description
Differentiable Stencil computations in JAX
Installation |Description |Examples |Benchmarking
๐ ๏ธ Installation
pip install pytreeclass kernex
๐ Description
Kernex extends jax.vmap
and jax.lax.scan
with kmap
and kscan
for general stencil computations.
๐ข Examples
import jax
import jax.numpy as jnp
import kernex as kex
from pytreeclass import treeclass,tree_viz
import numpy as np
import matplotlib.pyplot as plt
kmap
Convolution operation
# JAX channel first conv2d operation
@jax.jit
@kex.kmap(
kernel_size= (3,3,3),
padding = ('valid','same','same'))
def kernex_conv2d(x,w):
return jnp.sum(x*w)
Laplacian operation
# see also
# https://numba.pydata.org/numba-doc/latest/user/stencil.html#basic-usage
@kex.kmap(
kernel_size=(3,3),
padding= 'valid',
relative=True) # `relative`= True enables relative indexing
def laplacian(x):
return ( 0*x[1,-1] + 1*x[1,0] + 0*x[1,1] +
1*x[0,-1] +-4*x[0,0] + 1*x[0,1] +
0*x[-1,-1] + 1*x[-1,0] + 0*x[-1,1] )
# apply laplacian
>>> print(laplacian(jnp.ones([10,10])))
DeviceArray(
[[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)
Get Patches of an array
@kex.kmap(kernel_size=(3,3),relative=True)
def identity(x):
# similar to numba.stencil
# this function returns the top left cell in the padded/unpadded kernel view
# or center cell if `relative`=True
return x[0,0]
# unlike numba.stencil , vector output is allowed in kernex
# this function is similar to
# `jax.lax.conv_general_dilated_patches(x,(3,),(1,),padding='same')`
@jax.jit
@kex.kmap(kernel_size=(3,3),padding='same')
def get_3x3_patches(x):
# returns 5x5x3x3 array
return x
mat = jnp.arange(1,26).reshape(5,5)
>>> print(mat)
[[ 1 2 3 4 5]
[ 6 7 8 9 10]
[11 12 13 14 15]
[16 17 18 19 20]
[21 22 23 24 25]]
# get the view at array index = (0,0)
>>> print(get_3x3_patches(mat)[0,0])
[[0 0 0]
[0 1 2]
[0 6 7]]
Moving average
@kex.kmap(kernel_size=(3,))
def moving_average(x):
return jnp.mean(x)
>>> moving_average(jnp.array([1,2,3,7,9]))
DeviceArray([2. , 4. , 6.3333335], dtype=float32)
Apply stencil operations by index
To achieve the following operation with `jax.lax.switch` , we need a list of 10 functions correspoing to each cell of the example array. For this reason , kernex adopts a modified version of `jax.lax.switch` to reduce the number of branches required to be equal to the number of unique functions assigned.F = kex.kmap(kernel_size=(1,))
'''
Apply f(x) = x^2 on index=0 and f(x) = x^3 index=(1,10)
โโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโ
f = โ x^2 โ x^3 โ x^3 โ x^3 โ x^3 โ x^3 โ x^3 โ x^3 โ x^3 โ x^3 โ
โโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโ
โโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโ
f( โ 1 โ 2 โ 3 โ 4 โ 5 โ 6 โ 7 โ 8 โ 9 โ 10 โ ) =
โโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโ
โโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโ
โ 1 โ 8 โ 27 โ 64 โ 125 โ 216 โ 343 โ 512 โ 729 โ1000 โ
โโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโ
โโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโ
df/dx = โ 2x โ3x^2 โ3x^2 โ3x^2 โ3x^2 โ3x^2 โ3x^2 โ3x^2 โ3x^2 โ3x^2 โ
โโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโ
โโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโ
df/dx( โ 1 โ 2 โ 3 โ 4 โ 5 โ 6 โ 7 โ 8 โ 9 โ 10 โ ) =
โโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโ
โโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโฌโโโโโโ
โ 2 โ 12 โ 27 โ 48 โ 75 โ 108 โ 147 โ 192 โ 243 โ 300 โ
โโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโดโโโโโโ
'''
array = jnp.arange(1,11).astype('float32')
# use a modified version of lax.switch to switch between functions
# assign function at index
F[0] = lambda x:x[0]**2
F[1:] = lambda x:x[0]**3
print(F(array))
>>> [ 1. 8. 27. 64. 125. 216. 343. 512. 729. 1000.]
dFdx = jax.grad(lambda x:jnp.sum(F(x)))
print(dFdx(array))
>>> [ 2. 12. 27. 48. 75. 108. 147. 192. 243. 300.]
kscan
Linear convection
$\Large {\partial u \over \partial t} + c {\partial u \over \partial x} = 0$
$\Large u_i^{n} = u_i^{n-1} - c \frac{\Delta t}{\Delta x}(u_i^{n-1}-u_{i-1}^{n-1})$
# see https://nbviewer.org/github/barbagroup/CFDPython/blob/master/lessons/01_Step_1.ipynb
tmax,xmax = 0.5,2.0
nt,nx = 151,51
dt,dx = tmax/(nt-1) , xmax/(nx-1)
u = np.ones([nt,nx])
c = 0.5
# kscan moves sequentially in row-major order and updates in-place using lax.scan.
F = kernex.kscan(
kernel_size = (3,3),
padding = ((1,1),(1,1)),
named_axis={0:'n',1:'i'}, # n for time axis , i for spatial axis (optional naming)
relative=True)
# boundary condtion as a function
def bc(u):
return 1
# initial condtion as a function
def ic1(u):
return 1
def ic2(u):
return 2
def linear_convection(u):
return ( u['i','n-1'] -
(c*dt/dx) * (u['i','n-1'] - u['i-1','n-1']) )
F[:,0] = F[:,-1] = bc # assign 1 for left and right boundary for all t
# square wave initial condition
F[:,:int((nx-1)/4)+1] = F[:,int((nx-1)/2):] = ic1
F[0:1, int((nx-1)/4)+1 : int((nx-1)/2)] = ic2
# assign linear convection function for
# interior spatial location [1:-1]
# and start from t>0 [1:]
F[1:,1:-1] = linear_convection
kx_solution = F(jnp.array(u))
plt.figure(figsize=(20,7))
for line in kx_solution[::20]:
plt.plot(jnp.linspace(0,xmax,nx),line)
kmap
+ pytreeclass
= Pytorch-like Layers
MaxPool2D layer
@treeclass
class MaxPool2D:
kernel_size: tuple[int, ...] | int = static_field()
strides: tuple[int, ...] | int = static_field()
padding: tuple[int, ...] | int | str = static_field()
def __init__(self, *, kernel_size=(2, 2), strides=2, padding="valid"):
self.kernel_size = kernel_size
self.strides = strides
self.padding = padding
def __call__(self, x):
@jax.vmap # apply on batch dimension
@jax.vmap # apply on channels dimension
@kex.kmap(
kernel_size=self.kernel_size,
strides=self.strides,
padding=self.padding)
def _maxpool2d(x):
return jnp.max(x)
return _maxpool2d(x)
layer = MaxPool2D(kernel_size=(2,2),strides=(2,2),padding='same')
array = jnp.arange(1,26).reshape(1,1,5,5) # batch,channel,row,col
>>> print(array)
[[[[ 1 2 3 4 5]
[ 6 7 8 9 10]
[11 12 13 14 15]
[16 17 18 19 20]
[21 22 23 24 25]]]]
>>> print(layer(array))
[[[[ 7 9 10]
[17 19 20]
[22 24 25]]]]
AverageBlur2D layer
import os
from PIL import Image
@treeclass
class AverageBlurLayer:
'''channels first'''
in_channels : int
kernel_size : tuple[int]
def __init__(self,in_channels,kernel_size):
self.in_channels = in_channels
self.kernel_size = kernel_size
def __call__(self,x):
@jax.vmap # vectorize on batch dim
@jax.vmap # vectorize on channels
@kex.kmap(kernel_size=(*self.kernel_size,),padding='same')
def average_blur(x):
kernel = jnp.ones([*self.kernel_size])/jnp.array(self.kernel_size).prod()
return jnp.sum(x*(kernel),dtype=jnp.float32)
return average_blur(x).astype(jnp.uint8)
img = Image.open(os.path.join('assets','puppy.png'))
>>> img
batch_img = jnp.einsum('HWC->CHW' ,jnp.array(img))[None] # make it channel first and add batch dim
layer = jax.jit(AverageBlurLayer(in_channels=4,kernel_size=(25,25)))
blurred_image = layer(batch_img)
blurred_image = jnp.einsum('CHW->HWC' ,blurred_image[0])
plt.figure(figsize=(20,20))
plt.imshow(blurred_image)
Conv2D layer
@treeclass
class Conv2D:
weight: jnp.ndarray
bias: jnp.ndarray
in_channels: int = static_field()
out_channels: int = static_field()
kernel_size: tuple[int, ...] | int = static_field()
strides: tuple[int, ...] | int = static_field()
padding: tuple[int, ...] | int | str = static_field()
def __init__(self,
*,
in_channels,
out_channels,
kernel_size,
strides=1,
padding=("same", "same"),
key=jax.random.PRNGKey(0),
use_bias=True,
kernel_initializer=jax.nn.initializers.kaiming_uniform()):
self.weight = kernel_initializer(
key, (out_channels, in_channels, *kernel_size))
self.bias = (jnp.zeros(
(out_channels, *((1, ) * len(kernel_size)))) if use_bias else None)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.strides = strides
self.padding = ("valid", ) + padding
def __call__(self, x):
@kex.kmap(
kernel_size=(self.in_channels, *self.kernel_size),
strides=self.strides,
padding=self.padding)
def _conv2d(x, w):
return jnp.sum(x * w)
@jax.vmap # vectorize on batch dimension
def fwd_image(image):
# filters shape is OIHW
# vectorize on filters output dimension
return vmap(lambda w: _conv2d(image, w))(self.weight)[:, 0] + (
self.bias if self.bias is not None else 0)
return fwd_image(x)
โ Benchmarking
Conv2D
# testing and benchmarking convolution
# for complete benchmarking check /tests_and_benchmark
# 3x1024x1024 Input
C,H = 3,1024
@jax.jit
def jax_conv2d(x,w):
return jax.lax.conv_general_dilated(
lhs = x,
rhs = w,
window_strides = (1,1),
padding = 'SAME',
dimension_numbers = ('NCHW', 'OIHW', 'NCHW'),)[0]
x = jax.random.normal(jax.random.PRNGKey(0),(C,H,H))
xx = x[None]
w = jax.random.normal(jax.random.PRNGKey(0),(C,3,3))
ww = w[None]
# assert equal
np.testing.assert_allclose(kernex_conv2d(x,w),jax_conv2d(xx,ww),atol=1e-3)
# Mac M1 CPU
# check tests_and_benchmark folder for more.
%timeit kernex_conv2d(x,w).block_until_ready()
# 3.96 ms ยฑ 272 ยตs per loop (mean ยฑ std. dev. of 7 runs, 100 loops each)
%timeit jax_conv2d(xx,ww).block_until_ready()
# 27.5 ms ยฑ 993 ยตs per loop (mean ยฑ std. dev. of 7 runs, 10 loops each)
get_patches
# benchmarking `get_patches` with `jax.lax.conv_general_dilated_patches`
# On Mac M1 CPU
@jax.jit
@kex.kmap(kernel_size=(3,),padding='same')
def get_patches(x):
return x
@jax.jit
def jax_get_patches(x):
return jax.lax.conv_general_dilated_patches(x,(3,),(1,),padding='same')
x = jnp.ones([1_000_000])
xx = jnp.ones([1,1,1_000_000])
np.testing.assert_allclose(
get_patches(x),
jax_get_patches(xx).reshape(-1,1_000_000).T)
>> %timeit get_patches(x).block_until_ready()
>> %timeit jax_get_patches(xx).block_until_ready()
1.73 ms ยฑ 92.7 ยตs per loop (mean ยฑ std. dev. of 7 runs, 1,000 loops each)
10.6 ms ยฑ 337 ยตs per loop (mean ยฑ std. dev. of 7 runs, 100 loops each)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
kernex-0.0.4.tar.gz
(25.5 kB
view hashes)
Built Distribution
kernex-0.0.4-py3-none-any.whl
(25.3 kB
view hashes)