Skip to main content

A jax based nn library

Project description

Jaxtorch (a jax nn library)

This is my jax based nn library. I created this because I was annoyed by the complexity and 'magic'-ness of the popular jax frameworks (flax, haiku).

The objective is to enable pytorch-like model definition and training with a minimum of magic. Simple example:

import jax
import jax.numpy as jnp
import jaxlib
import jaxtorch

# Modules are just classes that inherit jaxtorch.Module
class Linear(jaxtorch.Module):
    # They can accept any constructor parameters
    def __init__(self, in_features: int, out_features: int, bias: bool = True):
        super().__init__()
        # Parameters are represented by a Param type, which identifies
        # them, and specifies how to initialize them.
        self.weight = jaxtorch.init.glorot_normal(out_features, in_features)
        assert type(self.weight) is jaxtorch.Param
        if bias:
            self.bias = jaxtorch.init.zeros(out_features)
        else:
            self.bias = None

    # The forward function accepts cx, a Context object as the first argument
    # always. This provides random number generation as well as the parameters.
    def forward(self, cx: jaxtorch.Context, x):
        # Parameters are looked up in the context using the stored identifier.
        y = x @ jnp.transpose(cx[self.weight])
        if self.bias:
            y = y + cx[self.bias]
        return y

model = Linear(3, 3)

# You initialize the weights by passing a RNG key.
# Calling init_weights also names all the parameters in the Module tree.
params = model.init_weights(jax.random.PRNGKey(0))

# Parameters are stored in a dictionary by name.
assert type(params) is dict
assert type(params[model.weight.name]) is jaxlib.xla_extension.DeviceArray
assert model.weight.name == 'weight'

def loss(params, key):
    cx = jaxtorch.Context(params, key)
    x = jnp.array([1.0,2.0,3.0])
    y = jnp.array([4.0,5.0,6.0])
    return jnp.mean((model(cx, x) - y)**2)
f_grad = jax.value_and_grad(loss)

for _ in range(100):
    (loss, grad) = f_grad(params, jax.random.PRNGKey(0))
    params = jax.tree_util.tree_map(lambda p, g: p - 0.01 * g, params, grad)
print(loss)
# 4.7440533e-08

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxtorch-0.1.0.tar.gz (13.1 kB view details)

Uploaded Source

Built Distribution

jaxtorch-0.1.0-py3-none-any.whl (13.5 kB view details)

Uploaded Python 3

File details

Details for the file jaxtorch-0.1.0.tar.gz.

File metadata

  • Download URL: jaxtorch-0.1.0.tar.gz
  • Upload date:
  • Size: 13.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.11.3 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.9.9

File hashes

Hashes for jaxtorch-0.1.0.tar.gz
Algorithm Hash digest
SHA256 e7fc5f6ad94254ff6b72be2122b1df54e399350f2b38b20867f1db5135943dd0
MD5 6c67aeb230ad710dedb05e991f8d84a5
BLAKE2b-256 cf418cbcd31b6e99d9cb36f6126704f5d322cf6b1a2558637f65794f7d490529

See more details on using hashes here.

File details

Details for the file jaxtorch-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: jaxtorch-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 13.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.11.3 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.9.9

File hashes

Hashes for jaxtorch-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 cd84d38421d2930096ab2d814b1e428b6449161585e225cfdc106b7b75dfae59
MD5 a79854826df94feff709c251b343add6
BLAKE2b-256 8265e82fad3ba761f5827bde9ebbbacf9f9dc5a5c23b6fb5cc277ef6e5e4d51f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page