Skip to main content

A unified JAX/Flax framework for memory-augmented reinforcement learning with support for RNNs, SSMs, and Transformers

Project description

🧠 Memorax

A unified reinforcement learning framework featuring memory-augmented algorithms and POMDP environment implementations. This repository provides modular components for building, configuring, and running a variety of RL algorithms on classic and memory-intensive environments.

✨ Features

📥 Installation

Install Memorax using pip:

pip install memorax

Or using uv:

uv add memorax

Optionally you can add support for CUDA with:

pip install memorax[cuda]

Optional: Set up Weights & Biases for logging by logging in:

wandb login

🚀 Quick Start

Run a default DQN experiment on CartPole:

uv run examples/dqn_gymnax.py

💻 Usage

import jax
import optax
from memorax.algorithms import PPO, PPOConfig
from memorax.environments import environment
from memorax.networks import (
    MLP, FFN, ALiBi, FeatureExtractor, GatedResidual, Network,
    PreNorm, SegmentRecurrence, SelfAttention, Stack, heads,
)

env, env_params = environment.make("gymnax::CartPole-v1")

cfg = PPOConfig(
    name="PPO-GTrXL",
    num_envs=8,
    num_eval_envs=16,
    num_steps=128,
    gamma=0.99,
    gae_lambda=0.95,
    num_minibatches=4,
    update_epochs=4,
    normalize_advantage=True,
    clip_coef=0.2,
    clip_vloss=True,
    ent_coef=0.01,
    vf_coef=0.5,
)

features, num_heads, num_layers = 64, 4, 2
feature_extractor = FeatureExtractor(observation_extractor=MLP(features=(features,)))
attention = GatedResidual(PreNorm(SegmentRecurrence(
    SelfAttention(features, num_heads, context_length=128, positional_embedding=ALiBi(num_heads)),
    memory_length=64, features=features,
)))
ffn = GatedResidual(PreNorm(FFN(features=features, expansion_factor=4)))
torso = Stack(blocks=(attention, ffn) * num_layers)

actor_network = Network(feature_extractor, torso, heads.Categorical(env.action_space(env_params).n))
critic_network = Network(feature_extractor, torso, heads.VNetwork())
optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(3e-4))

agent = PPO(cfg, env, env_params, actor_network, critic_network, optimizer, optimizer)
key, state = agent.init(jax.random.key(0))
key, state, transitions = agent.train(key, state, num_steps=10_000)

📂 Project Structure

memorax/
├─ examples/          # Small runnable scripts (e.g., DQN CartPole)
├─ memorax/
   ├─ algorithms/     # DQN, PPO, SAC, PQN, ...
   ├─ networks/       # MLP, CNN, ViT, RNN, heads, ...
   ├─ environments/   # Gymnax / PopGym / Brax / ...
   ├─ buffers/        # Custom flashbax buffers
   ├─ loggers/        # CLI, WandB, TensorBoardX integrations
   └─ utils/

📄 License

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.

📚 Citation

If you use Memory-RL for your work, please cite:

@software{memoryrl2025github,
  title   = {Memory-RL: A Unified Framework for Memory-Augmented Reinforcement Learning},
  author  = {Noah Farr},
  year    = {2025},
  url     = {https://github.com/memory-rl/memorax}
}

🙏 Acknowledgments

Special thanks to @huterguier for the valuable discussions and advice on the API design.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

memorax-1.0.1.tar.gz (78.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

memorax-1.0.1-py3-none-any.whl (99.2 kB view details)

Uploaded Python 3

File details

Details for the file memorax-1.0.1.tar.gz.

File metadata

  • Download URL: memorax-1.0.1.tar.gz
  • Upload date:
  • Size: 78.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.26 {"installer":{"name":"uv","version":"0.9.26","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for memorax-1.0.1.tar.gz
Algorithm Hash digest
SHA256 8ace7a8f0482f95b54b1e72e5947daf20dae6e9d613e187d792bc3056efb6056
MD5 19b7368460400ee5e85814e22d9e4f4f
BLAKE2b-256 53fff7861f86d092aa7a40506d5aac8079e515ae2b48a96dc62ec5815ae3675f

See more details on using hashes here.

File details

Details for the file memorax-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: memorax-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 99.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.26 {"installer":{"name":"uv","version":"0.9.26","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for memorax-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 ac1b5a5e029a11fc10a97595ab148c36dc766c7447e69bc531d642acec0f5f1f
MD5 2d45dd0f717c05f8d8c6671a266b8c78
BLAKE2b-256 c6fbca4254c14821bdf238e6d1ada78e6553ceb66357977cd94c7e6814183523

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page