Skip to main content

Simple Replay Buffer for RL

Project description

memmap-replay-buffer

An easy-to-use numpy memmap replay buffer for RL and other sequence-based learning tasks.

Install

$ pip install memmap-replay-buffer

Usage

import torch
from memmap_replay_buffer import ReplayBuffer

# initialize buffer

buffer = ReplayBuffer(
    './replay_data',
    max_episodes = 1000,
    max_timesteps = 500,
    fields = dict(
        state = ('float', (3, 16, 16), 0.),    # type, shape, and optional default value
        action = ('int', 2),
        reward = 'float'                       # default shape is ()
    ),
    meta_fields = dict(
        task_id = 'int'
    ),
    circular = True,
    overwrite = True
)

# store 4 episodes

for _ in range(4):
    with buffer.one_episode(task_id = 1):
        for _ in range(100):
            buffer.store(
                state = torch.randn(3, 16, 16),
                action = torch.randint(0, 4, (2,)).numpy(),
                reward = 1.0
            )

# rehydrate from disk

buffer_rehydrated = ReplayBuffer.from_folder('./replay_data')
assert buffer_rehydrated.num_episodes == 4

# train 2 episodes at a time

dataloader = buffer.dataloader(
    batch_size = 2,
    return_mask = True,
    to_named_tuple = ('state', 'action', 'reward', 'task_id', '_mask', '_lens')
)

for state, action, reward, task_id, mask, lens in dataloader:
    assert state.shape   == (2, 100, 3, 16, 16)
    assert action.shape  == (2, 100, 2)
    assert reward.shape  == (2, 100)
    assert task_id.shape == (2,)

    assert lens.shape    == (2,)
    assert mask.shape    == (2, 100)

# for loading per timestep

dataloader = buffer.dataloader(
    batch_size = 8,
    filter_meta = dict(
        task_id = 1
    ),
    to_named_tuple = ('state', 'action', 'task_id'),
    timestep_level = True,
    drop_last = True
)

for state, action, task_id in dataloader:
    assert state.shape   == (8, 3, 16, 16)
    assert action.shape  == (8, 2)
    assert task_id.shape == (8,)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

memmap_replay_buffer-0.0.16.tar.gz (10.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

memmap_replay_buffer-0.0.16-py3-none-any.whl (9.4 kB view details)

Uploaded Python 3

File details

Details for the file memmap_replay_buffer-0.0.16.tar.gz.

File metadata

  • Download URL: memmap_replay_buffer-0.0.16.tar.gz
  • Upload date:
  • Size: 10.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for memmap_replay_buffer-0.0.16.tar.gz
Algorithm Hash digest
SHA256 99eae95ecbd04140c360a57d4372defb42b2aa35cc73d7a19d65bc83859b965a
MD5 f5f6d0d536745fc91751dfbedf0debb1
BLAKE2b-256 43649ca474c46c2640e3585d9015162c1b2f0057c4983e9c9a7d077dae84f72a

See more details on using hashes here.

File details

Details for the file memmap_replay_buffer-0.0.16-py3-none-any.whl.

File metadata

File hashes

Hashes for memmap_replay_buffer-0.0.16-py3-none-any.whl
Algorithm Hash digest
SHA256 a1ac948df7555440a6c505d61a1a42b04d3da8c5504d91ac00707a2f32bba13d
MD5 2440293e0558dbf870096be452942216
BLAKE2b-256 b0414ebd46796a381050588d8ed85e494eee22fe6eb224174409fb36bdd36612

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page