A jax based nn library
Project description
Jaxtorch (a jax nn library)
This is my jax based nn library. I created this because I was annoyed by the complexity and 'magic'-ness of the popular jax frameworks (flax, haiku).
The objective is to enable pytorch-like model definition and training with a minimum of magic. Simple example:
import jax
import jax.numpy as jnp
import jaxlib
import jaxtorch
# Modules are just classes that inherit jaxtorch.Module
class Linear(jaxtorch.Module):
# They can accept any constructor parameters
def __init__(self, in_features: int, out_features: int, bias: bool = True):
super().__init__()
# Parameters are represented by a Param type, which identifies
# them, and specifies how to initialize them.
self.weight = jaxtorch.init.glorot_normal(out_features, in_features)
assert type(self.weight) is jaxtorch.Param
if bias:
self.bias = jaxtorch.init.zeros(out_features)
else:
self.bias = None
# The forward function accepts cx, a Context object as the first argument
# always. This provides random number generation as well as the parameters.
def forward(self, cx: jaxtorch.Context, x):
# Parameters are looked up in the context using the stored identifier.
y = x @ jnp.transpose(cx[self.weight])
if self.bias:
y = y + cx[self.bias]
return y
model = Linear(3, 3)
# You initialize the weights by passing a RNG key.
# Calling init_weights also names all the parameters in the Module tree.
params = model.init_weights(jax.random.PRNGKey(0))
# Parameters are stored in a dictionary by name.
assert type(params) is dict
assert type(params[model.weight.name]) is jaxlib.xla_extension.DeviceArray
assert model.weight.name == 'weight'
def loss(params, key):
cx = jaxtorch.Context(params, key)
x = jnp.array([1.0,2.0,3.0])
y = jnp.array([4.0,5.0,6.0])
return jnp.mean((model(cx, x) - y)**2)
f_grad = jax.value_and_grad(loss)
for _ in range(100):
(loss, grad) = f_grad(params, jax.random.PRNGKey(0))
params = jax.tree_util.tree_map(lambda p, g: p - 0.01 * g, params, grad)
print(loss)
# 4.7440533e-08
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
jaxtorch-0.2.0.tar.gz
(13.6 kB
view details)
Built Distribution
jaxtorch-0.2.0-py3-none-any.whl
(13.8 kB
view details)
File details
Details for the file jaxtorch-0.2.0.tar.gz
.
File metadata
- Download URL: jaxtorch-0.2.0.tar.gz
- Upload date:
- Size: 13.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.11.3 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.9.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0fb2b30c2580620e4b1fbc7424738ad8f96373fcd5a4ba44823ecb0e263f3f24 |
|
MD5 | 5673173493714fb977236c964a4f3eeb |
|
BLAKE2b-256 | 8fae4c1306051f8191d52cf185167e8f9a95c9d07ab34b678088762c1428ddbf |
File details
Details for the file jaxtorch-0.2.0-py3-none-any.whl
.
File metadata
- Download URL: jaxtorch-0.2.0-py3-none-any.whl
- Upload date:
- Size: 13.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.11.3 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.9.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 460616670559065ce33ed8337b9ffda8cb0912d235dbb74a1de72c764249a9bb |
|
MD5 | 9f6777b12bc72b5934e8240cf6ed03ca |
|
BLAKE2b-256 | d713a1e56db637da6192a71fd3e162f334c47136ce59baaae536e2b84280d797 |