A unified JAX/Flax framework for memory-augmented reinforcement learning with support for RNNs, SSMs, and Transformers
Project description
Memorax
A unified reinforcement learning framework featuring memory-augmented algorithms and POMDP environment implementations. This repository provides modular components for building, configuring, and running a variety of RL algorithms on classic and memory-intensive environments.
Features
- Memory-RL: JAX implementations of DQN, PPO (Discrete & Continuous), SAC (Discrete & Continuous), PQN, and their memory-augmented variants.
- Recurrent Cells: Support for multiple RNN cells and Memory Architectures, including LSTM, GRU, GPT2, GTrXL, FFM, xLSTM, SHM, S5, LRU, RetNet, Mamba, MinGRU, Linear Transformer.
- Networks: MLP, CNN, and ViT encoders with support for RoPE and ALiBi positional embeddings, and Mixture of Experts (MoE) for horizontal scaling.
- Environments: Support for Gymnax, PopJym, PopGym Arcade, Navix, Craftax, Brax, MuJoCo, gxm, and XMiniGrid.
- Logging & Sweeps: Support for a CLI Dashboard, Weights & Biases, TensorboardX, and Neptune.
- Easy to Extend: Clear directory structure for adding new networks, algorithms, or environments.
Installation
- Clone the repository:
git clone https://github.com/memorax/memorax.git
cd memorax
- Install Python dependencies:
uv sync
Optionally you can add support for CUDA with
uv sync --extra cuda
- Optional: Set up Weights & Biases for logging by logging in:
wandb login
Quick Start
Run a default DQN experiment on CartPole:
uv run examples/dqn_gymnax.py
Usage
import jax
import optax
from memorax.algorithms import PPO, PPOConfig
from memorax.environments import environment
from memorax.networks import (
MLP, FFN, ALiBi, FeatureExtractor, GatedResidual, Network,
PreNorm, SegmentRecurrence, SelfAttention, Stack, heads,
)
env, env_params = environment.make("gymnax::CartPole-v1")
cfg = PPOConfig(
name="PPO-GTrXL",
num_envs=8,
num_eval_envs=16,
num_steps=128,
gamma=0.99,
gae_lambda=0.95,
num_minibatches=4,
update_epochs=4,
normalize_advantage=True,
clip_coef=0.2,
clip_vloss=True,
ent_coef=0.01,
vf_coef=0.5,
)
features, num_heads, num_layers = 64, 4, 2
feature_extractor = FeatureExtractor(observation_extractor=MLP(features=(features,)))
attention = GatedResidual(PreNorm(SegmentRecurrence(
SelfAttention(features, num_heads, context_length=128, positional_embedding=ALiBi(num_heads)),
memory_length=64, features=features,
)))
ffn = GatedResidual(PreNorm(FFN(features=features, expansion_factor=4)))
torso = Stack(blocks=(attention, ffn) * num_layers)
actor_network = Network(feature_extractor, torso, heads.Categorical(env.action_space(env_params).n))
critic_network = Network(feature_extractor, torso, heads.VNetwork())
optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(3e-4))
agent = PPO(cfg, env, env_params, actor_network, critic_network, optimizer, optimizer)
key, state = agent.init(jax.random.key(0))
key, state, transitions = agent.train(key, state, num_steps=10_000)
Project structure
memorax/
├─ examples/ # Small runnable scripts (e.g., DQN CartPole)
├─ memorax/
├─ algorithms/ # DQN, PPO, SAC, PQN, ...
├─ networks/ # MLP, CNN, ViT, RNN, heads, ...
├─ environments/ # Gymnax / PopGym / Brax / ...
├─ buffers/ # Custom flashbax buffers
├─ loggers/ # CLI, WandB, TensorBoardX integrations
└─ utils/
License
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
Citation
If you use Memory-RL for your work, please cite:
@software{memoryrl2025github,
title = {Memory-RL: A Unified Framework for Memory-Augmented Reinforcement Learning},
author = {Noah Farr},
year = {2025},
url = {https://github.com/memory-rl/memorax}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file memorax-1.0.0.tar.gz.
File metadata
- Download URL: memorax-1.0.0.tar.gz
- Upload date:
- Size: 69.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.26 {"installer":{"name":"uv","version":"0.9.26","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
976a5cc3756e7d06428e34cc694815ed136c55afa71eb10bd9cb4aef36ea5c4b
|
|
| MD5 |
2e6ccc8bda22607650b5bc078abbd38e
|
|
| BLAKE2b-256 |
89368236d5d8e0b997ce007e217276ba44c39e28d0cbe2200f686e1bb5e016ee
|
File details
Details for the file memorax-1.0.0-py3-none-any.whl.
File metadata
- Download URL: memorax-1.0.0-py3-none-any.whl
- Upload date:
- Size: 90.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.26 {"installer":{"name":"uv","version":"0.9.26","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e86f4a62dac979a1f422dc12668946393d95c978c4a9ad19b9090efa92e3bdc8
|
|
| MD5 |
db77817b385d353f7a356371b94515cd
|
|
| BLAKE2b-256 |
ee27e1e64ae7f7b4c5f36a968e6f6a600a37c5e57779e59aa3e8ed55f80696fc
|