Skip to main content

A unified JAX/Flax framework for memory-augmented reinforcement learning with support for RNNs, SSMs, and Transformers

Project description

Memorax

A unified reinforcement learning framework featuring memory-augmented algorithms and POMDP environment implementations. This repository provides modular components for building, configuring, and running a variety of RL algorithms on classic and memory-intensive environments.

Features

Installation

  1. Clone the repository:
git clone https://github.com/memorax/memorax.git
cd memorax
  1. Install Python dependencies:
uv sync

Optionally you can add support for CUDA with

uv sync --extra cuda
  1. Optional: Set up Weights & Biases for logging by logging in:
wandb login

Quick Start

Run a default DQN experiment on CartPole:

uv run examples/dqn_gymnax.py

Usage

import jax
import optax
from memorax.algorithms import PPO, PPOConfig
from memorax.environments import environment
from memorax.networks import (
    MLP, FFN, ALiBi, FeatureExtractor, GatedResidual, Network,
    PreNorm, SegmentRecurrence, SelfAttention, Stack, heads,
)

env, env_params = environment.make("gymnax::CartPole-v1")

cfg = PPOConfig(
    name="PPO-GTrXL",
    num_envs=8,
    num_eval_envs=16,
    num_steps=128,
    gamma=0.99,
    gae_lambda=0.95,
    num_minibatches=4,
    update_epochs=4,
    normalize_advantage=True,
    clip_coef=0.2,
    clip_vloss=True,
    ent_coef=0.01,
    vf_coef=0.5,
)

features, num_heads, num_layers = 64, 4, 2
feature_extractor = FeatureExtractor(observation_extractor=MLP(features=(features,)))
attention = GatedResidual(PreNorm(SegmentRecurrence(
    SelfAttention(features, num_heads, context_length=128, positional_embedding=ALiBi(num_heads)),
    memory_length=64, features=features,
)))
ffn = GatedResidual(PreNorm(FFN(features=features, expansion_factor=4)))
torso = Stack(blocks=(attention, ffn) * num_layers)

actor_network = Network(feature_extractor, torso, heads.Categorical(env.action_space(env_params).n))
critic_network = Network(feature_extractor, torso, heads.VNetwork())
optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(3e-4))

agent = PPO(cfg, env, env_params, actor_network, critic_network, optimizer, optimizer)
key, state = agent.init(jax.random.key(0))
key, state, transitions = agent.train(key, state, num_steps=10_000)

Project structure

memorax/
├─ examples/          # Small runnable scripts (e.g., DQN CartPole)
├─ memorax/
   ├─ algorithms/     # DQN, PPO, SAC, PQN, ...
   ├─ networks/       # MLP, CNN, ViT, RNN, heads, ...
   ├─ environments/   # Gymnax / PopGym / Brax / ...
   ├─ buffers/        # Custom flashbax buffers
   ├─ loggers/        # CLI, WandB, TensorBoardX integrations
   └─ utils/

License

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.

Citation

If you use Memory-RL for your work, please cite:

@software{memoryrl2025github,
  title   = {Memory-RL: A Unified Framework for Memory-Augmented Reinforcement Learning},
  author  = {Noah Farr},
  year    = {2025},
  url     = {https://github.com/memory-rl/memorax}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

memorax-1.0.0.tar.gz (69.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

memorax-1.0.0-py3-none-any.whl (90.2 kB view details)

Uploaded Python 3

File details

Details for the file memorax-1.0.0.tar.gz.

File metadata

  • Download URL: memorax-1.0.0.tar.gz
  • Upload date:
  • Size: 69.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.26 {"installer":{"name":"uv","version":"0.9.26","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for memorax-1.0.0.tar.gz
Algorithm Hash digest
SHA256 976a5cc3756e7d06428e34cc694815ed136c55afa71eb10bd9cb4aef36ea5c4b
MD5 2e6ccc8bda22607650b5bc078abbd38e
BLAKE2b-256 89368236d5d8e0b997ce007e217276ba44c39e28d0cbe2200f686e1bb5e016ee

See more details on using hashes here.

File details

Details for the file memorax-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: memorax-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 90.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.26 {"installer":{"name":"uv","version":"0.9.26","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for memorax-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 e86f4a62dac979a1f422dc12668946393d95c978c4a9ad19b9090efa92e3bdc8
MD5 db77817b385d353f7a356371b94515cd
BLAKE2b-256 ee27e1e64ae7f7b4c5f36a968e6f6a600a37c5e57779e59aa3e8ed55f80696fc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page