Skip to main content

Implementation of various flavours of recurrent neural networks in Jax and Equinox

Project description

RNN-Jax

Implementation of various flavours of recurrent neural networks in Jax and Equinox

Usage

To start using the library, clone the repository, installing the dependencies in the pyproject.toml.

Example usage

Defining and running a model can be done in few lines

import jax
import equinox as eqx
import jax.random as jr
from rnn_jax.cells import ElmanRNNCell
from rnn_jax.layers import RNN


key = jr.key(0)  # PRNGkey
model_key, data_key = jr.split(key, 2) # split the keys
cell_key, out_key = jr.split(model_key, 2)

rnn = RNN(cell=ElmanRNNCell(idim=1, hdim=16, key=cell_key), odim=1, key=out_key)

x = jr.normal(key=data_key, shape=(100, 1))  # (seq_len, hdim)

outs = rnn(x)

For batched inputs, the model should bevmaped over the batch as follows

x = jr.normal(key=data_key, shape=(64, 100, 1)) #(batch, seq_len, hdim)
outs = eqx.filter_vmap(rnn)(x)

Overview of the cell types (other types will likely be added)

State Space Models (SSM)

State space models are a class of recurrent network that use linear recurrence to perform forward and backward pass through time. In JAX this can be implemented efficiently using jax.lax.associative_scan.

Third-Party Attributions

This project includes dataset files sourced from: reservoirpy (https://github.com/reservoirpy/reservoirpy.git) Copyright (c) Xavier Hinaut (2018) xavier.hinaut@inria.fr

The dataset retains its original MIT License, found in rnn_jax/datasets/_reservoirpy/LICENSE.md.

To DOs (roughly in order of importance)

  • code to integrate reservoirpy sets
  • implement some out-of-the-box training methods
  • modular layers (would require models with additional inputs e.g. $\sigma(W_{in} x + W_{h} h + W_{m} m)$) where $m$ is the message from other modules)
  • message-passing nn with recurrent cells. Maybe the modular layer can be viewed as a MPNN.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rnn_jax-0.0.2.tar.gz (235.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rnn_jax-0.0.2-py3-none-any.whl (246.7 kB view details)

Uploaded Python 3

File details

Details for the file rnn_jax-0.0.2.tar.gz.

File metadata

  • Download URL: rnn_jax-0.0.2.tar.gz
  • Upload date:
  • Size: 235.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.12 {"installer":{"name":"uv","version":"0.9.12"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for rnn_jax-0.0.2.tar.gz
Algorithm Hash digest
SHA256 7ad5c9c27ac62e8a1fdc5837cec4234092f2ec7c939722ac16c865954109310b
MD5 5fdf41dd2a29cc8b9fa4e55e3d090341
BLAKE2b-256 fdaecae61d515d7d07704a728951a76d13819caa3828fa23f5991665ec221683

See more details on using hashes here.

File details

Details for the file rnn_jax-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: rnn_jax-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 246.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.12 {"installer":{"name":"uv","version":"0.9.12"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for rnn_jax-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 629ff8526e92bb3ca1d5167420ec2fbf9aed679014283f9bdf6d39c54bfb7267
MD5 c6333dece31ee8e276e41157803e2b42
BLAKE2b-256 d7781703c429cfbd6de54ebb76432b4d4608cf0a6de1cb3c7ea0b15d82f7da65

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page