Skip to main content

Differentiable episodic memory for reinforcement learning.

Project description

hippotorch

PyPI pipeline status coverage docs

Tested on: Ubuntu 22.04, 24.04

Differentiable episodic memory for reinforcement learning. Retrieves what matters. Forgets what doesn't.

Changelog

Hippotorch is a drop-in upgrade for replay buffers. It keeps experiences in a learnable memory so agents can remember rare successes, connect distant cause and effect, and transfer knowledge between similar worlds. Under the hood it uses reward-aware contrastive learning, but you mostly interact with a friendly API.


Highlights

  • Memory that adapts with you. Dual encoders organize episodes by usefulness instead of mere recency.
  • Semantic + uniform sampling. A single buffer can surface hard-to-find wins while still covering the full state space.
  • Production-friendly extras. Hugging Face Hub export, FAISS retrieval, Gymnasium wrappers, and health reports ship in the box.
  • Batteries included. Dozens of scripts and docs show exactly how to benchmark, visualize, and share results.

If you already converge with a plain replay buffer, keep it. Hippotorch shines when agents forget early lessons, face sparse rewards, or operate in partially observed environments.


Installation

pip install hippotorch            # minimal setup
pip install hippotorch[faiss]     # fast nearest-neighbor retrieval
pip install hippotorch[envs]      # Gymnasium helpers + examples
pip install hippotorch[hub]       # Hugging Face Hub + safetensors
pip install hippotorch[umap]      # projector UMAP export

Requirements: Python ≥3.9, PyTorch ≥2.0 (CI enforces ≥80% test coverage)

Dev install: pip install -e .[dev]


Quick Tour

Create an encoder + memory, add episodes, then mix semantic and uniform samples:

import torch
from hippotorch import Episode, DualEncoder, MemoryStore, HippocampalReplayBuffer

state_dim, action_dim = 4, 1
encoder = DualEncoder(input_dim=state_dim + action_dim + 1, embed_dim=128)
memory = MemoryStore(embed_dim=128, capacity=50_000)
buffer = HippocampalReplayBuffer(memory=memory, encoder=encoder, mixture_ratio=0.3)

states = torch.randn(32, state_dim)
actions = torch.randn(32, action_dim)
rewards = torch.randn(32)
buffer.add_episode(Episode(states=states, actions=actions, rewards=rewards))

# Query-aware sampling
query_state = torch.cat([states[0], torch.zeros(action_dim), rewards[:1]])
batch = buffer.sample(batch_size=64, query_state=query_state, top_k=5)

# Sleep/consolidate occasionally
metrics = buffer.consolidate(steps=50, batch_size=64, report_quality=True)
print(metrics["loss"])

Rolling with Stable Baselines 3 or Gymnasium? Wrap your existing replay buffer with SB3ReplayBufferWrapper or the HippotorchMemoryWrapper. The SB3 wrapper can drive semantic replay by passing a query observation (defaults to the most recent observation) and supports a custom query-building hook. Note: the SB3 wrapper currently targets single-environment rollouts; vectorized envs (VecEnv) are not supported yet.

Need hyperparameter guidance? Start with docs/hyperparameter_guide.md for recommended ranges, then see docs/diagnostics.md for health checks and docs/curriculum.md for training tips.


Everyday Tools

Recall While Acting

  • Use the lightweight read API: from hippotorch import query.
  • Pipe query(..., top_k=5) results into policies or logging code.
  • Gymnasium adapter emits dict observations so SB3 policies can consume retrieval features alongside pixels.
  • Examples: examples/query_inference_demo.py, examples/minigrid_memory_wrapper.py.

Portable Brains

  • Share trained memories with push_memory_to_hub / load_memory_from_hub.
  • Choose local folders for offline passes or Hugging Face Hub for team-wide reuse.
  • scripts/hub_roundtrip_smoke.py is a 30-second sanity check.
  • Docs: docs/hub.md.

Glass-Box Diagnostics

  • buffer.health_report() returns retrievability, staleness, collapse indicators, and alignment scores.
  • Log with report.to_tensorboard(writer, step) or report.to_wandb(run).
  • See docs/diagnostics.md for visuals.

