Skip to main content

Implementation of various flavours of recurrent neural networks in Jax and Equinox

Project description

RNN-Jax

Implementation of various flavours of recurrent neural networks in Jax and Equinox

Usage

To start using the library, clone the repository, installing the dependencies in the pyproject.toml.

Example usage

Defining and running a model can be done in few lines

import jax
import equinox as eqx
import jax.random as jr
from rnn_jax.cells import ElmanRNNCell
from rnn_jax.layers import RNN


key = jr.key(0)  # PRNGkey
model_key, data_key = jr.split(key, 2) # split the keys
cell_key, out_key = jr.split(model_key, 2)

rnn = RNN(cell=ElmanRNNCell(idim=1, hdim=16, key=cell_key), odim=1, key=out_key)

x = jr.normal(key=data_key, shape=(100, 1))  # (seq_len, hdim)

outs = rnn(x)

For batched inputs, the model should bevmaped over the batch as follows

x = jr.normal(key=data_key, shape=(64, 100, 1)) #(batch, seq_len, hdim)
outs = eqx.filter_vmap(rnn)(x)

Overview of the cell types (other types will likely be added)

State Space Models (SSM)

State space models are a class of recurrent network that use linear recurrence to perform forward and backward pass through time. In JAX this can be implemented efficiently using jax.lax.associative_scan.

Third-Party Attributions

This project includes dataset files sourced from: reservoirpy (https://github.com/reservoirpy/reservoirpy.git) Copyright (c) Xavier Hinaut (2018) xavier.hinaut@inria.fr

The dataset retains its original MIT License, found in rnn_jax/datasets/_reservoirpy/LICENSE.md.

To DOs (roughly in order of importance)

  • code to integrate reservoirpy sets
  • implement some out-of-the-box training methods
  • modular layers (would require models with additional inputs e.g. $\sigma(W_{in} x + W_{h} h + W_{m} m)$) where $m$ is the message from other modules)
  • message-passing nn with recurrent cells. Maybe the modular layer can be viewed as a MPNN.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rnn_jax-0.0.1.tar.gz (235.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rnn_jax-0.0.1-py3-none-any.whl (246.7 kB view details)

Uploaded Python 3

File details

Details for the file rnn_jax-0.0.1.tar.gz.

File metadata

  • Download URL: rnn_jax-0.0.1.tar.gz
  • Upload date:
  • Size: 235.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.12 {"installer":{"name":"uv","version":"0.9.12"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for rnn_jax-0.0.1.tar.gz
Algorithm Hash digest
SHA256 8c23bcce4b0581c9cfeddd5ef12b13b50cd8bd19b20c624e99b75b345439ef03
MD5 e857cb713aad37c3bf4cc255b41ea743
BLAKE2b-256 f9234cded9f91add6533cbafc907ec91914ea802b779d43a9316ced7373b9449

See more details on using hashes here.

File details

Details for the file rnn_jax-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: rnn_jax-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 246.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.12 {"installer":{"name":"uv","version":"0.9.12"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for rnn_jax-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 72fb8b600efe6a8074f2f27ec36702ebb9b1689058034610d8ee830d0505698b
MD5 89728d6d0a2d4daa150f5901b52a305e
BLAKE2b-256 7b227b2c3b9db676f946d0d64d76ece5745301a0fd8cc6f093e8a9ce57222208

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page