Skip to main content

Differentiable episodic memory for reinforcement learning.

Project description

hippotorch

PyPI Linux CI Windows CI

Differentiable episodic memory for RL. Retrieves what matters. Forgets what doesn't.

hippotorch is a PyTorch library that replaces standard replay buffers with a learnable memory system. It uses reward-aware contrastive learning to organize experiences and hybrid sampling to retrieve them—solving the temporal credit assignment problem in sparse-reward, long-horizon tasks.

Install: pip install hippotorch

  • Fast retrieval (FAISS): pip install hippotorch[faiss]
  • Gym wrappers/examples: pip install hippotorch[envs]

Key Hyperparameters

Parameter Default Guidance
mixture_ratio 0.5 Start low (0.2), ramp to 0.5 after warmup
momentum 0.995 Higher = more stable keys, slower adaptation
temperature 0.07 Lower = sharper retrieval (try 0.05 for sparse rewards)
reward_weight 0.5 Higher = clusters by reward; lower = clusters by time

When to Use (and When Not To)

Hippotorch adds overhead. Use it where episodic structure matters.

Scenario Benefit Recommendation
Sparse rewards (Montezuma, long corridors) ✅ High Use hippotorch—retrieval surfaces rare successes
Partial observability (POMDPs, visual RL) ✅ High Use hippotorch—pattern completion reconstructs context
Long-horizon tasks (100+ steps to goal) ✅ High Use hippotorch—bridges temporal credit gap
Curriculum / transfer learning ✅ High Use hippotorch—retains skills across task stages
Dense rewards, full observability ⚠️ Low Use standard replay—uniform sampling is sufficient
Short episodes (<20 steps) ⚠️ Low Use standard replay—no retrieval advantage

Rule of Thumb

If your agent "forgets" how to solve early tasks, or struggles to connect actions to delayed rewards, hippotorch can help. If training already converges well with a standard buffer, you don't need it.


Installation

pip install hippotorch

# Optional: Fast retrieval for large memories (1M+ episodes)
pip install hippotorch[faiss]

# Optional: Gymnasium + wrappers/examples
pip install hippotorch[envs]

Requirements: Python ≥3.9, PyTorch ≥2.0


Quickstart

Create a dual encoder + memory store, attach a consolidator, and use the hybrid replay buffer.

import torch
from hippotorch import Episode, DualEncoder, MemoryStore, Consolidator, HippocampalReplayBuffer

state_dim, action_dim = 4, 1
input_dim = state_dim + action_dim + 1  # +1 for reward

# 1) Encoder and memory
encoder = DualEncoder(input_dim=input_dim, embed_dim=128, momentum=0.995)
memory = MemoryStore(embed_dim=128, capacity=50_000)

# 2) Reward-aware consolidator (sleep phase optimizer)
consolidator = Consolidator(encoder, temperature=0.07, reward_weight=0.5)

# 3) Hybrid replay buffer (semantic + uniform)
buffer = HippocampalReplayBuffer(memory=memory, encoder=encoder, mixture_ratio=0.3,
                                 consolidator=consolidator)

# 4) Record an episode (toy tensors)
T = 32
states = torch.randn(T, state_dim)
actions = torch.randn(T, action_dim)
rewards = torch.randn(T)
dones = torch.zeros(T, dtype=torch.bool)
episode = Episode(states=states, actions=actions, rewards=rewards, dones=dones)
buffer.add_episode(episode)

# 5) Sample with semantic + uniform mixing
query_state = torch.cat([states[0], torch.zeros(action_dim), rewards[0].unsqueeze(0)])
batch = buffer.sample(batch_size=64, query_state=query_state, top_k=5)
agent.update(batch)

# 6) Periodic consolidation ("sleep")
metrics = buffer.consolidate(steps=50, batch_size=64, report_quality=True)
print(metrics)

SB3 users can keep their rollout API unchanged with the adapter:

from hippotorch import SB3ReplayBufferWrapper, TerminalSegmenter
sb3_buffer = SB3ReplayBufferWrapper(buffer, segmenter=TerminalSegmenter())
# sb3_buffer.add(obs, next_obs, action, reward, done)

Recall While Acting (query() + wrappers)

Use the read-only query API for inference-time recall:

from hippotorch import query
query_vec = torch.cat([obs, torch.zeros(action_dim), torch.zeros(1)])
result = query(query_vec, buffer=buffer, top_k=5)
print(result.episode_ids, result.scores)

To feed retrieval features into an online policy, wrap the Gymnasium env:

from hippotorch import HippotorchMemoryWrapper
wrapped_env = HippotorchMemoryWrapper(env, buffer, query_state_fn=lambda obs: build_query(obs))
# SB3 tip: use MultiInputPolicy because the wrapper returns a Dict observation
model = PPO("MultiInputPolicy", wrapped_env, verbose=1)

See examples/query_inference_demo.py and examples/minigrid_memory_wrapper.py for runnable snippets.

Portable Brains (Hub)

Export a trained memory so another agent can load it instantly:

from hippotorch import (
    DualEncoder,
    HippocampalReplayBuffer,
    MemoryStore,
    push_memory_to_hub,
    load_memory_from_hub,
)

obs_dim = 42
encoder = DualEncoder(input_dim=obs_dim, embed_dim=128)
memory = MemoryStore(embed_dim=128, capacity=2048)

# Push to hub (requires real hub backend, e.g., huggingface_hub)
push_memory_to_hub(memory, repo_id="user/fetch-reach-expert", private=False)

# Later, load
restored = load_memory_from_hub("user/fetch-reach-expert")

# Or operate via the buffer convenience wrappers
buffer = HippocampalReplayBuffer(memory=memory, encoder=encoder)
buffer.save_to_hub("user/fetch-reach-expert")
restored_memory = buffer.load_memory_from_hub("user/fetch-reach-expert")

Note

  • The hub utilities in this repo are minimal stubs for testing. To use the Hub in production, integrate a real backend (e.g., huggingface_hub) by adapting hippotorch.utils.hub.

Quick Experiment Scripts

Convenience scripts in scripts/ run short, repeatable experiments:

  • Rank‑weighted consolidation ablation:
    • bash scripts/run_rank_ablation.sh
  • Consolidation micro‑bench (synthetic):
    • bash scripts/run_consolidation_micro.sh
  • CartPole parity (short run with logging):
    • bash scripts/quick_cartpole.sh
  • Zero‑noise corridor (Amnesiac):
    • bash scripts/corridor_multiseed_zn.sh
    • Faster: SEEDS=2 EPISODES=150 CONS_EVERY=10 CONS_STEPS=50 bash scripts/corridor_multiseed_zn.sh
  • Curriculum corridor (progressively increases length):
    • bash scripts/corridor_curriculum.sh
  • TensorBoard embedding snapshot (PCA):
    • bash scripts/log_tb_embedding.sh then tensorboard --logdir runs/hippo_tb
  • FAISS vs. torch retrieval benchmark:
    • python scripts/bench_retrieval.py --sizes 10000 100000 500000
  • MiniGrid memory baseline sweep + plot:
    • python scripts/minigrid_memory_benchmark.py --steps 8000 --seeds 3
  • Retrieval heatmap diagnostic:
    • python scripts/retrieval_heatmap.py --memory-checkpoint ... --encoder-checkpoint ...

See docs/benchmarks.md and docs/curriculum.md for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hippotorch-0.4.0.tar.gz (47.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

hippotorch-0.4.0-py3-none-any.whl (47.5 kB view details)

Uploaded Python 3

File details

Details for the file hippotorch-0.4.0.tar.gz.

File metadata

  • Download URL: hippotorch-0.4.0.tar.gz
  • Upload date:
  • Size: 47.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for hippotorch-0.4.0.tar.gz
Algorithm Hash digest
SHA256 459090a468f68341f419211bda2a34f163cf880e16fd57ce536d7b8d2fb1cc67
MD5 349d91753924066ed1842e7d6e55f883
BLAKE2b-256 424cde405f87c63d3f91ca8c9bb487caf68ffd71a74e6045344ac7e5bd97503e

See more details on using hashes here.

Provenance

The following attestation bundles were made for hippotorch-0.4.0.tar.gz:

Publisher: workflow.yml on domezsolt/hippotorch

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file hippotorch-0.4.0-py3-none-any.whl.

File metadata

  • Download URL: hippotorch-0.4.0-py3-none-any.whl
  • Upload date:
  • Size: 47.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for hippotorch-0.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ef25a6281659fd274208631299e26391e6ae9ee19a14e3875006673a4f5430dd
MD5 c42e91e58454d9f9c9261b0248be232e
BLAKE2b-256 d2d6df876d25abc81d4660d04354a3e9371f936d1aee7eb4cf4123e4cdbec59a

See more details on using hashes here.

Provenance

The following attestation bundles were made for hippotorch-0.4.0-py3-none-any.whl:

Publisher: workflow.yml on domezsolt/hippotorch

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page