Extra buffer classes for Stable-Baselines3
Reason this release was yanked:
Numba JIT Caching issue
Project description
sb3-extra-buffers
Unofficial implementation of extra Stable-Baselines3 buffer classes, mostly a proof-of-concept in current state. Main Goal: Reduce the memory consumption of memory buffers in Reinforcement Learning.
Motivation:
Reinforcement Learning is quite memory-hungry due to massive buffer sizes, so let's try to tackle it by not storing raw frame buffers in full np.float32 directly and find something smaller instead. For any input data that are sparse and containing large contiguous region of repeating values, lossless compression techniques can be applied to reduce memory footprint.
Applicable Input Types:
Semantic Segmentationmasks (1 color channel)Color Palettegame frames from retro video gamesGrayscalegame frames from retro video games
Installation
To install with isal and numba support:
pip install "sb3_extra_buffers[fast]"
Other install options:
pip install sb3_extra_buffers # only installs minimum requirements
pip install "sb3_extra_buffers[isal]" # only installs python-isal
pip install "sb3_extra_buffers[numba]" # only installs numba
pip install "sb3_extra_buffers[atari]" # installs gymnasium, ale-py
pip install "sb3_extra_buffers[vizdoom]" # installs gymnasium, vizdoom
Project Structure
sb3_extra_buffers
|- compressed
| |- CompressedRolloutBuffer: RolloutBuffer with compression
| |- CompressedReplayBuffer: ReplayBuffer with compression
|
|- recording
|- RecordBuffer: A buffer for recording game states
|- FramelessRecordBuffer: RecordBuffer but not recording game frames
|- DummyRecordBuffer: Dummy RecordBuffer, records nothing
Compressed Buffers
Defined in sb3_extra_buffers.compressed
Implemented Compression Methods:
rleUses Run-Length Encoding for compression.rle-jitJIT-compiled version ofrle, usesnumbalibrary.gzipCompression viagzip.igzipCompression viaisal.igzip, usespython-isallibrary.noneNo compression other than casting toelem_type.
JIT Before Multi-Processing:
When using rle-jit, remember to trigger JIT compilation before any multi-processing code is executed.
# Code for other stuffs...
from sb3_extra_buffers.compressed.compression_methods import HAS_NUMBA
# Compressed-buffer-related settings
compression_method = "rle-jit"
storage_dtypes = dict(elem_type=np.uint8, runs_type=np.uint16)
# Pre-JIT Numba to avoid fork issues
if HAS_NUMBA and "jit" in compression_method:
from sb3_extra_buffers.compressed.compression_methods.compression_methods_numba import init_jit
init_jit(**storage_dtypes)
# Now, safe to initialize multi-processing environments!
env = SubprocVecEnv([make_env for _ in range(4)])
Example Usage:
import numpy as np
import gymnasium as gym
from stable_baselines3 import PPO
from sb3_extra_buffers.compressed import CompressedRolloutBuffer, find_smallest_dtype
env = gym.make("CartPole-v1", render_mode="human")
flatten_obs_shape = np.prod(env.observation_space.shape)
buffer_dtypes = dict(elem_type=np.uint8, runs_type=find_smallest_dtype(flatten_obs_shape))
model = PPO("MlpPolicy", env, verbose=1, rollout_buffer_class=CompressedRolloutBuffer,
rollout_buffer_kwargs=dict(dtypes=buffer_dtypes, compression_method="rle"))
model.learn(total_timesteps=10_000)
vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = vec_env.step(action)
vec_env.render()
env.close()
Recording Buffers
Defined in sb3_extra_buffers.recording
Mainly used in combination with SegDoom to record stuffs.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file sb3_extra_buffers-0.1.6.tar.gz.
File metadata
- Download URL: sb3_extra_buffers-0.1.6.tar.gz
- Upload date:
- Size: 12.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
395bc332014e80707f46b9897985d2b7a841048eafcb2321f5bc8e6fac8e5701
|
|
| MD5 |
ee843995cceeecdc1c473b7164e27c22
|
|
| BLAKE2b-256 |
888c4b4aa3bab64f1b5e10a5b75132e437204a72797a4cda3f16728df13aa710
|
File details
Details for the file sb3_extra_buffers-0.1.6-py3-none-any.whl.
File metadata
- Download URL: sb3_extra_buffers-0.1.6-py3-none-any.whl
- Upload date:
- Size: 13.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0b15a34886e0bd17a12aba4c043384adde81d3b76bfd15bdcbc16f0fa31dd39c
|
|
| MD5 |
67ad69cfb093f8f38c974efb81b1ea30
|
|
| BLAKE2b-256 |
253a1ca1011d6447a68d3d706d9fd387499fe514438922a859c50cace7bb0b08
|