Skip to main content

Deep memory and sequence modeling in JAX

Project description

Memax - Sequence and Memory Modeling in JAX

Tests

Memax is a library for efficient recurrent models. Using category theory, we utilize a simple interface that should work for nearly all recurrent models. We provide a unified interface for fast recurrent state resets across the sequence, allowing you to train over batches of variable-length sequences without sequence truncation or zero-padding.

Table of Contents

  1. Models
  2. Datasets
  3. Getting Started
  4. Documentation
  5. Citation

Recurrent Models

We implement both linear and log-complexity recurrent models.

Name Parallel Time Complexity Paper Code
Linear Recurrent Unit $O(\log{n})$ [paper] [code]
Selective State Space Model (S6) $O(\log{n})$ [paper] [code]
Linear Recurrent Neural Network $O(\log{n})$ [paper] [code]
Fast Autoregressive Transformer $O(\log{n})$ [paper] [code]
Fast and Forgetful Memory $O(\log{n})$ [paper] [code]
Rotational RNN (RotRNN) $O(\log{n})$ [paper] [code]
Fast Weight Programmer $O(\log{n})$ [paper] [code]
DeltaNet $O(\log{n})$ [paper] [code]
Gated DeltaNet $O(\log{n})$ [paper] [code]
DeltaProduct $O(\log{n})$ [paper] [code]
Attention $O(\log{n})$ [paper] [code]
RoPE-Attention $O(\log{n})$ [paper] [code]
ALiBi-Attention $O(\log{n})$ [paper] [code]
Elman Network $O(n)$ [paper] [code]
Gated Recurrent Unit $O(n)$ [paper] [code]
Minimal Gated Unit $O(n)$ [paper] [code]
Long Short-Term Memory Unit $O(n)$ [paper] [code]

Datasets

We provide datasets to test our recurrent models.

Sequential MNIST [HuggingFace] [Code]

The recurrent model receives an MNIST image pixel by pixel, and must predict the digit class.

Sequence Lengths: [784]

MNIST Math [HuggingFace] [Code]

The recurrent model receives a sequence of MNIST images and operators, pixel by pixel, and must predict the percentile of the operators applied to the MNIST image classes.

Sequence Lengths: [784 * 5, 784 * 100, 784 * 1_000, 784 * 10_000, 784 * 1_000_000]

Continuous Localization [HuggingFace] [Code]

The recurrent model receives a sequence of translation and rotation vectors in the local coordinate frame, and must predict the corresponding position and orientation in the global coordinate frame.

Sequence Lengths: [20, 100, 1_000]

Getting Started

Install memax using pip and git for your specific framework

pip install "memax[equinox]@git+https://github.com/smorad/memax"
pip install "memax[flax]@git+https://github.com/smorad/memax"

If you want to use our dataset and training scripts, install via

pip install "memax[dataset,equinox]@git+https://github.com/smorad/memax"
pip install "memax[dataset,flax]@git+https://github.com/smorad/memax"

Equinox Quickstart

from memax.equinox.train_utils import get_residual_memory_models
import jax
import jax.numpy as jnp
from equinox import filter_jit, filter_vmap
from memax.equinox.train_utils import add_batch_dim

T, F = 5, 6 # time and feature dim

model = get_residual_memory_models(
    input=F, hidden=8, output=1, num_layers=2, 
    models=["LRU"], key=jax.random.key(0)
)["LRU"]

starts = jnp.array([True, False, False, True, False])
xs = jnp.zeros((T, F)) 
hs, ys = filter_jit(model)(model.initialize_carry(), (xs, starts))
last_h = filter_jit(model.latest_recurrent_state)(hs)

# with batch dim
B = 4
starts = jnp.zeros((B, T), dtype=bool)
xs = jnp.zeros((B, T, F))
hs_0 = add_batch_dim(model.initialize_carry(), B)
hs, ys = filter_jit(filter_vmap(model))(hs_0, (xs, starts))

Running Baselines

You can compare various recurrent models on our datasets with a single command

python run_equinox_experiments.py # equinox framework
python run_linen_experiments.py # flax linen framework

Custom Architectures

memax uses the equinox neural network library. See the semigroups directory for fast recurrent models that utilize an associative scan. We also provide a beta flax.linen API. In this example, we focus on equinox.

import equinox as eqx
import jax
import jax.numpy as jnp

from memax.equinox.set_actions.gru import GRU
from memax.equinox.models.residual import ResidualModel
from memax.equinox.semigroups.lru import LRU, LRUSemigroup
from memax.utils import debug_shape

# You can pack multiple subsequences into a single sequence using the start flag
sequence_starts = jnp.array([True, False, False, True, False])
x = jnp.zeros((5, 3))
inputs = (x, sequence_starts)

