A simple replay buffer implementation in python for sampling n-step trajectories
Project description
ReplayTables
Getting started
Installation:
pip install ReplayTables-andnp
Basic usage:
from typing import NamedTuple
from ReplayTables.ReplayBuffer import ReplayBuffer
class Data(NamedTuple):
x: np.ndarray
a: np.ndarray
r: np.ndarray
buffer = ReplayBuffer(
max_size=100_000,
structure=Data,
rng=np.random.default_rng(0),
)
buffer.add(Data(x, a, r))
batch = buffer.sample(32)
print(batch.x.shape) # -> (32, d)
print(batch.a.shape) # -> (32, )
print(batch.r.shape) # -> (32, )
Prioritized Replay
An implementation of prioritized experience replay from
Schaul, Tom, et al. "Prioritized experience replay." ICLR (2016).
The defaults for this implementation strictly adhere to the defaults from the original work, though several configuration options are available.
from typing import NamedTuple
from ReplayTables.PER import PERConfig, PrioritizedReplay
class Data(NamedTuple):
a: float
b: float
# all configurables are optional.
config = PERConfig(
# can also use "mean" mode to place new samples in the middle of the distribution
# or "given" mode, which requires giving the priority when the sample is added
new_priority_mode='max',
# the sampling distribution is a mixture between uniform sampling and the priority
# distribution. This specifies the weight given to the uniform sampler.
# Setting to 1 reverts this back to an inefficient form of standard uniform replay.
uniform_probability=1e-3,
# this implementation assume priorities are positive. Can scale priorities by raising to
# some power. Default is `priority**(1/2)`
priority_exponent=0.5,
# if `new_priority_mode` is 'max', then the buffer tracks the highest seen priority.
# this can cause accidental saturation if outlier priorities are observed. This provides
# an exponential decay of the max in order to prevent permanent saturation.
max_decay=1,
)
# if no config is given, defaults to original PER parameters
buffer = PrioritizedReplay(
max_size=100_000,
structure=Data,
rng=np.random.default_rng(0),
config=config,
)
buffer.add(Data(a=1, b=2))
# if `new_priority_mode` is 'given':
buffer.add(Data(a=1, b=2), priority=1.3)
batch = buffer.sample(32)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
ReplayTables-andnp-4.0.2.tar.gz
(19.9 kB
view details)
Built Distribution
File details
Details for the file ReplayTables-andnp-4.0.2.tar.gz
.
File metadata
- Download URL: ReplayTables-andnp-4.0.2.tar.gz
- Upload date:
- Size: 19.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | c19d486e371cff975db26d7335befc2c99d7f5445c846b504da92e9ce95b8338 |
|
MD5 | 7d143f6279abea7567f3578d9081f09c |
|
BLAKE2b-256 | 3b5b5b67f6e2f2c1fe4b21efe6e99b4c23abc583aac07a721df6655ecac46e0b |
File details
Details for the file ReplayTables_andnp-4.0.2-py3-none-any.whl
.
File metadata
- Download URL: ReplayTables_andnp-4.0.2-py3-none-any.whl
- Upload date:
- Size: 21.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3a7e6458f2bfcdd93f32b7473fe56b6d70d585a666de358071e2d5f3d906b215 |
|
MD5 | 0fa1a61d8a6db6504d1a010375026d03 |
|
BLAKE2b-256 | bba02e25532e864e722fe2de3bde8c8d70476123603ce125cf61d53777d44bbf |