Skip to main content

Simple Replay Buffer for RL

Project description

memmap-replay-buffer

An easy-to-use numpy memmap replay buffer for RL and other sequence-based learning tasks.

Install

$ pip install memmap-replay-buffer

Usage

import torch
from memmap_replay_buffer import ReplayBuffer

# initialize buffer

buffer = ReplayBuffer(
    './replay_data',
    max_episodes = 1000,
    max_timesteps = 500,
    fields = dict(
        state = ('float', (3, 16, 16), 0.),    # type, shape, and optional default value
        action = ('int', 2),
        reward = 'float'                       # default shape is ()
    ),
    meta_fields = dict(
        task_id = 'int'
    ),
    circular = True,
    overwrite = True
)

# store 4 episodes

for _ in range(4):
    with buffer.one_episode(task_id = 1):
        for _ in range(100):
            buffer.store(
                state = torch.randn(3, 16, 16),
                action = torch.randint(0, 4, (2,)).numpy(),
                reward = 1.0
            )

# rehydrate from disk

buffer_rehydrated = ReplayBuffer.from_folder('./replay_data')
assert buffer_rehydrated.num_episodes == 4

# train 2 episodes at a time

dataloader = buffer.dataloader(
    batch_size = 2,
    return_mask = True,
    to_named_tuple = ('state', 'action', 'reward', 'task_id', '_mask', '_lens')
)

for state, action, reward, task_id, mask, lens in dataloader:
    assert state.shape   == (2, 100, 3, 16, 16)
    assert action.shape  == (2, 100, 2)
    assert reward.shape  == (2, 100)
    assert task_id.shape == (2,)

    assert lens.shape    == (2,)
    assert mask.shape    == (2, 100)

# for loading per timestep

dataloader = buffer.dataloader(
    batch_size = 8,
    filter_meta = dict(
        task_id = 1
    ),
    to_named_tuple = ('state', 'action', 'task_id'),
    timestep_level = True,
    drop_last = True
)

for state, action, task_id in dataloader:
    assert state.shape   == (8, 3, 16, 16)
    assert action.shape  == (8, 2)
    assert task_id.shape == (8,)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

memmap_replay_buffer-0.0.23.tar.gz (17.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

memmap_replay_buffer-0.0.23-py3-none-any.whl (14.9 kB view details)

Uploaded Python 3

File details

Details for the file memmap_replay_buffer-0.0.23.tar.gz.

File metadata

  • Download URL: memmap_replay_buffer-0.0.23.tar.gz
  • Upload date:
  • Size: 17.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for memmap_replay_buffer-0.0.23.tar.gz
Algorithm Hash digest
SHA256 3bc082e0bec2ebc3f900559279603ca23e09065d0bed31571d38bc9a189e2cf9
MD5 bf14f491c7a373f1ddfb607d10367a46
BLAKE2b-256 9d46f6a67df71a726324c7d4fdaa73099f4b273b24354ce4aadbeb209f77ed47

See more details on using hashes here.

File details

Details for the file memmap_replay_buffer-0.0.23-py3-none-any.whl.

File metadata

File hashes

Hashes for memmap_replay_buffer-0.0.23-py3-none-any.whl
Algorithm Hash digest
SHA256 dad33b5510fed6c06c7c5d111d0a0e226a78c4d068e6bd8d6308273f64b67d27
MD5 0b6f230ef0b9b1cf4512366a8a3c45ca
BLAKE2b-256 998f57370fae45d54f8cbf3cc7c0569a4f15e6c12b399465d25be71391d7923a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page