Extra buffer classes for Stable-Baselines3, reduce memory usage with minimal overhead.
Project description
sb3-extra-buffers
Unofficial implementation of extra Stable-Baselines3 buffer classes. Aims to reduce memory usage drastically with minimal overhead.
Current Version: 0.2.2
Links:
- This Project on PyPI
- Stable Baselines3
- SB3 Contrib (experimental features for SB3)
- SBX (SB3 + JAX, uses SB3 buffers so can also benefit from compressed buffers here)
- RL Baselines3 Zoo (training framework for SB3)
Main Goal: Reduce the memory consumption of memory buffers in Reinforcement Learning while adding minimal overhead.
TO-DO List:
- Compressed Rollout / Replay Buffer
- Compressed variant for every buffer in SB3
- Compressed variant for every buffer in SB3-Contrib
- Compressed Array, maybe can make porting easier?
- Recording Buffers for game episodes
- Compressed Recording Buffers
- Buffer warm-up and model evaluation utility functions
- Example Atari train / eval scripts with compressed buffers
- Report results for example Atari train / eval scripts
- Example ViZDoom train / eval scripts with compressed buffers
- Report results for example ViZDoom train / eval scripts
- Report memory saving
- Documentation & better readme
- Define a standard bytes-out (compress) bytes-in (decompress) interface and store compressed obs in
np.ndarray[bytes]
Motivation:
Reinforcement Learning is quite memory-hungry due to massive buffer sizes, so let's try to tackle it by not storing raw frame buffers in full np.float32 or np.uint8 directly and find something smaller instead. For any input data that are sparse and containing large contiguous region of repeating values, lossless compression techniques can be applied to reduce memory footprint.
Applicable Input Types:
Semantic Segmentationmasks (1 color channel)Color Palettegame frames from retro video gamesGrayscalegame frames from retro video gamesRGB (Color)game frames from retro video games- For noisy input with a lot of variation (mostly
RGB), usinggzip1origzip0is recommended, run-length encoding won't work as great and can potentially even increase memory usage.
Implemented Compression Methods:
rleVectorized Run-Length Encoding for compression.rle-jitJIT-compiled version ofrle, usesnumbalibrary.gzipGzip compression viagzip.igzipIntel accelerated variant viaisal.igzip, usespython-isallibrary.noneNo compression other than casting toelem_typeand storing asbytes.
gzipsupports0-9compress levels,0is no compression,1is least compressionigzipsupports0-3compress levels,0is least compression- Shorthands are supported, i.e.
igzip3=igzipat level3
Installation
Install via PyPI:
pip install "sb3-extra-buffers[fast,extra]"
Other install options:
pip install "sb3-extra-buffers" # only installs minimum requirements
pip install "sb3-extra-buffers[extra]" # installs extra dependencies for SB3
pip install "sb3-extra-buffers[fast]" # installs python-isal and numba
pip install "sb3-extra-buffers[isal]" # only installs python-isal
pip install "sb3-extra-buffers[numba]" # only installs numba
pip install "sb3-extra-buffers[vizdoom]" # installs vizdoom
Current Project Structure
sb3_extra_buffers
|- compressed
| |- CompressedRolloutBuffer: RolloutBuffer with compression
| |- CompressedReplayBuffer: ReplayBuffer with compression
|
|- recording
| |- RecordBuffer: A buffer for recording game states
| |- FramelessRecordBuffer: RecordBuffer but not recording game frames
| |- DummyRecordBuffer: Dummy RecordBuffer, records nothing
|
|- training_utils
|- eval_model: Evaluate models in vectorized environment
|- warmup: Perform buffer warmup for off-policy algorithms
Example Scripts
Example scripts have been included and tested to ensure working properly.
Evaluation results for example training scripts:
PPO on PongNoFrameskip-v4, trained for 10M steps using rle-jit, framestack: None
(Best ) Evaluated 10000 episodes, mean reward: 21.0 +/- 0.00
Q1: 21 | Q2: 21 | Q3: 21 | Relative IQR: 0.00 | Min: 21 | Max: 21
(Final) Evaluated 10000 episodes, mean reward: 21.0 +/- 0.02
Q1: 21 | Q2: 21 | Q3: 21 | Relative IQR: 0.00 | Min: 20 | Max: 21
DQN on MsPacmanNoFrameskip-v4, trained for 10M steps using rle-jit, framestack: 4
(Best ) Evaluated 10000 episodes, mean reward: 3300.0 +/- 770.79
Q1: 2490 | Q2: 4020 | Q3: 4020 | Relative IQR: 0.38 | Min: 2460 | Max: 4020
(Final) Evaluated 10000 episodes, mean reward: 3379.2 +/- 453.78
Q1: 2690 | Q2: 3400 | Q3: 3880 | Relative IQR: 0.35 | Min: 1230 | Max: 4090
Compressed Buffers
Defined in sb3_extra_buffers.compressed
JIT Before Multi-Processing:
When using rle-jit, remember to trigger JIT compilation before any multi-processing code is executed.
# Code for other stuffs...
from sb3_extra_buffers.compressed.compression_methods import has_numba
# Compressed-buffer-related settings
compression_method = "rle-jit"
storage_dtypes = dict(elem_type=np.uint8, runs_type=np.uint16)
# Pre-JIT Numba to avoid fork issues
if has_numba() and "jit" in compression_method:
from sb3_extra_buffers.compressed.compression_methods.compression_methods_numba import init_jit
init_jit(**storage_dtypes)
# Now, safe to initialize multi-processing environments!
env = SubprocVecEnv(...)
Example Usage:
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import EvalCallback
from sb3_extra_buffers.compressed import CompressedRolloutBuffer, find_smallest_dtype
from sb3_extra_buffers.training_utils.atari import make_env
ATARI_GAME = "MsPacmanNoFrameskip-v4"
if __name__ == "__main__":
flatten_obs_shape = np.prod(make_env(env_id=ATARI_GAME, n_envs=1, framestack=4).observation_space.shape)
buffer_dtypes = dict(elem_type=np.uint8, runs_type=find_smallest_dtype(flatten_obs_shape))
env = make_env(env_id=ATARI_GAME, n_envs=4, framestack=4)
eval_env = make_env(env_id=ATARI_GAME, n_envs=1, framestack=4)
# Create PPO model using CompressedRolloutBuffer
model = PPO("CnnPolicy", env, verbose=1, rollout_buffer_class=CompressedRolloutBuffer,
rollout_buffer_kwargs=dict(dtypes=buffer_dtypes, compression_method="rle"))
# Evaluation callback (optional)
eval_callback = EvalCallback(eval_env, log_path=f"./logs/{ATARI_GAME}",
best_model_save_path=f"./logs/{ATARI_GAME}/best_model")
# Training
model.learn(total_timesteps=100_000, callback=eval_callback, progress_bar=True)
# Save the final model
model.save(f"ppo-{ATARI_GAME}.zip")
# Cleanup
env.close()
eval_env.close()
Recording Buffers
Defined in sb3_extra_buffers.recording
Mainly used in combination with SegDoom to record stuff.
WIP
Training Utils
Defined in sb3_extra_buffers.training_utils
Buffer warm-up and model evaluation
WIP
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file sb3_extra_buffers-0.2.2.tar.gz.
File metadata
- Download URL: sb3_extra_buffers-0.2.2.tar.gz
- Upload date:
- Size: 20.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6366a624bd6f11863d87ad580979a5880e788ae9c6ecd5b19f7e5e538774bf5b
|
|
| MD5 |
15f110ec4cc770e0bc491bcee5eaf756
|
|
| BLAKE2b-256 |
4e24b61280277c475c4d15a72196297239d28de9ef00694fec92013b03c9cbbd
|
File details
Details for the file sb3_extra_buffers-0.2.2-py3-none-any.whl.
File metadata
- Download URL: sb3_extra_buffers-0.2.2-py3-none-any.whl
- Upload date:
- Size: 28.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cb3ac242a1e49386b156e6f9b3471c2e183c94d41acd8741a3a5abd315f6ed17
|
|
| MD5 |
90077220d62a4a5e8b87c55f234ea666
|
|
| BLAKE2b-256 |
ce8f682cf52957eea8ea1fb1a8811eda29467cd6032886a27a83930d723907cb
|