Skip to main content

Simple Replay Buffer for RL

Project description

memmap-replay-buffer

An easy-to-use numpy memmap replay buffer for RL and other sequence-based learning tasks.

Install

$ pip install memmap-replay-buffer

Usage

import torch
from memmap_replay_buffer import ReplayBuffer

# initialize buffer

buffer = ReplayBuffer(
    './replay_data',
    max_episodes = 1000,
    max_timesteps = 500,
    fields = dict(
        state = ('float', (3, 16, 16), 0.),    # type, shape, and optional default value
        action = ('int', 2),
        reward = 'float'                       # default shape is ()
    ),
    meta_fields = dict(
        task_id = 'int'
    ),
    circular = True,
    overwrite = True
)

# store 4 episodes

for _ in range(4):
    with buffer.one_episode(task_id = 1):
        for _ in range(100):
            buffer.store(
                state = torch.randn(3, 16, 16),
                action = torch.randint(0, 4, (2,)).numpy(),
                reward = 1.0
            )

# rehydrate from disk

buffer_rehydrated = ReplayBuffer.from_folder('./replay_data')
assert buffer_rehydrated.num_episodes == 4

# train 2 episodes at a time

dataloader = buffer.dataloader(
    batch_size = 2,
    to_named_tuple = ('state', 'action', 'reward', 'task_id', '_lens')
)

for state, action, reward, task_id, lens in dataloader:
    assert state.shape   == (2, 100, 3, 16, 16)
    assert action.shape  == (2, 100, 2)
    assert reward.shape  == (2, 100)
    assert task_id.shape == (2,)
    assert lens.shape    == (2,)

# for loading per timestep

dataloader = buffer.dataloader(
    batch_size = 8,
    filter_meta = dict(
        task_id = 1
    ),
    to_named_tuple = ('state', 'action'),
    timestep_level = True,
    drop_last = True
)

for state, action in dataloader:
    assert state.shape == (8, 3, 16, 16)
    assert action.shape == (8, 2)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

memmap_replay_buffer-0.0.9.tar.gz (10.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

memmap_replay_buffer-0.0.9-py3-none-any.whl (9.0 kB view details)

Uploaded Python 3

File details

Details for the file memmap_replay_buffer-0.0.9.tar.gz.

File metadata

  • Download URL: memmap_replay_buffer-0.0.9.tar.gz
  • Upload date:
  • Size: 10.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for memmap_replay_buffer-0.0.9.tar.gz
Algorithm Hash digest
SHA256 d77c6bd4cf90d7c5f510d0ff9544bb5878348069ce89f441a011a992b15a58c8
MD5 7be8f26f453f4d3d359e242cdb45d93f
BLAKE2b-256 fc99a33b295c781ea023dad546300ca27b8c82d1772a2bb726bdd2e6d8dcd69b

See more details on using hashes here.

File details

Details for the file memmap_replay_buffer-0.0.9-py3-none-any.whl.

File metadata

File hashes

Hashes for memmap_replay_buffer-0.0.9-py3-none-any.whl
Algorithm Hash digest
SHA256 8e30947b6842c3852b579239b55c4dcfbfdf37f6d9af11006ac24c08f22cf9ee
MD5 0164b0e2da3936e471a0ed15928e60da
BLAKE2b-256 401a379d21159dccbf087cfb3882f88a9a8af5b9415cdf209e6304e5ea64fe15

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page