Skip to main content

A jax based nn library

Project description

Jaxtorch (a jax nn library)

This is my jax based nn library. I created this because I was annoyed by the complexity and 'magic'-ness of the popular jax frameworks (flax, haiku).

The objective is to enable pytorch-like model definition and training with a minimum of magic. Simple example:

import jax
import jax.numpy as jnp
import jaxlib
import jaxtorch

# Modules are just classes that inherit jaxtorch.Module
class Linear(jaxtorch.Module):
    # They can accept any constructor parameters
    def __init__(self, in_features: int, out_features: int, bias: bool = True):
        super().__init__()
        # Parameters are represented by a Param type, which identifies
        # them, and specifies how to initialize them.
        self.weight = jaxtorch.init.glorot_normal(out_features, in_features)
        assert type(self.weight) is jaxtorch.Param
        if bias:
            self.bias = jaxtorch.init.zeros(out_features)
        else:
            self.bias = None

    # The forward function accepts cx, a Context object as the first argument
    # always. This provides random number generation as well as the parameters.
    def forward(self, cx: jaxtorch.Context, x):
        # Parameters are looked up in the context using the stored identifier.
        y = x @ jnp.transpose(cx[self.weight])
        if self.bias:
            y = y + cx[self.bias]
        return y

model = Linear(3, 3)

# You initialize the weights by passing a RNG key.
# Calling init_weights also names all the parameters in the Module tree.
params = model.init_weights(jax.random.PRNGKey(0))

# Parameters are stored in a dictionary by name.
assert type(params) is dict
assert type(params[model.weight.name]) is jaxlib.xla_extension.DeviceArray
assert model.weight.name == 'weight'

def loss(params, key):
    cx = jaxtorch.Context(params, key)
    x = jnp.array([1.0,2.0,3.0])
    y = jnp.array([4.0,5.0,6.0])
    return jnp.mean((model(cx, x) - y)**2)
f_grad = jax.value_and_grad(loss)

for _ in range(100):
    (loss, grad) = f_grad(params, jax.random.PRNGKey(0))
    params = jax.tree_util.tree_map(lambda p, g: p - 0.01 * g, params, grad)
print(loss)
# 4.7440533e-08

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxtorch-0.2.0.tar.gz (13.6 kB view details)

Uploaded Source

Built Distribution

jaxtorch-0.2.0-py3-none-any.whl (13.8 kB view details)

Uploaded Python 3

File details

Details for the file jaxtorch-0.2.0.tar.gz.

File metadata

  • Download URL: jaxtorch-0.2.0.tar.gz
  • Upload date:
  • Size: 13.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.11.3 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.9.9

File hashes

Hashes for jaxtorch-0.2.0.tar.gz
Algorithm Hash digest
SHA256 0fb2b30c2580620e4b1fbc7424738ad8f96373fcd5a4ba44823ecb0e263f3f24
MD5 5673173493714fb977236c964a4f3eeb
BLAKE2b-256 8fae4c1306051f8191d52cf185167e8f9a95c9d07ab34b678088762c1428ddbf

See more details on using hashes here.

File details

Details for the file jaxtorch-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: jaxtorch-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 13.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.11.3 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.9.9

File hashes

Hashes for jaxtorch-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 460616670559065ce33ed8337b9ffda8cb0912d235dbb74a1de72c764249a9bb
MD5 9f6777b12bc72b5934e8240cf6ed03ca
BLAKE2b-256 d713a1e56db637da6192a71fd3e162f334c47136ce59baaae536e2b84280d797

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page