Skip to main content

Simple Replay Buffer for RL

Project description

memmap-replay-buffer

An easy-to-use numpy memmap replay buffer for RL and other sequence-based learning tasks.

Install

$ pip install memmap-replay-buffer

Usage

import torch
from memmap_replay_buffer import ReplayBuffer

# initialize buffer

buffer = ReplayBuffer(
    './replay_data',
    max_episodes = 1000,
    max_timesteps = 500,
    fields = dict(
        state = ('float', (3, 16, 16), 0.),    # type, shape, and optional default value
        action = ('int', 2),
        reward = 'float'                       # default shape is ()
    ),
    meta_fields = dict(
        task_id = 'int'
    ),
    circular = True,
    overwrite = True
)

# store 4 episodes

for _ in range(4):
    with buffer.one_episode(task_id = 1):
        for _ in range(100):
            buffer.store(
                state = torch.randn(3, 16, 16),
                action = torch.randint(0, 4, (2,)).numpy(),
                reward = 1.0
            )

# rehydrate from disk

buffer_rehydrated = ReplayBuffer.from_folder('./replay_data')
assert buffer_rehydrated.num_episodes == 4

# train 2 episodes at a time

dataloader = buffer.dataloader(
    batch_size = 2,
    to_named_tuple = ('state', 'action', 'reward', 'task_id', '_lens')
)

for state, action, reward, task_id, lens in dataloader:
    assert state.shape   == (2, 100, 3, 16, 16)
    assert action.shape  == (2, 100, 2)
    assert reward.shape  == (2, 100)
    assert task_id.shape == (2,)
    assert lens.shape    == (2,)

# for loading per timestep

dataloader = buffer.dataloader(
    batch_size = 8,
    filter_meta = dict(
        task_id = 1
    ),
    to_named_tuple = ('state', 'action'),
    timestep_level = True,
    drop_last = True
)

for state, action in dataloader:
    assert state.shape == (8, 3, 16, 16)
    assert action.shape == (8, 2)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

memmap_replay_buffer-0.0.10.tar.gz (10.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

memmap_replay_buffer-0.0.10-py3-none-any.whl (9.1 kB view details)

Uploaded Python 3

File details

Details for the file memmap_replay_buffer-0.0.10.tar.gz.

File metadata

  • Download URL: memmap_replay_buffer-0.0.10.tar.gz
  • Upload date:
  • Size: 10.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for memmap_replay_buffer-0.0.10.tar.gz
Algorithm Hash digest
SHA256 a08b2fb82000b8b0c9eb9a0a4ea900317bdeed714f6b85fd2a8652d317fc949a
MD5 ab43f2d4ce966aa8e2720af80740a69f
BLAKE2b-256 d102ee1f4ec96cbf1f91ede6f9977bb019f372dd71d26f3c3985ea5e1fafbd01

See more details on using hashes here.

File details

Details for the file memmap_replay_buffer-0.0.10-py3-none-any.whl.

File metadata

File hashes

Hashes for memmap_replay_buffer-0.0.10-py3-none-any.whl
Algorithm Hash digest
SHA256 251f8d57d59c9b0295466840175abe91cff0094027701aafefb4575bec189118
MD5 29406647d5a546460e57a900b83d3dce
BLAKE2b-256 cb4e94fe094a2bdfbedc1fdb966e95e6a425d566a4557dbd1f5435c188421884

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page