Batch Retrieval for Low Latency

  • buffer.query_batch(query_vecs, top_k=K) handles [B,T,D] tensors in one go.
  • Matches single-query results without looping Python.
  • Works with both torch and FAISS backends.

Multi-GPU Encoding

  • Set multi_gpu=True on DualEncoder/VisualEpisodeEncoder or Consolidator to enable torch.nn.DataParallel when multiple GPUs are present.
  • Snapshots handle module. prefixes transparently; save/load works across single- and multi-GPU runs.

Ready-to-Run Samples

Pick a script, set a seed, and you get a reproducible snapshot:

  • Benchmarks & diagnostics
    • Retrieval perf: python scripts/bench_retrieval.py --sizes 10000 100000
    • Visualization: python scripts/export_projector_embeddings.py --snapshot run.pt
    • Retrieval heatmap: python scripts/retrieval_heatmap.py --memory-checkpoint ...
  • Environments
    • CartPole smoke: bash scripts/quick_cartpole.sh
    • Corridor curriculum/oracle: bash scripts/corridor_curriculum.sh, bash scripts/corridor_oracle_zn.sh
    • MiniGrid sweeps: python scripts/minigrid_memory_benchmark.py --steps 8000 --seeds 3
    • FetchReach benchmark: bash scripts/fetchreach_benchmark.sh
    • HER comparison (FetchReach): bash scripts/her_comparison.sh
    • Intrinsic curiosity example: python -m examples.intrinsic_demo --episodes 20
  • Ablations & studies
    • Rank-weighted consolidation: bash scripts/run_rank_ablation.sh
    • Consolidation micro bench: bash scripts/run_consolidation_micro.sh
    • Visual MiniGrid clustering: python -m examples.minigrid_visual --steps 2000

All scripts keep runtime under a couple of minutes unless stated otherwise. Longer jobs (corridor oracle full run, curriculum sweeps) note their expected duration in the script header.


Learn More

  • docs/benchmarks.md – retrieval setups, FAISS parity, and profiling tips.
  • docs/curriculum.md – how to stage corridor tasks and measure regret.
  • docs/usage.md – wrappers, segmenters, and rollout recipes.
  • docs/hub.md – how to move memories between machines or teammates.
  • Getting started notebook: docs/tutorials/getting_started.ipynb
  • API Reference (MkDocs): build locally with make docs and open site/index.html (source: docs/api.md). Hosted docs: https://domezsolt.gitlab.io/hippotorch
  • Sparse Atari pilot (Montezuma’s Revenge): bash scripts/atari_pilot.sh or python -u scripts/atari_sparse_pilot.py --env ALE/MontezumaRevenge-v5 --steps 10000 (requires optional extras: pip install gymnasium[atari] autorom then run AutoROM --accept-license). See docs/atari_pilot.md.

Problems or ideas? Open an issue or send a Merge Request on GitLab.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hippotorch-0.4.3.tar.gz (86.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

hippotorch-0.4.3-py3-none-any.whl (64.8 kB view details)

Uploaded Python 3

File details

Details for the file hippotorch-0.4.3.tar.gz.

File metadata

  • Download URL: hippotorch-0.4.3.tar.gz
  • Upload date:
  • Size: 86.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for hippotorch-0.4.3.tar.gz
Algorithm Hash digest
SHA256 c82b810c6d307802cadf4537bddc60fee4709751d9a9c4dc07433123190032c0
MD5 b58ac597c1c51f5b97e6245cec9492c0
BLAKE2b-256 23862d18b15150b68581e402905467c6ff28517d7c0bd96d110ad338722da970

See more details on using hashes here.

File details

Details for the file hippotorch-0.4.3-py3-none-any.whl.

File metadata

  • Download URL: hippotorch-0.4.3-py3-none-any.whl
  • Upload date:
  • Size: 64.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for hippotorch-0.4.3-py3-none-any.whl
Algorithm Hash digest
SHA256 f131f8cc5188a3e8bd68b2bcfa321f91dee87d3f22d17c6028f06af4dc600a9e
MD5 a3f3b325165785adaadebe679fd1a9e9
BLAKE2b-256 aa23d00db6b81b6c52c57e58b5197ed1fe8c4474fb5f8e9c1d51b028f9eed72a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page