Differentiable episodic memory for reinforcement learning.
Project description
hippotorch
Differentiable episodic memory for RL. Retrieves what matters. Forgets what doesn't.
hippotorch is a PyTorch library that replaces standard replay buffers with a learnable memory system. It uses reward-aware contrastive learning to organize experiences and hybrid sampling to retrieve them—solving the temporal credit assignment problem in sparse-reward, long-horizon tasks.
Install: pip install hippotorch
Key Hyperparameters
| Parameter | Default | Guidance |
|---|---|---|
mixture_ratio |
0.5 | Start low (0.2), ramp to 0.5 after warmup |
momentum |
0.995 | Higher = more stable keys, slower adaptation |
temperature |
0.07 | Lower = sharper retrieval (try 0.05 for sparse rewards) |
reward_weight |
0.5 | Higher = clusters by reward; lower = clusters by time |
When to Use (and When Not To)
Hippotorch adds overhead. Use it where episodic structure matters.
| Scenario | Benefit | Recommendation |
|---|---|---|
| Sparse rewards (Montezuma, long corridors) | ✅ High | Use hippotorch—retrieval surfaces rare successes |
| Partial observability (POMDPs, visual RL) | ✅ High | Use hippotorch—pattern completion reconstructs context |
| Long-horizon tasks (100+ steps to goal) | ✅ High | Use hippotorch—bridges temporal credit gap |
| Curriculum / transfer learning | ✅ High | Use hippotorch—retains skills across task stages |
| Dense rewards, full observability | ⚠️ Low | Use standard replay—uniform sampling is sufficient |
| Short episodes (<20 steps) | ⚠️ Low | Use standard replay—no retrieval advantage |
Rule of Thumb
If your agent "forgets" how to solve early tasks, or struggles to connect actions to delayed rewards, hippotorch can help. If training already converges well with a standard buffer, you don't need it.
Installation
pip install hippotorch
# Optional: Fast retrieval for large memories (1M+ episodes)
pip install faiss-gpu
Requirements: Python ≥3.9, PyTorch ≥2.0
Quickstart
Create a dual encoder + memory store, attach a consolidator, and use the hybrid replay buffer.
import torch
from hippotorch import Episode, DualEncoder, MemoryStore, Consolidator, HippocampalReplayBuffer
state_dim, action_dim = 4, 1
input_dim = state_dim + action_dim + 1 # +1 for reward
# 1) Encoder and memory
encoder = DualEncoder(input_dim=input_dim, embed_dim=128, momentum=0.995)
memory = MemoryStore(embed_dim=128, capacity=50_000)
# 2) Reward-aware consolidator (sleep phase optimizer)
consolidator = Consolidator(encoder, temperature=0.07, reward_weight=0.5)
# 3) Hybrid replay buffer (semantic + uniform)
buffer = HippocampalReplayBuffer(memory=memory, encoder=encoder, mixture_ratio=0.3,
consolidator=consolidator)
# 4) Record an episode (toy tensors)
T = 32
states = torch.randn(T, state_dim)
actions = torch.randn(T, action_dim)
rewards = torch.randn(T)
dones = torch.zeros(T, dtype=torch.bool)
episode = Episode(states=states, actions=actions, rewards=rewards, dones=dones)
buffer.add_episode(episode)
# 5) Sample with semantic + uniform mixing
query_state = torch.cat([states[0], torch.zeros(action_dim), rewards[0].unsqueeze(0)])
batch = buffer.sample(batch_size=64, query_state=query_state, top_k=5)
agent.update(batch)
# 6) Periodic consolidation ("sleep")
metrics = buffer.consolidate(steps=50, batch_size=64, report_quality=True)
print(metrics)
SB3 users can keep their rollout API unchanged with the adapter:
from hippotorch import SB3ReplayBufferWrapper, TerminalSegmenter
sb3_buffer = SB3ReplayBufferWrapper(buffer, segmenter=TerminalSegmenter())
# sb3_buffer.add(obs, next_obs, action, reward, done)
Portable Brains (Hub)
Export a trained memory so another agent can load it instantly:
from hippotorch import (
DualEncoder,
HippocampalReplayBuffer,
MemoryStore,
push_memory_to_hub,
load_memory_from_hub,
)
obs_dim = 42
encoder = DualEncoder(input_dim=obs_dim, embed_dim=128)
memory = MemoryStore(embed_dim=128, capacity=2048)
# Push to hub (requires real hub backend, e.g., huggingface_hub)
push_memory_to_hub(memory, repo_id="user/fetch-reach-expert", private=False)
# Later, load
restored = load_memory_from_hub("user/fetch-reach-expert")
# Or operate via the buffer convenience wrappers
buffer = HippocampalReplayBuffer(memory=memory, encoder=encoder)
buffer.save_to_hub("user/fetch-reach-expert")
restored_memory = buffer.load_memory_from_hub("user/fetch-reach-expert")
Note
- The hub utilities in this repo are minimal stubs for testing. To use the Hub in production,
integrate a real backend (e.g.,
huggingface_hub) by adaptinghippotorch.utils.hub.
Quick Experiment Scripts
Convenience scripts in scripts/ run short, repeatable experiments:
- Rank‑weighted consolidation ablation:
bash scripts/run_rank_ablation.sh
- Consolidation micro‑bench (synthetic):
bash scripts/run_consolidation_micro.sh
- CartPole parity (short run with logging):
bash scripts/quick_cartpole.sh
- Zero‑noise corridor (Amnesiac):
bash scripts/corridor_multiseed_zn.sh- Faster:
SEEDS=2 EPISODES=150 CONS_EVERY=10 CONS_STEPS=50 bash scripts/corridor_multiseed_zn.sh
- Curriculum corridor (progressively increases length):
bash scripts/corridor_curriculum.sh
- TensorBoard embedding snapshot (PCA):
bash scripts/log_tb_embedding.shthentensorboard --logdir runs/hippo_tb
See docs/benchmarks.md and docs/curriculum.md for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file hippotorch-0.3.1.tar.gz.
File metadata
- Download URL: hippotorch-0.3.1.tar.gz
- Upload date:
- Size: 41.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bfd50e77a8e3a3b8a76267a60b063ba49b2a936c6106ca67c4794ced348fe952
|
|
| MD5 |
f871e6c9e91857058f6508e1f4c1825e
|
|
| BLAKE2b-256 |
474f16a0642cfc634f966341ff1f4526959962e25818165c60adfa33369839c9
|
Provenance
The following attestation bundles were made for hippotorch-0.3.1.tar.gz:
Publisher:
workflow.yml on domezsolt/hippotorch
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
hippotorch-0.3.1.tar.gz -
Subject digest:
bfd50e77a8e3a3b8a76267a60b063ba49b2a936c6106ca67c4794ced348fe952 - Sigstore transparency entry: 813999138
- Sigstore integration time:
-
Permalink:
domezsolt/hippotorch@23f38356035dc91e4b4acb8cf5152bc9988a2846 -
Branch / Tag:
refs/tags/v0.3.1 - Owner: https://github.com/domezsolt
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
workflow.yml@23f38356035dc91e4b4acb8cf5152bc9988a2846 -
Trigger Event:
release
-
Statement type:
File details
Details for the file hippotorch-0.3.1-py3-none-any.whl.
File metadata
- Download URL: hippotorch-0.3.1-py3-none-any.whl
- Upload date:
- Size: 42.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ed9f0256ff2f9cc7a31f3c805ff8abb050e14b044c47b0c1c95b58c13822b059
|
|
| MD5 |
34f992b9ed31235f0f92e2b96b4ed43e
|
|
| BLAKE2b-256 |
e38a10d3bc7df6efa3ea77efeb2b8495289bc89438423da2966500d519aa90a0
|
Provenance
The following attestation bundles were made for hippotorch-0.3.1-py3-none-any.whl:
Publisher:
workflow.yml on domezsolt/hippotorch
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
hippotorch-0.3.1-py3-none-any.whl -
Subject digest:
ed9f0256ff2f9cc7a31f3c805ff8abb050e14b044c47b0c1c95b58c13822b059 - Sigstore transparency entry: 813999141
- Sigstore integration time:
-
Permalink:
domezsolt/hippotorch@23f38356035dc91e4b4acb8cf5152bc9988a2846 -
Branch / Tag:
refs/tags/v0.3.1 - Owner: https://github.com/domezsolt
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
workflow.yml@23f38356035dc91e4b4acb8cf5152bc9988a2846 -
Trigger Event:
release
-
Statement type: