Implementation of various flavours of recurrent neural networks in Jax and Equinox
Project description
RNN-Jax
Implementation of various flavours of recurrent neural networks in Jax and Equinox
Usage
To start using the library, clone the repository, installing the dependencies in the pyproject.toml.
Example usage
Defining and running a model can be done in few lines
import jax
import equinox as eqx
import jax.random as jr
from rnn_jax.cells import ElmanRNNCell
from rnn_jax.layers import RNN
key = jr.key(0) # PRNGkey
model_key, data_key = jr.split(key, 2) # split the keys
cell_key, out_key = jr.split(model_key, 2)
rnn = RNN(cell=ElmanRNNCell(idim=1, hdim=16, key=cell_key), odim=1, key=out_key)
x = jr.normal(key=data_key, shape=(100, 1)) # (seq_len, hdim)
outs = rnn(x)
For batched inputs, the model should bevmaped over the batch as follows
x = jr.normal(key=data_key, shape=(64, 100, 1)) #(batch, seq_len, hdim)
outs = eqx.filter_vmap(rnn)(x)
Overview of the cell types (other types will likely be added)
-
Vanilla: Standard RNNs, following an equation which is roughly equivalent to $h_{t+1} = \sigma(W_{h} h_t + W_{x}x_{t+1} + b)$
- ElmanRNNCell: standard RNN (Elman, Finding Structure in Time, 1990)
- indRNNCell: independent RNN, where $W_h$ is diagonal (Li et al., Independently Recurrent Neural Network (IndRNN): Building A Longer and Deeper RNN, 2018)
-
Gated: Gated RNNs, i.e. architectures with gates designed to adaptively forget past inputs
- LongShortTermMemoryCell: LSTM cell (Hochreiter and Schmidhuber, Long Short-Term-Memory, 1997)
- GatedRecurrentUnit: GRU cell (Cho et al. Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation, 2014)
-
Antisymmetric: Architectures imposing an antisymmetric structure to the recurrence matrix $W_h$
- AntiSymmetricRNNCell: Antisymmetric RNN, where the update is described by $h_{t+1} = h_t + \sigma((W_{h} -W_{h}^T) h_t + W_{x}x_{t+1} + b)$ (Chang et al. AntisymmetricRNN: A Dynamical System View on Recurrent Neural Networks, 2019)
- GatedAntiSymmetricRNNCell: gated version of the antisymmetric RNN (same reference as above)
-
Other Recurrent Models
- ClockWorkRNNCell: Clockwork RNN, an architecture that processes inputs at different time scales (Koutník et al. A Clockwork RNN, 2014)
- LipschitzRNNCell: Lipschitz RNN, an architecture grounded in continuous time dymamical systems (Erichson et al. Lipschitz Recurrent Neural Networks, 2020)
- UnitaryEvolutionRNNCell: a flavor of Unitary RNN, that parametrizes the recurrence matrix to be unitary through Fourier transforms and Householder reflectors (Arjovsky et al. Unitary Evolution Recurrent Neural Networks, 2016)
- CoupledOscillatoryRNNCell: an RNN baased on oscillator dynamical systems (Rusch and Mishra, Coupled Oscillatory Recurrent Neural Network (coRNN), 2023), and its heterogenous variant (Ceni et al. Random Oscillators Network for Time Series Processing, 2024)
State Space Models (SSM)
State space models are a class of recurrent network that use linear recurrence to perform forward and backward pass through time. In JAX this can be implemented efficiently using jax.lax.associative_scan.
- S5: simplified SSM. An SSM that uses a diagonal recurrence matrix. (Smith et al. Simplified State Space Layers for Sequence Modeling, 2022).
- Linear Recurrent Unit: A model that adapts concepts to RNNs, employing linear recurrence and diagonal transiton matrix (Orvieto et al. Resurrecting Recurrent Neural Networks for Long Sequences, 2023).
Third-Party Attributions
This project includes dataset files sourced from: reservoirpy (https://github.com/reservoirpy/reservoirpy.git) Copyright (c) Xavier Hinaut (2018) xavier.hinaut@inria.fr
The dataset retains its original MIT License,
found in rnn_jax/datasets/_reservoirpy/LICENSE.md.
To DOs (roughly in order of importance)
- code to integrate reservoirpy sets
- implement some out-of-the-box training methods
- modular layers (would require models with additional inputs e.g. $\sigma(W_{in} x + W_{h} h + W_{m} m)$) where $m$ is the message from other modules)
- message-passing nn with recurrent cells. Maybe the modular layer can be viewed as a MPNN.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file rnn_jax-0.0.2.tar.gz.
File metadata
- Download URL: rnn_jax-0.0.2.tar.gz
- Upload date:
- Size: 235.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.12 {"installer":{"name":"uv","version":"0.9.12"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7ad5c9c27ac62e8a1fdc5837cec4234092f2ec7c939722ac16c865954109310b
|
|
| MD5 |
5fdf41dd2a29cc8b9fa4e55e3d090341
|
|
| BLAKE2b-256 |
fdaecae61d515d7d07704a728951a76d13819caa3828fa23f5991665ec221683
|
File details
Details for the file rnn_jax-0.0.2-py3-none-any.whl.
File metadata
- Download URL: rnn_jax-0.0.2-py3-none-any.whl
- Upload date:
- Size: 246.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.12 {"installer":{"name":"uv","version":"0.9.12"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
629ff8526e92bb3ca1d5167420ec2fbf9aed679014283f9bdf6d39c54bfb7267
|
|
| MD5 |
c6333dece31ee8e276e41157803e2b42
|
|
| BLAKE2b-256 |
d7781703c429cfbd6de54ebb76432b4d4608cf0a6de1cb3c7ea0b15d82f7da65
|