A jax based nn library
Project description
Jaxtorch (a jax nn library)
This is my jax based nn library. I created this because I was annoyed by the complexity and 'magic'-ness of the popular jax frameworks (flax, haiku).
The objective is to enable pytorch-like model definition and training with a minimum of magic. Simple example:
import jax
import jax.numpy as jnp
import jaxlib
import jaxtorch
# Modules are just classes that inherit jaxtorch.Module
class Linear(jaxtorch.Module):
# They can accept any constructor parameters
def __init__(self, in_features: int, out_features: int, bias: bool = True):
super().__init__()
# Parameters are represented by a Param type, which identifies
# them, and specifies how to initialize them.
self.weight = jaxtorch.init.glorot_normal(out_features, in_features)
assert type(self.weight) is jaxtorch.Param
if bias:
self.bias = jaxtorch.init.zeros(out_features)
else:
self.bias = None
# The forward function accepts cx, a Context object as the first argument
# always. This provides random number generation as well as the parameters.
def forward(self, cx: jaxtorch.Context, x):
# Parameters are looked up in the context using the stored identifier.
y = x @ jnp.transpose(cx[self.weight])
if self.bias:
y = y + cx[self.bias]
return y
model = Linear(3, 3)
# You initialize the weights by passing a RNG key.
# Calling init_weights also names all the parameters in the Module tree.
params = model.init_weights(jax.random.PRNGKey(0))
# Parameters are stored in a dictionary by name.
assert type(params) is dict
assert type(params[model.weight.name]) is jaxlib.xla_extension.DeviceArray
assert model.weight.name == 'weight'
def loss(params, key):
cx = jaxtorch.Context(params, key)
x = jnp.array([1.0,2.0,3.0])
y = jnp.array([4.0,5.0,6.0])
return jnp.mean((model(cx, x) - y)**2)
f_grad = jax.value_and_grad(loss)
for _ in range(100):
(loss, grad) = f_grad(params, jax.random.PRNGKey(0))
params = jax.tree_util.tree_map(lambda p, g: p - 0.01 * g, params, grad)
print(loss)
# 4.7440533e-08
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
jaxtorch-0.1.0.tar.gz
(13.1 kB
view details)
Built Distribution
jaxtorch-0.1.0-py3-none-any.whl
(13.5 kB
view details)
File details
Details for the file jaxtorch-0.1.0.tar.gz
.
File metadata
- Download URL: jaxtorch-0.1.0.tar.gz
- Upload date:
- Size: 13.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.11.3 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.9.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e7fc5f6ad94254ff6b72be2122b1df54e399350f2b38b20867f1db5135943dd0 |
|
MD5 | 6c67aeb230ad710dedb05e991f8d84a5 |
|
BLAKE2b-256 | cf418cbcd31b6e99d9cb36f6126704f5d322cf6b1a2558637f65794f7d490529 |
File details
Details for the file jaxtorch-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: jaxtorch-0.1.0-py3-none-any.whl
- Upload date:
- Size: 13.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.11.3 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.9.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | cd84d38421d2930096ab2d814b1e428b6449161585e225cfdc106b7b75dfae59 |
|
MD5 | a79854826df94feff709c251b343add6 |
|
BLAKE2b-256 | 8265e82fad3ba761f5827bde9ebbbacf9f9dc5a5c23b6fb5cc277ef6e5e4d51f |