# Initialize a multi-layer recurrent model
key = jax.random.key(0)
make_layer_fn = lambda recurrent_size, key: LRU(
    hidden_size=recurrent_size, recurrent_size=recurrent_size, key=key
)
model = ResidualModel(
    make_layer_fn=make_layer_fn,
    input_size=3,
    recurrent_size=16,
    output_size=4,
    num_layers=2,
    key=key,
)

# Note: We also have layers if you want to build your own model
layer = LRU(hidden_size=16, recurrent_size=16, key=key)
# Or semigroups/set actions (scanned functions) if you want to build your own layer
sg = LRUSemigroup(recurrent_size=16)

# Run the model! All models are jit-capable, using equinox.filter_jit
h = eqx.filter_jit(model.initialize_carry)()
# Unlike most other libraries, we output ALL recurrent states h, not just the most recent
h, y = eqx.filter_jit(model)(h, inputs)
# Since we have two layers, we have a recurrent state of shape
print(debug_shape(h))
#     ((5, 16), # Recurrent states of first layer
#     (5,) # Start carries for first layer
#     (5, 16) # Recurrent states of second layer
#     (5,)) # Start carries for second layer
# 
# Do your prediction
prediction = jax.nn.softmax(y)

# If you want to continue rolling out the RNN from h[-1]
# you should use the following helper function to extract
# h[-1] from the nested recurrent state
latest_h = eqx.filter_jit(model.latest_recurrent_state)(h)
# Continue rolling out as you please! You can use a single timestep
# or another sequence.
last_h, last_y = eqx.filter_jit(model)(latest_h, inputs)

# We can use a similar approach with RNNs
make_layer_fn = lambda recurrent_size, key: GRU(
    recurrent_size=recurrent_size, key=key
)
model = ResidualModel(
    make_layer_fn=make_layer_fn,
    input_size=3,
    recurrent_size=16,
    output_size=4,
    num_layers=2,
    key=jax.random.key(0),
)
h = eqx.filter_jit(model.initialize_carry)()
h, y = eqx.filter_jit(model)(h, inputs)
prediction = jax.nn.softmax(y)
latest_h = eqx.filter_jit(model.latest_recurrent_state)(h)
h, y = eqx.filter_jit(model)(latest_h, inputs)

Creating Custom Recurrent Models

All recurrent cells should follow the GRAS interface. A recurrent cell consists of an Algebra. You can roughly think of the Algebra as the function that updates the recurrent state, and the GRAS as the Algebra and all the associated MLPs/gates. You may reuse our Algebras in your custom GRAS, or even write your custom Algebra.

To implement your own Algebra and GRAS, we suggest copying one from our existing code, such as the LRNN for a Semigroup or the Elman Network for a SetAction.

Documentation

Full documentation is available here.

Citing our Work

Please cite the library as

@misc{morad_memax_2025,
	title = {memax},
	url = {https://github.com/smorad/memax},
	author = {Morad, Steven and Toledo, Edan and Kortvelesy, Ryan and He, Zhe},
	month = jun,
	year = {2025},
}

If you use the recurrent state resets (sequence_starts) with the log complexity memory models, please cite

@article{morad2024recurrent,
  title={Recurrent reinforcement learning with memoroids},
  author={Morad, Steven and Lu, Chris and Kortvelesy, Ryan and Liwicki, Stephan and Foerster, Jakob and Prorok, Amanda},
  journal={Advances in Neural Information Processing Systems},
  volume={37},
  pages={14386--14416},
  year={2024}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

memax-0.1.0.tar.gz (43.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

memax-0.1.0-py3-none-any.whl (61.1 kB view details)

Uploaded Python 3

File details

Details for the file memax-0.1.0.tar.gz.

File metadata

  • Download URL: memax-0.1.0.tar.gz
  • Upload date:
  • Size: 43.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for memax-0.1.0.tar.gz
Algorithm Hash digest
SHA256 0ae6580ebc66ae06de532a842607dc8d526fbd79f552a7bb191ec9e4051e93b4
MD5 60fa804f894be2df6c87b610d04e5a94
BLAKE2b-256 69de54885d77461453e6aab327f647bd6b2e1fb27219985d0e1ef1b80ace0443

See more details on using hashes here.

Provenance

The following attestation bundles were made for memax-0.1.0.tar.gz:

Publisher: python-publish.yml on smorad/memax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file memax-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: memax-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 61.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for memax-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 9fd825d364bc446c03427cd6cf49293a81f53b4964feac86202238caed13d51b
MD5 5f2fd29feaf05d9734e4be8e5806c059
BLAKE2b-256 ca945b59c25069c121178781fe5cdb1af18f9258ccdc2c6e8ebea534efa9a30b

See more details on using hashes here.

Provenance

The following attestation bundles were made for memax-0.1.0-py3-none-any.whl:

Publisher: python-publish.yml on smorad/memax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page