JAX NN library.
Project description
The โจMagicalโจ JAX NN Library.
*Serket is the goddess of magic in Egyptian mythology
๐ ๏ธ Installation
pip install serket
๐ Description
serket
aims to be the most intuitive and easy-to-use Neural network library in JAX.serket
is built on top ofpytreeclass
serket
currently implementsLinear
,FNN
Dropout
Sequential
Lambda
โฉ Quick Example
Simple Fully connected neural network.
Model definition
import serket as sk
import jax.numpy as jnp
import jax.random as jr
@sk.treeclass
class NN:
def __init__(
self,
in_features:int,
out_features:int,
hidden_features: int, key:jr.PRNGKey = jr.PRNGKey(0)):
k1,k2,k3 = jr.split(key, 3)
self.l1 = sk.nn.Linear(in_features, hidden_features, key=k1)
self.l2 = sk.nn.Linear(hidden_features, hidden_features, key=k2)
self.l3 = sk.nn.Linear(hidden_features, out_features, key=k3)
def __call__(self, x):
x = self.l1(x)
x = jax.nn.relu(x)
x = self.l2(x)
x = jax.nn.relu(x)
x = self.l3(x)
return x
model = NN(
in_features=1,
out_features=1,
hidden_features=128,
key=jr.PRNGKey(0))
# `*` represents untrainable(static) nodes.
print(model.tree_diagram())
NN
โโโ l1=Linear
โ โโโ weight=f32[1,128]
โ โโโ bias=f32[128]
โ โ*โ in_features=1
โ โ*โ out_features=128
โ โ*โ weight_init_func=init(key,shape,dtype)
โ โ*โ bias_init_func=Lambda(key,shape)
โโโ l2=Linear
โ โโโ weight=f32[128,128]
โ โโโ bias=f32[128]
โ โ*โ in_features=128
โ โ*โ out_features=128
โ โ*โ weight_init_func=init(key,shape,dtype)
โ โ*โ bias_init_func=Lambda(key,shape)
โโโ l3=Linear
โโโ weight=f32[128,1]
โโโ bias=f32[1]
โ*โ in_features=128
โ*โ out_features=1
โ*โ weight_init_func=init(key,shape,dtype)
โ*โ bias_init_func=Lambda(key,shape)
>>> print(model.summary())
โโโโโโฌโโโโโโโฌโโโโโโโโโโฌโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโ
โNameโType โParam # โSize โConfig โ
โโโโโโผโโโโโโโผโโโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโโโค
โl1 โLinearโ256(0) โ1.00KB โweight=f32[1,128] โ
โ โ โ โ(0.00B)โbias=f32[128] โ
โโโโโโผโโโโโโโผโโโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโโโค
โl2 โLinearโ16,512(0)โ64.50KBโweight=f32[128,128]โ
โ โ โ โ(0.00B)โbias=f32[128] โ
โโโโโโผโโโโโโโผโโโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโโโค
โl3 โLinearโ129(0) โ516.00Bโweight=f32[128,1] โ
โ โ โ โ(0.00B)โbias=f32[1] โ
โโโโโโดโโโโโโโดโโโโโโโโโโดโโโโโโโโดโโโโโโโโโโโโโโโโโโโโ
Total count : 16,897(0)
Dynamic count : 16,897(0)
Frozen count : 0(0)
---------------------------------------------------
Total size : 66.00KB(0.00B)
Dynamic size : 66.00KB(0.00B)
Frozen size : 0.00B(0.00B)
===================================================
Train
x = jnp.linspace(0,1,100)[:,None]
y = x**3 + jax.random.uniform(jax.random.PRNGKey(0),(100,1))*0.01
@jax.value_and_grad
def loss_func(model,x,y):
return jnp.mean((model(x)-y)**2)
@jax.jit
def update(model,x,y):
value,grad = loss_func(model,x,y)
return value , model - 1e-3*grad
for _ in range(20_000):
value,model = update(model,x,y)
Filter
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
serket-0.0.1.tar.gz
(6.7 kB
view hashes)
Built Distribution
serket-0.0.1-py3-none-any.whl
(7.6 kB
view hashes)