A unified JAX/Flax framework for memory-augmented reinforcement learning with support for RNNs, SSMs, and Transformers
Project description
🧠 Memorax
A unified reinforcement learning framework featuring memory-augmented algorithms and POMDP environment implementations. This repository provides modular components for building, configuring, and running a variety of RL algorithms on classic and memory-intensive environments.
✨ Features
- 🤖 Memory-RL: JAX implementations of DQN, PPO (Discrete & Continuous), SAC (Discrete & Continuous), PQN, IPPO, R2D2, and their memory-augmented variants with burn-in support for recurrent networks.
- 📦 Pure JAX Episode Buffer: A fully JAX-native episode buffer implementation enabling efficient storage and sampling of complete episodes for recurrent training, with support for Prioritized Experience Replay.
- 🔁 Recurrent Cells: Support for multiple RNN cells and Memory Architectures, including LSTM, GRU, GPT2, GTrXL, FFM, xLSTM, SHM, S5, LRU, RetNet, Mamba, MinGRU, Linear Transformer.
- 🧬 Networks: MLP, CNN, and ViT encoders with support for RoPE and ALiBi positional embeddings, and Mixture of Experts (MoE) for horizontal scaling.
- 🎮 Environments: Support for Gymnax, PopJym, PopGym Arcade, Navix, Craftax, Brax, MuJoCo, gxm, XMiniGrid, and JaxMARL.
- 📊 Logging & Sweeps: Support for a CLI Dashboard, Weights & Biases, TensorboardX, and Neptune.
- 🔧 Easy to Extend: Clear directory structure for adding new networks, algorithms, or environments.
📥 Installation
Install Memorax using pip:
pip install memorax
Or using uv:
uv add memorax
Optionally you can add support for CUDA with:
pip install memorax[cuda]
Optional: Set up Weights & Biases for logging by logging in:
wandb login
🚀 Quick Start
Run a default DQN experiment on CartPole:
uv run examples/dqn_gymnax.py
💻 Usage
import jax
import optax
from memorax.algorithms import PPO, PPOConfig
from memorax.environments import environment
from memorax.networks import (
MLP, FFN, ALiBi, FeatureExtractor, GatedResidual, Network,
PreNorm, SegmentRecurrence, SelfAttention, Stack, heads,
)
env, env_params = environment.make("gymnax::CartPole-v1")
cfg = PPOConfig(
name="PPO-GTrXL",
num_envs=8,
num_eval_envs=16,
num_steps=128,
gamma=0.99,
gae_lambda=0.95,
num_minibatches=4,
update_epochs=4,
normalize_advantage=True,
clip_coef=0.2,
clip_vloss=True,
ent_coef=0.01,
vf_coef=0.5,
)
features, num_heads, num_layers = 64, 4, 2
feature_extractor = FeatureExtractor(observation_extractor=MLP(features=(features,)))
attention = GatedResidual(PreNorm(SegmentRecurrence(
SelfAttention(features, num_heads, context_length=128, positional_embedding=ALiBi(num_heads)),
memory_length=64, features=features,
)))
ffn = GatedResidual(PreNorm(FFN(features=features, expansion_factor=4)))
torso = Stack(blocks=(attention, ffn) * num_layers)
actor_network = Network(feature_extractor, torso, heads.Categorical(env.action_space(env_params).n))
critic_network = Network(feature_extractor, torso, heads.VNetwork())
optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(3e-4))
agent = PPO(cfg, env, env_params, actor_network, critic_network, optimizer, optimizer)
key, state = agent.init(jax.random.key(0))
key, state, transitions = agent.train(key, state, num_steps=10_000)
📂 Project Structure
memorax/
├─ examples/ # Small runnable scripts (e.g., DQN CartPole)
├─ memorax/
├─ algorithms/ # DQN, PPO, SAC, PQN, ...
├─ networks/ # MLP, CNN, ViT, RNN, heads, ...
├─ environments/ # Gymnax / PopGym / Brax / ...
├─ buffers/ # Custom flashbax buffers
├─ loggers/ # CLI, WandB, TensorBoardX integrations
└─ utils/
📄 License
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
📚 Citation
If you use Memory-RL for your work, please cite:
@software{memoryrl2025github,
title = {Memory-RL: A Unified Framework for Memory-Augmented Reinforcement Learning},
author = {Noah Farr},
year = {2025},
url = {https://github.com/memory-rl/memorax}
}
🙏 Acknowledgments
Special thanks to @huterguier for the valuable discussions and advice on the API design.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file memorax-1.0.1.tar.gz.
File metadata
- Download URL: memorax-1.0.1.tar.gz
- Upload date:
- Size: 78.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.26 {"installer":{"name":"uv","version":"0.9.26","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8ace7a8f0482f95b54b1e72e5947daf20dae6e9d613e187d792bc3056efb6056
|
|
| MD5 |
19b7368460400ee5e85814e22d9e4f4f
|
|
| BLAKE2b-256 |
53fff7861f86d092aa7a40506d5aac8079e515ae2b48a96dc62ec5815ae3675f
|
File details
Details for the file memorax-1.0.1-py3-none-any.whl.
File metadata
- Download URL: memorax-1.0.1-py3-none-any.whl
- Upload date:
- Size: 99.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.26 {"installer":{"name":"uv","version":"0.9.26","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ac1b5a5e029a11fc10a97595ab148c36dc766c7447e69bc531d642acec0f5f1f
|
|
| MD5 |
2d45dd0f717c05f8d8c6671a266b8c78
|
|
| BLAKE2b-256 |
c6fbca4254c14821bdf238e6d1ada78e6553ceb66357977cd94c7e6814183523
|