Skip to main content

A simple replay buffer implementation in python for sampling n-step trajectories

Project description

ReplayTables

Benchmarks

Getting started

Installation:

pip install ReplayTables-andnp

Basic usage:

from typing import NamedTuple
from ReplayTables.ReplayBuffer import ReplayBuffer

class Data(NamedTuple):
    x: np.ndarray
    a: np.ndarray
    r: np.ndarray

buffer = ReplayBuffer(
    max_size=100_000,
    structure=Data,
    rng=np.random.default_rng(0),
)

buffer.add(Data(x, a, r))

batch = buffer.sample(32)
print(batch.x.shape) # -> (32, d)
print(batch.a.shape) # -> (32, )
print(batch.r.shape) # -> (32, )

Prioritized Replay

An implementation of prioritized experience replay from

Schaul, Tom, et al. "Prioritized experience replay." ICLR (2016).

The defaults for this implementation strictly adhere to the defaults from the original work, though several configuration options are available.

from typing import NamedTuple
from ReplayTables.PER import PERConfig, PrioritizedReplay

class Data(NamedTuple):
    a: float
    b: float

# all configurables are optional.
config = PERConfig(
    # can also use "mean" mode to place new samples in the middle of the distribution
    # or "given" mode, which requires giving the priority when the sample is added
    new_priority_mode='max',
    # the sampling distribution is a mixture between uniform sampling and the priority
    # distribution. This specifies the weight given to the uniform sampler.
    # Setting to 1 reverts this back to an inefficient form of standard uniform replay.
    uniform_probability=1e-3,
    # this implementation assume priorities are positive. Can scale priorities by raising to
    # some power. Default is `priority**(1/2)`
    priority_exponent=0.5,
    # if `new_priority_mode` is 'max', then the buffer tracks the highest seen priority.
    # this can cause accidental saturation if outlier priorities are observed. This provides
    # an exponential decay of the max in order to prevent permanent saturation.
    max_decay=1,
)

# if no config is given, defaults to original PER parameters
buffer = PrioritizedReplay(
    max_size=100_000,
    structure=Data,
    rng=np.random.default_rng(0),
    config=config,
)

buffer.add(Data(a=1, b=2))

# if `new_priority_mode` is 'given':
buffer.add(Data(a=1, b=2), priority=1.3)

batch = buffer.sample(32)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ReplayTables-andnp-5.9.2.tar.gz (27.2 kB view details)

Uploaded Source

Built Distribution

ReplayTables_andnp-5.9.2-py3-none-any.whl (30.6 kB view details)

Uploaded Python 3

File details

Details for the file ReplayTables-andnp-5.9.2.tar.gz.

File metadata

  • Download URL: ReplayTables-andnp-5.9.2.tar.gz
  • Upload date:
  • Size: 27.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.6

File hashes

Hashes for ReplayTables-andnp-5.9.2.tar.gz
Algorithm Hash digest
SHA256 d8d9432bf0a5b2453207239a80b7ab7c6e11970326b21ca25a538f194c1f8e9a
MD5 240ef03f36b4fc6d1a3fbbc975d10958
BLAKE2b-256 edaa5b268d7b7075d66fe11d408a027d38fc061fd3ee1be2e03f3293b50b4f52

See more details on using hashes here.

File details

Details for the file ReplayTables_andnp-5.9.2-py3-none-any.whl.

File metadata

File hashes

Hashes for ReplayTables_andnp-5.9.2-py3-none-any.whl
Algorithm Hash digest
SHA256 d2f33da92a7db3baa5aebda54859cb1a4f6f9f42a6cf9c320fabd9c2aad4dade
MD5 7da442db7ac003d2e9c13592a9c3e37d
BLAKE2b-256 7e1098080f69442f7c9faddbc55b79a2c32e65fa322dade727804449a31f4b01

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page